NVIDIA recsys-examples · DynamicEmb Embedding Layer · RW Sharding
⭐ GitHub: shijieliu/recsys-examples (fea-dynamicemb_table_fusion)在分布式训练的一个 iteration 中,DynamicEmb 的 Embedding 层经历以下 5 个阶段。前两个阶段处理 input keys 的跨 rank 通信(发生在 DynamicEmb 正式 forward 之前),后三个阶段是 DynamicEmb 自身的 prefetch → forward → backward 流程。
During one iteration of distributed training, DynamicEmb's embedding layer goes through the following 5 phases. The first two handle cross-rank communication of input keys (before DynamicEmb's forward pass), and the latter three are DynamicEmb's own prefetch → forward → backward pipeline.
本文档聚焦于 DynamicEmb Embedding 层在一次训练迭代中的完整流程,覆盖 TorchRec TrainPipelineSparseDist 的 6 个阶段:
This document covers the complete flow of the DynamicEmb Embedding layer during one training iteration, spanning all 6 phases of TorchRec's TrainPipelineSparseDist:
DynamicEmb 根据 用户配置 和 可用 HBM 容量 自动选择三种内部存储模式之一(StorageMode),决定了 prefetch / forward / backward 各阶段的执行路径。
DynamicEmb automatically selects one of three internal storage modes based on user configuration and available HBM capacity (StorageMode), determining the execution path for prefetch / forward / backward phases.
| 模式Mode | 触发条件Trigger | 存储架构Storage Architecture | 各阶段行为Phase Behavior | 适用场景Use Cases |
|---|---|---|---|---|
| CACHE |
caching=True且 embedding 总大小 > local_hbm_for_values
and total embedding size > local_hbm_for_values
|
HBM Cache(DynamicEmbCache)+ Host/PS Storage(DynamicEmbStorage / external_storage)GPU 表作为 hot key 缓存,cold key 驻留 host 端 GPU table acts as hot-key cache; cold keys reside on host |
prefetch: _prefetch_cache_path(hash lookup → cache hit/miss → evict → store) (hash lookup → cache hit/miss → evict → store)forward: load_from_flat(slot_indices 直接读,无 hash) (direct read via slot_indices, no hash)backward: fused_update_for_flat_table + decrement_counter
|
Embedding table 远大于 GPU 显存,稀疏特征有时间局部性,热 key 缓存在 HBM 可显著提速Embedding table far exceeds GPU memory; sparse features have temporal locality; caching hot keys in HBM yields significant speedup |
| HBM_DIRECT |
caching=False且 embedding 总大小 ≤ local_hbm_for_values
and total embedding size ≤ local_hbm_for_values
|
单一 HBM Hash Table(DynamicEmbStorage,.is_cuda = True)所有 embedding + optimizer 状态全部在 GPU 显存 Single HBM Hash Table ( DynamicEmbStorage, .is_cuda = True)All embeddings + optimizer states reside entirely in GPU memory |
prefetch: _prefetch_hbm_direct_path(hash lookup → 无 evict,直接 slot 分配) (hash lookup → no eviction, direct slot allocation)forward: load_from_flat(同 CACHE 路径) (same as CACHE path)backward: fused_update_for_flat_table + decrement_counter
|
Embedding table 可完全放入 GPU 显存,性能最佳(无 host ↔ GPU 数据搬运)Embedding table fits entirely in GPU memory — best performance (no host ↔ GPU data transfer) |
| DEFAULT |
caching=False且 embedding 总大小 > local_hbm_for_values
and total embedding size > local_hbm_for_values
|
HybridStorage(HBM tier + Host tier 拼接,无缓存层) key 由 hash 函数决定存放位置(HBM 或 host),不可主动调度 HybridStorage (HBM tier + Host tier hybrid, no cache layer) Key placement (HBM or host) determined by hash function, not actively managed |
prefetch: 仅做 dedup(segmented_unique),跳过 hash lookup / slot 分配(slot_indices = None)forward: _generic_forward_path(storage.find hash lookup → admission → init → insert)backward: update_for_padded_buffer + storage.insert(re-insert)⚠ 无法利用 prefetch stream 与 compute 重叠隐藏 lookup 延迟 prefetch: Only dedup ( segmented_unique), skips hash lookup / slot allocation (slot_indices = None)forward: _generic_forward_path (storage.find hash lookup → admission → init → insert)backward: update_for_padded_buffer + storage.insert (re-insert)⚠ Cannot overlap prefetch stream with compute to hide lookup latency |
不需要缓存或缓存收益不大的场景;HBM + Host 拼接模式,key 的存储位置由 hash 决定Cases where caching is unnecessary or offers little benefit; HBM + Host hybrid mode, key placement determined by hash |
DynamicEmbTableOptions 的两个核心参数决定:
caching: bool = False — 是否启用 GPU 缓存模式global_hbm_for_values: int = 0 — 分配给 embedding + optimizer 状态的 GPU 显存(bytes)自动判定逻辑(_create_cache_storage()):
total_memory = sum(init_capacity * dtype_bytes * value_dim) # embedding + optimizer states
if total_memory > local_hbm:
if caching: → CACHE (DynamicEmbCache + Storage backend)
else: → DEFAULT (HybridStorage: HBM + Host 拼接)
else:
→ HBM_DIRECT (单一 DynamicEmbStorage,全在 GPU)
运行时通过 _is_hbm_storage(storage) 判断是否为 HBM_DIRECT(isinstance(storage, DynamicEmbStorage) and storage._state.tables[0].is_cuda)。
DynamicEmbTableOptions:
caching: bool = False — whether to enable GPU cache modeglobal_hbm_for_values: int = 0 — GPU memory allocated for embeddings + optimizer states (bytes)Auto-detection logic (_create_cache_storage()):
total_memory = sum(init_capacity * dtype_bytes * value_dim) # embedding + optimizer states
if total_memory > local_hbm:
if caching: → CACHE (DynamicEmbCache + Storage backend)
else: → DEFAULT (HybridStorage: HBM + Host hybrid)
else:
→ HBM_DIRECT (single DynamicEmbStorage, all on GPU)
At runtime, _is_hbm_storage(storage) checks for HBM_DIRECT (isinstance(storage, DynamicEmbStorage) and storage._state.tables[0].is_cuda).
alt 帧区分不同路径。CACHE 与 HBM_DIRECT 共享 use_counter=True 分支(区别仅在于 state 来源),DEFAULT 走 use_counter=False 分支。alt frames to distinguish paths. CACHE and HBM_DIRECT share the use_counter=True branch (differing only in state source), while DEFAULT takes the use_counter=False branch.| 阶段Phase | 关键操作Key Operations | CUDA Stream | D2H Stall | H2D Transfer | 核心文件Core Files |
|---|---|---|---|---|---|
| input_dist | dedup keys → bucketize → 发起 splits AllToAll (async)dedup keys → bucketize → initiate splits AllToAll (async) | data_dist_stream | 2× D2H num_uniques.item() length_per_key.tolist() |
4× H2D dist_type, batch_size_per_split, length_per_split, stride tensors |
shard/embedding.py, input_dist.py |
| wait_sparse sub-stage 1 |
等 splits AllToAll → 发起 lengths/values AllToAllawait splits AllToAll → initiate lengths/values AllToAll | data_dist_stream | 1× D2H output_splits.tolist() |
1× H2D recat permute index |
train_pipeline.py, dist_data.py |
| wait_sparse sub-stage 2 ⚠ 已延迟到 prefetch⚠ deferred to prefetch |
等待 lengths/values AllToAll 完成 → build KJT → permuteawait lengths/values AllToAll → build KJT → permute | data_dist_stream | 1× D2H length_per_key.tolist() |
无None | utils.py (_prefetch_embeddings) |
| prefetch Cache/HBM path |
dedup → hash lookup → cache hit/miss → storage fallback → insert & evict | prefetch_stream (separate) | 多处 .item()/.any() 隐式同步multiple .item()/.any() implicit syncs | 无None | batched_dynamicemb_function.py |
| prefetch DEFAULT path |
仅 dedup(segmented_unique),跳过 hash lookup / slot 分配dedup only (segmented_unique), skip hash lookup / slot allocation | prefetch_stream | 无None | 无None | batched_dynamicemb_function.py |
| forward Cache path |
consume PrefetchState → load_from_flat → gather_embedding_pooled | default_stream | 无 ✅None ✅ | 无 ✅None ✅ | batched_dynamicemb_function.py |
| forward DEFAULT path |
_generic_forward_path: storage.find → admission → init → insert → gather_embedding_pooled | default_stream | 无None | 无None | batched_dynamicemb_function.py |
| output_dist RW Sequence |
permute_2D → AllToAll (async) → wait → index_selectpermute_2D → AllToAll (async) → wait → index_select | default_stream | 无 ✅None ✅ | 无 ✅None ✅ | rw_sequence_sharding.py, dist_data.py, comm_ops.py |
| backward Cache/HBM path |
reduce_grads → fused_update_for_flat_table → decrement_counter | default_stream | 无 ✅None ✅ | 无 ✅None ✅ | batched_dynamicemb_function.py, optimizer.py |
| backward DEFAULT path |
reduce_grads → update_for_padded_buffer → storage.insert (re-insert) | default_stream | 无 ✅None ✅ | 无 ✅None ✅ | batched_dynamicemb_function.py, optimizer.py |
TrainPipelineSparseDist 通过 pipeline rewrite 将 input_dist / wait_sparse 与上一批次的 forward 重叠执行。关键优化:wait_sparse 的第二个子阶段(等待 lengths/values AllToAll 完成 + rebuild KJT)被延迟到 _prefetch_embeddings() 入口处执行,使得 AllToAll #2/#3 的通信时间可以被 forward compute 完全覆盖。prefetch 自身在独立的 CUDA stream 上运行,隐藏 hash lookup + cache miss + I/O 延迟。TrainPipelineSparseDist uses pipeline rewrite to overlap input_dist / wait_sparse with the previous batch's forward. Key optimization: the second sub-stage of wait_sparse (awaiting lengths/values AllToAll completion + rebuilding KJT) is deferred to the entry of _prefetch_embeddings(), allowing AllToAll #2/#3 communication to be fully overlapped with forward compute. Prefetch itself runs on a separate CUDA stream, hiding hash lookup + cache miss + I/O latency.input_dist 是 Pipeline 调用 _start_sparse_data_dist() 触发的第一个阶段,在 data_dist_stream 上执行。核心任务:
segmented_unique_cuda 对 batch 内的 keys 去重,减少后续通信量block_bucketize_sparse_features 将 keys 按目标 rank 分桶此阶段包含 2 次 D2H 同步(num_uniques.item() 和 length_per_key().tolist())以及 4 次 H2D 传输(dist_type, batch_size_per_split, length_per_split, stride tensors),这些是 nsys timeline 上可观察到的 cudaMemcpy。
input_dist is the first phase triggered by the pipeline's _start_sparse_data_dist(), executing on data_dist_stream. Core tasks:
segmented_unique_cuda to deduplicate keys within the batch, reducing subsequent communication volumeblock_bucketize_sparse_features to bucket keys by target rankThis phase involves 2 D2H syncs (num_uniques.item() and length_per_key().tolist()) and 4 H2D transfers (dist_type, batch_size_per_split, length_per_split, stride tensors), visible as cudaMemcpy on the nsys timeline.
wait_sparse_data_dist 包含 两个子阶段,经过 train_pipeline 的 rewrite 后被拆分执行:
_wait_sparse_data_dist(),在 forward 之前):等待 AllToAll #1(splits metadata)完成 → 用 output_splits 发起 AllToAll #2(lengths)+ AllToAll #3(values)_prefetch_embeddings() 内部):等待 AllToAll #2 + #3 完成 → 重建 KJT(KeyedJaggedTensor)→ permute_2D_sparse_data 重排序这种 rewrite 的核心目的:Sub-stage 1 发起 lengths/values AllToAll 后立即返回,让 forward(batch_i) 在中间执行,此时 AllToAll #2/#3 的 NCCL 通信可以与 forward compute 完全重叠。Sub-stage 2 延迟到 _prefetch_embeddings() 入口处再等待通信完成。具体代码在 utils.py:1639(_prefetch_embeddings 仅 wrap 了 request.wait(),prefetch 逻辑在其外部)。
wait_sparse_data_dist consists of two sub-stages, split by the train_pipeline rewrite:
_wait_sparse_data_dist(), before forward): await AllToAll #1 (splits metadata) completion → use output_splits to initiate AllToAll #2 (lengths) + AllToAll #3 (values)_prefetch_embeddings()): await AllToAll #2 + #3 completion → rebuild KJT (KeyedJaggedTensor) → permute_2D_sparse_data reorderPurpose of this rewrite: Sub-stage 1 returns immediately after initiating lengths/values AllToAll, allowing forward(batch_i) to execute in between, during which AllToAll #2/#3 NCCL communication fully overlaps with forward compute. Sub-stage 2 defers waiting for communication completion to the entry of _prefetch_embeddings(). See utils.py:1639 (_prefetch_embeddings wraps request.wait(); prefetch logic is external to it).
outputs.tolist() 的开销——需要等 AllToAll #1 结束后把 splits metadata 搬回 CPU。这个 stall 是整个 input_dist 流程中最大的,因为它同时包含了 NCCL 通信等待和 stream 同步。outputs.tolist() — waiting for AllToAll #1 to finish and moving splits metadata back to CPU. This stall is the largest in the entire input_dist flow, as it includes both NCCL communication wait and stream synchronization.实际执行时,_prefetch_embeddings() 入口处会先执行 wait_sparse sub-stage 2(等待 lengths/values AllToAll 完成 + rebuild KJT),然后再进入 DynamicEmb 自身的 prefetch 逻辑。
Prefetch 在 独立的 CUDA stream 上运行,目标是将 unique keys 解析为 flat-table slot 位置(slot_indices),供 forward 直接使用。核心路径分两种:
At execution time, _prefetch_embeddings() first executes wait_sparse sub-stage 2 at its entry (awaiting lengths/values AllToAll completion + rebuilding KJT), then proceeds to DynamicEmb's own prefetch logic.
Prefetch runs on a separate CUDA stream, aiming to resolve unique keys into flat-table slot positions (slot_indices) for direct use by forward. Two core paths:
| Cache Path (caching=True) | HBM Direct Path (no cache) | |
|---|---|---|
| 适用场景Use Case | Zipf 分布、热点 key 集中Zipf distribution, concentrated hot keys | 均匀分布或表完全放入 HBMUniform distribution or table fits entirely in HBM |
| 流程Flow |
cache.lookup → 区分 hit/missseparate hit/miss →storage.find(miss fallback) (miss fallback) →cache.insert_and_evict(LRU/LFU)→load_from_flat / store_to_flat(搬迁数据) (data migration) →storage.insert(write-back evicted) (write-back evicted)
|
_find_keys(hash table lookup) (hash table lookup) →key_index_map.insert(新 key 入表) (insert new keys) →store_to_flat(初始化 embedding) (initialize embeddings)
|
| 隐式同步Implicit Sync | 多处 .item() / .any() 同步multiple .item() / .any() syncs |
较少同步fewer syncs |
PrefetchState 包含 slot_indices、unique_keys、reverse_indices 等,供 forward 零查询使用。PrefetchState contains slot_indices, unique_keys, reverse_indices, etc., enabling zero-lookup forward pass..item() / .any() 调用(图中红色 SYNC 标记),每一处都会触发隐式 cudaStreamSynchronize。这是当前 DynamicEmb 的主要性能瓶颈之一,未来有望通过 kernel fusion 或异步化消除。.item() / .any() calls in prefetch (red SYNC markers in diagram) each trigger an implicit cudaStreamSynchronize. This is one of DynamicEmb's main performance bottlenecks, potentially addressable via kernel fusion or asynchronous approaches.Forward 阶段根据 use_counter(即 update_slot_indices is not None)分为两条路径:
| Cache / HBM_DIRECT Path | DEFAULT Path | 共同步骤 |
|---|---|---|
use_counter = Trueload_from_flat:通过 PrefetchState 中的 slot_indices 直接索引读取 embedding,无 hash lookup,极快[if non_admitted_positions] initializer 初始化新 embedding |
use_counter = False_generic_forward_path: ① storage.find(hash lookup 查找已有 key)② _apply_admission(过滤新 key)③ initializer(初始化新 embedding)④ storage.insert(持久化新 key)
|
gather_embedding_pooled:将 unique_embs scatter 到 output tensor,应用 SUM/MEAN pooling 和 per_sample_weights |
prefetch()(line 921-922)The forward phase branches into two paths based on use_counter (i.e., update_slot_indices is not None):
| Cache / HBM_DIRECT Path | DEFAULT Path | Common Step |
|---|---|---|
use_counter = Trueload_from_flat: reads embeddings directly via PrefetchState's slot_indices, no hash lookup, extremely fast[if non_admitted_positions] initializer initializes new embeddings |
use_counter = False_generic_forward_path: ① storage.find (hash lookup for existing keys)② _apply_admission (filter new keys)③ initializer (initialize new embeddings)④ storage.insert (persist new keys)
|
gather_embedding_pooled: scatter unique_embs to output tensor, apply SUM/MEAN pooling and per_sample_weights |
prefetch() (line 921-922)load_from_flat(按 slot 索引读)→ gather_embedding_pooled(scatter + pool),且全程无 CPU-GPU 同步。load_from_flat (read by slot index) → gather_embedding_pooled (scatter + pool), with zero CPU-GPU synchronization.storage.find(hash lookup)→ admission → init → insert 流程,延迟明显高于 Cache 路径。这也是推荐使用 CACHE 或 HBM_DIRECT 模式的原因之一。storage.find (hash lookup) → admission → init → insert flow inline, resulting in noticeably higher latency than the Cache path. This is one reason CACHE or HBM_DIRECT modes are recommended.DynamicEmb 使用 EmbeddingCollection(sequence / unpooled embeddings),因此 output_dist 走 all_to_all_single(而非 pooled 模式的 ReduceScatterV)。RW sharding 将 embedding table 按行切分到各 rank,每个 rank 只持有部分行的 lookup 结果,需要通过 AllToAll 将结果送回请求方。
Output_dist 分为两个子阶段:
compute_and_output_dist() 内部):fbgemm.permute_2D_sparse_data 按 rank 重排 embeddings → dist.all_to_all_single(async_op=True) 异步发起 NCCL AllToAllSequenceEmbeddingsAwaitable._wait_impl()):等待 AllToAll 完成 → torch.index_select 恢复原始行序此路径 无 D2H / H2D 传输:forward_recat_tensor 在 __init__ 时预注册为 GPU buffer;input_splits / output_splits 为 Python List[int],由 NCCL backend 内部处理。唯一的同步点是 myreq.req.wait()。
DynamicEmb uses EmbeddingCollection (sequence / unpooled embeddings), so output_dist uses all_to_all_single (not ReduceScatterV for pooled mode). RW sharding partitions embedding tables by rows across ranks — each rank holds only partial lookup results, requiring AllToAll to route results back to the requesting rank.
Output_dist consists of two sub-phases:
compute_and_output_dist()): fbgemm.permute_2D_sparse_data reorders embeddings by rank → dist.all_to_all_single(async_op=True) launches async NCCL AllToAllSequenceEmbeddingsAwaitable._wait_impl()): waits for AllToAll completion → torch.index_select restores original row orderThis path has no D2H / H2D transfers: forward_recat_tensor is pre-registered as a GPU buffer in __init__; input_splits / output_splits are Python List[int], handled internally by the NCCL backend. The only sync point is myreq.req.wait().
RwSequenceEmbeddingDist.forward(local_embs, sharding_ctx)
→ SequenceEmbeddingsAllToAll.forward(local_embs, lengths, input_splits, output_splits, ...)
→ alltoall_sequence(...) # comm_ops.py
→ All2All_Seq_Req.forward(pg, myreq, a2ai, embs)
1. fbgemm.permute_2D_sparse_data(recat, ...) # GPU kernel: 按 rank 重排 embeddings
2. dist.all_to_all_single(output, input, # ★ NCCL AllToAll (async_op=True)
output_split_sizes, input_split_sizes, pg, async_op=True)
→ return Request (async handle)
→ SequenceEmbeddingsAwaitable(tensor_awaitable, unbucketize_permute_tensor, D)
# .wait() 时触发:
SequenceEmbeddingsAwaitable._wait_impl()
→ All2All_Seq_Req_Wait.forward(pg, myreq)
1. myreq.req.wait() # ★ NCCL sync point (阻塞等待 AllToAll 完成)
2. return sharded_output_embeddings.view(-1, D)
→ torch.index_select(ret, 0, unbucketize_permute_tensor) # GPU kernel: 恢复原始行序
RwSequenceEmbeddingDist.forward(local_embs, sharding_ctx)
→ SequenceEmbeddingsAllToAll.forward(local_embs, lengths, input_splits, output_splits, ...)
→ alltoall_sequence(...) # comm_ops.py
→ All2All_Seq_Req.forward(pg, myreq, a2ai, embs)
1. fbgemm.permute_2D_sparse_data(recat, ...) # GPU kernel: reorder embeddings by rank
2. dist.all_to_all_single(output, input, # ★ NCCL AllToAll (async_op=True)
output_split_sizes, input_split_sizes, pg, async_op=True)
→ return Request (async handle)
→ SequenceEmbeddingsAwaitable(tensor_awaitable, unbucketize_permute_tensor, D)
# Triggered on .wait():
SequenceEmbeddingsAwaitable._wait_impl()
→ All2All_Seq_Req_Wait.forward(pg, myreq)
1. myreq.req.wait() # ★ NCCL sync point (blocks until AllToAll completes)
2. return sharded_output_embeddings.view(-1, D)
→ torch.index_select(ret, 0, unbucketize_permute_tensor) # GPU kernel: restore original row order
recat、unbucketize_permute_tensor、embeddings)都在 GPU 上;splits 是 Python List[int],PyTorch NCCL backend 内部处理。相比 input_dist 的 6 次 cudaMemcpy,output_dist 的 overhead 极低。recat, unbucketize_permute_tensor, embeddings) reside on GPU; splits are Python List[int], handled internally by the PyTorch NCCL backend. Compared to input_dist's 6 cudaMemcpy operations, output_dist has minimal overhead.Backward 由 PyTorch autograd 触发 DynamicEmbeddingFunction.backward(),核心步骤:
[B × total_D] 梯度归约到 [unique_N × max_D]decrement_counter);DEFAULT Path 需要 storage.insert 将更新后的 values re-insert 到 hash tableBackward is triggered by PyTorch autograd calling DynamicEmbeddingFunction.backward(). Core steps:
[B × total_D] gradients to [unique_N × max_D]decrement_counter); DEFAULT path re-inserts updated values into the hash table via storage.insert| 对比项Comparison | CACHE / HBM_DIRECT (use_counter=True) | DEFAULT (use_counter=False) |
|---|---|---|
| 更新目标Update Target | Cache/storage flat table(via update_slot_indices)Cache/storage flat table (via update_slot_indices) | unique_values padded buffer(in memory)unique_values padded buffer (in memory) |
| Kernel | _update_for_flat_table_kernel |
_update_for_padded_buffer_kernel |
| 后处理Post-processing | decrement_counter(释放引用计数锁,允许后续 prefetch evict)decrement_counter (release ref-count lock, allowing subsequent prefetch eviction) |
storage.insert(re-insert 更新后的 values 到 hash table)storage.insert (re-insert updated values into hash table) |
| 内存布局Memory Layout | Flat table(连续、变长 per table)Flat table (contiguous, variable-length per table) | Padded buffer(按 max_D 对齐,insert 后写入 hash table)Padded buffer (aligned to max_D, written to hash table after insert) |