Pipeline Context — Class Hierarchy, TorchRec Types & Lifecycle

NVIDIA recsys-examples · TrainPipelineContext 深度解析

Context 类层次图Context Class Hierarchy

TrainPipelineContext 及其派生类、关联的 TorchRec 类型TrainPipelineContext, derived classes, and associated TorchRec types

类关系图 — 继承 + 组合Class Diagram — Inheritance + Composition
TrainPipelineContext @dataclass · utils.py:89 FIELDS (split phase): input_dist_splits_requests: Dict[str, Awaitable] fused_splits_awaitables: List[(..., FusedKJTListSplitsAwaitable)] FIELDS (tensor phase): input_dist_tensors_requests: Dict[str, Awaitable] module_contexts: Dict[str, Multistreamable] FIELDS (metadata): events: List[torch.Event] postproc_fwd_results: Dict[str, Any] index: Optional[int] module_contexts_next_batch: Dict (v0 deprecated) PrefetchTrainPipelineContext extends TrainPipelineContext · utils.py:130 ADDITIONAL FIELDS (post-prefetch): module_input_post_prefetch: Dict[str, Multistreamable] module_contexts_post_prefetch: Dict[str, Multistreamable] module_input_post_prefetch_next_batch (v0 deprecated) module_contexts_post_prefetch_next_batch (v0 deprecated) extends Awaitable[W] torchrec.distributed.types ABC + Generic[W] wait() → W (calls _wait_impl) callbacks: List[Callable[[W], W]] Multistreamable torchrec.streamable ABC record_stream(stream) KJTListSplitsAwaitable torchrec.distributed.embedding_sharding Awaitable[Awaitable[KJTList]] wait() → KJTListAwaitable Phase 1: splits AllToAll FusedKJTListSplitsAwaitable fuses multiple KJTListSplits wait() → List[KJTListAwaitable] KJTAllToAllTensorsAwaitable Awaitable[KeyedJaggedTensor] wait() → KJT on local rank ShardedModule torchrec.distributed.types create_context() → ShrdCtx input_dist(ctx, kjt) → Awt[Awt] compute_and_output_dist(ctx, data) BaseForward[TCtx] _context: TCtx (bound TrainPipelineContext) _module: ShardedModule _args: List[ArgInfo] set_context(ctx) / get_context() PipelinedForward pops tensors_requests PrefetchPipelinedFwd pops post_prefetch SplitPrefetchPipelinedFwd embedding + dense 两阶段 _context ref TrainPipelineSparseDist batches: Deque[Optional[In]] contexts: Deque[TrainPipelineContext] _pipelined_modules: List[ShardedModule] contexts[0] → current forward/backward batch contexts[1] → next batch (input_dist running) 每 progress() 调 dequeue_batch() → popleft + set_module_context 新 batch 入队 → _create_context(index) → append PrefetchPipeline 使用 _batch_i/_batch_ip1 (v0 legacy) holds deque of TorchRec AllToAll 两阶段模型 Phase 1 (splits): input_dist() → KJTListSplitsAwaitable → .wait() → KJTAllToAllTensorsAwaitable Phase 2 (tensors): tensors_awaitable.wait() → KeyedJaggedTensor (本地数据就绪) Pipeline Context 的字段名直接对应这两个阶段:input_dist_splits_requests → input_dist_tensors_requests

TorchRec 核心类型详解TorchRec Core Types Deep Dive

Context 字段中存储的 TorchRec 对象实际上是什么?What are the TorchRec objects stored in Context fields?

类型Type 来源Source 核心方法Key Method 含义Meaning
Awaitable[W] torchrec.distributed.types wait() → W 异步通信句柄的抽象基类。不是 Python asyncio —— 是 TorchRec 自己的 "wait 模式",统一表示 AllToAll 等通信操作的未就绪结果。wait() 内部调 _wait_impl() 然后依次执行 callbacksABC for async communication handles. NOT Python asyncio — TorchRec's own "wait pattern". wait() calls _wait_impl() then runs callbacks in chain.
Multistreamable torchrec.streamable record_stream(s) 允许跨 CUDA stream 传递的对象接口。语义等同 torch.Tensor.record_stream:声明该对象将在指定 stream 上被读取,防止 CUDA 缓存分配器过早回收内存。ShardedModule.create_context() 返回的上下文实现此接口。Interface for objects that can be safely passed across CUDA streams. Semantically equivalent to torch.Tensor.record_stream: declares the object will be read on the given stream. ShardedModule.create_context() returns objects implementing this.
KJTListSplitsAwaitable torchrec.distributed.embedding_sharding wait() → KJTListAwaitable AllToAll Phase 1:包含多个 KJTSplitsAllToAllMeta(splits 张量列表)。wait() 完成 splits 通信,得到各 rank 的 tensor 大小,返回 KJTListAwaitable(Phase 2 句柄)。存储在 input_dist_splits_requestsAllToAll Phase 1: Contains KJTSplitsAllToAllMeta. wait() completes splits communication, returns KJTListAwaitable (Phase 2 handle). Stored in input_dist_splits_requests.
FusedKJTListSplitsAwaitable torchrec.distributed.embedding_sharding wait() → List[KJTListAwaitable] 多组 KJT 的 splits AllToAll 融合为一次通信_fuse_input_dist_splits() 按 ProcessGroup 分组创建。存储在 fused_splits_awaitablesFuses multiple KJT splits AllToAll into a single communication. Created by _fuse_input_dist_splits() grouped by ProcessGroup. Stored in fused_splits_awaitables.
KJTAllToAllTensorsAwaitable torchrec.distributed.dist_data wait() → KeyedJaggedTensor AllToAll Phase 2:已知 splits 后,对 values/lengths 等张量分量做 all_to_all_singlewait() 返回本 rank 重排后的 KeyedJaggedTensor。这就是 input_dist_tensors_requests 中存储的对象。AllToAll Phase 2: With splits known, performs all_to_all_single on values/lengths tensor components. wait() returns the reassembled local KeyedJaggedTensor. This is what's stored in input_dist_tensors_requests.
ShrdCtx (Multistreamable) torchrec.distributed.types record_stream() ShardedModule.create_context() 的返回值。包含 sharding metadata(input_splits, output_splits, recat 等)。由 _start_data_dist 创建,存入 module_contexts,在 PipelinedForward.__call__ 中传给 compute_and_output_dist(ctx, data)Return value of ShardedModule.create_context(). Contains sharding metadata. Created by _start_data_dist, stored in module_contexts, passed to compute_and_output_dist(ctx, data) during forward.

字段生产者/消费者映射Field Producer / Consumer Mapping

字段Field 类型Type 写入者 (Producer)Writer (Producer) 读取者 (Consumer)Reader (Consumer) Stream
input_dist_splits_requests Dict[str, Awaitable] _start_data_dist() _fuse_input_dist_splits() data_dist
fused_splits_awaitables List[(names, FusedAwaitable)] _fuse_input_dist_splits() wait_sparse_data_dist() data_dist
input_dist_tensors_requests Dict[str, Awaitable] wait_sparse_data_dist() PipelinedForward / _prefetch_embeddings data_dist → default / prefetch
module_contexts Dict[str, Multistreamable] _start_data_dist() PipelinedForward / _prefetch_embeddings data_dist → default / prefetch
module_input_post_prefetch Dict[str, Multistreamable] _prefetch_embeddings() PrefetchPipelinedForward prefetch → default
module_contexts_post_prefetch Dict[str, Multistreamable] _prefetch_embeddings() PrefetchPipelinedForward prefetch → default
关键理解:Context 是一个 "按阶段填充" 的字典容器。每个 pipeline stage 负责写入特定字段、清空上一阶段的字段。字典的 key 都是 ShardedModule 的 FQN(如 "model.sparse_arch.embedding_bag_collection"),value 是 TorchRec 的 Awaitable 或已完成的数据。 Key insight: Context is a "progressively filled" dictionary container. Each pipeline stage writes specific fields and clears previous ones. Dict keys are ShardedModule FQNs, values are TorchRec Awaitables or resolved data.

Context 生命周期 — Base PipelineContext Lifecycle — Base Pipeline

TrainPipelineContext 在 TrainPipelineSparseDist 中从创建到销毁的完整流程Full lifecycle of TrainPipelineContext in TrainPipelineSparseDist

Base Pipeline Context 生命周期序列图Base Pipeline Context Lifecycle
Pipeline
Context
splits_req
fused_awaitables
tensors_req
module_ctx
PHASE 0: Creation — _create_context() new Context(index=N, version=1) 所有 Dict 字段 = {} (empty) PHASE 1: _start_data_dist(batch, ctx) — on data_dist_stream sm.create_context() → module_contexts[name] = ShrdCtx sm.input_dist(module_ctx, kjt) → splits_requests[name] = KJTListSplitsAwaitable _fuse_input_dist_splits() → fused_splits_awaitables AllToAll Phase 1 通信已启动 (异步) PHASE 2: wait_sparse_data_dist(ctx) — on data_dist_stream fused_awaitable.wait() → tensors_requests[name] = KJTAllToAllTensorsAwaitable splits_requests.clear() fused_awaitables.clear() PHASE 3: PipelinedForward.__call__() — on default_stream (during model(batch)) request = tensors_requests.pop(name) data = request.wait() → KeyedJaggedTensor ctx = module_contexts.pop(name) sm.compute_and_output_dist(ctx, data) → embedding output PHASE 4: dequeue_batch() — after backward + optimizer contexts.popleft() → Context 出队,等待 GC 所有 Dict 已被 pop 清空 各字段状态时间线 splits_requests: LIVE cleared fused_awaitables: LIVE cleared tensors_requests: LIVE (AllToAll Phase 2 running) popped module_contexts: LIVE (holds ShrdCtx from create_context) popped Phase 1 Phase 2 Phase 3 Phase 4: GC

Context 生命周期 — Prefetch PipelineContext Lifecycle — Prefetch Pipeline

PrefetchTrainPipelineContext 额外的 prefetch 阶段PrefetchTrainPipelineContext's additional prefetch phase

Prefetch Pipeline 在 Phase 2 和 Phase 3 之间插入了 Phase 2.5: _prefetch_embeddings()。这个阶段在 prefetch_stream 上执行:

  1. input_dist_tensors_requests pop 出 awaitable 并 wait()(等 AllToAll Phase 2 完成)
  2. 调用 ShardedModule.prefetch(ctx, dist_input)(DynamicEmb cache 预热)
  3. 将结果存入 module_input_post_prefetchmodule_contexts_post_prefetch

这样 Phase 3(PrefetchPipelinedForward.__call__)只需从 post_prefetch 字段 pop 数据,零等待

Prefetch Pipeline inserts Phase 2.5: _prefetch_embeddings() between Phase 2 and Phase 3. This runs on prefetch_stream:

  1. Pop awaitable from input_dist_tensors_requests and wait() (completes AllToAll Phase 2)
  2. Call ShardedModule.prefetch(ctx, dist_input) (DynamicEmb cache warm-up)
  3. Store results in module_input_post_prefetch and module_contexts_post_prefetch

Phase 3 (PrefetchPipelinedForward.__call__) simply pops from post_prefetch fields — zero wait.

Prefetch Pipeline Context 字段状态时间线Prefetch Pipeline Context Field Timeline
_start_data_dist wait_sparse_dist _prefetch_embeddings PrefetchPipelinedFwd dequeue splits_requests fused_awaitables tensors_requests AllToAll Phase 2 popped by _prefetch module_contexts ShrdCtx (sharding metadata) ★ PREFETCH-ONLY FIELDS: post_prefetch (input) wait+prefetch done popped post_prefetch (ctx) ShrdCtx transferred 关键区别:Base Pipeline 在 Phase 3 (forward) 时 wait tensors → 有延迟。Prefetch Pipeline 在 Phase 2.5 就 wait 了 → Phase 3 零等待。
总结TrainPipelineContext 是一个阶段性字典容器——Pipeline 的每个 stage 往里 "存" 数据(异步句柄或已完成结果),下一个 stage 从中 "取" 并 pop。Context 随 batch 一起入队 (contexts deque),生命周期 = 一个 batch 从 H2D 到 backward/optimizer 完成。所有字段在出队时已被 pop 清空,等待 Python GC。 Summary: TrainPipelineContext is a stage-progressive dictionary container — each pipeline stage "stores" data (async handles or resolved results), the next stage "pops" and consumes. Context enters the deque with its batch, lives through H2D → forward → backward → optimizer, and all fields are pop-emptied by dequeue time.