A discrete-event simulator for LLM inference -- exploring the full serving stack without a single GPU.
Microsoft Research -- MLSys 2024Vidur is built as a classical discrete-event simulator (DES). Instead of stepping through wall-clock time, it maintains a priority queue of timestamped events and processes them in order. Each event handler may produce new events, driving the simulation forward without any actual GPU computation.
The Simulator class is remarkably simple -- under 130 lines. It initializes the cluster,
request generator, scheduler, and metrics store, then runs a tight event loop using Python's heapq.
class Simulator:
def __init__(self, config: SimulationConfig) -> None:
self._time = 0
self._terminate = False
self._event_queue = [] # min-heap of (priority, event)
# Create the simulated cluster (replicas, NOT real GPUs)
self._cluster = Cluster(config.cluster_config, ...)
self._metric_store = MetricsStore(config)
# Generate ALL requests up-front (synthetic or trace-based)
self._request_generator = RequestGeneratorRegistry.get(
config.request_generator_config.get_type(), ...)
# Create the hierarchical scheduler
self._scheduler = GlobalSchedulerRegistry.get(
config.cluster_config.global_scheduler_config.get_type(),
config, self._cluster.replicas)
self._init_event_queue() # Seed with RequestArrivalEvents
def run(self) -> None:
while self._event_queue and not self._terminate:
_, event = heapq.heappop(self._event_queue) # Pop lowest-time event
self._set_time(event._time) # Advance sim clock
new_events = event.handle_event( # Process event
self._scheduler, self._metric_store)
self._add_events(new_events) # Push new events
assert self._scheduler.is_empty() or self._terminate
time + execution_time, the execution time comes from the ML predictor, not from actually
running any GPU kernel. This is what makes Vidur a simulator.
class EventType(BaseIntEnum):
# At any given time step, call the schedule event last
# to ensure that all the requests are processed
BATCH_STAGE_ARRIVAL = 1 # Batch arrives at a pipeline stage
REQUEST_ARRIVAL = 2 # New request enters the system
BATCH_STAGE_END = 3 # Stage finishes "execution"
BATCH_END = 4 # All stages done for a batch
GLOBAL_SCHEDULE = 5 # Global scheduler dispatches
REPLICA_SCHEDULE = 6 # Replica creates batches
REPLICA_STAGE_SCHEDULE = 7 # Stage begins "execution"
The integer values define priority ordering at the same timestamp. BATCH_STAGE_ARRIVAL (1) is
processed before REQUEST_ARRIVAL (2), which is before scheduling events (5-7). This ensures all
completions and arrivals are registered before scheduling decisions are made.
| Event Type | Priority | Source File | Produces |
|---|---|---|---|
REQUEST_ARRIVAL | 2 | request_arrival_event.py | GlobalScheduleEvent |
GLOBAL_SCHEDULE | 5 | global_schedule_event.py | ReplicaScheduleEvent(s) |
REPLICA_SCHEDULE | 6 | replica_schedule_event.py | BatchStageArrivalEvent(s) |
BATCH_STAGE_ARRIVAL | 1 | batch_stage_arrival_event.py | ReplicaStageScheduleEvent |
REPLICA_STAGE_SCHEDULE | 7 | replica_stage_schedule_event.py | BatchStageEndEvent |
BATCH_STAGE_END | 3 | batch_stage_end_event.py | BatchEndEvent or BatchStageArrival(next) |
BATCH_END | 4 | batch_end_event.py | ReplicaScheduleEvent |
Let us trace a single simulated request through all six stages of the pipeline. Every transition between stages is mediated by an event in the priority queue -- there are no direct function calls between components.
Before the simulation starts, the request generator creates all requests and seeds them into the event queue. Vidur supports two generation strategies: synthetic (with configurable arrival rate distributions and token length distributions) and trace-replay (replaying real workload traces from CSV files).
class SyntheticRequestGenerator(BaseRequestGenerator):
def _generate_next_request(self, last_arrived_at):
# Get inter-request time from distribution
inter_request_time = (
self.request_interval_generator
.get_next_inter_request_time()
)
arrived_at = last_arrived_at + inter_request_time
# Get token counts from distribution
prefill_tokens, decode_tokens = (
self.request_length_generator
.get_next_num_tokens()
)
return Request(
arrived_at=arrived_at,
num_prefill_tokens=int(prefill_tokens),
num_decode_tokens=int(decode_tokens),
)
class TraceReplayRequestGenerator(BaseRequestGenerator):
def __init__(self, config):
# Load CSV: arrived_at, num_prefill_tokens,
# num_decode_tokens
self.trace_df = pd.read_csv(config.trace_file)
# Scale and clamp tokens
self.trace_df["num_prefill_tokens"] = (
self.trace_df["num_prefill_tokens"]
* config.prefill_scale_factor
).clip(lower=1)
def generate_requests(self):
return [
Request(row["arrived_at"],
row["num_prefill_tokens"],
row["num_decode_tokens"])
for _, row in self.trace_df.iterrows()
]
def _init_event_queue(self) -> None:
requests = self._request_generator.generate() # ALL requests created here
for request in requests:
self._add_event(
RequestArrivalEvent(request.arrived_at, request) # Scheduled for future
)
All RequestArrivalEvents are created before run() starts. The heap ensures they are
processed in chronological order regardless of insertion order.
| Category | Type | Description |
|---|---|---|
| Interval Generators | Poisson | Exponential inter-arrival times (Poisson process) |
Gamma | Gamma-distributed inter-arrival times | |
Static | All requests arrive at time 0 (offline/batch) | |
| Length Generators | Uniform | Uniform random prefill/decode lengths |
Zipf | Zipf-distributed token lengths (realistic skew) | |
Trace | Lengths read from trace file |
When a RequestArrivalEvent fires, it adds the request to the global scheduler's queue and
immediately creates a GlobalScheduleEvent. The global scheduler then decides which replica
each pending request should go to.
class RequestArrivalEvent(BaseEvent):
def handle_event(self, scheduler, metrics_store):
scheduler.add_request(self._request) # Add to global queue
metrics_store.on_request_arrival(self.time, self._request)
return [GlobalScheduleEvent(self.time)] # Trigger scheduling NOW
class GlobalScheduleEvent(BaseEvent):
def handle_event(self, scheduler, metrics_store):
self._replica_set = set()
# schedule() returns List[(replica_id, request)]
self._request_mapping = scheduler.schedule()
for replica_id, request in self._request_mapping:
self._replica_set.add(replica_id)
scheduler.get_replica_scheduler(replica_id).add_request(request)
# One ReplicaScheduleEvent per affected replica
return [
ReplicaScheduleEvent(self.time, replica_id)
for replica_id in self._replica_set
]
def schedule(self):
self.sort_requests()
request_mapping = []
while self._request_queue:
request = self._request_queue.pop(0)
replica_id = (self._request_counter
% self._num_replicas)
self._request_counter += 1
request_mapping.append(
(replica_id, request))
return request_mapping
def schedule(self):
self.sort_requests()
pending_map = {
rs.replica_id: rs.num_pending_requests
for rs in
self._replica_schedulers.values()
}
while self._request_queue:
request = self._request_queue.pop(0)
replica_id = min(
pending_map.items(),
key=lambda x: x[1])[0]
pending_map[replica_id] += 1
request_mapping.append(
(replica_id, request))
return request_mapping
def schedule(self):
self.sort_requests()
request_mapping = []
while self._request_queue:
request = (
self._request_queue.pop(0))
replica_id = randint(
1, self._num_replicas) - 1
request_mapping.append(
(replica_id, request))
return request_mapping
The replica scheduler is the most complex component. It must decide which requests to include in the next batch while respecting memory constraints (KV-cache blocks), batch size limits, and token budget limits. Each scheduler implementation represents a different real-world serving system's batching strategy.
class ReplicaScheduleEvent(BaseEvent):
def handle_event(self, scheduler, metrics_store):
replica_scheduler = scheduler.get_replica_scheduler(self._replica_id)
self._batches = replica_scheduler.on_schedule() # Creates batches
if not self._batches:
return []
metrics_store.on_replica_schedule(
self.time, self._replica_id,
replica_scheduler.memory_usage_percent)
for batch in self._batches:
batch.on_schedule(self.time) # Mark requests as scheduled
# Send each batch to pipeline stage 0
return [
BatchStageArrivalEvent(self.time, self._replica_id,
0, # stage_id = 0 (first stage)
batch)
for batch in self._batches
]
class VLLMReplicaScheduler(BaseReplicaScheduler):
def _get_next_batch(self) -> Batch:
requests, num_tokens = [], []
# First: try to schedule new requests (prefill-prioritizing)
while self._request_queue:
request = self._request_queue[0]
next_num_tokens = self._get_request_next_num_tokens(request)
if not self._can_allocate_request(request):
break # Out of KV-cache blocks
# Check token budget: batch_size * max_tokens_per_request
new_num_tokens = num_tokens + [next_num_tokens]
new_batch_tokens = len(new_num_tokens) * max(new_num_tokens)
if new_batch_tokens > self._config.max_tokens_in_batch:
break
request = self._request_queue.pop(0)
self._allocate_request(request) # Reserve KV-cache blocks
requests.append(request)
num_tokens.append(next_num_tokens)
if requests:
return Batch(self._replica_id, requests, num_tokens)
# Fallback: schedule preempted requests (decode tokens)
# With OOM handling: evict victim requests if needed
while self._preempted_requests:
request = self._preempted_requests.pop(0)
while not self._can_allocate_request(request):
victim = self._preempted_requests.pop(-1)
victim.restart() # Restart from scratch
self.free(victim.id)
self._request_queue = [victim] + self._request_queue
else:
self._allocate_request(request)
...
return Batch(self._replica_id, requests, num_tokens)
class SarathiReplicaScheduler(BaseReplicaScheduler):
def _get_request_next_num_tokens(self, request, batch_contains_prefill, num_batch_tokens):
if request.is_prefill_complete:
return 1 # Decode: always 1 token
# Chunked prefill: limit to remaining chunk budget
next_num_tokens = min(
request.num_prefill_tokens - request.num_processed_tokens,
self._config.chunk_size - num_batch_tokens, # KEY: chunk_size limit
)
return max(0, next_num_tokens)
chunk_size parameter (e.g., 512, 1K, 2K) controls the maximum prefill tokens per iteration.
class BaseReplicaScheduler(ABC):
def __init__(self, ...):
self._request_queue = []
self._num_allocated_blocks = 0
self._allocation_map = {} # {request_id: num_blocks}
# Memory-aware batch size cap
self._max_batch_size = min(
memory_planner.get_max_batch_size(),
self._config.batch_size_cap)
def can_allocate(self, num_blocks) -> bool:
return (self._config.num_blocks
- self._num_allocated_blocks >= num_blocks)
def allocate(self, request_id, num_blocks):
self._num_allocated_blocks += num_blocks
self._allocation_map[request_id] = (
self._allocation_map.get(request_id, 0) + num_blocks)
def free(self, *request_ids):
for request_id in request_ids:
num_blocks = self._allocation_map.pop(request_id)
self._num_allocated_blocks -= num_blocks
Once a batch is created, it must traverse all pipeline stages. The ReplicaStageScheduler
is the component that actually invokes the execution time predictor and creates the BatchStage entity
with predicted timings. It also enforces pipeline synchrony: each stage can only process one batch at a time.
class ReplicaStageScheduler:
def __init__(self, replica_id, stage_id, is_last_stage,
execution_time_predictor):
self._batch_queue = []
self._is_busy = False # Only one batch at a time per stage
def on_schedule(self) -> Tuple[Batch, BatchStage, ExecutionTime]:
if self._is_busy or not self._batch_queue:
return None, None, None
self._is_busy = True
batch = self._batch_queue.pop(0)
# THIS IS WHERE PREDICTION HAPPENS (no real GPU execution)
execution_time = self._execution_time_predictor.get_execution_time(
batch, self._stage_id)
total_execution_time = execution_time.total_time # model + CPU overhead
model_execution_time = execution_time.model_time # model only
batch_stage = BatchStage(
batch.id, self._replica_id, self._stage_id,
total_execution_time, model_execution_time,
batch.requests, batch.num_tokens)
return batch, batch_stage, execution_time
def on_stage_end(self):
self._is_busy = False # Free the stage for next batch
class ReplicaStageScheduleEvent(BaseEvent):
def handle_event(self, scheduler, metrics_store):
stage_scheduler = scheduler._replica_schedulers[
self._replica_id]._replica_stage_schedulers[self._stage_id]
self._batch, self._batch_stage, execution_time = (
stage_scheduler.on_schedule()) # Predicts time!
self._batch_stage.on_schedule(self.time)
# Create end event at: now + predicted_execution_time
return [
BatchStageEndEvent(
self.time + self._batch_stage.execution_time, # FUTURE time
self._replica_id, self._stage_id,
stage_scheduler.is_last_stage,
self._batch, self._batch_stage)
]
BatchStageEndEvent scheduled at
current_time + predicted_time. The event loop will naturally process this "future" event when
its timestamp becomes the lowest in the heap. No actual time passes -- the simulator simply jumps to that moment.
class BatchStageEndEvent(BaseEvent):
def handle_event(self, scheduler, metrics_store):
# Free the pipeline stage
scheduler.get_replica_stage_scheduler(
self._replica_id, self._stage_id).on_stage_end()
self._batch_stage.on_stage_end(self.time)
metrics_store.on_batch_stage_end(self._batch_stage, ...)
# Always try to schedule next batch on this stage
next_events = [ReplicaStageScheduleEvent(
self.time, self._replica_id, self._stage_id)]
if self._is_last_stage:
# All PP stages done -> batch iteration complete
return next_events + [
BatchEndEvent(self.time, self._replica_id, self._batch)]
# Forward to NEXT pipeline stage
return next_events + [
BatchStageArrivalEvent(
self.time, self._replica_id,
self._stage_id + 1, # stage_id + 1
self._batch)]
The execution time predictor is what makes Vidur a high-fidelity simulator rather than a toy model. It uses sklearn-based ML models (Random Forest, Linear Regression) trained on profiled kernel runtimes to predict execution time for 13+ individual operators within each transformer layer.
class BaseExecutionTimePredictor(ABC):
def get_execution_time(self, batch: Batch, pipeline_stage: int) -> ExecutionTime:
# Conditionally compute communication costs
if pipeline_stage == self._replica_config.num_pipeline_stages - 1:
pp_comm_time = 0 # Last stage: no PP send
else:
pp_comm_time = self._get_pipeline_parallel_communication_time(batch)
if self._replica_config.tensor_parallel_size == 1:
tp_comm_time = 0 # No TP: no all-reduce
else:
tp_comm_time = self._get_tensor_parallel_communication_time(batch)
return ExecutionTime(
self._num_layers_per_pipeline_stage,
self._get_attention_rope_execution_time(batch),
self._get_attention_kv_cache_save_execution_time(batch),
self._get_attention_decode_execution_time(batch),
self._get_attention_prefill_execution_time(batch),
self._get_attention_layer_pre_proj_execution_time(batch),
self._get_attention_layer_post_proj_execution_time(batch),
self._get_mlp_layer_up_proj_execution_time(batch),
self._get_mlp_layer_down_proj_execution_time(batch),
self._get_mlp_layer_act_execution_time(batch),
self._get_attn_norm_layer_act_execution_time(batch),
self._get_mlp_norm_layer_act_execution_time(batch),
self._get_add_layer_act_execution_time(batch),
tp_comm_time, pp_comm_time,
self._get_schedule_time(batch),
self._get_sampler_e2e_time(batch),
self._get_prepare_inputs_e2e_time(batch),
self._get_process_model_outputs_time(batch),
self._get_ray_comm_time(batch),
)
class ExecutionTime(BaseEntity):
@property
def model_time(self) -> float:
# Per-layer time * number of layers + PP communication
block_time = self._get_block_execution_time() # attn + mlp + add
stage_time = block_time * self._num_layers_per_pipeline_stage
return (stage_time + self.pipeline_parallel_communication_time) * 1e-3
@property
def total_time(self) -> float:
# model_time (GPU) + CPU overhead (schedule, sampler, etc.)
return self.model_time + self._get_cpu_overhead() * 1e-3
def _get_block_execution_time(self) -> float:
return (self._get_attention_layer_execution_time()
+ self._get_mlp_layer_execution_time()
+ self._add_time)
def _get_cpu_overhead(self) -> float:
return (self._schedule_time + self._sampler_e2e_time
+ self._prepare_inputs_e2e_time
+ self._process_model_outputs_time
+ self._ray_comm_time)
The paper identifies a key insight: LLM operators can be categorized into three groups based on what determines their runtime. This allows targeted prediction strategies rather than needing to profile every possible combination.
Runtime depends on total tokens in batch (prefill + decode). Examples: linear projections, activation functions. The MLP layer takes the same compute regardless of request history.
Runtime depends on context length of each request. The attention kernel is sensitive to both current tokens and KV-cache size. Prefill attention is quadratic; decode attention depends on total KV-cache reads.
Runtime depends on data transfer amount, independent of model architecture. Includes all-reduce (TP), all-gather (TP), and send/recv (PP). Profiled once, reused across models.
When the last pipeline stage completes, a BatchEndEvent fires. This updates all request
state, frees completed requests' memory, and triggers the next scheduling cycle. The autoregressive decode loop is an
emergent behavior of this event chain: each batch processes one decode token per request, and the cycle repeats until
all tokens are generated.
class BatchEndEvent(BaseEvent):
def handle_event(self, scheduler, metrics_store):
self._batch.on_batch_end(self.time) # Updates all request tokens
replica_scheduler = scheduler.get_replica_scheduler(self._replica_id)
replica_scheduler.on_batch_end(self._batch) # Free/preempt requests
metrics_store.on_batch_end(
self.time, self._batch, self._replica_id,
replica_scheduler.memory_usage_percent)
# Re-trigger replica scheduling (next decode iteration!)
return [ReplicaScheduleEvent(self.time, self._replica_id)]
def on_batch_end(self, time, num_tokens_processed):
self._num_processed_tokens += num_tokens_processed
# Check: did we just finish all prefill tokens?
if self._num_processed_tokens == self._num_prefill_tokens:
self._is_prefill_complete = True
self._num_processed_tokens += 1 # First decode token is "free"
if self._prefill_completed_at == 0:
self._prefill_completed_at = time # Record TTFT
# Check: is the request fully complete?
if self._num_processed_tokens == self.total_tokens:
self._completed_at = time
self._completed = True # Done!
BatchEnd → ReplicaSchedule → BatchStageArrival → StageSchedule → BatchStageEnd → BatchEnd
naturally repeats. Each iteration, the replica scheduler sees that requests still have unprocessed decode tokens and includes them
in the next batch with num_tokens=1. The request self-completes when num_processed_tokens == total_tokens.
The fundamental unit. Tracks its own lifecycle state through callbacks from batch processing.
class Request(BaseEntity):
# Identity
_arrived_at: float
_num_prefill_tokens: int
_num_decode_tokens: int
_num_processed_tokens: int
# Lifecycle timestamps
_scheduled_at, _completed_at: float
_prefill_completed_at: float # TTFT
# State flags
_scheduled, _completed: bool
_is_prefill_complete: bool
_preempted: bool
_num_restarts: int
Groups requests for one iteration. Each request contributes a specific number of tokens to the batch.
class Batch(BaseEntity):
_replica_id: int
_requests: List[Request]
_num_tokens: List[int] # Per-request
_total_num_tokens: int # sum(num_tokens)
_num_prefill_tokens: int # Prefill subset
# Rounded for hardware alignment
_total_num_tokens_rounded = (
(total + 7) // 8 * 8)
One batch processed at one pipeline stage. Carries the predicted execution time and generates Chrome trace events.
class BatchStage(BaseEntity):
_batch_id, _replica_id: int
_pipeline_stage: int
_execution_time: float # total (model+CPU)
_model_execution_time: float # GPU model only
_requests: List[Request]
_num_tokens: List[int]
Cluster creates num_replicas Replica objects. Each Replica encapsulates model config
(layers, heads, embedding dims) and device config (memory, FLOPS), but holds NO actual weights.
class Cluster(BaseEntity):
def __init__(self, cluster_config, ...):
self._replicas = {}
for _ in range(config.num_replicas):
replica = Replica(config.replica_config,
generator_config)
self._replicas[replica.id] = replica
Vidur's configuration hierarchy mirrors its component hierarchy. The SimulationConfig
is the root, containing cluster config (replicas, parallelism), scheduler config (which batching policy), request generator
config (workload), and metrics config (output format). This is what enables Vidur-Search to programmatically sweep across
hundreds of configurations.
SimulationConfig
├── cluster_config: ClusterConfig
│ ├── num_replicas: int
│ ├── replica_config: ReplicaConfig
│ │ ├── model_config: ModelConfig # layers, heads, embedding_dim
│ │ ├── device_config: DeviceSKUConfig # A100, H100 specs
│ │ ├── num_pipeline_stages: int # PP dimension
│ │ └── tensor_parallel_size: int # TP dimension
│ ├── replica_scheduler_config: BaseReplicaSchedulerConfig
│ │ ├── batch_size_cap, block_size: int
│ │ ├── max_tokens_in_batch: int
│ │ └── chunk_size: int # Sarathi only
│ └── global_scheduler_config: BaseGlobalSchedulerConfig
├── request_generator_config: BaseRequestGeneratorConfig
│ ├── max_tokens: int
│ └── seed, duration, num_requests: ...
├── execution_time_predictor_config: BaseExecutionTimePredictorConfig
│ ├── compute_input_file: str # Profiled data paths
│ └── attention_input_file: str
├── metrics_config: MetricsConfig
│ ├── output_dir: str
│ ├── write_json_trace: bool
│ └── enable_chrome_trace: bool
└── time_limit: Optional[float]
| Category | Metrics | Collected At |
|---|---|---|
| Request-level | TTFT, TBT, E2E latency, scheduling delay, preempted time, num_restarts | on_batch_end, on_request_arrival |
| Batch-level | Batch size, num_tokens, prefill/decode mix, execution time | on_batch_stage_end, on_batch_end |
| Replica-level | Memory usage %, busy/idle time, tokens processed per iteration | on_replica_schedule |
| Cluster-level | Model FLOPs Utilization (MFU), Memory Bandwidth Utilization (MBU), throughput | Derived at plot time |
| Aspect | Real Inference Engine (e.g., vLLM) | Vidur Simulator |
|---|---|---|
| Model Weights | Loaded into GPU memory (GBs) | Never loaded. Only model spec (num_layers, dims) |
| GPU Execution | Real CUDA kernels, real latency | ML-predicted execution times, instant "fast-forward" |
| KV-Cache | Physical GPU memory allocated/freed | Block counter: _num_allocated_blocks += n |
| Time Progression | Real wall-clock time | Simulated: heapq.heappop jumps to next event |
| Hardware Required | GPU cluster (A100/H100) | CPU only (laptop is fine) |
| One Workload Cost | ~$97K (42K GPU hours) | ~$0.10 (~1 CPU hour) |
| Fidelity | Ground truth (by definition) | <9% error on execution latency, <5% on E2E |