参考:Datasets & DataLoaders — PyTorch Tutorials 2.0.1+cu117 documentation
加载一个数据集 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 import torchfrom torch.utils.data import Datasetfrom torchvision import datasetsfrom torchvision.transforms import ToTensorimport matplotlib.pyplot as plttraining_data = datasets.FashionMNIST( root="data" , train=True , download=True , transform=ToTensor() ) test_data = datasets.FashionMNIST( root="data" , train=False , download=True , transform=ToTensor() )
自定义一个数据集 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 import osimport pandas as pdfrom torchvision.io import read_imageclass CustomImageDataset (Dataset ): def __init__ (self, annotations_file, img_dir, transform=None , target_transform=None ): self .img_labels = pd.read_csv(annotations_file) self .img_dir = img_dir self .transform = transform self .target_transform = target_transform def __len__ (self ): return len (self .img_labels) def __getitem__ (self, idx ): img_path = os.path.join(self .img_dir, self .img_labels.iloc[idx, 0 ]) image = read_image(img_path) label = self .img_labels.iloc[idx, 1 ] if self .transform: image = self .transform(image) if self .target_transform: label = self .target_transform(label) return image, label
DataLoaders 1 2 3 from torch.utils.data import DataLoadertrain_dataloader = DataLoader(training_data, batch_size=64 , shuffle=True ) test_dataloader = DataLoader(test_data, batch_size=64 , shuffle=Tru