Picotron学习笔记

之前就很好奇各种训练并行手段怎么实现,但是像deepspeed这种框架太大了。终于发现了huggingface的一个用于教学目的训练实现。记录一下在这个小框架里各种并行的代码是什么样的。

DP

DP的概念:每个节点存完整的权重,把数据集分片喂给各个含有完整模型权重的节点上。反向传播过后执行一次all-reduce同步梯度,再更新权重。

DataLoader:

self.sampler = DistributedSampler(
    self.tokenized_dataset, 
    num_replicas=pgm.process_group_manager.dp_world_size, 
    rank=pgm.process_group_manager.dp_rank, 
    seed=seed,
    shuffle=False
)

super().__init__(
    self.tokenized_dataset,
    batch_size=micro_batch_size,
    collate_fn=self.collate_batch, 
    pin_memory=True, 
    num_workers=num_workers, 
    sampler=self.sampler,
    shuffle=False,
)
主要是要在一个基本的DataLoader中加入DistributedSampler,这样在多进程环境中可以让DataLoader平均的分配数据给各个GPU(不重复的)。

实现一:Naive

class DataParallelNaive(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module
        self.require_backward_grad_sync = True
        self.register_backward_hook(self._allreduce_grads)
    
    def forward(self, *inputs, **kwargs):
        return self.module(*inputs, **kwargs)
    
    def register_backward_hook(self, hook):
        for p in self.module.parameters():
            if p.requires_grad is True:
                p.register_hook(hook)
                
    def _allreduce_grads(self, grad):
		    if self.require_backward_grad_sync:
            dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.dp_group)
            grad /= pgm.process_group_manager.dp_world_size
        return grad
个人认为最重要的是两个实现:
  1. 为每个参数注册hook函数:在各个参数完成梯度计算后,会检查是不是最后一次反向传播。如果是,就会执行all-reduce同步梯度。
  1. 整个DataParallel类是一个nn.Module,在调用的时候非常方便。其本质上就是在模型初始化时注册hook,不用这个类也可以做到(用一个类似init函数之类的东西也可以)

实现二:Bucket

动机:naive版本中的通信的粒度是每一个参数,即每一个参数在计算完参数之后就要通信,这样会导致通信的开销很大。如果把若干个参数放进一个Bucket里,在Bucket中的所有参数都计算完梯度之后再统一通信,可以有效减少通信的次数。
 
实现的核心步骤:
  1. 构建Bucket,存储参数和梯度
  1. 构建Bucket Manage,负责调度和初始化Bucket
  1. 构建DataParallel,包装模型

Step1:

class Bucket:
    def __init__(self, params: List[nn.Parameter], grad_data: torch.Tensor, process_group: dist.ProcessGroup):
        self.params = set(params)
        self.params_with_grad_ready = set()
        self.process_group = process_group
        self.group_size = dist.get_world_size(group=self.process_group)
        self.grad_data = grad_data

Step2:

class BucketManager:
    def __init__(self, params: List[nn.Parameter], bucket_size: int, process_group: dist.ProcessGroup, grad_type: torch.dtype, device: str):
        self.params = list(params)
        self.bucket_size = bucket_size
        self.process_group = process_group
        self.grad_type = grad_type
        self.device = device

        self.params_to_bucket_location = {}
        self.grad_data_list = []
        self.buckets = []
        self._post_backward_callback_set = False

        self._init_buckets()

    def _init_buckets(self):
        cur_bucket_idx = 0
        cur_bucket_size = 0

        # Distribute all params in different buckets
        for param in self.params:
            if not param.requires_grad:
                continue

            if cur_bucket_size == 0:
                self.params_to_bucket_location[param] = (0, param.numel(), cur_bucket_idx)
                cur_bucket_size = param.numel()
                continue
            
            if cur_bucket_size + param.numel() > self.bucket_size:
                cur_bucket_idx += 1
                self.params_to_bucket_location[param] = (0, param.numel(), cur_bucket_idx)
                cur_bucket_size = param.numel()
            else:
                self.params_to_bucket_location[param] = (cur_bucket_size, cur_bucket_size + param.numel(), cur_bucket_idx)
                cur_bucket_size += param.numel()

        buckets_sizes = [0] * (cur_bucket_idx + 1)
        buckets_to_params = [[] for _ in range(cur_bucket_idx + 1)]
        for param, (_, end, bucket_idx) in self.params_to_bucket_location.items():
            buckets_sizes[bucket_idx] = max(buckets_sizes[bucket_idx], end)
            buckets_to_params[bucket_idx].append(param)

        for i in range(len(buckets_sizes)):
            self.grad_data_list.append(torch.zeros(buckets_sizes[i], dtype=self.grad_type, device=self.device))
            self.buckets.append(Bucket(buckets_to_params[i], self.grad_data_list[i], self.process_group))

        # TODO: understand why reverse
        for param in self.params[::-1]:
            if not param.requires_grad:
                continue
            start_idx, end_idx, bucket_idx = self.params_to_bucket_location[param]
            param.main_grad = self._get_view_from_tensor(self.grad_data_list[bucket_idx], param.shape, start_idx, end_idx)
    
    def _get_view_from_tensor(self, tensor: torch.Tensor, shape: torch.Size, start: int, end: int):
		    return tensor[start:end].view(shape)
这里的核心(为什么可以减少通信次数)是self.grad_data_list中都是一大块tensor,然后在view成不同的小块对应到各个参数的梯度中,这样在all-reduce的时候只需要对这一大块tensor通信即可:
# In Bucket class:
def sync_grad(self):
		assert self.handle is None
		self.grad_data /= self.group_size
		self.handle = dist.all_reduce(self.grad_data, group=self.process_group, async_op=True)

Step3:

class DataParallelBucket(nn.Module):
    def __init__(self, module, device, bucket_cap_mb = 25, grad_type = torch.float32):
        super().__init__()
        self.module = module
        self.need_grad_sync = True
        grad_size = 2

        bucket_size = bucket_cap_mb * 1024 * 1024 // grad_size
        self.bucket_manager = BucketManager(module.parameters(), bucket_size, pgm.process_group_manager.dp_group, grad_type, device)
        self.register_backward_hook()

    def register_backward_hook(self):
        self.grad_acc = []
        for param in self.module.parameters():
            # TODO: get grad_fn
            param_tmp = param.expand_as(param)
            grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
            grad_acc_fn.register_hook(self._make_grad_hook(param, self.bucket_manager))

    def _make_grad_hook(self, param: nn.Parameter, bucket_manager: BucketManager):

        def hook(*unsed):
            if param.requires_grad:
                assert param.grad is not None
                # accmulate grad manully
                param.main_grad.add_(param.grad.data)
                param.grad = None

                if self.need_grad_sync:
                    # Only set once
                    if not self._post_backward_callback_set:
                        torch.autograd.Variable._execution_engine.queue_callback(self._post_backward)
                        self._post_backward_callback_set = True

                    bucket_manager.mark_param_as_ready(param)
        return hook
    
    def _post_backward(self):
        self.bucket_manager.wait()
        self._post_backward_callback_set = False
        for param in self.module.parameters():
            if param.requires_grad:
                param.grad = param.main_grad.to(param.dtype)
这里的核心是注册钩子函数的逻辑。相比于naive版本直接在参数上注册钩子函数,为了使参数的梯度以Bucket为单位统一管理以及支持不同精度的梯度和参数之间的操作,需要有额外的操作:
    def register_backward_hook(self):
        self.grad_acc = []
        for param in self.module.parameters():
            # TODO: get grad_fn
            param_tmp = param.expand_as(param)
            grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
            grad_acc_fn.register_hook(self._make_grad_hook(param, self.bucket_manager))

    def _make_grad_hook(self, param: nn.Parameter, bucket_manager: BucketManager):

        def hook(*unsed):
            if param.requires_grad:
                assert param.grad is not None
                # accmulate grad manully
                param.main_grad.add_(param.grad.data)
                param.grad = None

                if self.need_grad_sync:
                    # Only set once
                    if not self._post_backward_callback_set:
                        torch.autograd.Variable._execution_engine.queue_callback(self._post_backward)
                        self._post_backward_callback_set = True

                    bucket_manager.mark_param_as_ready(param)
        return hook
    
    def _post_backward(self):
        self.bucket_manager.wait()
        self._post_backward_callback_set = False
        for param in self.module.parameters():
            if param.requires_grad:
                param.grad = param.main_grad.to(param.dtype)
首先是为什么不能用以下代码,而要绕一大圈呢:
def register_backward_hook(self):
        for param in self.module.parameters():
            param.register_hook(partial(self.hook, param, self.bucket_manager))

    def hook(self, param, bucket_manager, grad):
        if param.requires_grad:
            param.main_grad.add_(grad)
    
            if self.need_grad_sync:
                # Only set once
                if not self._post_backward_callback_set:
                    torch.autograd.Variable._execution_engine.queue_callback(self._post_backward)
                    self._post_backward_callback_set = True
    
                bucket_manager.mark_param_as_ready(param)
        return None
在反复实验过后,疑似结论是:上述的简单方法只是概率有效。个人理解的原因:主要是param.register_hookgrad_acc_fn.register_hook的调用时机不同。前者在参数其中一次求导后梯度还没有写入grad中就会被调用,而后者是参数的所有相关梯度都计算完毕(梯度累加之后)并且写入到grad后才被调用。 考虑以下场景:当一个参数被共享时,有关它的梯度会有一个累加的过程。如果采用第一种方案,参数第一个梯度被计算后,还没被写入grad,调用函数后会把参数mark为ready。当第二个梯度计算后,函数会再次将参数mark为ready,导致错误。而第二种方法是在参数的所有相关梯度都计算完后(梯度累加之后)才mark该参数为ready,避免了这种问题的产生。
其次是这段代码:
torch.autograd.Variable._execution_engine.queue_callback(self._post_backward)
这段代码主要是为了解决PyTorch内部的自动求导的累加机制不支持不同精度的梯度运算。torch.autograd.Variable._execution_engine.queue_callback 注册的函数会在所有的参数的梯度全部计算完毕之后调用。
现在重新来看代码的注释:
This hook serves two main purposes:
1. PyTorch does not natively support gradient accumulation with mixed precision.
2. After gradient accumulation, it flags parameters as ready for synchronization.
第一和第二点分别对应了上述代码。

TP

TP的概念:将一个张量按列或行分割,分配到不同的GPU上进行计算。这样不仅权重的显存降低,梯度,激活值以及优化器状态的显存都降低了。

通信机制:

TP主要有三个通信机制,分别是Broadcast, Reduce 和 Gather:
class Reduce(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        if pgm.process_group_manager.tp_world_size == 1:
            return input
        
        # Base on PyTorch offical
        output = input.clone()
        dist.all_reduce(output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group)
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output
    
class Copy(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input
    
    @staticmethod
    def backward(ctx, grad_output):
        if pgm.process_group_manager.tp_world_size == 1:
            return grad_output
        
        # Base on PyTorch offical
        dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group)
        return grad_output
    
class Gather(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        if pgm.process_group_manager.tp_world_size == 1:
            return input
        last_dim = input.dim() - 1
        # Base on PyTorch offical
        input = input.contiguous()
        output = [torch.empty_like(input) for _ in range(pgm.process_group_manager.tp_world_size)]
        output[pgm.process_group_manager.tp_rank] = input
        # understand why two tensors here
        dist.all_gather(output, input, group=pgm.process_group_manager.tp_group)
        final_output = torch.cat(output, dim=last_dim)
        return final_output
    
    @staticmethod
    def backward(ctx, grad_output):
        if pgm.process_group_manager.tp_world_size == 1:
            return grad_output
        last_dim = grad_output.dim() - 1
        assert grad_output.size()[last_dim] % pgm.process_group_manager.tp_world_size == 0
        last_dim_size = grad_output.size()[last_dim] // pgm.process_group_manager.tp_world_size
        chunks = torch.split(grad_output, last_dim_size, dim=last_dim)
        return chunks[pgm.process_group_manager.tp_rank].contiguous()

切分:行切分和列切分

class ColumnParallelLinear(nn.Module):
    def __init__(self, in_features, out_features, bias: bool = False, gather_output: bool = False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.gather_output = gather_output

        self.tp_world_size = pgm.process_group_manager.tp_world_size
        self.tp_rank = pgm.process_group_manager.tp_rank

        assert out_features % self.tp_world_size == 0
        self.output_size_per_partition = out_features // self.tp_world_size

        # x: (batch_size, in_features) w: (in_features, out_features)
        # torch linear: x @ w^T
        self.weight = nn.Parameter(torch.Tensor(self.output_size_per_partition, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(self.output_size_per_partition))
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter("bias", None)
        
        self.reset_parameters()

    def reset_parameters(self):
        if self.tp_world_size == 1:
            #  U(-sqrt(k), sqrt(k))
            k = 1 / self.weight.size(1)
            bound = math.sqrt(k)
            torch.nn.init.uniform_(self.weight, -bound, bound)
            return
    
        master_weight = torch.empty(self.out_features, self.in_features, dtype=self.weight.dtype, requires_grad=False)
        # Calculate bound based on master weight's input dimension. U(-sqrt(k), sqrt(k))
        k = 1 / master_weight.size(1)
        bound = math.sqrt(k)
        torch.nn.init.uniform_(master_weight, -bound, bound)
        
        weight_list = torch.split(master_weight, self.output_size_per_partition, dim=0)
        self.weight.data = weight_list[self.tp_rank].contiguous()

    def forward(self, input):
        # TODO: Why copy here?
        input_parallel = Copy.apply(input)
        output = F.linear(input_parallel, self.weight, self.bias)
        if self.gather_output:
            output = Gather.apply(output)
        return output

class RowParallelLinear(nn.Module):
    def __init__(self, in_features, out_features, bias: bool = False, gather_output: bool = False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.gather_output = gather_output

        self.tp_world_size = pgm.process_group_manager.tp_world_size
        self.tp_rank = pgm.process_group_manager.tp_rank

        assert in_features % self.tp_world_size == 0
        self.in_size_per_partition = in_features // self.tp_world_size

        # x: (batch_size, in_features) w: (in_features, out_features)
        # torch linear: x @ w^T
        self.weight = nn.Parameter(torch.Tensor(self.out_features, self.in_size_per_partition))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(self.out_features))
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter("bias", None)
        
        self.reset_parameters()

    def reset_parameters(self):
        if self.tp_world_size == 1:
            #  U(-sqrt(k), sqrt(k))
            k = 1 / self.weight.size(1)
            bound = math.sqrt(k)
            torch.nn.init.uniform_(self.weight, -bound, bound)
            return
    
        master_weight = torch.empty(self.out_features, self.in_features, dtype=self.weight.dtype, requires_grad=False)
        # Calculate bound based on master weight's input dimension. U(-sqrt(k), sqrt(k))
        k = 1 / master_weight.size(1)
        bound = math.sqrt(k)
        torch.nn.init.uniform_(master_weight, -bound, bound)
        
        weight_list = torch.split(master_weight, self.in_size_per_partition, dim=1)
        self.weight.data = weight_list[self.tp_rank].contiguous()

    def forward(self, input):
        output_parallel = F.linear(input, self.weight)
        output = Reduce.apply(output_parallel)
        return output if self.bias is None else output + self.bias
具体的思路都是,在创建这一层的时候,创建属于该GPU自己的权重,再利用上述的通信机制来同步权重和梯度。

Transformer中的切分:

MLP:

notion image
所以MLP中的两个权重,第一个按列切分,第二个按行切分即可:
    module_linear_name_stype_mapping_list = [
        ("mlp", "c_fc", "column"),
        ("mlp", "c_proj", "row"),
    ]

    for layer in model.transformer.h:
        for module_name, linear_proj_name, style in module_linear_name_stype_mapping_list:
            _replace_module(getattr(layer, module_name), linear_proj_name, style)

Attention:

notion image
按图上来说,Q,K,V的映射矩阵是按列切分的,而最后的线性层是按行切分的:
module_linear_name_stype_mapping_list = [
        ("attn", "q_attn", "column"),
        ("attn", "k_attn", "column"),
        ("attn", "v_attn", "column"),
        ("attn", "c_proj", "row"),
    ]

    for layer in model.transformer.h:
        for module_name, linear_proj_name, style in module_linear_name_stype_mapping_list:
            _replace_module(getattr(layer, module_name), linear_proj_name, style)

PP

PP的概念:当一个模型过大的时候,可以将模型按层切开分成几段模型,然后再分配到不同的GPU上。

通信机制:

# PP communication
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"
def pipeline_communication(operation, device, dtype, tensor = None, shape = None):
    global STEP
    global VERBOSE

    if operation == "recv_forward":
        if pgm.process_group_manager.pp_is_first_stage:
            return None
        tensor = torch.empty(shape, requires_grad=True, device=device, dtype=dtype)
        src = pgm.process_group_manager.pp_prev_rank
    elif operation == "send_forward":
        if pgm.process_group_manager.pp_is_last_stage:
            return
        dest = pgm.process_group_manager.pp_next_rank
    elif operation == "recv_backward":
        if pgm.process_group_manager.pp_is_last_stage:
            return None
        tensor = torch.empty(shape, requires_grad=True, device=device, dtype=dtype)
        src = pgm.process_group_manager.pp_next_rank
    elif operation == "send_backward":
        if pgm.process_group_manager.pp_is_first_stage:
            return
        dest = pgm.process_group_manager.pp_prev_rank

    is_send = operation.startswith("send")
    peer_rank = dest if is_send else src

    # TODO: why process group is not needed here
    op = dist.P2POp(dist.isend if is_send else dist.irecv, tensor, peer_rank)

    if VERBOSE: 
        print(f"{operation} | {'sending' if is_send else 'receiving'} {operation.split('_')[1]} "
              f"{pgm.process_group_manager.pp_rank} {'→' if is_send else '←'} {peer_rank} | "
              f"STEP:{STEP} | RANK:{pgm.process_group_manager.pp_rank}")
        
    [req.wait() for req in dist.batch_isend_irecv([op])]
    torch.cuda.synchronize()

    if VERBOSE: STEP += 1

    return tensor if not is_send else None
#End PP commnunication
不同于TP的通信机制,PP全部是点到点的通信。

模型构造:

class PipelineParallel(nn.Module):
    def __init__(self, model, config):
        super().__init__()
        layer_distribution = self._distribute_layers(config.n_layer)
        self.embedding = model.embedding if pgm.process_group_manager.pp_is_first_stage else nn.Identity()
        self.layers = nn.ModuleList([model.transformer.h[i] for i in layer_distribution])
        self.final_norm = model.transformer.ln_f if pgm.process_group_manager.pp_is_last_stage else nn.Identity()
        self.final_proj = model.lm_head if pgm.process_group_manager.pp_is_last_stage else nn.Identity()

    def _distribute_layers(self, n_layer):
        layer_per_gpu = [n_layer // pgm.process_group_manager.pp_world_size + (1 if i < n_layer % pgm.process_group_manager.pp_world_size > 0 else 0) for i in range(pgm.process_group_manager.pp_world_size)]
        start_layer = sum(layer_per_gpu[:pgm.process_group_manager.pp_rank])
        return list(range(start_layer, start_layer + layer_per_gpu[pgm.process_group_manager.pp_rank]))
    
    def forward(self, input_ids, hidden_states):
        x = hidden_states if hidden_states is not None else input_ids
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x)
        x = self.final_norm(x)
        return self.final_proj(x)

    '''
    反向传播流程:

    计算loss.backward(),得到d(loss)/d(output_B)。

    调用阶段B的backward方法,传入output_A、output_B和d(loss)/d(output_B),计算得到d(loss)/d(output_A)(即output_A.grad)。

    将output_A.grad作为阶段A的反向传播输入,继续计算阶段A的梯度。
    '''
    def backward(self, input_tensor, output_tensor, output_tensor_grad):
        if input_tensor is not None:
            input_tensor.retain_grad()
        
        if output_tensor_grad is None:
            output_tensor_grad = torch.ones_like(output_tensor, memory_format=torch.preserve_format)
        
        torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad, retain_graph=False, create_graph=False)

        return input_tensor.grad if input_tensor is not None else None
这里个人觉得有几个巧妙的地方:
  1. _distribute_layers这个函数写的很巧妙,完美解决了平均分配层数的问题
  1. 模型的各个层实现的很巧妙,使用nn.Identity()forward中不需要有额外的判断
  1. backward 需要仔细理解一下

AFAB:

def train_step_pipeline_afab(model, dataloader, tensor_shape, device, dtype):
    logging_loss = 0.0
    input_tensors, output_tensors = [], []
    requires_grad_sync = pgm.process_group_manager.dp_world_size > 1

    # === All Forward ===
    for _ in range(dataloader.grad_acc_steps):
        input_tensor = pipeline_communication("recv_forward", device=device, dtype=dtype, shape=tensor_shape)
        batch = next(dataloader)
        batch['hidden_states'] = input_tensor.to(device) if input_tensor is not None else input_tensor

        output_tensor = model.forward(input_ids=batch["input_ids"].to(device), hidden_states=batch["hidden_states"])
        pipeline_communication("send_forward", device=device, dtype=dtype, tensor=output_tensor)

        targets = batch["target_ids"].to(device)
        if pgm.process_group_manager.pp_is_last_stage:
            output_tensor = F.cross_entropy(output_tensor.view(-1, output_tensor.size(-1)), targets.view(-1), reduction='mean')
            logging_loss += output_tensor.item() / dataloader.grad_acc_steps

        input_tensors.append(input_tensor)
        output_tensors.append(output_tensor)

    # === All Backward ===
    for i in range(dataloader.grad_acc_steps):
        if requires_grad_sync:
            is_last_iteration = (i == dataloader.grad_acc_steps - 1)
            model.require_backward_grad_sync = is_last_iteration
        output_tensor_grad = pipeline_communication("recv_backward", device=device, dtype=dtype, shape=tensor_shape)
        input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
        input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
        pipeline_communication("send_backward", device=device, dtype=dtype, tensor=input_tensor_grad)

    return logging_loss
图示:
notion image
这里要手动推一下,才能完整理解步骤(假设有4个设备):
STEP 1:GPU 0 收到第一个数据,而其他GPU因为没有收到数据而被阻塞。
GPU 0
GPU 1
GPU 2
GPU 3
batch 1
wait for recv
wait for recv
wait for recv
STEP 2:GPU 0 向后发送自己处理过的 batch 1。
STEP 3:GPU 0 收到第二个数据,GPU 1 收到经过 GPU 0处理后的batch 1,而其他GPU因为没有收到数据而被阻塞。
GPU 0
GPU 1
GPU 2
GPU 3
batch 2
batch 1
wait for recv
wait for recv
STEP 4:GPU 0 和 GPU 1 均向后发送数据。
STEP 5:GPU 0 收到第三个数据,GPU 1 收到经过 GPU 0处理后的batch 2,GPU 2 收到经过 GPU 1处理后的batch 1,而其他GPU因为没有收到数据而被阻塞。
GPU 0
GPU 1
GPU 2
GPU 3
batch 3
batch 2
batch 1
wait for recv
后面不再推了,一些细节可以看代码。反向传播的过程是类似的。