DynamicEmb Training Pipeline — 序列图解析Sequence Diagram Analysis

NVIDIA recsys-examples · DynamicEmb Embedding Layer · RW Sharding

⭐ GitHub: shijieliu/recsys-examples (fea-dynamicemb_table_fusion)

DynamicEmb 训练迭代总览DynamicEmb Training Iteration Overview

DynamicEmb 是 NVIDIA recsys-examples 中的动态嵌入层,支持 GPU 加速的 hash table、LRU/LFU 缓存、以及融合优化器。DynamicEmb is a dynamic embedding layer in NVIDIA recsys-examples, featuring GPU-accelerated hash tables, LRU/LFU caching, and fused optimizers.

在分布式训练的一个 iteration 中,DynamicEmb 的 Embedding 层经历以下 5 个阶段。前两个阶段处理 input keys 的跨 rank 通信(发生在 DynamicEmb 正式 forward 之前),后三个阶段是 DynamicEmb 自身的 prefetch → forward → backward 流程。

During one iteration of distributed training, DynamicEmb's embedding layer goes through the following 5 phases. The first two handle cross-rank communication of input keys (before DynamicEmb's forward pass), and the latter three are DynamicEmb's own prefetch → forward → backward pipeline.

📌 文档讨论范围说明

本文档聚焦于 DynamicEmb Embedding 层在一次训练迭代中的完整流程,覆盖 TorchRec TrainPipelineSparseDist 的 6 个阶段:

  • input_dist / wait_sparse(Phase 1-2):详细的内部序列图,涵盖 dedup、bucketize、AllToAll 发起与等待的每一步
  • prefetch(Phase 3):DynamicEmb 的 key 解析与缓存管理(Cache / HBM_DIRECT / DEFAULT 三种路径)
  • forward(Phase 4):DynamicEmb 自身的 embedding lookup 逻辑
  • output_dist(Phase 5):详细的内部序列图,展示 RW Sequence Sharding 下 AllToAll 将 embedding 结果重新分配到各 rank 的完整调用链
  • backward(Phase 6):梯度更新与融合优化器
📌 Scope of This Document

This document covers the complete flow of the DynamicEmb Embedding layer during one training iteration, spanning all 6 phases of TorchRec's TrainPipelineSparseDist:

  • input_dist / wait_sparse (Phase 1-2): Detailed internal sequence diagrams covering every step of dedup, bucketize, AllToAll initiation and awaiting
  • prefetch (Phase 3): DynamicEmb's key resolution and cache management (Cache / HBM_DIRECT / DEFAULT paths)
  • forward (Phase 4): DynamicEmb's own embedding lookup logic
  • output_dist (Phase 5): Detailed internal sequence diagram showing the full call chain for AllToAll redistribution of embedding results under RW Sequence Sharding
  • backward (Phase 6): Gradient update and fused optimizer

StorageMode 三种模式详解StorageMode: Three Storage Modes Explained

DynamicEmb 根据 用户配置可用 HBM 容量 自动选择三种内部存储模式之一(StorageMode),决定了 prefetch / forward / backward 各阶段的执行路径。

DynamicEmb automatically selects one of three internal storage modes based on user configuration and available HBM capacity (StorageMode), determining the execution path for prefetch / forward / backward phases.

模式Mode 触发条件Trigger 存储架构Storage Architecture 各阶段行为Phase Behavior 适用场景Use Cases
CACHE caching=True
且 embedding 总大小 > local_hbm_for_values and total embedding size > local_hbm_for_values
HBM CacheDynamicEmbCache)+ Host/PS StorageDynamicEmbStorage / external_storage
GPU 表作为 hot key 缓存,cold key 驻留 host 端 GPU table acts as hot-key cache; cold keys reside on host
prefetch: _prefetch_cache_path(hash lookup → cache hit/miss → evict → store) (hash lookup → cache hit/miss → evict → store)
forward: load_from_flat(slot_indices 直接读,无 hash) (direct read via slot_indices, no hash)
backward: fused_update_for_flat_table + decrement_counter
Embedding table 远大于 GPU 显存,稀疏特征有时间局部性,热 key 缓存在 HBM 可显著提速Embedding table far exceeds GPU memory; sparse features have temporal locality; caching hot keys in HBM yields significant speedup
HBM_DIRECT caching=False
且 embedding 总大小 ≤ local_hbm_for_values and total embedding size ≤ local_hbm_for_values
单一 HBM Hash TableDynamicEmbStorage.is_cuda = True
所有 embedding + optimizer 状态全部在 GPU 显存
Single HBM Hash Table (DynamicEmbStorage, .is_cuda = True)
All embeddings + optimizer states reside entirely in GPU memory
prefetch: _prefetch_hbm_direct_path(hash lookup → 无 evict,直接 slot 分配) (hash lookup → no eviction, direct slot allocation)
forward: load_from_flat(同 CACHE 路径) (same as CACHE path)
backward: fused_update_for_flat_table + decrement_counter
Embedding table 可完全放入 GPU 显存,性能最佳(无 host ↔ GPU 数据搬运)Embedding table fits entirely in GPU memory — best performance (no host ↔ GPU data transfer)
DEFAULT caching=False
且 embedding 总大小 > local_hbm_for_values and total embedding size > local_hbm_for_values
HybridStorage(HBM tier + Host tier 拼接,无缓存层)
key 由 hash 函数决定存放位置(HBM 或 host),不可主动调度
HybridStorage (HBM tier + Host tier hybrid, no cache layer)
Key placement (HBM or host) determined by hash function, not actively managed
prefetch: 仅做 dedup(segmented_unique),跳过 hash lookup / slot 分配(slot_indices = None
forward: _generic_forward_path(storage.find hash lookup → admission → init → insert)
backward: update_for_padded_buffer + storage.insert(re-insert)
⚠ 无法利用 prefetch stream 与 compute 重叠隐藏 lookup 延迟
prefetch: Only dedup (segmented_unique), skips hash lookup / slot allocation (slot_indices = None)
forward: _generic_forward_path (storage.find hash lookup → admission → init → insert)
backward: update_for_padded_buffer + storage.insert (re-insert)
⚠ Cannot overlap prefetch stream with compute to hide lookup latency
不需要缓存或缓存收益不大的场景;HBM + Host 拼接模式,key 的存储位置由 hash 决定Cases where caching is unnecessary or offers little benefit; HBM + Host hybrid mode, key placement determined by hash
用户接口:模式选择由 DynamicEmbTableOptions 的两个核心参数决定:
  • caching: bool = False — 是否启用 GPU 缓存模式
  • global_hbm_for_values: int = 0 — 分配给 embedding + optimizer 状态的 GPU 显存(bytes)

自动判定逻辑(_create_cache_storage()):

total_memory = sum(init_capacity * dtype_bytes * value_dim)  # embedding + optimizer states
if total_memory > local_hbm:
    if caching:     → CACHE  (DynamicEmbCache + Storage backend)
    else:           → DEFAULT (HybridStorage: HBM + Host 拼接)
else:
    → HBM_DIRECT (单一 DynamicEmbStorage,全在 GPU)

运行时通过 _is_hbm_storage(storage) 判断是否为 HBM_DIRECT(isinstance(storage, DynamicEmbStorage) and storage._state.tables[0].is_cuda)。

User Interface: Mode selection is controlled by two key parameters of DynamicEmbTableOptions:
  • caching: bool = False — whether to enable GPU cache mode
  • global_hbm_for_values: int = 0 — GPU memory allocated for embeddings + optimizer states (bytes)

Auto-detection logic (_create_cache_storage()):

total_memory = sum(init_capacity * dtype_bytes * value_dim)  # embedding + optimizer states
if total_memory > local_hbm:
    if caching:     → CACHE  (DynamicEmbCache + Storage backend)
    else:           → DEFAULT (HybridStorage: HBM + Host hybrid)
else:
    → HBM_DIRECT (single DynamicEmbStorage, all on GPU)

At runtime, _is_hbm_storage(storage) checks for HBM_DIRECT (isinstance(storage, DynamicEmbStorage) and storage._state.tables[0].is_cuda).

图中的对应关系: input_dist / wait_sparse 阶段与 StorageMode 无关; prefetch / forward / backward 图中使用 alt 帧区分不同路径。CACHE 与 HBM_DIRECT 共享 use_counter=True 分支(区别仅在于 state 来源),DEFAULT 走 use_counter=False 分支。
Diagram Mapping: The input_dist / wait_sparse phases are independent of StorageMode; prefetch / forward / backward diagrams use alt frames to distinguish paths. CACHE and HBM_DIRECT share the use_counter=True branch (differing only in state source), while DEFAULT takes the use_counter=False branch.
input_dist
wait_sparse
sub-stage 1
prefetch
含 wait_sparse sub-stage 2incl. wait_sparse sub-stage 2
forward
output_dist
backward
阶段Phase 关键操作Key Operations CUDA Stream D2H Stall H2D Transfer 核心文件Core Files
input_dist dedup keys → bucketize → 发起 splits AllToAll (async)dedup keys → bucketize → initiate splits AllToAll (async) data_dist_stream 2× D2H
num_uniques.item()
length_per_key.tolist()
4× H2D
dist_type, batch_size_per_split,
length_per_split, stride tensors
shard/embedding.py, input_dist.py
wait_sparse
sub-stage 1
等 splits AllToAll → 发起 lengths/values AllToAllawait splits AllToAll → initiate lengths/values AllToAll data_dist_stream 1× D2H
output_splits.tolist()
1× H2D
recat permute index
train_pipeline.py, dist_data.py
wait_sparse
sub-stage 2
⚠ 已延迟到 prefetch⚠ deferred to prefetch
等待 lengths/values AllToAll 完成 → build KJT → permuteawait lengths/values AllToAll → build KJT → permute data_dist_stream 1× D2H
length_per_key.tolist()
None utils.py (_prefetch_embeddings)
prefetch
Cache/HBM path
dedup → hash lookup → cache hit/miss → storage fallback → insert & evict prefetch_stream (separate) 多处 .item()/.any() 隐式同步multiple .item()/.any() implicit syncs None batched_dynamicemb_function.py
prefetch
DEFAULT path
仅 dedup(segmented_unique),跳过 hash lookup / slot 分配dedup only (segmented_unique), skip hash lookup / slot allocation prefetch_stream None None batched_dynamicemb_function.py
forward
Cache path
consume PrefetchState → load_from_flat → gather_embedding_pooled default_stream 无 ✅None ✅ 无 ✅None ✅ batched_dynamicemb_function.py
forward
DEFAULT path
_generic_forward_path: storage.find → admission → init → insert → gather_embedding_pooled default_stream None None batched_dynamicemb_function.py
output_dist
RW Sequence
permute_2D → AllToAll (async) → wait → index_selectpermute_2D → AllToAll (async) → wait → index_select default_stream 无 ✅None ✅ 无 ✅None ✅ rw_sequence_sharding.py, dist_data.py, comm_ops.py
backward
Cache/HBM path
reduce_grads → fused_update_for_flat_table → decrement_counter default_stream 无 ✅None ✅ 无 ✅None ✅ batched_dynamicemb_function.py, optimizer.py
backward
DEFAULT path
reduce_grads → update_for_padded_buffer → storage.insert (re-insert) default_stream 无 ✅None ✅ 无 ✅None ✅ batched_dynamicemb_function.py, optimizer.py
Pipeline 机制:TorchRec TrainPipelineSparseDist 通过 pipeline rewrite 将 input_dist / wait_sparse 与上一批次的 forward 重叠执行。关键优化:wait_sparse 的第二个子阶段(等待 lengths/values AllToAll 完成 + rebuild KJT)被延迟到 _prefetch_embeddings() 入口处执行,使得 AllToAll #2/#3 的通信时间可以被 forward compute 完全覆盖。prefetch 自身在独立的 CUDA stream 上运行,隐藏 hash lookup + cache miss + I/O 延迟。
Pipeline Mechanism: TorchRec TrainPipelineSparseDist uses pipeline rewrite to overlap input_dist / wait_sparse with the previous batch's forward. Key optimization: the second sub-stage of wait_sparse (awaiting lengths/values AllToAll completion + rebuilding KJT) is deferred to the entry of _prefetch_embeddings(), allowing AllToAll #2/#3 communication to be fully overlapped with forward compute. Prefetch itself runs on a separate CUDA stream, hiding hash lookup + cache miss + I/O latency.

Input Distribution — Keys 的跨 Rank 通信Input Distribution — Cross-Rank Key Communication

在 DynamicEmb forward 之前,需要将各 rank 的 input keys 按 RW Sharding 策略重新分配到对应的 rank。Before DynamicEmb's forward pass, input keys from each rank are redistributed according to the RW Sharding strategy.

Phase 1 input_dist — 发起 Splits AllToAllInitiate Splits AllToAll (与 caching 模式无关)(independent of caching mode)

input_dist 是 Pipeline 调用 _start_sparse_data_dist() 触发的第一个阶段,在 data_dist_stream 上执行。核心任务:

  • Dedup:调用 segmented_unique_cuda 对 batch 内的 keys 去重,减少后续通信量
  • Bucketize:通过 block_bucketize_sparse_features 将 keys 按目标 rank 分桶
  • 发起 splits AllToAll:异步发送各 rank 的 split sizes metadata(AllToAll #1)

此阶段包含 2 次 D2H 同步num_uniques.item()length_per_key().tolist())以及 4 次 H2D 传输(dist_type, batch_size_per_split, length_per_split, stride tensors),这些是 nsys timeline 上可观察到的 cudaMemcpy。

input_dist is the first phase triggered by the pipeline's _start_sparse_data_dist(), executing on data_dist_stream. Core tasks:

  • Dedup: calls segmented_unique_cuda to deduplicate keys within the batch, reducing subsequent communication volume
  • Bucketize: uses block_bucketize_sparse_features to bucket keys by target rank
  • Initiate splits AllToAll: asynchronously sends split sizes metadata for each rank (AllToAll #1)

This phase involves 2 D2H syncs (num_uniques.item() and length_per_key().tolist()) and 4 H2D transfers (dist_type, batch_size_per_split, length_per_split, stride tensors), visible as cudaMemcpy on the nsys timeline.

input_dist_sync_input_dist.svg 滚动查看完整时序图Scroll to view full sequence diagram
Pipeline.progress() ShardedDynEmbCollection.input_dist() RwSparseFeatDist.forward() KJTAllToAll.forward() SplitsAllToAll.forward() GPUdata_dist_stream NCCLAllToAll
input_dist sequence diagram
D2H #1 (num_uniques.item()) 会触发 GPU→CPU 同步等待,因为 CPU 需要知道去重后的 key 数量来做 tensor slice。这是目前 DynamicEmb 在 input_dist 阶段不可避免的 stall。
D2H #1 (num_uniques.item()) triggers a GPU→CPU synchronization wait because the CPU needs the deduplicated key count for tensor slicing. This is an unavoidable stall in DynamicEmb's input_dist phase.

Phase 2 wait_sparse_data_dist — 两个子阶段(pipeline rewrite 拆分执行)Two Sub-stages (split by pipeline rewrite) (与 caching 模式无关)(independent of caching mode)

wait_sparse_data_dist 包含 两个子阶段,经过 train_pipeline 的 rewrite 后被拆分执行:

  • Sub-stage 1_wait_sparse_data_dist(),在 forward 之前):等待 AllToAll #1(splits metadata)完成 → 用 output_splits 发起 AllToAll #2(lengths)+ AllToAll #3(values)
  • Sub-stage 2已被延迟到 _prefetch_embeddings() 内部):等待 AllToAll #2 + #3 完成 → 重建 KJT(KeyedJaggedTensor)→ permute_2D_sparse_data 重排序

这种 rewrite 的核心目的:Sub-stage 1 发起 lengths/values AllToAll 后立即返回,让 forward(batch_i) 在中间执行,此时 AllToAll #2/#3 的 NCCL 通信可以与 forward compute 完全重叠。Sub-stage 2 延迟到 _prefetch_embeddings() 入口处再等待通信完成。具体代码在 utils.py:1639_prefetch_embeddings 仅 wrap 了 request.wait(),prefetch 逻辑在其外部)。

wait_sparse_data_dist consists of two sub-stages, split by the train_pipeline rewrite:

  • Sub-stage 1 (_wait_sparse_data_dist(), before forward): await AllToAll #1 (splits metadata) completion → use output_splits to initiate AllToAll #2 (lengths) + AllToAll #3 (values)
  • Sub-stage 2 (deferred to _prefetch_embeddings()): await AllToAll #2 + #3 completion → rebuild KJT (KeyedJaggedTensor) → permute_2D_sparse_data reorder

Purpose of this rewrite: Sub-stage 1 returns immediately after initiating lengths/values AllToAll, allowing forward(batch_i) to execute in between, during which AllToAll #2/#3 NCCL communication fully overlaps with forward compute. Sub-stage 2 defers waiting for communication completion to the entry of _prefetch_embeddings(). See utils.py:1639 (_prefetch_embeddings wraps request.wait(); prefetch logic is external to it).

input_dist_sync_wait_sparse.svg 滚动查看完整时序图Scroll to view full sequence diagram
Pipeline.progress() KJTAllToAllSplitsAwaitable SplitsAllToAllAwaitable._wait_impl() KJTAllToAllTensorsAwaitable GPUdata_dist_stream NCCLAllToAll
wait_sparse sequence diagram
D2H #1 (outputs.tolist())outputs.tolist() 的开销——需要等 AllToAll #1 结束后把 splits metadata 搬回 CPU。这个 stall 是整个 input_dist 流程中最大的,因为它同时包含了 NCCL 通信等待和 stream 同步。
D2H #1 (outputs.tolist()) is the cost of outputs.tolist() — waiting for AllToAll #1 to finish and moving splits metadata back to CPU. This stall is the largest in the entire input_dist flow, as it includes both NCCL communication wait and stream synchronization.

DynamicEmb Embedding — Forward / Backward 流程DynamicEmb Embedding — Forward / Backward Pipeline

keys 通信完成后,进入 DynamicEmb 自身的三阶段流程:prefetch(key 解析)→ forward(embedding 读取)→ backward(梯度更新)。After key communication completes, DynamicEmb enters its own three-phase pipeline: prefetch (key resolution) → forward (embedding lookup) → backward (gradient update).

Phase 3 prefetch — Key 解析与缓存管理Key Resolution & Cache Management (图中 alt 帧:Cache Path + HBM Direct Path)(alt frames: Cache Path + HBM Direct Path)

实际执行时,_prefetch_embeddings() 入口处会先执行 wait_sparse sub-stage 2(等待 lengths/values AllToAll 完成 + rebuild KJT),然后再进入 DynamicEmb 自身的 prefetch 逻辑。

Prefetch 在 独立的 CUDA stream 上运行,目标是将 unique keys 解析为 flat-table slot 位置(slot_indices),供 forward 直接使用。核心路径分两种:

At execution time, _prefetch_embeddings() first executes wait_sparse sub-stage 2 at its entry (awaiting lengths/values AllToAll completion + rebuilding KJT), then proceeds to DynamicEmb's own prefetch logic.

Prefetch runs on a separate CUDA stream, aiming to resolve unique keys into flat-table slot positions (slot_indices) for direct use by forward. Two core paths:

Cache Path (caching=True) HBM Direct Path (no cache)
适用场景Use Case Zipf 分布、热点 key 集中Zipf distribution, concentrated hot keys 均匀分布或表完全放入 HBMUniform distribution or table fits entirely in HBM
流程Flow cache.lookup区分 hit/missseparate hit/miss
storage.find(miss fallback) (miss fallback)
cache.insert_and_evict(LRU/LFU)→
load_from_flat / store_to_flat(搬迁数据) (data migration)
storage.insert(write-back evicted) (write-back evicted)
_find_keys(hash table lookup) (hash table lookup)
key_index_map.insert(新 key 入表) (insert new keys)
store_to_flat(初始化 embedding) (initialize embeddings)
隐式同步Implicit Sync 多处 .item() / .any() 同步multiple .item() / .any() syncs 较少同步fewer syncs
dynamicemb_prefetch.svg 滚动查看完整时序图 · 红色条 = 隐式 cudaStreamSynchronizeScroll to view · Red bars = implicit cudaStreamSynchronize
BatchedDynamicEmbTablesV2batched_dynamicemb_tables.py DynamicEmbStoragekey_value_table.py DynamicEmbCachekey_value_table.py CUDA Kernels(dynamicemb_ext)
prefetch sequence diagram
为什么需要独立 stream?Prefetch 与上一批次的 compute 重叠执行,这样 hash lookup + cache miss + storage I/O 的延迟被隐藏在 GPU compute 时间内。输出 PrefetchState 包含 slot_indices、unique_keys、reverse_indices 等,供 forward 零查询使用。
Why a separate stream? Prefetch overlaps with the previous batch's compute, hiding hash lookup + cache miss + storage I/O latency within GPU compute time. The output PrefetchState contains slot_indices, unique_keys, reverse_indices, etc., enabling zero-lookup forward pass.
同步热点:Prefetch 中有多处 .item() / .any() 调用(图中红色 SYNC 标记),每一处都会触发隐式 cudaStreamSynchronize。这是当前 DynamicEmb 的主要性能瓶颈之一,未来有望通过 kernel fusion 或异步化消除。
Sync Hotspots: Multiple .item() / .any() calls in prefetch (red SYNC markers in diagram) each trigger an implicit cudaStreamSynchronize. This is one of DynamicEmb's main performance bottlenecks, potentially addressable via kernel fusion or asynchronous approaches.

Phase 4 forward — Embedding 读取Embedding Lookup (图中 alt 帧:Cache/HBM_DIRECT Path + DEFAULT Path)(alt frames: Cache/HBM_DIRECT Path + DEFAULT Path)

Forward 阶段根据 use_counter(即 update_slot_indices is not None)分为两条路径:

Cache / HBM_DIRECT PathDEFAULT Path共同步骤
use_counter = True
load_from_flat:通过 PrefetchState 中的 slot_indices 直接索引读取 embedding,无 hash lookup,极快
[if non_admitted_positions] initializer 初始化新 embedding
use_counter = False
_generic_forward_path
storage.find(hash lookup 查找已有 key)
_apply_admission(过滤新 key)
initializer(初始化新 embedding)
storage.insert(持久化新 key)
gather_embedding_pooled:将 unique_embs scatter 到 output tensor,应用 SUM/MEAN pooling 和 per_sample_weights
  • 如果 PrefetchState 队列为空,会自动 fallback 调用 prefetch()(line 921-922)

The forward phase branches into two paths based on use_counter (i.e., update_slot_indices is not None):

Cache / HBM_DIRECT PathDEFAULT PathCommon Step
use_counter = True
load_from_flat: reads embeddings directly via PrefetchState's slot_indices, no hash lookup, extremely fast
[if non_admitted_positions] initializer initializes new embeddings
use_counter = False
_generic_forward_path:
storage.find (hash lookup for existing keys)
_apply_admission (filter new keys)
initializer (initialize new embeddings)
storage.insert (persist new keys)
gather_embedding_pooled: scatter unique_embs to output tensor, apply SUM/MEAN pooling and per_sample_weights
  • If the PrefetchState queue is empty, it automatically falls back to calling prefetch() (line 921-922)
dynamicemb_forward.svg alt 帧区分 Cache 与 DEFAULT 路径alt frames distinguish Cache vs DEFAULT paths
forward sequence diagram
Cache 路径为什么 forward 这么快?所有重活(hash lookup、cache miss、eviction、storage I/O)都在 prefetch 阶段的独立 stream 上完成了。Forward 只做两步:load_from_flat(按 slot 索引读)→ gather_embedding_pooled(scatter + pool),且全程无 CPU-GPU 同步。
Why is the Cache path forward so fast? All the heavy lifting (hash lookup, cache miss, eviction, storage I/O) was completed on a separate stream during prefetch. Forward only does two steps: load_from_flat (read by slot index) → gather_embedding_pooled (scatter + pool), with zero CPU-GPU synchronization.
DEFAULT 路径的代价:没有 prefetch 缓存可用,forward 内部需要执行完整的 storage.find(hash lookup)→ admission → init → insert 流程,延迟明显高于 Cache 路径。这也是推荐使用 CACHE 或 HBM_DIRECT 模式的原因之一。
Cost of DEFAULT path: Without prefetch cache, forward must execute the full storage.find (hash lookup) → admission → init → insert flow inline, resulting in noticeably higher latency than the Cache path. This is one reason CACHE or HBM_DIRECT modes are recommended.

Phase 5 output_dist — Embedding 结果的跨 Rank 重分配Cross-Rank Redistribution of Embedding Results (RW Sequence Sharding · AllToAll)(RW Sequence Sharding · AllToAll)

DynamicEmb 使用 EmbeddingCollection(sequence / unpooled embeddings),因此 output_dist 走 all_to_all_single(而非 pooled 模式的 ReduceScatterV)。RW sharding 将 embedding table 按行切分到各 rank,每个 rank 只持有部分行的 lookup 结果,需要通过 AllToAll 将结果送回请求方。

Output_dist 分为两个子阶段:

  • Async Launchcompute_and_output_dist() 内部):fbgemm.permute_2D_sparse_data 按 rank 重排 embeddings → dist.all_to_all_single(async_op=True) 异步发起 NCCL AllToAll
  • .wait()SequenceEmbeddingsAwaitable._wait_impl()):等待 AllToAll 完成 → torch.index_select 恢复原始行序

此路径 无 D2H / H2D 传输forward_recat_tensor__init__ 时预注册为 GPU buffer;input_splits / output_splits 为 Python List[int],由 NCCL backend 内部处理。唯一的同步点是 myreq.req.wait()

DynamicEmb uses EmbeddingCollection (sequence / unpooled embeddings), so output_dist uses all_to_all_single (not ReduceScatterV for pooled mode). RW sharding partitions embedding tables by rows across ranks — each rank holds only partial lookup results, requiring AllToAll to route results back to the requesting rank.

Output_dist consists of two sub-phases:

  • Async Launch (inside compute_and_output_dist()): fbgemm.permute_2D_sparse_data reorders embeddings by rank → dist.all_to_all_single(async_op=True) launches async NCCL AllToAll
  • .wait() (SequenceEmbeddingsAwaitable._wait_impl()): waits for AllToAll completion → torch.index_select restores original row order

This path has no D2H / H2D transfers: forward_recat_tensor is pre-registered as a GPU buffer in __init__; input_splits / output_splits are Python List[int], handled internally by the NCCL backend. The only sync point is myreq.req.wait().

output_dist.svg RW Sequence Sharding · 1 次 NCCL AllToAllRW Sequence Sharding · 1 NCCL AllToAll
Pipelinecompute_and_output_dist RwSequenceEmbeddingDist SequenceEmbsAllToAll SequenceEmbsAwaitable GPUdefault_stream NCCLAllToAll
output_dist sequence diagram
调用链详解:
RwSequenceEmbeddingDist.forward(local_embs, sharding_ctx)
  → SequenceEmbeddingsAllToAll.forward(local_embs, lengths, input_splits, output_splits, ...)
    → alltoall_sequence(...)                          # comm_ops.py
      → All2All_Seq_Req.forward(pg, myreq, a2ai, embs)
        1. fbgemm.permute_2D_sparse_data(recat, ...)  # GPU kernel: 按 rank 重排 embeddings
        2. dist.all_to_all_single(output, input,       # ★ NCCL AllToAll (async_op=True)
             output_split_sizes, input_split_sizes, pg, async_op=True)
      → return Request (async handle)
  → SequenceEmbeddingsAwaitable(tensor_awaitable, unbucketize_permute_tensor, D)

# .wait() 时触发:
SequenceEmbeddingsAwaitable._wait_impl()
  → All2All_Seq_Req_Wait.forward(pg, myreq)
    1. myreq.req.wait()                               # ★ NCCL sync point (阻塞等待 AllToAll 完成)
    2. return sharded_output_embeddings.view(-1, D)
  → torch.index_select(ret, 0, unbucketize_permute_tensor)  # GPU kernel: 恢复原始行序
Detailed call chain:
RwSequenceEmbeddingDist.forward(local_embs, sharding_ctx)
  → SequenceEmbeddingsAllToAll.forward(local_embs, lengths, input_splits, output_splits, ...)
    → alltoall_sequence(...)                          # comm_ops.py
      → All2All_Seq_Req.forward(pg, myreq, a2ai, embs)
        1. fbgemm.permute_2D_sparse_data(recat, ...)  # GPU kernel: reorder embeddings by rank
        2. dist.all_to_all_single(output, input,       # ★ NCCL AllToAll (async_op=True)
             output_split_sizes, input_split_sizes, pg, async_op=True)
      → return Request (async handle)
  → SequenceEmbeddingsAwaitable(tensor_awaitable, unbucketize_permute_tensor, D)

# Triggered on .wait():
SequenceEmbeddingsAwaitable._wait_impl()
  → All2All_Seq_Req_Wait.forward(pg, myreq)
    1. myreq.req.wait()                               # ★ NCCL sync point (blocks until AllToAll completes)
    2. return sharded_output_embeddings.view(-1, D)
  → torch.index_select(ret, 0, unbucketize_permute_tensor)  # GPU kernel: restore original row order
为什么 output_dist 没有 D2H / H2D?所有 tensor(recatunbucketize_permute_tensor、embeddings)都在 GPU 上;splits 是 Python List[int],PyTorch NCCL backend 内部处理。相比 input_dist 的 6 次 cudaMemcpy,output_dist 的 overhead 极低。
Why no D2H / H2D in output_dist? All tensors (recat, unbucketize_permute_tensor, embeddings) reside on GPU; splits are Python List[int], handled internally by the PyTorch NCCL backend. Compared to input_dist's 6 cudaMemcpy operations, output_dist has minimal overhead.

Phase 6 backward — 梯度更新与融合优化器Gradient Update & Fused Optimizer (图中 alt 帧:CACHE/HBM_DIRECT Path + DEFAULT Path)(alt frames: CACHE/HBM_DIRECT Path + DEFAULT Path)

Backward 由 PyTorch autograd 触发 DynamicEmbeddingFunction.backward(),核心步骤:

  • reduce_grads:反向的 gather——将 [B × total_D] 梯度归约到 [unique_N × max_D]
  • Fused optimizer update:一个 kernel 完成 read emb + apply optimizer + write back(SGD/Adam/AdaGrad/RowWiseAdaGrad 都有融合实现)
  • Post-update:CACHE/HBM_DIRECT Path 释放 refcount(decrement_counter);DEFAULT Path 需要 storage.insert 将更新后的 values re-insert 到 hash table

Backward is triggered by PyTorch autograd calling DynamicEmbeddingFunction.backward(). Core steps:

  • reduce_grads: reverse gather — reduces [B × total_D] gradients to [unique_N × max_D]
  • Fused optimizer update: a single kernel performs read emb + apply optimizer + write back (fused implementations for SGD/Adam/AdaGrad/RowWiseAdaGrad)
  • Post-update: CACHE/HBM_DIRECT path releases refcount (decrement_counter); DEFAULT path re-inserts updated values into the hash table via storage.insert
dynamicemb_backward.svg 无 cudaStreamSynchronizeNo cudaStreamSynchronize
backward sequence diagram
对比项Comparison CACHE / HBM_DIRECT (use_counter=True) DEFAULT (use_counter=False)
更新目标Update Target Cache/storage flat table(via update_slot_indices)Cache/storage flat table (via update_slot_indices) unique_values padded buffer(in memory)unique_values padded buffer (in memory)
Kernel _update_for_flat_table_kernel _update_for_padded_buffer_kernel
后处理Post-processing decrement_counter(释放引用计数锁,允许后续 prefetch evict)decrement_counter (release ref-count lock, allowing subsequent prefetch eviction) storage.insert(re-insert 更新后的 values 到 hash table)storage.insert (re-insert updated values into hash table)
内存布局Memory Layout Flat table(连续、变长 per table)Flat table (contiguous, variable-length per table) Padded buffer(按 max_D 对齐,insert 后写入 hash table)Padded buffer (aligned to max_D, written to hash table after insert)
为什么用融合优化器?传统做法需要 4 次 memory round-trip(read emb → read grad → update state → write emb)。Fused kernel 在一次 pass 中完成所有操作,减少 4× 显存带宽消耗,这对 embedding table 的大规模稀疏更新至关重要。
Why fused optimizers? The traditional approach requires 4 memory round-trips (read emb → read grad → update state → write emb). The fused kernel completes everything in a single pass, reducing memory bandwidth consumption by 4×, which is critical for large-scale sparse updates on embedding tables.
目录导航Contents