简单总结:
- forward分为两步:scatter数据 以及 replicate模型
- scatter数据 用 to函数
- replicate模型 先allocate(这里用NCCL.broadcast实现), 再浅拷贝, 最后一一赋值
- parallel_apply
- 核心就是启动多个线程,同时model.forward()
- 最后gather输出 (用Tensor.copy_实现, 先把一个大tensor spilt再一一赋值, 隐式concat)
- 输出被gather后,在GPU0上算loss
- scatter loss(实现一样 to函数)
- 分别计算梯度,gather 梯度(使用reduce_add)实现
- 更新模型
- 重复整个过程
DDP 实现
- 使用torch.mutiprocessing创建进程