跳转至

数据并行训练

常见的分布式并行策略 一文中介绍了数据并行的特点。

OneFlow 提供了 oneflow.nn.parallel.DistributedDataParallel 模块及 launcher,可以让用户几乎不用对单机单卡脚本做修改,就能地进行数据并行训练。

可以用以下命令快速体验 OneFlow 的数据并行:

wget https://docs.oneflow.org/master/code/parallelism/ddp_train.py #下载脚本
python3 -m oneflow.distributed.launch --nproc_per_node 2 ./ddp_train.py #数据并行训练

输出:

50/500 loss:0.004111831542104483
50/500 loss:0.00025336415274068713
...
500/500 loss:6.184563972055912e-11
500/500 loss:4.547473508864641e-12

w:tensor([[2.0000],
        [3.0000]], device='cuda:1', dtype=oneflow.float32,
       grad_fn=<accumulate_grad>)

w:tensor([[2.0000],
        [3.0000]], device='cuda:0', dtype=oneflow.float32,
       grad_fn=<accumulate_grad>)

代码

点击以下 “Code” 可以展开以上运行脚本的代码。

Code
import oneflow as flow
from oneflow.nn.parallel import DistributedDataParallel as ddp

train_x = [
    flow.tensor([[1, 2], [2, 3]], dtype=flow.float32),
    flow.tensor([[4, 6], [3, 1]], dtype=flow.float32),
]
train_y = [
    flow.tensor([[8], [13]], dtype=flow.float32),
    flow.tensor([[26], [9]], dtype=flow.float32),
]


class Model(flow.nn.Module):
    def __init__(self):
        super().__init__()
        self.lr = 0.01
        self.iter_count = 500
        self.w = flow.nn.Parameter(flow.tensor([[0], [0]], dtype=flow.float32))

    def forward(self, x):
        x = flow.matmul(x, self.w)
        return x


m = Model().to("cuda")
m = ddp(m)
loss = flow.nn.MSELoss(reduction="sum")
optimizer = flow.optim.SGD(m.parameters(), m.lr)

for i in range(0, m.iter_count):
    rank = flow.env.get_rank()
    x = train_x[rank].to("cuda")
    y = train_y[rank].to("cuda")

    y_pred = m(x)
    l = loss(y_pred, y)
    if (i + 1) % 50 == 0:
        print(f"{i+1}/{m.iter_count} loss:{l}")

    optimizer.zero_grad()
    l.backward()
    optimizer.step()

print(f"\nw:{m.w}")

可以发现,数据并行的训练代码,与单机单卡脚本的不同只有2个:

  • 使用 DistributedDataParallel 处理一下 module 对象(m = ddp(m))
  • 使用 get_rank 获取当前设备编号,并针对设备分发数据

然后使用 launcher 启动脚本,把剩下的一切都交给 OneFlow,让分布式训练,像单机单卡训练一样简单:

python3 -m oneflow.distributed.launch --nproc_per_node 2 ./ddp_train.py

DistributedSampler

本文为了简化问题,突出 DistributedDataParallel,因此使用的数据是手工分发的。在实际应用中,可以直接使用 DistributedSampler 配合数据并行使用。

DistributedSampler 会在每个进程中实例化 Dataloader,每个 Dataloader 实例会加载完整数据的一部分,自动完成数据的分发。

Back to top