A source-code deep dive into Vidur's three-tier scheduling hierarchy: Global Scheduler dispatches requests to replicas, Replica Scheduler batches them for execution using strategies from vLLM, Sarathi, Orca, LightLLM, and FasterTransformer, and Stage Scheduler manages work across pipeline-parallel stages.
Vidur models the scheduling problem of LLM inference as a three-level hierarchy. Each level operates on a different abstraction: the Global Scheduler routes incoming requests to model replicas, each Replica Scheduler batches requests into execution units using a specific batching strategy, and the Stage Scheduler dispatches batches across pipeline-parallel stages within a single replica.
class BaseGlobalScheduler(ABC):
def __init__(self, config: SimulationConfig, replicas: Dict[int, Replica]):
self._config = config
self._replicas = replicas
self._num_replicas = len(self._replicas)
execution_time_predictor = ExecutionTimePredictorRegistry.get(...)
# Tier 2: One replica scheduler per replica
self._replica_schedulers = {
replica_id: ReplicaSchedulerRegistry.get(
config.cluster_config.replica_scheduler_config.get_type(),
replica=replica,
num_stages=replica.num_pipeline_stages,
execution_time_predictor=execution_time_predictor,
)
for replica_id, replica in replicas.items()
}
self._request_queue = []
vidur/entities/request.py
The Request class tracks the complete lifecycle of an inference request, from arrival through prefill, decode, potential preemption, restart, and completion. It extends BaseEntity which provides auto-incrementing IDs.
class Request(BaseEntity):
def __init__(self,
arrived_at: float, # timestamp of arrival
num_prefill_tokens: int, # prompt length
num_decode_tokens: int, # expected generation length
num_processed_tokens: int = 0, # tokens processed so far
):
self._id = Request.generate_id()
self._arrived_at = arrived_at
self._num_prefill_tokens = num_prefill_tokens
self._num_decode_tokens = num_decode_tokens
self._num_processed_tokens = num_processed_tokens
# --- Timing Fields ---
self._scheduled_at = 0
self._execution_time = 0
self._model_execution_time = 0
self._scheduling_delay = 0
self._preempted_time = 0
self._completed_at = 0
self._prefill_completed_at = 0
self._latest_stage_scheduled_at = 0
self._latest_stage_completed_at = 0
self._latest_iteration_scheduled_at = 0
self._latest_iteration_completed_at = 0
# --- State Flags ---
self._scheduled = False
self._preempted = False
self._completed = False
self._is_prefill_complete = False
self._num_restarts = 0
| Property | Type | Description |
|---|---|---|
num_prefill_tokens | int | Number of prompt tokens to process in prefill phase |
num_decode_tokens | int | Number of tokens to generate in decode phase |
num_processed_tokens | int | Total tokens processed so far (prefill + decode) |
is_prefill_complete | bool | Whether all prefill tokens have been processed |
total_tokens | int | num_prefill_tokens + num_decode_tokens |
scheduling_delay | float | scheduled_at - arrived_at (time waiting in queue) |
e2e_time | float | completed_at - arrived_at (total end-to-end latency) |
num_restarts | int | How many times this request was preempted and restarted |
pd_ratio | float | Prefill-to-decode ratio: num_prefill / num_decode |
def on_batch_end(self, time: float, num_tokens_processed: int) -> None:
self._num_processed_tokens += num_tokens_processed
if self._num_processed_tokens == self._num_prefill_tokens:
self._is_prefill_complete = True
# we get one decode token when the prefill processing completes
self._num_processed_tokens += 1
if self._prefill_completed_at == 0:
self._prefill_completed_at = time
# check if request is completed
if self._num_processed_tokens == self.total_tokens:
self._completed_at = time
self._completed = True
When a request is preempted and restarted, all previously decoded tokens are folded into the prefill phase. This allows the restarted request to reprocess all prior work in a single parallel prefill pass rather than sequential decode steps:
def restart(self):
# when we restart the request, we can process all the previously
# decoded tokens in parallel (i.e., we can prefill all the tokens)
total_tokens = self._num_prefill_tokens + self._num_decode_tokens
self._num_prefill_tokens = self._num_processed_tokens
self._num_decode_tokens = total_tokens - self._num_prefill_tokens
self._num_processed_tokens = 0
self._scheduled = False
self._preempted = False
self._completed = False
self._is_prefill_complete = False
self._num_restarts += 1
vidur/entities/batch.py
A Batch groups multiple requests for simultaneous execution. It tracks per-request token counts (critical for mixed prefill/decode batches) and includes token-count rounding for hardware alignment.
class Batch(BaseEntity):
def __init__(self,
replica_id: int,
requests: List[Request],
num_tokens: List[int], # per-request token counts for this iteration
) -> None:
self._id = Batch.generate_id()
self._replica_id = replica_id
self._requests = requests
self._num_tokens = num_tokens
self._total_num_tokens = sum(num_tokens)
# Separate prefill vs decode token counts
self._num_prefill_tokens = sum([
(t if not r.is_prefill_complete else 0)
for r, t in zip(self.requests, self._num_tokens)
])
# Round up to multiple of 8 for GPU efficiency
self._total_num_tokens_rounded = (self._total_num_tokens + 7) // 8 * 8
self._scheduled_at = None
self._completed_at = None
self._scheduled = False
self._completed = False
def on_schedule(self, time: float) -> None:
self._scheduled_at = time
self._scheduled = True
for request in self._requests:
request.on_batch_schedule(time)
def on_batch_end(self, time: float):
self._completed = True
self._completed_at = time
for request, num_tokens in zip(self._requests, self._num_tokens):
request.on_batch_end(time, num_tokens)
vidur/entities/replica.py
A Replica represents a single model deployment (potentially across multiple GPUs with pipeline and tensor parallelism). It encapsulates model architecture details, device configuration, and parallelism dimensions.
class Replica(BaseEntity):
def __init__(self,
replica_config: ReplicaConfig,
generator_config: BaseRequestGeneratorConfig,
) -> None:
self._id = Replica.generate_id()
self._replica_config = replica_config
self._model_config = replica_config.model_config
self._device_config = replica_config.device_config
# Validation: layers must divide evenly across pipeline stages
assert self._model_config.num_layers % \
self._replica_config.num_pipeline_stages == 0
# Validation: embedding must divide for tensor parallelism
assert self._model_config.embedding_dim % \
self._replica_config.tensor_parallel_size == 0
| Replica Property | Computation |
|---|---|
num_pipeline_stages | From replica_config |
num_layers_per_pipeline_stage | num_layers // num_pipeline_stages |
attention_head_dim | embedding_dim // num_q_heads |
q_heads_per_tensor_parallel_worker | num_q_heads // tensor_parallel_size |
kv_heads_per_tensor_parallel_worker | ceil(num_kv_heads / tensor_parallel_size) |
per_device_flops | fp16_tflops * 2^40 |
total_memory_gb | From device_config (e.g. 80 for A100) |
The Global Scheduler's job is simple but critical: distribute incoming requests across available replicas. All three strategies inherit from BaseGlobalScheduler and implement a single abstract method: schedule() -> List[Tuple[int, Request]].
round_robin_global_scheduler.py
Maintains a monotonically incrementing counter. Each request is assigned to replica_id = counter % num_replicas. Simple, stateless, and deterministic.
def schedule(self):
self.sort_requests()
request_mapping = []
while self._request_queue:
request = self._request_queue.pop(0)
replica_id = self._request_counter \
% self._num_replicas
self._request_counter += 1
request_mapping.append(
(replica_id, request))
return request_mapping
lor_global_scheduler.py
Load-aware scheduling. Builds a pending request count map from all replica schedulers, then assigns each request to the replica with the fewest pending requests. Updates the count as it goes.
def schedule(self):
self.sort_requests()
pending_map = {
rs.replica_id: rs.num_pending_requests
for rs in
self._replica_schedulers.values()
}
request_mapping = []
while self._request_queue:
request = self._request_queue.pop(0)
replica_id = min(
pending_map.items(),
key=lambda x: x[1])[0]
pending_map[replica_id] += 1
request_mapping.append(
(replica_id, request))
return request_mapping
random_global_scheduler.py
Assigns each request to a random replica using randint(1, N) - 1. Provides a baseline for comparison. No state, no load awareness.
def schedule(self):
self.sort_requests()
request_mapping = []
while self._request_queue:
request = self._request_queue.pop(0)
replica_id = randint(
1, self._num_replicas) - 1
request_mapping.append(
(replica_id, request))
return request_mapping
vidur/scheduler/replica_scheduler/base_replica_scheduler.py
BaseReplicaScheduler establishes the common infrastructure used by all five replica scheduler implementations. It manages block-based KV-cache memory allocation, creates one ReplicaStageScheduler per pipeline stage, and defines the on_schedule() loop that fills pipeline stages with batches.
class BaseReplicaScheduler(ABC):
def __init__(self,
replica_config, replica_scheduler_config,
request_generator_config, replica,
num_stages, execution_time_predictor,
) -> None:
self._max_blocks_per_sequence = (
request_generator_config.max_tokens // replica_scheduler_config.block_size
)
# Memory planning: compute max batch size from available GPU memory
memory_planner = MemoryPlanner(replica_config, replica)
if not self._config.num_blocks:
self._config.num_blocks = (
self._max_blocks_per_sequence * memory_planner.get_max_request_slots()
)
self._max_batch_size = min(
memory_planner.get_max_batch_size(),
self._config.batch_size_cap,
)
self._request_queue = []
self._num_allocated_blocks = 0
self._allocation_map = {} # request_id -> num_blocks
# Tier 3: Create one stage scheduler per pipeline stage
self._replica_stage_schedulers = {
stage_id: ReplicaStageScheduler(
replica.id, stage_id,
stage_id == num_stages - 1, # is_last_stage
execution_time_predictor,
)
for stage_id in range(num_stages)
}
Memory is tracked in units of blocks. Each block holds block_size tokens worth of KV-cache. The allocate() and free() methods maintain a running count and per-request allocation map.
def can_allocate(self, num_blocks: int) -> bool:
return self._config.num_blocks - self._num_allocated_blocks >= num_blocks
def allocate(self, request_id: int, num_blocks: int) -> None:
self._num_allocated_blocks += num_blocks
if request_id not in self._allocation_map:
self._allocation_map[request_id] = num_blocks
else:
self._allocation_map[request_id] += num_blocks
assert self._num_allocated_blocks <= self._config.num_blocks
def free(self, *request_ids: List[int]) -> None:
for request_id in request_ids:
num_blocks = self._allocation_map.pop(request_id)
self._num_allocated_blocks -= num_blocks
assert self._num_allocated_blocks >= 0
This is the entry point called by the simulation engine. It fills pipeline stages by calling the subclass-specific _get_next_batch() up to num_stages times:
def on_schedule(self) -> List[Batch]:
scheduled_batches = []
while self._num_running_batches < self._num_stages:
batch = self._get_next_batch() # abstract: each scheduler implements differently
if not batch:
break
scheduled_batches.append(batch)
self._num_running_batches += 1
return scheduled_batches
def _get_request_next_num_tokens(self, request: Request) -> int:
if request.is_prefill_complete:
return 1 # decode: one token at a time
return request.num_prefill_tokens # prefill: all prompt tokens
vidur/scheduler/replica_scheduler/vllm_replica_scheduler.py
The vLLM scheduler implements PagedAttention-style continuous batching. It uses fine-grained block allocation with a watermark mechanism to reserve memory for running requests. Preempted requests can be evicted from the back of the queue (LIFO victim selection) and restarted.
def _can_allocate_request(self, request: Request) -> bool:
if request.id not in self._allocation_map:
# New request: need blocks for all prefill tokens + watermark
num_required_blocks = ceil(request.num_prefill_tokens / self._config.block_size)
return (self._config.num_blocks
- self._num_allocated_blocks
- num_required_blocks
>= self._watermark_blocks)
# Existing request: vllm requires at least one block to be available
return self._config.num_blocks - self._num_allocated_blocks >= 1
def _allocate_request(self, request: Request) -> None:
if request.id not in self._allocation_map:
num_required_blocks = ceil(request.num_prefill_tokens / self._config.block_size)
self.allocate(request.id, num_required_blocks)
return
# For existing requests: allocate one more block if needed
num_tokens_reserved = self._allocation_map[request.id] * self._config.block_size
num_tokens_required = max(0, request.num_processed_tokens - num_tokens_reserved)
if num_tokens_required == 0:
return
self.allocate(request.id, 1)
# Preempted request processing with eviction
while self._preempted_requests:
request = self._preempted_requests.pop(0)
while not self._can_allocate_request(request):
if self._preempted_requests:
victim_request = self._preempted_requests.pop(-1) # LIFO eviction
victim_request.restart()
self.free(victim_request.id)
self._request_queue = [victim_request] + self._request_queue
else:
request.restart() # no victims: restart self
self.free(request.id)
self._request_queue = [request] + self._request_queue
break
else:
self._allocate_request(request)
requests.append(request)
num_tokens.append(self._get_request_next_num_tokens(request))
vidur/scheduler/replica_scheduler/sarathi_replica_scheduler.py
Sarathi's key innovation is chunked prefill. Instead of processing the entire prompt at once, it breaks prefill into chunks of configurable size (chunk_size). This allows mixing prefill chunks with decode tokens in the same batch, reducing pipeline bubbles and improving GPU utilization.
def _get_request_next_num_tokens(self, request, batch_contains_prefill, num_batch_tokens):
assert not request.completed
if request.is_prefill_complete:
return 1 # decode: always 1 token
# CHUNKED PREFILL: cap by remaining budget in this chunk
next_num_tokens = min(
request.num_prefill_tokens - request.num_processed_tokens, # remaining prefill
self._config.chunk_size - num_batch_tokens, # budget left
)
next_num_tokens = max(0, next_num_tokens)
return next_num_tokens
Sarathi's _get_next_batch() operates in three ordered phases, prioritizing running requests:
def _get_next_batch(self) -> Batch:
requests, num_tokens = [], []
skipped_requests, running_prefills = [], []
contains_prefill = False
num_batch_tokens = 0
# PHASE 1: Preempted requests with completed prefill (decode only)
while self._preempted_requests:
request = self._preempted_requests.pop(0)
if not request.is_prefill_complete:
running_prefills.append(request) # defer to phase 2
continue
# ... eviction logic similar to vLLM ...
self._allocate_request(request)
num_batch_tokens += next_num_tokens
requests.append(request)
# PHASE 2: Partial prefills from preemption
for request in running_prefills:
next_num_tokens = self._get_request_next_num_tokens(
request, contains_prefill, num_batch_tokens)
if next_num_tokens == 0:
skipped_requests.append(request); continue
contains_prefill = True
num_batch_tokens += next_num_tokens
requests.append(request)
# PHASE 3: New requests from queue (consume remaining chunk budget)
while self._request_queue:
if len(self._allocation_map) == self._config.batch_size_cap: break
next_num_tokens = self._get_request_next_num_tokens(
self._request_queue[0], contains_prefill, num_batch_tokens)
if next_num_tokens == 0: break
request = self._request_queue.pop(0)
self._allocate_request(request)
contains_prefill = True
num_batch_tokens += next_num_tokens
requests.append(request)
if not requests: return
return Batch(self._replica_id, requests, num_tokens)
vidur/scheduler/replica_scheduler/orca_replica_scheduler.py
Orca (the system that introduced iteration-level scheduling / continuous batching) uses the simplest allocation strategy: it pre-allocates max_blocks_per_sequence for each new request. No watermarks, no incremental allocation. Preempted requests get priority but without eviction logic.
class OrcaReplicaScheduler(BaseReplicaScheduler):
def _get_next_batch(self) -> Batch:
requests, num_tokens = [], []
# Preempted requests first (all have completed prefill)
while self._preempted_requests:
if len(requests) == self._max_batch_size: break
request = self._preempted_requests.pop(0)
next_num_tokens = self._get_request_next_num_tokens(request)
requests.append(request)
num_tokens.append(next_num_tokens)
# New requests: pre-allocate max blocks per sequence
while self._request_queue:
if len(requests) == self._max_batch_size: break
if not self.can_allocate(self._max_blocks_per_sequence): break
request = self._request_queue.pop(0)
self.allocate(request.id, self._max_blocks_per_sequence)
next_num_tokens = self._get_request_next_num_tokens(request)
requests.append(request); num_tokens.append(next_num_tokens)
if not requests: return
return Batch(self._replica_id, requests, num_tokens)
vidur/scheduler/replica_scheduler/lightllm_replica_scheduler.py
LightLLM uses block_size=1 (token-level granularity) and a unique memory feasibility check based on cumulative token analysis. It enforces pipeline_stages=1, meaning no pipeline parallelism. It explicitly separates prefill and decode batches with a configurable max_waiting_iters threshold.
class LightLLMReplicaScheduler(BaseReplicaScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self._config.block_size == 1, \
"LightLLM scheduler only supports block size of 1."
assert self._num_stages == 1, \
"LightLLM scheduler does not support pipeline parallel."
self._num_waiting_iters = 0
def _can_allocate_request(self, request: Request) -> bool:
self.cache_len_list.append(self._get_tuple_tokens(request))
self.cache_len_list.sort(key=lambda x: -x[1]) # sort by remaining tokens desc
left_out_len_array = np.array([e[1] for e in self.cache_len_list])
has_run_len_array = np.array([e[0] for e in self.cache_len_list])
cum_run_len_array = np.cumsum(has_run_len_array)
size_array = np.arange(1, len(self.cache_len_list) + 1, 1)
# Peak memory need: worst case across all possible completion orderings
need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max()
return need_max_token_num < self._config.num_blocks
def _get_next_batch(self) -> Batch:
# No preempted (decode) requests -> do prefill
if not self._preempted_requests:
batch = self._get_prefill_batch()
if batch: self._num_waiting_iters = 0
return batch
# If we've waited too long, allow a prefill batch
if self._num_waiting_iters >= self._config.max_waiting_iters:
self._num_waiting_iters = 0
batch = self._get_prefill_batch()
if batch: return batch
# Otherwise, run decode batch
if self._can_decode():
self._num_waiting_iters += 1
return self._get_decode_batch()
else:
raise RuntimeError("OOM handling not implemented yet")
vidur/scheduler/replica_scheduler/faster_transformer_replica_scheduler.py
FasterTransformer represents the traditional static batching approach. Once a batch is created, all its requests run together until they ALL complete. Unlike the other schedulers, it tracks preempted BATCHES (not individual requests). Memory for a batch is only freed when every request in it finishes.
class FasterTransformerReplicaScheduler(BaseReplicaScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._preempted_batches = [] # NOTE: batches, not requests!
self._pending_free_map = {}
def on_batch_end(self, batch: Batch) -> None:
self._num_running_batches -= 1
if batch.all_requests_completed:
# Free memory for ALL requests at once
self.free_batch(batch)
self.free(*self._pending_free_map.pop(batch.id, []))
else:
self._preempted_batches.append(batch) # keep the whole batch
def _get_next_batch(self) -> Batch:
# Priority: continue existing batch (minus completed requests)
if self._preempted_batches:
preempted_batch = self._preempted_batches.pop(0)
return self._generate_next_batch_from_preempted(preempted_batch)
# New batch: allocate max blocks per sequence (static)
requests, num_tokens = [], []
while self._request_queue:
if len(requests) == self._max_batch_size: break
if not self.can_allocate(self._max_blocks_per_sequence): break
request = self._request_queue.pop(0)
self.allocate(request.id, self._max_blocks_per_sequence)
...
return Batch(self._replica_id, requests, num_tokens)
| Feature | vLLM | Sarathi | Orca | LightLLM | FasterTransformer |
|---|---|---|---|---|---|
| Allocation | Incremental + watermark | Incremental + watermark | Max pre-allocate | Token-level (block=1) | Max pre-allocate |
| Prefill Strategy | Full at once | Chunked | Full at once | Separate batch | Full at once |
| Mixed Prefill+Decode | Yes (new+decode) | Yes (chunked+decode) | Yes | No (separate) | Yes (within batch) |
| Eviction | LIFO victim | LIFO victim | None | RuntimeError | None |
| Pipeline Parallel | Yes | Yes | Yes | No (assert) | Yes |
| Preempt Unit | Request | Request | Request | Request | Batch |
vidur/scheduler/replica_stage_scheduler/replica_stage_schduler.py
The ReplicaStageScheduler is the lowest tier. Each pipeline stage has one instance. It maintains a batch queue and a busy flag. When on_schedule() is called, it pops a batch, queries the ExecutionTimePredictor for timing, and creates a BatchStage object for the simulation engine.
class ReplicaStageScheduler:
def __init__(self,
replica_id: int,
stage_id: int,
is_last_stage: bool,
execution_time_predictor: BaseExecutionTimePredictor,
) -> None:
self._replica_id = replica_id
self._stage_id = stage_id
self._is_last_stage = is_last_stage
self._execution_time_predictor = execution_time_predictor
self._batch_queue = []
self._is_busy = False
def add_batch(self, batch: Batch) -> None:
self._batch_queue.append(batch)
def on_stage_end(self) -> None:
self._is_busy = False
def on_schedule(self) -> Tuple[Batch, BatchStage, ExecutionTime]:
if self._is_busy or not self._batch_queue:
return None, None, None
self._is_busy = True
batch = self._batch_queue.pop(0)
# Query the execution time predictor for this batch at this stage
execution_time = self._execution_time_predictor.get_execution_time(
batch, self._stage_id,
)
total_execution_time = execution_time.total_time
model_execution_time = execution_time.model_time
# Create a BatchStage: the unit of work for simulation
batch_stage = BatchStage(
batch.id, self._replica_id, self._stage_id,
total_execution_time, model_execution_time,
batch.requests, batch.num_tokens,
)
return batch, batch_stage, execution_time
vidur/scheduler/utils/memory_planner.py
The MemoryPlanner computes how many requests can fit in GPU memory by subtracting model parameter memory from total device memory, then dividing by KV-cache memory per request.
class MemoryPlanner:
def _get_kv_cache_memory_per_layer_per_request(self) -> int:
return (
2 # 2 bytes per float (fp16)
* 2 # one for key, one for value
* self._replica.attention_head_dim
* self._replica.kv_heads_per_tensor_parallel_worker
* self._replica.max_request_tokens
)
def get_max_batch_size(self) -> int:
available_memory = (
self._replica.total_memory_gb * 1024**3
* (1 - self._replica.memory_margin_fraction)
)
parameter_memory = self._get_parameter_memory_per_device()
kv_per_request = self._get_kv_cache_memory_per_device_per_request()
memory_for_kv_cache = available_memory - parameter_memory
number_of_requests = memory_for_kv_cache // kv_per_request
assert number_of_requests > 0, "Not enough memory for even a single request"
return number_of_requests
def get_max_request_slots(self) -> int:
return self.get_max_batch_size() * self._replica.num_pipeline_stages
A request moves through a well-defined state machine, driven by callbacks from the batch and batch-stage entities. Here is the complete lifecycle:
vidur/entities/batch_stage.py
BatchStage bridges the gap between the batch (logical unit) and the pipeline stage (physical execution). It holds the execution time prediction and propagates stage-level events to requests. It also provides Chrome trace format for visualization.
class BatchStage(BaseEntity):
def on_schedule(self, time: float) -> None:
self._scheduled_at = time
self._scheduled = True
for request in self._requests:
request.on_batch_stage_schedule(time)
def on_stage_end(self, time: float) -> None:
assert time == self._scheduled_at + self._execution_time
self._completed_at = time
for request in self._requests:
request.on_batch_stage_end(
time, self._execution_time, self._model_execution_time)
def to_chrome_trace(self, time: int) -> dict:
return {
"name": f"{self.request_ids}",
"ph": "X",
"ts": (time - self._execution_time) * 1e6,
"dur": self._execution_time * 1e6,
"pid": self._replica_id,
"tid": self._pipeline_stage,
}
vidur/entities/execution_time.py
The ExecutionTime entity models every component of a forward pass at the sub-operation level:
Per transformer block:
model_time =
(block_exec * layers_per_stage
+ PP_comm) * 1e-3 # ms -> s
schedule_timesampler_e2e_timeprepare_inputs_e2e_timeprocess_model_outputs_timeray_comm_timetotal_time =
model_time
+ cpu_overhead * 1e-3
vidur/entities/base_entity.py
All entities (Request, Batch, BatchStage, Replica, ExecutionTime) share the same auto-incrementing ID mechanism via class-level counters:
class BaseEntity:
_id = -1 # class variable: each subclass has its own counter
@classmethod
def generate_id(cls):
cls._id += 1
return cls._id
@property
def id(self) -> int:
return self._id