之前就很好奇各种训练并行手段怎么实现,但是像deepspeed这种框架太大了。终于发现了huggingface的一个用于教学目的训练实现。记录一下在这个小框架里各种并行的代码是什么样的。
DP
DP的概念:每个节点存完整的权重,把数据集分片喂给各个含有完整模型权重的节点上。反向传播过后执行一次
all-reduce
同步梯度,再更新权重。DataLoader:
self.sampler = DistributedSampler(
self.tokenized_dataset,
num_replicas=pgm.process_group_manager.dp_world_size,
rank=pgm.process_group_manager.dp_rank,
seed=seed,
shuffle=False
)
super().__init__(
self.tokenized_dataset,
batch_size=micro_batch_size,
collate_fn=self.collate_batch,
pin_memory=True,
num_workers=num_workers,
sampler=self.sampler,
shuffle=False,
)
主要是要在一个基本的
DataLoader
中加入DistributedSampler
,这样在多进程环境中可以让DataLoader
平均的分配数据给各个GPU(不重复的)。实现一:Naive
class DataParallelNaive(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
self.require_backward_grad_sync = True
self.register_backward_hook(self._allreduce_grads)
def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs)
def register_backward_hook(self, hook):
for p in self.module.parameters():
if p.requires_grad is True:
p.register_hook(hook)
def _allreduce_grads(self, grad):
if self.require_backward_grad_sync:
dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.dp_group)
grad /= pgm.process_group_manager.dp_world_size
return grad
个人认为最重要的是两个实现:
- 为每个参数注册
hook
函数:在各个参数完成梯度计算后,会检查是不是最后一次反向传播。如果是,就会执行all-reduce
同步梯度。
- 整个
DataParallel
类是一个nn.Module
,在调用的时候非常方便。其本质上就是在模型初始化时注册hook
,不用这个类也可以做到(用一个类似init
函数之类的东西也可以)
实现二:Bucket
动机:naive版本中的通信的粒度是每一个参数,即每一个参数在计算完参数之后就要通信,这样会导致通信的开销很大。如果把若干个参数放进一个
Bucket
里,在Bucket
中的所有参数都计算完梯度之后再统一通信,可以有效减少通信的次数。实现的核心步骤:
- 构建
Bucket
,存储参数和梯度
- 构建
Bucket Manage
,负责调度和初始化Bucket
- 构建
DataParallel
,包装模型
Step1:
class Bucket:
def __init__(self, params: List[nn.Parameter], grad_data: torch.Tensor, process_group: dist.ProcessGroup):
self.params = set(params)
self.params_with_grad_ready = set()
self.process_group = process_group
self.group_size = dist.get_world_size(group=self.process_group)
self.grad_data = grad_data
Step2:
class BucketManager:
def __init__(self, params: List[nn.Parameter], bucket_size: int, process_group: dist.ProcessGroup, grad_type: torch.dtype, device: str):
self.params = list(params)
self.bucket_size = bucket_size
self.process_group = process_group
self.grad_type = grad_type
self.device = device
self.params_to_bucket_location = {}
self.grad_data_list = []
self.buckets = []
self._post_backward_callback_set = False
self._init_buckets()
def _init_buckets(self):
cur_bucket_idx = 0
cur_bucket_size = 0
# Distribute all params in different buckets
for param in self.params:
if not param.requires_grad:
continue
if cur_bucket_size == 0:
self.params_to_bucket_location[param] = (0, param.numel(), cur_bucket_idx)
cur_bucket_size = param.numel()
continue
if cur_bucket_size + param.numel() > self.bucket_size:
cur_bucket_idx += 1
self.params_to_bucket_location[param] = (0, param.numel(), cur_bucket_idx)
cur_bucket_size = param.numel()
else:
self.params_to_bucket_location[param] = (cur_bucket_size, cur_bucket_size + param.numel(), cur_bucket_idx)
cur_bucket_size += param.numel()
buckets_sizes = [0] * (cur_bucket_idx + 1)
buckets_to_params = [[] for _ in range(cur_bucket_idx + 1)]
for param, (_, end, bucket_idx) in self.params_to_bucket_location.items():
buckets_sizes[bucket_idx] = max(buckets_sizes[bucket_idx], end)
buckets_to_params[bucket_idx].append(param)
for i in range(len(buckets_sizes)):
self.grad_data_list.append(torch.zeros(buckets_sizes[i], dtype=self.grad_type, device=self.device))
self.buckets.append(Bucket(buckets_to_params[i], self.grad_data_list[i], self.process_group))
# TODO: understand why reverse
for param in self.params[::-1]:
if not param.requires_grad:
continue
start_idx, end_idx, bucket_idx = self.params_to_bucket_location[param]
param.main_grad = self._get_view_from_tensor(self.grad_data_list[bucket_idx], param.shape, start_idx, end_idx)
def _get_view_from_tensor(self, tensor: torch.Tensor, shape: torch.Size, start: int, end: int):
return tensor[start:end].view(shape)
这里的核心(为什么可以减少通信次数)是
self.grad_data_list
中都是一大块tensor
,然后在view
成不同的小块对应到各个参数的梯度中,这样在all-reduce
的时候只需要对这一大块tensor
通信即可:# In Bucket class:
def sync_grad(self):
assert self.handle is None
self.grad_data /= self.group_size
self.handle = dist.all_reduce(self.grad_data, group=self.process_group, async_op=True)
Step3:
class DataParallelBucket(nn.Module):
def __init__(self, module, device, bucket_cap_mb = 25, grad_type = torch.float32):
super().__init__()
self.module = module
self.need_grad_sync = True
grad_size = 2
bucket_size = bucket_cap_mb * 1024 * 1024 // grad_size
self.bucket_manager = BucketManager(module.parameters(), bucket_size, pgm.process_group_manager.dp_group, grad_type, device)
self.register_backward_hook()
def register_backward_hook(self):
self.grad_acc = []
for param in self.module.parameters():
# TODO: get grad_fn
param_tmp = param.expand_as(param)
grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
grad_acc_fn.register_hook(self._make_grad_hook(param, self.bucket_manager))
def _make_grad_hook(self, param: nn.Parameter, bucket_manager: BucketManager):
def hook(*unsed):
if param.requires_grad:
assert param.grad is not None
# accmulate grad manully
param.main_grad.add_(param.grad.data)
param.grad = None
if self.need_grad_sync:
# Only set once
if not self._post_backward_callback_set:
torch.autograd.Variable._execution_engine.queue_callback(self._post_backward)
self._post_backward_callback_set = True
bucket_manager.mark_param_as_ready(param)
return hook
def _post_backward(self):
self.bucket_manager.wait()
self._post_backward_callback_set = False
for param in self.module.parameters():
if param.requires_grad:
param.grad = param.main_grad.to(param.dtype)
这里的核心是注册钩子函数的逻辑。相比于naive版本直接在参数上注册钩子函数,为了使参数的梯度以Bucket为单位统一管理以及支持不同精度的梯度和参数之间的操作,需要有额外的操作:
def register_backward_hook(self):
self.grad_acc = []
for param in self.module.parameters():
# TODO: get grad_fn
param_tmp = param.expand_as(param)
grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
grad_acc_fn.register_hook(self._make_grad_hook(param, self.bucket_manager))
def _make_grad_hook(self, param: nn.Parameter, bucket_manager: BucketManager):
def hook(*unsed):
if param.requires_grad:
assert param.grad is not None
# accmulate grad manully
param.main_grad.add_(param.grad.data)
param.grad = None
if self.need_grad_sync:
# Only set once
if not self._post_backward_callback_set:
torch.autograd.Variable._execution_engine.queue_callback(self._post_backward)
self._post_backward_callback_set = True
bucket_manager.mark_param_as_ready(param)
return hook
def _post_backward(self):
self.bucket_manager.wait()
self._post_backward_callback_set = False
for param in self.module.parameters():
if param.requires_grad:
param.grad = param.main_grad.to(param.dtype)
首先是为什么不能用以下代码,而要绕一大圈呢:
def register_backward_hook(self):
for param in self.module.parameters():
param.register_hook(partial(self.hook, param, self.bucket_manager))
def hook(self, param, bucket_manager, grad):
if param.requires_grad:
param.main_grad.add_(grad)
if self.need_grad_sync:
# Only set once
if not self._post_backward_callback_set:
torch.autograd.Variable._execution_engine.queue_callback(self._post_backward)
self._post_backward_callback_set = True
bucket_manager.mark_param_as_ready(param)
return None
在反复实验过后,疑似结论是:上述的简单方法只是概率有效。个人理解的原因:主要是
param.register_hook
和grad_acc_fn.register_hook
的调用时机不同。前者在参数其中一次求导后梯度还没有写入grad中就会被调用,而后者是参数的所有相关梯度都计算完毕(梯度累加之后)并且写入到grad后才被调用。
考虑以下场景:当一个参数被共享时,有关它的梯度会有一个累加的过程。如果采用第一种方案,参数第一个梯度被计算后,还没被写入grad,调用函数后会把参数mark为ready。当第二个梯度计算后,函数会再次将参数mark为ready,导致错误。而第二种方法是在参数的所有相关梯度都计算完后(梯度累加之后)才mark该参数为ready,避免了这种问题的产生。其次是这段代码:
torch.autograd.Variable._execution_engine.queue_callback(self._post_backward)
这段代码主要是为了解决PyTorch内部的自动求导的累加机制不支持不同精度的梯度运算。
torch.autograd.Variable._execution_engine.queue_callback
注册的函数会在所有的参数的梯度全部计算完毕之后调用。现在重新来看代码的注释:
This hook serves two main purposes:
1. PyTorch does not natively support gradient accumulation with mixed precision.
2. After gradient accumulation, it flags parameters as ready for synchronization.
第一和第二点分别对应了上述代码。
TP
TP的概念:将一个张量按列或行分割,分配到不同的GPU上进行计算。这样不仅权重的显存降低,梯度,激活值以及优化器状态的显存都降低了。
通信机制:
TP主要有三个通信机制,分别是Broadcast, Reduce 和 Gather:
class Reduce(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
if pgm.process_group_manager.tp_world_size == 1:
return input
# Base on PyTorch offical
output = input.clone()
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group)
return output
@staticmethod
def backward(ctx, grad_output):
return grad_output
class Copy(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx, grad_output):
if pgm.process_group_manager.tp_world_size == 1:
return grad_output
# Base on PyTorch offical
dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group)
return grad_output
class Gather(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
if pgm.process_group_manager.tp_world_size == 1:
return input
last_dim = input.dim() - 1
# Base on PyTorch offical
input = input.contiguous()
output = [torch.empty_like(input) for _ in range(pgm.process_group_manager.tp_world_size)]
output[pgm.process_group_manager.tp_rank] = input
# understand why two tensors here
dist.all_gather(output, input, group=pgm.process_group_manager.tp_group)
final_output = torch.cat(output, dim=last_dim)
return final_output
@staticmethod
def backward(ctx, grad_output):
if pgm.process_group_manager.tp_world_size == 1:
return grad_output
last_dim = grad_output.dim() - 1
assert grad_output.size()[last_dim] % pgm.process_group_manager.tp_world_size == 0
last_dim_size = grad_output.size()[last_dim] // pgm.process_group_manager.tp_world_size
chunks = torch.split(grad_output, last_dim_size, dim=last_dim)
return chunks[pgm.process_group_manager.tp_rank].contiguous()
切分:行切分和列切分
class ColumnParallelLinear(nn.Module):
def __init__(self, in_features, out_features, bias: bool = False, gather_output: bool = False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.tp_world_size = pgm.process_group_manager.tp_world_size
self.tp_rank = pgm.process_group_manager.tp_rank
assert out_features % self.tp_world_size == 0
self.output_size_per_partition = out_features // self.tp_world_size
# x: (batch_size, in_features) w: (in_features, out_features)
# torch linear: x @ w^T
self.weight = nn.Parameter(torch.Tensor(self.output_size_per_partition, in_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(self.output_size_per_partition))
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
if self.tp_world_size == 1:
# U(-sqrt(k), sqrt(k))
k = 1 / self.weight.size(1)
bound = math.sqrt(k)
torch.nn.init.uniform_(self.weight, -bound, bound)
return
master_weight = torch.empty(self.out_features, self.in_features, dtype=self.weight.dtype, requires_grad=False)
# Calculate bound based on master weight's input dimension. U(-sqrt(k), sqrt(k))
k = 1 / master_weight.size(1)
bound = math.sqrt(k)
torch.nn.init.uniform_(master_weight, -bound, bound)
weight_list = torch.split(master_weight, self.output_size_per_partition, dim=0)
self.weight.data = weight_list[self.tp_rank].contiguous()
def forward(self, input):
# TODO: Why copy here?
input_parallel = Copy.apply(input)
output = F.linear(input_parallel, self.weight, self.bias)
if self.gather_output:
output = Gather.apply(output)
return output
class RowParallelLinear(nn.Module):
def __init__(self, in_features, out_features, bias: bool = False, gather_output: bool = False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.tp_world_size = pgm.process_group_manager.tp_world_size
self.tp_rank = pgm.process_group_manager.tp_rank
assert in_features % self.tp_world_size == 0
self.in_size_per_partition = in_features // self.tp_world_size
# x: (batch_size, in_features) w: (in_features, out_features)
# torch linear: x @ w^T
self.weight = nn.Parameter(torch.Tensor(self.out_features, self.in_size_per_partition))
if bias:
self.bias = nn.Parameter(torch.Tensor(self.out_features))
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
if self.tp_world_size == 1:
# U(-sqrt(k), sqrt(k))
k = 1 / self.weight.size(1)
bound = math.sqrt(k)
torch.nn.init.uniform_(self.weight, -bound, bound)
return
master_weight = torch.empty(self.out_features, self.in_features, dtype=self.weight.dtype, requires_grad=False)
# Calculate bound based on master weight's input dimension. U(-sqrt(k), sqrt(k))
k = 1 / master_weight.size(1)
bound = math.sqrt(k)
torch.nn.init.uniform_(master_weight, -bound, bound)
weight_list = torch.split(master_weight, self.in_size_per_partition, dim=1)
self.weight.data = weight_list[self.tp_rank].contiguous()
def forward(self, input):
output_parallel = F.linear(input, self.weight)
output = Reduce.apply(output_parallel)
return output if self.bias is None else output + self.bias
具体的思路都是,在创建这一层的时候,创建属于该GPU自己的权重,再利用上述的通信机制来同步权重和梯度。
Transformer中的切分:
MLP:

所以MLP中的两个权重,第一个按列切分,第二个按行切分即可:
module_linear_name_stype_mapping_list = [
("mlp", "c_fc", "column"),
("mlp", "c_proj", "row"),
]
for layer in model.transformer.h:
for module_name, linear_proj_name, style in module_linear_name_stype_mapping_list:
_replace_module(getattr(layer, module_name), linear_proj_name, style)
Attention:

按图上来说,Q,K,V的映射矩阵是按列切分的,而最后的线性层是按行切分的:
module_linear_name_stype_mapping_list = [
("attn", "q_attn", "column"),
("attn", "k_attn", "column"),
("attn", "v_attn", "column"),
("attn", "c_proj", "row"),
]
for layer in model.transformer.h:
for module_name, linear_proj_name, style in module_linear_name_stype_mapping_list:
_replace_module(getattr(layer, module_name), linear_proj_name, style)
PP
PP的概念:当一个模型过大的时候,可以将模型按层切开分成几段模型,然后再分配到不同的GPU上。
通信机制:
# PP communication
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"
def pipeline_communication(operation, device, dtype, tensor = None, shape = None):
global STEP
global VERBOSE
if operation == "recv_forward":
if pgm.process_group_manager.pp_is_first_stage:
return None
tensor = torch.empty(shape, requires_grad=True, device=device, dtype=dtype)
src = pgm.process_group_manager.pp_prev_rank
elif operation == "send_forward":
if pgm.process_group_manager.pp_is_last_stage:
return
dest = pgm.process_group_manager.pp_next_rank
elif operation == "recv_backward":
if pgm.process_group_manager.pp_is_last_stage:
return None
tensor = torch.empty(shape, requires_grad=True, device=device, dtype=dtype)
src = pgm.process_group_manager.pp_next_rank
elif operation == "send_backward":
if pgm.process_group_manager.pp_is_first_stage:
return
dest = pgm.process_group_manager.pp_prev_rank
is_send = operation.startswith("send")
peer_rank = dest if is_send else src
# TODO: why process group is not needed here
op = dist.P2POp(dist.isend if is_send else dist.irecv, tensor, peer_rank)
if VERBOSE:
print(f"{operation} | {'sending' if is_send else 'receiving'} {operation.split('_')[1]} "
f"{pgm.process_group_manager.pp_rank} {'→' if is_send else '←'} {peer_rank} | "
f"STEP:{STEP} | RANK:{pgm.process_group_manager.pp_rank}")
[req.wait() for req in dist.batch_isend_irecv([op])]
torch.cuda.synchronize()
if VERBOSE: STEP += 1
return tensor if not is_send else None
#End PP commnunication
不同于TP的通信机制,PP全部是点到点的通信。
模型构造:
class PipelineParallel(nn.Module):
def __init__(self, model, config):
super().__init__()
layer_distribution = self._distribute_layers(config.n_layer)
self.embedding = model.embedding if pgm.process_group_manager.pp_is_first_stage else nn.Identity()
self.layers = nn.ModuleList([model.transformer.h[i] for i in layer_distribution])
self.final_norm = model.transformer.ln_f if pgm.process_group_manager.pp_is_last_stage else nn.Identity()
self.final_proj = model.lm_head if pgm.process_group_manager.pp_is_last_stage else nn.Identity()
def _distribute_layers(self, n_layer):
layer_per_gpu = [n_layer // pgm.process_group_manager.pp_world_size + (1 if i < n_layer % pgm.process_group_manager.pp_world_size > 0 else 0) for i in range(pgm.process_group_manager.pp_world_size)]
start_layer = sum(layer_per_gpu[:pgm.process_group_manager.pp_rank])
return list(range(start_layer, start_layer + layer_per_gpu[pgm.process_group_manager.pp_rank]))
def forward(self, input_ids, hidden_states):
x = hidden_states if hidden_states is not None else input_ids
x = self.embedding(x)
for layer in self.layers:
x = layer(x)
x = self.final_norm(x)
return self.final_proj(x)
'''
反向传播流程:
计算loss.backward(),得到d(loss)/d(output_B)。
调用阶段B的backward方法,传入output_A、output_B和d(loss)/d(output_B),计算得到d(loss)/d(output_A)(即output_A.grad)。
将output_A.grad作为阶段A的反向传播输入,继续计算阶段A的梯度。
'''
def backward(self, input_tensor, output_tensor, output_tensor_grad):
if input_tensor is not None:
input_tensor.retain_grad()
if output_tensor_grad is None:
output_tensor_grad = torch.ones_like(output_tensor, memory_format=torch.preserve_format)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad, retain_graph=False, create_graph=False)
return input_tensor.grad if input_tensor is not None else None
这里个人觉得有几个巧妙的地方:
_distribute_layers
这个函数写的很巧妙,完美解决了平均分配层数的问题
- 模型的各个层实现的很巧妙,使用
nn.Identity()
让forward
中不需要有额外的判断
backward
需要仔细理解一下
AFAB:
def train_step_pipeline_afab(model, dataloader, tensor_shape, device, dtype):
logging_loss = 0.0
input_tensors, output_tensors = [], []
requires_grad_sync = pgm.process_group_manager.dp_world_size > 1
# === All Forward ===
for _ in range(dataloader.grad_acc_steps):
input_tensor = pipeline_communication("recv_forward", device=device, dtype=dtype, shape=tensor_shape)
batch = next(dataloader)
batch['hidden_states'] = input_tensor.to(device) if input_tensor is not None else input_tensor
output_tensor = model.forward(input_ids=batch["input_ids"].to(device), hidden_states=batch["hidden_states"])
pipeline_communication("send_forward", device=device, dtype=dtype, tensor=output_tensor)
targets = batch["target_ids"].to(device)
if pgm.process_group_manager.pp_is_last_stage:
output_tensor = F.cross_entropy(output_tensor.view(-1, output_tensor.size(-1)), targets.view(-1), reduction='mean')
logging_loss += output_tensor.item() / dataloader.grad_acc_steps
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# === All Backward ===
for i in range(dataloader.grad_acc_steps):
if requires_grad_sync:
is_last_iteration = (i == dataloader.grad_acc_steps - 1)
model.require_backward_grad_sync = is_last_iteration
output_tensor_grad = pipeline_communication("recv_backward", device=device, dtype=dtype, shape=tensor_shape)
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
pipeline_communication("send_backward", device=device, dtype=dtype, tensor=input_tensor_grad)
return logging_loss
图示:

这里要手动推一下,才能完整理解步骤(假设有4个设备):
STEP 1:GPU 0 收到第一个数据,而其他GPU因为没有收到数据而被阻塞。
GPU 0 | GPU 1 | GPU 2 | GPU 3 |
batch 1 | wait for recv | wait for recv | wait for recv |
STEP 2:GPU 0 向后发送自己处理过的 batch 1。
STEP 3:GPU 0 收到第二个数据,GPU 1 收到经过 GPU 0处理后的batch 1,而其他GPU因为没有收到数据而被阻塞。
GPU 0 | GPU 1 | GPU 2 | GPU 3 |
batch 2 | batch 1 | wait for recv | wait for recv |
STEP 4:GPU 0 和 GPU 1 均向后发送数据。
STEP 5:GPU 0 收到第三个数据,GPU 1 收到经过 GPU 0处理后的batch 2,GPU 2 收到经过 GPU 1处理后的batch 1,而其他GPU因为没有收到数据而被阻塞。
GPU 0 | GPU 1 | GPU 2 | GPU 3 |
batch 3 | batch 2 | batch 1 | wait for recv |
后面不再推了,一些细节可以看代码。反向传播的过程是类似的。