用 SWPipeline 的 Task / Schedule / Stream / Stage 视角解读上游 TorchRec 的所有训练 PipelineAnalyzing upstream TorchRec training pipelines through SWPipeline's Task / Schedule / Stream / Stage lens
torchrec/distributed/train_pipeline/train_pipelines.py上游 TorchRec 的 train_pipelines.py 中定义了 1 个 ABC 基类 + 11 个具体实现。本页用 SWPipeline 的统一概念——PipelineTask(做什么)、TaskSchedule(stage / stream / thread_group / globally_ordered)、依赖关系——来解读每种 Pipeline 的调度策略。
Upstream TorchRec's train_pipelines.py defines 1 ABC base class + 11 concrete implementations. This page uses SWPipeline's unified concepts — PipelineTask (what), TaskSchedule (stage / stream / thread_group / globally_ordered), and dependencies — to interpret each Pipeline's scheduling strategy.
| # | 类名Class | 类型Type | Streams | Depth | 说明Description |
|---|---|---|---|---|---|
| 0 | TrainPipeline |
ABC | — | — | 抽象基类,定义 progress() 接口ABC defining progress() interface |
| 1 | TrainPipelineBase |
Train | 2 | 2 | 最简 2-stage H2D 流水线Simplest 2-stage H2D pipeline |
| 2 | TrainPipelinePT2 |
Compiled | 1 | 1 | PT2 编译,无流水线PT2 compiled, no pipelining |
| 3 | TrainPipelineSparseDist |
Train | 3 | 3 | 经典 3-stage: H2D ‖ InputDist ‖ FWD/BWDClassic 3-stage: H2D ‖ InputDist ‖ FWD/BWD |
| 4 | TrainPipelineSparseDistLite |
Train | 2 | 2 | 省内存 2-stage, InputDist 在关键路径Memory-efficient 2-stage, InputDist in critical path |
| 5 | TrainPipelineFusedSparseDist |
Train | 3–4 | 3 | Emb lookup 独立 stream, 与 optimizer 重叠Separate emb lookup stream, overlaps with optimizer |
| 6 | TrainPipelineSemiSync |
Train | 3 | 3 | 半同步训练, Emb A2A 与 Forward 完全重叠Semi-sync, fully overlaps Emb A2A with Forward |
| 7 | PrefetchTrainPipelineSparseDist |
Train | 4 | 3 | 4-stage 含 cache prefetch4-stage with cache prefetch |
| 8 | EvalPipelineSparseDist |
Eval | 3 | 2 | 推理用, 后台 DataLoadingThread H2DEval, background DataLoadingThread H2D |
| 9 | EvalPipelineFusedSparseDist |
Eval | 3–4 | 3 | 推理用 Fused 变体Eval variant of Fused pipeline |
| 10 | StagedTrainPipeline |
Generic | 用户定义User-defined | N | 通用 N-stage 声明式流水线Generic N-stage declarative pipeline |
| 11 | TrainPipelineSparseDistCompAutograd |
Compiled | 3 | 3 | SparseDist + Compiled AutogradSparseDist + Compiled Autograd |
继承关系:
Inheritance hierarchy:
TrainPipeline (ABC) ├── TrainPipelineBase │ └── TrainPipelinePT2 ├── TrainPipelineSparseDist │ ├── TrainPipelineSparseDistLite │ ├── TrainPipelineFusedSparseDist │ │ └── EvalPipelineFusedSparseDist │ ├── TrainPipelineSemiSync │ ├── PrefetchTrainPipelineSparseDist │ ├── EvalPipelineSparseDist │ └── TrainPipelineSparseDistCompAutograd └── StagedTrainPipeline
# pipeline_depth = 2 (cur_batch + next_batch) # 2 streams: memcpy_stream + default_stream PipelinePlan( schedule = { "H2D": TaskSchedule(stage=0, stream=memcpy), # stage 0: batch 进入 pipeline 的第一步 "ZeroGrad": TaskSchedule(stage=1, stream=default), "WaitBatch": TaskSchedule(stage=1, stream=default), # wait memcpy → default "Forward": TaskSchedule(stage=1, stream=default), "Backward": TaskSchedule(stage=1, stream=default), "OptimizerStep": TaskSchedule(stage=1, stream=default), # stage 1: batch 处理的最后一步 }, intra_iter_deps = [ ("WaitBatch", "H2D"), # cross-stage: memcpy → default stream sync ("WaitBatch", "ZeroGrad"), # stage=1 ordering ("Forward", "WaitBatch"), # stage=1 ordering ("Backward", "Forward"), # stage=1 ordering ("OptimizerStep", "Backward"), # stage=1 ordering ], inter_iter_deps = [ ("Forward", "OptimizerStep"), # Forward(i) 读取 OptimizerStep(i-1) 更新的权重 ], # same-stream (default) → FIFO 已保证 GPU 顺序 pipeline_depth = 2, )
print_schedule(5) 输出print_schedule(5) Output# Task Thread Stream | P0 P1 P2 P3 P4 -- ----------------- ------- ------------ + ----- ----- ----- ----- ----- 0 ZeroGrad default default | -- i0 i1 i2 i3 1 WaitBatch default default | -- i0 i1 i2 i3 2 Forward default default | -- i0 i1 i2 i3 3 Backward default default | -- i0 i1 i2 i3 4 OptimizerStep default default | -- i0 i1 i2 i3 5 H2D default memcpy | i0 i1 i2 i3 i4 # pipeline_depth=2: H2D(i+1) 在 memcpy stream 上与当前 batch 的处理并行
## progress(iter i): batch stream 1. ZeroGrad [i] default 2. WaitBatch [i] default (wait memcpy→default) 3. Forward [i] default 4. H2D [i+1] memcpy ← Backward 之后才提交 5. Backward [i] default 6. OptimizerStep [i] default ← H2D 在 memcpy 上与 Opt 并行
ShardedModule 做 FX rewrite,不拆分 InputDist。适合无 sharding 或极简场景。H2D(i+1) 在 Backward(i) 结束后才启动,仅与 Optimizer(i) 在不同 stream 重叠。
Key feature: No FX rewrite of ShardedModule, no InputDist split. Suitable for non-sharded or minimal scenarios. H2D(i+1) starts after Backward(i), only overlaps with Optimizer(i) on a separate stream.
# pipeline_depth = 1 (no pipelining) # 1 stream: default_stream only # Model compiled via torch.compile on compile_on_iter (default: 3rd iter) PipelinePlan( schedule = { "LoadBatch": TaskSchedule(stage=0, stream=default), # next(dataloader_iter) "H2D": TaskSchedule(stage=0, stream=default), # _to_device(non_blocking=False) "InputTransform": TaskSchedule(stage=0, stream=default), # PT2 input hints "ZeroGrad": TaskSchedule(stage=0, stream=default), "Forward": TaskSchedule(stage=0, stream=default), # compiled model(batch) "Backward": TaskSchedule(stage=0, stream=default), "OptimizerStep": TaskSchedule(stage=0, stream=default), }, pipeline_depth = 1, # fully serial, no overlap )
non_blocking=False(同步拷贝),完全没有 stream 重叠。优化完全依赖 torch.compile 对计算图的编译加速。在第 compile_on_iter 次迭代触发编译(默认第 3 次)。继承自 TrainPipelineBase。
Note: H2D uses non_blocking=False (synchronous copy), zero stream overlap. Optimization relies entirely on torch.compile graph compilation. Compilation triggers at compile_on_iter (default: 3rd iteration). Inherits from TrainPipelineBase.
# pipeline_depth = 3 (max 3 batches in flight: batches[0], [1], [2]) # 3 streams: memcpy, data_dist, default # fill_pipeline loads 2 batches; progress() enqueue_batch adds 3rd # Requires FX rewrite via _rewrite_model (PipelinedForward) PipelinePlan( schedule = { "H2D": TaskSchedule(stage=0, stream=memcpy), # stage 0: batch 入口 (batches[2]) "InputDistStart": TaskSchedule(stage=1, stream=data_dist, globally_ordered=True), # stage 1 (batches[1]) "InputDistWait": TaskSchedule(stage=1, stream=data_dist), # stage 1 (batches[1]) "ZeroGrad": TaskSchedule(stage=2, stream=default), "WaitBatch": TaskSchedule(stage=2, stream=default), # stage 2 (batches[0]) "Forward": TaskSchedule(stage=2, stream=default), # stage 2: batch 出口 (batches[0]) "Backward": TaskSchedule(stage=2, stream=default), "OptimizerStep": TaskSchedule(stage=2, stream=default), }, intra_iter_deps = [ ("InputDistStart", "H2D"), # cross-stage: memcpy → data_dist stream sync ("InputDistWait", "InputDistStart"), # stage=1 ordering ("WaitBatch", "InputDistWait"), # cross-stage: data_dist → default stream sync ("Forward", "InputDistWait"), # cross-stage: data_dist → default stream sync ("WaitBatch", "ZeroGrad"), # stage=2 ordering ("Forward", "WaitBatch"), # stage=2 ordering ("Backward", "Forward"), # stage=2 ordering ("OptimizerStep", "Backward"), # stage=2 ordering ], inter_iter_deps = [ ("Forward", "OptimizerStep"), # Forward(i) 读取 OptimizerStep(i-1) 更新的权重 ], # same-stream (default) → FIFO 已保证 GPU 顺序 pipeline_depth = 3, )
print_schedule(5) 输出print_schedule(5) Output# Task Thread Stream | P0 P1 P2 P3 P4 -- ----------------- ------- ------------ + ----- ----- ----- ----- ----- 0 ZeroGrad default default | -- -- i0 i1 i2 1 WaitBatch default default | -- -- i0 i1 i2 2 Forward default default | -- -- i0 i1 i2 3 Backward default default | -- -- i0 i1 i2 4 OptimizerStep default default | -- -- i0 i1 i2 5 InputDistStart default data_dist | -- i0 i1 i2 i3 6 InputDistWait default data_dist | -- i0 i1 i2 i3 7 H2D default memcpy | i0 i1 i2 i3 i4 # pipeline_depth=3: 3 条 stream (memcpy, data_dist, default) 并行
## progress(iter i): batch stream 1. ZeroGrad [i] default 2. WaitBatch [i] default (wait data_dist→default) 3. InputDistStart [i+1] data_dist ← start splits A2A 4. H2D [i+2] memcpy ← enqueue_batch 5. Forward (EmbLookup 在内部执行) [i] default ← PipelinedForward 6. InputDistWait [i+1] data_dist ← wait splits A2A 7. Backward [i] default 8. OptimizerStep [i] default # 注: enqueue_batch_after_forward=True 时, H2D 在 Forward 之后提交 # 以避免 H2D 与 UVM embedding lookup 争抢 PCIe 带宽
_rewrite_model 对 ShardedModule 做 FX 符号追踪,将 input_dist 提前到独立 stream 执行。最多 3 个 batch 在飞(源码注释: "max pipelined batches == 3"):batches[0] 做 FWD/BWD,batches[1] 做 InputDist,batches[2] 做 H2D。三者在 3 条不同 stream 上并行。这是 TorchRec 最核心的流水线,其他 6 个 training pipeline 均继承自此。compute_and_output_dist)在 Forward 内部由 PipelinedForward 自动执行。FusedSparseDist 对此做了进一步优化——将 Emb Lookup 提取到独立 stream 上提前执行,详见下文对比。
Core mechanism: Uses _rewrite_model to FX-trace ShardedModule, extracting input_dist to a separate stream. Up to 3 batches in flight (source: "max pipelined batches == 3"): batches[0] runs FWD/BWD, batches[1] runs InputDist, batches[2] runs H2D — all 3 in parallel on different streams. This is TorchRec's core pipeline — 6 other training pipelines inherit from it.compute_and_output_dist) runs inside Forward, auto-triggered by PipelinedForward. FusedSparseDist further optimizes this by extracting Emb Lookup to a separate stream — see comparison below.
# pipeline_depth = 2 (max 2 batches: fill 1, enqueue 1 in progress) # 2 streams: memcpy + default (data_dist_stream set to None!) # fill_pipeline fills only 1 batch (vs 2 in full SDD) PipelinePlan( schedule = { "H2D": TaskSchedule(stage=0, stream=memcpy), # stage 0: batch 入口 "ZeroGrad": TaskSchedule(stage=1, stream=default), "WaitBatch": TaskSchedule(stage=1, stream=default), "InputDistStart": TaskSchedule(stage=1, stream=default), # on default stream! "InputDistWait": TaskSchedule(stage=1, stream=default), # on default stream! "Forward": TaskSchedule(stage=1, stream=default), "Backward": TaskSchedule(stage=1, stream=default), "OptimizerStep": TaskSchedule(stage=1, stream=default), # stage 1: batch 出口 }, intra_iter_deps = [ ("WaitBatch", "H2D"), # cross-stage: memcpy → default stream sync ("WaitBatch", "ZeroGrad"), # stage=1 ordering ("InputDistStart", "WaitBatch"), # stage=1 ordering (default stream) ("InputDistWait", "InputDistStart"), # stage=1 ordering ("Forward", "InputDistWait"), # stage=1 ordering ("Backward", "Forward"), # stage=1 ordering ("OptimizerStep", "Backward"), # stage=1 ordering ], inter_iter_deps = [ ("Forward", "OptimizerStep"), # Forward(i) 读取 OptimizerStep(i-1) 更新的权重 ], # same-stream (default) → FIFO 已保证 GPU 顺序 pipeline_depth = 2, )
print_schedule(5) 输出print_schedule(5) Output# Task Thread Stream | P0 P1 P2 P3 P4 -- ----------------- ------- ------------ + ----- ----- ----- ----- ----- 0 ZeroGrad default default | -- i0 i1 i2 i3 1 WaitBatch default default | -- i0 i1 i2 i3 2 InputDistStart default default | -- i0 i1 i2 i3 3 InputDistWait default default | -- i0 i1 i2 i3 4 Forward default default | -- i0 i1 i2 i3 5 Backward default default | -- i0 i1 i2 i3 6 OptimizerStep default default | -- i0 i1 i2 i3 7 H2D default memcpy | i0 i1 i2 i3 i4 # pipeline_depth=2: 仅 H2D 在 memcpy stream 上与上一轮 Opt 重叠 # 所有 stage=1 任务都在 default stream 上串行执行
## progress(iter i): batch stream 1. ZeroGrad [i] default 2. WaitBatch [i] default (wait memcpy→default) 3. InputDistStart+Wait [i] default ← 关键路径! data_dist=None 4. Forward [i] default 5. Backward [i] default 6. H2D [i+1] memcpy ← enqueue_batch, BWD 与 OPT 之间 7. OptimizerStep [i] default ← H2D 在 memcpy 上与 Opt 并行
self._data_dist_stream = None。InputDist 在 default stream 上执行(关键路径),只有 H2D 使用独立 memcpy_stream。fill_pipeline 仅填 1 个 batch(vs 完整 SDD 的 2 个),progress() 中 enqueue_batch 加到 2 个 → 最大 2 batch 在飞。内存开销约为 SparseDist 的 2/3(2 batch vs 3),QPS 提升约 4-5%。适合内存受限场景。
Design point: Key difference — self._data_dist_stream = None. InputDist executes on default stream (critical path), only H2D uses separate memcpy_stream. fill_pipeline loads only 1 batch (vs 2 in full SDD), progress() enqueue_batch brings it to 2 → max 2 batches in flight. Memory overhead ~2/3 of SparseDist (2 batches vs 3), ~4-5% QPS improvement. Ideal for memory-constrained deployments.
# pipeline_depth = 3 (max 3 batches in flight, same as SparseDist) # source: "max pipelined batches == 3 (capacity)" # 3-4 streams: memcpy, data_dist, emb_lookup (default=data_dist), default # Uses InSyncEmbeddingPipelinedForward (not PipelinedForward) PipelinePlan( schedule = { "H2D": TaskSchedule(stage=0, stream=memcpy), # stage 0: batch 入口 (batches[2]) "InputDistStart": TaskSchedule(stage=1, stream=data_dist, globally_ordered=True), # stage 1 (batches[1]) "InputDistWait": TaskSchedule(stage=1, stream=data_dist), # stage 1 (batches[1]) "EmbLookup": TaskSchedule(stage=2, stream=emb_lookup), # stage 2 (batches[0]), compute_and_output_dist "ZeroGrad": TaskSchedule(stage=2, stream=default), "WaitBatch": TaskSchedule(stage=2, stream=default), "Forward": TaskSchedule(stage=2, stream=default), # stage 2: dense-only forward (batch 出口) "Backward": TaskSchedule(stage=2, stream=default), "OptimizerStep": TaskSchedule(stage=2, stream=default), }, intra_iter_deps = [ ("InputDistStart", "H2D"), # cross-stage: memcpy → data_dist stream sync ("InputDistWait", "InputDistStart"), # stage=1 ordering ("EmbLookup", "InputDistWait"), # cross-stage: data_dist → emb_lookup stream sync ("Forward", "EmbLookup"), # cross-stream: emb_lookup → default ("WaitBatch", "ZeroGrad"), # stage=2 ordering ("Forward", "WaitBatch"), # stage=2 ordering ("Backward", "Forward"), # stage=2 ordering ("OptimizerStep", "Backward"), # stage=2 ordering ], inter_iter_deps = [ ("EmbLookup", "Backward"), # EmbLookup(i) 读取 Backward(i-1) 更新的 emb 权重 (TBE fused) ("Forward", "OptimizerStep"), # Forward(i) 读取 OptimizerStep(i-1) 更新的 dense 权重 ], # EmbLookup↔Backward 跨 stream (emb_lookup↔default),必须显式声明 pipeline_depth = 3, )
print_schedule(5) 输出print_schedule(5) Output# Task Thread Stream | P0 P1 P2 P3 P4 -- ----------------- ------- ------------ + ----- ----- ----- ----- ----- 0 EmbLookup default emb_lookup | -- -- i0 i1 i2 1 ZeroGrad default default | -- -- i0 i1 i2 2 WaitBatch default default | -- -- i0 i1 i2 3 Forward default default | -- -- i0 i1 i2 4 Backward default default | -- -- i0 i1 i2 5 OptimizerStep default default | -- -- i0 i1 i2 6 InputDistStart default data_dist | -- i0 i1 i2 i3 7 InputDistWait default data_dist | -- i0 i1 i2 i3 8 H2D default memcpy | i0 i1 i2 i3 i4 # EmbLookup 在 emb_lookup stream 上独立执行,与 ZeroGrad/WaitBatch 在 default stream 上并行
## progress(iter i): batch stream 1. EmbLookup [i] emb_lookup ← progress 最前面! 2. ZeroGrad [i] default 3. WaitBatch [i] default (wait data_dist→default) 4. InputDistStart [i+1] data_dist 5. H2D [i+2] memcpy 6. Forward (dense-only, wait emb_lookup) [i] default 7. InputDistWait [i+1] data_dist 8. Backward [i] default 9. OptimizerStep [i] default # embedding_lookup_after_data_dist=True 时, EmbLookup 在 H2D 之后
两者 pipeline depth 相同(3 batches in flight),关键差异在于 Embedding Lookup(compute_and_output_dist)的执行时机和位置:
| 维度 | SparseDist | FusedSparseDist |
|---|---|---|
| PipelinedForward 类型 | PipelinedForward |
InSyncEmbeddingPipelinedForward |
| Emb Lookup 位置 | 在 model_fwd() 内部,由 PipelinedForward 自动执行,运行在 default_stream |
在 model_fwd() 之前,由 start_embedding_lookup() 显式调用,运行在 emb_lookup_stream |
| Emb Lookup 时机 | Forward(i) 开始时(与 Forward 串行) | progress() 最开始,在 ZeroGrad 之前 |
| 可重叠对象 | 无(Emb Lookup 在 Forward 关键路径上) | EmbLookup(i) 可与上一轮 Optimizer(i-1) 的尾部在不同 stream 上并行 |
| Forward 内容 | 完整 forward(含 emb lookup + dense) | 仅 dense forward(emb 结果已由 InSyncEmbeddingPipelinedForward 注入) |
| 额外 stream | 无 | emb_lookup_stream(默认复用 data_dist,可配 "new" 或 "current") |
执行顺序对比(CPU 提交顺序):
Both have the same pipeline depth (3 batches in flight). The critical difference is when and where Embedding Lookup (compute_and_output_dist) executes:
| Dimension | SparseDist | FusedSparseDist |
|---|---|---|
| PipelinedForward type | PipelinedForward |
InSyncEmbeddingPipelinedForward |
| Emb Lookup location | Inside model_fwd(), auto-executed by PipelinedForward, runs on default_stream |
Before model_fwd(), explicitly called via start_embedding_lookup(), runs on emb_lookup_stream |
| Emb Lookup timing | At the start of Forward(i) (serial with Forward) | At the very beginning of progress(), before ZeroGrad |
| Can overlap with | Nothing (Emb Lookup is on Forward's critical path) | EmbLookup(i) can overlap with previous Optimizer(i-1) tail on a different stream |
| Forward content | Full forward (emb lookup + dense) | Dense-only forward (emb results injected by InSyncEmbeddingPipelinedForward) |
| Extra stream | None | emb_lookup_stream (default: reuses data_dist, or "new" / "current") |
Execution order comparison (CPU submission order):
## SparseDist progress(): 1. ZeroGrad # default_stream 2. WaitBatch # wait data_dist → default 3. InputDistStart(i+1) # data_dist_stream 4. H2D(i+2) # memcpy_stream 5. Forward(i) ← Emb Lookup 在此内部执行 # default_stream 6. InputDistWait(i+1) # data_dist_stream 7. Backward(i) # default_stream 8. Optimizer(i) # default_stream 9. Dequeue ## FusedSparseDist progress(): 1. EmbLookup(i) # emb_lookup_stream ← 提到最前面! 2. ZeroGrad # default_stream 3. WaitBatch # wait data_dist → default 4. InputDistStart(i+1) # data_dist_stream 5. H2D(i+2) # memcpy_stream 6. Forward(i) ← 仅 dense, emb 结果已就绪 # default_stream, 先 wait emb_lookup_stream 7. InputDistWait(i+1) # data_dist_stream 8. Backward(i) # default_stream 9. Optimizer(i) # default_stream 10. Dequeue
"data_dist"(默认):复用 data_dist_stream,减少 CUDA Caching Allocator 为额外 stream 预留显存"new":新建独立 stream,最大化重叠潜力"current":使用 default_stream,退化为与 SparseDist 等效(无重叠)"data_dist" (default): reuses data_dist_stream, reduces CUDA Caching Allocator memory reservation for extra streams"new": fresh stream, maximizes overlap potential"current": default_stream, degrades to SparseDist-equivalent (no overlap)# pipeline_depth = 4 (4 batches in flight: i, i+1, i+2, i+3) # 3 streams: memcpy, data_dist, default # Uses EmbeddingPipelinedForward (detaches emb from autograd) # H2D(i+3) → InputDist(i+2) → EmbLookup(i+1) → Forward/Backward(i) PipelinePlan( schedule = { "H2D": TaskSchedule(stage=0, stream=memcpy), # stage 0: batch[i+3] "InputDistStart": TaskSchedule(stage=1, stream=data_dist, globally_ordered=True), # stage 1: batch[i+2] "InputDistWait": TaskSchedule(stage=1, stream=data_dist), # stage 1: batch[i+2] "EmbLookup": TaskSchedule(stage=2, stream=default), # stage 2: batch[i+1], start_embedding_lookup "ZeroGrad": TaskSchedule(stage=3, stream=default), # stage 3: batch[i] "Forward": TaskSchedule(stage=3, stream=default), # stage 3: uses detached emb tensors "Backward": TaskSchedule(stage=3, stream=default), # stage 3: dense backward only "EmbBackward": TaskSchedule(stage=3, stream=default), # stage 3: torch.autograd.backward "OptimizerStep": TaskSchedule(stage=3, stream=default), # stage 3: batch 出口 }, intra_iter_deps = [ ("InputDistStart", "H2D"), # cross-stage: memcpy → data_dist stream sync ("InputDistWait", "InputDistStart"), # stage=1 ordering ("EmbLookup", "InputDistWait"), # cross-stage: data_dist → default stream sync ("Forward", "EmbLookup"), # cross-stage: emb ready before forward ("Forward", "ZeroGrad"), # stage=3 ordering ("Backward", "Forward"), # stage=3 ordering ("EmbBackward", "Backward"), # stage=3 ordering ("OptimizerStep", "EmbBackward"), # stage=3 ordering ], inter_iter_deps = [ ("EmbLookup", "Backward"), # EmbLookup(i) 读取 Backward(i-1) 更新的 emb 权重 (TBE fused) ], # same-stream (default) → FIFO 已保证 GPU 顺序 # 注意: Forward(i) 使用 B-2 参数 → 逻辑上依赖 OptimizerStep(i-2),k=2 超出 SWPipeline 支持范围 pipeline_depth = 4, )
print_schedule(6) 输出print_schedule(6) Output# Task Thread Stream | P0 P1 P2 P3 P4 P5 -- ----------------- ------- ------------ + ----- ----- ----- ----- ----- ----- 0 ZeroGrad default default | -- -- -- i0 i1 i2 1 Forward default default | -- -- -- i0 i1 i2 2 Backward default default | -- -- -- i0 i1 i2 3 EmbBackward default default | -- -- -- i0 i1 i2 4 OptimizerStep default default | -- -- -- i0 i1 i2 5 EmbLookup default default | -- -- i0 i1 i2 i3 6 InputDistStart default data_dist | -- i0 i1 i2 i3 i4 7 InputDistWait default data_dist | -- i0 i1 i2 i3 i4 8 H2D default memcpy | i0 i1 i2 i3 i4 i5 # pipeline_depth=4: P0~P2 为 fill_pipeline, P3 起输出第一个结果 # stage 0: H2D | stage 1: InputDist | stage 2: EmbLookup | stage 3: Fwd/Bwd/Opt
## SWPipeline enqueue order (stage-descending): batch stream stage 1. ZeroGrad [i] default ← stage 3 2. Forward (emb 已在上轮 EmbLookup 完成) [i] default ← stage 3 3. Backward (dense only) [i] default ← stage 3 4. EmbBackward [i] default ← stage 3 5. OptimizerStep [i] default ← stage 3 6. EmbLookup [i+1] default ← stage 2 7. InputDistStart [i+2] data_dist ← stage 1 8. InputDistWait [i+2] data_dist ← stage 1 9. H2D [i+3] memcpy ← stage 0 # 注意: pipeline_depth=4, 共 4 个 batch 在飞。 # Forward(i) 使用的 embedding 来自 EmbLookup(i) (前一轮提交)。 # sync 模式: EmbLookup 在 Backward+OptimizerStep 之后, # 需要等上一轮 Backward 完成才能安全启动。
start_batch(默认 900)个 batch 使用同步模式。stash_gradients=True 时保存梯度以确保"真正的"半同步语义。Embedding 的 backward 独立执行(torch.autograd.backward),与 dense backward 分离。
Convergence impact: Each batch predicts using B-2 (not B-1) parameters. First start_batch (default 900) batches use synchronous mode. With stash_gradients=True, gradients are stored for "true" semi-sync semantics. Embedding backward is separate (torch.autograd.backward), decoupled from dense backward.
# pipeline_depth = 3 (3 batches: _batch_i, _batch_ip1, _batch_ip2) # 4 streams: memcpy, data_dist, prefetch, default # Uses PrefetchPipelinedForward PipelinePlan( schedule = { "H2D": TaskSchedule(stage=0, stream=memcpy), # stage 0: batch 入口 "InputDistStart": TaskSchedule(stage=0, stream=data_dist, globally_ordered=True), "InputDistWait": TaskSchedule(stage=1, stream=data_dist), "EmbPrefetch": TaskSchedule(stage=1, stream=prefetch), # _prefetch_embeddings "ZeroGrad": TaskSchedule(stage=2, stream=default), "WaitBatch": TaskSchedule(stage=2, stream=default), # wait prefetch_stream "Forward": TaskSchedule(stage=2, stream=default), "Backward": TaskSchedule(stage=2, stream=default), "OptimizerStep": TaskSchedule(stage=2, stream=default), # stage 2: batch 出口 }, intra_iter_deps = [ ("InputDistStart", "H2D"), # cross-stage: memcpy → data_dist stream sync ("InputDistWait", "InputDistStart"), # stage=1 ordering (cross-stage) ("EmbPrefetch", "InputDistWait"), # cross-stream: data_dist → prefetch ("WaitBatch", "EmbPrefetch"), # cross-stage: prefetch → default stream sync ("WaitBatch", "ZeroGrad"), # stage=2 ordering ("Forward", "WaitBatch"), # stage=2 ordering ("Backward", "Forward"), # stage=2 ordering ("OptimizerStep", "Backward"), # stage=2 ordering ], inter_iter_deps = [ ("EmbPrefetch", "Forward"), # EmbPrefetch(i) 须在 Forward(i-1) 消费完 prefetch 结果后执行 ("Forward", "OptimizerStep"), # Forward(i) 读取 OptimizerStep(i-1) 更新的权重 ], # EmbPrefetch↔Forward 跨 stream (prefetch↔default),必须显式声明 pipeline_depth = 3, )
print_schedule(5) 输出print_schedule(5) Output# Task Thread Stream | P0 P1 P2 P3 P4 -- ----------------- ------- ------------ + ----- ----- ----- ----- ----- 0 ZeroGrad default default | -- -- i0 i1 i2 1 WaitBatch default default | -- -- i0 i1 i2 2 Forward default default | -- -- i0 i1 i2 3 Backward default default | -- -- i0 i1 i2 4 OptimizerStep default default | -- -- i0 i1 i2 5 InputDistWait default data_dist | -- i0 i1 i2 i3 6 EmbPrefetch default prefetch | -- i0 i1 i2 i3 7 H2D default memcpy | i0 i1 i2 i3 i4 8 InputDistStart default data_dist | i0 i1 i2 i3 i4 # stage 0: H2D + InputDistStart (两个不同 stream 同 stage) # stage 1: InputDistWait + EmbPrefetch # stage 2: ZeroGrad → WaitBatch → Forward → Backward → OptimizerStep
## progress(iter i): batch stream 1. ZeroGrad [i] default 2. WaitBatch [i] default (wait prefetch→default) 3. H2D [i+2] memcpy ← _copy_batch_to_gpu 4. InputDistWait [i+1] data_dist ← 上轮 InputDistStart 的结果 5. Forward [i] default ← PrefetchPipelinedForward 6. EmbPrefetch [i+1] prefetch ← _prefetch_embeddings 7. Backward [i] default 8. OptimizerStep [i] default 9. InputDistStart [i+2] data_dist ← _start_sparse_data_dist 10. Slide: i←ip1, ip1←ip2
progress() 单次迭代执行顺序:
wait_for_batch(_batch_i, prefetch_stream)_copy_batch_to_gpu(_batch_ip2)_wait_sparse_data_dist()model(_batch_i)_prefetch(_batch_ip1) on prefetch_streamloss.backward()_start_sparse_data_dist(_batch_ip2)_batch_i ← _batch_ip1, _batch_ip1 ← _batch_ip2progress() single iteration execution order:
wait_for_batch(_batch_i, prefetch_stream)_copy_batch_to_gpu(_batch_ip2)_wait_sparse_data_dist()model(_batch_i)_prefetch(_batch_ip1) on prefetch_streamloss.backward()_start_sparse_data_dist(_batch_ip2)_batch_i ← _batch_ip1, _batch_ip1 ← _batch_ip2prefetch_stream 将 cache prefetch 与 Forward 重叠。对于使用 UVM_CACHING 的大规模 Embedding,cache miss 延迟可达数十 ms,prefetch 能将其完全隐藏。内存开销为 3 个 batch。此 pipeline 仍使用 deprecated 的 _batch_i/_batch_ip1/_batch_ip2 成员变量。
Core value: Dedicated prefetch_stream overlaps cache prefetch with Forward. For large-scale Embeddings using UVM_CACHING, cache miss latency can reach tens of ms; prefetch fully hides this. Memory overhead: 3 batches. This pipeline still uses deprecated _batch_i/_batch_ip1/_batch_ip2 member variables.
EvalPipelineSparseDist 使用 DataLoadingThread 在后台线程做 H2D,每次 progress() 都重新初始化 pipelined modules。
EvalPipelineFusedSparseDist 继承 TrainPipelineFusedSparseDist,使用 torch.no_grad() 包裹 forward。
EvalPipelineSparseDist uses DataLoadingThread for background H2D; re-initializes pipelined modules in every progress().
EvalPipelineFusedSparseDist inherits TrainPipelineFusedSparseDist, wraps forward in torch.no_grad().
# pipeline_depth = 2, background thread H2D # Unique: uses DataLoadingThread (separate Python thread) for H2D PipelinePlan( schedule = { "H2D": TaskSchedule(stage=0, stream=memcpy, thread_group="loader"), # stage 0: batch 入口 (background thread) "InputDistStart": TaskSchedule(stage=1, stream=data_dist, globally_ordered=True), "InputDistWait": TaskSchedule(stage=1, stream=data_dist), "WaitBatch": TaskSchedule(stage=1, stream=default), "Forward": TaskSchedule(stage=1, stream=default), # stage 1: batch 出口 (no backward/optimizer) }, pipeline_depth = 2, )
# pipeline_depth = N (= number of stages) # Each PipelineStage has: name, runnable, stream, fill_callback, data_exhausted_callback # Forward/Backward/Optimizer run OUTSIDE the pipeline (in user's train loop) # Example: 2-stage pipeline for data_copy + gpu_postproc StagedTrainPipeline( pipeline_stages = [ PipelineStage(name="data_copy", runnable=get_h2d_func("cuda"), stream=Stream()), PipelineStage(name="gpu_postproc", runnable=gpu_postproc, stream=Stream()), ] ) # Usage: pipeline outputs the processed batch, user handles FWD/BWD/OPT while batch := pipeline.progress(dataloader_iter): optimizer.zero_grad() loss, pred = model(batch) loss.backward() optimizer.step()
关键特点:
PipelineStage 列表声明流水线,不需要子类化torch.cuda.Event 做同步,不使用 stream.wait_streamset_flush(True) 会让流水线排空所有 in-flight batch 后调用 callbackKey features:
PipelineStage list, no subclassing neededtorch.cuda.Event, not stream.wait_streamset_flush(True) drains all in-flight batches then invokes callbackStagedTrainPipeline 是 TorchRec 中最接近声明式 pipeline 的设计,但它只管 forward 之前的 stage,且没有 shortcut / profiling / DeclaredIO 等机制。SWPipeline 覆盖完整的 train loop(包括 FWD/BWD/OPT),并提供自动 profiling 和 zero-overhead shortcut。
Comparison with SWPipeline: StagedTrainPipeline is TorchRec's closest to declarative pipeline design, but it only manages pre-forward stages and lacks shortcut / profiling / DeclaredIO. SWPipeline covers the full train loop (including FWD/BWD/OPT) with automatic profiling and zero-overhead shortcut.
# Same as TrainPipelineSparseDist, but wraps forward/backward # in torch._dynamo.compiled_autograd context # Enables inductor reordering passes for compute-comm overlap: # "sink_waits", "raise_comms", "reorder_compute_for_overlap" PipelinePlan( schedule = { # identical to TrainPipelineSparseDist (stages corrected for SWPipeline convention) "H2D": TaskSchedule(stage=0, stream=memcpy), # stage 0: batch 入口 "InputDistStart": TaskSchedule(stage=1, stream=data_dist), "InputDistWait": TaskSchedule(stage=1, stream=data_dist), "ZeroGrad": TaskSchedule(stage=2, stream=default), "WaitBatch": TaskSchedule(stage=2, stream=default), "Forward": TaskSchedule(stage=2, stream=default), # compiled_autograd ctx "Backward": TaskSchedule(stage=2, stream=default), # compiled_autograd ctx "OptimizerStep": TaskSchedule(stage=2, stream=default), # stage 2: batch 出口 }, intra_iter_deps = [ ("InputDistStart", "H2D"), # cross-stage: memcpy → data_dist stream sync ("InputDistWait", "InputDistStart"), # stage=1 ordering ("WaitBatch", "InputDistWait"), # cross-stage: data_dist → default stream sync ("Forward", "InputDistWait"), # cross-stage: data_dist → default stream sync ("WaitBatch", "ZeroGrad"), # stage=2 ordering ("Forward", "WaitBatch"), # stage=2 ordering ("Backward", "Forward"), # stage=2 ordering ("OptimizerStep", "Backward"), # stage=2 ordering ], inter_iter_deps = [ ("Forward", "OptimizerStep"), # Forward(i) 读取 OptimizerStep(i-1) 更新的权重 ], # same-stream (default) → FIFO 已保证 GPU 顺序 pipeline_depth = 3, )
print_schedule(5) 输出print_schedule(5) Output# Task Thread Stream | P0 P1 P2 P3 P4 -- ----------------- ------- ------------ + ----- ----- ----- ----- ----- 0 ZeroGrad default default | -- -- i0 i1 i2 1 WaitBatch default default | -- -- i0 i1 i2 2 Forward default default | -- -- i0 i1 i2 3 Backward default default | -- -- i0 i1 i2 4 OptimizerStep default default | -- -- i0 i1 i2 5 InputDistStart default data_dist | -- i0 i1 i2 i3 6 InputDistWait default data_dist | -- i0 i1 i2 i3 7 H2D default memcpy | i0 i1 i2 i3 i4 # 与 SparseDist 完全相同,区别仅在 Forward+Backward 运行在 compiled_autograd 上下文中
## progress(iter i): 与 SparseDist 完全相同 1. ZeroGrad [i] default 2. WaitBatch [i] default 3. InputDistStart [i+1] data_dist 4. H2D [i+2] memcpy 5. Forward (compiled_autograd ctx) [i] default 6. InputDistWait [i+1] data_dist 7. Backward (compiled_autograd ctx) [i] default 8. OptimizerStep [i] default
torch._dynamo.compiled_autograd + inductor 后端编译 backward graph。启用 reorder_for_compute_comm_overlap pass,让编译器自动调度计算与通信的重叠。第一次迭代使用 nullcontext()(无编译),之后启用编译。
Compilation optimization: Uses torch._dynamo.compiled_autograd + inductor backend to compile backward graph. Enables reorder_for_compute_comm_overlap passes for automatic compute-comm overlap scheduling by the compiler. First iteration uses nullcontext() (no compile), then enables compilation.
| 维度Dimension | Base | PT2 | SparseDist | SDLite | Fused | SemiSync | Prefetch | EvalSD | Staged | CompAuto |
|---|---|---|---|---|---|---|---|---|---|---|
| Streams | 2 | 1 | 3 | 2 | 3–4 | 3 | 4 | 3 | N | 3 |
| Depth | 2 | 1 | 3 | 2 | 3 | 3 | 3 | 2 | N | 3 |
| 线程Threads | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1+bg | 1 | 1 |
| FX Rewrite | — | — | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | — | ✓ |
| H2D 重叠overlap | OPT | — | FWD | OPT | FWD | FWD | FWD | bg | 自定义Custom | FWD |
| InputDist | — | — | 独立 streamSeparate | 关键路径Critical | 独立 streamSeparate | Event 同步Event sync | 独立 streamSeparate | 独立 streamSeparate | — | 独立 streamSeparate |
| EmbLookup | — | — | 在 FWD 中In FWD | 在 FWD 中In FWD | 独立 streamSeparate | 独立 stageSeparate | 在 FWD 中In FWD | 在 FWD 中In FWD | — | 在 FWD 中In FWD |
| Prefetch | — | — | — | — | — | — | ✓ | — | — | — |
| 同步机制Sync | stream | — | stream | stream | stream | Event | stream | stream | Event | stream |
| 编译Compiled | — | torch.compile | — | — | — | — | — | — | — | compiled_autograd |
| Backward | ✓ | ✓ | ✓ | ✓ | ✓ | 分离 Emb/DenseSplit Emb/Dense | ✓ | — | — | ✓ |
progress() 方法中,修改重叠方式需重写整个方法。SWPipeline 将调度提取为声明式 PipelinePlan,只需修改 Plan,Task 函数不变。此外:
progress() — changing overlap requires rewriting the entire method. SWPipeline extracts scheduling into declarative PipelinePlan, only the Plan changes. Additionally: