NVIDIA recsys-examples · TrainPipelineSparseDist · PrefetchTrainPipelineSparseDist · SWPipeline
⭐ GitHub: NVIDIA/recsys-examplesrecsys-examples 中有两套 TorchRec 训练 pipeline,它们在第一个 batch 到达时通过 _rewrite_model 将模型中的 ShardedModule.forward 替换为流水线化版本,使得 input_dist(AllToAll 通信)可以与上一批次的 forward / backward 重叠执行。
recsys-examples provides two TorchRec training pipelines that use _rewrite_model on the first batch to replace ShardedModule.forward with pipelined versions, enabling input_dist (AllToAll communication) to overlap with the previous batch's forward / backward.
ShardedModule.forward() 拆成 input_dist + compute_and_output_dist,让 AllToAll 通信提前执行。详细原理、Context 数据流和 torch.fx 追踪机制请参阅 → Forward Hijack Deep Dive
📖 Forward Split Mechanism: The pipeline's core is splitting ShardedModule.forward() into input_dist + compute_and_output_dist, executing AllToAll communication ahead of time. For details on Context data flow and torch.fx tracing → Forward Hijack Deep Dive
| 属性Property | TrainPipelineSparseDist | PrefetchTrainPipelineSparseDist |
|---|---|---|
| 流水线阶段Pipeline Stages | 3: ① H2D → ② input_dist → ③ fwd/bwd | 4: ① H2D → ② input_dist → ③ prefetch → ④ fwd/bwd |
| CUDA Streams | memcpy_stream, data_dist_stream |
memcpy_stream, data_dist_stream, prefetch_stream, default_stream |
| 流水线深度Pipeline Depth | 2 batches | 3 batches(batch_i, batch_ip1, batch_ip2) |
| Context | TrainPipelineContext |
PrefetchTrainPipelineContext |
| Forward 替换类型Replacement | PipelinedForward |
PrefetchPipelinedForward |
| Forward 数据来源Forward Data Source | input_dist_tensors_requests → request.wait() in fwd |
module_input_post_prefetch → data already prefetched |
| 适用场景Use Case | 标准 FBGEMM 静态 EmbeddingStandard FBGEMM static embedding | DynamicEmb(需 prefetch cache lookup)DynamicEmb (needs prefetch for cache lookup) |
_rewrite_model 是 pipeline 初始化的核心——它将模型中所有 ShardedModule.forward 替换为流水线化版本(PipelinedForward 或 PrefetchPipelinedForward)。这个过程只在第一个 batch 到达时执行一次,由 _pipeline_model() 调用。
_rewrite_model is the core of pipeline initialization — it replaces all ShardedModule.forward in the model with pipelined versions (PipelinedForward or PrefetchPipelinedForward). This runs only once when the first batch arrives, invoked by _pipeline_model().
start_sparse_data_dist)必须先执行一次 module.input_dist(),因为 ShardedModule 的 _input_dists 属性是惰性初始化的。只有执行过一次 input_dist 后,Step 6 的 _override_input_dist_forwards 才能找到并替换内部的 KJTAllToAll forward。
Why two steps? Step 5 (start_sparse_data_dist) must call module.input_dist() once because ShardedModule's _input_dists attribute is lazily initialized. Only after this first invocation can Step 6's _override_input_dist_forwards find and replace the internal KJTAllToAll forward.
_rewrite_model 有两条路径:FX 追踪路径(通过 torch.fx.Tracer 符号执行获取参数映射)和 hack 路径(mod_directly=True,直接扫描 batch 属性找到 KJT 字段)。当前生产代码走的是 hack 路径,因为它更简单、更鲁棒。下图展示完整的 FX 路径逻辑,以及 hack 路径如何跳过 FX 追踪。
_rewrite_model has two paths: the FX tracing path (using torch.fx.Tracer for symbolic execution to derive argument mappings) and the hack path (mod_directly=True, scanning batch attributes to find the KJT field directly). Production code currently uses the hack path as it is simpler and more robust. The diagram below shows the full FX path logic and how the hack path bypasses FX tracing.
TrainPipelineSparseDist 在首次 progress() 时填充 2 个 batch,对 batch_i 完成 rewrite + input_dist(AllToAll #1 发起 + 等待 → AllToAll #2/#3 发起)。
TrainPipelineSparseDist fills 2 batches on the first progress() call, completing rewrite + input_dist for batch_i (AllToAll #1 initiate + wait → AllToAll #2/#3 initiate).
PrefetchTrainPipelineSparseDist 填充 2 个 batch,但对 batch_1 额外完成 wait_sparse + prefetch(DynamicEmb cache lookup),然后对 batch_2 启动 input_dist。比 Base Pipeline 多了一个完整的 prefetch 阶段。
PrefetchTrainPipelineSparseDist fills 2 batches, but additionally completes wait_sparse + prefetch (DynamicEmb cache lookup) for batch_1, then starts input_dist for batch_2. One extra prefetch stage compared to Base Pipeline.
两种 Forward 类都继承自 BaseForward,核心区别在于 __call__ 中数据的来源和等待时机。
Both Forward classes inherit from BaseForward. The key difference is where data comes from and when waiting occurs in __call__.
def __call__(self, *input, **kwargs): # 从 context 取 awaitable request = self._context \ .input_dist_tensors_requests.pop(name) # ⚠ 在 forward 内部等待 AllToAll 完成 with cuda.stream(self._stream): data = request.wait() # blocking! # 同步 stream current_stream.wait_stream(self._stream) data.record_stream(current_stream) return module.compute_and_output_dist(ctx, data)
def __call__(self, *input, **kwargs): # ✅ 直接取已 prefetch 的数据 data = self._context \ .module_input_post_prefetch.pop(name) ctx = self._context \ .module_contexts_post_prefetch.pop(name) # 同步 prefetch_stream → default current_stream.wait_stream(self._stream) data.record_stream(current_stream) return module.compute_and_output_dist(ctx, data)
@dataclass class TrainPipelineContext: # AllToAll #1 awaitable input_dist_splits_requests: Dict[str, Awaitable] # AllToAll #2+#3 awaitable input_dist_tensors_requests: Dict[str, Awaitable] # ShardedModule contexts module_contexts: Dict[str, Multistreamable] # Fused splits for batched AllToAll fused_splits_awaitables: List[...] # Postproc cache postproc_fwd_results: Dict[str, Any] events: List[torch.Event] index: Optional[int] version: int # 0=legacy, 1=current
@dataclass class PrefetchTrainPipelineContext( TrainPipelineContext ): # ↓↓↓ NEW: prefetch 输出 ↓↓↓ # prefetch 后的 KJTList 数据 module_input_post_prefetch: Dict[str, Multistreamable] # prefetch 后的 module context module_contexts_post_prefetch: Dict[str, Multistreamable] # (deprecated v0 fields) module_input_post_prefetch_next_batch: ... module_contexts_post_prefetch_next_batch: ...
input_dist_tensors_requests 在 PipelinedForward.__call__ 中被 pop 并 wait()。Prefetch pipeline 中,_prefetch_embeddings() 先消费 input_dist_tensors_requests,执行 prefetch 后将结果存入 module_input_post_prefetch,再由 PrefetchPipelinedForward.__call__ 消费。
Data flow: In Base pipeline, input_dist_tensors_requests is popped and wait()ed in PipelinedForward.__call__. In Prefetch pipeline, _prefetch_embeddings() first consumes input_dist_tensors_requests, runs prefetch, stores results in module_input_post_prefetch, which is then consumed by PrefetchPipelinedForward.__call__.
TrainPipelineContext 是 pipeline 各阶段之间的数据桥梁。每个 batch 拥有独立的 context 实例,数据在 pipeline 阶段之间通过 context 字典传递。下图展示 Base Pipeline 和 Prefetch Pipeline 中,context 字段在每个阶段如何被写入和消费。
TrainPipelineContext bridges data between pipeline stages. Each batch owns an independent context instance; data flows between stages via context dictionaries. The diagram below shows how context fields are populated and consumed at each stage in both Base and Prefetch pipelines.
以 JaggedMegatron* 生产变体为例,展示稳态下每次 progress() 的执行步骤。
Using JaggedMegatron* production variants as examples, showing the steady-state execution steps of each progress() call.
| 步骤Step | JaggedMegatronTrainPipelineSparseDist | JaggedMegatronPrefetchTrainPipelineSparseDist |
|---|---|---|
| 1 | zero_grad() | zero_grad() |
| 2 | wait_for_batch(batch_i) | wait_prefetch_async(batch_i) CPU 等待 background thread prefetch 完成 |
| 3 | start_sparse_data_dist(batch_ip1) | copy_batch_to_gpu(batch_ip2) [H2D] |
| 4 | copy_batch_to_gpu(batch_ip2) [H2D] | wait_sparse_data_dist() |
| 5 | wait_sparse_data_dist(batch_ip1) | Shuffle Phase 1 (AllGather workloads) |
| 6 | Shuffle Phase 1 (AllGather + submit KK) | forward(batch_i) PrefetchPipelinedForward: 零等待 |
| 7 | forward(batch_i) PipelinedForward: wait AllToAll #2/#3 |
_start_prefetch_async(batch_ip1) background thread 异步 prefetch |
| 8 | Shuffle Phase 2 (finish_shuffle) | Shuffle Phase 2 (finish_shuffle) |
| 9 | loss AllReduce + backward + optimizer | loss AllReduce + backward + optimizer |
| 10 | dequeue_batch() | _start_sparse_data_dist(batch_ip2) |
| 11 | — | batch_i ← batch_ip1, batch_ip1 ← batch_ip2 |
SWPipeline 是 recsys-examples 新增的通用软件流水线框架,将上面描述的硬编码阶段分解为 声明式 Task DAG。核心设计将 "做什么"(PipelineTask)和 "怎么调度"(PipelinePlan)完全分离——同一组 Task 可以用不同的 Plan 切换 serial / pipelined 模式。
SWPipeline is a new general-purpose software pipeline framework in recsys-examples that decomposes the hard-coded stages described above into a declarative Task DAG. It cleanly separates "what to compute" (PipelineTask) from "how to schedule" (PipelinePlan) — the same Tasks can run with different Plans to switch between serial / pipelined modes.
examples/commons/pipeline/sw_pipeline.py · examples/commons/pipeline/sw_train_pipeline.py
_build_pipeline() 构造 11 个 Task 及其依赖。Serial 模式 (depth=1) 全部在 stage 0 / default stream / 主线程上执行。DenseForward(6) 和 EmbBackward(9) 在 serial 模式下是 no-op,为未来 pipelined 变体预留(embedding/dense forward/backward 分离)。
_build_pipeline() constructs 11 Tasks with dependencies. Serial mode (depth=1) executes all on stage 0 / default stream / main thread. DenseForward(6) and EmbBackward(9) are no-ops in serial mode, reserved for future pipelined variants (embedding/dense forward/backward split).
SWPipeline 多线程执行引入经典竞态:若 Task B 在 Task A 完成 event.record() 前就调 stream.wait_event(event),wait 立即返回(event 未 record)→ GPU 执行序被打破。
threading.Event;下游先 cpu_signal.wait() 确保 CUDA event 已 recordstream.wait_event(cuda_event),建立正确 GPU 依赖SWPipeline's multi-threaded execution introduces a classic race: if Task B calls stream.wait_event(event) before Task A completes event.record(), wait returns immediately (event not recorded) → GPU execution order breaks.
threading.Event after completion; downstream cpu_signal.wait() ensures CUDA event is recordedstream.wait_event(cuda_event) to establish correct GPU dependencyfill_pipeline() 中显式预取 2-3 batch。SWSerial (depth=1) 不需要预取,在第一个 batch 内通过 _ensure_initialized 完成初始化。未来 pipelined 变体 (depth>1) 将使用 SWPipeline.fill_pipeline() 预取多 batch。
Difference from old pipeline: Old pipeline explicitly prefetches 2-3 batches in fill_pipeline(). SWSerial (depth=1) needs no prefetch — init completes within the first batch via _ensure_initialized. Future pipelined variant (depth>1) will use SWPipeline.fill_pipeline() for multi-batch prefetch.
Shortcut 是 SWPipeline 的核心基础设施之一。它允许在运行时跳过指定 Task 的计算,转而回放之前缓存的 IterContext 输出。这不是简单的 "删除"——它必须:
IterContext 的全部副作用(新增/修改/删除的属性)globally_ordered Task 仍然执行 _SubmissionSequencer 保序,避免跨 rank NCCL 死锁Shortcut is one of SWPipeline's core infrastructure features. It allows skipping a Task's computation at runtime, replaying previously cached IterContext outputs instead. This is not a simple "delete" — it must:
IterContext (added/modified/deleted attrs)_SubmissionSequencer for globally_ordered Tasks to avoid cross-rank NCCL deadlocks_exec_task()Three-Phase Execution — _exec_task()_shortcut_cache[name] 是一个 5-tuple:
_shortcut_cache[name] is a 5-tuple:
| 字段Field | 类型Type | 含义Meaning |
|---|---|---|
produced |
Dict[str, Any] |
Task 执行后在 IterContext 上新增或修改的属性。Tensor 存为 (detached_clone, requires_grad) 二元组。Attributes added or modified on IterContext by the Task. Tensors stored as (detached_clone, requires_grad) tuples. |
deleted |
frozenset |
Task 执行后从 IterContext 上删除的属性名集合。Attribute names deleted from IterContext by the Task. |
input_attr_names |
tuple |
Task 执行前已存在、执行后未被修改、且 requires_grad=True 的 Tensor 属性名。这些是"上游输入",_GraftGrad 需要它们来维持梯度链。Tensor attrs existing before Task, unchanged after, with requires_grad=True. These are "upstream inputs" that _GraftGrad needs to maintain the grad chain. |
zero_grads |
Tuple[Tensor, ...] |
为每个 upstream input 预分配的零张量,在 _GraftGrad.backward 中直接返回,避免每次 backward 重复分配。Pre-allocated zero tensors for each upstream input, returned directly in _GraftGrad.backward to avoid repeated allocations. |
io_cached |
List[Any] |
每个 DeclaredIO.capture() 的返回值经 _cache_val 递归缓存。_apply_shortcut 时调用对应的 dio.restore(_restore_val(cached)) 将外部状态写回。若 Task 无 io,此列表为空。Each DeclaredIO.capture() return value recursively cached via _cache_val. During _apply_shortcut, the corresponding dio.restore(_restore_val(cached)) writes external state back. Empty list if Task has no io. |
这是 Shortcut 中最精妙的部分。当 Task 被 shortcut 时,它产出的 Tensor 是从 cache 恢复的 detach() 叶节点——与上游计算图断开了。如果下游 loss.backward() 需要梯度流过这里,就会中断。
_GraftGrad 是一个自定义 torch.autograd.Function,作用如同"梯度桥":
grad_output 传给 restored(下游正常收梯度);将预分配的零张量传给所有 upstream inputs(触发它们的 backward 链,但因为是零所以不影响参数梯度)这保证了:即使 Task 被跳过,上游参数仍然参与 backward 图(不会报 "unused parameter" 错误),但梯度贡献为零(符合"这个 Task 没有真正计算"的语义)。
This is the most elegant part of Shortcut. When a Task is shortcut, its output Tensors are detach()ed leaf nodes restored from cache — disconnected from the upstream computation graph. If downstream loss.backward() needs gradient flow through here, it breaks.
_GraftGrad is a custom torch.autograd.Function acting as a "gradient bridge":
grad_output to restored (downstream gets normal grads); passes pre-allocated zeros to all upstream inputs (triggers their backward chains, but zeros mean no parameter gradient impact)This ensures: even with Task skipped, upstream params still participate in the backward graph (no "unused parameter" errors), but contribute zero gradients (semantically correct for "Task didn't actually compute").
| 条件Condition | 可否 ShortcutShortcuttable? | 原因Reason |
|---|---|---|
纯 ctx 计算 Task(无外部副作用)Pure ctx compute Task (no side effects) |
✅ | 输出可 detach+clone 缓存,_GraftGrad 维持梯度Outputs can be detach+clone cached, _GraftGrad maintains grads |
globally_ordered Task(含 NCCL)globally_ordered Task (with NCCL) |
✅ | _SubmissionSequencer 仍执行保序,只跳 fn() 计算_SubmissionSequencer still fires for ordering, only fn() is skipped |
有外部副作用的 Task(写 pipeline_ctx、module 状态等)Tasks with side effects (pipeline_ctx, module state, etc.) |
✅ 需声明 DeclaredIOwith DeclaredIO |
框架自动 capture/restore 外部状态,可独立 shortcutFramework auto capture/restore external state, can shortcut independently |
| 不可重放的 Task(如 checkpoint 写盘)Non-replayable Tasks (e.g. checkpoint writes) | ❌ | 副作用不可逆,无法通过 capture/restore 重放Side effects are irreversible, cannot be replayed via capture/restore |
sw_pipeline.py:258–411 (Shortcut API + internals) ·
sw_pipeline.py:317–342 (_GraftGrad)
Shortcut 通过 diff IterContext 的属性来缓存/恢复 task 的输出。但许多 task 有框架看不见的副作用 (side effects)——它们读写了 ctx 之外的共享状态:
PrefetchTrainPipelineContext(TorchRec 的 pipeline 共享状态:fused_splits_awaitables、input_dist_tensors_requests、module_input_post_prefetch 等)SplitPrefetchPipelinedForward._cached_awaitable)如果不处理这些副作用,shortcut 一个 task 后,下游 task 会从共享状态中读到过期数据,导致 AssertionError 或 "backward through graph a second time" 等错误。
Shortcut diffs IterContext attributes to cache/restore a task's output. But many tasks have side effects invisible to the framework — they read/write shared state outside of ctx:
PrefetchTrainPipelineContext (TorchRec pipeline shared state: fused_splits_awaitables, input_dist_tensors_requests, module_input_post_prefetch, etc.)SplitPrefetchPipelinedForward._cached_awaitable)Without handling these, shortcutting a task causes downstream tasks to read stale data from shared state, leading to AssertionError or "backward through graph a second time" errors.
// Task B: 写入 shared 状态 def encode(ctx): ctx.embedding = encoder(ctx.features) shared.buffer["emb"] = ctx.embedding ctx._mirror_buffer = dict(shared.buffer) // ← 手动复制到 ctx // Task C: 读取 shared 状态 def decode(ctx): shared.buffer = dict(ctx._mirror_buffer) // ← 手动恢复 ctx.output = decoder(shared.buffer["emb"])
每个 task 都要写 mirror 代码;下游必须知道上游的命名约定;新增 task 时 mirror 链脆弱。
Every task needs mirror code; downstream must know upstream's naming conventions; mirror chain is fragile when adding new tasks.
// Task B: 纯业务逻辑 def encode(ctx): ctx.embedding = encoder(ctx.features) shared.buffer["emb"] = ctx.embedding // 正常写 // Task C: 纯业务逻辑 def decode(ctx): ctx.output = decoder(shared.buffer["emb"]) // 正常读 // 副作用声明在 task 定义处 PipelineTask("encode", encode, io=[ DeclaredIO( capture=lambda: dict(shared.buffer), restore=lambda s: shared.buffer.update(s), ), ])
DeclaredIO 是一个简单的 dataclass,声明 task 对外部状态的读写合约:
DeclaredIO is a simple dataclass declaring the task's read/write contract with external state:
| 字段Field | 类型Type | 含义Meaning |
|---|---|---|
capture |
Callable[[], Any] |
Task 执行后调用。快照外部状态,返回可缓存的值。返回值经 _cache_val 递归 detach+clone 所有 Tensor。Called after Task runs. Snapshots external state, returns a cacheable value. Return value goes through _cache_val (recursive tensor detach+clone). |
restore |
Callable[[Any], None] |
Shortcut 时调用。接收 _restore_val 处理后的值(全新 detached Tensor),写回外部状态,确保下游 task 能找到正确数据。Called during shortcut. Receives value processed by _restore_val (fresh detached tensors), writes back to external state so downstream tasks find correct data. |
| 概念Concept | 职责Responsibility | 谁写Who writes it |
|---|---|---|
ctx 属性attributes |
Task 之间的显式数据流Explicit data flow between Tasks | 框架自动 diff/replayFramework auto diff/replay |
DeclaredIO |
Task 的外部副作用(暗通道)Task's external side effects (dark channels) | 用户在 Task 定义时声明User declares at Task definition |
capture() |
side effect → 可缓存快照side effect → cacheable snapshot | 用户提供回调User-provided callback |
restore() |
缓存快照 → 写回 side effectcached snapshot → write back side effect | 用户提供回调User-provided callback |
_cache_val/_restore_val |
递归 detach+clone Tensor(含嵌套结构)Recursive tensor detach+clone (nested structures) | 框架内部Framework internal |
| shortcut 决策shortcut decision | 哪些 Task 走快捷路径Which Tasks use shortcut path | Profiler / 用户 APIProfiler / User API |
DeclaredIO 本质上是一个轻量级的 effect declaration(效果声明)系统——不依赖类型系统,而是在运行时通过 capture/restore 回调让框架管理 side effect 的生命周期。核心原则:task function 只负责计算,side effect 的生命周期管理交给框架。
Design philosophy: DeclaredIO is essentially a lightweight effect declaration system — instead of relying on a type system, it uses runtime capture/restore callbacks to let the framework manage side effect lifecycles. Core principle: task functions handle only computation; side effect lifecycle management is the framework's job.
在推荐系统 pipeline 中,Task 2–5 都有外部副作用(读写 PrefetchTrainPipelineContext 和 SplitPrefetchPipelinedForward 内部状态)。通过 DeclaredIO,每个 task function 只写业务逻辑:
In the recommendation pipeline, Tasks 2–5 all have side effects (read/write PrefetchTrainPipelineContext and SplitPrefetchPipelinedForward internal state). With DeclaredIO, each task function contains only business logic:
| Task | 外部副作用Side Effect | DeclaredIO.capture |
DeclaredIO.restore |
|---|---|---|---|
| EmbInputDistStart | pipeline_ctx.fused_splits_awaitablespipeline_ctx.module_contexts_next_batch |
快照两者(Awaitable 包 _IdempotentAwaitable)Snapshot both (wrap Awaitable in _IdempotentAwaitable) |
写回 pipeline_ctxWrite back to pipeline_ctx |
| EmbInputDistWait | pipeline_ctx.input_dist_tensors_requestspipeline_ctx.module_contexts |
快照两者Snapshot both | 写回 pipeline_ctxWrite back to pipeline_ctx |
| EmbPrefetch | pipeline_ctx.module_input_post_prefetchpipeline_ctx.module_contexts_post_prefetch |
快照两者Snapshot both | 写回 pipeline_ctxWrite back to pipeline_ctx |
| MPEmbForward | fwd._cached_awaitable(Module 内部状态)(module internal state) |
fwd.get_embedding_result() |
fwd.set_embedding_result(cached) |
profile()、profile_many()、ProfileResult.print_report())已迁移至 sw_pipeline_overview.html § Profiler。
📖 Full TaskProfiler API reference (profile(), profile_many(), ProfileResult.print_report()) has moved to sw_pipeline_overview.html § Profiler.
SWPipeline 内置 Shortcut 机制:启用后首次执行正常运行并缓存 IterContext 输出,后续跳过 fn(ctx) 直接回放缓存值。TaskProfiler 利用此机制测量 exposed time:
exposed(T) = baseline_serial − serial_with_T_shortcutSWPipeline's built-in Shortcut mechanism: on first execution, runs normally and caches IterContext outputs; subsequent runs skip fn(ctx) and replay cached values. TaskProfiler leverages this for exposed time measurement:
exposed(T) = baseline_serial − serial_with_T_shortcut| 维度Dimension | PrefetchTrainPipelineSparseDist | SWSerialTrainPipeline |
|---|---|---|
| 设计范式Paradigm | 硬编码 4 阶段Hard-coded 4 stages | 声明式 Task DAG + PipelinePlanDeclarative Task DAG + PipelinePlan |
| 调度方式Scheduling | 固定 stream 分配Fixed stream assignment | 可配置 stage/stream/thread_groupConfigurable stage/stream/thread_group |
| 初始化Init | fill_pipeline() 预取 2-3 batch prefetch 2-3 batches | _ensure_initialized() 首 batch 内 within 1st batch |
| 模型 RewriteModel Rewrite | PrefetchPipelinedForward | PrefetchPipelinedForward(相同) (same) |
| 同步Sync | wait_stream + record_stream | 两阶段:CPU signal + CUDA eventTwo-phase: CPU signal + CUDA event |
| 跨迭代依赖Cross-Iter | 隐式(stream 序)Implicit (stream order) | 显式 cross_iter_depsExplicit cross_iter_deps |
| Profiling | 无内置None built-in | TaskProfiler shortcut + exposed time shortcut + exposed time |
| 多线程Threading | 单线程Single thread | 每 thread_group 一个 workerOne worker per thread_group |
| 全局序列化Global Order | N/A | _SubmissionSequencer 保证跨 rank NCCL 顺序 consistent NCCL order across ranks |
| 注册名Name | "jagged_prefetch_sparse_dist" | "jagged_sw_serial" |
# Fixed 4-stage pipeline def __init__(self, ...): self._memcpy_stream = cuda.Stream() self._data_dist_stream = cuda.Stream() self._prefetch_stream = cuda.Stream() # Stages hard-coded in progress()
# Declarative Task DAG def _build_pipeline(self): t_h2d = PipelineTask("H2D", fn) ... 11 tasks ... plan = PipelinePlan( schedule={t: TaskSchedule(stage=0)}, deps=[(t_dist, t_h2d), ...], pipeline_depth=1, # serial ) return SWPipeline(plan)
| 文件File | 关键内容Key Contents |
|---|---|
examples/commons/pipeline/train_pipeline.py | TrainPipelineSparseDist, PrefetchTrainPipelineSparseDist, JaggedMegatron* variants |
examples/commons/pipeline/utils.py | TrainPipelineContext, PrefetchTrainPipelineContext, PipelinedForward, PrefetchPipelinedForward, _rewrite_model, _start_data_dist, _prefetch_embeddings, ArgInfo |
examples/commons/pipeline/sw_pipeline.py | SWPipeline, PipelineTask, PipelinePlan, TaskSchedule, IterContext, TaskProfiler, _SubmissionSequencer |
examples/commons/pipeline/sw_train_pipeline.py | SWSerialTrainPipeline (11 tasks, serial baseline) |
examples/commons/pipeline/train_pipeline_factory.py | Pipeline factory — selects variant by config name |
examples/hstu/test/test_pipeline.py | test_sw_serial_pipeline — numerical equivalence vs JaggedMegatronTrainNonePipeline |
examples/hstu/training/trainer/training.py | pipeline_breakdown() — TaskProfiler integration |
examples/commons/distributed/finalize_model_grads.py | DDP gradient AllReduce synchronization |