NVIDIA recsys-examples · TrainPipelineContext 深度解析
utils.py:89–141 (Context) · utils.py:467–636 (BaseForward hierarchy)
| 类型Type | 来源Source | 核心方法Key Method | 含义Meaning |
|---|---|---|---|
Awaitable[W] |
torchrec.distributed.types | wait() → W |
异步通信句柄的抽象基类。不是 Python asyncio —— 是 TorchRec 自己的 "wait 模式",统一表示 AllToAll 等通信操作的未就绪结果。wait() 内部调 _wait_impl() 然后依次执行 callbacks。ABC for async communication handles. NOT Python asyncio — TorchRec's own "wait pattern". wait() calls _wait_impl() then runs callbacks in chain. |
Multistreamable |
torchrec.streamable | record_stream(s) |
允许跨 CUDA stream 传递的对象接口。语义等同 torch.Tensor.record_stream:声明该对象将在指定 stream 上被读取,防止 CUDA 缓存分配器过早回收内存。ShardedModule.create_context() 返回的上下文实现此接口。Interface for objects that can be safely passed across CUDA streams. Semantically equivalent to torch.Tensor.record_stream: declares the object will be read on the given stream. ShardedModule.create_context() returns objects implementing this. |
KJTListSplitsAwaitable |
torchrec.distributed.embedding_sharding | wait() → KJTListAwaitable |
AllToAll Phase 1:包含多个 KJTSplitsAllToAllMeta(splits 张量列表)。wait() 完成 splits 通信,得到各 rank 的 tensor 大小,返回 KJTListAwaitable(Phase 2 句柄)。存储在 input_dist_splits_requests。AllToAll Phase 1: Contains KJTSplitsAllToAllMeta. wait() completes splits communication, returns KJTListAwaitable (Phase 2 handle). Stored in input_dist_splits_requests. |
FusedKJTListSplitsAwaitable |
torchrec.distributed.embedding_sharding | wait() → List[KJTListAwaitable] |
多组 KJT 的 splits AllToAll 融合为一次通信。_fuse_input_dist_splits() 按 ProcessGroup 分组创建。存储在 fused_splits_awaitables。Fuses multiple KJT splits AllToAll into a single communication. Created by _fuse_input_dist_splits() grouped by ProcessGroup. Stored in fused_splits_awaitables. |
KJTAllToAllTensorsAwaitable |
torchrec.distributed.dist_data | wait() → KeyedJaggedTensor |
AllToAll Phase 2:已知 splits 后,对 values/lengths 等张量分量做 all_to_all_single。wait() 返回本 rank 重排后的 KeyedJaggedTensor。这就是 input_dist_tensors_requests 中存储的对象。AllToAll Phase 2: With splits known, performs all_to_all_single on values/lengths tensor components. wait() returns the reassembled local KeyedJaggedTensor. This is what's stored in input_dist_tensors_requests. |
ShrdCtx (Multistreamable) |
torchrec.distributed.types | record_stream() |
ShardedModule.create_context() 的返回值。包含 sharding metadata(input_splits, output_splits, recat 等)。由 _start_data_dist 创建,存入 module_contexts,在 PipelinedForward.__call__ 中传给 compute_and_output_dist(ctx, data)。Return value of ShardedModule.create_context(). Contains sharding metadata. Created by _start_data_dist, stored in module_contexts, passed to compute_and_output_dist(ctx, data) during forward. |
| 字段Field | 类型Type | 写入者 (Producer)Writer (Producer) | 读取者 (Consumer)Reader (Consumer) | Stream |
|---|---|---|---|---|
input_dist_splits_requests |
Dict[str, Awaitable] | _start_data_dist() |
_fuse_input_dist_splits() |
data_dist |
fused_splits_awaitables |
List[(names, FusedAwaitable)] | _fuse_input_dist_splits() |
wait_sparse_data_dist() |
data_dist |
input_dist_tensors_requests |
Dict[str, Awaitable] | wait_sparse_data_dist() |
PipelinedForward / _prefetch_embeddings |
data_dist → default / prefetch |
module_contexts |
Dict[str, Multistreamable] | _start_data_dist() |
PipelinedForward / _prefetch_embeddings |
data_dist → default / prefetch |
module_input_post_prefetch |
Dict[str, Multistreamable] | _prefetch_embeddings() |
PrefetchPipelinedForward |
prefetch → default |
module_contexts_post_prefetch |
Dict[str, Multistreamable] | _prefetch_embeddings() |
PrefetchPipelinedForward |
prefetch → default |
"model.sparse_arch.embedding_bag_collection"),value 是 TorchRec 的 Awaitable 或已完成的数据。
Key insight: Context is a "progressively filled" dictionary container. Each pipeline stage writes specific fields and clears previous ones. Dict keys are ShardedModule FQNs, values are TorchRec Awaitables or resolved data.
Prefetch Pipeline 在 Phase 2 和 Phase 3 之间插入了 Phase 2.5: _prefetch_embeddings()。这个阶段在 prefetch_stream 上执行:
input_dist_tensors_requests pop 出 awaitable 并 wait()(等 AllToAll Phase 2 完成)ShardedModule.prefetch(ctx, dist_input)(DynamicEmb cache 预热)module_input_post_prefetch 和 module_contexts_post_prefetch这样 Phase 3(PrefetchPipelinedForward.__call__)只需从 post_prefetch 字段 pop 数据,零等待。
Prefetch Pipeline inserts Phase 2.5: _prefetch_embeddings() between Phase 2 and Phase 3. This runs on prefetch_stream:
input_dist_tensors_requests and wait() (completes AllToAll Phase 2)ShardedModule.prefetch(ctx, dist_input) (DynamicEmb cache warm-up)module_input_post_prefetch and module_contexts_post_prefetchPhase 3 (PrefetchPipelinedForward.__call__) simply pops from post_prefetch fields — zero wait.
TrainPipelineContext 是一个阶段性字典容器——Pipeline 的每个 stage 往里 "存" 数据(异步句柄或已完成结果),下一个 stage 从中 "取" 并 pop。Context 随 batch 一起入队 (contexts deque),生命周期 = 一个 batch 从 H2D 到 backward/optimizer 完成。所有字段在出队时已被 pop 清空,等待 Python GC。
Summary: TrainPipelineContext is a stage-progressive dictionary container — each pipeline stage "stores" data (async handles or resolved results), the next stage "pops" and consumes. Context enters the deque with its batch, lives through H2D → forward → backward → optimizer, and all fields are pop-emptied by dequeue time.
train_pipeline.py:304 (_create_context) · train_pipeline.py:249 (dequeue_batch) · train_pipeline.py:642 (progress)