数据处理与加载

按照批次加载数据

Dataset

torch.utils.data.Dataset是一个抽象类,表示数据集。用户需要自定义类来继承这个抽象类,并实现两个核心方法:__len__()__getitem__()

  • __len__():返回数据集的总大小
  • __getitem__(index):根据索引返回一个数据点
1
2
3
4
5
6
7
8
9
10
11
12
from torch.utils.data import Dataset

class CustomDataset(Dataset):
def __init__(self, data, labels): # data和labels为tensor类型数据
self.data = data
self.labels = labels

def __len__(self):
return len(self.data)

def __getitem__(self, index):
return self.data[index], self.labels[index]

如果datalabels为列表或数组,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from torch.utils.data import Dataset
import torch

class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels

def __len__(self):
return len(self.data)

def __getitem__(self, index):
X = torch.tensor(self.data[index], dtype=torch.float)
y = torch.tensor(self.labels[index], dtype=torch.float)
return X, y

Dataset,可以使用dataset[i]的方式使用索引i访问数据,比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

# 多元线性回归中真实的w和b
true_w = torch.tensor([2, 3.2]).reshape((-1, 1))
true_b = torch.tensor(6.6)

# 生成特征和标签,并添加正态分布的噪声
features = torch.normal(0, 4, (3000, 2))
labels = torch.matmul(features, true_w) + true_b
labels += torch.normal(0, 0.2, labels.shape)

# 按批次使用数据,batch_size=128
dataset = TensorDataset(features, labels)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

print(dataset[0])

输出结果:

1
(tensor([ 7.8515, -0.8850]), tensor([19.6474]))

DataLoader

torch.utils.data.DataLoader用于封装Dataset,提供批量加载、打乱数据、多线程加载等功能

1
2
3
4
from torch.utils.data import DataLoader

dataset = CustomDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
  • batch_size:每个批次的大小
  • shuffle:是否在每个epoch开始时打乱数据
  • num_workers:用于数据加载的子进程数
  • dataset:数据集对象
  • drop_last:默认为False。指定在数据集的大小不能被批大小batch_size整除时,是否要丢弃最后一个不完整的批次。

加载数据:

1
2
3
for epoch in range(1):
for batch_index, (inputs, labels) in enumerate(dataloader):
pass
  • batch_index
  • (inputs, labels):得到的批次特征和标签

enumerate是python的内置函数,用于将可迭代对象(如列表、元组、字符串等)组合为一个索引序列,同时列出数据和数据下标

如果不需要获取批次的索引,可以不使用enumerate

1
2
3
for epoch in range(1):
for inputs, labels in dataloader:
print(f'inputs = {inputs}, labels = {labels}')

TensorDataset

torch.utils.data.TensorDataset,无需自定义Dataset,适用于数据已经完全加载到内存中的情况。如果数据集太大无法一次性加载到内存,则需要考虑其他方法,比如用自定义的Dataset类进行懒加载(在__getitem__()中使用打开文件返回数据)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
from torch.utils.data import TensorDataset, DataLoader

# 随机生成数据集features, labels
features = torch.randn(100, 10) # 100个样本,每个样本10个特征
true_w = torch.normal(2, 3, (10, 1))
true_b = torch.tensor(0.5)
labels = torch.matmul(features, true_w) + true_b
labels += torch.randn(labels.shape)

# 用TensorDataset构建数据集
dataset = TensorDataset(features, labels)

# 使用DataLoader按批次加载数据集
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
for epoch in range(1):
for x, y in dataloader:
print(x, y)

未解决的问题:

  1. Dataset自定义数据集,DataLoader无法使用num_workers设置多进程,但是如果使用TensorDatasetDataLoader可以设置num_workers

数据处理与加载
https://blog.shinebook.net/2025/03/06/人工智能/pytorch/数据处理与加载/
作者
X
发布于
2025年3月6日
许可协议