GPipe 流水线并行源码

目录

  1. GPipe Pipeline 概述
  1. 核心模块解析
  1. 关键函数与类
  1. 跨设备同步实现
  1. 总结

GPipe Pipeline 概述

GPipe 是一种管道并行(pipeline parallelism)实现,通常用于分布式深度学习中大模型的训练。它将模型划分为多个分区,每个分区放在不同的设备上,数据在设备之间逐步传递以完成计算。GPipe 使用检查点和数据依赖管理来优化显存利用并确保计算顺序。

核心模块解析

Fork 和 Join 的作用

forkjoin 用于在计算图中建立伪依赖关系,确保反向传播的顺序。具体来说:
  • fork:在前向传播中创建分叉,生成一个虚拟张量 phony,用于在计算图中保持依赖。
  • join:将输出张量与 phony 合并,确保反向传播时遵循指定的顺序。
这种机制保证了反向传播的顺序正确,特别适合在多设备或多分区的流水线并行中使用。

依赖控制与同步

  • forkjoin 将分区的计算按顺序排布,避免资源竞争。
  • 在反向传播时,join 通过 phony 依赖,确保在 B 之前计算 A 的梯度,从而实现同步。
  • 示例:通过 fork(A) 创建依赖,join(B, phony) 确保 B 的梯度计算在 A 之后进行。

任务创建与执行流程

compute 方法中,任务通过 in_queues 分发到各个设备的工作线程,由 forkjoin 管理依赖关系,确保前向与反向传播的顺序一致。
  1. 前向传播:创建任务,分配到指定设备的工作线程。
  1. 反向传播join 插入的伪依赖控制顺序,保证同步执行。

关键函数与类

GPipe

class GPipe(Module):
    """Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module
    to train on GPipe_. If the module requires lots of memory, GPipe will be
    very efficient.

    Args:
        module (torch.nn.Sequential):
            sequential module to be parallelized
        balance (ints):
            list of number of layers in each partition

    Keyword Args:
        devices (iterable of devices):
            devices to use (default: all CUDA devices)
        chunks (int):
            number of micro-batches (default: ``1``)
        checkpoint (str):
            when to enable checkpointing, one of ``'always'``,
            ``'except_last'``, or ``'never'`` (default: ``'except_last'``)
        deferred_batch_norm (bool):
            whether to use deferred BatchNorm moving statistics (default:
            :data:`False`, see :ref:`Deferred Batch Normalization` for more
            details)

    Raises:
        TypeError:
            the module is not a :class:`nn.Sequential <torch.nn.Sequential>`.
        ValueError:
            invalid arguments, or wrong balance
        IndexError:
            the number of devices is fewer than the number of partitions.
    """

    def __init__(self,
                 module: nn.Sequential,
                 balance: Optional[Iterable[int]] = None,
                 *,
                 devices: Optional[Devices] = None,
                 chunks: int = 1,
                 checkpoint: str = 'except_last',
                 deferred_batch_norm: bool = False,
                 ) -> None:
        super().__init__()

        # 初始化与验证模块
        if balance is None:
            raise ValueError('balance is required')
        if chunks <= 0:
            raise ValueError('number of chunks must be positive integer')
        if checkpoint not in ['always', 'except_last', 'never']:
            raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'")

        verify_module(module)
        verify_skippables(module)

        # 设置分区与设备
        if devices is None:
            devices = range(torch.cuda.device_count())
        devices = [torch.device(d) for d in devices]
        self.partitions, self.balance, self.devices = split_module(module, balance, devices)

        self.chunks = chunks
        self.checkpoint = checkpoint

        # 初始化复制流与跳跃布局
        self._copy_streams: List[List[AbstractStream]] = []
        self._skip_layout = inspect_skip_layout(self.partitions)
  • 作用GPipe 类用于将 nn.Sequential 模块划分为多个分区,以便在多个设备上并行训练。
  • 参数:包括模型分区的 balance 列表、使用的设备、微批次数量等。
  • 关键点
    • 检查点机制:减少显存占用。
    • 设备管理:将分区分配到指定的设备。
    • 跳跃布局:用于跳跃连接的管理。

forkjoin 函数

def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
    if torch.is_grad_enabled() and input.requires_grad:
        input, phony = Fork.apply(input)
    else:
        phony = get_phony(input.device, requires_grad=False)
    return input, phony
  • 作用:创建伪依赖,在计算图中分叉出 phony 张量。
  • 返回input 分支和 phony(作为依赖的占位符)。
def join(input: Tensor, phony: Tensor) -> Tensor:
    if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
        input = Join.apply(input, phony)
    return input
  • 作用:将 inputphony 合并,使得反向传播时 B 必须等待 A
  • 目的:控制依赖顺序,确保前向传播依赖关系在反向传播中得到维护。

record_stream 函数

def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
    if is_cuda(stream):
        tensor = tensor.new_empty([0]).set_(tensor.storage())
        tensor.record_stream(as_cuda(stream))
  • 作用:确保张量内存不会在流操作完成前被释放。
  • 工作原理:将 tensor 的内存分配与特定 stream 关联,特别适用于异步 CUDA 操作。

depend 函数

def depend(fork_from: Batch, join_to: Batch) -> None:
    fork_from[0], phony = fork(fork_from[0])
    join_to[0] = join(join_to[0], phony)
  • 作用:建立两个 Batch 对象的依赖关系,确保反向传播的顺序。
  • 过程:通过 forkjoin 创建依赖,保证在反向传播时 join_to 的计算顺序在 fork_from 之后。

fence 函数

def fence(self, schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals]) -> None:
    for i, j in schedule:
        if i != 0:
            depend(batches[i-1], batches[i])

        next_stream = copy_streams[j][i]
        if j != 0:
            prev_stream = copy_streams[j-1][i]
            copy(batches[i], prev_stream, next_stream)
  • 作用:同步微批次数据传输,确保设备间的依赖。
  • 细节:利用 depend 和流控制保证数据在不同分区的有序传输。

compute 函数

def compute(self, schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], in_queues: List[InQueue], out_queues: List[OutQueue]) -> None:
    for i, j in schedule:
        if checkpoint:
            chk = Checkpointing(function, batch)
            task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
        in_queues[j].put(task)

    for i, j in schedule:
        ok, payload = out_queues[j].get()
        if j != n-1:
            wait(batch, streams[j], copy_streams[j][i])
        with use_device(devices[j]):
            task.finalize(batch)
  • 作用:并行执行任务,控制检查点机制和流同步。
  • 过程:任务分发到 in_queues,按 forkjoin 控制依赖顺序并执行。

worker 函数

def worker(in_queue: InQueue, out_queue: OutQueue, device: torch.device, grad_mode: bool) -> None:
    """The main loop of a worker thread."""
    torch.set_grad_enabled(grad_mode)

    with use_device(device):
        while True:
            task = in_queue.get()

            if task is None:
                break

            try:
                batch = task.compute()
            except Exception:
                exc_info = cast(ExcInfo, sys.exc_info())
                out_queue.put((False, exc_info))
                continue

            out_queue.put((True, (task, batch)))

    done = (False, None)
    out_queue.put(done)
  • 作用worker 函数是每个设备上运行的工作线程的主循环,负责从 in_queue 中获取任务,执行计算并将结果发送到 out_queue
  • 流程
      1. in_queue 中获取任务。
      1. 执行计算并处理异常。
      1. 将结果或异常发送到 out_queue

spawn_workers 函数

def spawn_workers(devices: List[torch.device]) -> Generator[Tuple[List[InQueue], List[OutQueue]], None, None]:
    """Spawns worker threads. A worker thread is bound to a device."""
    in_queues: List[InQueue] = []
    out_queues: List[OutQueue] = []

    # Spawn workers.
    workers: Dict[torch.device, Tuple[InQueue, OutQueue]] = {}

    def normalize_device(device: torch.device) -> torch.device:
        if device.type == 'cuda' and device.index is None:
            return torch.device('cuda', index=torch.cuda.current_device())
        if device.type == 'cpu' and device.index is not None:
            return torch.device('cpu')
        return device

    for device in devices:
        device = normalize_device(device)

        try:
            in_queue, out_queue = workers[device]
        except KeyError:
            in_queue = Queue()
            out_queue = Queue()
            workers[device] = (in_queue, out_queue)

            t = Thread(
                target=worker,
                args=(in_queue, out_queue, device, torch.is_grad_enabled()),
                daemon=True,
            )
            t.start()

        in_queues.append(in_queue)
        out_queues.append(out_queue)

    try:
        yield (in_queues, out_queues)
    finally:
        # Close workers.
        for in_queue in set(in_queues):
            in_queue.put(None)

        # Join running workers.
        running = set(out_queues)
        while running:
            out_queue = running.pop()
            ok, payload = out_queue.get()

            done = (False, None)
            if (ok, payload) == done:
                continue

            running.add(out_queue)
  • 作用spawn_workers 函数用于为每个设备启动一个工作线程。
  • 工作流程
      1. 为每个设备创建输入和输出队列(in_queueout_queue)。
      1. 启动工作线程,线程通过 worker 函数执行任务。
      1. yield 后确保所有工作线程正确关闭。

Worker 线程通信

  • 队列机制
    • 每个设备对应一个输入队列(in_queue)和一个输出队列(out_queue)。
    • 任务通过 in_queue 分发给设备上的工作线程,执行完成后,结果通过 out_queue 返回。
  • 通信流程
      1. 主线程将任务放入 in_queue
      1. 工作线程从 in_queue 中取出任务并执行。
      1. 计算完成后,结果通过 out_queue 发送回主线程。
      1. 当所有任务完成时,向 in_queue 发送 None 以终止工作线程。

跨设备同步实现

GPipe 通过以下机制实现跨设备的同步:
  1. 复制流管理 (record_stream)
      • 使用 record_stream 函数确保张量在复制过程中不会被释放。每个设备都有独立的 CUDA 流,这些流用于管理设备间的数据复制和同步。
  1. 依赖控制 (forkjoin)
      • 使用 forkjoin 函数在前向和反向传播中建立伪依赖关系,确保梯度计算顺序正确,避免

总结

GPipe 提供了一种高效的管道并行方案,特别适用于大模型的分布式训练。通过使用 forkjoin、工作线程和检查点机制,GPipe 能够有效管理设备间的数据传输、同步和显存占用。worker 线程和 spawn_workers 函数为不同设备上的并行计算提供了强有力的支持,通过队列实现任务的分发和结果的收集,有效地实现了跨设备的计算同步与通信。