SWPipeline API Reference

PyTorch / CUDA 通用软件流水线框架 ·General-purpose Software Pipeline for PyTorch / CUDA · sw_pipeline.py · ← 回到 SWPipeline 总览← Back to SWPipeline Overview

数据类Data Classes

Class
class IterContext(batch: Any, iter_idx: int)
每次迭代的上下文对象,传递给所有 task 函数。batchnext(data_iter) 的原始数据,task 通过 ctx.xxx = ... 传递中间结果,通过 del ctx.xxx 提前释放显存。每个 in-flight 迭代拥有独立的 IterContext,避免跨迭代数据竞争。 Per-iteration context passed to every task function. batch is raw data from next(data_iter). Tasks pass intermediate results via ctx.xxx = ... and free memory via del ctx.xxx. Each in-flight iteration gets its own IterContext to avoid cross-iteration data hazards.
属性Attr类型Type说明Description
batchAny数据迭代器产出的原始 batchRaw batch from data iterator
iter_idxint0-based 迭代编号0-based iteration index
任意arbitraryAnyTask 动态附加:ctx.X = ...Dynamically attached by tasks: ctx.X = ...
示例Example
def copy_batch(ctx: IterContext) -> None:
    ctx.x_gpu = ctx.batch["x"].to("cuda", non_blocking=True)
    ctx.y_gpu = ctx.batch["y"].to("cuda", non_blocking=True)

def forward(ctx: IterContext) -> None:
    ctx.logits = model(ctx.x_gpu)
    del ctx.x_gpu  # free GPU memory early
Dataclass
@dataclass
class DeclaredIO
声明 Task 对 IterContext 以外共享状态(外部副作用)的读写合约。框架自动在缓存时调用 capture、在 shortcut 回放时调用 restore。返回值经 _cache_val 递归 detach+clone 所有 Tensor。 Declares a Task's read/write contract with shared state outside IterContext (external side effects). The framework calls capture during caching and restore during shortcut replay. Return values go through _cache_val (recursive tensor detach+clone).
字段Field类型Type说明Description
captureCallable[[], Any]Task 执行调用。快照外部状态,返回可缓存值Called after Task runs. Snapshots external state, returns cacheable value
restoreCallable[[Any], None]Shortcut 时调用。接收 _restore_val 后的值,写回外部状态Called during shortcut. Receives _restore_val output, writes back external state
示例Example
shared_buffer = {}

PipelineTask("encode", encode_fn, io=[
    DeclaredIO(
        capture=lambda: dict(shared_buffer),
        restore=lambda s: shared_buffer.update(s),
    ),
])
Dataclass
@dataclass(eq=False)
class PipelineTask
一个可调度的计算单元。只声明计算身份(name + fn + io),不包含调度属性。调度配置通过 PipelinePlanschedule 字典声明。__hash____eq__ 基于 name,允许直接作为 schedule 的 key。 One schedulable computation unit. Declares only computation identity (name + fn + io), not scheduling. Scheduling is in PipelinePlan.schedule. __hash__ and __eq__ are based on name, so tasks can serve as schedule keys.
字段Field类型Type说明Description
namestr全局唯一标识(也用于 NVTX tag)Globally unique identifier (also NVTX tag)
fnCallable[[IterContext], None]计算函数,读写 ctx 属性Compute function, reads/writes ctx attrs
ioList[DeclaredIO]外部副作用声明(默认 []Side effect declarations (default [])
示例Example
h2d = PipelineTask("CopyBatch", copy_batch_fn)
fwd = PipelineTask("Forward", forward_fn)
bwd = PipelineTask("Backward", backward_fn)

# with DeclaredIO for tasks with external side effects
emb = PipelineTask("EmbLookup", emb_fn, io=[
    DeclaredIO(capture=capture_emb, restore=restore_emb),
])
Dataclass
@dataclass
class TaskSchedule
每个 Task 在 PipelinePlan 中的调度配置。所有属性有默认值——最简使用只需设 stage Per-task scheduling configuration in a PipelinePlan. All fields have defaults — simplest usage only needs stage.
字段Field类型Type默认Default说明Description
stageint0Pipeline stage(stage k → 处理 iter i-k 的数据)Pipeline stage (stage k → processes iter i-k data)
streamOptional[Stream]NoneNone = default streamNone = default stream
thread_groupstr"default"所属线程组。同组 Task 串行执行Thread group. Same-group tasks execute serially
globally_orderedboolFalseTrue_SubmissionSequencer 保证跨 rank 顺序(防 NCCL 死锁)True_SubmissionSequencer ensures cross-rank order (prevents NCCL deadlocks)
示例Example
TaskSchedule(stage=0, stream=memcpy_stream, thread_group="io")
TaskSchedule(stage=1, globally_ordered=True)  # NCCL task
TaskSchedule(stage=2)  # default stream, default thread
Dataclass
@dataclass
class PipelinePlan
完整调度方案。类似 TorchRec ShardingPlan——将"做什么"和"怎么调度"分离。schedule 的 key 即 task 集合——无需单独传 task 列表。intra_iter_deps 声明迭代内依赖,inter_iter_deps 声明跨迭代依赖。 Complete scheduling plan. Analogous to TorchRec ShardingPlan — separates "what" from "how". schedule keys define the task set. intra_iter_deps declares intra-iteration deps, inter_iter_deps declares cross-iteration deps.
字段Field类型Type说明Description
scheduleDict[TaskRef, TaskSchedule]Task → 调度配置映射。key 定义任务集合Task → scheduling config. Keys define task set
intra_iter_depsList[Tuple[TaskRef, TaskRef]]迭代内依赖 (task, depends_on)Intra-iteration deps (task, depends_on)
inter_iter_depsList[Tuple[…]]跨迭代依赖 (task_i, task_i-1)(默认 []Cross-iteration deps (task_i, task_i-1) (default [])
pipeline_depthint流水线深度,等于 max(stage) + 11 = serial,2+ = pipelined(默认 2Pipeline depth, equals max(stage) + 1. 1 = serial, 2+ = pipelined (default 2)
完整示例Full Example
copy  = PipelineTask("H2D", h2d_fn)
fwd   = PipelineTask("Forward", fwd_fn)
bwd   = PipelineTask("Backward", bwd_fn)
optim = PipelineTask("Optim", optim_fn)

plan = PipelinePlan(
    schedule={
        copy:  TaskSchedule(stage=0, stream=copy_stream, thread_group="io"),
        fwd:   TaskSchedule(stage=1, thread_group="compute"),
        bwd:   TaskSchedule(stage=1, thread_group="compute"),
        optim: TaskSchedule(stage=1, thread_group="compute"),
    },
    intra_iter_deps=[(fwd, copy), (bwd, fwd), (optim, bwd)],
    inter_iter_deps=[(fwd, optim)],  # Fwd(i) waits for Optim(i-1)
    pipeline_depth=2,
)
Dataclass
@dataclass
class ProfileResult
TaskProfiler.profile() 返回结果,包含基线时间和每个 task 的 exposed time。 Return type of TaskProfiler.profile(), containing baseline time and per-task exposed time.
字段Field类型Type说明Description
baseline_sfloat一次 serial 迭代的中位 wall-clock(秒)Median wall-clock of one serial iteration (seconds)
exposed_sDict[str, float]{task_name: exposed_seconds}{task_name: exposed_seconds}
Type Alias
TaskRef = Union[str, PipelineTask]
PipelinePlanscheduleintra_iter_depsinter_iter_deps 中,可用 PipelineTask 对象或 name 字符串引用 task。推荐使用对象以获得更好的类型检查。 In PipelinePlan fields, tasks can be referenced by PipelineTask objects or name strings. Objects are recommended for better type checking.


同步与调度Synchronization & Scheduling

两阶段同步协议Two-Phase Synchronization Protocol

SWPipeline 通过两层依赖声明描述任务间的数据关系: SWPipeline uses two layers of dependency declarations to describe data relationships between tasks:

  • intra_iter_deps:同一 batch(同一 iter)内的数据依赖。如果任务 B 消费任务 A 的输出,写 intra_iter_deps=[(B, A)] intra_iter_deps: Data dependency within the same batch (same iter). If task B consumes A's output, write intra_iter_deps=[(B, A)].
  • inter_iter_deps:跨迭代数据依赖。任务 A(iter i) 依赖 B(iter i-1) 的输出,写 inter_iter_deps=[(A, B)]。仅支持 i-1 inter_iter_deps: Cross-iteration data dependency. Task A(iter i) depends on B(iter i-1), write inter_iter_deps=[(A, B)]. Only i-1 is supported.

两者都纯粹描述数据依赖语义,与 stream、stage 无关。stream 和 stage 带来的约束在推论部分讨论。运行时,每条依赖通过两阶段协议执行: Both purely describe data dependency semantics, independent of stream or stage. Constraints arising from stream and stage are discussed in the Corollaries section. At runtime, each dependency is enforced via a two-phase protocol:

# Phase 1: CPU synchronization — threading.Event
sig = self._cpu_signal(iter_idx, dep_name)
sig.wait(timeout=30)                         # CPU blocks until dep's CPU work is done

# Phase 2: GPU synchronization — torch.cuda.Event
stream.wait_event(self._cuda_event(iter_idx, dep_name))  # GPU waits for dep's kernels

# Signal sender (after task completes):
self._cuda_event(iter_idx, task_def.name).record(stream)
self._cpu_signal(iter_idx, task_def.name).set()

多线程分发与 CPU 同步Multi-Thread Dispatch & CPU Synchronization

SWPipeline 支持多线程提交 kernel。每个任务通过 thread_group 属性指定所属的工作线程。运行时,_enqueue_period_enqueue_order 顺序将任务分发到各线程的 work_queue SWPipeline supports multi-threaded kernel submission. Each task specifies its worker thread via the thread_group attribute. At runtime, _enqueue_period dispatches tasks to each thread's work_queue following _enqueue_order:

_enqueue_period 分发流程_enqueue_period dispatch flow:
for task_def in self._enqueue_order:           # stage-descending
    self._work_queues[task_def.thread_group].put((iter_idx, task_def))
                      ↓                 ↓
            ┌─────────────────┐  ┌─────────────────┐
            │  Thread "memcpy" │  │  Thread "default" │
            │  work_queue:     │  │  work_queue:      │
            │   H2D            │  │   WaitBatch       │
            │                  │  │   Forward          │
            │                  │  │   Backward         │
            └────────┬─────────┘  └────────┬──────────┘
                     │                     │
              _submit_task()          _submit_task()

CPU 依赖不是独立的依赖层——它是两阶段协议的 Phase 1 自动处理的。当 Thread B 上的 WaitBatch 声明 intra_iter_deps=[H2D] 时: CPU dependency is not a separate layer — it is automatically handled by Phase 1 of the two-phase protocol. When WaitBatch on Thread B declares intra_iter_deps=[H2D]:

  1. Thread B 在 _submit_task 开头执行 cpu_signal("H2D").wait()CPU 阻塞,直到 Thread A 完成 H2D 并调用 .set() Thread B calls cpu_signal("H2D").wait() at the start of _submit_taskCPU blocks until Thread A finishes H2D and calls .set()
  2. 解除阻塞后,Thread B 执行 stream.wait_event(cuda_event("H2D")) → GPU 也同步 Once unblocked, Thread B calls stream.wait_event(cuda_event("H2D")) → GPU also synchronized

即使两个任务在同一线程上,Phase 1 的 wait() 也一样执行——如果 dep 已完成则 wait() 立即返回,零开销。因此用户只需声明数据依赖(intra_iter_deps / inter_iter_deps),CPU 线程同步由 Phase 1 自动保证 Even for tasks on the same thread, Phase 1's wait() still executes — if the dep is already done, wait() returns immediately with zero overhead. So users only need to declare data dependencies (intra_iter_deps / inter_iter_deps); CPU thread synchronization is automatically guaranteed by Phase 1.

示例:跨线程依赖Example: Cross-Thread Dependency

Thread "memcpy" (H2D)                    Thread "default" (WaitBatch)
═══════════════════════════════          ═══════════════════════════════════
  ├─ run H2D on memcpy_stream              ├─ cpu_signal("H2D").wait()  ← Phase 1: CPU blocks
  ├─ cuda_event("H2D").record(memcpy)      │   ... Thread B blocked ...
  ├─ cpu_signal("H2D").set()  ─────────────┤   unblocked!
  ▼ (done, process next in queue)          ├─ stream.wait_event(cuda_event("H2D"))  ← Phase 2: GPU waits
                                           ├─ run WaitBatch on default_stream
                                           ▼ (done)

period / iter / stageperiod / iter / stage

概念Concept 含义Meaning
period 全局时钟周期,每次 progress() 递增 1Global clock, incremented by 1 on each progress()
stage 流水线深度偏移(0 = 最早阶段,max = 最后阶段)Pipeline depth offset (0 = earliest, max = final stage)
iter_idx 任务处理的实际数据迭代编号Actual data iteration index
iter_idx = period − stage
pipeline_depth = max(stage) + 1
示例Example: pipeline_depth=3 (SparseDist)
              stage=0(H2D)  stage=1(InputDist)  stage=2(Fwd/Bwd/Opt)
period=2:     iter 2        iter 1              iter 0
period=3:     iter 3        iter 2              iter 1
period=4:     iter 4        iter 3              iter 2

三条推论与死锁分析Three Corollaries & Deadlock Analysis

推论 1:intra_iter_deps 的 stage 约束intra_iter_deps 中被依赖方的 stage 必须 ≤ 依赖方的 stage:stagedep ≤ stagetask。否则 dep 在更晚 period 执行,task 永远等不到——死锁。 Corollary 1: Stage constraint for intra_iter_deps. stagedep ≤ stagetask. Otherwise dep executes in a later period — deadlock.
推论 2:stage-descending 提交。同一 period 内 stage 越高 → batch 越老。_build_enqueue_order 采用 stage-descending 策略: Corollary 2: Stage-descending submission. Higher stage → older batch within same period. _build_enqueue_order uses stage-descending:
  • 正确性inter_iter_deps(k=1) 使供给者(高 stage)与消费者(低 stage)同 period,必须先提交供给者 Correctness: k=1 inter_iter_deps place supplier (high stage) and consumer (low stage) in same period; supplier first
  • 吞吐:优先老 batch 减少 head-of-line blocking Throughput: Prioritize older batch to reduce head-of-line blocking
推论 3:inter_iter_deps 的 stage 差约束。被依赖方与依赖方的 stage 差不能大于 1:stagedep − stagetask ≤ 1。否则 dep 落在未来 period,构成结构性死锁。 Corollary 3: Stage gap for inter_iter_deps. stagedep − stagetask ≤ 1. Otherwise dep falls in a future period — structural deadlock.
kΔperiod 结果Result 示例Example
≤ 0≤ −1 ✅ dep 在更早 period,零阻塞✅ dep in earlier period, zero blocking 同 stage 内跨 iter:dep 在上一 period 已完成Same stage cross-iter: dep completed in prev period
10 ⚠️ 同 period,需 stage-descending⚠️ Same period, needs stage-descending A(s=0) 依赖 B(s=1):B 先提交则安全A(s=0) depends on B(s=1): safe if B submitted first
≥ 2≥ 1 ❌ 结构性死锁❌ Structural deadlock fwd(s=0) 依赖 opt(s=2):opt 在未来 periodfwd(s=0) depends on opt(s=2): opt in future period

死锁场景Deadlock Scenarios

场景 1:单线程 k=1 升序死锁Scenario 1: Single-thread k=1 ascending deadlock

A(s=0, T1), B(s=1, T1), A.inter_iter_deps=[B]  →  同 periodsame period
升序Ascending ❌:  T1: A(s=0) → wait B → BLOCKED → B 永远不提交never submitted 💀
降序Descending ✅:  T1: B(s=1) → set ✓ → A(s=0) → 已 setset

场景 2:多线程交叉死锁Scenario 2: Multi-thread cross deadlock

T1: A(s=0), D(s=1)  A.cross_iter_dep→B(T2)   T2: C(s=0), B(s=1)  C.cross_iter_dep→D(T1)
升序Ascending ❌:  T1: A→wait B→BLOCKED  T2: C→wait D→BLOCKED互锁mutual 💀
降序Descending ✅:  T1: D(s=1)→  T2: B(s=1)→  T1: A(s=0)→  T2: C(s=0)→

场景 3:k ≥ 2 结构性死锁Scenario 3: k ≥ 2 structural deadlock

fwd(s=0), bwd(s=1), opt(s=2), fwd.inter_iter_deps=[opt]  →  k=2
           stage=2(opt)   stage=1(bwd)   stage=0(fwd)
period 3:   opt(1)        bwd(2)        fwd(3)
period 4:   opt(2)        bwd(3)        fwd(4)
fwd(3) 需要needs opt(2):  fwd(3) @ period 3,  opt(2) @ period 4未来 period,任何顺序都死锁future period, deadlock regardless 💀
实际训练:3-stage pipeline 中 fwd(I) 与 opt(I-2) 自然共存于同 period。同 stream 上 stage-descending 保证 opt 先入队——无需 inter_iter_deps。代价是 weight staleness = pipeline_depth − 1。 In practice: In a 3-stage pipeline, fwd(I) coexists with opt(I-2) in the same period. Stage-descending on the same stream ensures opt is queued first — no inter_iter_deps needed. Cost: weight staleness = pipeline_depth − 1.

综合示例:fwd / bwd / opt 的 stage 安排Example: fwd / bwd / opt Stage Arrangements

基本设定:intra_iter_deps = [(bwd, fwd), (opt, bwd)]。分析有/无 fwd.inter_iter_deps = [opt] Base: intra_iter_deps = [(bwd, fwd), (opt, bwd)]. With/without fwd.inter_iter_deps = [opt].

无 inter_iter_deps(接受 weight staleness)Without inter_iter_deps (accept staleness)

模式Modefwdbwdoptdepth加速Speedup说明Notes
A0001纯串行Serial
B0012opt(i-1) ‖ fwd(i)+bwd(i)
C0112bwd(i-1)+opt(i-1) ‖ fwd(i)
D0123三路,staleness=2Three-way, staleness=2

有 inter_iter_deps → 无 GPU 加速With inter_iter_deps → no GPU speedup

推论 1 + 推论 3 → 仅 3 种合法安排(A/B/C),均因 wait_event 强制串行 Corollary 1 + 3 → only 3 legal arrangements (A/B/C), all forced serial by wait_event:

模式Modefwdbwdoptk加速Speedup
A0000
B0011❌ fwd wait_event opt
C0111❌ fwd wait_event opt
核心洞察:Pipeline 加速恰恰因为放弃了 inter_iter_deps。不声明时不同 stage 真正并行;声明后全局串行。实践建议:对 fwd/bwd/opt 不要声明 inter_iter_deps,依靠 stream FIFO + stage-descending 保证权重可见性。 Key insight: Pipelining works precisely by not declaring inter_iter_deps. Without: true parallelism. With: forced serial. Advice: don't declare inter_iter_deps for fwd/bwd/opt; rely on stream FIFO + stage-descending.

常见陷阱Common Pitfalls

陷阱Pitfall 后果Consequence 修复Fix
遗漏跨 stream depMissing cross-stream dep GPU 数据竞争GPU data race 加 depAdd dep
遗漏同 stage intra-depMissing intra-stage dep _topo_sort 字母序,顺序错_topo_sort alphabetical, wrong order 加 depAdd dep
dep 循环Circular deps 死锁Deadlock 确保 DAGEnsure DAG
intra_iter_deps 中 dep.stage > task.stagedep.stage > task.stage in intra_iter_deps 死锁(推论 1Deadlock (Cor. 1) 改用 inter_iter_depsUse inter_iter_deps
inter_iter_deps k ≥ 2 结构性死锁(推论 3Structural deadlock (Cor. 3) 缩减 stage 差至 ≤ 1Reduce stage gap to ≤ 1

Ready-First 调度算法Ready-First Scheduling Algorithm

_build_enqueue_order 采用 ready-first 策略:构建 period-local 依赖图,按 (stall_cost, name) 优先级做拓扑排序,在保证正确性的前提下最小化 GPU stream 空转(stall)。 _build_enqueue_order uses a ready-first strategy: builds a period-local dependency graph and topologically sorts by (stall_cost, name) priority, minimizing GPU stream stalls while preserving correctness.

问题:为什么 stage-descending 不是最优Problem: Why Stage-Descending Is Not Optimal

GPU stream 是 FIFO 队列。考虑以下配置: GPU streams are FIFO queues. Consider this setup:

任务Task   stage   stream   intra_iter_deps
────────────────────────────────────
P      0      X        —              无依赖none
Q      1      Y        [P]            跨 stage(0→1),不在 period-local 图中cross-stage(0→1), not in period-local graph
R      0      X        [Q]            跨 stage(1→0),不在 period-local 图中,但 GPU 仍有 wait_eventcross-stage(1→0), not in period-local graph, but GPU still has wait_event

P 和 R 都在 stage 0、stream X 上。stage-descending 先提交 Q(stage 1),再提交 P、R(stage 0)。问题是 R 依赖 Q(跨 stream Y→X),提交时会插入 stream_X.wait_event(Q)。如果 R 被排在 P 前面(如按字母序),P 就被堵在 R 的 wait_event 后面白白空转: P and R are both at stage 0 on stream X. Stage-descending submits Q(stage 1) first, then P and R(stage 0). The problem: R depends on Q (cross-stream Y→X), inserting stream_X.wait_event(Q). If R is queued before P (e.g., alphabetical), P idles behind R's wait_event:

不优顺序(R 先于 P)Suboptimal (R before P):  stream_X: [R.wait_event(Q)] → [R kernels] → [P kernels]
                                                       ↑ P 白等idles

优化顺序(P 先于 R)Optimal (P before R):     stream_X: [P kernels] → [R.wait_event(Q)] → [R kernels]
                                         ↑ P 立即执行runs immediately

GPU stall 定义:stream.wait_event(event) 使 GPU stream 暂停直到 event 被 record。如果 event 来自同 period 内尚在执行的另一个 stream 上的任务,stream 会空转。这段空转时间就是 stall。 GPU stall defined: stream.wait_event(event) pauses the GPU stream until the event is recorded. If the event comes from another stream's task still executing within the same period, the stream idles. This idle time is the stall.

Stage-descending 的局限:它只按 stage 排序,同 stage 内的任务顺序不考虑 stall_cost。ready-first 算法解决这个问题——优先提交无跨 stream 依赖(stall_cost=0)的任务。 Stage-descending limitation: It only sorts by stage; within the same stage, task order ignores stall_cost. Ready-first fixes this — it prioritizes tasks with no cross-stream deps (stall_cost=0).

核心思路:Period-Local 依赖图 + Stall Cost 优先级Core Idea: Period-Local Dependency Graph + Stall Cost Priority

算法分两步:The algorithm has two steps:

步骤 1:构建 period-local 依赖图Step 1: Build the period-local dependency graph

只纳入同一 period 内实际发生交互的依赖边(参见推论 3 的分类): Include only dependency edges that interact within the same period (see Corollary 3 classification):

边类型Edge type条件Condition图中方向Edge direction
intra_iter_deps stagedep = stagetask dep → task
inter_iter_deps stagedep = stagetask + 1 dep → task

不纳入的边:intra_iter_depsstagedep < stagetask(dep 在更早 period 完成);inter_iter_depsstagedep ≤ stagetask(dep 在更早 period 完成)。这些依赖的 threading.Event 已 set、cuda_event 已 record,不会造成任何阻塞。 Excluded edges: intra_iter_deps where stagedep < stagetask (dep completed in earlier period); inter_iter_deps where stagedep ≤ stagetask (dep completed in earlier period). These dependencies have threading.Event set and cuda_event recorded — zero blocking.

步骤 2:Stall-Cost 感知的拓扑排序Step 2: Stall-cost-aware topological sort

对 period-local 图做拓扑排序。当多个任务同时 ready(入度 = 0)时,按 stall_cost 升序选取: Topologically sort the period-local graph. When multiple tasks are ready (in-degree = 0), pick by ascending stall_cost:

stall_cost(task) = 该任务在 period-local 图中,有多少个跨 stream 的入边(dep 与 task 不在同一 stream 上)number of cross-stream in-edges for this task in the period-local graph (dep and task on different streams)

排序 keySort key: (stall_cost, task_name)
  → stall_cost = 0 的任务优先:不插入 wait_event,不堵同 stream 后续任务tasks first: no wait_event inserted, won't block subsequent tasks on same stream
  → task_name 做 tie-breaker,保证确定性(跨 rank 一致)as tie-breaker for determinism (consistent across ranks)

伪代码Pseudocode

def _build_enqueue_order_ready_first(self) -> List[PipelineTask]:
    # --- Step 1: build period-local graph ---
    adj:       Dict[str, List[str]] = defaultdict(list)   # dep → [tasks]
    in_degree: Dict[str, int]       = {t.name: 0 for t in self._defs.values()}

    for task_name, dep_list in self._intra_iter_deps.items():
        task_stage = self._stage_map[task_name]
        for dep_name in dep_list:
            if self._stage_map[dep_name] == task_stage:   # same stage → same period
                adj[dep_name].append(task_name)
                in_degree[task_name] += 1

    for task_name, dep_list in self._inter_iter_deps.items():
        task_stage = self._stage_map[task_name]
        for dep_name in dep_list:
            if self._stage_map[dep_name] == task_stage + 1:  # k=1 → same period
                adj[dep_name].append(task_name)
                in_degree[task_name] += 1

    # --- Step 2: compute stall_cost ---
    stall_cost: Dict[str, int] = {t.name: 0 for t in self._defs.values()}
    for task_name, dep_list in self._intra_iter_deps.items():
        task_stream = self._stream_map[task_name]
        task_stage  = self._stage_map[task_name]
        for dep_name in dep_list:
            if self._stage_map[dep_name] == task_stage \
               and self._stream_map[dep_name] != task_stream:
                stall_cost[task_name] += 1        # in-period cross-stream dep

    for task_name, dep_list in self._inter_iter_deps.items():
        task_stream = self._stream_map[task_name]
        task_stage  = self._stage_map[task_name]
        for dep_name in dep_list:
            if self._stage_map[dep_name] == task_stage + 1 \
               and self._stream_map[dep_name] != task_stream:
                stall_cost[task_name] += 1        # in-period cross-stream dep

    # --- Step 3: topo sort with stall_cost priority ---
    ready = sorted(
        [n for n, d in in_degree.items() if d == 0],
        key=lambda n: (stall_cost[n], n)
    )
    order: List[str] = []
    while ready:
        name = ready.pop(0)
        order.append(name)
        for succ in adj[name]:
            in_degree[succ] -= 1
            if in_degree[succ] == 0:
                ready.append(succ)
                ready.sort(key=lambda n: (stall_cost[n], n))

    return [self._defs[n] for n in order]

实例推演:TrainPipelineSparseDist(3-stage)Worked Example: TrainPipelineSparseDist (3-stage)

配置回顾(共 8 个任务,3 个 stream,3 个 stage):Configuration recap (8 tasks, 3 streams, 3 stages):

任务Task            stage   stream        intra_iter_deps            inter_iter_deps
─────────────────────────────────────────────────────────────────────────
H2D              0      memcpy        —                             —
InputDistStart   1      data_dist     H2D (stage 0→1, cross-stage)   —
InputDistWait    1      data_dist     InputDistStart                —
ZeroGrad         2      default       —                             —
WaitBatch        2      default       InputDistWait (stage 1→2),    —
                                      ZeroGrad
Forward          2      default       InputDistWait (stage 1→2),    —
                                      WaitBatch
Backward         2      default       Forward                       —
OptimizerStep    2      default       Backward                      —

步骤 1:Period-Local 图Step 1: Period-Local Graph

只保留 stagedep = stagetaskintra_iter_deps 边(无 inter_iter_deps):Keep only intra_iter_deps edges where stagedep = stagetask (no inter_iter_deps):

保留的边(同 stage = 同 period)Retained edges (same stage = same period):
  InputDistStart → InputDistWait        (stage 1 → 1) ✓
  ZeroGrad       → WaitBatch            (stage 2 → 2) ✓
  WaitBatch      → Forward              (stage 2 → 2) ✓
  Forward        → Backward             (stage 2 → 2) ✓
  Backward       → OptimizerStep        (stage 2 → 2) ✓

排除的边(跨 stage = 跨 period,dep 在更早 period 已完成)Excluded edges (cross-stage = cross-period, dep completed in earlier period):
  H2D → InputDistStart                  (stage 0 → 1) ✗
  InputDistWait → WaitBatch             (stage 1 → 2) ✗
  InputDistWait → Forward               (stage 1 → 2) ✗

步骤 2:计算 stall_costStep 2: Compute stall_cost

stall_cost = period-local 图中跨 stream 的入边数。所有保留的边都是同 stream 内部的:stall_cost = number of cross-stream in-edges in the period-local graph. All retained edges are within the same stream:

任务Task              period-local 入边period-local in-edges          跨 stream?cross-stream?   stall_cost
──────────────────────────────────────────────────────────────────────
H2D              none0
InputDistStart   none0
InputDistWait    InputDistStart(data_dist)   同 streamsame     0
ZeroGrad         none0
WaitBatch        ZeroGrad(default)           同 streamsame     0
Forward          WaitBatch(default)          同 streamsame     0
Backward         Forward(default)            同 streamsame     0
OptimizerStep    Backward(default)           同 streamsame     0

所有 stall_cost = 0,因为跨 stage(跨 stream)的 dep 被排除了——它们来自更早 period,GPU event 早已 record。 All stall_cost = 0 because cross-stage (cross-stream) deps are excluded — they come from earlier periods with GPU events already recorded.

步骤 3:拓扑排序Step 3: Topological Sort

初始 ready(入度=0,按 (stall_cost, name) 排序)Initial ready (in-degree=0, sorted by (stall_cost, name)):
  [(0,"H2D"), (0,"InputDistStart"), (0,"ZeroGrad")]

出队过程Dequeue process:
  1. H2D无后继在 period-local 图中no successors in period-local graph
  2. InputDistStart释放releases InputDistWait
  3. ZeroGrad释放releases WaitBatch
  4. InputDistWait无 period-local 后继no period-local successors
  5. WaitBatch释放releases Forward
  6. Forward释放releases Backward
  7. Backward释放releases OptimizerStep
  8. OptimizerStep

最终提交顺序Final submission order:

 #  任务Task             stream       stage   stall_cost
 1. H2D              memcpy        0       0
 2. InputDistStart   data_dist     1       0
 3. ZeroGrad         default       2       0
 4. InputDistWait    data_dist     1       0
 5. WaitBatch        default       2       0
 6. Forward          default       2       0
 7. Backward         default       2       0
 8. OptimizerStep    default       2       0
与 stage-descending 对比:stage-descending 输出的是 [ZeroGrad, WaitBatch, Forward, Backward, OptimizerStep, InputDistStart, InputDistWait, H2D]。两者在此例中效果相同——因为跨 stage 的 dep 都来自更早 period,不产生 stall。但在有 inter_iter_deps(k=1)的场景下,ready-first 的优势才体现出来。 Comparison with stage-descending: Stage-descending produces [ZeroGrad, WaitBatch, Forward, Backward, OptimizerStep, InputDistStart, InputDistWait, H2D]. Both are equivalent for this example — because all cross-stage deps come from earlier periods, producing no stalls. The advantage of ready-first emerges in scenarios with inter_iter_deps (k=1).

进阶示例:同 stream 不同 stage + inter_iter_depsAdvanced Example: Same Stream, Different Stages + inter_iter_deps

考虑一个假设的 pipeline,同一 stream 上有任务分布在不同 stage,且存在 inter_iter_deps(k=1):Consider a hypothetical pipeline where the same stream has tasks at different stages, with inter_iter_deps (k=1):

配置Setup:
  P: stage=0, stream=X               # stall_cost = 0
  Q: stage=1, stream=Y, intra_iter_deps=[P]     # 跨 stage:排除出 period-local 图cross-stage: excluded from period-local graph
  R: stage=0, stream=X, intra_iter_deps=[Q]     # 跨 stage:排除出 period-local 图cross-stage: excluded from period-local graph
  A: stage=0, stream=Z, inter_iter_deps=[Q]   # k=1: Q(stage 1) → 进入 period-local 图enters period-local graph

Period-Local 图Period-Local Graph(只有 k=1 的 inter_iter_deps 边) (only the k=1 inter_iter_deps edge):

  Q(stage 1) ──→ A(stage 0)       # cross_iter_dep, k=1, 同 periodsame period
  P, R: 无 period-local 边no period-local edges

stall_cost:

  P: 0    (无 period-local 入边)(no period-local in-edge)
  Q: 0    (无 period-local 入边)(no period-local in-edge)
  R: 0    (无 period-local 入边)(no period-local in-edge)
  A: 1    (Q→A 跨 stream: Y→Z)(Q→A cross-stream: Y→Z)

拓扑排序Topological sort:

初始 readyInitial ready: [(0,"P"), (0,"Q"), (0,"R")]   A 入度=1,不在 ready 中in-degree=1, not ready

  1. P  (stall=0, stream=X) → 立即执行,不等任何人runs immediately, no waits
  2. Q  (stall=0, stream=Y) → 立即执行runs immediately释放 Areleases A
  3. R  (stall=0, stream=X) → 立即执行(P 已在 X 上,FIFO 保证顺序)runs immediately (P already on X, FIFO guarantees order)
  4. A  (stall=1, stream=Z) → wait_event(Q),但 Q 已提交,延迟最小wait_event(Q), but Q already submitted, minimal delay
最终顺序Final order:  P(X,s0) → Q(Y,s1) → R(X,s0) → A(Z,s0)

GPU 时间线timeline:
  stream_X: ▓▓P▓▓▓▓R▓▓▓▓▓▓           ← P,R 连续执行,零空转run back-to-back, zero stall
  stream_Y: ▓▓Q▓▓▓▓▓▓▓▓▓▓▓
  stream_Z: ░░░░░░▓▓A▓▓▓▓▓           ← A 等 Q 完成后执行waits for Q then runs
Stage-descending 对比:stage-descending 会先提交 Q(stage 1),然后 P、R、A(stage 0)。顺序为 Q → P → R → A。在此例中 P 被延迟了——Q 在不同 stream 上,P 不需要等 Q,但 stage-descending 强制让 Q 先提交。ready-first 让 P 和 Q 同时出现在 ready 列表中,按 name 排序 P 先出队。 Stage-descending comparison: Stage-descending submits Q(stage 1) first, then P, R, A (stage 0). Order: Q → P → R → A. Here P is delayed — Q runs on a different stream and P doesn't need Q, but stage-descending forces Q first. Ready-first places both P and Q in the ready list simultaneously; P dequeues first by name.

正确性保证:为什么 Ready-First 能避免死锁Correctness Guarantee: Why Ready-First Prevents Deadlock

Ready-first 不仅优化性能,还从结构上保证了无死锁。死锁的根本原因是循环等待——任务 X 等待任务 Y,而 Y 又直接或间接等待 X。Ready-first 通过以下两层保证消除循环等待: Ready-first is not merely a performance optimization — it structurally guarantees deadlock freedom. Deadlock arises from circular waits — task X waits for Y while Y (directly or indirectly) waits for X. Ready-first eliminates circular waits through two layers of guarantees:

依赖来源Dependency source 为什么不会死锁Why no deadlock
跨 periodCross-period 被依赖方在更早的 period 中执行完毕:threading.Event 已 set、cuda_event 已 record。等待立即返回,零阻塞。 The dependency ran in an earlier period: threading.Event already set, cuda_event already recorded. Waits return immediately — zero blocking.
同 period(period-local)Same period (period-local) 拓扑排序保证:若 X 依赖 Y(period-local 边 Y→X),则 Y 一定排在 X 前面入队。在 FIFO 工作队列中,Y 先被 worker 取出执行,X 等 Y 时 Y 已经在运行或已完成。 Topological sort guarantees: if X depends on Y (period-local edge Y→X), Y is enqueued before X. In the FIFO work queue, Y is picked up first by the worker — when X waits for Y, Y is already running or completed.

多线程场景:每个 thread group 有独立的 FIFO 队列。如果任务 X 排在队头但阻塞等待同 queue 中更后面的 Y → head-of-line 死锁。Ready-first 的拓扑排序确保 Y 永远排在 X 前面,消除了这种死锁模式。 Multi-threaded scenario: Each thread group has its own FIFO queue. If task X is at the head but blocks waiting for Y further back in the same queue → head-of-line deadlock. Ready-first's topological sort ensures Y is always ahead of X, eliminating this deadlock pattern.

❌ 可能死锁的入队顺序(未按拓扑排序)❌ Potentially deadlocking enqueue order (not topologically sorted):
  thread_queue: [X, Y]      X 先出队,等 Y 的 CPU signal → Y 永远出不了队 → 死锁dequeued first, waits for Y's CPU signal → Y never dequeues → deadlock

✓ Ready-first 保证的入队顺序(拓扑排序)✓ Ready-first guaranteed enqueue order (topologically sorted):
  thread_queue: [Y, X]      Y 先出队执行 → set CPU signal → X 出队后立即获得 signaldequeued first → set CPU signal → X dequeues and immediately gets signal

形式化:period-local 依赖图是 DAG(无环有向图)——_validate 确保 intra_iter_deps 满足 dep.stage ≤ task.stageinter_iter_deps 满足 dep.stage - task.stage ≤ 1(推论 1、3)。DAG 的拓扑排序一定存在且无环,因此 ready-first 产生的入队顺序必然无死锁 Formal argument: The period-local dependency graph is a DAG — _validate ensures intra_iter_deps satisfy dep.stage ≤ task.stage and inter_iter_deps satisfy dep.stage - task.stage ≤ 1 (Corollaries 1 & 3). A DAG always admits a topological sort with no cycles, so the ready-first enqueue order is guaranteed deadlock-free.

完整性证明:排除的跨 stage 边不影响正确性Completeness Proof: Excluded Cross-Stage Edges Do Not Affect Correctness

Period-local 图排除intra_iter_depsstagedep < stagetask 的边。这是否意味着这些同 batch 的依赖得不到保证?以下定理证明不会。 The period-local graph excludes intra_iter_deps edges where stagedep < stagetask. Does this mean these same-batch dependencies are unsatisfied? The following theorem proves they are not.

定理:对于任意 intra_iter_deps 边 (T, D)(T 依赖 D,同一迭代 i),ready-first 入队顺序保证 D 的 CPU signal 和 CUDA event 在 T 执行时已可用。 Theorem: For any intra_iter_deps edge (T, D) where T depends on D within iteration i, the ready-first enqueue order guarantees D's CPU signal and CUDA event are available when T executes.

证明:设 D 的 stage 为 sD,T 的 stage 为 sT。推论 1(_validate 强制)要求 sD ≤ sT。分三种情况: Proof: Let sD and sT be the stages of D and T. Corollary 1 (_validate enforced) requires sD ≤ sT. Three cases:

情况Case stage 关系Stage relation 证明Proof
A sD = sT D 和 T 在同一 period P = i + sD。边 D→T period-local 图中,拓扑排序保证 D 排在 T 前面入队。
· 单线程:D 先 CPU 提交,T 后提交。✓
· 多线程同 thread group:FIFO 保证 D 先被 worker 处理。✓
· 多线程不同 thread group:T 的 worker 调用 cpu_signal(i, D).wait() 阻塞。D 的依赖均来自更早 period 或同 period 更上游(topo 排序保证),D 不依赖 T → 无循环等待 → D 终将完成。✓
D and T are in the same period P = i + sD. Edge D→T is in the period-local graph; topological sort guarantees D is enqueued before T.
· Single-threaded: D is CPU-submitted before T. ✓
· Multi-threaded, same thread group: FIFO ensures D is picked up first. ✓
· Multi-threaded, different groups: T's worker calls cpu_signal(i, D).wait(). D's deps are from earlier periods or upstream in the same period (topo-sorted); D does not depend on T → no circular wait → D eventually completes. ✓
B sD < sT D 在 period PD = i + sD,T 在 period PT = i + sTPD < PT。边不在 period-local 图中。

关键:periods 按序处理,_enqueue_period(PT) 被调用时 _enqueue_period(PD) 早已完成。

子证明(D 不被 PT 中的任务阻塞):D 的所有依赖要么来自更早 period(< PD,已完成),要么来自 PD 内的同 stage 任务(情况 A 保证)。由推论 1 的 stagedep ≤ stagetask 约束,D 的依赖链只能指向 stage ≤ sD < sT 的任务,不可能涉及 period PT 中的任何任务。

因此 D 终将执行完毕:cpu_signal(i, D) 已 set → T 的 sig.wait() 立即返回;cuda_event(i, D) 已 record → T 的 stream.wait_event() 保证 GPU 顺序。✓
D runs in period PD = i + sD, T in period PT = i + sT, PD < PT. Edge is not in the period-local graph.

Key: Periods are processed sequentially; _enqueue_period(PT) is called after _enqueue_period(PD) has completed.

Sub-proof (D is not blocked by any task in PT): All of D's dependencies are from earlier periods (< PD, already done) or same-period tasks in PD (guaranteed by Case A). By Corollary 1's stagedep ≤ stagetask constraint, D's dependency chain can only reach tasks with stage ≤ sD < sT, never involving any task in period PT.

Therefore D eventually completes: cpu_signal(i, D) is set → T's sig.wait() returns immediately; cuda_event(i, D) is recorded → T's stream.wait_event() guarantees GPU ordering. ✓
C sD > sT _validate 禁止(推论 1)。
此时 D 在 period PD = i + sD > PT,T 执行时 D 尚未入队 → cpu_signal(i, D).wait() 永远等不到 → 结构性死锁
这就是推论 1 存在的原因。
Forbidden by _validate (Corollary 1).
D would be in period PD = i + sD > PT; when T executes, D is not yet enqueued → cpu_signal(i, D).wait() blocks forever → structural deadlock.
This is precisely why Corollary 1 exists.

具体示例(FusedSparseDist)InputDistStart(stage=1) 依赖 H2D(stage=0),属于情况 B。 Concrete example (FusedSparseDist): InputDistStart(stage=1) depends on H2D(stage=0) — this is Case B.

迭代Iteration i = 3:
  H2D(i=3):            stage=0, period = 3+0 = 3
  InputDistStart(i=3): stage=1, period = 3+1 = 4

时间线Timeline:
  period 3: 提交enqueue H2D(i=3)  →  cpu_signal(3,"H2D").set()  →  cuda_event(3,"H2D").record()
  period 4: 提交enqueue InputDistStart(i=3)
            → cpu_signal(3,"H2D").wait()   →  已 set,立即返回already set, returns immediately
            → stream.wait_event(cuda_event) →  已 record,GPU 保证顺序already recorded, GPU guarantees order

QED:推论 1 排除情况 C;情况 A 由 period-local 拓扑排序保证;情况 B 由 period 顺序执行保证。三者合一,ready-first 对所有 intra_iter_deps 必然正确。对 inter_iter_deps 的证明类似(用推论 3 替换推论 1,证明结构相同)。 QED: Corollary 1 rules out Case C; Case A is guaranteed by period-local topological sort; Case B is guaranteed by sequential period processing. Combined, ready-first is correct for all intra_iter_deps. The proof for inter_iter_deps is analogous (substitute Corollary 3 for Corollary 1; proof structure is identical).

SWPipeline

Class
class SWPipeline(plan: PipelinePlan, device: int = 0)
通用软件流水线引擎。接收 PipelinePlan 后自动完成拓扑排序、分配同步资源、创建 _SubmissionSequencer。支持 pipelined / serial / single-iter 三种执行模式,以及运行时 task shortcut。构造时验证 DAG(无环、stage 单调、名称唯一)。 General-purpose software pipeline engine. Parses PipelinePlan, performs topo sort, allocates sync resources, creates _SubmissionSequencer. Supports pipelined / serial / single-iter execution and runtime task shortcutting. Validates DAG at construction (acyclic, monotone stages, unique names).
参数Param类型Type说明Description
planPipelinePlan完整调度计划Complete scheduling plan
deviceintCUDA 设备号(默认 0)CUDA device (default 0)
三种执行模式Three Execution Modes
pipe = SWPipeline(plan, device=0)

# Mode 1: run() — one-shot pipelined execution
elapsed = pipe.run(dataloader, verbose=True)

# Mode 2: run_serial() — serial baseline
baseline = pipe.run_serial(dataloader, verbose=True)

# Mode 3: fill_pipeline + progress — fine-grained control
data_iter = pipe.fill_pipeline(dataloader)
while True:
    try:
        idx = pipe.progress(data_iter)
    except StopIteration:
        break

执行时序图Execution Sequence Diagram

fill_pipeline()progress() 循环。depth=2, 3 batches, Stage 0: CopyBatch (IO Worker), Stage 1: Forward + Backward (Comp Worker)。 fill_pipeline()progress() loop. depth=2, 3 batches. Stage 0: CopyBatch (IO Worker), Stage 1: Forward + Backward (Comp Worker).

IterContext Main Thread IO Worker Comp Worker fill_pipeline(data_iter) create ctx₀ CopyBatch(0) create ctx₁ Fwd(0), Bwd(0) CopyBatch(1) return Copy Batch(0) H2D cpu_sig + cuda_evt Copy Batch(1) wait… Fwd(0) compute Bwd(0) iter_complete[0].set() Pipeline Overlap progress(data_iter) → returns 0 wait(complete[0]) ✓ stream.wait_event() del ctx₀ create ctx₂ CopyBatch(2) Fwd(1), Bwd(1) return 0 CB(2) Fwd(1) Bwd(1) iter_complete[1].set() Overlap progress(data_iter) → returns 1 wait(complete[1]) ✓ del ctx₁ Fwd(2), Bwd(2) Fwd(2) Bwd(2) iter_complete[2].set() progress() → del ctx₂, stop_workers(), StopIteration ctx₀ ✕ del ctx₁ ✕ del ctx₂ Stage 0 Stage 1 idle overlap iter_complete dispatch

SWPipeline 方法Methods

fill_pipeline

Method
fill_pipeline(data_iter) Iterator
预填充流水线:预取 depth 个 batch,启动 worker 线程,提交前 depth 个 period。必须在首次 progress() 前调用。接受任何可迭代对象。 Prefetch depth batches, start worker threads, submit first depth periods. Must be called before first progress(). Accepts any iterable.
参数Param类型Type说明Description
data_iterIterable数据源Data source

返回ReturnsIterator

异常RaisesRuntimeError (重复调用)(called twice without drain)

data_iter = pipe.fill_pipeline(train_loader)
# Now ready for pipe.progress(data_iter)

progress

Method
progress(data_iter: Optional[Iterator] = None) int
推进一步:(1) 等待最老 in-flight 迭代完成 (2) 提交下一个 period (3) 返回已完成的 iter_idx。传 data_iter=None 只 retire 不入队新数据。超时 60 秒抛 RuntimeError Advance one step: (1) wait for oldest in-flight iteration (2) enqueue next period (3) return completed iter_idx. Pass None to retire without enqueuing. 60s timeout raises RuntimeError.

返回Returnsint · 异常RaisesStopIteration, RuntimeError

data_iter = pipe.fill_pipeline(dataloader)
while True:
    try:
        idx = pipe.progress(data_iter)
    except StopIteration:
        break

run

Method
run(data_iter: Iterator, verbose: bool = False, emit_nvtx: bool = False) float
一站式 pipelined 执行:消费整个 data_iter,返回 wall-clock 秒数。内部 = fill_pipeline + progress 循环。NVTX 前缀 SWP/{TaskName}/iter{N}emit_nvtx=True 启用 torch.autograd.profiler.emit_nvtx() 让 backward kernel 在 nsys 中可见。 One-shot pipelined execution consuming entire data_iter. Returns wall-clock seconds. Internally = fill_pipeline + progress loop. NVTX prefix SWP/{TaskName}/iter{N}. emit_nvtx=True enables torch.autograd.profiler.emit_nvtx() for nsys backward-kernel visibility.

返回Returnsfloat (秒,含 cuda.synchronize)(seconds, includes cuda.synchronize)

elapsed = pipe.run(batches, verbose=True, emit_nvtx=True)
print(f"Pipeline: {elapsed*1000:.1f} ms")

run_serial

Method
run_serial(data_iter: Iterator, verbose: bool = False, emit_nvtx: bool = False) float
在 default stream 上串行执行所有迭代(无流水线、无多线程),用作性能基线。NVTX 前缀 SWP_serial/{TaskName}/iter{N}。Shortcut task 自动追加 [skip] 后缀。 Run all iterations serially on default stream (no pipeline, no threads) as a performance baseline. NVTX prefix SWP_serial/{TaskName}/iter{N}. Shortcut tasks get [skip] suffix.

返回Returnsfloat

serial   = pipe.run_serial(make_data(100), verbose=True)
pipeline = pipe.run(make_data(100), verbose=True)
print(f"Speedup: {serial / pipeline:.2f}x")

run_one_serial_iter

Method
run_one_serial_iter(batch: Any, iter_idx: int = 0) None
在 default stream 上执行单次迭代。按拓扑序依次执行所有 Task。尊重 shortcut 状态——被 shortcut 的 Task 跳过 fn(ctx) 并回放缓存。不需要 fill_pipeline(),不涉及 worker 线程。原子操作,无 in-flight 状态。 Single serial iteration on default stream. Executes tasks in topo order. Respects shortcut state — shortcutted tasks skip fn(ctx) and replay cache. No fill_pipeline() needed, no workers. Atomic, leaves no in-flight state.
batch = next(iter(train_loader))
pipe.run_one_serial_iter(batch, iter_idx=0)

# With shortcut
pipe.enable_shortcut("Forward")
pipe.run_one_serial_iter(batch, iter_idx=1)  # caching
pipe.run_one_serial_iter(batch, iter_idx=2)  # shortcut (Forward skipped)

enable_shortcut

Method
enable_shortcut(*task_names: str) None
启用 shortcut 模式。首次执行 = caching(正常跑 fn + 快照输出 + 调用 DeclaredIO.capture()),后续执行 = shortcut(跳过 fn,回放 cache + 调用 DeclaredIO.restore())。globally_ordered Task 仍执行 _SubmissionSequencer 保序。所有 rank 必须同步 enable/disable 相同 Task。 Enable shortcut mode. First execution = caching (run fn + snapshot + call DeclaredIO.capture()), subsequent = shortcut (skip fn, replay cache + call DeclaredIO.restore()). globally_ordered tasks still run _SubmissionSequencer. All ranks must enable/disable same tasks.

异常RaisesValueError (未知 task 名称)(unknown task name)

pipe.enable_shortcut("EmbLookup")
pipe.enable_shortcut("EmbLookup", "Prefetch", "MPEmbForward")

disable_shortcut

Method
disable_shortcut(*task_names: str) None
关闭 shortcut 并清除缓存。后续 Task 恢复正常执行,无额外开销。 Disable shortcut and clear cache. Tasks resume normal execution, zero overhead.
pipe.disable_shortcut("EmbLookup")

drain

Method
drain() None
Pipelined 模式:依次 retire 所有 in-flight 迭代 → 停止 worker → 同步 CUDA → _reset_state()。之后可重新 fill_pipeline()
Serial 模式run_one_serial_iter):仅 cuda.synchronize——串行迭代是原子操作,无 in-flight 状态。
Shortcut 配置和缓存保留,需显式 disable_shortcut() 清除。
Pipelined mode: retires all in-flight iterations → stops workers → syncs CUDA → _reset_state(). Then fill_pipeline() again.
Serial mode (run_one_serial_iter): just cuda.synchronize — serial iterations are atomic.
Shortcut config and cache are preserved; call disable_shortcut() to clear.
# Switch data iterator in pipelined mode
pipe.drain()
data_iter = pipe.fill_pipeline(new_loader)

# Change shortcut config mid-pipeline
pipe.drain()
pipe.enable_shortcut("Forward")
data_iter = pipe.fill_pipeline(loader)
Method
print_schedule(num_iterations: int = 3) None
打印 period × task 调度表。每行一个 Task,列含 Thread、Stream、各 period 上的迭代分配。被 shortcut 的 Task 名后缀 [skip] Print period × task schedule table. One task/row. Columns: Thread, Stream, iteration per period. Shortcutted tasks have [skip] suffix.
pipe.print_schedule(3)
#    #  Task                      Thread   Stream   | P0   P1   P2
#   --  ------------------------  -------  -------  + ---- ---- ----
#    0  H2DAndShuffle             default  default  | i0   i1   i2
#    1  EmbInputDistStart [skip]  default  default  |  .    .    .
#    2  EmbInputDistWait          default  default  | i0   i1   i2

format_schedule

Method
format_schedule(num_iterations: int = 3) str
print_schedule,但返回字符串不打印。可用于日志或进一步处理。 Same as print_schedule but returns string instead of printing. Useful for logging.
logger.info(pipe.format_schedule(5))
Method
print_stage_analysis() None
打印拓扑排序、stage 分配、Dilworth 最长递减子序列 (LDS)、以及 LDS == depth 时的最优性验证。 Print topo order, stage assignment, Dilworth LDS, and optimality check (LDS == depth).
pipe.print_stage_analysis()
# Task schedule (topological order):
#   Pos  Task              Idx  Stage  Type
#   --------------------------------------------------
#     0  CopyBatch           0      0  stream
#     1  Forward             1      1  stream
#     ...
# LDS = 2  (example: (1, 0))
# Pipeline depth = 2
#   ✓  LDS == depth (Dilworth-optimal)

__repr__

Method
__repr__() str
返回可读摘要:设备、深度、任务列表、当前 shortcut、完整调度表。 Readable summary: device, depth, task list, active shortcuts, and full schedule table.
print(pipe)
# SWPipeline(device=0, depth=1, tasks=[...], shortcuts=['MPEmbForward'])
#
#    #  Task                      Thread   Stream   | P0
#   --  ------------------------  -------  -------  + ---
#    0  H2DAndShuffle             default  default  | i0

TaskProfiler

Class
class TaskProfiler(pipeline: SWPipeline)
Per-task exposed time 测量工具。基于 SWPipeline shortcut 机制:逐一 shortcut 每个 task,测量 exposed(T) = baseline − serial_with_T_shortcut。结果驱动 auto-scheduling 决策。 Per-task exposed time measurement. Uses SWPipeline shortcut: shortcut each task one by one, measure exposed(T) = baseline − serial_with_T_shortcut. Results drive auto-scheduling decisions.
profiler = TaskProfiler(pipe)
result = profiler.profile(batch, num_warmup=5, num_measure=20)
result.print_report()

profile

Method
profile(batch: Any, num_warmup: int = 3, num_measure: int = 10, num_rounds: int = 3, skip_tasks: Optional[set] = None) ProfileResult
四阶段 profiling:(1) Warmup (2) Baseline(取 num_rounds 轮中位数)(3) 逐 Task shortcut 测量 (4) exposed = max(0, baseline - median_shortcut)。每 round 只有两次 cuda.synchronizeskip_tasks 排除不可独立 shortcut 的 Task(如 LossBackward)。 Four-phase profiling: (1) Warmup (2) Baseline (median of num_rounds) (3) Per-task shortcut measurement (4) exposed = max(0, baseline - median_shortcut). Only two cuda.synchronize per round. skip_tasks excludes non-independently-shortcuttable tasks (e.g. LossBackward).
参数Param类型Type默认Default说明Description
batchAny代表性 batchRepresentative batch
num_warmupint3预热次数Warmup iterations
num_measureint10每 round 迭代数Iterations per round
num_roundsint3round 数Number of rounds
skip_tasksOptional[set]None排除的 Task 名称Task names to exclude

返回ReturnsProfileResult

result = profiler.profile(batch, num_warmup=5, num_measure=20, num_rounds=5)
result.print_report()

# Access raw data
for name, t in result.exposed_s.items():
    print(f"{name}: {t*1e3:.3f} ms")

profile_many

Method
profile_many(batches: list, num_warmup: int = 3, num_measure: int = 10, num_rounds: int = 3) List[ProfileResult]
对多个 batch 独立 profile。用于分析 exposed time 随数据分布变化的情况。 Profile each batch independently. Useful for analyzing exposed time variation across data distributions.
results = profiler.profile_many([batch_a, batch_b, batch_c])
for i, r in enumerate(results):
    print(f"\n--- Batch {i} ---")
    r.print_report()

Method
ProfileResult.print_report() None
格式化打印 baseline、每个 Task 的 exposed time(ms + %)、总和。SUM > 100% 表示 Task 在关键路径上有重叠。 Print baseline, per-task exposed time (ms + %), and total. SUM > 100% means tasks overlap on critical path.
result.print_report()
# Baseline serial iteration: 8.503 ms
#
#   Task                       Exposed   % baseline
#   ------------------------ ---------- ------------
#   CopyBatch                  0.521ms        6.1%
#   EmbLookup                  1.234ms       14.5%
#   ...
#   ------------------------ ---------- ------------
#   SUM                        7.571ms       89.0%

内部机制Internals

以下为框架内部实现细节,普通用户无需直接调用。了解它们有助于调试和理解 shortcut 行为。 The following are internal implementation details. Normal users don't need to call them directly. Understanding them helps with debugging and reasoning about shortcut behavior.

_exec_task

Internal
_exec_task(task_def: PipelineTask, ctx: IterContext) None
Task 执行的核心分发逻辑。根据 shortcut 状态切换三种模式: Core task dispatch logic. Switches between three modes based on shortcut state:
# Simplified logic:
def _exec_task(self, task_def, ctx):
    name = task_def.name
    if name in self._shortcut_tasks:
        if name in self._shortcut_cache:
            self._apply_shortcut(ctx, task_def)      # Mode 1: replay
            return
        before_ids = {k: id(v) for k, v in vars(ctx).items()}
        task_def.fn(ctx)
        self._capture_and_cache(ctx, task_def, before_ids)  # Mode 2: cache
        return
    task_def.fn(ctx)                                  # Mode 3: normal
Mode 1 (Shortcut Active): 跳过 fn(ctx),从缓存恢复 ctx 属性 + _GraftGrad 梯度桥 + DeclaredIO.restore()
Mode 2 (Caching): 正常执行 fn(ctx),diff ctx 属性,_cache_val 递归 detach+clone,调用 DeclaredIO.capture(),存入缓存。
Mode 3 (Normal): 直接执行 fn(ctx),零额外开销。
Mode 1 (Shortcut Active): skip fn(ctx), restore ctx attrs + _GraftGrad gradient bridge + DeclaredIO.restore().
Mode 2 (Caching): run fn(ctx) normally, diff ctx attrs, _cache_val recursive detach+clone, call DeclaredIO.capture(), store in cache.
Mode 3 (Normal): run fn(ctx) directly, zero overhead.

_cache_val

Internal Static
_cache_val(v: Any) Any
递归缓存值。对每个 torch.Tensor 执行 detach().clone(),保存 requires_grad 状态。支持 dictlisttuple 和含 Tensor 的任意对象(如 TorchRec JaggedTensor)。返回值用 _CACHE_TAG sentinel 标记,供 _restore_val 识别重建。 Recursively cache a value. Each torch.Tensor gets detach().clone(), preserving requires_grad. Supports dict, list, tuple, and arbitrary objects with tensor attrs (e.g. JaggedTensor). Tagged with _CACHE_TAG sentinel for _restore_val.
# Tags: "t"=Tensor, "d"=dict, "l"=list, "u"=tuple, "o"=object
_cache_val(tensor)     → (_CACHE_TAG, "t", tensor.detach().clone(), requires_grad)
_cache_val({"a": t})   → (_CACHE_TAG, "d", {"a": _cache_val(t)})
_cache_val(jt)         → (_CACHE_TAG, "o", JaggedTensor, {attr: _cache_val(v) ...})

_restore_val

Internal Static
_restore_val(v: Any) Any
_cache_val 输出递归重建值。Tensor 通过 detach()(零拷贝)生成新叶节点,忠实恢复 requires_grad。每次调用返回全新 Tensor,防止 "backward through graph a second time" 错误。 Recursively restore from _cache_val output. Tensors get detach() (zero-copy) as fresh leaves, faithfully restoring requires_grad. Each call returns brand-new tensors, preventing "backward through graph a second time" errors.

_GraftGrad

Internal
class _GraftGrad(torch.autograd.Function)
自定义 autograd Function,作为 shortcut 的 "梯度桥"。Forward: identity(返回 restored tensor 不变)。Backward: 将 grad_output 传给 restored tensor;将预分配零张量传给 upstream inputs(触发它们的 backward 但不影响参数梯度)。确保 shortcut 后下游 loss.backward() 能正确传播梯度。 Custom autograd Function serving as shortcut's "gradient bridge". Forward: identity (returns restored tensor unchanged). Backward: passes grad_output to restored; passes pre-allocated zeros to upstream inputs (triggers backward without affecting param grads). Ensures loss.backward() works after shortcutting.

_SubmissionSequencer

Internal
class _SubmissionSequencer
跨线程的确定性提交排序器。每个 globally_ordered Task 获得预分配的序列号。execute_ordered(seq, fn) 阻塞直到轮到该 Task,确保所有 rank 以相同顺序调用 NCCL 集合操作,防止死锁。超时 30 秒抛 RuntimeError Deterministic cross-thread submission sequencer. Each globally_ordered task gets a pre-assigned sequence number. execute_ordered(seq, fn) blocks until it's that task's turn, ensuring all ranks call NCCL collectives in identical order (prevents deadlocks). 30s timeout raises RuntimeError.
方法Method说明Description
reset()重置序列号为 0Reset sequence counter to 0
execute_ordered(seq, fn)等待 _next_seq == seq,然后执行 fn()Wait for _next_seq == seq, then execute fn()

端到端示例 — FusedSparseDistEnd-to-End Example — FusedSparseDist

FusedSparseDist Pipeline(3-stage, 4 streams, 9 tasks)FusedSparseDist Pipeline (3-stage, 4 streams, 9 tasks)
# TrainPipelineFusedSparseDist 的 SWPipeline 等价表示
# 9 个 task · 3 stages · 4 streams: memcpy, data_dist, emb_lookup, default

import torch
from sw_pipeline import (
    PipelineTask, TaskSchedule, PipelinePlan, SWPipeline,
)

# 1. Streams
memcpy_stream     = torch.cuda.Stream()
data_dist_stream  = torch.cuda.Stream()
emb_lookup_stream = torch.cuda.Stream()

# 2. Tasks (function bodies omitted for brevity)
t_h2d        = PipelineTask("H2D",            h2d)
t_dist_start = PipelineTask("InputDistStart", input_dist_start)
t_dist_wait  = PipelineTask("InputDistWait",  input_dist_wait)
t_emb        = PipelineTask("EmbLookup",      emb_lookup)
t_zero       = PipelineTask("ZeroGrad",       zero_grad)
t_wait       = PipelineTask("WaitBatch",      wait_batch)
t_fwd        = PipelineTask("Forward",        dense_forward)
t_bwd        = PipelineTask("Backward",       backward)
t_opt        = PipelineTask("OptimizerStep",  optimizer_step)

# 3. PipelinePlan — 3 stages, 4 streams
plan = PipelinePlan(
    schedule={
        t_h2d:        TaskSchedule(stage=0, stream=memcpy_stream),
        t_dist_start: TaskSchedule(stage=1, stream=data_dist_stream, globally_ordered=True),
        t_dist_wait:  TaskSchedule(stage=1, stream=data_dist_stream),
        t_emb:        TaskSchedule(stage=2, stream=emb_lookup_stream),
        t_zero:       TaskSchedule(stage=2),           # default stream
        t_wait:       TaskSchedule(stage=2),
        t_fwd:        TaskSchedule(stage=2),
        t_bwd:        TaskSchedule(stage=2),
        t_opt:        TaskSchedule(stage=2),
    },
    intra_iter_deps=[
        (t_dist_start, t_h2d),         # cross-stage: memcpy → data_dist
        (t_dist_wait,  t_dist_start),   # same-stream ordering
        (t_emb,        t_dist_wait),    # cross-stage: data_dist → emb_lookup
        (t_fwd,        t_emb),          # cross-stream: emb_lookup → default
        (t_wait,       t_zero),         # same-stream ordering
        (t_fwd,        t_wait),         # same-stream ordering
        (t_bwd,        t_fwd),
        (t_opt,        t_bwd),
    ],
    inter_iter_deps=[
        (t_emb, t_bwd),               # EmbLookup(i) 依赖 Backward(i-1):TBE fused 在 backward 时更新权重
    ],
)
# pipeline_depth 自动推断: max(stage)+1 = 3

# 4. Construct & inspect
pipe = SWPipeline(plan, device=0)
pipe.print_schedule(5)

# print_schedule(5) 输出:
#    #  Task               Thread   Stream        | P0    P1    P2    P3    P4
#   --  -----------------  -------  ------------  + ----- ----- ----- ----- -----
#    0  EmbLookup          default  emb_lookup    |  --    --   i0    i1    i2
#    1  ZeroGrad           default  default       |  --    --   i0    i1    i2
#    2  WaitBatch          default  default       |  --    --   i0    i1    i2
#    3  Forward            default  default       |  --    --   i0    i1    i2
#    4  Backward           default  default       |  --    --   i0    i1    i2
#    5  OptimizerStep      default  default       |  --    --   i0    i1    i2
#    6  InputDistStart     default  data_dist     |  --   i0    i1    i2    i3
#    7  InputDistWait      default  data_dist     |  --   i0    i1    i2    i3
#    8  H2D                default  memcpy        | i0    i1    i2    i3    i4

基于Based on sw_pipeline.py · ← 回到 SWPipeline 总览← Back to SWPipeline Overview