Forward Hijack — Splitting ShardedModule.forward()

NVIDIA recsys-examples · The core mechanism behind TorchRec pipeline overlapping

问题 — 为什么需要拆分 forward?The Problem — Why Split forward()?

在 TorchRec 的分布式推荐系统中,ShardedModule(如 ShardedEBC)的 forward() 内部包含两个截然不同的阶段:

  1. input_dist(ctx, kjt) — 执行 AllToAll 通信,将 KJT(KeyedJaggedTensor)按 sharding plan 分发到各 rank
  2. compute_and_output_dist(ctx, data) — 执行 本地 embedding 查找,然后将结果 AllToAll 回传

如果不做拆分,model(batch_i) 的 forward 过程会串行执行这两步,而 AllToAll 通信(~毫秒级)会阻塞 GPU 计算。

核心思想:把 input_dist 从 batch_i 的 forward 中"剥离"出来,提前到 batch_(i-1) 的 forward/backward 期间,在一个独立的 CUDA stream 上异步执行。这样 AllToAll 通信就被 GPU 计算完全遮盖了。

In TorchRec's distributed recommendation system, a ShardedModule (e.g., ShardedEBC) has a forward() that internally performs two distinct phases:

  1. input_dist(ctx, kjt) — Executes AllToAll communication, distributing KJT (KeyedJaggedTensor) to ranks per the sharding plan
  2. compute_and_output_dist(ctx, data) — Performs local embedding lookup, then AllToAll redistributes results

Without splitting, model(batch_i) runs these sequentially, and AllToAll (~milliseconds) blocks GPU compute.

Core idea: Extract input_dist from batch_i's forward, move it to execute during batch_(i-1)'s forward/backward on a separate CUDA stream. This fully hides AllToAll latency behind GPU compute.

ShardedModule.forward() 内部结构ShardedModule.forward() Internals
ShardedModule.forward(kjt) — original, unsplit ① input_dist(ctx, kjt) AllToAll — NETWORK BOUND ② compute_and_output_dist(ctx, data) Embedding lookup — GPU BOUND ↓ Pipeline splits this into two stages on different streams ↓ ① _start_data_dist(batch_i+1) on data_dist_stream — during batch_i fwd/bwd ② PipelinedForward(batch_i) on default_stream — compute_and_output_dist only Context batch_i+1 — next batch (pipelined ahead) batch_i — current batch (forward pass)

拆分手术 — 4 个关键步骤The Split — 4 Key Steps

Forward 拆分的完整流程Complete Forward-Split Flow_rewrite_model → _start_data_dist → PipelinedForward
_rewrite_model
ShardedModule
TrainPipelineContext
_start_data_dist
PipelinedForward
STEP 1: ONE-TIME SETUP (_rewrite_model, called during fill_pipeline) named_modules() → find ShardedModule inspect batch → build ArgInfo(input_attrs, is_getitems) module.forward = PipelinedForward(name, args, module, ctx, stream) forward is now a PipelinedForward instance ArgInfo stores "how to extract KJT from batch": input_attrs=["", "sparse_features"] → batch.sparse_features This lets _start_data_dist call input_dist() without going through model() STEP 2: EACH ITER — _start_data_dist(batch_i+1) [on data_dist_stream] args = _build_args_kwargs(batch_i+1, forward.args) reads ArgInfo module.input_dist(ctx, *args) → Awaitable context.input_dist_tensors_requests[name] = awaitable STEP 3: model(batch_i) → PipelinedForward.__call__() [on default_stream] request = context.input_dist_tensors_requests.pop(name) data = request.wait() [wait on data_dist_stream] module.compute_and_output_dist(ctx, data) RESULT: model() gets embedding output as if forward() was called normally • The model graph is unchanged — only ShardedModule.forward is monkey-patched • All non-ShardedModule layers (dense layers, MLP, etc.) execute normally on default_stream • The batch_i+1 input_dist was already running during batch_i compute → AllToAll hidden!

Context — 拆分后的桥梁Context — The Bridge Between Split Halves

TrainPipelineContext 是拆分手术的关键:它是 _start_data_dist() 的输出端和 PipelinedForward.__call__() 的输入端之间的共享状态。

TrainPipelineContext is the key to the split: it's the shared state between the output of _start_data_dist() and the input of PipelinedForward.__call__().

Context 数据流Context Data FlowTrainPipelineContext vs PrefetchTrainPipelineContext
TrainPipelineContext — Base Pipeline TrainPipelineContext input_dist_splits_requests: Dict[str, Awaitable] input_dist_tensors_requests: Dict[str, Awaitable] module_contexts: Dict[str, Multistreamable] fused_splits_awaitables: List[(...)] module_contexts_next_batch: Dict[...] _start_data_dist() on data_dist_stream writes splits writes ctx wait_sparse_data_dist() fuse + wait → tensors_requests PipelinedForward on default_stream pops tensors_req pops module_ctx PrefetchTrainPipelineContext — Adds Prefetch Layer PrefetchTrainPipelineContext (extends TrainPipelineContext) module_input_post_prefetch module_contexts_post_prefetch (inherits all base fields above) _prefetch_embeddings() on prefetch_stream pops input_dist_tensors_requests → wait() + prefetch() → stores into post_prefetch PrefetchPipelinedFwd pops from post_prefetch → ZERO WAIT! Key: Base pipeline waits for AllToAll in forward(). Prefetch pipeline waits in a background stage → forward has zero AllToAll latency.
字段Field 写入者Writer 读取者Reader 含义Meaning
input_dist_splits_requests _start_data_dist _fuse_input_dist_splits AllToAll splits phase 的 awaitable(第一阶段通信)AllToAll splits phase awaitable (first comm phase)
fused_splits_awaitables _fuse_input_dist_splits wait_sparse_data_dist 融合后的 splits awaitable(按 PG 分组)Fused splits awaitable (grouped by PG)
input_dist_tensors_requests wait_sparse_data_dist PipelinedForward / _prefetch_embeddings AllToAll tensors phase 的 awaitable(第二阶段通信)AllToAll tensors phase awaitable (second comm phase)
module_contexts wait_sparse_data_dist PipelinedForward / _prefetch_embeddings ShardedModule 的上下文(包含 sharding 元数据)ShardedModule context (sharding metadata)
module_input_post_prefetch _prefetch_embeddings PrefetchPipelinedForward prefetch 后的 KJT 数据(已完成通信+预取)Post-prefetch KJT data (comm complete + prefetched)
module_contexts_post_prefetch _prefetch_embeddings PrefetchPipelinedForward prefetch 后的模块上下文Post-prefetch module context

_rewrite_model — Forward 替换的两条路径_rewrite_model — Two Paths for Forward Replacement

_rewrite_model 有两条代码路径。默认路径使用 torch.fx.Tracer 进行符号追踪,分析计算图中 ShardedModule 的调用方式。Hack 路径(当前实际使用)直接在 batch 中查找 KJT 属性,跳过 FX 追踪。

_rewrite_model has two code paths. The default path uses torch.fx.Tracer for symbolic tracing, analyzing how ShardedModules are called in the computation graph. The hack path (currently active) directly finds KJT attributes in the batch, bypassing FX tracing.

Path A: mod_with_trace() — torch.fx Symbolic Trace

# 1. Determine trace depth — only trace up to ShardedModule leaf_modules = _get_leaf_module_names(model) # 2. Custom Tracer: ShardedModule/FSDP are leaf nodes tracer = Tracer(leaf_modules=leaf_modules) graph = tracer.trace(model, concrete_args) # 3. For each call_module node that targets a ShardedModule: for node in graph.nodes: if node.op == "call_module" and node.target in sharded_modules: # 4. Extract ArgInfo by walking the FX graph backward arg_info_list = _get_node_args(model, node, ...) # ArgInfo captures: batch.attr1.attr2 → input_attrs=["attr1","attr2"] # 5. Replace forward child.forward = PipelinedForward( node.target, arg_info_list, child, context, dist_stream )

Path B: mod_directly() — Direct KJT Lookup (Active)

# 1. Scan batch.__dict__ for KeyedJaggedTensor attribute for attr_name, attr_value in batch.__dict__.items(): if isinstance(attr_value, KeyedJaggedTensor): kjt_name = attr_name # e.g. "sparse_features" # 2. For each ShardedModule, build ArgInfo directly for n, module in sharded_modules.items(): module.forward = PipelinedForward( n, [ArgInfo( input_attrs=["", kjt_name], # batch → batch.sparse_features is_getitems=[False, False], name=None, postproc_modules=[None, None], constants=[None, None], )], module, context, dist_stream, )
为什么需要 ArgInfo? 在 pipeline 中,_start_data_dist() 需要在 model(batch) 之前就调用 ShardedModule.input_dist()。但 input_dist 需要从 batch 中提取 KJT 参数——这在原始 model graph 中是由 getattr 链完成的。ArgInfo 记录了这条提取路径,让 _build_args_kwargs(batch, forward.args) 能够独立于模型 forward 来提取参数。 Why ArgInfo? In the pipeline, _start_data_dist() must call ShardedModule.input_dist() before model(batch). But input_dist needs KJT args extracted from the batch — originally done by getattr chains in the model graph. ArgInfo records this extraction path so _build_args_kwargs(batch, forward.args) can extract args independently of the model forward.
torch.fx 追踪原理torch.fx Tracing Principlemod_with_trace() — symbolic execution
_rewrite_model
torch.fx.Tracer
FX Graph
_get_node_args
ShardedModule
unwrap: DMP → DDP → Float16 → Module _get_leaf_module_names(model) Shallow traversal: find modules adjacent to ShardedModule → mark siblings as leaf Result: trace INTO parent, NOT into leaves Tracer.trace(model, concrete_args) Symbolic execution (no real data!): • proxy_buffer_attributes = False • ShardedModule → is_leaf = True → records call_module node returns FX Graph for node in graph.nodes: node.op == "call_module" node.target in sharded_modules? _get_node_args(node) Walk graph backward from node.args: • "placeholder" → batch itself • "call_function getattr" → attr name • "call_function __getitem__" → index → ArgInfo(["", "sparse_features"]) meaning: batch.sparse_features returns arg_info_list module.forward = PipelinedForward(name, arg_info_list, module, ctx, stream) Result: ShardedModule.forward no longer calls input_dist(). It's now a PipelinedForward that picks up pre-computed results from TrainPipelineContext. The original forward is saved in _original_forwards for restore.

PipelinedForward vs PrefetchPipelinedForwardPipelinedForward vs PrefetchPipelinedForward

这两个类继承自 BaseForward,是 forward 替换的具体实现。它们的 __call__ 方法在 model(batch) 时被触发——但不再调用 input_dist,而是从 Context 中取出已经计算好的结果。

Both classes inherit from BaseForward and implement the forward replacement. Their __call__ is triggered during model(batch) — but instead of calling input_dist, they retrieve pre-computed results from the Context.

PipelinedForward.__call__()

# Called during model(batch_i) on default_stream # 1. Pop the awaitable from context (written by _start_data_dist) request = self._context.input_dist_tensors_requests.pop(self._name) # 2. Wait for AllToAll to complete (on data_dist_stream) with cuda.stream(self._stream): # data_dist_stream data = request.wait() # ← BLOCKING WAIT HERE # 3. Transfer data to default_stream current_stream.wait_stream(self._stream) data.record_stream(current_stream) # 4. Only do embedding compute (input_dist already done!) return self._module.compute_and_output_dist(ctx, data)

PrefetchPipelinedForward.__call__()

# Called during model(batch_i) on default_stream # 1. Pop ALREADY-PREFETCHED data (written by _prefetch_embeddings) data = self._context.module_input_post_prefetch.pop(self._name) ctx = self._context.module_contexts_post_prefetch.pop(self._name) # 2. Transfer from prefetch_stream to default_stream current_stream.wait_stream(self._stream) data.record_stream(current_stream) # 3. Only do embedding compute (NO WAIT AT ALL!) return self._module.compute_and_output_dist(ctx, data) # ↑ The wait() already happened in _prefetch_embeddings # on a separate prefetch_stream during the previous step
关键差异PipelinedForward 在 forward 时仍需 request.wait() 等待 AllToAll 完成。PrefetchPipelinedForward 的 wait 已经在 _prefetch_embeddings(prefetch_stream 上)中完成了,所以 forward 是真正的零等待。 Key difference: PipelinedForward still needs request.wait() during forward to wait for AllToAll. PrefetchPipelinedForward's wait already happened in _prefetch_embeddings (on prefetch_stream), so forward is truly zero-wait.

Prefetch 层 — 第三次拆分Prefetch Layer — The Third Split

对于 DynamicEmb 等支持 prefetch() 的 embedding table,Pipeline 做了第三次拆分:把 AllToAll 等待 + cache 预热 也从 forward 中分离出来,移到独立的 prefetch_stream 上:

  • _start_data_dist — 启动 AllToAll(data_dist_stream)
  • _prefetch_embeddings — 等待 AllToAll + 调用 module.prefetch()(prefetch_stream)
  • PrefetchPipelinedForward — 直接拿预取结果做 compute_and_output_dist(default_stream)

For embedding tables supporting prefetch() (e.g., DynamicEmb), the pipeline adds a third split: separating AllToAll wait + cache warm-up from forward into its own prefetch_stream:

  • _start_data_dist — starts AllToAll (data_dist_stream)
  • _prefetch_embeddings — waits for AllToAll + calls module.prefetch() (prefetch_stream)
  • PrefetchPipelinedForward — directly uses prefetched results for compute_and_output_dist (default_stream)
三级拆分 — ShardedModule.forward() 的完整分解Three-Level Split — Full Decomposition of ShardedModule.forward()
ORIGINAL (no pipeline): ShardedModule.forward(kjt) = input_dist + compute_and_output_dist BASE PIPELINE (2-stage split): _start_data_dist → input_dist data_dist_stream (batch_i+1) Context PipelinedForward → wait + compute default_stream (batch_i) — still waits for AllToAll PREFETCH PIPELINE (3-stage split): _start_data_dist data_dist_stream Ctx _prefetch_embeddings prefetch_stream — wait + prefetch() Ctx PrefetchPipelinedFwd default_stream — ZERO WAIT! Each stage runs on a different CUDA stream → all three stages from different batches overlap in time!

时间线 — 拆分后的执行效果Timeline — Execution After Split

Prefetch Pipeline 稳态时间线(3 batch in-flight)Prefetch Pipeline Steady-State Timeline (3 batches in-flight)
time → memcpy_stream data_dist_stream prefetch_stream default_stream CPU thread H2D batch_i+2 input_dist batch_i+1 (AllToAll) prefetch batch_i fwd batch_i bwd + optim batch_i input_dist batch_i+2 prefetch batch_i+1 wait sdd i+2 fwd prefetch All 4 streams active simultaneously
总结:Pipeline 的核心是一次 "forward 截取手术"——用 _rewrite_model 记住 batch 到 ShardedModule 的参数提取路径(ArgInfo),然后把 forward() 替换为只做 compute_and_output_dist() 的 PipelinedForward。input_dist() 被提前到上一个 batch 的时间段内异步执行,通过 TrainPipelineContext 在两个 stage 之间传递中间结果。 Summary: The pipeline's core is a "forward hijack surgery" — _rewrite_model records the batch-to-ShardedModule argument extraction path (ArgInfo), then replaces forward() with a PipelinedForward that only does compute_and_output_dist(). input_dist() is moved to execute asynchronously during the previous batch's time window, with TrainPipelineContext bridging intermediate results between the two stages.