TorchRec Pipeline Schedule Analysis

用 SWPipeline 的 Task / Schedule / Stream / Stage 视角解读上游 TorchRec 的所有训练 PipelineAnalyzing upstream TorchRec training pipelines through SWPipeline's Task / Schedule / Stream / Stage lens

torchrec/distributed/train_pipeline/train_pipelines.py

总览 — 上游 TorchRec 的 11 种 PipelineOverview — 11 Pipeline Types in Upstream TorchRec

来源: meta-pytorch/torchrec main 分支Source: meta-pytorch/torchrec main branch

上游 TorchRec 的 train_pipelines.py 中定义了 1 个 ABC 基类 + 11 个具体实现。本页用 SWPipeline 的统一概念——PipelineTask(做什么)、TaskSchedule(stage / stream / thread_group / globally_ordered)、依赖关系——来解读每种 Pipeline 的调度策略。

Upstream TorchRec's train_pipelines.py defines 1 ABC base class + 11 concrete implementations. This page uses SWPipeline's unified concepts — PipelineTask (what), TaskSchedule (stage / stream / thread_group / globally_ordered), and dependencies — to interpret each Pipeline's scheduling strategy.

# 类名Class 类型Type Streams Depth 说明Description
0 TrainPipeline ABC 抽象基类,定义 progress() 接口ABC defining progress() interface
1 TrainPipelineBase Train 2 2 最简 2-stage H2D 流水线Simplest 2-stage H2D pipeline
2 TrainPipelinePT2 Compiled 1 1 PT2 编译,无流水线PT2 compiled, no pipelining
3 TrainPipelineSparseDist Train 3 3 经典 3-stage: H2D ‖ InputDist ‖ FWD/BWDClassic 3-stage: H2D ‖ InputDist ‖ FWD/BWD
4 TrainPipelineSparseDistLite Train 2 2 省内存 2-stage, InputDist 在关键路径Memory-efficient 2-stage, InputDist in critical path
5 TrainPipelineFusedSparseDist Train 3–4 3 Emb lookup 独立 stream, 与 optimizer 重叠Separate emb lookup stream, overlaps with optimizer
6 TrainPipelineSemiSync Train 3 3 半同步训练, Emb A2A 与 Forward 完全重叠Semi-sync, fully overlaps Emb A2A with Forward
7 PrefetchTrainPipelineSparseDist Train 4 3 4-stage 含 cache prefetch4-stage with cache prefetch
8 EvalPipelineSparseDist Eval 3 2 推理用, 后台 DataLoadingThread H2DEval, background DataLoadingThread H2D
9 EvalPipelineFusedSparseDist Eval 3–4 3 推理用 Fused 变体Eval variant of Fused pipeline
10 StagedTrainPipeline Generic 用户定义User-defined N 通用 N-stage 声明式流水线Generic N-stage declarative pipeline
11 TrainPipelineSparseDistCompAutograd Compiled 3 3 SparseDist + Compiled AutogradSparseDist + Compiled Autograd
default_stream memcpy_stream data_dist_stream prefetch_stream emb_lookup_stream

继承关系:

Inheritance hierarchy:

TrainPipeline (ABC)
├── TrainPipelineBase
│   └── TrainPipelinePT2
├── TrainPipelineSparseDist
│   ├── TrainPipelineSparseDistLite
│   ├── TrainPipelineFusedSparseDist
│   │   └── EvalPipelineFusedSparseDist
│   ├── TrainPipelineSemiSync
│   ├── PrefetchTrainPipelineSparseDist
│   ├── EvalPipelineSparseDist
│   └── TrainPipelineSparseDistCompAutograd
└── StagedTrainPipeline

1. TrainPipelineBase

最简 2-stage 流水线 — H2D 与 FWD/BWD 重叠Simplest 2-stage pipeline — H2D overlaps with FWD/BWD

SWPipeline Schedule

# pipeline_depth = 2 (cur_batch + next_batch)
# 2 streams: memcpy_stream + default_stream

PipelinePlan(
  schedule = {
    "H2D":            TaskSchedule(stage=0, stream=memcpy),      # stage 0: batch 进入 pipeline 的第一步
    "ZeroGrad":       TaskSchedule(stage=1, stream=default),
    "WaitBatch":      TaskSchedule(stage=1, stream=default),     # wait memcpy → default
    "Forward":        TaskSchedule(stage=1, stream=default),
    "Backward":       TaskSchedule(stage=1, stream=default),
    "OptimizerStep":  TaskSchedule(stage=1, stream=default),     # stage 1: batch 处理的最后一步
  },
  intra_iter_deps = [
    ("WaitBatch",     "H2D"),            # cross-stage: memcpy → default stream sync
    ("WaitBatch",     "ZeroGrad"),       # stage=1 ordering
    ("Forward",       "WaitBatch"),      # stage=1 ordering
    ("Backward",      "Forward"),        # stage=1 ordering
    ("OptimizerStep", "Backward"),       # stage=1 ordering
  ],
  inter_iter_deps = [
    ("Forward", "OptimizerStep"),       # Forward(i) 读取 OptimizerStep(i-1) 更新的权重
  ],                                          # same-stream (default) → FIFO 已保证 GPU 顺序
  pipeline_depth = 2,
)

print_schedule(5) 输出print_schedule(5) Output

   #  Task               Thread   Stream        | P0    P1    P2    P3    P4
  --  -----------------  -------  ------------  + ----- ----- ----- ----- -----
   0  ZeroGrad           default  default       |  --   i0    i1    i2    i3
   1  WaitBatch          default  default       |  --   i0    i1    i2    i3
   2  Forward            default  default       |  --   i0    i1    i2    i3
   3  Backward           default  default       |  --   i0    i1    i2    i3
   4  OptimizerStep      default  default       |  --   i0    i1    i2    i3
   5  H2D                default  memcpy        | i0    i1    i2    i3    i4
# pipeline_depth=2: H2D(i+1) 在 memcpy stream 上与当前 batch 的处理并行

CPU 提交顺序(单次 progress)CPU Submission Order (single progress call)

##  progress(iter i):                      batch    stream
  1. ZeroGrad                                [i]      default
  2. WaitBatch                               [i]      default (wait memcpy→default)
  3. Forward                                 [i]      default
  4. H2D                                     [i+1]    memcpy  ← Backward 之后才提交
  5. Backward                                [i]      default
  6. OptimizerStep                           [i]      default  ← H2D 在 memcpy 上与 Opt 并行
TrainPipelineBase 时序图TrainPipelineBase Timeline
memcpy default 0G Wait Forward (i) Backward (i) H2D (i+1) Opt(i) 0G Wait Forward (i+1) Backward (i+1) H2D (i+2) Opt iter i iter i+1 H2D 在 Backward 之后启动,仅与 Optimizer 重叠H2D starts after Backward, only overlaps with Optimizer
特点: 不对 ShardedModule 做 FX rewrite,不拆分 InputDist。适合无 sharding 或极简场景。H2D(i+1) 在 Backward(i) 结束后才启动,仅与 Optimizer(i) 在不同 stream 重叠。 Key feature: No FX rewrite of ShardedModule, no InputDist split. Suitable for non-sharded or minimal scenarios. H2D(i+1) starts after Backward(i), only overlaps with Optimizer(i) on a separate stream.

2. TrainPipelinePT2

PyTorch 2 编译流水线 — torch.compile 加速,单 stream 无重叠PyTorch 2 compiled pipeline — torch.compile acceleration, single stream no overlap

SWPipeline Schedule

# pipeline_depth = 1 (no pipelining)
# 1 stream: default_stream only
# Model compiled via torch.compile on compile_on_iter (default: 3rd iter)

PipelinePlan(
  schedule = {
    "LoadBatch":       TaskSchedule(stage=0, stream=default),  # next(dataloader_iter)
    "H2D":             TaskSchedule(stage=0, stream=default),  # _to_device(non_blocking=False)
    "InputTransform":  TaskSchedule(stage=0, stream=default),  # PT2 input hints
    "ZeroGrad":        TaskSchedule(stage=0, stream=default),
    "Forward":         TaskSchedule(stage=0, stream=default),  # compiled model(batch)
    "Backward":        TaskSchedule(stage=0, stream=default),
    "OptimizerStep":   TaskSchedule(stage=0, stream=default),
  },
  pipeline_depth = 1,  # fully serial, no overlap
)
注意: H2D 使用 non_blocking=False(同步拷贝),完全没有 stream 重叠。优化完全依赖 torch.compile 对计算图的编译加速。在第 compile_on_iter 次迭代触发编译(默认第 3 次)。继承自 TrainPipelineBase Note: H2D uses non_blocking=False (synchronous copy), zero stream overlap. Optimization relies entirely on torch.compile graph compilation. Compilation triggers at compile_on_iter (default: 3rd iteration). Inherits from TrainPipelineBase.

3. TrainPipelineSparseDist

经典 3-stage: H2D ‖ InputDist ‖ Forward/Backward — TorchRec 核心流水线Classic 3-stage: H2D ‖ InputDist ‖ Forward/Backward — TorchRec's core pipeline

SWPipeline Schedule

# pipeline_depth = 3 (max 3 batches in flight: batches[0], [1], [2])
# 3 streams: memcpy, data_dist, default
# fill_pipeline loads 2 batches; progress() enqueue_batch adds 3rd
# Requires FX rewrite via _rewrite_model (PipelinedForward)

PipelinePlan(
  schedule = {
    "H2D":              TaskSchedule(stage=0, stream=memcpy),         # stage 0: batch 入口 (batches[2])
    "InputDistStart":   TaskSchedule(stage=1, stream=data_dist, globally_ordered=True),  # stage 1 (batches[1])
    "InputDistWait":    TaskSchedule(stage=1, stream=data_dist),    # stage 1 (batches[1])
    "ZeroGrad":         TaskSchedule(stage=2, stream=default),
    "WaitBatch":        TaskSchedule(stage=2, stream=default),       # stage 2 (batches[0])
    "Forward":          TaskSchedule(stage=2, stream=default),       # stage 2: batch 出口 (batches[0])
    "Backward":         TaskSchedule(stage=2, stream=default),
    "OptimizerStep":    TaskSchedule(stage=2, stream=default),
  },
  intra_iter_deps = [
    ("InputDistStart", "H2D"),            # cross-stage: memcpy → data_dist stream sync
    ("InputDistWait",  "InputDistStart"),  # stage=1 ordering
    ("WaitBatch",      "InputDistWait"),   # cross-stage: data_dist → default stream sync
    ("Forward",        "InputDistWait"),   # cross-stage: data_dist → default stream sync
    ("WaitBatch",      "ZeroGrad"),        # stage=2 ordering
    ("Forward",        "WaitBatch"),       # stage=2 ordering
    ("Backward",       "Forward"),         # stage=2 ordering
    ("OptimizerStep",  "Backward"),        # stage=2 ordering
  ],
  inter_iter_deps = [
    ("Forward", "OptimizerStep"),        # Forward(i) 读取 OptimizerStep(i-1) 更新的权重
  ],                                           # same-stream (default) → FIFO 已保证 GPU 顺序
  pipeline_depth = 3,
)

print_schedule(5) 输出print_schedule(5) Output

   #  Task               Thread   Stream        | P0    P1    P2    P3    P4
  --  -----------------  -------  ------------  + ----- ----- ----- ----- -----
   0  ZeroGrad           default  default       |  --    --   i0    i1    i2
   1  WaitBatch          default  default       |  --    --   i0    i1    i2
   2  Forward            default  default       |  --    --   i0    i1    i2
   3  Backward           default  default       |  --    --   i0    i1    i2
   4  OptimizerStep      default  default       |  --    --   i0    i1    i2
   5  InputDistStart     default  data_dist     |  --   i0    i1    i2    i3
   6  InputDistWait      default  data_dist     |  --   i0    i1    i2    i3
   7  H2D                default  memcpy        | i0    i1    i2    i3    i4
# pipeline_depth=3: 3 条 stream (memcpy, data_dist, default) 并行

CPU 提交顺序(单次 progress)CPU Submission Order (single progress call)

##  progress(iter i):                      batch    stream
  1. ZeroGrad                                [i]      default
  2. WaitBatch                               [i]      default (wait data_dist→default)
  3. InputDistStart                          [i+1]    data_dist  ← start splits A2A
  4. H2D                                     [i+2]    memcpy     ← enqueue_batch
  5. Forward (EmbLookup 在内部执行)            [i]      default    ← PipelinedForward
  6. InputDistWait                           [i+1]    data_dist  ← wait splits A2A
  7. Backward                                [i]      default
  8. OptimizerStep                           [i]      default
# 注: enqueue_batch_after_forward=True 时, H2D 在 Forward 之后提交
# 以避免 H2D 与 UVM embedding lookup 争抢 PCIe 带宽
TrainPipelineSparseDist 时序图TrainPipelineSparseDist Timeline
memcpy data_dist default Forward (i) Backward (i) Opt(i) H2D (i+1) DistStart(i+1) Wait(i+1) Forward (i+1) Backward (i+1) H2D (i+2) DistStart(i+2) iter i iter i+1
核心机制: 通过 _rewrite_modelShardedModule 做 FX 符号追踪,将 input_dist 提前到独立 stream 执行。最多 3 个 batch 在飞(源码注释: "max pipelined batches == 3"):batches[0] 做 FWD/BWD,batches[1] 做 InputDist,batches[2] 做 H2D。三者在 3 条不同 stream 上并行。这是 TorchRec 最核心的流水线,其他 6 个 training pipeline 均继承自此。
注意:此处 Embedding Lookup(compute_and_output_dist)在 Forward 内部由 PipelinedForward 自动执行。FusedSparseDist 对此做了进一步优化——将 Emb Lookup 提取到独立 stream 上提前执行,详见下文对比。
Core mechanism: Uses _rewrite_model to FX-trace ShardedModule, extracting input_dist to a separate stream. Up to 3 batches in flight (source: "max pipelined batches == 3"): batches[0] runs FWD/BWD, batches[1] runs InputDist, batches[2] runs H2D — all 3 in parallel on different streams. This is TorchRec's core pipeline — 6 other training pipelines inherit from it.
Note: Embedding Lookup (compute_and_output_dist) runs inside Forward, auto-triggered by PipelinedForward. FusedSparseDist further optimizes this by extracting Emb Lookup to a separate stream — see comparison below.

4. TrainPipelineSparseDistLite

省内存变体 — InputDist 留在关键路径, 仅 H2D 重叠Memory-efficient variant — InputDist stays in critical path, only H2D overlaps

SWPipeline Schedule

# pipeline_depth = 2 (max 2 batches: fill 1, enqueue 1 in progress)
# 2 streams: memcpy + default (data_dist_stream set to None!)
# fill_pipeline fills only 1 batch (vs 2 in full SDD)

PipelinePlan(
  schedule = {
    "H2D":              TaskSchedule(stage=0, stream=memcpy),     # stage 0: batch 入口
    "ZeroGrad":         TaskSchedule(stage=1, stream=default),
    "WaitBatch":        TaskSchedule(stage=1, stream=default),
    "InputDistStart":   TaskSchedule(stage=1, stream=default),   # on default stream!
    "InputDistWait":    TaskSchedule(stage=1, stream=default),   # on default stream!
    "Forward":          TaskSchedule(stage=1, stream=default),
    "Backward":         TaskSchedule(stage=1, stream=default),
    "OptimizerStep":    TaskSchedule(stage=1, stream=default),   # stage 1: batch 出口
  },
  intra_iter_deps = [
    ("WaitBatch",      "H2D"),            # cross-stage: memcpy → default stream sync
    ("WaitBatch",      "ZeroGrad"),       # stage=1 ordering
    ("InputDistStart", "WaitBatch"),      # stage=1 ordering (default stream)
    ("InputDistWait",  "InputDistStart"), # stage=1 ordering
    ("Forward",        "InputDistWait"),  # stage=1 ordering
    ("Backward",       "Forward"),        # stage=1 ordering
    ("OptimizerStep",  "Backward"),       # stage=1 ordering
  ],
  inter_iter_deps = [
    ("Forward", "OptimizerStep"),       # Forward(i) 读取 OptimizerStep(i-1) 更新的权重
  ],                                          # same-stream (default) → FIFO 已保证 GPU 顺序
  pipeline_depth = 2,
)

print_schedule(5) 输出print_schedule(5) Output

   #  Task               Thread   Stream        | P0    P1    P2    P3    P4
  --  -----------------  -------  ------------  + ----- ----- ----- ----- -----
   0  ZeroGrad           default  default       |  --   i0    i1    i2    i3
   1  WaitBatch          default  default       |  --   i0    i1    i2    i3
   2  InputDistStart     default  default       |  --   i0    i1    i2    i3
   3  InputDistWait      default  default       |  --   i0    i1    i2    i3
   4  Forward            default  default       |  --   i0    i1    i2    i3
   5  Backward           default  default       |  --   i0    i1    i2    i3
   6  OptimizerStep      default  default       |  --   i0    i1    i2    i3
   7  H2D                default  memcpy        | i0    i1    i2    i3    i4
# pipeline_depth=2: 仅 H2D 在 memcpy stream 上与上一轮 Opt 重叠
# 所有 stage=1 任务都在 default stream 上串行执行

CPU 提交顺序(单次 progress)CPU Submission Order (single progress call)

##  progress(iter i):                      batch    stream
  1. ZeroGrad                                [i]      default
  2. WaitBatch                               [i]      default (wait memcpy→default)
  3. InputDistStart+Wait                     [i]      default  ← 关键路径! data_dist=None
  4. Forward                                 [i]      default
  5. Backward                                [i]      default
  6. H2D                                     [i+1]    memcpy   ← enqueue_batch, BWD 与 OPT 之间
  7. OptimizerStep                           [i]      default  ← H2D 在 memcpy 上与 Opt 并行
设计要点: 关键区别——self._data_dist_stream = None。InputDist 在 default stream 上执行(关键路径),只有 H2D 使用独立 memcpy_stream。fill_pipeline 仅填 1 个 batch(vs 完整 SDD 的 2 个),progress()enqueue_batch 加到 2 个 → 最大 2 batch 在飞。内存开销约为 SparseDist 的 2/3(2 batch vs 3),QPS 提升约 4-5%。适合内存受限场景。 Design point: Key difference — self._data_dist_stream = None. InputDist executes on default stream (critical path), only H2D uses separate memcpy_stream. fill_pipeline loads only 1 batch (vs 2 in full SDD), progress() enqueue_batch brings it to 2 → max 2 batches in flight. Memory overhead ~2/3 of SparseDist (2 batches vs 3), ~4-5% QPS improvement. Ideal for memory-constrained deployments.

5. TrainPipelineFusedSparseDist

Emb Lookup 独立 stream — 与上一轮 Optimizer 重叠Separate Emb Lookup stream — overlaps with previous Optimizer

SWPipeline Schedule

# pipeline_depth = 3 (max 3 batches in flight, same as SparseDist)
# source: "max pipelined batches == 3 (capacity)"
# 3-4 streams: memcpy, data_dist, emb_lookup (default=data_dist), default
# Uses InSyncEmbeddingPipelinedForward (not PipelinedForward)

PipelinePlan(
  schedule = {
    "H2D":              TaskSchedule(stage=0, stream=memcpy),         # stage 0: batch 入口 (batches[2])
    "InputDistStart":   TaskSchedule(stage=1, stream=data_dist, globally_ordered=True),  # stage 1 (batches[1])
    "InputDistWait":    TaskSchedule(stage=1, stream=data_dist),    # stage 1 (batches[1])
    "EmbLookup":        TaskSchedule(stage=2, stream=emb_lookup),   # stage 2 (batches[0]), compute_and_output_dist
    "ZeroGrad":         TaskSchedule(stage=2, stream=default),
    "WaitBatch":        TaskSchedule(stage=2, stream=default),
    "Forward":          TaskSchedule(stage=2, stream=default),       # stage 2: dense-only forward (batch 出口)
    "Backward":         TaskSchedule(stage=2, stream=default),
    "OptimizerStep":    TaskSchedule(stage=2, stream=default),
  },
  intra_iter_deps = [
    ("InputDistStart", "H2D"),           # cross-stage: memcpy → data_dist stream sync
    ("InputDistWait",  "InputDistStart"), # stage=1 ordering
    ("EmbLookup",      "InputDistWait"),  # cross-stage: data_dist → emb_lookup stream sync
    ("Forward",        "EmbLookup"),      # cross-stream: emb_lookup → default
    ("WaitBatch",      "ZeroGrad"),       # stage=2 ordering
    ("Forward",        "WaitBatch"),      # stage=2 ordering
    ("Backward",       "Forward"),        # stage=2 ordering
    ("OptimizerStep",  "Backward"),       # stage=2 ordering
  ],
  inter_iter_deps = [
    ("EmbLookup",      "Backward"),       # EmbLookup(i) 读取 Backward(i-1) 更新的 emb 权重 (TBE fused)
    ("Forward",        "OptimizerStep"), # Forward(i) 读取 OptimizerStep(i-1) 更新的 dense 权重
  ],                                          # EmbLookup↔Backward 跨 stream (emb_lookup↔default),必须显式声明
  pipeline_depth = 3,
)

print_schedule(5) 输出print_schedule(5) Output

   #  Task               Thread   Stream        | P0    P1    P2    P3    P4
  --  -----------------  -------  ------------  + ----- ----- ----- ----- -----
   0  EmbLookup          default  emb_lookup    |  --    --   i0    i1    i2
   1  ZeroGrad           default  default       |  --    --   i0    i1    i2
   2  WaitBatch          default  default       |  --    --   i0    i1    i2
   3  Forward            default  default       |  --    --   i0    i1    i2
   4  Backward           default  default       |  --    --   i0    i1    i2
   5  OptimizerStep      default  default       |  --    --   i0    i1    i2
   6  InputDistStart     default  data_dist     |  --   i0    i1    i2    i3
   7  InputDistWait      default  data_dist     |  --   i0    i1    i2    i3
   8  H2D                default  memcpy        | i0    i1    i2    i3    i4
# EmbLookup 在 emb_lookup stream 上独立执行,与 ZeroGrad/WaitBatch 在 default stream 上并行

CPU 提交顺序(单次 progress)CPU Submission Order (single progress call)

##  progress(iter i):                      batch    stream
  1. EmbLookup                               [i]      emb_lookup ← progress 最前面!
  2. ZeroGrad                                [i]      default
  3. WaitBatch                               [i]      default (wait data_dist→default)
  4. InputDistStart                          [i+1]    data_dist
  5. H2D                                     [i+2]    memcpy
  6. Forward (dense-only, wait emb_lookup)    [i]      default
  7. InputDistWait                           [i+1]    data_dist
  8. Backward                                [i]      default
  9. OptimizerStep                           [i]      default
# embedding_lookup_after_data_dist=True 时, EmbLookup 在 H2D 之后

与 TrainPipelineSparseDist 的核心区别

两者 pipeline depth 相同(3 batches in flight),关键差异在于 Embedding Lookup(compute_and_output_dist)的执行时机和位置

维度 SparseDist FusedSparseDist
PipelinedForward 类型 PipelinedForward InSyncEmbeddingPipelinedForward
Emb Lookup 位置 model_fwd() 内部,由 PipelinedForward 自动执行,运行在 default_stream model_fwd() 之前,由 start_embedding_lookup() 显式调用,运行在 emb_lookup_stream
Emb Lookup 时机 Forward(i) 开始时(与 Forward 串行) progress() 最开始,在 ZeroGrad 之前
可重叠对象 无(Emb Lookup 在 Forward 关键路径上) EmbLookup(i) 可与上一轮 Optimizer(i-1) 的尾部在不同 stream 上并行
Forward 内容 完整 forward(含 emb lookup + dense) 仅 dense forward(emb 结果已由 InSyncEmbeddingPipelinedForward 注入)
额外 stream emb_lookup_stream(默认复用 data_dist,可配 "new""current"

执行顺序对比(CPU 提交顺序):

Key Differences from TrainPipelineSparseDist

Both have the same pipeline depth (3 batches in flight). The critical difference is when and where Embedding Lookup (compute_and_output_dist) executes:

Dimension SparseDist FusedSparseDist
PipelinedForward type PipelinedForward InSyncEmbeddingPipelinedForward
Emb Lookup location Inside model_fwd(), auto-executed by PipelinedForward, runs on default_stream Before model_fwd(), explicitly called via start_embedding_lookup(), runs on emb_lookup_stream
Emb Lookup timing At the start of Forward(i) (serial with Forward) At the very beginning of progress(), before ZeroGrad
Can overlap with Nothing (Emb Lookup is on Forward's critical path) EmbLookup(i) can overlap with previous Optimizer(i-1) tail on a different stream
Forward content Full forward (emb lookup + dense) Dense-only forward (emb results injected by InSyncEmbeddingPipelinedForward)
Extra stream None emb_lookup_stream (default: reuses data_dist, or "new" / "current")

Execution order comparison (CPU submission order):

CPU 提交顺序对比CPU Submission Order Comparison

## SparseDist progress():
  1. ZeroGrad                                  # default_stream
  2. WaitBatch                                 # wait data_dist → default
  3. InputDistStart(i+1)                       # data_dist_stream
  4. H2D(i+2)                                  # memcpy_stream
  5. Forward(i) ← Emb Lookup 在此内部执行      # default_stream
  6. InputDistWait(i+1)                        # data_dist_stream
  7. Backward(i)                               # default_stream
  8. Optimizer(i)                              # default_stream
  9. Dequeue

## FusedSparseDist progress():
  1. EmbLookup(i)                              # emb_lookup_stream ← 提到最前面!
  2. ZeroGrad                                  # default_stream
  3. WaitBatch                                 # wait data_dist → default
  4. InputDistStart(i+1)                       # data_dist_stream
  5. H2D(i+2)                                  # memcpy_stream
  6. Forward(i) ← 仅 dense, emb 结果已就绪      # default_stream, 先 wait emb_lookup_stream
  7. InputDistWait(i+1)                        # data_dist_stream
  8. Backward(i)                               # default_stream
  9. Optimizer(i)                              # default_stream
 10. Dequeue
EmbLookup 重叠示意(Fused 独有)EmbLookup Overlap (Fused-only)
emb_lookup memcpy data_dist default Opt (i-1) EmbLookup (i) 重叠区域overlap zone 0G W DistStart(i+1) H2D(i+2) Forward (i) dense Wait(i+1) Backward (i) Opt (i) EmbLookup (i+1) 重叠区域overlap zone
Fused-TBE 假设: FBGEMM 的 Fused-TBE 在 backward 期间直接更新 embedding 权重(而非 optimizer step),因此 EmbLookup(i) 只需等 Backward(i-1) 完成即可安全启动,无需等 Optimizer(i-1)。这是 FusedSparseDist 能提前 EmbLookup 的根本原因。如果模型使用 Feature Processor 或非 fused optimizer,此假设不成立,结果可能不正确。此 Pipeline 仍标记为实验性质。

emb_lookup_stream 选项:
  • "data_dist"(默认):复用 data_dist_stream,减少 CUDA Caching Allocator 为额外 stream 预留显存
  • "new":新建独立 stream,最大化重叠潜力
  • "current":使用 default_stream,退化为与 SparseDist 等效(无重叠)
Fused-TBE assumption: FBGEMM's Fused-TBE updates embedding weights during backward (not optimizer step), so EmbLookup(i) only needs Backward(i-1) to complete — it doesn't have to wait for Optimizer(i-1). This is the fundamental reason FusedSparseDist can start EmbLookup earlier. If the model uses Feature Processors or non-fused optimizers, this assumption breaks and results may be incorrect. This Pipeline is still marked experimental.

emb_lookup_stream options:
  • "data_dist" (default): reuses data_dist_stream, reduces CUDA Caching Allocator memory reservation for extra streams
  • "new": fresh stream, maximizes overlap potential
  • "current": default_stream, degrades to SparseDist-equivalent (no overlap)

6. TrainPipelineSemiSync

半同步训练 — 用 B-2 参数预测, Emb A2A 与 Forward 完全重叠Semi-synchronous — predicts with B-2 params, fully overlaps Emb A2A with Forward

SWPipeline Schedule

# pipeline_depth = 4 (4 batches in flight: i, i+1, i+2, i+3)
# 3 streams: memcpy, data_dist, default
# Uses EmbeddingPipelinedForward (detaches emb from autograd)
# H2D(i+3) → InputDist(i+2) → EmbLookup(i+1) → Forward/Backward(i)

PipelinePlan(
  schedule = {
    "H2D":              TaskSchedule(stage=0, stream=memcpy),         # stage 0: batch[i+3]
    "InputDistStart":   TaskSchedule(stage=1, stream=data_dist, globally_ordered=True),  # stage 1: batch[i+2]
    "InputDistWait":    TaskSchedule(stage=1, stream=data_dist),    # stage 1: batch[i+2]
    "EmbLookup":        TaskSchedule(stage=2, stream=default),      # stage 2: batch[i+1], start_embedding_lookup
    "ZeroGrad":         TaskSchedule(stage=3, stream=default),      # stage 3: batch[i]
    "Forward":          TaskSchedule(stage=3, stream=default),      # stage 3: uses detached emb tensors
    "Backward":         TaskSchedule(stage=3, stream=default),      # stage 3: dense backward only
    "EmbBackward":      TaskSchedule(stage=3, stream=default),      # stage 3: torch.autograd.backward
    "OptimizerStep":    TaskSchedule(stage=3, stream=default),      # stage 3: batch 出口
  },
  intra_iter_deps = [
    ("InputDistStart", "H2D"),           # cross-stage: memcpy → data_dist stream sync
    ("InputDistWait",  "InputDistStart"), # stage=1 ordering
    ("EmbLookup",      "InputDistWait"),  # cross-stage: data_dist → default stream sync
    ("Forward",        "EmbLookup"),      # cross-stage: emb ready before forward
    ("Forward",        "ZeroGrad"),       # stage=3 ordering
    ("Backward",       "Forward"),        # stage=3 ordering
    ("EmbBackward",    "Backward"),       # stage=3 ordering
    ("OptimizerStep",  "EmbBackward"),    # stage=3 ordering
  ],
  inter_iter_deps = [
    ("EmbLookup",      "Backward"),       # EmbLookup(i) 读取 Backward(i-1) 更新的 emb 权重 (TBE fused)
  ],                                          # same-stream (default) → FIFO 已保证 GPU 顺序
  # 注意: Forward(i) 使用 B-2 参数 → 逻辑上依赖 OptimizerStep(i-2),k=2 超出 SWPipeline 支持范围
  pipeline_depth = 4,
)

print_schedule(6) 输出print_schedule(6) Output

   #  Task               Thread   Stream        | P0    P1    P2    P3    P4    P5
  --  -----------------  -------  ------------  + ----- ----- ----- ----- ----- -----
   0  ZeroGrad           default  default       |  --    --    --   i0    i1    i2
   1  Forward            default  default       |  --    --    --   i0    i1    i2
   2  Backward           default  default       |  --    --    --   i0    i1    i2
   3  EmbBackward        default  default       |  --    --    --   i0    i1    i2
   4  OptimizerStep      default  default       |  --    --    --   i0    i1    i2
   5  EmbLookup          default  default       |  --    --   i0    i1    i2    i3
   6  InputDistStart     default  data_dist     |  --   i0    i1    i2    i3    i4
   7  InputDistWait      default  data_dist     |  --   i0    i1    i2    i3    i4
   8  H2D                default  memcpy        | i0    i1    i2    i3    i4    i5
# pipeline_depth=4: P0~P2 为 fill_pipeline, P3 起输出第一个结果
# stage 0: H2D | stage 1: InputDist | stage 2: EmbLookup | stage 3: Fwd/Bwd/Opt

CPU 提交顺序(单次 progress, semi-sync 模式)CPU Submission Order (single progress, semi-sync mode)

##  SWPipeline enqueue order (stage-descending):  batch    stream     stage
  1. ZeroGrad                                    [i]      default    ← stage 3
  2. Forward (emb 已在上轮 EmbLookup 完成)          [i]      default    ← stage 3
  3. Backward (dense only)                       [i]      default    ← stage 3
  4. EmbBackward                                [i]      default    ← stage 3
  5. OptimizerStep                              [i]      default    ← stage 3
  6. EmbLookup                                  [i+1]    default    ← stage 2
  7. InputDistStart                             [i+2]    data_dist  ← stage 1
  8. InputDistWait                              [i+2]    data_dist  ← stage 1
  9. H2D                                        [i+3]    memcpy     ← stage 0
# 注意: pipeline_depth=4, 共 4 个 batch 在飞。
# Forward(i) 使用的 embedding 来自 EmbLookup(i) (前一轮提交)。
# sync 模式: EmbLookup 在 Backward+OptimizerStep 之后,
# 需要等上一轮 Backward 完成才能安全启动。
SemiSync 核心思想SemiSync Core Idea
data_dist default Emb Forward (i) InputDist (i+1) Backward (i) EmbBwd Emb Forward (i+1) InputDist (i+2) Opt(i-1) 模型用 B-2 参数预测, Emb A2A 与 Forward 完全重叠Predicts with B-2 params, Emb A2A fully overlaps Forward
收敛性影响: 每个 batch 使用 B-2(而非 B-1)的参数做预测。前 start_batch(默认 900)个 batch 使用同步模式。stash_gradients=True 时保存梯度以确保"真正的"半同步语义。Embedding 的 backward 独立执行(torch.autograd.backward),与 dense backward 分离。 Convergence impact: Each batch predicts using B-2 (not B-1) parameters. First start_batch (default 900) batches use synchronous mode. With stash_gradients=True, gradients are stored for "true" semi-sync semantics. Embedding backward is separate (torch.autograd.backward), decoupled from dense backward.

7. PrefetchTrainPipelineSparseDist

4-stage 含 Cache Prefetch — 隐藏 HBM cache miss 延迟4-stage with Cache Prefetch — hides HBM cache miss latency

SWPipeline Schedule

# pipeline_depth = 3 (3 batches: _batch_i, _batch_ip1, _batch_ip2)
# 4 streams: memcpy, data_dist, prefetch, default
# Uses PrefetchPipelinedForward

PipelinePlan(
  schedule = {
    "H2D":              TaskSchedule(stage=0, stream=memcpy),         # stage 0: batch 入口
    "InputDistStart":   TaskSchedule(stage=0, stream=data_dist, globally_ordered=True),
    "InputDistWait":    TaskSchedule(stage=1, stream=data_dist),
    "EmbPrefetch":      TaskSchedule(stage=1, stream=prefetch),      # _prefetch_embeddings
    "ZeroGrad":         TaskSchedule(stage=2, stream=default),
    "WaitBatch":        TaskSchedule(stage=2, stream=default),       # wait prefetch_stream
    "Forward":          TaskSchedule(stage=2, stream=default),
    "Backward":         TaskSchedule(stage=2, stream=default),
    "OptimizerStep":    TaskSchedule(stage=2, stream=default),       # stage 2: batch 出口
  },
  intra_iter_deps = [
    ("InputDistStart",  "H2D"),            # cross-stage: memcpy → data_dist stream sync
    ("InputDistWait",   "InputDistStart"), # stage=1 ordering (cross-stage)
    ("EmbPrefetch",     "InputDistWait"),  # cross-stream: data_dist → prefetch
    ("WaitBatch",       "EmbPrefetch"),    # cross-stage: prefetch → default stream sync
    ("WaitBatch",       "ZeroGrad"),       # stage=2 ordering
    ("Forward",         "WaitBatch"),      # stage=2 ordering
    ("Backward",        "Forward"),        # stage=2 ordering
    ("OptimizerStep",   "Backward"),       # stage=2 ordering
  ],
  inter_iter_deps = [
    ("EmbPrefetch",    "Forward"),        # EmbPrefetch(i) 须在 Forward(i-1) 消费完 prefetch 结果后执行
    ("Forward",        "OptimizerStep"), # Forward(i) 读取 OptimizerStep(i-1) 更新的权重
  ],                                          # EmbPrefetch↔Forward 跨 stream (prefetch↔default),必须显式声明
  pipeline_depth = 3,
)

print_schedule(5) 输出print_schedule(5) Output

   #  Task               Thread   Stream        | P0    P1    P2    P3    P4
  --  -----------------  -------  ------------  + ----- ----- ----- ----- -----
   0  ZeroGrad           default  default       |  --    --   i0    i1    i2
   1  WaitBatch          default  default       |  --    --   i0    i1    i2
   2  Forward            default  default       |  --    --   i0    i1    i2
   3  Backward           default  default       |  --    --   i0    i1    i2
   4  OptimizerStep      default  default       |  --    --   i0    i1    i2
   5  InputDistWait      default  data_dist     |  --   i0    i1    i2    i3
   6  EmbPrefetch        default  prefetch      |  --   i0    i1    i2    i3
   7  H2D                default  memcpy        | i0    i1    i2    i3    i4
   8  InputDistStart     default  data_dist     | i0    i1    i2    i3    i4
# stage 0: H2D + InputDistStart (两个不同 stream 同 stage)
# stage 1: InputDistWait + EmbPrefetch
# stage 2: ZeroGrad → WaitBatch → Forward → Backward → OptimizerStep

CPU 提交顺序(单次 progress)CPU Submission Order (single progress call)

##  progress(iter i):                      batch    stream
  1. ZeroGrad                                [i]      default
  2. WaitBatch                               [i]      default (wait prefetch→default)
  3. H2D                                     [i+2]    memcpy     ← _copy_batch_to_gpu
  4. InputDistWait                           [i+1]    data_dist  ← 上轮 InputDistStart 的结果
  5. Forward                                 [i]      default    ← PrefetchPipelinedForward
  6. EmbPrefetch                             [i+1]    prefetch   ← _prefetch_embeddings
  7. Backward                                [i]      default
  8. OptimizerStep                           [i]      default
  9. InputDistStart                          [i+2]    data_dist  ← _start_sparse_data_dist
 10. Slide: i←ip1, ip1←ip2

progress() 单次迭代执行顺序:

  1. ZeroGrad
  2. WaitBatchwait_for_batch(_batch_i, prefetch_stream)
  3. H2D_copy_batch_to_gpu(_batch_ip2)
  4. InputDistWait_wait_sparse_data_dist()
  5. Forwardmodel(_batch_i)
  6. EmbPrefetch_prefetch(_batch_ip1) on prefetch_stream
  7. Backwardloss.backward()
  8. OptimizerStep
  9. InputDistStart_start_sparse_data_dist(_batch_ip2)
  10. 窗口滑动Slide window: _batch_i ← _batch_ip1, _batch_ip1 ← _batch_ip2

progress() single iteration execution order:

  1. ZeroGrad
  2. WaitBatchwait_for_batch(_batch_i, prefetch_stream)
  3. H2D_copy_batch_to_gpu(_batch_ip2)
  4. InputDistWait_wait_sparse_data_dist()
  5. Forwardmodel(_batch_i)
  6. EmbPrefetch_prefetch(_batch_ip1) on prefetch_stream
  7. Backwardloss.backward()
  8. OptimizerStep
  9. InputDistStart_start_sparse_data_dist(_batch_ip2)
  10. Slide window: _batch_i ← _batch_ip1, _batch_ip1 ← _batch_ip2
核心价值: 通过专用 prefetch_stream 将 cache prefetch 与 Forward 重叠。对于使用 UVM_CACHING 的大规模 Embedding,cache miss 延迟可达数十 ms,prefetch 能将其完全隐藏。内存开销为 3 个 batch。此 pipeline 仍使用 deprecated 的 _batch_i/_batch_ip1/_batch_ip2 成员变量。 Core value: Dedicated prefetch_stream overlaps cache prefetch with Forward. For large-scale Embeddings using UVM_CACHING, cache miss latency can reach tens of ms; prefetch fully hides this. Memory overhead: 3 batches. This pipeline still uses deprecated _batch_i/_batch_ip1/_batch_ip2 member variables.

8–9. Eval Pipelines

推理用流水线 — 后台线程 H2D + 无 Backward/OptimizerEvaluation pipelines — background thread H2D + no Backward/Optimizer

EvalPipelineSparseDist 使用 DataLoadingThread 在后台线程做 H2D,每次 progress() 都重新初始化 pipelined modules。

EvalPipelineFusedSparseDist 继承 TrainPipelineFusedSparseDist,使用 torch.no_grad() 包裹 forward。

EvalPipelineSparseDist uses DataLoadingThread for background H2D; re-initializes pipelined modules in every progress().

EvalPipelineFusedSparseDist inherits TrainPipelineFusedSparseDist, wraps forward in torch.no_grad().

EvalPipelineSparseDist Schedule

# pipeline_depth = 2, background thread H2D
# Unique: uses DataLoadingThread (separate Python thread) for H2D

PipelinePlan(
  schedule = {
    "H2D":              TaskSchedule(stage=0, stream=memcpy, thread_group="loader"), # stage 0: batch 入口 (background thread)
    "InputDistStart":   TaskSchedule(stage=1, stream=data_dist, globally_ordered=True),
    "InputDistWait":    TaskSchedule(stage=1, stream=data_dist),
    "WaitBatch":        TaskSchedule(stage=1, stream=default),
    "Forward":          TaskSchedule(stage=1, stream=default),  # stage 1: batch 出口 (no backward/optimizer)
  },
  pipeline_depth = 2,
)

10. StagedTrainPipeline

通用声明式 N-stage 流水线 — 用户自定义 Stage 列表Generic declarative N-stage pipeline — user-defined Stage list

SWPipeline Schedule

# pipeline_depth = N (= number of stages)
# Each PipelineStage has: name, runnable, stream, fill_callback, data_exhausted_callback
# Forward/Backward/Optimizer run OUTSIDE the pipeline (in user's train loop)

# Example: 2-stage pipeline for data_copy + gpu_postproc
StagedTrainPipeline(
  pipeline_stages = [
    PipelineStage(name="data_copy",    runnable=get_h2d_func("cuda"), stream=Stream()),
    PipelineStage(name="gpu_postproc", runnable=gpu_postproc,          stream=Stream()),
  ]
)

# Usage: pipeline outputs the processed batch, user handles FWD/BWD/OPT
while batch := pipeline.progress(dataloader_iter):
    optimizer.zero_grad()
    loss, pred = model(batch)
    loss.backward()
    optimizer.step()

关键特点:

  • 声明式: 用户通过 PipelineStage 列表声明流水线,不需要子类化
  • 前向专用: 只负责 forward 之前的 stage(H2D、post-processing、SDD 等),FWD/BWD/OPT 在用户训练循环中执行
  • Event 同步: Stage 之间通过 torch.cuda.Event 做同步,不使用 stream.wait_stream
  • Fill pipeline: 初始化时按 "batch 0 跑 N 个 stage,batch 1 跑 N-1 个 stage…" 的方式填满流水线
  • Flush 支持: set_flush(True) 会让流水线排空所有 in-flight batch 后调用 callback

Key features:

  • Declarative: Users declare pipeline via PipelineStage list, no subclassing needed
  • Pre-forward only: Handles stages before forward (H2D, post-processing, SDD), FWD/BWD/OPT run in user's train loop
  • Event sync: Inter-stage sync via torch.cuda.Event, not stream.wait_stream
  • Fill pipeline: Initialization fills pipeline: "batch 0 runs N stages, batch 1 runs N-1 stages, ..."
  • Flush support: set_flush(True) drains all in-flight batches then invokes callback
与 SWPipeline 的比较: StagedTrainPipeline 是 TorchRec 中最接近声明式 pipeline 的设计,但它只管 forward 之前的 stage,且没有 shortcut / profiling / DeclaredIO 等机制。SWPipeline 覆盖完整的 train loop(包括 FWD/BWD/OPT),并提供自动 profiling 和 zero-overhead shortcut。 Comparison with SWPipeline: StagedTrainPipeline is TorchRec's closest to declarative pipeline design, but it only manages pre-forward stages and lacks shortcut / profiling / DeclaredIO. SWPipeline covers the full train loop (including FWD/BWD/OPT) with automatic profiling and zero-overhead shortcut.

11. TrainPipelineSparseDistCompAutograd

SparseDist + Compiled Autograd — 编译后端加速 backwardSparseDist + Compiled Autograd — compiled backend accelerates backward

SWPipeline Schedule

# Same as TrainPipelineSparseDist, but wraps forward/backward
# in torch._dynamo.compiled_autograd context
# Enables inductor reordering passes for compute-comm overlap:
#   "sink_waits", "raise_comms", "reorder_compute_for_overlap"

PipelinePlan(
  schedule = {
    # identical to TrainPipelineSparseDist (stages corrected for SWPipeline convention)
    "H2D":              TaskSchedule(stage=0, stream=memcpy),         # stage 0: batch 入口
    "InputDistStart":   TaskSchedule(stage=1, stream=data_dist),
    "InputDistWait":    TaskSchedule(stage=1, stream=data_dist),
    "ZeroGrad":         TaskSchedule(stage=2, stream=default),
    "WaitBatch":        TaskSchedule(stage=2, stream=default),
    "Forward":          TaskSchedule(stage=2, stream=default),       # compiled_autograd ctx
    "Backward":         TaskSchedule(stage=2, stream=default),       # compiled_autograd ctx
    "OptimizerStep":    TaskSchedule(stage=2, stream=default),       # stage 2: batch 出口
  },
  intra_iter_deps = [
    ("InputDistStart", "H2D"),            # cross-stage: memcpy → data_dist stream sync
    ("InputDistWait",  "InputDistStart"),  # stage=1 ordering
    ("WaitBatch",      "InputDistWait"),   # cross-stage: data_dist → default stream sync
    ("Forward",        "InputDistWait"),   # cross-stage: data_dist → default stream sync
    ("WaitBatch",      "ZeroGrad"),        # stage=2 ordering
    ("Forward",        "WaitBatch"),       # stage=2 ordering
    ("Backward",       "Forward"),         # stage=2 ordering
    ("OptimizerStep",  "Backward"),        # stage=2 ordering
  ],
  inter_iter_deps = [
    ("Forward", "OptimizerStep"),        # Forward(i) 读取 OptimizerStep(i-1) 更新的权重
  ],                                           # same-stream (default) → FIFO 已保证 GPU 顺序
  pipeline_depth = 3,
)

print_schedule(5) 输出print_schedule(5) Output

   #  Task               Thread   Stream        | P0    P1    P2    P3    P4
  --  -----------------  -------  ------------  + ----- ----- ----- ----- -----
   0  ZeroGrad           default  default       |  --    --   i0    i1    i2
   1  WaitBatch          default  default       |  --    --   i0    i1    i2
   2  Forward            default  default       |  --    --   i0    i1    i2
   3  Backward           default  default       |  --    --   i0    i1    i2
   4  OptimizerStep      default  default       |  --    --   i0    i1    i2
   5  InputDistStart     default  data_dist     |  --   i0    i1    i2    i3
   6  InputDistWait      default  data_dist     |  --   i0    i1    i2    i3
   7  H2D                default  memcpy        | i0    i1    i2    i3    i4
# 与 SparseDist 完全相同,区别仅在 Forward+Backward 运行在 compiled_autograd 上下文中

CPU 提交顺序(同 SparseDist, 但 Forward+Backward 在 compiled_autograd 上下文中)CPU Submission Order (same as SparseDist, but Forward+Backward in compiled_autograd context)

##  progress(iter i): 与 SparseDist 完全相同
  1. ZeroGrad                                [i]      default
  2. WaitBatch                               [i]      default
  3. InputDistStart                          [i+1]    data_dist
  4. H2D                                     [i+2]    memcpy
  5. Forward (compiled_autograd ctx)           [i]      default
  6. InputDistWait                           [i+1]    data_dist
  7. Backward (compiled_autograd ctx)          [i]      default
  8. OptimizerStep                           [i]      default
编译优化: 通过 torch._dynamo.compiled_autograd + inductor 后端编译 backward graph。启用 reorder_for_compute_comm_overlap pass,让编译器自动调度计算与通信的重叠。第一次迭代使用 nullcontext()(无编译),之后启用编译。 Compilation optimization: Uses torch._dynamo.compiled_autograd + inductor backend to compile backward graph. Enables reorder_for_compute_comm_overlap passes for automatic compute-comm overlap scheduling by the compiler. First iteration uses nullcontext() (no compile), then enables compilation.

对比总表Comparison Matrix

11 种 Pipeline 的关键 Schedule 属性一览Key schedule properties of all 11 Pipeline types at a glance

维度Dimension Base PT2 SparseDist SDLite Fused SemiSync Prefetch EvalSD Staged CompAuto
Streams 2 1 3 2 3–4 3 4 3 N 3
Depth 2 1 3 2 3 3 3 2 N 3
线程Threads 1 1 1 1 1 1 1 1+bg 1 1
FX Rewrite
H2D 重叠overlap OPT FWD OPT FWD FWD FWD bg 自定义Custom FWD
InputDist 独立 streamSeparate 关键路径Critical 独立 streamSeparate Event 同步Event sync 独立 streamSeparate 独立 streamSeparate 独立 streamSeparate
EmbLookup 在 FWD 中In FWD 在 FWD 中In FWD 独立 streamSeparate 独立 stageSeparate 在 FWD 中In FWD 在 FWD 中In FWD 在 FWD 中In FWD
Prefetch
同步机制Sync stream stream stream stream Event stream stream Event stream
编译Compiled torch.compile compiled_autograd
Backward 分离 Emb/DenseSplit Emb/Dense
SWPipeline 的优势: 上述 11 种 Pipeline 的调度策略全部硬编码在 progress() 方法中,修改重叠方式需重写整个方法。SWPipeline 将调度提取为声明式 PipelinePlan,只需修改 Plan,Task 函数不变。此外:
  • Zero-overhead Shortcut — 缓存 Task 输出,跳过不需要重新执行的 Task
  • TaskProfiler — 自动测量每个 Task 的 exposed time,定量指导最优 stage 划分
  • DeclaredIO — 显式声明 Task 副作用,框架自动管理 capture/restore
  • 多线程 + NCCL Sequencer — 安全的多线程多 stream 执行,内置 NCCL 提交排序器
详见 SWPipeline 设计文档SWPipeline API
SWPipeline advantage: All 11 pipelines above hard-code scheduling in progress() — changing overlap requires rewriting the entire method. SWPipeline extracts scheduling into declarative PipelinePlan, only the Plan changes. Additionally:
  • Zero-overhead Shortcut — caches Task outputs, skips Tasks that don't need re-execution
  • TaskProfiler — automatically measures each Task's exposed time for optimal stage partitioning
  • DeclaredIO — explicitly declares Task side effects, framework auto-manages capture/restore
  • Multi-thread + NCCL Sequencer — safe multi-thread multi-stream execution with built-in NCCL submission ordering
See SWPipeline Design Docs and SWPipeline API.