Pipeline Deep Dive — Initialization & Rewrite

NVIDIA recsys-examples · TrainPipelineSparseDist · PrefetchTrainPipelineSparseDist · SWPipeline

⭐ GitHub: NVIDIA/recsys-examples

总览 — 两套 Pipeline 的初始化与 RewriteOverview — Initialization & Rewrite of Two Pipelines

recsys-examples 通过 model rewrite 将 ShardedModule 的 forward 替换为流水线化版本recsys-examples uses model rewrite to replace ShardedModule's forward with pipelined versions

recsys-examples 中有两套 TorchRec 训练 pipeline,它们在第一个 batch 到达时通过 _rewrite_model 将模型中的 ShardedModule.forward 替换为流水线化版本,使得 input_dist(AllToAll 通信)可以与上一批次的 forward / backward 重叠执行。

recsys-examples provides two TorchRec training pipelines that use _rewrite_model on the first batch to replace ShardedModule.forward with pipelined versions, enabling input_dist (AllToAll communication) to overlap with the previous batch's forward / backward.

__init__
_rewrite_model
fill_pipeline
progress() 稳态
📖 Forward 拆分机制详解:Pipeline 的核心是把 ShardedModule.forward() 拆成 input_dist + compute_and_output_dist,让 AllToAll 通信提前执行。详细原理、Context 数据流和 torch.fx 追踪机制请参阅 → Forward Hijack Deep Dive 📖 Forward Split Mechanism: The pipeline's core is splitting ShardedModule.forward() into input_dist + compute_and_output_dist, executing AllToAll communication ahead of time. For details on Context data flow and torch.fx tracing → Forward Hijack Deep Dive
属性Property TrainPipelineSparseDist PrefetchTrainPipelineSparseDist
流水线阶段Pipeline Stages 3: ① H2D → ② input_dist → ③ fwd/bwd 4: ① H2D → ② input_dist → ③ prefetch → ④ fwd/bwd
CUDA Streams memcpy_stream, data_dist_stream memcpy_stream, data_dist_stream, prefetch_stream, default_stream
流水线深度Pipeline Depth 2 batches 3 batches(batch_i, batch_ip1, batch_ip2)
Context TrainPipelineContext PrefetchTrainPipelineContext
Forward 替换类型Replacement PipelinedForward PrefetchPipelinedForward
Forward 数据来源Forward Data Source input_dist_tensors_requests → request.wait() in fwd module_input_post_prefetch → data already prefetched
适用场景Use Case 标准 FBGEMM 静态 EmbeddingStandard FBGEMM static embedding DynamicEmb(需 prefetch cache lookup)DynamicEmb (needs prefetch for cache lookup)

INIT _rewrite_model — Model Rewrite 详解Model Rewrite Deep Dive

_rewrite_model 是 pipeline 初始化的核心——它将模型中所有 ShardedModule.forward 替换为流水线化版本(PipelinedForwardPrefetchPipelinedForward)。这个过程只在第一个 batch 到达时执行一次,由 _pipeline_model() 调用。

_rewrite_model is the core of pipeline initialization — it replaces all ShardedModule.forward in the model with pipelined versions (PipelinedForward or PrefetchPipelinedForward). This runs only once when the first batch arrives, invoked by _pipeline_model().

_pipeline_model() → _rewrite_model() → _override_input_dist_forwards() 滚动查看完整时序图Scroll to view full sequence diagram
TrainPipeline
_rewrite_model
nn.Module
ShardedModule
data_dist
TrainPipeline _pipeline_model() _rewrite_model() utils.py:1257 nn.Module model (unwrapped) ShardedModule pipelined target data_dist _stream _rewrite_model(model, context, dist_stream, pipelined_forward) Step 1: Unwrap model layers DMP → DDP → Float16Module → nn.Module named_modules() {name: ShardedModule} dict Step 2: Collect all ShardedModule instances Step 3: Build ArgInfo (hack path) Scan batch attrs → find KJT field name loop [each SM] Step 4: Replace forward original_forwards.append(module.forward) module.forward = PipelinedForward(name, ArgInfo, module, ctx, stream) (pipelined_modules, model, original_forwards, postprocs) Step 5: Initialize input dists start_sparse_data_dist(batch, context) module.input_dist(ctx, kjt) on data_dist_stream → AllToAll #1 Step 6: Override KJTAllToAll forwards _override_input_dist_forwards() → fusing splits KJTAllToAll.forward → KJTAllToAllForward (supports fused splits)
utils.py — _rewrite_model(), _override_input_dist_forwards()
train_pipeline.py — _pipeline_model(), _init_pipelined_modules()
为什么需要两步? Step 5(start_sparse_data_dist)必须先执行一次 module.input_dist(),因为 ShardedModule 的 _input_dists 属性是惰性初始化的。只有执行过一次 input_dist 后,Step 6 的 _override_input_dist_forwards 才能找到并替换内部的 KJTAllToAll forward。 Why two steps? Step 5 (start_sparse_data_dist) must call module.input_dist() once because ShardedModule's _input_dists attribute is lazily initialized. Only after this first invocation can Step 6's _override_input_dist_forwards find and replace the internal KJTAllToAll forward.

FX _rewrite_model — torch.fx 追踪细节torch.fx Tracing Detail

_rewrite_model 有两条路径:FX 追踪路径(通过 torch.fx.Tracer 符号执行获取参数映射)和 hack 路径mod_directly=True,直接扫描 batch 属性找到 KJT 字段)。当前生产代码走的是 hack 路径,因为它更简单、更鲁棒。下图展示完整的 FX 路径逻辑,以及 hack 路径如何跳过 FX 追踪。

_rewrite_model has two paths: the FX tracing path (using torch.fx.Tracer for symbolic execution to derive argument mappings) and the hack path (mod_directly=True, scanning batch attributes to find the KJT field directly). Production code currently uses the hack path as it is simpler and more robust. The diagram below shows the full FX path logic and how the hack path bypasses FX tracing.

_rewrite_model — FX Trace Path vs Hack Path 参与者标签跟随滚动Participant labels stick on scroll
_rewrite_model
torch.fx Tracer
FX Graph
ShardedModule
PipelinedForward
Step 1: Unwrap model layers DMP → DDP → Float16Module → nn.Module named_modules() → collect all ShardedModule instances {name: ShardedModule} dict Step 3: _get_leaf_module_names() → trace depth control FX Trace Path (mod_directly=False) Tracer.trace(model, concrete_args={batch}) is_leaf_module(ShardedModule) → True (don't trace into) ShardedModule / FSDP = leaf modules → opaque placeholders in FX graph → produces FX Graph (symbolic IR) loop [node where op=="call_module" ∈ sharded_modules] _get_node_args(node) → ArgInfo ArgInfo(input_attrs, is_getitems, name) — maps FX node inputs to KJT fields returns List[ArgInfo] per ShardedModule loop [each ShardedModule in dict] save original: original_forwards[name] = module.forward module.forward = PipelinedForward(name, arg_info_list, module, ctx, dist_stream) ⚡ Hack Path (mod_directly=True) — ACTIVE PATH Skips FX tracing entirely. For each ShardedModule: 1. Scan batch.__dict__.values() → find attribute whose type matches first param of input_dist() 2. Build ArgInfo(input_attrs=[attr_name], is_getitems=[False], name=sm_name) directly 3. For each ShardedModule: module.forward = PipelinedForward(name, [simple_ArgInfo], module, ctx, dist_stream) 4. Same forward replacement, but ArgInfo is constructed by direct attribute matching, not FX symbolic execution Why hack path wins: • No torch.fx dependency → avoids trace failures on complex models (e.g., control flow, dynamic shapes) • Simpler ArgInfo (single KJT attr) covers all production use cases in recsys-examples Summary: _rewrite_model output pipelined_modules: List[ShardedModule] • original_forwards: Dict[str, Callable] • postprocs: Dict[str, PipelinePostProc] Every ShardedModule.forward is now PipelinedForward — reads from ctx instead of taking KJT directly PipelinedForward.__call__ contract: 1. Pop input_dist_tensors_requests[name] from context → Awaitable 2. request.wait() on dist_stream → get redistributed KJT 3. compute_and_output_dist(ctx, data)
utils.py:1257 — _rewrite_model(), _get_node_args(), _get_node_args_helper()

BASE fill_pipeline() — 基础 Pipeline 初始化Base Pipeline Initialization

TrainPipelineSparseDist 在首次 progress() 时填充 2 个 batch,对 batch_i 完成 rewrite + input_dist(AllToAll #1 发起 + 等待 → AllToAll #2/#3 发起)。

TrainPipelineSparseDist fills 2 batches on the first progress() call, completing rewrite + input_dist for batch_i (AllToAll #1 initiate + wait → AllToAll #2/#3 initiate).

TrainPipelineSparseDist.fill_pipeline() Pipeline depth = 2 batchesPipeline depth = 2 batches
DataLoader
TrainPipeline
memcpy stream
data_dist stream
ShardedModule
DataLoader TrainPipeline fill_pipeline() memcpy _stream data_dist _stream ShardedModule batch_i next(dataloader_iter) batch_i (CPU) _to_device(batch_i) [H2D async] _init_pipelined_modules → _pipeline_model → _rewrite_model (首次 / first time) start_sparse_data_dist(batch_i, ctx_i) module.input_dist() → AllToAll #1 wait_sparse_data_dist(ctx_i) AllToAll #1 done → #2/#3 started batch_i+1 next(dataloader_iter) batch_i+1 (CPU) _to_device(batch_i+1) [H2D async] ✅ Pipeline 填充完成 / Pipeline Filled batches = [batch_i, batch_i+1] • batch_i: input_dist ✓ (AllToAll #2/#3 in flight) • batch_i+1: H2D ✓ ⚠ AllToAll #2/#3 的等待(wait_sparse sub-stage 2)被延迟到 forward 内部的 PipelinedForward.__call__() 中
train_pipeline.py:260 — fill_pipeline()

PREFETCH _fill_pipeline() — Prefetch Pipeline 初始化Prefetch Pipeline Initialization

PrefetchTrainPipelineSparseDist 填充 2 个 batch,但对 batch_1 额外完成 wait_sparse + prefetch(DynamicEmb cache lookup),然后对 batch_2 启动 input_dist。比 Base Pipeline 多了一个完整的 prefetch 阶段。

PrefetchTrainPipelineSparseDist fills 2 batches, but additionally completes wait_sparse + prefetch (DynamicEmb cache lookup) for batch_1, then starts input_dist for batch_2. One extra prefetch stage compared to Base Pipeline.

PrefetchTrainPipelineSparseDist._fill_pipeline() 注意额外的 prefetch 阶段Note the additional prefetch stage
DataLoader
PrefetchPipeline
memcpy
data_dist
prefetch
ShardedModule
DataLoader PrefetchPipeline _fill_pipeline() memcpy _stream data_dist _stream prefetch _stream ShardedModule batch_1 (self._batch_i) next(dataloader_iter) _to_device + shuffle [H2D] _init_pipelined_modules → _rewrite_model: forward = PrefetchPipelinedForward(...) _start_sparse_data_dist(batch_1) AllToAll #1 (splits) _wait_sparse_data_dist() #1 done → #2 + #3 started 🔑 _prefetch(batch_1) — DynamicEmb 独有 _prefetch_embeddings(batch_1, context, ...) request.wait() — AllToAll #2/#3 完成 sharded_module.prefetch(ctx, data) hash lookup → cache → slot_indices ctx.module_input_post_prefetch[name] = data ctx.module_contexts_post_prefetch[name] = ctx batch_2 (self._batch_ip1) next(dataloader_iter) _to_device + shuffle [H2D] _start_sparse_data_dist(batch_2) ✅ Pipeline 填充完成 / Pipeline Filled batch_1: H2D ✓ input_dist ✓ wait_sparse ✓ prefetch ✓ • batch_2: H2D ✓ input_dist started ✅ 首次 progress() 进 forward 时,DynamicEmb 的 slot_indices 已就绪,PrefetchPipelinedForward 零等待
train_pipeline.py:543 — PrefetchTrainPipelineSparseDist._fill_pipeline()

FWD Forward 替换对比Replacement Comparison

两种 Forward 类都继承自 BaseForward,核心区别在于 __call__ 中数据的来源和等待时机。

Both Forward classes inherit from BaseForward. The key difference is where data comes from and when waiting occurs in __call__.

PipelinedForward

def __call__(self, *input, **kwargs):
    # 从 context 取 awaitable
    request = self._context \
        .input_dist_tensors_requests.pop(name)

    # ⚠ 在 forward 内部等待 AllToAll 完成
    with cuda.stream(self._stream):
        data = request.wait()  # blocking!

    # 同步 stream
    current_stream.wait_stream(self._stream)
    data.record_stream(current_stream)

    return module.compute_and_output_dist(ctx, data)

PrefetchPipelinedForward

def __call__(self, *input, **kwargs):
    # ✅ 直接取已 prefetch 的数据
    data = self._context \
        .module_input_post_prefetch.pop(name)
    ctx = self._context \
        .module_contexts_post_prefetch.pop(name)

    # 同步 prefetch_stream → default
    current_stream.wait_stream(self._stream)
    data.record_stream(current_stream)

    return module.compute_and_output_dist(ctx, data)
Forward 调用链对比Forward Call Chain Comparison 左 = Base,右 = PrefetchLeft = Base, Right = Prefetch
PipelinedForward.__call__() ① context.input_dist_tensors_requests.pop(name) ② with cuda.stream(dist_stream): data = request.wait() ← ⚠ AllToAll #2/#3 等待 ③ current_stream.wait_stream(dist_stream) ④ module.compute_and_output_dist(ctx, data) ⚠ 如果 AllToAll 未完成,forward 在此阻塞 If AllToAll not done, forward blocks here PrefetchPipelinedForward.__call__() ① data = context.module_input_post_prefetch.pop(name) ② ctx = context.module_contexts_post_prefetch.pop(name) ③ current_stream.wait_stream(prefetch_stream) ④ module.compute_and_output_dist(ctx, data) ✅ 零等待!数据已在 prefetch 阶段就绪 Zero wait! Data already ready from prefetch stage DynamicEmb slot_indices 在 prefetch_stream 上已计算完成

CTX Context 数据结构对比Dataclass Comparison utils.py:90

TrainPipelineContext

@dataclass
class TrainPipelineContext:
  # AllToAll #1 awaitable
  input_dist_splits_requests: Dict[str, Awaitable]

  # AllToAll #2+#3 awaitable
  input_dist_tensors_requests: Dict[str, Awaitable]

  # ShardedModule contexts
  module_contexts: Dict[str, Multistreamable]

  # Fused splits for batched AllToAll
  fused_splits_awaitables: List[...]

  # Postproc cache
  postproc_fwd_results: Dict[str, Any]

  events: List[torch.Event]
  index: Optional[int]
  version: int  # 0=legacy, 1=current

PrefetchTrainPipelineContext

@dataclass
class PrefetchTrainPipelineContext(
    TrainPipelineContext
):
  # ↓↓↓ NEW: prefetch 输出 ↓↓↓

  # prefetch 后的 KJTList 数据
  module_input_post_prefetch:
      Dict[str, Multistreamable]

  # prefetch 后的 module context
  module_contexts_post_prefetch:
      Dict[str, Multistreamable]

  # (deprecated v0 fields)
  module_input_post_prefetch_next_batch: ...
  module_contexts_post_prefetch_next_batch: ...
数据流向: Base pipeline 中,input_dist_tensors_requestsPipelinedForward.__call__ 中被 pop 并 wait()。Prefetch pipeline 中,_prefetch_embeddings() 先消费 input_dist_tensors_requests,执行 prefetch 后将结果存入 module_input_post_prefetch,再由 PrefetchPipelinedForward.__call__ 消费。 Data flow: In Base pipeline, input_dist_tensors_requests is popped and wait()ed in PipelinedForward.__call__. In Prefetch pipeline, _prefetch_embeddings() first consumes input_dist_tensors_requests, runs prefetch, stores results in module_input_post_prefetch, which is then consumed by PrefetchPipelinedForward.__call__.

CTX Context 数据流详解Data Flow Deep Dive

TrainPipelineContext 是 pipeline 各阶段之间的数据桥梁。每个 batch 拥有独立的 context 实例,数据在 pipeline 阶段之间通过 context 字典传递。下图展示 Base Pipeline 和 Prefetch Pipeline 中,context 字段在每个阶段如何被写入和消费。

TrainPipelineContext bridges data between pipeline stages. Each batch owns an independent context instance; data flows between stages via context dictionaries. The diagram below shows how context fields are populated and consumed at each stage in both Base and Prefetch pipelines.

Context 数据流图Context Data Flow Diagram 参与者标签跟随滚动Participant labels stick on scroll
Base Pipeline: TrainPipelineContext 数据流 ① _start_data_dist(batch, ctx) on data_dist_stream: for sm in sharded_modules: awaitable = sm.input_dist(sm_ctx, kjt) WRITES to context: input_dist_splits_requests[name] = awaitable ② wait_sparse_data_dist(ctx) READS: fused_splits_awaitables fused_awaitable.wait() → AllToAll #1 done for sm: request = input_dist(sm_ctx, data) WRITES to context: input_dist_tensors_requests[name] = request module_contexts[name] = sm_ctx (copied from _next_batch) ③ PipelinedForward.__call__ POPS from context: request = input_dist_tensors_requests.pop(name) with cuda.stream(dist_stream): data = request.wait() ← ⚠ blocks! ctx = module_contexts.pop(name) → compute_and_output_dist(ctx, data) ⚠ Base Pipeline 关键瓶颈 / Base Pipeline bottleneck: request.wait() in forward blocks GPU default stream until AllToAll #2/#3 completes on data_dist_stream Base Context 完整生命周期 / Full Lifecycle: _start_data_dist splits_requests wait_sparse_data_dist tensors_requests PipelinedForward wait() + compute module_contexts: written by wait_sparse → popped by PipelinedForward | module_contexts_next_batch: temp buffer between stages Prefetch Pipeline: PrefetchTrainPipelineContext 数据流 ① _start_data_dist + ② wait_sparse_data_dist → 同 Base Pipeline / same as Base populates input_dist_tensors_requests + module_contexts (AllToAll #2/#3 complete) ③ _prefetch_embeddings(batch, ctx, pipelined_modules) on prefetch_stream: POPS from context (consumes AllToAll result): data = input_dist_tensors_requests[name].wait() sm.prefetch(ctx=sm_ctx, dist_input=data) DynamicEmb: hash lookup → HKV cache → slot_indices WRITES to context (prefetch output): module_input_post_prefetch[name] = data module_contexts_post_prefetch[name] = sm_ctx ④ PrefetchPipelinedForward.__call__ POPS from context (already prefetched!): data = module_input_post_prefetch.pop(name) ctx = module_contexts_post_prefetch.pop(name) current_stream.wait_stream(prefetch_stream) data.record_stream(current_stream) ✅ 零等待!/ Zero wait! data and slot_indices already computed during prefetch stage → compute_and_output_dist(ctx, data) Prefetch Context 完整生命周期 / Full Lifecycle: start_dist wait_sparse tensors_req prefetch_emb post_prefetch PrefetchPipelinedFwd 0-wait! Key insight: prefetch_embeddings moves the wait() from forward into prefetch_stream, so forward never blocks for AllToAll 对比 / Comparison: Base: forward 内 wait AllToAll → GPU idle | Prefetch: forward 前 prefetch 阶段已完成 wait + cache lookup → forward 零等待 Context 字段读写映射 / Field Read-Write Map: input_dist_splits_requests: W=_start_data_dist R=wait_sparse | input_dist_tensors_requests: W=wait_sparse R=PipelinedFwd (base) / _prefetch_emb (prefetch) module_contexts: W=wait_sparse R=PipelinedFwd | module_input_post_prefetch: W=_prefetch_emb R=PrefetchPipelinedFwd | module_contexts_post_prefetch: W=_prefetch_emb R=PrefetchPipelinedFwd
utils.py:90 — TrainPipelineContext, PrefetchTrainPipelineContext
utils.py:1117 — _start_data_dist(), _prefetch_embeddings(), PipelinedForward, PrefetchPipelinedForward

CMP 稳态 progress() 对比Steady-State progress() Comparison

JaggedMegatron* 生产变体为例,展示稳态下每次 progress() 的执行步骤。

Using JaggedMegatron* production variants as examples, showing the steady-state execution steps of each progress() call.

步骤Step JaggedMegatronTrainPipelineSparseDist JaggedMegatronPrefetchTrainPipelineSparseDist
1 zero_grad() zero_grad()
2 wait_for_batch(batch_i) wait_prefetch_async(batch_i)
CPU 等待 background thread prefetch 完成
3 start_sparse_data_dist(batch_ip1) copy_batch_to_gpu(batch_ip2) [H2D]
4 copy_batch_to_gpu(batch_ip2) [H2D] wait_sparse_data_dist()
5 wait_sparse_data_dist(batch_ip1) Shuffle Phase 1 (AllGather workloads)
6 Shuffle Phase 1 (AllGather + submit KK) forward(batch_i)
PrefetchPipelinedForward: 零等待
7 forward(batch_i)
PipelinedForward: wait AllToAll #2/#3
_start_prefetch_async(batch_ip1)
background thread 异步 prefetch
8 Shuffle Phase 2 (finish_shuffle) Shuffle Phase 2 (finish_shuffle)
9 loss AllReduce + backward + optimizer loss AllReduce + backward + optimizer
10 dequeue_batch() _start_sparse_data_dist(batch_ip2)
11 batch_i ← batch_ip1, batch_ip1 ← batch_ip2
关键差异总结:
AllToAll 等待时机: Base 在 forward 内部(Step 7);Prefetch 在 background thread(Step 2 wait_prefetch_async)
Prefetch: Base 无;Prefetch 在 forward 后异步提交(Step 7),与 backward 重叠
input_dist 启动时机: Base 在 forward 前(Step 3);Prefetch 在 backward 后(Step 10)
Pipeline depth: Base = 2 batches;Prefetch = 3 batches
Key Differences:
AllToAll wait timing: Base waits inside forward (Step 7); Prefetch waits in background thread (Step 2)
Prefetch: None in Base; Prefetch submits async after forward (Step 7), overlaps with backward
input_dist start timing: Base before forward (Step 3); Prefetch after backward (Step 10)
Pipeline depth: Base = 2 batches; Prefetch = 3 batches

SWPipeline — 软件流水线框架SWPipeline — Software Pipeline Framework

将训练迭代分解为声明式 Task DAG,通过 stage 划分实现跨迭代重叠Decomposes training iteration into a declarative Task DAG, overlapping across iterations via stages

完整文档已迁移:SWPipeline 的完整文档(数据类、DeclaredIO 设计、全部 API 参考、Shortcut 机制、TaskProfiler)已独立为 sw_pipeline_overview.html。此处仅保留概览和 SWSerialTrainPipeline 适配器的上下文。 Full docs have moved: SWPipeline's complete documentation (data classes, DeclaredIO design, full API reference, Shortcut mechanism, TaskProfiler) is now at sw_pipeline_overview.html. This section retains only the overview and SWSerialTrainPipeline adapter context.

SWPipeline 是 recsys-examples 新增的通用软件流水线框架,将上面描述的硬编码阶段分解为 声明式 Task DAG。核心设计将 "做什么"(PipelineTask)"怎么调度"(PipelinePlan)完全分离——同一组 Task 可以用不同的 Plan 切换 serial / pipelined 模式。

  • PipelineTask — 可调度计算单元(name + fn + io)
  • TaskSchedule — 每个 Task 的调度属性(stage, stream, thread_group, globally_ordered)
  • PipelinePlan — 完整调度方案(schedule + deps + cross_iter_deps + depth)
  • SWPipeline — 执行引擎:worker 线程、两阶段同步、shortcut 缓存
  • DeclaredIO — 外部副作用声明(详见独立文档

SWPipeline is a new general-purpose software pipeline framework in recsys-examples that decomposes the hard-coded stages described above into a declarative Task DAG. It cleanly separates "what to compute" (PipelineTask) from "how to schedule" (PipelinePlan) — the same Tasks can run with different Plans to switch between serial / pipelined modes.

  • PipelineTask — Schedulable computation unit (name + fn + io)
  • TaskSchedule — Per-task scheduling properties (stage, stream, thread_group, globally_ordered)
  • PipelinePlan — Complete scheduling plan (schedule + deps + cross_iter_deps + depth)
  • SWPipeline — Execution engine: worker threads, two-phase sync, shortcut caching
  • DeclaredIO — Side effect declaration (see dedicated docs)
PipelineTask
+
PipelinePlan
SWPipeline
fill + progress
类关系图Class Relationship Diagramsw_pipeline.py + sw_train_pipeline.py
PipelineTask name: str fn: Callable[[IterContext], None] __hash__ → hash(name) __eq__ → name == other.name TaskSchedule stage: int = 0 stream: Optional[Stream] thread_group: str = "default" globally_ordered: bool PipelinePlan schedule: Dict[Task,Sched] deps: List[Tuple[Task,Task]] cross_iter_deps: List[…] pipeline_depth: int = 2 IterContext batch: Any iter_idx: int + arbitrary attrs via tasks SWPipeline State: _defs: Dict[str, PipelineTask] _stage_map / _stream_map / _thread_group_map _topo_order / _enqueue_order _cuda_events / _cpu_signals (pre-allocated per slot) _work_queues: Dict[str, Queue] _sequencer: _SubmissionSequencer Methods: fill_pipeline(data_iter) ← prefetch depth batches, start workers progress(data_iter) ← retire oldest, enqueue next period run_one_serial_iter(batch, idx) ← single serial iteration enable_shortcut / disable_shortcut _submit_task(iter_idx, task_def) ← core execution with two-phase sync plan → __init__ per-iteration ctx SWSerialTrainPipeline _sw_pipeline: SWPipeline _pipeline_ctx: PrefetchCtx _pipelined_modules: List[ShardedModule] _build_pipeline() → SWPipeline _ensure_initialized(batch) progress(dataloader_iter) registered as "jagged_sw_serial" has-a

SWSerialTrainPipeline — 11-Task 依赖图SWSerialTrainPipeline — 11-Task Dependency Graph

一次训练迭代分解为 11 个 PipelineTask,serial 模式全部 stage 0One training iteration decomposed into 11 PipelineTasks, serial mode all stage 0

_build_pipeline() 构造 11 个 Task 及其依赖。Serial 模式 (depth=1) 全部在 stage 0 / default stream / 主线程上执行。DenseForward(6) 和 EmbBackward(9) 在 serial 模式下是 no-op,为未来 pipelined 变体预留(embedding/dense forward/backward 分离)。

_build_pipeline() constructs 11 Tasks with dependencies. Serial mode (depth=1) executes all on stage 0 / default stream / main thread. DenseForward(6) and EmbBackward(9) are no-ops in serial mode, reserved for future pipelined variants (embedding/dense forward/backward split).

11-Task Dependency DAG_build_pipeline() — all stage 0, depth 1
DATA EMBEDDING COMPUTE GRADIENT UPDATE H2DAndShuffle _to_device + batch_shuffler EmbInputDistStart _start_data_dist (AllToAll) EmbInputDistWait fused_splits_awaitables.wait() EmbPrefetch _prefetch_embeddings (DynamicEmb cache) EmbForward model(batch) + zero_grad DenseForward no-op (reserved) LossPostprocess loss allreduce + collective_assert DenseBackward loss.backward() EmbBackward no-op (reserved) FinalizeGrads finalize_model_grads (DDP AllReduce) OptimStep LEGEND Intra-iter dep Active task No-op (reserved)

两阶段同步协议Two-Phase Synchronization Protocol

CPU threading.Event + CUDA Event 防止 wait-before-record 竞态CPU threading.Event + CUDA Event prevents the wait-before-record race

SWPipeline 多线程执行引入经典竞态:若 Task B 在 Task A 完成 event.record() 前就调 stream.wait_event(event),wait 立即返回(event 未 record)→ GPU 执行序被打破。

  • Phase 1 (CPU):上游 Task 完成后 set threading.Event;下游先 cpu_signal.wait() 确保 CUDA event 已 record
  • Phase 2 (GPU):CPU wait 返回后调 stream.wait_event(cuda_event),建立正确 GPU 依赖

SWPipeline's multi-threaded execution introduces a classic race: if Task B calls stream.wait_event(event) before Task A completes event.record(), wait returns immediately (event not recorded) → GPU execution order breaks.

  • Phase 1 (CPU): Upstream Task sets threading.Event after completion; downstream cpu_signal.wait() ensures CUDA event is recorded
  • Phase 2 (GPU): After CPU wait returns, calls stream.wait_event(cuda_event) to establish correct GPU dependency
两阶段同步序列图Two-Phase Sync — Sequence Diagram_submit_task() internals
Worker A
Worker B
CUDA Stream X
CUDA Stream Y
Worker Thread A Worker Thread B CUDA Stream X CUDA Stream Y ① with cuda.stream(X): fn(ctx) ② cuda_event[A].record(stream_X) ③ cpu_signal[A].set() Phase 1: CPU wait cpu_signal[A].wait() → OK Phase 2: GPU wait stream_Y.wait_event(cuda_event[A]) GPU dependency ⑤ with cuda.stream(Y): fn(ctx) ⑥ cuda_event[B].record(stream_Y) ⑦ cpu_signal[B].set() Key: Without Phase 1, Thread B might call stream_Y.wait_event() before Thread A records → wait becomes no-op → data race on GPU

_ensure_initialized — 首批次触发模型 Rewrite_ensure_initialized — First Batch Triggers Model Rewrite

在 EmbInputDistStart task 内部调用,复用 _rewrite_model + PrefetchPipelinedForwardCalled inside EmbInputDistStart task, reuses _rewrite_model + PrefetchPipelinedForward

_ensure_initialized 序列图_ensure_initialized Sequencefirst batch only
progress()
SWPipeline
EmbInputDistStart
_rewrite_model
progress() SWPipeline EmbInputDistStart _rewrite_model run_one_serial_iter(batch, 0) → H2DAndShuffle exec EmbInputDistStart initialized?=No _rewrite_model() collect ShardedModules build ArgInfo → PrefetchPipelinedFwd init input_dist _start_data_dist(batch_0) _override_input_dist_forwards _initialized = True continue: ③ EmbInputDistWait → … → ⑪ OptimStep dist_stream=None (serial) → all on default stream. PrefetchPipelinedForward installed for zero-wait embedding lookup.
与旧 pipeline 区别:旧 pipeline 在 fill_pipeline() 中显式预取 2-3 batch。SWSerial (depth=1) 不需要预取,在第一个 batch 内通过 _ensure_initialized 完成初始化。未来 pipelined 变体 (depth>1) 将使用 SWPipeline.fill_pipeline() 预取多 batch。 Difference from old pipeline: Old pipeline explicitly prefetches 2-3 batches in fill_pipeline(). SWSerial (depth=1) needs no prefetch — init completes within the first batch via _ensure_initialized. Future pipelined variant (depth>1) will use SWPipeline.fill_pipeline() for multi-batch prefetch.

Shortcut — 任务缓存与跳过机制Shortcut — Task Caching & Skip Mechanism

SWPipeline 内置的任务级快照/回放引擎,支持 Profiling 和调试SWPipeline's built-in task-level snapshot/replay engine for profiling and debugging

📖 Shortcut 机制的完整文档(三阶段执行模式、_GraftGrad 梯度嫁接、安全 shortcut 条件)和 API 参考已迁移至 sw_pipeline_overview.html § Shortcut 📖 Complete Shortcut documentation (three execution modes, _GraftGrad, safe shortcut conditions) and API reference has moved to sw_pipeline_overview.html § Shortcut.

Shortcut 是 SWPipeline 的核心基础设施之一。它允许在运行时跳过指定 Task 的计算,转而回放之前缓存的 IterContext 输出。这不是简单的 "删除"——它必须:

  • 精确还原 Task 对 IterContext 的全部副作用(新增/修改/删除的属性)
  • 保持 autograd 图的连通性:下游 backward 能传梯度到上游参数
  • globally_ordered Task 仍然执行 _SubmissionSequencer 保序,避免跨 rank NCCL 死锁
  • 所有 rank 必须同步 enable/disable 相同 Task

Shortcut is one of SWPipeline's core infrastructure features. It allows skipping a Task's computation at runtime, replaying previously cached IterContext outputs instead. This is not a simple "delete" — it must:

  • Faithfully replay all side effects on IterContext (added/modified/deleted attrs)
  • Maintain autograd graph connectivity: downstream backward propagates grads to upstream params
  • Still run _SubmissionSequencer for globally_ordered Tasks to avoid cross-rank NCCL deadlocks
  • All ranks must enable/disable the same Tasks synchronously

三阶段执行模式 — _exec_task()Three-Phase Execution — _exec_task()

_exec_task 决策流程_exec_task Decision Flowsw_pipeline.py:391–411
_exec_task(task_def, ctx) name ∈ _shortcut_tasks? No Mode 3: Normal task_def.fn(ctx) 零额外开销 Yes name ∈ _shortcut_cache? Yes (cache hit) Mode 1: Shortcut _apply_shortcut(ctx, name) 回放缓存值 + _GraftGrad 嫁接梯度 No (first time) Mode 2: Caching ① before = snapshot ctx attrs ② task_def.fn(ctx) // 正常执行 ③ _capture_and_cache(ctx, name, before) diff before/after → cache produced + deleted _apply_shortcut 详解: 1. 从 cache 取出 (produced, deleted, input_attr_names, zero_grads) 2. 收集上游 inputs: [getattr(ctx, attr) for attr in input_attr_names] 3. 删除 ctx 上被 Task 删掉过的属性 (for k in deleted: delattr) 4. 恢复 produced 值: _restore_val() → detach (零拷贝) 5. 对 requires_grad 的 Tensor: _GraftGrad.apply(restored, zeros, *inputs) → forward: 返回 restored 不变 → backward: 传 grad_output 给 restored; 传 zeros 给 upstream inputs globally_ordered Task 的特殊处理: 即使 shortcut 跳过了 fn(ctx),_SubmissionSequencer 仍然会执行 execute_ordered(),保证跨 rank 的 NCCL 调用顺序一致 → 避免死锁。只跳计算,不跳保序。

缓存数据结构Cache Data Structure

_shortcut_cache[name] 是一个 5-tuple:

_shortcut_cache[name] is a 5-tuple:

字段Field 类型Type 含义Meaning
produced Dict[str, Any] Task 执行后在 IterContext新增或修改的属性。Tensor 存为 (detached_clone, requires_grad) 二元组。Attributes added or modified on IterContext by the Task. Tensors stored as (detached_clone, requires_grad) tuples.
deleted frozenset Task 执行后从 IterContext删除的属性名集合。Attribute names deleted from IterContext by the Task.
input_attr_names tuple Task 执行前已存在、执行后未被修改、且 requires_grad=True 的 Tensor 属性名。这些是"上游输入",_GraftGrad 需要它们来维持梯度链。Tensor attrs existing before Task, unchanged after, with requires_grad=True. These are "upstream inputs" that _GraftGrad needs to maintain the grad chain.
zero_grads Tuple[Tensor, ...] 为每个 upstream input 预分配的零张量,在 _GraftGrad.backward 中直接返回,避免每次 backward 重复分配。Pre-allocated zero tensors for each upstream input, returned directly in _GraftGrad.backward to avoid repeated allocations.
io_cached List[Any] 每个 DeclaredIO.capture() 的返回值经 _cache_val 递归缓存。_apply_shortcut 时调用对应的 dio.restore(_restore_val(cached)) 将外部状态写回。若 Task 无 io,此列表为空。Each DeclaredIO.capture() return value recursively cached via _cache_val. During _apply_shortcut, the corresponding dio.restore(_restore_val(cached)) writes external state back. Empty list if Task has no io.

_GraftGrad — Autograd 梯度嫁接_GraftGrad — Autograd Gradient Grafting

这是 Shortcut 中最精妙的部分。当 Task 被 shortcut 时,它产出的 Tensor 是从 cache 恢复的 detach() 叶节点——与上游计算图断开了。如果下游 loss.backward() 需要梯度流过这里,就会中断。

_GraftGrad 是一个自定义 torch.autograd.Function,作用如同"梯度桥":

  • Forward:identity,直接返回 restored tensor(不改值)
  • Backward:将 grad_output 传给 restored(下游正常收梯度);将预分配的零张量传给所有 upstream inputs(触发它们的 backward 链,但因为是零所以不影响参数梯度

这保证了:即使 Task 被跳过,上游参数仍然参与 backward 图(不会报 "unused parameter" 错误),但梯度贡献为零(符合"这个 Task 没有真正计算"的语义)。

This is the most elegant part of Shortcut. When a Task is shortcut, its output Tensors are detach()ed leaf nodes restored from cache — disconnected from the upstream computation graph. If downstream loss.backward() needs gradient flow through here, it breaks.

_GraftGrad is a custom torch.autograd.Function acting as a "gradient bridge":

  • Forward: identity — returns restored tensor unchanged
  • Backward: passes grad_output to restored (downstream gets normal grads); passes pre-allocated zeros to all upstream inputs (triggers their backward chains, but zeros mean no parameter gradient impact)

This ensures: even with Task skipped, upstream params still participate in the backward graph (no "unused parameter" errors), but contribute zero gradients (semantically correct for "Task didn't actually compute").

_GraftGrad Autograd 图_GraftGrad Autograd Graphsw_pipeline.py:317–342
Normal Task (左) vs Shortcut Task with _GraftGrad (右) Normal Execution upstream input (requires_grad) task_def.fn(ctx) ctx.output loss ∂L/∂input ≠ 0 Shortcut + _GraftGrad upstream input cached output (detach, leaf) _GraftGrad.apply( restored, zero_grads, *upstream) ctx.output (= restored, unchanged) loss grad_output zeros → ∂L/∂input = 0 fn(ctx) SKIPPED

哪些 Task 可以安全 Shortcut?Which Tasks Can Be Safely Shortcut?

条件Condition 可否 ShortcutShortcuttable? 原因Reason
ctx 计算 Task(无外部副作用)Pure ctx compute Task (no side effects) 输出可 detach+clone 缓存,_GraftGrad 维持梯度Outputs can be detach+clone cached, _GraftGrad maintains grads
globally_ordered Task(含 NCCL)globally_ordered Task (with NCCL) _SubmissionSequencer 仍执行保序,只跳 fn() 计算_SubmissionSequencer still fires for ordering, only fn() is skipped
有外部副作用的 Task(写 pipeline_ctx、module 状态等)Tasks with side effects (pipeline_ctx, module state, etc.) 需声明 DeclaredIOwith DeclaredIO 框架自动 capture/restore 外部状态,可独立 shortcutFramework auto capture/restore external state, can shortcut independently
不可重放的 Task(如 checkpoint 写盘)Non-replayable Tasks (e.g. checkpoint writes) 副作用不可逆,无法通过 capture/restore 重放Side effects are irreversible, cannot be replayed via capture/restore

DeclaredIO — 管理 Task 的外部副作用DeclaredIO — Managing Task Side Effects

将隐式副作用 (implicit side effects) 提升为显式声明 (declared effects),让框架自动处理 shortcut 缓存/恢复Elevate implicit side effects into declared effects so the framework can automatically handle shortcut capture/restore

📖 DeclaredIO 的完整设计文档(独立示例、执行流程图、概念总结)已迁移至 sw_pipeline_overview.html § DeclaredIO。此处保留在推荐系统中的具体应用说明。 📖 Complete DeclaredIO documentation (standalone example, execution flow diagram, concept summary) has moved to sw_pipeline_overview.html § DeclaredIO. This section retains the recommendation system-specific usage details.

问题:Shortcut 的盲区The Problem: Shortcut's Blind Spot DESIGN

Shortcut 通过 diff IterContext 的属性来缓存/恢复 task 的输出。但许多 task 有框架看不见的副作用 (side effects)——它们读写了 ctx 之外的共享状态:

  • PrefetchTrainPipelineContext(TorchRec 的 pipeline 共享状态:fused_splits_awaitablesinput_dist_tensors_requestsmodule_input_post_prefetch 等)
  • Module 内部状态(如 SplitPrefetchPipelinedForward._cached_awaitable

如果不处理这些副作用,shortcut 一个 task 后,下游 task 会从共享状态中读到过期数据,导致 AssertionError 或 "backward through graph a second time" 等错误。

Shortcut diffs IterContext attributes to cache/restore a task's output. But many tasks have side effects invisible to the framework — they read/write shared state outside of ctx:

  • PrefetchTrainPipelineContext (TorchRec pipeline shared state: fused_splits_awaitables, input_dist_tensors_requests, module_input_post_prefetch, etc.)
  • Module internal state (e.g. SplitPrefetchPipelinedForward._cached_awaitable)

Without handling these, shortcutting a task causes downstream tasks to read stale data from shared state, leading to AssertionError or "backward through graph a second time" errors.

旧方案 vs 新方案Before vs After

旧:手动 MirrorBefore: Manual Mirror

// Task B: 写入 shared 状态
def encode(ctx):
    ctx.embedding = encoder(ctx.features)
    shared.buffer["emb"] = ctx.embedding
    ctx._mirror_buffer = dict(shared.buffer) // ← 手动复制到 ctx

// Task C: 读取 shared 状态
def decode(ctx):
    shared.buffer = dict(ctx._mirror_buffer) // ← 手动恢复
    ctx.output = decoder(shared.buffer["emb"])

每个 task 都要写 mirror 代码;下游必须知道上游的命名约定;新增 task 时 mirror 链脆弱。

Every task needs mirror code; downstream must know upstream's naming conventions; mirror chain is fragile when adding new tasks.

新:DeclaredIOAfter: DeclaredIO

// Task B: 纯业务逻辑
def encode(ctx):
    ctx.embedding = encoder(ctx.features)
    shared.buffer["emb"] = ctx.embedding  // 正常写

// Task C: 纯业务逻辑
def decode(ctx):
    ctx.output = decoder(shared.buffer["emb"]) // 正常读

// 副作用声明在 task 定义处
PipelineTask("encode", encode, io=[
    DeclaredIO(
        capture=lambda: dict(shared.buffer),
        restore=lambda s: shared.buffer.update(s),
    ),
])

DeclaredIO 数据结构DeclaredIO Data Structure

DeclaredIO 是一个简单的 dataclass,声明 task 对外部状态的读写合约:

DeclaredIO is a simple dataclass declaring the task's read/write contract with external state:

字段Field 类型Type 含义Meaning
capture Callable[[], Any] Task 执行调用。快照外部状态,返回可缓存的值。返回值经 _cache_val 递归 detach+clone 所有 Tensor。Called after Task runs. Snapshots external state, returns a cacheable value. Return value goes through _cache_val (recursive tensor detach+clone).
restore Callable[[Any], None] Shortcut 时调用。接收 _restore_val 处理后的值(全新 detached Tensor),写回外部状态,确保下游 task 能找到正确数据。Called during shortcut. Receives value processed by _restore_val (fresh detached tensors), writes back to external state so downstream tasks find correct data.

执行流程Execution Flow

_exec_task 中 DeclaredIO 的集成DeclaredIO Integration in _exec_task
Shortcut 缓存迭代 (左) vs 回放迭代 (右)Caching Iteration (left) vs Replay Iteration (right) 缓存迭代 (首次)Caching Iteration (first time) before_ids = {k: id(v) for k, v in vars(ctx)} task.fn(ctx) 正常执行normal run ctx diff (框架自动)ctx diff (framework auto) produced = {k: _cache_val(v) for k, v in after if k not in before or id(v) != before[k]} DeclaredIO capture (框架自动)DeclaredIO capture (framework auto) io_cached = [_cache_val(dio.capture()) for dio in task.io] 存入缓存Store in Cache _shortcut_cache[name] = ( produced, deleted, input_attrs, zero_grads, io_cached) ctx 属性已更新attrs updated 外部状态external state 已更新 (fn 的副作用)updated (fn side effect) → 下游 task 正常读取 ctx 和外部状态→ downstream reads ctx and external state normally 回放迭代 (Shortcut)Replay Iteration (Shortcut) task.fn(ctx) SKIP ctx 回放 (框架自动)ctx replay (framework auto) for k in deleted: delattr(ctx, k) for k, v in produced: setattr(ctx, k, _restore_val(v) | _GraftGrad) DeclaredIO restore (框架自动)DeclaredIO restore (framework auto) for dio, cached in zip(task.io, io_cached): dio.restore(_restore_val(cached)) ctx 从缓存恢复restored from cache 外部状态external state 从缓存恢复restored from cache → 下游 task 正常读取 ctx 和外部状态 ✓→ downstream reads ctx and external state normally ✓ 核心:task function 只写业务逻辑Key: task fn = pure business logic 副作用的生命周期管理完全由框架通过 DeclaredIO 驱动Side effect lifecycle fully managed by framework via DeclaredIO → 无需手动 mirror 代码,新增 task 只需声明 io=[]→ no manual mirror code, new tasks just declare io=[]

概念总结Concept Summary

概念Concept 职责Responsibility 谁写Who writes it
ctx 属性attributes Task 之间的显式数据流Explicit data flow between Tasks 框架自动 diff/replayFramework auto diff/replay
DeclaredIO Task 的外部副作用(暗通道)Task's external side effects (dark channels) 用户在 Task 定义时声明User declares at Task definition
capture() side effect → 可缓存快照side effect → cacheable snapshot 用户提供回调User-provided callback
restore() 缓存快照 → 写回 side effectcached snapshot → write back side effect 用户提供回调User-provided callback
_cache_val/_restore_val 递归 detach+clone Tensor(含嵌套结构)Recursive tensor detach+clone (nested structures) 框架内部Framework internal
shortcut 决策shortcut decision 哪些 Task 走快捷路径Which Tasks use shortcut path Profiler / 用户 APIProfiler / User API
设计哲学:DeclaredIO 本质上是一个轻量级的 effect declaration(效果声明)系统——不依赖类型系统,而是在运行时通过 capture/restore 回调让框架管理 side effect 的生命周期。核心原则:task function 只负责计算,side effect 的生命周期管理交给框架。 Design philosophy: DeclaredIO is essentially a lightweight effect declaration system — instead of relying on a type system, it uses runtime capture/restore callbacks to let the framework manage side effect lifecycles. Core principle: task functions handle only computation; side effect lifecycle management is the framework's job.

SWSerialTrainPipeline 中的实际应用Real-World Usage in SWSerialTrainPipeline

在推荐系统 pipeline 中,Task 2–5 都有外部副作用(读写 PrefetchTrainPipelineContextSplitPrefetchPipelinedForward 内部状态)。通过 DeclaredIO,每个 task function 只写业务逻辑:

In the recommendation pipeline, Tasks 2–5 all have side effects (read/write PrefetchTrainPipelineContext and SplitPrefetchPipelinedForward internal state). With DeclaredIO, each task function contains only business logic:

Task 外部副作用Side Effect DeclaredIO.capture DeclaredIO.restore
EmbInputDistStart pipeline_ctx.fused_splits_awaitables
pipeline_ctx.module_contexts_next_batch
快照两者(Awaitable_IdempotentAwaitableSnapshot both (wrap Awaitable in _IdempotentAwaitable) 写回 pipeline_ctxWrite back to pipeline_ctx
EmbInputDistWait pipeline_ctx.input_dist_tensors_requests
pipeline_ctx.module_contexts
快照两者Snapshot both 写回 pipeline_ctxWrite back to pipeline_ctx
EmbPrefetch pipeline_ctx.module_input_post_prefetch
pipeline_ctx.module_contexts_post_prefetch
快照两者Snapshot both 写回 pipeline_ctxWrite back to pipeline_ctx
MPEmbForward fwd._cached_awaitable
(Module 内部状态)(module internal state)
fwd.get_embedding_result() fwd.set_embedding_result(cached)

TaskProfiler — 暴露时间测量TaskProfiler — Exposed Time Measurement

通过 Shortcut 缓存机制跳过单个 Task,测量对整体耗时的影响Skips individual Tasks via Shortcut caching to measure their impact on total time

📖 TaskProfiler 的完整 API 参考(profile()profile_many()ProfileResult.print_report())已迁移至 sw_pipeline_overview.html § Profiler 📖 Full TaskProfiler API reference (profile(), profile_many(), ProfileResult.print_report()) has moved to sw_pipeline_overview.html § Profiler.

SWPipeline 内置 Shortcut 机制:启用后首次执行正常运行并缓存 IterContext 输出,后续跳过 fn(ctx) 直接回放缓存值。TaskProfiler 利用此机制测量 exposed time

  • exposed(T) = baseline_serial − serial_with_T_shortcut
  • 结果驱动自动调度:确定哪些 Task 应分配到不同 stage 以最大化重叠

SWPipeline's built-in Shortcut mechanism: on first execution, runs normally and caches IterContext outputs; subsequent runs skip fn(ctx) and replay cached values. TaskProfiler leverages this for exposed time measurement:

  • exposed(T) = baseline_serial − serial_with_T_shortcut
  • Results drive auto-scheduling: determining which Tasks should be on different stages for max overlap
_exec_task — Three ModesShortcut / Caching / Normal
Mode 1: Shortcut Active name ∈ shortcut_tasks AND cache exists skip fn → replay cached → ctx Mode 2: Caching name ∈ shortcut_tasks BUT no cache yet run fn → capture → cache Mode 3: Normal name ∉ shortcut_tasks   run fn(ctx) — zero overhead ProfileResult baseline_s: median wall-clock per serial iteration | exposed_s: {task: max(0, baseline − median_shortcut)}

CMP 旧 Pipeline vs SWPipeline 对比Old Pipeline vs SWPipeline Comparison

维度Dimension PrefetchTrainPipelineSparseDist SWSerialTrainPipeline
设计范式Paradigm硬编码 4 阶段Hard-coded 4 stages声明式 Task DAG + PipelinePlanDeclarative Task DAG + PipelinePlan
调度方式Scheduling固定 stream 分配Fixed stream assignment可配置 stage/stream/thread_groupConfigurable stage/stream/thread_group
初始化Initfill_pipeline() 预取 2-3 batch prefetch 2-3 batches_ensure_initialized() 首 batch 内 within 1st batch
模型 RewriteModel RewritePrefetchPipelinedForwardPrefetchPipelinedForward(相同) (same)
同步Syncwait_stream + record_stream两阶段:CPU signal + CUDA eventTwo-phase: CPU signal + CUDA event
跨迭代依赖Cross-Iter隐式(stream 序)Implicit (stream order)显式 cross_iter_depsExplicit cross_iter_deps
Profiling无内置None built-inTaskProfiler shortcut + exposed time shortcut + exposed time
多线程Threading单线程Single thread每 thread_group 一个 workerOne worker per thread_group
全局序列化Global OrderN/A_SubmissionSequencer 保证跨 rank NCCL 顺序 consistent NCCL order across ranks
注册名Name"jagged_prefetch_sparse_dist""jagged_sw_serial"

Old: PrefetchTrainPipelineSparseDist

# Fixed 4-stage pipeline
def __init__(self, ...):
    self._memcpy_stream = cuda.Stream()
    self._data_dist_stream = cuda.Stream()
    self._prefetch_stream = cuda.Stream()
    # Stages hard-coded in progress()

New: SWSerialTrainPipeline

# Declarative Task DAG
def _build_pipeline(self):
    t_h2d = PipelineTask("H2D", fn)
    ... 11 tasks ...
    plan = PipelinePlan(
        schedule={t: TaskSchedule(stage=0)},
        deps=[(t_dist, t_h2d), ...],
        pipeline_depth=1,  # serial
    )
    return SWPipeline(plan)
演进路线:SWSerialTrainPipeline 是 serial baseline。下一步创建 pipelined 变体——只改 PipelinePlan(H2D/InputDist 分到 stage 1 + 独立 stream/thread),Task 函数不改。TaskProfiler 的 exposed time 数据自动指导最优 stage 划分。 Evolution: SWSerialTrainPipeline is the serial baseline. The pipelined variant only changes the PipelinePlan (assign H2D/InputDist to stage 1 + separate stream/thread). Task functions stay unchanged. TaskProfiler exposed time guides optimal stage partitioning.

源码文件索引Source File Reference

文件File关键内容Key Contents
examples/commons/pipeline/train_pipeline.pyTrainPipelineSparseDist, PrefetchTrainPipelineSparseDist, JaggedMegatron* variants
examples/commons/pipeline/utils.pyTrainPipelineContext, PrefetchTrainPipelineContext, PipelinedForward, PrefetchPipelinedForward, _rewrite_model, _start_data_dist, _prefetch_embeddings, ArgInfo
examples/commons/pipeline/sw_pipeline.pySWPipeline, PipelineTask, PipelinePlan, TaskSchedule, IterContext, TaskProfiler, _SubmissionSequencer
examples/commons/pipeline/sw_train_pipeline.pySWSerialTrainPipeline (11 tasks, serial baseline)
examples/commons/pipeline/train_pipeline_factory.pyPipeline factory — selects variant by config name
examples/hstu/test/test_pipeline.pytest_sw_serial_pipeline — numerical equivalence vs JaggedMegatronTrainNonePipeline
examples/hstu/training/trainer/training.pypipeline_breakdown() — TaskProfiler integration
examples/commons/distributed/finalize_model_grads.pyDDP gradient AllReduce synchronization
目录导航Contents