按照批次加载数据
Dataset
torch.utils.data.Dataset
是一个抽象类,表示数据集。用户需要自定义类来继承这个抽象类,并实现两个核心方法:__len__()
和__getitem__()
__len__()
:返回数据集的总大小
__getitem__(index)
:根据索引返回一个数据点
1 2 3 4 5 6 7 8 9 10 11 12 from torch.utils.data import Datasetclass CustomDataset (Dataset ): def __init__ (self, data, labels ): self .data = data self .labels = labels def __len__ (self ): return len (self .data) def __getitem__ (self, index ): return self .data[index], self .labels[index]
如果data
和labels
为列表或数组,
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 from torch.utils.data import Datasetimport torchclass CustomDataset (Dataset ): def __init__ (self, data, labels ): self .data = data self .labels = labels def __len__ (self ): return len (self .data) def __getitem__ (self, index ): X = torch.tensor(self .data[index], dtype=torch.float ) y = torch.tensor(self .labels[index], dtype=torch.float ) return X, y
对Dataset
类 ,可以使用dataset[i]
的方式使用索引i
访问数据,比如:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 import torchimport torch.nn as nnfrom torch.utils.data import TensorDataset, DataLoader true_w = torch.tensor([2 , 3.2 ]).reshape((-1 , 1 )) true_b = torch.tensor(6.6 ) features = torch.normal(0 , 4 , (3000 , 2 )) labels = torch.matmul(features, true_w) + true_b labels += torch.normal(0 , 0.2 , labels.shape) dataset = TensorDataset(features, labels) dataloader = DataLoader(dataset, batch_size=128 , shuffle=True )print (dataset[0 ])
输出结果:
1 (tensor([ 7.8515 , -0.8850 ]), tensor([19.6474 ]))
DataLoader
torch.utils.data.DataLoader
用于封装Dataset
,提供批量加载、打乱数据、多线程加载等功能
1 2 3 4 from torch.utils.data import DataLoader dataset = CustomDataset(data, labels) dataloader = DataLoader(dataset, batch_size=32 , shuffle=True , num_workers=4 )
batch_size
:每个批次的大小
shuffle
:是否在每个epoch开始时打乱数据
num_workers
:用于数据加载的子进程数
dataset
:数据集对象
drop_last
:默认为False
。指定在数据集的大小不能被批大小batch_size
整除时,是否要丢弃最后一个不完整的批次。
加载数据:
1 2 3 for epoch in range (1 ): for batch_index, (inputs, labels) in enumerate (dataloader): pass
batch_index
:
(inputs, labels)
:得到的批次特征和标签
enumerate
是python的内置函数,用于将可迭代对象(如列表、元组、字符串等)组合为一个索引序列,同时列出数据和数据下标
如果不需要获取批次的索引,可以不使用enumerate
:
1 2 3 for epoch in range (1 ): for inputs, labels in dataloader: print (f'inputs = {inputs} , labels = {labels} ' )
TensorDataset
torch.utils.data.TensorDataset
,无需自定义Dataset
,适用于数据已经完全加载到内存中的情况。如果数据集太大无法一次性加载到内存,则需要考虑其他方法,比如用自定义的Dataset
类进行懒加载(在__getitem__()
中使用打开文件返回数据)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 import torchfrom torch.utils.data import TensorDataset, DataLoader features = torch.randn(100 , 10 ) true_w = torch.normal(2 , 3 , (10 , 1 )) true_b = torch.tensor(0.5 ) labels = torch.matmul(features, true_w) + true_b labels += torch.randn(labels.shape) dataset = TensorDataset(features, labels) dataloader = DataLoader(dataset, batch_size=32 , shuffle=True , num_workers=4 )for epoch in range (1 ): for x, y in dataloader: print (x, y)
未解决的问题:
Dataset
自定义数据集,DataLoader
无法使用num_workers
设置多进程,但是如果使用TensorDataset
,DataLoader
可以设置num_workers