目录
GPipe Pipeline 概述
GPipe
是一种管道并行(pipeline parallelism)实现,通常用于分布式深度学习中大模型的训练。它将模型划分为多个分区,每个分区放在不同的设备上,数据在设备之间逐步传递以完成计算。GPipe
使用检查点和数据依赖管理来优化显存利用并确保计算顺序。核心模块解析
Fork 和 Join 的作用
fork
和 join
用于在计算图中建立伪依赖关系,确保反向传播的顺序。具体来说:- fork:在前向传播中创建分叉,生成一个虚拟张量
phony
,用于在计算图中保持依赖。
- join:将输出张量与
phony
合并,确保反向传播时遵循指定的顺序。
这种机制保证了反向传播的顺序正确,特别适合在多设备或多分区的流水线并行中使用。
依赖控制与同步
fork
和join
将分区的计算按顺序排布,避免资源竞争。
- 在反向传播时,
join
通过phony
依赖,确保在B
之前计算A
的梯度,从而实现同步。
- 示例:通过
fork(A)
创建依赖,join(B, phony)
确保B
的梯度计算在A
之后进行。
任务创建与执行流程
compute
方法中,任务通过 in_queues
分发到各个设备的工作线程,由 fork
和 join
管理依赖关系,确保前向与反向传播的顺序一致。- 前向传播:创建任务,分配到指定设备的工作线程。
- 反向传播:
join
插入的伪依赖控制顺序,保证同步执行。
关键函数与类
GPipe
类
class GPipe(Module):
"""Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module
to train on GPipe_. If the module requires lots of memory, GPipe will be
very efficient.
Args:
module (torch.nn.Sequential):
sequential module to be parallelized
balance (ints):
list of number of layers in each partition
Keyword Args:
devices (iterable of devices):
devices to use (default: all CUDA devices)
chunks (int):
number of micro-batches (default: ``1``)
checkpoint (str):
when to enable checkpointing, one of ``'always'``,
``'except_last'``, or ``'never'`` (default: ``'except_last'``)
deferred_batch_norm (bool):
whether to use deferred BatchNorm moving statistics (default:
:data:`False`, see :ref:`Deferred Batch Normalization` for more
details)
Raises:
TypeError:
the module is not a :class:`nn.Sequential <torch.nn.Sequential>`.
ValueError:
invalid arguments, or wrong balance
IndexError:
the number of devices is fewer than the number of partitions.
"""
def __init__(self,
module: nn.Sequential,
balance: Optional[Iterable[int]] = None,
*,
devices: Optional[Devices] = None,
chunks: int = 1,
checkpoint: str = 'except_last',
deferred_batch_norm: bool = False,
) -> None:
super().__init__()
# 初始化与验证模块
if balance is None:
raise ValueError('balance is required')
if chunks <= 0:
raise ValueError('number of chunks must be positive integer')
if checkpoint not in ['always', 'except_last', 'never']:
raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'")
verify_module(module)
verify_skippables(module)
# 设置分区与设备
if devices is None:
devices = range(torch.cuda.device_count())
devices = [torch.device(d) for d in devices]
self.partitions, self.balance, self.devices = split_module(module, balance, devices)
self.chunks = chunks
self.checkpoint = checkpoint
# 初始化复制流与跳跃布局
self._copy_streams: List[List[AbstractStream]] = []
self._skip_layout = inspect_skip_layout(self.partitions)
- 作用:
GPipe
类用于将nn.Sequential
模块划分为多个分区,以便在多个设备上并行训练。
- 参数:包括模型分区的
balance
列表、使用的设备、微批次数量等。
- 关键点:
- 检查点机制:减少显存占用。
- 设备管理:将分区分配到指定的设备。
- 跳跃布局:用于跳跃连接的管理。
fork
与 join
函数
def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
if torch.is_grad_enabled() and input.requires_grad:
input, phony = Fork.apply(input)
else:
phony = get_phony(input.device, requires_grad=False)
return input, phony
- 作用:创建伪依赖,在计算图中分叉出
phony
张量。
- 返回:
input
分支和phony
(作为依赖的占位符)。
def join(input: Tensor, phony: Tensor) -> Tensor:
if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
input = Join.apply(input, phony)
return input
- 作用:将
input
与phony
合并,使得反向传播时B
必须等待A
。
- 目的:控制依赖顺序,确保前向传播依赖关系在反向传播中得到维护。
record_stream
函数
def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
if is_cuda(stream):
tensor = tensor.new_empty([0]).set_(tensor.storage())
tensor.record_stream(as_cuda(stream))
- 作用:确保张量内存不会在流操作完成前被释放。
- 工作原理:将
tensor
的内存分配与特定stream
关联,特别适用于异步 CUDA 操作。
depend
函数
def depend(fork_from: Batch, join_to: Batch) -> None:
fork_from[0], phony = fork(fork_from[0])
join_to[0] = join(join_to[0], phony)
- 作用:建立两个
Batch
对象的依赖关系,确保反向传播的顺序。
- 过程:通过
fork
和join
创建依赖,保证在反向传播时join_to
的计算顺序在fork_from
之后。
fence
函数
def fence(self, schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals]) -> None:
for i, j in schedule:
if i != 0:
depend(batches[i-1], batches[i])
next_stream = copy_streams[j][i]
if j != 0:
prev_stream = copy_streams[j-1][i]
copy(batches[i], prev_stream, next_stream)
- 作用:同步微批次数据传输,确保设备间的依赖。
- 细节:利用
depend
和流控制保证数据在不同分区的有序传输。
compute
函数
def compute(self, schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], in_queues: List[InQueue], out_queues: List[OutQueue]) -> None:
for i, j in schedule:
if checkpoint:
chk = Checkpointing(function, batch)
task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
in_queues[j].put(task)
for i, j in schedule:
ok, payload = out_queues[j].get()
if j != n-1:
wait(batch, streams[j], copy_streams[j][i])
with use_device(devices[j]):
task.finalize(batch)
- 作用:并行执行任务,控制检查点机制和流同步。
- 过程:任务分发到
in_queues
,按fork
和join
控制依赖顺序并执行。
worker
函数
def worker(in_queue: InQueue, out_queue: OutQueue, device: torch.device, grad_mode: bool) -> None:
"""The main loop of a worker thread."""
torch.set_grad_enabled(grad_mode)
with use_device(device):
while True:
task = in_queue.get()
if task is None:
break
try:
batch = task.compute()
except Exception:
exc_info = cast(ExcInfo, sys.exc_info())
out_queue.put((False, exc_info))
continue
out_queue.put((True, (task, batch)))
done = (False, None)
out_queue.put(done)
- 作用:
worker
函数是每个设备上运行的工作线程的主循环,负责从in_queue
中获取任务,执行计算并将结果发送到out_queue
。
- 流程:
- 从
in_queue
中获取任务。 - 执行计算并处理异常。
- 将结果或异常发送到
out_queue
。
spawn_workers
函数
def spawn_workers(devices: List[torch.device]) -> Generator[Tuple[List[InQueue], List[OutQueue]], None, None]:
"""Spawns worker threads. A worker thread is bound to a device."""
in_queues: List[InQueue] = []
out_queues: List[OutQueue] = []
# Spawn workers.
workers: Dict[torch.device, Tuple[InQueue, OutQueue]] = {}
def normalize_device(device: torch.device) -> torch.device:
if device.type == 'cuda' and device.index is None:
return torch.device('cuda', index=torch.cuda.current_device())
if device.type == 'cpu' and device.index is not None:
return torch.device('cpu')
return device
for device in devices:
device = normalize_device(device)
try:
in_queue, out_queue = workers[device]
except KeyError:
in_queue = Queue()
out_queue = Queue()
workers[device] = (in_queue, out_queue)
t = Thread(
target=worker,
args=(in_queue, out_queue, device, torch.is_grad_enabled()),
daemon=True,
)
t.start()
in_queues.append(in_queue)
out_queues.append(out_queue)
try:
yield (in_queues, out_queues)
finally:
# Close workers.
for in_queue in set(in_queues):
in_queue.put(None)
# Join running workers.
running = set(out_queues)
while running:
out_queue = running.pop()
ok, payload = out_queue.get()
done = (False, None)
if (ok, payload) == done:
continue
running.add(out_queue)
- 作用:
spawn_workers
函数用于为每个设备启动一个工作线程。
- 工作流程:
- 为每个设备创建输入和输出队列(
in_queue
和out_queue
)。 - 启动工作线程,线程通过
worker
函数执行任务。 - 在
yield
后确保所有工作线程正确关闭。
Worker 线程通信
- 队列机制:
- 每个设备对应一个输入队列(
in_queue
)和一个输出队列(out_queue
)。 - 任务通过
in_queue
分发给设备上的工作线程,执行完成后,结果通过out_queue
返回。
- 通信流程:
- 主线程将任务放入
in_queue
。 - 工作线程从
in_queue
中取出任务并执行。 - 计算完成后,结果通过
out_queue
发送回主线程。 - 当所有任务完成时,向
in_queue
发送None
以终止工作线程。
跨设备同步实现
GPipe
通过以下机制实现跨设备的同步:- 复制流管理 (
record_stream
): - 使用
record_stream
函数确保张量在复制过程中不会被释放。每个设备都有独立的 CUDA 流,这些流用于管理设备间的数据复制和同步。
- 依赖控制 (
fork
和join
): - 使用
fork
和join
函数在前向和反向传播中建立伪依赖关系,确保梯度计算顺序正确,避免
总结
GPipe
提供了一种高效的管道并行方案,特别适用于大模型的分布式训练。通过使用 fork
、join
、工作线程和检查点机制,GPipe
能够有效管理设备间的数据传输、同步和显存占用。worker
线程和 spawn_workers
函数为不同设备上的并行计算提供了强有力的支持,通过队列实现任务的分发和结果的收集,有效地实现了跨设备的计算同步与通信。