How vLLM V1 decides which requests run each step, manages a unified token budget for prefill and decode, allocates GPU KV cache blocks, and implements prefix caching with a doubly-linked free-block queue.
Understanding the scheduler requires knowing five core data structures that bridge the request lifecycle and GPU memory. All live in the scheduler process (CPU-side).
A priority queue (deque-based) of Request objects that have arrived
but have not yet started execution. The scheduler pops from this queue in FCFS or priority order.
Preempted requests are prepended back to the front via prepend_request().
A plain Python list of requests currently occupying GPU KV cache blocks.
These are iterated first each schedule() call. Preempted
requests are pop()-ed from the tail (LIFO eviction for FCFS).
Lives inside each SingleTypeKVCacheManager. Maps request ID to the
ordered list of KV cache blocks assigned. Block IDs are append-only -- new blocks are
added but existing entries never change position.
A BlockHashToBlockMap in BlockPool.
Maps BlockHashWithGroupId to one or more KVCacheBlock
objects. This is the hash table that makes prefix caching O(1) per block lookup.
A custom doubly-linked list of free/eviction-candidate blocks. Supports O(1)
popleft(), append(), and crucially
remove(block) from the middle -- impossible with Python's built-in deque.
The per-block metadata unit. Fields: block_id,
ref_cnt, _block_hash,
prev_free_block, next_free_block,
is_null. The linked-list pointers embed the block into the free queue
without allocating wrapper objects.
# From vllm/v1/core/kv_cache_utils.py - The KVCacheBlock dataclass @dataclass(slots=True) class KVCacheBlock: block_id: int ref_cnt: int = 0 _block_hash: BlockHashWithGroupId | None = None # Doubly linked list pointers for FreeKVCacheBlockQueue prev_free_block: "KVCacheBlock | None" = None next_free_block: "KVCacheBlock | None" = None is_null: bool = False
schedule() Method -- Step by StepThe scheduler's main entry point is called once per engine step. Its philosophy is stated in a key comment at the top of the method:
# NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. # Each request just has the num_computed_tokens and # num_tokens_with_spec. [...] At each step, the scheduler tries # to assign tokens to the requests so that each request's # num_computed_tokens can catch up its num_tokens_with_spec. # This is general enough to cover chunked prefills, prefix caching, # speculative decoding, and the "jump decoding" optimization.
num_tokens_with_spec - num_computed_tokens == 1. A "prefill" is any request where the gap is larger. Chunked prefill simply caps how much of that gap is scheduled in a single step.
Set token_budget = self.max_num_scheduled_tokens.
Create empty collections for scheduled new/running/resumed requests, new block mappings,
and encoder inputs. Call kv_cache_manager.new_step_starts().
Iterate self.running. For each request, calculate
num_new_tokens = request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens.
Apply chunked prefill cap, then min(num_new_tokens, token_budget).
Call allocate_slots(). If allocation fails, preempt from the tail.
Only if no preemptions occurred. Pop from waiting queue, look up prefix cache hits,
compute num_new_tokens = request.num_tokens - num_computed_tokens,
apply chunked prefill threshold, call allocate_slots().
Move successful requests to self.running.
Pack NewRequestData (for first-time requests) and
CachedRequestData (for continuing requests) with block IDs
and computed token counts. Compute common prefix blocks for cascade attention.
_update_after_schedule() advances each request's
num_computed_tokens by the scheduled amount so the next
call to schedule() sees updated state. Clear
finished_req_ids.
There is a single token_budget variable shared between prefill and decode. Decode tokens consume 1 token each from the same pool. Long prefills can consume many tokens. This is why vLLM V1 can naturally perform "continuous batching" -- there are no separate phases.
# scheduler.py line 360-363 token_budget = self.max_num_scheduled_tokens if self._pause_state == PauseState.PAUSED_ALL: token_budget = 0 # Do not schedule when paused
The budget is decremented identically for running and waiting requests:
# For RUNNING requests (line 509-510): num_scheduled_tokens[request_id] = num_new_tokens token_budget -= num_new_tokens # For WAITING requests (line 818-819): num_scheduled_tokens[request_id] = num_new_tokens token_budget -= num_new_tokens
max_num_scheduled_tokens (usually set to max_num_batched_tokens).
| Config Parameter | Default | Effect |
|---|---|---|
max_num_scheduled_tokens |
max_num_batched_tokens |
Total token budget per step (shared across all requests) |
max_num_seqs |
256 | Maximum requests in the running list (max_num_running_reqs) |
long_prefill_token_threshold |
0 (disabled) | Cap single-request prefill chunk size (enables chunked prefill) |
The running queue is iterated first. This ensures requests that already hold KV cache blocks get to continue, minimizing wasted GPU memory from preempted partial computations.
# scheduler.py line 377-378 req_index = 0 while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] num_new_tokens = ( request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens ) # Apply chunked prefill threshold if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget)
Waiting requests are only considered if no preemptions occurred in the running phase. This prevents a starvation cycle where new requests cause preemptions that cause more new requests.
# scheduler.py line 557 if not preempted_reqs and self._pause_state == PauseState.UNPAUSED: # ... schedule waiting requests while (self.waiting or self.skipped_waiting) and token_budget > 0: if len(self.running) == self.max_num_running_reqs: break
Waiting requests get a prefix cache lookup (get_computed_blocks()),
which may dramatically reduce the number of tokens that need computing:
# scheduler.py line 603-606 if request.num_computed_tokens == 0: new_computed_blocks, num_new_local_computed_tokens = ( self.kv_cache_manager.get_computed_blocks(request) )
self.skipped_waiting and re-checked each step. This avoids head-of-line blocking in the main waiting queue.
Chunked prefill is controlled by long_prefill_token_threshold. When a request's remaining tokens exceed this threshold, only the threshold amount is scheduled. The request stays in the running queue with updated num_computed_tokens and will get more tokens next step.
# scheduler.py line 402-404 if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget)
# scheduler.py line 658-661 num_new_tokens = request.num_tokens - num_computed_tokens threshold = self.scheduler_config.long_prefill_token_threshold if 0 < threshold < num_new_tokens: num_new_tokens = threshold
Without chunked prefill enabled, the scheduler enforces that the full request fits in the budget:
# scheduler.py line 666-671 if ( not self.scheduler_config.enable_chunked_prefill and num_new_tokens > token_budget ): # If chunked_prefill is disabled, # we can stop the scheduling here. break
scheduler_reserve_full_isl is enabled, waiting requests must pass a can_fit_full_sequence() admission gate that checks if the KV cache has enough blocks for the entire sequence (not just the first chunk). This prevents over-admission that would lead to preemptions later.
Every block of tokens has a unique hash computed as a chained hash. The hash of block N depends on the hash of block N-1, creating a Merkle-chain. The block size for hashing defaults to 16 tokens (configurable via block_size).
# kv_cache_utils.py -- hash_block_tokens() def hash_block_tokens( hash_function: Callable[[Any], bytes], parent_block_hash: BlockHash | None, curr_block_token_ids: Sequence[int], extra_keys: tuple[Any, ...] | None = None, ) -> BlockHash: if not parent_block_hash: parent_block_hash = NONE_HASH # Random seed or from PYTHONHASHSEED curr_block_token_ids_tuple = tuple(curr_block_token_ids) return BlockHash( hash_function((parent_block_hash, curr_block_token_ids_tuple, extra_keys)) )
The hash chain is built incrementally. Each Request maintains a
block_hashes list that grows as tokens are appended. The block hasher
function is created per-model:
# kv_cache_utils.py -- get_request_block_hasher() def request_block_hasher(request: Request) -> list[BlockHash]: start_token_idx = len(request.block_hashes) * block_size # ... while True: end_token_idx = start_token_idx + block_size if end_token_idx > num_tokens: break # Only hash FULL blocks extra_keys, curr_mm_idx = generate_block_hash_extra_keys( request, start_token_idx, end_token_idx, curr_mm_idx ) block_tokens = request.all_token_ids[start_token_idx:end_token_idx] block_hash = hash_block_tokens( caching_hash_fn, prev_block_hash_value, block_tokens, extra_keys ) new_block_hashes.append(block_hash) start_token_idx += block_size prev_block_hash_value = block_hash return new_block_hashes
(mm_hash, start_offset) in the block hash. LoRA requests include the adapter name. Cache salt is included for the first block. This ensures two requests with the same token IDs but different LoRA adapters or images will have different block hashes.
For full attention, the lookup is a simple left-to-right scan. Each block hash is checked in the
cached_block_hash_to_block map. Because block hashes are chained,
a miss at position N guarantees misses at N+1, N+2, ...
# single_type_kv_cache_manager.py -- FullAttentionManager.find_longest_cache_hit() max_num_blocks = max_length // block_size for block_hash in itertools.islice(block_hashes, max_num_blocks): if cached_block := block_pool.get_cached_block( block_hash, kv_cache_group_ids ): for computed, cached in zip(computed_blocks, cached_block): computed.append(cached) else: break # Chain broken -- no further hits possible if use_eagle and computed_blocks[0]: # Drop last matched block to force recompute for EAGLE hidden states for computed in computed_blocks: computed.pop()
When a cached block is "hit", its ref_cnt is incremented. If the
block was an eviction candidate (ref_cnt == 0, sitting in the free queue),
it is removed from the free queue via O(1) linked-list surgery:
# block_pool.py -- touch() def touch(self, blocks: Sequence[KVCacheBlock]) -> None: for block in blocks: # ref_cnt=0 means this block is in the free list (eviction candidate) if block.ref_cnt == 0 and not block.is_null: self.free_block_queue.remove(block) # O(1) doubly-linked list removal block.ref_cnt += 1
This is a critical performance structure. Python's built-in deque does not support O(1) removal from the middle. The custom implementation uses the prev_free_block and next_free_block pointers directly on KVCacheBlock objects -- no wrapper nodes, no extra memory allocation.
# block_pool.py -- get_new_blocks() def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks) if self.enable_caching: for block in ret: self._maybe_evict_cached_block(block) # Remove hash if block was cached assert block.ref_cnt == 0 block.ref_cnt += 1 else: for block in ret: assert block.ref_cnt == 0 block.ref_cnt += 1 return ret
When a request finishes or is preempted, blocks are freed in reverse order. This ensures tail blocks (least reusable for prefix caching) are evicted first:
# single_type_kv_cache_manager.py -- free() def free(self, request_id: str) -> None: req_blocks = self.req_to_blocks.pop(request_id, []) ordered_blocks = reversed(req_blocks) # Tail blocks freed first! self.block_pool.free_blocks(ordered_blocks) self.num_cached_block.pop(request_id, None)
# block_pool.py -- free_blocks() def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: blocks_list = list(ordered_blocks) for block in blocks_list: block.ref_cnt -= 1 # Only return to free queue if ref_cnt dropped to 0 self.free_block_queue.append_n( [block for block in blocks_list if block.ref_cnt == 0 and not block.is_null] )
ref_cnt is 2. Freeing one request decrements to 1 -- the block stays allocated for the other request and is NOT added to the free queue. Only when ref_cnt reaches 0 does the block become an eviction candidate.
# kv_cache_utils.py -- FreeKVCacheBlockQueue.popleft_n() def popleft_n(self, n: int) -> list[KVCacheBlock]: if n == 0: return [] assert self.num_free_blocks >= n self.num_free_blocks -= n curr_block = self.fake_free_list_head.next_free_block ret = [] for _ in range(n): ret.append(curr_block) last_block = curr_block curr_block = curr_block.next_free_block # Disconnect popped block from the list last_block.prev_free_block = None last_block.next_free_block = None if curr_block is not None: self.fake_free_list_head.next_free_block = curr_block curr_block.prev_free_block = self.fake_free_list_head return ret
vLLM V1 uses a recomputation-only preemption policy. There is no swapping of KV cache to CPU memory. When a running request cannot fit its new tokens and no free blocks remain, the scheduler preempts the lowest-priority request by freeing ALL its blocks and putting it back in the waiting queue with num_computed_tokens = 0.
# scheduler.py -- _preempt_request() def _preempt_request(self, request: Request, timestamp: float) -> None: assert request.status == RequestStatus.RUNNING self.kv_cache_manager.free(request) # Free ALL KV blocks self.encoder_cache_manager.free(request) # Free encoder cache too request.status = RequestStatus.PREEMPTED request.num_computed_tokens = 0 # Must recompute from scratch! if request.spec_token_ids: request.spec_token_ids = [] request.num_preemptions += 1 # Put back at the front of the waiting queue self.waiting.prepend_request(request)
When allocate_slots() returns None for a
running request, the scheduler enters a preemption loop:
# scheduler.py lines 455-499 while True: new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens, num_lookahead_tokens=self.num_lookahead_tokens, ) if new_blocks is not None: break # Success! # OOM: preempt the lowest-priority request if self.policy == SchedulingPolicy.PRIORITY: preempted_req = max( self.running, key=lambda r: (r.priority, r.arrival_time), ) self.running.remove(preempted_req) else: # FCFS: pop from the TAIL (LIFO -- last admitted is first evicted) preempted_req = self.running.pop() self._preempt_request(preempted_req, scheduled_timestamp) preempted_reqs.append(preempted_req) if preempted_req == request: # We preempted ourselves -- give up break
| Aspect | vLLM V0 | vLLM V1 |
|---|---|---|
| Preemption strategy | Swap to CPU or recompute | Recompute only |
| KV blocks after preempt | May survive in CPU swap | All freed (ref_cnt decremented) |
| Recovery path | Swap back or recompute | Prefix cache hit + recompute remainder |
| num_computed_tokens | Preserved if swapped | Reset to 0 |
The scheduler produces a SchedulerOutput dataclass each step, which is sent to worker processes. To minimize IPC cost, it distinguishes between new requests (full data) and cached requests (only the diff).
Contains the full prompt_token_ids, sampling_params,
block_ids, mm_features,
lora_request, and num_computed_tokens.
Workers cache this data so it does not need to be resent.
Contains only new_block_ids (appended to existing blocks),
num_computed_tokens (updated), and
resumed_req_ids (set of requests that were preempted and are resuming --
these need full block_ids replacement rather than append).
# output.py -- SchedulerOutput @dataclass class SchedulerOutput: scheduled_new_reqs: list[NewRequestData] scheduled_cached_reqs: CachedRequestData num_scheduled_tokens: dict[str, int] # req_id -> num tokens this step total_num_scheduled_tokens: int # sum of above scheduled_spec_decode_tokens: dict[str, list[int]] scheduled_encoder_inputs: dict[str, list[int]] num_common_prefix_blocks: list[int] # for cascade attention finished_req_ids: set[str] free_encoder_mm_hashes: list[str] preempted_req_ids: set[str] | None = None new_block_ids_to_zero: list[int] | None = None kv_connector_metadata: KVConnectorMetadata | None = None # ... additional fields for structured output, spec decode, etc.
# scheduler.py lines 907-923 -- Final assembly scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, scheduled_encoder_inputs=scheduled_encoder_inputs, num_common_prefix_blocks=num_common_prefix_blocks, preempted_req_ids={req.request_id for req in preempted_reqs}, finished_req_ids=self.finished_req_ids, free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), new_block_ids_to_zero=new_block_ids_to_zero, )
The _make_cached_request_data() method efficiently packs only the diff:
# scheduler.py -- _make_cached_request_data() for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)): req_id = req.request_id req_ids.append(req_id) scheduled_in_prev_step = req_id in self.prev_step_scheduled_req_ids if idx >= num_running_reqs: resumed_req_ids.add(req_id) # This is a resumed (preempted) request if not scheduled_in_prev_step: all_token_ids[req_id] = req.all_token_ids.copy() # Full token IDs for new entries new_block_ids.append( req_to_new_blocks[req_id].get_block_ids(allow_none=True) ) num_computed_tokens.append(req.num_computed_tokens) num_output_tokens.append(req.num_output_tokens + req.num_output_placeholders)
all_token_ids field is omitted. The worker already has these tokens cached. Only new_block_ids and updated num_computed_tokens are sent.
The KVCacheCoordinator sits between the scheduler-facing
KVCacheManager and the per-layer-type
SingleTypeKVCacheManager instances. Three concrete
coordinators exist, chosen by the model's attention configuration:
| Coordinator | When Used | Prefix Caching |
|---|---|---|
KVCacheCoordinatorNoPrefixCache |
Prefix caching disabled or unsupported specs | No |
UnitaryKVCacheCoordinator |
Single KV cache group (e.g., all full attention) | Yes |
HybridKVCacheCoordinator |
Multiple cache groups (e.g., full + sliding window) | Yes, with fixed-point algorithm |
# kv_cache_coordinator.py -- Factory function def get_kv_cache_coordinator(...) -> KVCacheCoordinator: if not enable_caching: return KVCacheCoordinatorNoPrefixCache(...) if len(kv_cache_config.kv_cache_groups) == 1: return UnitaryKVCacheCoordinator(...) return HybridKVCacheCoordinator(...)
For models with mixed attention types (e.g., full attention + sliding window), each attention type may support a different cache hit length. The hybrid coordinator uses an iterative algorithm where each type either accepts or reduces the candidate length:
# kv_cache_coordinator.py -- HybridKVCacheCoordinator.find_longest_cache_hit() while True: curr_hit_length = hit_length for spec, group_ids, manager_cls in self.attention_groups: hit_blocks = manager_cls.find_longest_cache_hit( block_hashes=_get_block_hashes(spec), max_length=curr_hit_length, kv_cache_group_ids=group_ids, block_pool=self.block_pool, kv_cache_spec=spec, use_eagle=self.use_eagle, alignment_tokens=self.lcm_block_size, ) curr_hit_length = len(hit_blocks[0]) * spec.block_size if curr_hit_length >= hit_length: break # Converged -- no type reduced the length hit_length = curr_hit_length
This is the most complex method in the KV cache system. It handles prefix cache hits, sliding window cleanup, external (P/D disaggregation) tokens, and speculative decode lookahead -- all in one call. Here is the block layout it operates on:
# kv_cache_manager.py -- Block layout diagram from the source: # ----------------------------------------------------------------- # | <comp> | <new_comp> | <ext_comp> | <new> | <lookahead> | # ----------------------------------------------------------------- # | <to be computed> | # ----------------------------------------------------------------- # | <to be allocated> | # ----------------------------------------------------------------- # comp = request.num_computed_tokens # new_comp = num_new_computed_tokens (prefix cache hits this step) # ext_comp = num_external_computed_tokens (from P/D connector) # new = num_new_tokens (tokens to compute this step) # lookahead = num_lookahead_tokens (for spec decode)
The allocation has three phases:
Remove blocks outside the sliding window. This frees memory before allocation begins:
self.coordinator.remove_skipped_blocks( request.request_id, total_computed_tokens )
Calculate how many new blocks are needed and compare against free pool:
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(...) if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): return None # OOM -- caller must preempt
Add computed blocks (touching them to prevent eviction), then allocate fresh blocks for new tokens. Finally, cache completed blocks by setting their hashes:
self.coordinator.allocate_new_computed_blocks(...) # Touch + extend new_blocks = self.coordinator.allocate_new_blocks(...) # Fresh blocks from pool self.coordinator.cache_blocks(request, num_tokens_to_cache) # Set hashes
Block ID 0 is reserved as a "null block" -- a sentinel used to represent positions
where blocks have been freed (e.g., outside a sliding window). Its
ref_cnt is not tracked:
# block_pool.py -- BlockPool.__init__() self.null_block = self.free_block_queue.popleft() # Takes block 0 self.null_block.is_null = True
When get_new_blocks() pops a block from the free queue, that
block might be a cached eviction candidate. The method calls
_maybe_evict_cached_block() to clean up its hash from the lookup table:
# block_pool.py -- _maybe_evict_cached_block() def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: block_hash = block.block_hash if block_hash is None: return False # Not cached, nothing to evict if self.cached_block_hash_to_block.pop(block_hash, block.block_id) is None: return False # Already removed from cache block.reset_hash() # Clear the block's hash metadata return True
Used in RLHF flows after weight updates to invalidate all cached blocks. Only succeeds if no requests are actively using blocks:
# block_pool.py -- reset_prefix_cache() def reset_prefix_cache(self) -> bool: num_used_blocks = self.num_gpu_blocks - self.get_num_free_blocks() if num_used_blocks != 1: # Only the null block should be "used" return False self.cached_block_hash_to_block = BlockHashToBlockMap() for block in self.blocks: block.reset_hash() return True
| File | Key Classes / Functions | Responsibility |
|---|---|---|
sched/scheduler.py |
Scheduler.schedule() |
Main scheduling loop: running phase, waiting phase, preemption, output construction |
sched/output.py |
SchedulerOutput, NewRequestData, CachedRequestData |
IPC-optimized scheduler-to-worker communication protocol |
kv_cache_manager.py |
KVCacheManager |
Scheduler-facing facade for block allocation, prefix cache, and freeing |
kv_cache_coordinator.py |
UnitaryKVCacheCoordinator, HybridKVCacheCoordinator |
Coordinate multi-group cache hits, dispatch to per-type managers |
single_type_kv_cache_manager.py |
FullAttentionManager, SlidingWindowManager |
Per-attention-type block management, find_longest_cache_hit(), req_to_blocks |
block_pool.py |
BlockPool, BlockHashToBlockMap |
GPU block pool with free queue, hash-to-block cache, eviction logic |
kv_cache_utils.py |
KVCacheBlock, FreeKVCacheBlockQueue, hash_block_tokens() |
Block metadata, O(1) doubly-linked free list, chained hash computation |
Prefill and decode share one token budget. A single scheduling loop handles all request types uniformly.
Block IDs are never reordered within a request. New blocks are appended. This enables efficient GPU-side attention.
V1 dropped CPU swap. Prefix caching makes recomputation efficient -- most prefix blocks are recovered from cache.
The free queue embeds linked-list pointers in KVCacheBlock itself. No wrapper objects means no GC overhead.