zzm

pytorch数据集读取
一、自带数据集读取利用torchvision.datasets可以自动下载一些预置数据集,如果数据集本地存在,则会...
扫描右侧二维码阅读全文
07
2019/02

pytorch数据集读取

一、自带数据集读取

利用torchvision.datasets可以自动下载一些预置数据集,如果数据集本地存在,则会使用本地的

import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


# 读取CIFAR10训练数据集
trainset = torchvision.datasets.CIFAR10(root='data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

# 读取CIFAR10测试数据集
testset = torchvision.datasets.CIFAR10(root='data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

自定义数据集读取

例如需要读取mnist目录下的数据集
其中./mnist/train.txt和./mnist/test.txt分别保存了训练集和测试集图片的位置、标签。
例如train.txt内容如下:

./mnist//train/0.jpg tensor(5)
./mnist//train/1.jpg tensor(0)
./mnist//train/2.jpg tensor(4)
./mnist//train/3.jpg tensor(1)
./mnist//train/4.jpg tensor(9)
./mnist//train/5.jpg tensor(2)
./mnist//train/6.jpg tensor(1)
./mnist//train/7.jpg tensor(3)
./mnist//train/8.jpg tensor(1)
./mnist//train/9.jpg tensor(4)
...

则可以通过自定义类继承torch.utils.data.dataset.Dataset,实现三个方法__init__、__getitem__和__len__来读取数据集。

import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
class MyDataset(Dataset):
    # 数据集初始化操作
    def __init__(self,base_path):
        self.data = []
        with open(base_path) as fp:
            for line in fp.readlines():
                tmp = line.split(" ")
                self.data.append([tmp[0],tmp[1][7:8]])

        # 定义transform
        self.transformations = \
            transforms.Compose([transforms.ToTensor()])
    
    # 获取单条数据
    def __getitem__(self, index):
        img = self.transformations(Image.open(self.data[index][0]))
        label = int(self.data[index][1])
        return img,label
 
    # 数据集长度
    def __len__(self):
        return len(self.data)

batch_size = 64
train_data = MyDataset("./mnist/train.txt")
test_data = MyDataset("./mnist/test.txt")
train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data,
                                          batch_size=batch_size,
                                          shuffle=False)
Last modification:February 8th, 2019 at 12:37 pm

Leave a Comment