How SimAI separates prefill and decode phases onto different GPU pools -- the key new capability compared to original Vidur.
vidur-alibabacloud · SplitWise SchedulerIn standard LLM serving, every GPU replica runs both prefill (processing the full prompt in parallel) and decode (generating tokens one-by-one autoregressively). These two phases have fundamentally different computational profiles: prefill is compute-bound with high arithmetic intensity and long per-request latency, while decode is memory-bandwidth-bound with low arithmetic intensity but extreme latency sensitivity since each token delays the user-visible response.
When co-located on the same GPU, prefill and decode contend for the same resources. A long prefill batch can stall ongoing decode iterations, causing Time-Between-Tokens (TBT) spikes. Conversely, interleaving small decode batches with prefill wastes GPU compute cycles. The disaggregation approach -- pioneered by the SplitWise and DistServe papers -- separates prefill and decode onto dedicated GPU pools (P-nodes and D-nodes), allowing each pool to be independently optimized for its workload profile.
The key tradeoff is straightforward: disaggregation introduces a KV cache transfer cost between the P-node (which generates the KV cache during prefill) and the D-node (which needs it for decode). If the transfer bandwidth is sufficient and the model's KV cache per token is small, disaggregation can dramatically improve both TTFT (Time-To-First-Token) and TBT. If the transfer is slow or the KV cache is large, the overhead may negate the benefits.
SimAI extends the original Vidur simulator with four tightly integrated additions: (1) a ReplicaType classification that tags each replica as PREFILL, DECODE, or MIXED; (2) a request DAG architecture that models the KV cache transfer as an explicit graph edge; (3) the SplitwiseGlobalScheduler that routes requests to the correct pool; and (4) KV cache size estimation and transfer flow modeling. Together, these enable full PD disaggregation simulation without any GPU hardware.
The foundation of PD disaggregation is the ReplicaType enum. In original Vidur, all replicas are implicitly MIXED -- they handle both prefill and decode. SimAI adds explicit P/D typing so each replica knows its role.
from enum import IntEnum
class ReplicaType(IntEnum):
MIXED = 0 # Co-located: both prefill + decode (original Vidur)
PREFILL = 1 # Dedicated prefill nodes (P-nodes)
DECODE = 2 # Dedicated decode nodes (D-nodes)
Processes only the prompt phase. Optimized for compute throughput -- large batch sizes, high GPU utilization. After prefill completes, the KV cache is transferred to a D-node.
Processes only the token generation phase. Optimized for latency consistency -- steady decode iterations without prefill interference, ensuring low and predictable TBT.
SimAI represents each request as a directed acyclic graph (DAG) using nx.DiGraph(). In the co-located case, the DAG is simple: PromptTask -> TokenTask. In the disaggregated case, a KVCacheTransferFlow node is inserted between them, modeling the data movement cost.
class Request(BaseEntity):
def __init__(self, ...):
# DAG for task dependency modeling
self.dag = nx.DiGraph()
self.root_node = None
self.flow_node = None
# PD disaggregation tracking
self.prefill_replica_id = None
self.decode_replica_id = None
# Timing fields
self.prefill_arrived_at = arrived_at
self.decode_arrived_at = float('inf')
self.decode_time = float('inf')
# KV cache transfer metadata
self.pd_p2p_comm_size = float('inf')
self.pd_p2p_comm_time = float('inf')
self.pd_p2p_bytes_per_token = None
self.pd_p2p_comm_dtype = None
The SplitwiseGlobalScheduler is the central orchestrator for PD disaggregation. During initialization, it splits the replica pool into prefill and decode subsets based on pd_node_ratio, then creates separate sub-schedulers for each pool. When a request arrives, it creates a PromptTask and TokenTask, routes the PromptTask to a P-node and the TokenTask to a D-node, and inserts the KV cache transfer flow between them.
class SplitwiseGlobalScheduler(BaseGlobalScheduler):
def __init__(self, config: SimulationConfig, replicas: Dict[int, Replica]):
super().__init__(config, replicas)
# P:D ratio from config (e.g., 0.5 = 1P:1D, 0.33 = 1P:2D)
self.pd_node_ratio = self._replicas[0].pd_node_ratio
# Split replicas into P-pool and D-pool
self._num_prefill_nodes = self._num_replicas * self.pd_node_ratio
self._num_decode_nodes = self._num_replicas - self._num_prefill_nodes
# Build prefill replica dict (IDs 0..N_p-1)
self.prefill_replicas = {}
for replica_id, replica in self._replicas.items():
if replica_id < self._num_prefill_nodes:
self.prefill_replicas[replica_id] = replica
replica.replica_type = ReplicaType.PREFILL
# Build decode replica dict (IDs N_p..N-1)
self.decode_replicas = {}
for replica_id, replica in self._replicas.items():
if replica_id >= self._num_prefill_nodes:
self.decode_replicas[replica_id] = replica
replica.replica_type = ReplicaType.DECODE
# Create sub-schedulers for each pool (default: LOR)
self.prefill_scheduler = self.get_global_scheduler(self.prefill_replicas)
self.decode_scheduler = self.get_global_scheduler(self.decode_replicas)
# Round-robin counters
self.p_request_counter = 0
self.d_request_counter = 0
The schedule() method processes each request in three steps:
For each request, create a PromptTask (with prompt_size tokens) and a TokenTask (with decode_tokens - 1 iterations). Build the initial DAG edge: PromptTask -> TokenTask.
Select a prefill replica via round-robin over the P-pool. Select a decode replica via round-robin with offset over the D-pool. Record request.prefill_replica_id and request.decode_replica_id. Note: although the constructor sets self._sub_scheduler = "lor" and creates LOR sub-schedulers, the actual schedule() method bypasses them and uses simple round-robin counters instead (see source code analysis below).
If prefill_replica != decode_replica, call add_kv_cache_transfer() which transforms the DAG from PromptTask -> TokenTask to PromptTask -> KVCacheTransferFlow -> TokenTask, with the flow linked via a DummyLink that simulates transfer latency.
SplitwiseGlobalScheduler contains two distinct node-selection strategies in its source code, but only one is actually active. Understanding this gap is critical for anyone extending or benchmarking SimAI's PD disaggregation.
The schedule() method uses simple modular counters to distribute requests across pools. This is the code path that actually executes:
for request in self._request_queue:
# --- Prefill node selection: simple round-robin ---
# (LOR code is commented out on line 351)
# replica_id = min(pending_prefill_requests_map.items(), key=lambda x: x[1])[0]
replica_id = self.p_request_counter % len(self.prefill_replicas) # Round-robin
self.p_request_counter += 1
prefill_replica = self.prefill_replicas[replica_id]
request.prefill_replica_id = replica_id
# --- Decode node selection: round-robin with offset ---
# (LOR code is commented out on lines 364-366)
# replica_id = min(pending_decode_requests_map.items(), key=lambda x: x[1])[0]
replica_id = (self.d_request_counter % len(self.decode_replicas)) \
+ len(self.prefill_replicas) # Offset by N_prefill
self.d_request_counter += 1
decode_replica = self.decode_replicas[replica_id]
request.decode_replica_id = replica_id
The class defines two load-aware methods — find_best_prefill_replica() and find_best_decode_replica() — that implement smarter selection based on pending token count and memory utilization. However, neither method is ever called by schedule() in the current codebase:
def find_best_prefill_replica(self, prefill_replicas, prefill_task):
"""Select prefill replica with minimum pending tokens, if queue not overloaded"""
if len(prefill_replicas) == 0:
return None
prefill_replica = min(prefill_replicas,
key=lambda replica: replica.sched_pending_tokens) # Load-aware!
if self.is_queue_long(prefill_replica, prefill_task):
return None # Back-pressure: reject if queue too long
return prefill_replica
def find_best_decode_replica(self, decode_replicas, prefill_task, decode_task):
"""Select decode replica with minimum pending tokens, if memory sufficient"""
if len(decode_replicas) == 0:
return None
decode_replica = min(decode_replicas,
key=lambda replica: replica.sched_pending_tokens) # Load-aware!
if self.is_memory_loaded(decode_replica, [prefill_task, decode_task]):
return None # Back-pressure: reject if OOM risk
return decode_replica
The constructor sets self._sub_scheduler = "lor" and creates LOR (Least Outstanding Requests) sub-schedulers for both pools via get_global_scheduler(). These sub-schedulers are stored as self.prefill_scheduler and self.decode_scheduler, but neither is ever invoked — the schedule() method implements its own routing logic instead of delegating to them.
| Aspect | Round-Robin (Active) | Load-Aware (Dead Code) |
|---|---|---|
| P-node selection | p_request_counter % len(prefill_replicas) |
min(replicas, key=sched_pending_tokens) |
| D-node selection | (d_request_counter % len(decode_replicas)) + N_p |
min(replicas, key=sched_pending_tokens) + memory check |
| Back-pressure | None — always assigns | is_queue_long() / is_memory_loaded() |
| Fairness | Perfect (deterministic) | Approximate (greedy min) |
| Load imbalance risk | High if request sizes vary | Low (adapts to token counts) |
The source code contains clear evidence that this is a work-in-progress codebase where load-aware routing was planned but not yet completed, rather than a deliberate design decision to use round-robin. Four pieces of evidence:
# TODO: > > 参考 sw写的;但也很多区别;换一个名字;类似pd分离的其他名字;不严格是sw了
# TODO: Refer to Splitwise implementation; but many differences; need a new name;
# similar to pd separation; not strictly Splitwise anymore
class SplitwiseGlobalScheduler(BaseGlobalScheduler):
def __init__(self, config, replicas):
# ① Config-driven sub-scheduler was planned but commented out:
# self._sub_scheduler = self._config.splitwise_scheduler_sub_scheduler
# ② TODO explicitly marks this as incomplete:
# TODO > improve _sub_scheduler flexible choice
# ③ Round-robin was tried, then switched to LOR:
# self._sub_scheduler = "round_robin"
self._sub_scheduler = "lor" # ← Creates LOR sub-schedulers
# ④ But schedule() bypasses sub-schedulers entirely,
# using inline round-robin counters instead (lines 354, 369)
The development trajectory is clear from these comments:
| Evidence | What it tells us |
|---|---|
# TODO: > > 参考 sw写的 (line 24) |
Class was initially modeled after the Splitwise paper (ISCA'24), which uses load-aware routing. The developers acknowledge it has diverged significantly. |
# self._config.splitwise_scheduler_sub_scheduler (line 40) |
A config-driven sub-scheduler selector was planned (would read from CLI/config file), but was never completed. The attribute splitwise_scheduler_sub_scheduler does not exist in the config class. |
# TODO > improve _sub_scheduler flexible choice (line 41) |
Developers explicitly intend to make sub-scheduler selection configurable. This is an unfinished TODO, not a deliberate omission. |
| LOR commented in schedule(), round-robin active (lines 351, 364) | Both LOR lines (min(pending_*_requests_map.items(), key=lambda x: x[1])[0]) are present but commented out immediately above the round-robin lines. This suggests the developers tested LOR, encountered issues (likely related to the pending_requests_map not updating correctly across scheduling rounds), and fell back to round-robin as a simpler interim solution. |
schedule() method reveals a subtle bug: pending_prefill_requests_map and pending_decode_requests_map are initialized at lines 306-320 by reading replica.pending_tasks, but they are never updated inside the scheduling loop (unlike the original Vidur LOR scheduler at line 63 which increments pending_requests_map[replica_id] += 1 after each assignment). This means the LOR code, if uncommented, would always pick the same replica — effectively degrading to non-random assignment rather than true load balancing. The round-robin fallback was likely adopted after this bug was discovered.
The add_kv_cache_transfer() method is the heart of the DAG transformation. It estimates the KV cache size for the given prompt, creates a FlowType.KVCacheTransfer flow node, removes the direct PromptTask->TokenTask edge, and inserts the transfer flow as an intermediary.
def add_kv_cache_transfer(self, request, src_replica, dest_replica, bandwidth):
"""Transform DAG: prompt->token into prompt->kv_transfer->token."""
prefill_task = request.root_node # Root = PromptTask
decode_task = next(request.successors(prefill_task)) # Next = TokenTask
# Estimate KV cache size based on prompt length and model config
flow_size = request.estimate_kv_cache_size(
num_tokens=prefill_task.prompt_size,
replica=src_replica)
# Create KVCacheTransfer flow node
kv_transfer_flow = request.create_flow(
FlowType.KVCacheTransfer,
size=flow_size,
src=src_replica,
dest=dest_replica)
kv_transfer_flow.notify = True
# Rewire the DAG
request.flow_node = kv_transfer_flow
request.dag.remove_edge(prefill_task, decode_task)
request.dag.add_edge(prefill_task, kv_transfer_flow)
request.dag.add_edge(kv_transfer_flow, decode_task)
# Assign tasks to replicas
prefill_task.instance = src_replica
decode_task.instance = dest_replica
# Simulate transfer latency with configurable bandwidth
kv_transfer_flow.link = DummyLink(
name="DummyLink",
bandwidth=bandwidth)
The KV cache size is computed by the estimate_kv_cache_size() method on the Request object. The formula accounts for both the key and value tensors across all layers:
kv_size = 2 * num_tokens * hidden_dim * num_layers * dtype_bytes
Where the factor of 2 accounts for both K and V tensors, hidden_dim is mlp_hidden_dim from the model config, num_layers is the total transformer layers, and dtype_bytes depends on pd_p2p_comm_dtype (e.g., 2 bytes for float16).
def estimate_kv_cache_size(self, num_tokens=None, replica=None):
"""Returns the KV-cache size (in bytes) after generating num_tokens."""
# Determine bytes per element based on dtype
if replica.pd_p2p_comm_dtype == 'float16':
pd_p2p_bytes_per_token = 2
elif replica.pd_p2p_comm_dtype == 'float32':
pd_p2p_bytes_per_token = 4
elif replica.pd_p2p_comm_dtype == 'bfloat16':
pd_p2p_bytes_per_token = 2
elif replica.pd_p2p_comm_dtype == 'int8':
pd_p2p_bytes_per_token = 1
# ... (int16=2, int32=4, int64=8, float64=8)
self.pd_p2p_bytes_per_token = pd_p2p_bytes_per_token
self.pd_p2p_comm_dtype = replica.pd_p2p_comm_dtype
# KV = 2 (K+V) * seq_len * hidden_dim * layers * dtype_bytes
return 2 * num_tokens * replica.mlp_hidden_dim \
* replica.num_layers * pd_p2p_bytes_per_token
At the global level, the SplitwiseGlobalScheduler uses round-robin to decide which replica handles each phase (as shown in Section 2.3.1). Once assigned, the per-replica scheduling behavior differs between prefill and decode replicas — prefill replicas optimize for throughput, while decode replicas optimize for latency. The codebase also contains unused load-aware methods that could replace round-robin in future versions.
Goal: maximize compute throughput. Prefill replicas can afford to accumulate larger batches because they only need to process the prompt once. The scheduler checks sched_pending_tokens to avoid overloading any single P-node. The find_best_prefill_replica() method selects the replica with the minimum pending token count, subject to queue length limits.
Goal: minimize latency variance. Decode replicas run continuous iteration loops generating one token per step. The scheduler checks sched_memory to ensure the D-node has enough KV cache memory for both the incoming prefill result and the ongoing decode. The find_best_decode_replica() method uses memory-awareness to prevent OOM.
| Aspect | Prefill Replica | Decode Replica |
|---|---|---|
| Selection criterion | Round-robin (load-aware method exists but unused) | Round-robin with offset (memory-aware method exists but unused) |
| Bottleneck check | is_queue_long() |
is_memory_loaded() |
| Optimization target | Throughput (tokens/sec) | Latency (TBT / TPOT) |
| Batch pattern | Large, variable-size batches | Continuous iteration, steady batch |
| Memory pressure | Transient (KV freed after transfer) | Accumulating (KV held until request completes) |
The transfer cost between P-nodes and D-nodes is the critical overhead introduced by disaggregation. SimAI models this cost in two ways: (1) an analytical model using pd_p2p_comm_bandwidth from the config, where transfer_time = kv_size / bandwidth; and (2) integration with astra-sim + NS-3 for detailed network simulation when higher fidelity is needed.
The KV cache size per token varies dramatically across model architectures due to differences in attention head configuration. Models using GQA (Grouped Query Attention) or MLA (Multi-head Latent Attention) have significantly smaller KV caches than standard MHA (Multi-Head Attention).
| Model | Attention Type | KV Heads | Head Dim | Layers | KV Size / Token (fp16) |
|---|---|---|---|---|---|
| LLaMA-70B | GQA | 8 | 128 | 80 | ~320 KB (2*2*80*128*8 bytes) |
| LLaMA-405B | GQA | 8 | 128 | 126 | ~504 KB (2*2*126*128*8 bytes) |
| DeepSeek-V3-671B | MLA (d_kv_c=512) | compressed | 512 | 61 | ~125 KB (MLA compresses ~10x vs MHA) |
| Qwen3-MoE-235B | GQA | 4 | 128 | 94 | ~188 KB (2*2*94*128*4 bytes) |
For a prompt of 2048 tokens on LLaMA-70B with float16, the KV cache transfer size is approximately 640 MB (320 KB/token * 2048 tokens). At the default 800 Gbps bandwidth, this takes about 6.4 ms -- negligible compared to the prefill compute time. However, at 100 Gbps (e.g., RoCE across racks), it rises to ~51 ms, which may become significant.
# Analytical model (DummyLink)
transfer_time = kv_cache_size_bytes / (pd_p2p_comm_bandwidth_bps / 8)
# Example: LLaMA-70B, 2048 tokens, float16, 800 Gbps
kv_size = 2 * 2048 * 8192 * 80 * 2 # = 5,368,709,120 bytes (~5.0 GB)
transfer_time = 5368709120 / (800e9 / 8) # = 53.7 ms
# With NS-3 simulation: accounts for congestion, link utilization, etc.
replica.mlp_hidden_dim as the hidden dimension in the KV cache formula. For standard MHA/GQA models, the actual per-token KV size is 2 * num_kv_heads * head_dim * num_layers * dtype_bytes. The SimAI formula uses the full hidden_dim which gives a conservative upper bound. When running with real model configs, verify the mapping between mlp_hidden_dim and your model's actual KV head configuration.
The discrete-event simulation in SimAI processes different event chains depending on whether PD disaggregation is enabled. Below we compare the two event chains side-by-side.
RequestArrival -> GlobalSchedule -> ReplicaSchedule -> BatchStageArrival -> BatchStageEnd (prefill) -> BatchStageEnd (decode iter 1) -> ... -> BatchStageEnd (decode iter N) -> BatchEnd. All events happen on the same replica.
RequestArrival -> GlobalSchedule(Prefill) -> ReplicaSchedule(P) -> BatchStageEnd(Prefill) -> KVTransfer -> GlobalSchedule(Decode) -> ReplicaSchedule(D) -> BatchStageEnd(Decode)* -> BatchEnd. Events span two replicas with a transfer in between.
PD disaggregation in SimAI is controlled by three main parameters in the ReplicaConfig dataclass. These are set via command-line arguments or configuration files.
@dataclass
class ReplicaConfig:
# ... other fields ...
# PD disaggregation parameters
pd_p2p_comm_bandwidth: int = field(
default=800,
metadata={"help": "PD P2P communication bandwidth (bps)"},
)
pd_p2p_comm_dtype: str = field(
default='float16',
metadata={"help": "Data type for KV cache transfer"},
)
pd_node_ratio: float = field(
default=0.5,
metadata={"help": "Ratio of P replicas to total replicas"},
)
# Related bandwidth parameters
nvlink_bandwidth: int = field(
default=1600,
metadata={"help": "NVLink bandwidth for TP/EP (bps)"},
)
rdma_bandwidth: int = field(
default=800,
metadata={"help": "RDMA bandwidth for TP/EP (bps)"},
)
| Parameter | Default | Description | Impact |
|---|---|---|---|
pd_p2p_comm_bandwidth |
800 bps | Point-to-point bandwidth between P and D nodes | Higher = lower KV transfer latency |
pd_p2p_comm_dtype |
float16 | Data type for KV cache during transfer | Smaller dtype = less data to transfer |
pd_node_ratio |
0.5 | Fraction of replicas allocated as P-nodes. 0.5 = 1P:1D, 0.33 = 1P:2D, 0.67 = 2P:1D | Most sensitive parameter for overall performance |
# Balanced P:D (default)
--pd_node_ratio 0.5 --pd_p2p_comm_bandwidth 800 --pd_p2p_comm_dtype float16
# More decode capacity (long outputs, chat workload)
--pd_node_ratio 0.33 --pd_p2p_comm_bandwidth 800 --pd_p2p_comm_dtype float16
# More prefill capacity (long prompts, summarization workload)
--pd_node_ratio 0.67 --pd_p2p_comm_bandwidth 800 --pd_p2p_comm_dtype float16
# INT8 quantized KV transfer (halves transfer size)
--pd_node_ratio 0.5 --pd_p2p_comm_bandwidth 400 --pd_p2p_comm_dtype int8
When PD disaggregation is enabled, SimAI tracks additional per-request metrics that capture the behavior of each phase independently. These metrics are recorded in the MetricsStore and exported for analysis.
class MetricConstants:
# --- PD-specific timing metrics ---
PREFILL_ARRIVED_AT = "prefill_arrived_at" # When request reached P-node
PREFILL_COMPLETED_AT = "prefill_completed_at" # When prefill finished
DECODE_ARRIVED_AT = "decode_arrived_at" # When KV arrived at D-node
DECODE_TIME = "decode_time" # Total decode duration
# --- PD-specific routing metrics ---
PREFILL_REPLICA_ID = "prefill_replica_id" # Which P-node was used
DECODE_REPLICA_ID = "decode_replica_id" # Which D-node was used
# --- KV transfer metrics ---
PD_P2P_COMM_SIZE = "pd_p2p_comm_size" # KV cache bytes transferred
PD_P2P_COMM_TIME = "pd_p2p_comm_time" # Transfer latency
# --- Derived latency metrics ---
PREFILL_TIME_E2E = "prefill_e2e_time" # TTFT = prefill_completed - arrived
TBT = "tbt" # Time Between Tokens (TPOT)
These metrics enable fine-grained analysis of PD disaggregation performance:
prefill_e2e_time = prefill_completed_at - arrived_at. This is the Time-To-First-Token -- the latency the user waits before seeing the first response token. In disaggregated mode, this also includes scheduling latency on the P-node.
tbt measures the Time Between Tokens (also called TPOT -- Time Per Output Token). In disaggregated mode, D-nodes are free from prefill interference, so TBT should be more consistent and lower-variance than in co-located mode.
By comparing pd_p2p_comm_time against prefill_e2e_time, you can determine what fraction of the end-to-end TTFT is consumed by the KV transfer. If pd_p2p_comm_time / prefill_e2e_time > 0.3, the transfer overhead is significant and you may want to increase bandwidth, use a smaller dtype (int8), or use MLA-based models.
The metrics store also computes derived values during the on_request_end() callback:
# From vidur/metrics/metrics_store.py
# TTFT: Time from arrival to first token (prefill latency)
prefill_e2e_time = request.prefill_completed_at - request.arrived_at
# Scheduling delay for prefill
prefill_scheduling_delay = request.prefill_completed_at - request.scheduled_at
# Model execution time (excluding scheduling)
prefill_model_execution_time = (
request.prefill_completed_at - request.scheduled_at
)
# Decode phase duration
decode_model_execution_time = (
request.completed_at - request.prefill_completed_at
)
SimAI enables PD disaggregation simulation through cross-component integration -- vidur-alibabacloud handles the discrete-event scheduling (replica classification, request DAG, SplitWise scheduler), while astra-sim can model the KV cache transfer costs with network-level fidelity via NS-3. The DummyLink provides a fast analytical fallback when full network simulation is not needed. This layered approach allows users to choose their fidelity/speed tradeoff.
The pd_node_ratio is the most sensitive parameter in PD disaggregation. Too many prefill nodes wastes decode capacity -- D-nodes sit idle while users wait for tokens. Too few prefill nodes creates a prefill bottleneck -- requests queue up waiting for prefill slots, increasing TTFT. The optimal ratio depends on the workload: chat workloads (short prompts, long outputs) favor more D-nodes (ratio ~0.25-0.33), while summarization workloads (long prompts, short outputs) favor more P-nodes (ratio ~0.5-0.67). SimAI allows sweeping this parameter to find the sweet spot.
DeepSeek's Multi-head Latent Attention (MLA) reduces KV cache size by approximately 10x compared to standard MHA at similar model scale. This fundamentally changes the PD disaggregation tradeoff: with MLA, the KV transfer cost becomes almost negligible, making disaggregation nearly "free" in terms of transfer overhead. This means aggressive disaggregation strategies become viable even with lower-bandwidth interconnects. SimAI can model this by configuring the appropriate mlp_hidden_dim and num_layers for MLA-based models.
The nx.DiGraph-based request DAG is remarkably extensible. While currently used for the simple PromptTask -> KVTransfer -> TokenTask chain, it could model more complex patterns: multi-step prefill (chunked prefill across multiple P-nodes), speculative decoding (draft + verify stages), or even multi-modal pipelines (image encoding -> text prefill -> decode). The DAG abstraction decouples scheduling logic from the specific task graph topology.
In disaggregated mode, memory pressure is fundamentally asymmetric. P-nodes hold KV caches transiently (only during prefill, then transferred out), while D-nodes accumulate KV caches for the entire decode duration of every active request. This means D-nodes typically need much more KV cache memory than P-nodes. The is_memory_loaded() check in the decode scheduler is critical -- without it, D-nodes would accept more requests than they can hold in memory, causing OOM-equivalent failures in the simulation.
To fully understand how PD disaggregation works in SimAI, it is helpful to trace a single request through the entire scheduling pipeline. The following step-flow shows every major decision point and data mutation that occurs.
The request generator creates a Request object with num_prefill_tokens and num_decode_tokens. The request is added to self._request_queue in the SplitwiseGlobalScheduler. At this point, prefill_arrived_at is set to the current simulation time, and decode_arrived_at remains float('inf').
Inside schedule(), for each request: (a) create prefill_task = request.create_task(TaskType.PROMPT, prompt_size=request.num_prefill_tokens); (b) create decode_task = request.create_task(TaskType.TOKEN, token_size=request.num_decode_tokens - 1); (c) add DAG edge: request.dag.add_edge(prefill_task, decode_task); (d) set request.root_node = prefill_task.
The scheduler selects a prefill replica: replica_id = self.p_request_counter % len(self.prefill_replicas). The counter is incremented. The request records request.prefill_replica_id = replica_id. The (replica_id, request) pair is appended to prefill_request_mapping.
The decode replica is selected with an offset: replica_id = (self.d_request_counter % len(self.decode_replicas)) + len(self.prefill_replicas). This offset ensures decode replica IDs start after prefill replica IDs. The request records request.decode_replica_id = replica_id.
Since prefill_replica != decode_replica, add_kv_cache_transfer() is called. This: (a) estimates KV cache size via estimate_kv_cache_size(); (b) creates a KVCacheTransferFlow node; (c) removes the PromptTask->TokenTask edge; (d) adds PromptTask->KVFlow and KVFlow->TokenTask edges; (e) assigns task instances and creates the DummyLink. Both tasks are also added to their respective replica pools via add_to_pool().
The scheduler updates memory tracking: prefill_replica.sched_memory += prefill_task.max_memory(prefill_replica) and decode_replica.sched_memory += prefill_task.max_memory(decode_replica) + decode_task.max_memory(decode_replica). Note that the D-node accounts for both the incoming KV cache (from prefill) and the ongoing decode memory. Pending token counts are also updated for both replicas.
This section provides practical guidance for setting up and running PD disaggregation simulations with SimAI.
The P:D ratio is workload-dependent. Use the following guidelines as a starting point:
| Workload Type | Typical Prompt Length | Typical Output Length | Suggested pd_node_ratio | Reasoning |
|---|---|---|---|---|
| Chat / Conversational | 100-500 | 200-2000 | 0.25 - 0.33 | Decode-heavy: more D-nodes needed for long outputs |
| Summarization | 2000-8000 | 100-500 | 0.5 - 0.67 | Prefill-heavy: more P-nodes needed for long prompts |
| Code Generation | 500-2000 | 500-4000 | 0.33 - 0.5 | Balanced: moderate prompt with long output |
| RAG / Document Q&A | 4000-32000 | 100-1000 | 0.5 - 0.75 | Very prefill-heavy: long context requires P-node capacity |
The pd_p2p_comm_bandwidth parameter has a direct impact on the KV transfer latency. Running a sweep across bandwidths helps identify the minimum viable interconnect for a given workload. Below is the relationship between bandwidth and transfer time for a 2048-token prompt on LLaMA-70B with float16:
| Interconnect | Bandwidth | Transfer Time (2048 tokens) | Viable? |
|---|---|---|---|
| NVLink (intra-node) | 900 GB/s | ~0.6 ms | Excellent |
| PCIe Gen5 x16 | 64 GB/s | ~8.4 ms | Good |
| 400GbE RDMA (intra-rack) | 50 GB/s | ~10.7 ms | Acceptable |
| 100GbE RoCE (cross-rack) | 12.5 GB/s | ~42.9 ms | Marginal for large models |
| 25GbE TCP | ~3 GB/s | ~178 ms | Not recommended |
Reducing the dtype of the KV cache during transfer can halve the transfer time without significantly impacting model quality. The pd_p2p_comm_dtype parameter supports multiple options:
2 bytes/element. Default choice. No quantization error. Use when bandwidth is sufficient.
1 byte/element. Halves transfer size. Minor quality impact. Recommended for bandwidth-constrained setups.
4 bytes/element. Doubles transfer size. Only use if model requires full precision KV cache.
To evaluate whether PD disaggregation is beneficial for your workload, run the simulation twice with the same parameters except for the global scheduler type. Compare the following metrics:
# Use round_robin or lor global scheduler
--global_scheduler_type round_robin
--num_replicas 4
All 4 replicas handle both prefill and decode (MIXED mode).
# Use splitwise global scheduler
--global_scheduler_type splitwise
--num_replicas 4
--pd_node_ratio 0.5
--pd_p2p_comm_bandwidth 800
2 P-nodes (replicas 0-1) and 2 D-nodes (replicas 2-3).
Key metrics to compare between the two runs:
| Metric | Expect from Disagg | Why |
|---|---|---|
| P50 TTFT | Similar or slightly worse | KV transfer adds latency, but dedicated P-nodes reduce queueing |
| P99 TTFT | Better (lower) | No decode interference on P-nodes means less tail latency |
| TBT variance | Much better | D-nodes free from prefill stalls: steady decode iteration |
| Overall throughput | Depends on ratio | Wrong P:D ratio can reduce throughput vs co-located |
| File | Role | Key Additions |
|---|---|---|
entities/replica.py |
Replica classification | ReplicaType enum (MIXED / PREFILL / DECODE) |
entities/request.py |
Request DAG + KV size estimation | nx.DiGraph DAG, estimate_kv_cache_size(), pd_p2p fields |
scheduler/global_scheduler/splitwise_global_scheduler.py |
PD scheduling orchestration | P/D pool split, schedule(), add_kv_cache_transfer() |
entities/flow.py |
Transfer flow abstraction | FlowType.KVCacheTransfer, Flow class |
entities/interconnect.py |
Transfer latency modeling | DummyLink with configurable bandwidth |
config/config.py |
Configuration | pd_p2p_comm_bandwidth, pd_p2p_comm_dtype, pd_node_ratio |
metrics/constants.py |
Metric definitions | PD timing/routing/transfer metrics |
metrics/metrics_store.py |
Metric collection | on_request_end() PD metric recording |
events/batch_end_event.py |
Event handling | PD-aware batch completion logic |