NVIDIA recsys-examples · The core mechanism behind TorchRec pipeline overlapping
在 TorchRec 的分布式推荐系统中,ShardedModule(如 ShardedEBC)的 forward() 内部包含两个截然不同的阶段:
如果不做拆分,model(batch_i) 的 forward 过程会串行执行这两步,而 AllToAll 通信(~毫秒级)会阻塞 GPU 计算。
核心思想:把 input_dist 从 batch_i 的 forward 中"剥离"出来,提前到 batch_(i-1) 的 forward/backward 期间,在一个独立的 CUDA stream 上异步执行。这样 AllToAll 通信就被 GPU 计算完全遮盖了。
In TorchRec's distributed recommendation system, a ShardedModule (e.g., ShardedEBC) has a forward() that internally performs two distinct phases:
Without splitting, model(batch_i) runs these sequentially, and AllToAll (~milliseconds) blocks GPU compute.
Core idea: Extract input_dist from batch_i's forward, move it to execute during batch_(i-1)'s forward/backward on a separate CUDA stream. This fully hides AllToAll latency behind GPU compute.
TrainPipelineContext 是拆分手术的关键:它是 _start_data_dist() 的输出端和 PipelinedForward.__call__() 的输入端之间的共享状态。
TrainPipelineContext is the key to the split: it's the shared state between the output of _start_data_dist() and the input of PipelinedForward.__call__().
| 字段Field | 写入者Writer | 读取者Reader | 含义Meaning |
|---|---|---|---|
input_dist_splits_requests |
_start_data_dist |
_fuse_input_dist_splits |
AllToAll splits phase 的 awaitable(第一阶段通信)AllToAll splits phase awaitable (first comm phase) |
fused_splits_awaitables |
_fuse_input_dist_splits |
wait_sparse_data_dist |
融合后的 splits awaitable(按 PG 分组)Fused splits awaitable (grouped by PG) |
input_dist_tensors_requests |
wait_sparse_data_dist |
PipelinedForward / _prefetch_embeddings |
AllToAll tensors phase 的 awaitable(第二阶段通信)AllToAll tensors phase awaitable (second comm phase) |
module_contexts |
wait_sparse_data_dist |
PipelinedForward / _prefetch_embeddings |
ShardedModule 的上下文(包含 sharding 元数据)ShardedModule context (sharding metadata) |
module_input_post_prefetch |
_prefetch_embeddings |
PrefetchPipelinedForward |
prefetch 后的 KJT 数据(已完成通信+预取)Post-prefetch KJT data (comm complete + prefetched) |
module_contexts_post_prefetch |
_prefetch_embeddings |
PrefetchPipelinedForward |
prefetch 后的模块上下文Post-prefetch module context |
_rewrite_model 有两条代码路径。默认路径使用 torch.fx.Tracer 进行符号追踪,分析计算图中 ShardedModule 的调用方式。Hack 路径(当前实际使用)直接在 batch 中查找 KJT 属性,跳过 FX 追踪。
_rewrite_model has two code paths. The default path uses torch.fx.Tracer for symbolic tracing, analyzing how ShardedModules are called in the computation graph. The hack path (currently active) directly finds KJT attributes in the batch, bypassing FX tracing.
_start_data_dist() 需要在 model(batch) 之前就调用 ShardedModule.input_dist()。但 input_dist 需要从 batch 中提取 KJT 参数——这在原始 model graph 中是由 getattr 链完成的。ArgInfo 记录了这条提取路径,让 _build_args_kwargs(batch, forward.args) 能够独立于模型 forward 来提取参数。
Why ArgInfo? In the pipeline, _start_data_dist() must call ShardedModule.input_dist() before model(batch). But input_dist needs KJT args extracted from the batch — originally done by getattr chains in the model graph. ArgInfo records this extraction path so _build_args_kwargs(batch, forward.args) can extract args independently of the model forward.
utils.py:1309–1458 · utils.py:678–712 (Tracer) · utils.py:1191–1231 (_get_node_args)
这两个类继承自 BaseForward,是 forward 替换的具体实现。它们的 __call__ 方法在 model(batch) 时被触发——但不再调用 input_dist,而是从 Context 中取出已经计算好的结果。
Both classes inherit from BaseForward and implement the forward replacement. Their __call__ is triggered during model(batch) — but instead of calling input_dist, they retrieve pre-computed results from the Context.
PipelinedForward 在 forward 时仍需 request.wait() 等待 AllToAll 完成。PrefetchPipelinedForward 的 wait 已经在 _prefetch_embeddings(prefetch_stream 上)中完成了,所以 forward 是真正的零等待。
Key difference: PipelinedForward still needs request.wait() during forward to wait for AllToAll. PrefetchPipelinedForward's wait already happened in _prefetch_embeddings (on prefetch_stream), so forward is truly zero-wait.
utils.py:503–537 (PipelinedForward) · utils.py:540–584 (PrefetchPipelinedForward)
对于 DynamicEmb 等支持 prefetch() 的 embedding table,Pipeline 做了第三次拆分:把 AllToAll 等待 + cache 预热 也从 forward 中分离出来,移到独立的 prefetch_stream 上:
module.prefetch()(prefetch_stream)compute_and_output_dist(default_stream)For embedding tables supporting prefetch() (e.g., DynamicEmb), the pipeline adds a third split: separating AllToAll wait + cache warm-up from forward into its own prefetch_stream:
module.prefetch() (prefetch_stream)compute_and_output_dist (default_stream)_rewrite_model 记住 batch 到 ShardedModule 的参数提取路径(ArgInfo),然后把 forward() 替换为只做 compute_and_output_dist() 的 PipelinedForward。input_dist() 被提前到上一个 batch 的时间段内异步执行,通过 TrainPipelineContext 在两个 stage 之间传递中间结果。
Summary: The pipeline's core is a "forward hijack surgery" — _rewrite_model records the batch-to-ShardedModule argument extraction path (ArgInfo), then replaces forward() with a PipelinedForward that only does compute_and_output_dist(). input_dist() is moved to execute asynchronously during the previous batch's time window, with TrainPipelineContext bridging intermediate results between the two stages.