DP Pytorch实现

简单总结:
  1. forward分为两步:scatter数据 以及 replicate模型
  • scatter数据 用 to函数
  • replicate模型 先allocate(这里用NCCL.broadcast实现), 再浅拷贝, 最后一一赋值
  1. parallel_apply
  • 核心就是启动多个线程,同时model.forward()
  • 最后gather输出 (用Tensor.copy_实现, 先把一个大tensor spilt再一一赋值, 隐式concat)
  1. 输出被gather后,在GPU0上算loss
  1. scatter loss(实现一样 to函数)
  1. 分别计算梯度,gather 梯度(使用reduce_add)实现
  1. 更新模型
  1. 重复整个过程
 

DDP 实现

  1. 使用torch.mutiprocessing创建进程