Dataset 与 DataLoader¶
OneFlow 的 Dataset
与 DataLoader
的行为与 PyTorch 的是一致的,都是为了让数据集管理与模型训练解耦。
Dataset
类用于定义如何读取数据。对于常见的计算机视觉数据集(如 FashionMNIST),可以直接使用 FlowVision 库的 datasets
模块提供的数据集类,可以帮助我们自动下载并加载一些流行的数据集,这些类都间接继承了 Dataset
类。对于其他数据集,可以通过继承 Dataset
类来自定义数据集类。
DataLoader
将 Dataset
封装为迭代器,方便训练时遍历并操作数据。
import matplotlib.pyplot as plt
import oneflow as flow
import oneflow.nn as nn
from oneflow.utils.data import Dataset
from flowvision import datasets
from flowvision import transforms
ToTensor
可以将 PIL 图像或 NumPy 数组转换为张量),可以在数据集类中直接使用。
使用 FlowVision 加载数据集¶
以下的例子展示了如何使用 flowvision.datasets
加载 FashionMNIST 数据集。
我们向 FashionMNIST
类传入以下参数:
- root
:数据集存放的路径
- train
: True
代表下载训练集、False
代表下载测试集
- download=True
: 如果 root
路径下数据集不存在,则从网络下载
- transforms
:指定的数据转换方式
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=transforms.ToTensor(),
source_url="https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/",
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=transforms.ToTensor(),
source_url="https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/",
)
第一次运行,会下载数据集,输出:
Downloading https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/train-images-idx3-ubyte.gz
Downloading https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
26422272/? [00:02<00:00, 8090800.72it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/train-labels-idx1-ubyte.gz
Downloading https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
29696/? [00:00<00:00, 806948.09it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/t10k-images-idx3-ubyte.gz
Downloading https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
4422656/? [00:00<00:00, 19237994.98it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/t10k-labels-idx1-ubyte.gz
Downloading https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
6144/? [00:00<00:00, 152710.85it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
遍历数据¶
Dataset
对象,可以像 list
一样,用下标索引,比如 training_data[index]
。
以下的例子,随机访问 training_data
中的9个图片,并显示。
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
from random import randint
for i in range(1, cols * rows + 1):
sample_idx = randint(0, len(training_data))
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze().numpy(), cmap="gray")
plt.show()
自定义 Dataset¶
通过继承 oneflow.utils.data.Dataset 可以实现自定义 Dataset
,自定义 Dataset
同样可以配合下一节介绍的 Dataloader
使用,简化数据处理的流程。
以下的例子展示了如何实现一个自定义 Dataset
,它的关键步骤是:
- 继承
oneflow.utils.data.Dataset
- 实现类的
__len__
方法,返回结果通常为该数据集中的样本数量 - 实现类的
__getitem__
方法,它的返回值对应了用户(或框架)调用dataset_obj[idx]
时得到的结果
import numpy as np
class CustomDataset(Dataset):
raw_data_x = np.array([[1, 2], [2, 3], [4, 6], [3, 1]], dtype=np.float32)
raw_label = np.array([[8], [13], [26], [9]], dtype=np.float32)
def __init__(self, transform=None, target_transform=None):
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(raw_label)
def __getitem__(self, idx):
x = CustomDataset.raw_data_x[idx]
label = CustomDataset.raw_label[idx]
if self.transform:
x = self.transform(x)
if self.target_transform:
label = self.target_transform(label)
return x, label
custom_dataset = CustomDataset()
print(custom_dataset[0])
print(custom_dataset[1])
输出:
(array([1., 2.], dtype=float32), array([8.], dtype=float32))
(array([2., 3.], dtype=float32), array([13.], dtype=float32))
使用 DataLoader¶
利用 Dataset 可以一次获取一条样本数据。但是在训练中,往往有其它的需求,如:一次读取 batch size 份数据;1轮 epoch 训练后,数据重新打乱(reshuffle)等。
这时候,使用 DataLoader
即可。 DataLoader
可以将 Dataset
封装为迭代器,方便训练循环中获取数据。如以下例子:
batch_size=64
: 指定一次迭代返回的数据 batch sizeshuffle
:是否要随机打乱数据的顺序
from oneflow.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
x, label = next(iter(train_dataloader))
print(f"shape of x:{x.shape}, shape of label: {label.shape}")
输出:
shape of x:flow.Size([64, 1, 28, 28]), shape of label: flow.Size([64])
img = x[0].squeeze().numpy()
label = label[0]
plt.imshow(img, cmap="gray")
plt.show()
print(label)
输出:(随机输出一张图片)
tensor(9, dtype=oneflow.int64)
自然我们也可以在训练的循环中,使用 DataLoader
迭代器:
for x, label in train_dataloader:
print(x.shape, label.shape)
# training...