跳转至

Dataset 与 DataLoader

OneFlow 的 DatasetDataLoader 的行为与 PyTorch 的是一致的,都是为了让数据集管理与模型训练解耦。

Dataset 类用于定义如何读取数据。对于常见的计算机视觉数据集(如 FashionMNIST),可以直接使用 FlowVision 库的 datasets 模块提供的数据集类,可以帮助我们自动下载并加载一些流行的数据集,这些类都间接继承了 Dataset 类。对于其他数据集,可以通过继承 Dataset 类来自定义数据集类。

DataLoaderDataset 封装为迭代器,方便训练时遍历并操作数据。

import matplotlib.pyplot as plt

import oneflow as flow
import oneflow.nn as nn
from oneflow.utils.data import Dataset
from flowvision import datasets
from flowvision import transforms
上面导入的 flowvision.transforms 提供了一些对图像数据进行变换的操作(如 ToTensor 可以将 PIL 图像或 NumPy 数组转换为张量),可以在数据集类中直接使用。

使用 FlowVision 加载数据集

以下的例子展示了如何使用 flowvision.datasets 加载 FashionMNIST 数据集。

我们向 FashionMNIST 类传入以下参数: - root:数据集存放的路径 - trainTrue 代表下载训练集、False 代表下载测试集 - download=True: 如果 root 路径下数据集不存在,则从网络下载 - transforms:指定的数据转换方式

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=transforms.ToTensor(),
    source_url="https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/",
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=transforms.ToTensor(),
    source_url="https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/",
)

第一次运行,会下载数据集,输出:

Downloading https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/train-images-idx3-ubyte.gz
Downloading https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
26422272/? [00:02<00:00, 8090800.72it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/train-labels-idx1-ubyte.gz
Downloading https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
29696/? [00:00<00:00, 806948.09it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/t10k-images-idx3-ubyte.gz
Downloading https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
4422656/? [00:00<00:00, 19237994.98it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/t10k-labels-idx1-ubyte.gz
Downloading https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
6144/? [00:00<00:00, 152710.85it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

遍历数据

Dataset 对象,可以像 list 一样,用下标索引,比如 training_data[index]。 以下的例子,随机访问 training_data 中的9个图片,并显示。

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
from random import randint
for i in range(1, cols * rows + 1):
    sample_idx = randint(0, len(training_data))
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze().numpy(), cmap="gray")
plt.show()

fashionMNIST

自定义 Dataset

通过继承 oneflow.utils.data.Dataset 可以实现自定义 Dataset,自定义 Dataset 同样可以配合下一节介绍的 Dataloader 使用,简化数据处理的流程。

以下的例子展示了如何实现一个自定义 Dataset,它的关键步骤是:

  • 继承 oneflow.utils.data.Dataset
  • 实现类的 __len__ 方法,返回结果通常为该数据集中的样本数量
  • 实现类的 __getitem__ 方法,它的返回值对应了用户(或框架)调用 dataset_obj[idx] 时得到的结果
import numpy as np
class CustomDataset(Dataset):
    raw_data_x = np.array([[1, 2], [2, 3], [4, 6], [3, 1]], dtype=np.float32)
    raw_label = np.array([[8], [13], [26], [9]], dtype=np.float32)

    def __init__(self, transform=None, target_transform=None):
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(raw_label)

    def __getitem__(self, idx):
        x = CustomDataset.raw_data_x[idx]
        label = CustomDataset.raw_label[idx]
        if self.transform:
            x = self.transform(x)
        if self.target_transform:
            label = self.target_transform(label)
        return x, label

custom_dataset = CustomDataset()
print(custom_dataset[0])
print(custom_dataset[1])

输出:

(array([1., 2.], dtype=float32), array([8.], dtype=float32))
(array([2., 3.], dtype=float32), array([13.], dtype=float32))

使用 DataLoader

利用 Dataset 可以一次获取一条样本数据。但是在训练中,往往有其它的需求,如:一次读取 batch size 份数据;1轮 epoch 训练后,数据重新打乱(reshuffle)等。

这时候,使用 DataLoader 即可。 DataLoader 可以将 Dataset 封装为迭代器,方便训练循环中获取数据。如以下例子:

  • batch_size=64 : 指定一次迭代返回的数据 batch size
  • shuffle :是否要随机打乱数据的顺序
from oneflow.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
x, label = next(iter(train_dataloader))
print(f"shape of x:{x.shape}, shape of label: {label.shape}")

输出:

shape of x:flow.Size([64, 1, 28, 28]), shape of label: flow.Size([64])
img = x[0].squeeze().numpy()
label = label[0]
plt.imshow(img, cmap="gray")
plt.show()
print(label)

输出:(随机输出一张图片)

dataloader item

tensor(9, dtype=oneflow.int64)

自然我们也可以在训练的循环中,使用 DataLoader 迭代器:

for x, label in train_dataloader:
    print(x.shape, label.shape)
    # training...
Back to top