A walk through vLLM V1 (v0.19.x) source: the schedule() method line by line, unified token budget, prefix caching with Merkle-chained hashes, the O(1) doubly-linked free-block queue, preemption via recomputation, async-scheduling skip-ahead, and speculative-decode rejection bookkeeping.
Understanding the scheduler requires knowing five core data structures that bridge the request lifecycle and GPU memory. All live in the scheduler process (CPU-side).
A priority queue (deque-based) of Request objects that have arrived
but have not yet started execution. The scheduler pops from this queue in FCFS or priority order.
Preempted requests are prepended back to the front via prepend_request().
A plain Python list of requests currently occupying GPU KV cache blocks.
These are iterated first each schedule() call. Preempted
requests are pop()-ed from the tail (LIFO eviction for FCFS).
Lives inside each SingleTypeKVCacheManager. Maps request ID to the
ordered list of KV cache blocks assigned. Block IDs are append-only -- new blocks are
added but existing entries never change position.
A BlockHashToBlockMap in BlockPool.
Maps BlockHashWithGroupId to one or more KVCacheBlock
objects. This is the hash table that makes prefix caching O(1) per block lookup.
A custom doubly-linked list of free/eviction-candidate blocks. Supports O(1)
popleft(), append(), and crucially
remove(block) from the middle -- impossible with Python's built-in deque.
The per-block metadata unit. Fields: block_id,
ref_cnt, _block_hash,
prev_free_block, next_free_block,
is_null. The linked-list pointers embed the block into the free queue
without allocating wrapper objects.
# From vllm/v1/core/kv_cache_utils.py - The KVCacheBlock dataclass @dataclass(slots=True) class KVCacheBlock: block_id: int ref_cnt: int = 0 _block_hash: BlockHashWithGroupId | None = None # Doubly linked list pointers for FreeKVCacheBlockQueue prev_free_block: "KVCacheBlock | None" = None next_free_block: "KVCacheBlock | None" = None is_null: bool = False
Request Object — What the Scheduler SeesEverything the scheduler reasons about lives on a single Request object (vllm/v1/request.py line 58). It's not a frozen dataclass but a plain class that carries three kinds of data: identity / inputs (what the user sent), progress counters (how far we've gotten), and status / lifecycle flags (which queue it belongs in).
# vllm/v1/request.py lines 77-177 -- Request.__init__ assignments (condensed) class Request: # ---- 1. IDENTITY & METADATA (set once at construction) ---- self.request_id: str self.client_index: int # which engine client sent this self.priority: int = 0 # lower = higher priority self.arrival_time: float # tie-breaker for priority self.lora_request: LoRARequest | None self.cache_salt: str | None # separates prefix-cache namespaces self.trace_headers: Mapping[str, str] | None # ---- 2. INPUTS / SAMPLING CONFIG ---- self.prompt_token_ids: list[int] | None self.prompt_embeds: torch.Tensor | None self.num_prompt_tokens: int self.mm_features: list[MultiModalFeatureSpec] # vision/audio self.sampling_params: SamplingParams | None self.pooling_params: PoolingParams | None self.max_tokens: int # hard cap on output length self.structured_output_request: StructuredOutputRequest | None # ---- 3. TOKEN STREAMS (mutate as tokens are sampled) ---- self._output_token_ids: list[int] # verified outputs only self._all_token_ids: list[int] # prompt + output, concatenated self.spec_token_ids: list[int] # unverified speculative drafts # ---- 4. PROGRESS COUNTERS (the scheduler's core state) ---- self.num_computed_tokens: int = 0 # KV written for positions [0, N) self.num_output_placeholders: int = 0 # async-scheduling IOU self.num_external_computed_tokens: int = 0 # from KVConnector (P/D) self.num_cached_tokens: int = -1 # prefix-cache hit on first schedule self.num_preemptions: int = 0 # times kicked out of running # ---- 5. STATUS / LIFECYCLE ---- self.status: RequestStatus = RequestStatus.WAITING self.events: list[EngineCoreEvent] # QUEUED / SCHEDULED / PREEMPTED ... self.stop_reason: int | str | None self.is_prefill_chunk: bool = False self.num_nans_in_logits: int = 0 # ---- 6. KV CACHE INDEXING ---- self.block_hashes: list[BlockHash] # chained hashes per full block self._block_hasher: Callable # called on token append self.skip_reading_prefix_cache: bool # ---- 7. CONNECTOR / STREAMING EXTRAS ---- self.kv_transfer_params: dict | None # P/D disaggregation self.resumable: bool # supports streaming input self.streaming_queue: deque[StreamingUpdate | None] | None self.discard_latest_async_tokens: bool
Several "fields" are actually properties derived on demand (vllm/v1/request.py lines 226-242):
@property def num_tokens(self) -> int: return len(self._all_token_ids) # prompt + verified output @property def num_tokens_with_spec(self) -> int: return len(self._all_token_ids) + len(self.spec_token_ids) @property def num_output_tokens(self) -> int: return len(self._output_token_ids) @property def has_encoder_inputs(self) -> bool: return len(self.mm_features) > 0
Status drives which queue a request belongs in. Anything after PREEMPTED in the enum is considered finished (is_finished() at line 317 uses status > PREEMPTED):
# vllm/v1/request.py lines 295-311 class RequestStatus(enum.IntEnum): WAITING = enum.auto() # in self.waiting (admission candidates) WAITING_FOR_STRUCTURED_OUTPUT_GRAMMAR = enum.auto()# in skipped_waiting until grammar ready WAITING_FOR_REMOTE_KVS = enum.auto() # in skipped_waiting, connector loading WAITING_FOR_STREAMING_REQ = enum.auto() # in skipped_waiting, next chunk RUNNING = enum.auto() # in self.running (has KV blocks) PREEMPTED = enum.auto() # just freed KV, back to waiting FINISHED_STOPPED = enum.auto() # stop token / stop string matched FINISHED_LENGTH_CAPPED = enum.auto() # hit max_tokens FINISHED_ABORTED = enum.auto() # client disconnected / finish_requests() FINISHED_IGNORED = enum.auto() # prompt > max_model_len FINISHED_ERROR = enum.auto() # model NaN etc. FINISHED_REPETITION = enum.auto()
_is_blocked_waiting_status() treats WAITING_FOR_* as "don't try to schedule me" and shunts them into skipped_waiting. Plain WAITING is the only status that _select_waiting_queue_for_scheduling() will pull into running.
Request.__lt__ is what heapq in PriorityRequestQueue uses to order. Three tie-breakers:
# vllm/v1/request.py lines 281-292 def __lt__(self, other: "Request") -> bool: if self.priority != other.priority: return self.priority < other.priority # 1. lower priority wins if self.arrival_time != other.arrival_time: return self.arrival_time < other.arrival_time # 2. earlier arrival wins if self.request_id != other.request_id: return self.request_id < other.request_id # 3. lex request_id wins return id(self) < id(other) # 4. pointer tiebreak (deterministic)
Each engine step, schedule() makes two decisions in order: (1) which of the already-running requests continue, and (2) which waiting requests get admitted. Both phases are governed by the same constraints, evaluated in this priority order:
| # | Constraint | Code location | If violated |
|---|---|---|---|
| 1 | token_budget > 0 (tokens left in this step) |
scheduler.py:385, 567 | Stop both loops |
| 2 | len(running) < max_num_running_reqs |
scheduler.py:568-569 | Stop admitting from waiting |
| 3 | Status is WAITING (not blocked) | scheduler.py:578-588 | Move to skipped_waiting, try next |
| 4 | max_loras budget per step |
scheduler.py:592-603 | Skip this LoRA, try next request |
| 5 | KV cache has enough free blocks | allocate_slots() returns None |
RUNNING: preempt victim & retry; WAITING: break |
| 6 | Encoder compute budget (MM models) | _try_schedule_encoder_inputs |
Truncate num_new_tokens to before MM input |
| 7 | num_new_tokens > 0 after all caps |
scheduler.py:442-458 | RUNNING: continue to next; WAITING: break |
The scheduler iterates self.running in list order. For each one, it computes how many tokens it would like to serve (the gap formula from earlier), applies all caps, then asks allocate_slots() to reserve KV. The answer is always "yes" unless memory runs out -- and even then, the scheduler preempts a lower-priority victim and retries rather than rejecting the current request.
self.running is a plain list, not a priority queue. A request keeps its admission position. Under FCFS preemption pops the tail (self.running.pop()), so the most recently admitted request is the first to lose its KV -- classic LIFO eviction.
Only if no preemptions happened in the running loop (scheduler.py:564). The queue traversal is:
_select_waiting_queue_for_scheduling() picks between self.waiting and self.skipped_waiting. FCFS: skipped_waiting first (older entries). PRIORITY: whichever queue head sorts lower by Request.__lt__.
peek_request() looks but does not pop. If the request can't be scheduled this step, it stays where it is (or moves to skipped_waiting).
If status is WAITING_FOR_STRUCTURED_OUTPUT_GRAMMAR / REMOTE_KVS / STREAMING_REQ, _try_promote_blocked_waiting_request() checks if the blocker resolved. If not, pop_request() + prepend into step_skipped_waiting.
If this request's LoRA adapter would exceed max_loras in this step, skip it -- same treatment as blocked status.
get_computed_blocks() scans request.block_hashes against cached_block_hash_to_block. Any hit means those blocks are free for this request -- its effective num_new_tokens shrinks by num_new_local_computed_tokens.
If scheduler_reserve_full_isl is enabled, require the whole sequence (prompt + max output) to fit before admitting. Prevents a common pathology where chunked prefill admits something that preempts itself a few steps later.
Final gate. If it returns None we break out of the waiting loop entirely -- no point trying more requests when the pool is full. Unlike the RUNNING phase, the scheduler does not preempt running requests to admit waiting ones.
request_queue.pop_request(), self.running.append(request), set status to RUNNING, record its blocks, subtract from token_budget. Classify as scheduled_new_reqs (fresh) or scheduled_resumed_reqs (was PREEMPTED).
Boiled down, the scheduler's logic is:
# Pseudocode of the full policy for req in self.running: # in admission order if token_budget == 0: break gap = req.num_tokens_with_spec + req.num_output_placeholders - req.num_computed_tokens if gap == 0: continue # async skip-ahead n = min(gap, long_prefill_threshold, token_budget, max_model_len_slack) while allocate_slots(req, n) is None: preempt_lowest_priority_or_tail() # free blocks, retry token_budget -= n if not preempted_reqs: # admission only if no churn while (waiting or skipped_waiting) and token_budget > 0: if len(running) == max_num_running_reqs: break req = _select_waiting_queue_for_scheduling().peek_request() if blocked(req) or lora_full(req): move_to_skipped_waiting(req); continue new_computed, num_hit = get_computed_blocks(req) n = min(req.num_tokens - num_hit, long_prefill_threshold, token_budget) if allocate_slots(req, n, new_computed_blocks=new_computed) is None: break # pool full -- no retry running.append(req); token_budget -= n
schedule() Method -- Step by StepThe scheduler's main entry point is called once per engine step. Its philosophy is stated in a key comment at the top of the method:
# NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. # Each request just has the num_computed_tokens and # num_tokens_with_spec. [...] At each step, the scheduler tries # to assign tokens to the requests so that each request's # num_computed_tokens can catch up its num_tokens_with_spec. # This is general enough to cover chunked prefills, prefix caching, # speculative decoding, and the "jump decoding" optimization.
num_tokens_with_spec - num_computed_tokens == 1. A "prefill" is any request where the gap is larger. Chunked prefill simply caps how much of that gap is scheduled in a single step.
num_tokens_with_spec and num_computed_tokens?These two counters on Request are the only state the scheduler cares about when deciding how much work to give a request. Everything else -- prefill vs decode, chunk size, speculative drafts -- is derived from them.
# vllm/v1/request.py line 135 -- num_computed_tokens is a plain attribute self.num_computed_tokens = 0 # vllm/v1/request.py lines 226-231 -- num_tokens and num_tokens_with_spec are properties @property def num_tokens(self) -> int: return len(self._all_token_ids) # prompt + output (verified) @property def num_tokens_with_spec(self) -> int: return len(self._all_token_ids) + len(self.spec_token_ids)
The number of token positions the request theoretically occupies right now: prompt + verified outputs + speculative drafts. Speculative drafts are included because the scheduler must reserve KV slots for them before the target model verifies them.
The number of positions for which the model has actually run forward and whose KV is written on GPU. Advanced by _update_after_schedule() at the end of each step; rolled back by update_from_output() when speculative drafts get rejected.
The quantity the RUNNING loop computes at line 404-408 is literally just the gap:
num_new_tokens = (
request.num_tokens_with_spec
+ request.num_output_placeholders # async scheduling: placeholders for in-flight samples
- request.num_computed_tokens
)
In classic LLM serving, prefill (process the whole prompt in one big batch) and decode (sample one token at a time) are separate phases, often with separate schedulers. vLLM V1 collapses them into "how many positions need computing this step?":
| Classic view | V1 view (the gap) | What the scheduler does |
|---|---|---|
| Prefill: run whole prompt | gap >> 1 |
Schedule up to min(gap, token_budget, long_prefill_threshold) tokens |
| Decode: sample 1 new token | gap == 1 |
Schedule 1 token (consumes 1 from the same pool) |
| Chunked prefill | gap >> 1, threshold caps chunk |
Same formula; long_prefill_token_threshold slices the gap |
| Speculative decode | gap == 1 + num_spec |
Drafts are already in num_tokens_with_spec; no special code path |
| Continuous batching | Mix of gaps in one step | Falls out for free -- prefills and decodes share the same budget and batch |
allocate_slots() call. A single engine step happily mixes a big first-time prefill, a chunked-prefill continuation, and a dozen decodes in the same batch -- because to the scheduler they're all just "requests whose num_computed_tokens is behind their num_tokens_with_spec." That is precisely what the NOTE(woosuk) comment above is promising.
Set token_budget = self.max_num_scheduled_tokens.
Create empty collections for scheduled new/running/resumed requests, new block mappings,
and encoder inputs. Call kv_cache_manager.new_step_starts().
Iterate self.running. For each request, calculate
num_new_tokens = request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens.
Apply chunked prefill cap, then min(num_new_tokens, token_budget).
Call allocate_slots(). If allocation fails, preempt from the tail.
Only if no preemptions occurred. Pop from waiting queue, look up prefix cache hits,
compute num_new_tokens = request.num_tokens - num_computed_tokens,
apply chunked prefill threshold, call allocate_slots().
Move successful requests to self.running.
Pack NewRequestData (for first-time requests) and
CachedRequestData (for continuing requests) with block IDs
and computed token counts. Compute common prefix blocks for cascade attention.
_update_after_schedule() advances each request's
num_computed_tokens by the scheduled amount so the next
call to schedule() sees updated state. Clear
finished_req_ids.
There is a single token_budget variable shared between prefill and decode. Decode tokens consume 1 token each from the same pool. Long prefills can consume many tokens. This is why vLLM V1 can naturally perform "continuous batching" -- there are no separate phases.
# scheduler.py lines 360-381 -- Budget + bookkeeping init, new step signal scheduled_new_reqs: list[Request] = [] scheduled_resumed_reqs: list[Request] = [] scheduled_running_reqs: list[Request] = [] preempted_reqs: list[Request] = [] req_to_new_blocks: dict[str, KVCacheBlocks] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens if self._pause_state == PauseState.PAUSED_ALL: # Do not schedule any requests when paused. token_budget = 0 # Encoder-related (multi-modal models). scheduled_encoder_inputs: dict[str, list[int]] = {} encoder_compute_budget = self.max_num_encoder_input_tokens # Spec decode-related (draft tokens previously proposed). scheduled_spec_decode_tokens: dict[str, list[int]] = {} scheduled_timestamp = time.monotonic() self.kv_cache_manager.new_step_starts() # Per-type managers drain their take_new_block_ids buffers
SingleTypeKVCacheManager accumulates newly allocated block IDs across the current step. new_step_starts() clears that buffer so the next take_new_block_ids() at the end of this step only reports blocks allocated within it -- used for KV cache zeroing on certain hardware and for KV events exported to external consumers.
The budget is decremented identically for running and waiting requests -- that uniformity is the whole reason prefill and decode can live in the same loop:
# For RUNNING requests (scheduler.py line 516-517): num_scheduled_tokens[request_id] = num_new_tokens token_budget -= num_new_tokens # For WAITING requests (scheduler.py line 825-826): num_scheduled_tokens[request_id] = num_new_tokens token_budget -= num_new_tokens
max_num_scheduled_tokens (usually set to max_num_batched_tokens).
| Config Parameter | Default | Effect |
|---|---|---|
max_num_scheduled_tokens |
max_num_batched_tokens |
Total token budget per step (shared across all requests) |
max_num_seqs |
256 | Maximum requests in the running list (max_num_running_reqs) |
long_prefill_token_threshold |
0 (disabled) | Cap single-request prefill chunk size (enables chunked prefill) |
The running queue is iterated first. This ensures requests that already hold KV cache blocks get to continue, minimizing wasted GPU memory from preempted partial computations.
# scheduler.py lines 383-411 -- Iterate RUNNING + compute num_new_tokens # First, schedule the RUNNING requests. req_index = 0 while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] # ---- ASYNC SCHEDULING SKIP-AHEAD (lines 388-402) ---- # If we're past max_tokens even after rejecting all draft tokens, # skip scheduling this request for one more step. The arithmetic # is: num_computed_tokens + 1 (the step we're scheduling) minus # (num_output_placeholders - 1) so we account for draft tokens # that might all be rejected downstream. if ( request.num_output_placeholders > 0 and request.num_computed_tokens + 2 - request.num_output_placeholders >= request.num_prompt_tokens + request.max_tokens ): req_index += 1 continue # ---- HOW MANY NEW TOKENS TO SCHEDULE ---- # num_tokens_with_spec = prompt + output + spec_tokens # num_output_placeholders is non-zero only in async scheduling. num_new_tokens = ( request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens ) # Chunked prefill threshold (0 disables chunking). if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) # Never let scheduled input pos overrun max_model_len - 1. num_new_tokens = min( num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens )
max_tokens -- even if every draft spec token ends up rejected -- there's nothing left to do. continue skips to the next running request instead of break-ing, so newer requests can still fill the slot.
Waiting requests are only considered if no preemptions occurred in the running phase and the scheduler is not paused. This avoids a churn cycle where newly admitted requests immediately trigger more preemptions.
# scheduler.py lines 564-603 -- WAITING 請求准入階段 # 守門: 若 RUNNING 階段剛搶占過人,本步放棄 admit,避免 preempt → admit → 再 OOM 的 thrash cycle if not preempted_reqs and self._pause_state == PauseState.UNPAUSED: # 本地臨時佇列: 只存「本步」被跳過的請求 # 不能直接塞回 self.skipped_waiting, 否則下次 peek 拉到同一個人 → 無限迴圈 step_skipped_waiting = create_request_queue(self.policy) # 迴圈退出三條件: (1) 兩個來源佇列都空 (2) token 預算用完 while (self.waiting or self.skipped_waiting) and token_budget > 0: # (3) batch 寬度上限 (max_num_seqs), 達到就直接 break — token 再多也救不了 if len(self.running) == self.max_num_running_reqs: break # 從 self.waiting / self.skipped_waiting 擇一 (FCFS 偏 skipped; PRIORITY 比 head) request_queue = self._select_waiting_queue_for_scheduling() # 注意是 peek 不是 pop — 還沒決定要不要 admit, 先看一下 request = request_queue.peek_request() # ---- 過濾層 1: async 阻塞狀態 ---- # _is_blocked_waiting_status: 是否為 WAITING_FOR_{GRAMMAR, REMOTE_KVS, STREAMING_REQ} # _try_promote_blocked_waiting_request: 若依賴已完成, flip 回 WAITING 並回傳 True # 組合邏輯: 仍阻塞 AND 無法晉升 → 延後 if self._is_blocked_waiting_status(request.status) \ and not self._try_promote_blocked_waiting_request(request): request_queue.pop_request() # 確認延後才從來源 pop step_skipped_waiting.prepend_request(request) # 塞到本地隔離池, 本步不再碰 continue # 試下一個請求 (而非 break) # ---- 過濾層 2: LoRA 預算 ---- # 四重 AND, 關鍵在第 4 條: 同一個 LoRA 多請求共用不額外占預算 (adapter 只載一次), # 只有「引入新 LoRA 種類且已達上限」才會被擋 if (self.lora_config # 1. 系統啟用 LoRA and request.lora_request # 2. 此請求帶 adapter and len(scheduled_loras) == self.lora_config.max_loras # 3. 本步 LoRA 已滿 and request.lora_request.lora_int_id not in scheduled_loras): # 4. 是新 LoRA request_queue.pop_request() step_skipped_waiting.prepend_request(request) continue # (後續: prefix cache lookup → allocate_slots → 真正搬進 self.running) # 迴圈結束後 (scheduler.py:848-850), 把本步隔離池整批搬回 self.skipped_waiting # 這是讓 step_skipped_waiting 能防止無限迴圈的關鍵: 只在迴圈外才合併 if step_skipped_waiting: self.skipped_waiting.prepend_requests(step_skipped_waiting)
Waiting requests get a prefix cache lookup (get_computed_blocks()). On a fresh request (num_computed_tokens == 0), a KVConnector (for P/D disaggregation or remote KV transfer) may additionally contribute externally computed tokens:
# scheduler.py lines 605-650 -- Cache lookup + connector-aware match count num_external_computed_tokens = 0 load_kv_async = False if request.num_computed_tokens == 0: # 1. Local prefix cache hit (from cached_block_hash_to_block). new_computed_blocks, num_new_local_computed_tokens = ( self.kv_cache_manager.get_computed_blocks(request) ) # 2. Optional: remote / external cache (e.g. prefill-decode disagg). if self.connector is not None: ext_tokens, load_kv_async = self.connector.get_num_new_matched_tokens( request, num_new_local_computed_tokens ) if ext_tokens is None: # Connector needs a retry -- park in skipped_waiting. request_queue.pop_request() step_skipped_waiting.prepend_request(request) continue num_external_computed_tokens = ext_tokens request.num_external_computed_tokens = ext_tokens num_computed_tokens = min( num_new_local_computed_tokens + num_external_computed_tokens, request.num_tokens, ) else: # Resumed / preempted / KV-transfer-completed requests already # have num_computed_tokens > 0. new_computed_blocks = self.kv_cache_manager.empty_kv_cache_blocks num_new_local_computed_tokens = 0 num_computed_tokens = request.num_computed_tokens
self.skipped_waiting and re-checked each step. This avoids head-of-line blocking in the main queue.
Between self.waiting and self.skipped_waiting, which does the scheduler drain first? Under FCFS, skipped_waiting wins (older requests are ahead). Under PRIORITY, the heads of both queues are compared and the lower-priority tuple wins:
# scheduler.py lines 1572-1582 -- queue selection def _select_waiting_queue_for_scheduling(self) -> RequestQueue | None: if self.policy == SchedulingPolicy.FCFS: return self.skipped_waiting or self.waiting or None # PRIORITY: compare queue heads. if self.waiting and self.skipped_waiting: waiting_req = self.waiting.peek_request() skipped_req = self.skipped_waiting.peek_request() return self.waiting if waiting_req < skipped_req else self.skipped_waiting return self.waiting or self.skipped_waiting or None
Request.__lt__ orders by (priority, arrival_time), so lower numeric priority (or earlier arrival as tie-break) sorts first. The two underlying queue classes are:
Subclasses collections.deque. add_request() appends to the right; pop_request() is popleft; prepend_request() is appendleft. All O(1).
Wraps a heapq min-heap keyed on Request.__lt__. prepend_request() collapses to add_request() -- you cannot jump the heap order. remove_request() triggers an O(n) heapify.
Chunked prefill is controlled by long_prefill_token_threshold. When a request's remaining tokens exceed this threshold, only the threshold amount is scheduled. The request stays in the running queue with updated num_computed_tokens and will get more tokens next step.
# scheduler.py lines 409-411 -- RUNNING branch if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget)
# scheduler.py lines 660-668 -- WAITING branch # We use request.num_tokens (not num_prompt_tokens) so resumed/preempted # requests with already-sampled outputs are recomputed correctly. num_new_tokens = request.num_tokens - num_computed_tokens threshold = self.scheduler_config.long_prefill_token_threshold if 0 < threshold < num_new_tokens: num_new_tokens = threshold
Without chunked prefill enabled, the scheduler enforces that the full request fits in the budget:
# scheduler.py lines 672-681 -- Admission with chunking disabled if ( not self.scheduler_config.enable_chunked_prefill and num_new_tokens > token_budget ): # Whole prefill wouldn't fit, and we can't slice it -- stop. break num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0
scheduler_reserve_full_isl is enabled, waiting requests must pass a can_fit_full_sequence() admission gate that checks if the KV cache has enough blocks for the entire sequence (not just the first chunk). This prevents over-admission that would lead to preemptions later.
Every block of tokens has a unique hash computed as a chained hash. The hash of block N depends on the hash of block N-1, creating a Merkle-chain. The block size for hashing defaults to 16 tokens (configurable via block_size).
# kv_cache_utils.py -- hash_block_tokens() def hash_block_tokens( hash_function: Callable[[Any], bytes], parent_block_hash: BlockHash | None, curr_block_token_ids: Sequence[int], extra_keys: tuple[Any, ...] | None = None, ) -> BlockHash: if not parent_block_hash: parent_block_hash = NONE_HASH # Random seed or from PYTHONHASHSEED curr_block_token_ids_tuple = tuple(curr_block_token_ids) return BlockHash( hash_function((parent_block_hash, curr_block_token_ids_tuple, extra_keys)) )
The hash chain is built incrementally. Each Request maintains a
block_hashes list that grows as tokens are appended. The block hasher
function is created per-model:
# kv_cache_utils.py -- get_request_block_hasher() def request_block_hasher(request: Request) -> list[BlockHash]: start_token_idx = len(request.block_hashes) * block_size # ... while True: end_token_idx = start_token_idx + block_size if end_token_idx > num_tokens: break # Only hash FULL blocks extra_keys, curr_mm_idx = generate_block_hash_extra_keys( request, start_token_idx, end_token_idx, curr_mm_idx ) block_tokens = request.all_token_ids[start_token_idx:end_token_idx] block_hash = hash_block_tokens( caching_hash_fn, prev_block_hash_value, block_tokens, extra_keys ) new_block_hashes.append(block_hash) start_token_idx += block_size prev_block_hash_value = block_hash return new_block_hashes
(mm_hash, start_offset) in the block hash. LoRA requests include the adapter name. Cache salt is included for the first block. This ensures two requests with the same token IDs but different LoRA adapters or images will have different block hashes.
For full attention, the lookup is a simple left-to-right scan. Each block hash is checked in the
cached_block_hash_to_block map. Because block hashes are chained,
a miss at position N guarantees misses at N+1, N+2, ...
# kv_cache_manager.py lines 176-216 -- top-level cache hit entry point def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: if not self.enable_caching or request.skip_reading_prefix_cache: return self.empty_kv_cache_blocks, 0 # CRITICAL: cap to prompt_length - 1. If every block matched, we would # have zero tokens to run through the model, so logits couldn't be # produced for sampling. One-token slack forces at least one recompute. max_cache_hit_length = request.num_tokens - 1 computed_blocks, num_new_computed_tokens = ( self.coordinator.find_longest_cache_hit( request.block_hashes, max_cache_hit_length ) ) return self.create_kv_cache_blocks(computed_blocks), num_new_computed_tokens
# UnitaryKVCacheCoordinator delegates to its only single-type manager. # single_type_kv_cache_manager.py -- FullAttentionManager.find_longest_cache_hit() max_num_blocks = max_length // block_size for block_hash in itertools.islice(block_hashes, max_num_blocks): if cached_block := block_pool.get_cached_block( block_hash, kv_cache_group_ids ): for computed, cached in zip(computed_blocks, cached_block): computed.append(cached) else: break # Chain broken -- no further hits possible if use_eagle and computed_blocks[0]: # Drop last matched block to force recompute for EAGLE hidden states. # EAGLE needs the last block's full hidden state, which the draft # proposer consumes next step; a cached block alone doesn't carry it. for computed in computed_blocks: computed.pop()
When a cached block is "hit", its ref_cnt is incremented. If the
block was an eviction candidate (ref_cnt == 0, sitting in the free queue),
it is removed from the free queue via O(1) linked-list surgery:
# block_pool.py -- touch() def touch(self, blocks: Sequence[KVCacheBlock]) -> None: for block in blocks: # ref_cnt=0 means this block is in the free list (eviction candidate) if block.ref_cnt == 0 and not block.is_null: self.free_block_queue.remove(block) # O(1) doubly-linked list removal block.ref_cnt += 1
This is a critical performance structure. Python's built-in deque does not support O(1) removal from the middle. The custom implementation uses the prev_free_block and next_free_block pointers directly on KVCacheBlock objects -- no wrapper nodes, no extra memory allocation.
Pops N blocks in a single traversal of the free queue. Each popped block may still carry a hash (it was an eviction candidate, not unused) -- so _maybe_evict_cached_block() strips its hash entry from cached_block_hash_to_block before the block's ref_cnt is raised:
# block_pool.py lines 320-350 -- get_new_blocks() with eviction hook def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: if num_blocks > self.get_num_free_blocks(): raise ValueError(f"Cannot get {num_blocks} free blocks from the pool") ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks) if self.enable_caching: for block in ret: # If this block was a cached eviction candidate, delete its # hash -> block entry so future lookups won't match a block # that's about to be repurposed. self._maybe_evict_cached_block(block) assert block.ref_cnt == 0 block.ref_cnt += 1 if self.metrics_collector: self.metrics_collector.on_block_allocated(block) else: for block in ret: assert block.ref_cnt == 0 block.ref_cnt += 1 if self.metrics_collector: self.metrics_collector.on_block_allocated(block) return ret
When a request finishes or is preempted, blocks are freed in reverse order. This ensures tail blocks (least reusable for prefix caching) are evicted first:
# single_type_kv_cache_manager.py -- free() def free(self, request_id: str) -> None: req_blocks = self.req_to_blocks.pop(request_id, []) ordered_blocks = reversed(req_blocks) # Tail blocks freed first! self.block_pool.free_blocks(ordered_blocks) self.num_cached_block.pop(request_id, None)
# block_pool.py -- free_blocks() def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: blocks_list = list(ordered_blocks) for block in blocks_list: block.ref_cnt -= 1 # Only return to free queue if ref_cnt dropped to 0 self.free_block_queue.append_n( [block for block in blocks_list if block.ref_cnt == 0 and not block.is_null] )
ref_cnt is 2. Freeing one request decrements to 1 -- the block stays allocated for the other request and is NOT added to the free queue. Only when ref_cnt reaches 0 does the block become an eviction candidate.
Every free-block queue operation (popleft, append, remove, append_n) has to splice pointers. Edge cases (empty list, first / last element) normally require branching. Two fake sentinel blocks (fake_free_list_head, fake_free_list_tail) eliminate those branches -- every real block is guaranteed to have a prev and next:
# kv_cache_utils.py lines 180-208 -- FreeKVCacheBlockQueue.__init__() def __init__(self, blocks: list[KVCacheBlock]) -> None: self.num_free_blocks = len(blocks) # Pre-link consecutive blocks: blocks[i] <<=>> blocks[i+1]. for i in range(self.num_free_blocks): if i > 0: blocks[i].prev_free_block = blocks[i - 1] if i < self.num_free_blocks - 1: blocks[i].next_free_block = blocks[i + 1] # Two fake sentinels with block_id=-1. They are NEVER popped and never appear in ret. self.fake_free_list_head = KVCacheBlock(block_id=-1) self.fake_free_list_tail = KVCacheBlock(block_id=-1) if self.num_free_blocks > 0: self.fake_free_list_head.next_free_block = blocks[0] blocks[0].prev_free_block = self.fake_free_list_head self.fake_free_list_tail.prev_free_block = blocks[-1] blocks[-1].next_free_block = self.fake_free_list_tail else: self.fake_free_list_head.next_free_block = self.fake_free_list_tail self.fake_free_list_tail.prev_free_block = self.fake_free_list_head
# kv_cache_utils.py lines 280-298 -- remove() is what deque can't do def remove(self, block: KVCacheBlock) -> None: if block.prev_free_block is None or block.next_free_block is None: # Caller bug: trying to remove a block not in the free list. raise RuntimeError(f"remove() called on invalid block: {block}") # Splice around `block` -- sentinels guarantee both neighbours exist. block.prev_free_block.next_free_block = block.next_free_block block.next_free_block.prev_free_block = block.prev_free_block block.prev_free_block = block.next_free_block = None self.num_free_blocks -= 1
When a big prefill needs dozens of new blocks in one shot, a single popleft_n() traversal is cheaper than N calls to popleft(): it only rewires the fake head pointer once, at the end:
# kv_cache_utils.py lines 247-278 -- popleft_n() def popleft_n(self, n: int) -> list[KVCacheBlock]: if n == 0: return [] assert self.num_free_blocks >= n self.num_free_blocks -= n curr_block = self.fake_free_list_head.next_free_block ret = [] for _ in range(n): ret.append(curr_block) last_block = curr_block curr_block = curr_block.next_free_block # Disconnect popped block from the list -- no neighbour patching yet. last_block.prev_free_block = None last_block.next_free_block = None # Rewire fake_head -> first surviving block exactly ONCE. if curr_block is not None: self.fake_free_list_head.next_free_block = curr_block curr_block.prev_free_block = self.fake_free_list_head return ret
vLLM V1 uses a recomputation-only preemption policy. There is no swapping of KV cache to CPU memory. When a running request cannot fit its new tokens and no free blocks remain, the scheduler preempts the lowest-priority request by freeing ALL its blocks and putting it back in the waiting queue with num_computed_tokens = 0.
# scheduler.py lines 956-976 -- _preempt_request() def _preempt_request(self, request: Request, timestamp: float) -> None: assert request.status == RequestStatus.RUNNING self.kv_cache_manager.free(request) # Decrement ref_cnt on every block self.encoder_cache_manager.free(request) request.status = RequestStatus.PREEMPTED request.num_computed_tokens = 0 # Force full recompute on re-admit if request.spec_token_ids: request.spec_token_ids = [] # Drafts are context-dependent, discard request.num_preemptions += 1 # FCFS: prepend_request is appendleft on the deque, so preempted # requests re-enter at the HEAD of waiting and get first chance. # PRIORITY: prepend degrades to add_request (heap order wins). self.waiting.prepend_request(request)
kv_cache_manager.free() decrements ref_cnt on every block owned by the request. Prefix-shared blocks (ref_cnt > 1 because another running request also uses them) survive the drop to ref_cnt ≥ 1 and stay in the cache. Blocks whose ref_cnt hits 0 go to the free queue and are still potentially recoverable on re-admit via prefix cache hit -- that's why "recompute only" is cheap in practice.
When allocate_slots() returns None for a running request, the scheduler enters a preemption loop. Under PRIORITY scheduling the victim may already have been scheduled this step, so the scheduler has to roll back the budget and encoder-cache reservations it made for that victim:
# scheduler.py lines 462-506 -- Full preemption loop with priority rollback while True: new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens, num_lookahead_tokens=self.num_lookahead_tokens, ) if new_blocks is not None: break # Success! # ---- OOM PATH ---- if self.policy == SchedulingPolicy.PRIORITY: # Victim = highest (priority, arrival_time) currently running. preempted_req = max( self.running, key=lambda r: (r.priority, r.arrival_time), ) self.running.remove(preempted_req) # If that victim was already scheduled earlier in THIS step, # roll back the reservations it made so token_budget and the # encoder compute budget reflect reality. if preempted_req in scheduled_running_reqs: pid = preempted_req.request_id scheduled_running_reqs.remove(preempted_req) token_budget += num_scheduled_tokens.pop(pid) req_to_new_blocks.pop(pid) scheduled_spec_decode_tokens.pop(pid, None) preempted_enc = scheduled_encoder_inputs.pop(pid, None) if preempted_enc: encoder_compute_budget += sum( preempted_req.get_num_encoder_embeds(i) for i in preempted_enc ) req_index -= 1 # Iterator stepped back since we deleted ahead of it else: # FCFS: pop tail (LIFO -- youngest in running is evicted first). preempted_req = self.running.pop() self._preempt_request(preempted_req, scheduled_timestamp) preempted_reqs.append(preempted_req) if preempted_req == request: # We preempted ourselves -- nothing more to free; give up. break
req_index than the one we're trying to schedule. Removing it shrinks self.running, so all later indices shift left by one. req_index -= 1 keeps the while loop pointing at the right next candidate.
| Aspect | vLLM V0 | vLLM V1 |
|---|---|---|
| Preemption strategy | Swap to CPU or recompute | Recompute only |
| KV blocks after preempt | May survive in CPU swap | All freed (ref_cnt decremented) |
| Recovery path | Swap back or recompute | Prefix cache hit + recompute remainder |
| num_computed_tokens | Preserved if swapped | Reset to 0 |
Speculative decoding asks a cheap draft model (or EAGLE head) to propose num_spec_tokens future tokens, then the target model verifies them in one forward pass. The scheduler's job is to (a) make room for draft tokens during schedule() and (b) claw back the num_computed_tokens it optimistically advanced when drafts are rejected.
# scheduler.py lines 520-536 -- Spec decode packing inside RUNNING loop if request.spec_token_ids: # How many of the drafts actually got a slot this step? num_scheduled_spec_tokens = ( num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders ) if num_scheduled_spec_tokens > 0: spec_token_ids = request.spec_token_ids if len(spec_token_ids) > num_scheduled_spec_tokens: # Token budget capped us below num_spec_tokens -- trim trailing drafts. spec_token_ids = spec_token_ids[:num_scheduled_spec_tokens] scheduled_spec_decode_tokens[request.request_id] = spec_token_ids # Drafts are consumed; update_draft_token_ids() will attach fresh ones. request.spec_token_ids = []
update_from_output() runs after the model has sampled. The number of generated tokens minus 1 tells us how many drafts were accepted; the rest were rejected and must be re-scheduled next step:
# scheduler.py lines 1366-1390 -- spec decode accept / reject accounting scheduled_spec_token_ids = ( scheduler_output.scheduled_spec_decode_tokens.get(req_id) ) if scheduled_spec_token_ids and generated_token_ids: num_draft_tokens = len(scheduled_spec_token_ids) # The last sampled token is "original" (the verification output of the final draft). num_accepted = len(generated_token_ids) - 1 num_rejected = num_draft_tokens - num_accepted # Roll back the optimistic advance that _update_after_schedule() applied. if request.num_computed_tokens > 0: request.num_computed_tokens -= num_rejected if request.num_output_placeholders > 0: request.num_output_placeholders -= num_rejected spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats, num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted, num_invalid_spec_tokens=scheduler_output.num_invalid_spec_tokens, request_id=req_id, )
num_computed_tokens and num_output_placeholders are both decremented. In async scheduling the placeholders are what kept a slot reserved for the spec tokens; stripping them here means the next schedule() call won't mistakenly believe the request has already produced those outputs.
After the target model samples, the proposer produces a new batch of drafts. update_draft_token_ids() (lines 1669-1689) filters drafts against the structured-output grammar (if any) and attaches them to request.spec_token_ids so the next schedule() pass will pick them up.
The scheduler produces a SchedulerOutput dataclass each step, which is sent to worker processes. To minimize IPC cost, it distinguishes between new requests (full data) and cached requests (only the diff).
Contains the full prompt_token_ids, sampling_params,
block_ids, mm_features,
lora_request, and num_computed_tokens.
Workers cache this data so it does not need to be resent.
Contains only new_block_ids (appended to existing blocks),
num_computed_tokens (updated), and
resumed_req_ids (set of requests that were preempted and are resuming --
these need full block_ids replacement rather than append).
# output.py -- SchedulerOutput @dataclass class SchedulerOutput: scheduled_new_reqs: list[NewRequestData] scheduled_cached_reqs: CachedRequestData num_scheduled_tokens: dict[str, int] # req_id -> num tokens this step total_num_scheduled_tokens: int # sum of above scheduled_spec_decode_tokens: dict[str, list[int]] scheduled_encoder_inputs: dict[str, list[int]] num_common_prefix_blocks: list[int] # for cascade attention finished_req_ids: set[str] free_encoder_mm_hashes: list[str] preempted_req_ids: set[str] | None = None new_block_ids_to_zero: list[int] | None = None kv_connector_metadata: KVConnectorMetadata | None = None # ... additional fields for structured output, spec decode, etc.
# scheduler.py lines 914-930 -- Final assembly scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, scheduled_encoder_inputs=scheduled_encoder_inputs, num_common_prefix_blocks=num_common_prefix_blocks, preempted_req_ids={req.request_id for req in preempted_reqs}, finished_req_ids=self.finished_req_ids, free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), new_block_ids_to_zero=new_block_ids_to_zero, )
The _make_cached_request_data() method (scheduler.py lines 1055-1113) efficiently packs only the diff. There are three subtle tricks:
# scheduler.py lines 1072-1103 -- _make_cached_request_data() core loop for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)): req_id = req.request_id req_ids.append(req_id) # TRICK 1: In PP (pipeline parallelism) without async scheduling, # sampled tokens have to travel scheduler -> first-stage worker, # so we ship just the chunk being fed this step (excluding drafts). if self.use_pp and not self.scheduler_config.async_scheduling: num_tokens = num_scheduled_tokens[req_id] - len( spec_decode_tokens.get(req_id, ()) ) token_ids = req.all_token_ids[ req.num_computed_tokens : req.num_computed_tokens + num_tokens ] new_token_ids.append(token_ids) # TRICK 2: Distinguish "new in running list" vs "previously scheduled". scheduled_in_prev_step = req_id in self.prev_step_scheduled_req_ids if idx >= num_running_reqs: assert not scheduled_in_prev_step resumed_req_ids.add(req_id) # This is a resumed (preempted) request # TRICK 3: Only ship full all_token_ids for entries worker doesn't have. # Running continuations that were in prev_step_scheduled_req_ids get # only block-id + num_computed_tokens on the wire. if not scheduled_in_prev_step: all_token_ids[req_id] = req.all_token_ids.copy() new_block_ids.append( req_to_new_blocks[req_id].get_block_ids(allow_none=True) ) num_computed_tokens.append(req.num_computed_tokens) num_output_tokens.append(req.num_output_tokens + req.num_output_placeholders)
new_block_ids (new blocks appended this step) plus num_computed_tokens cross the process boundary. The worker already has prompt + earlier output tokens cached in its input batch. Resumed (preempted → re-admitted) requests go to resumed_req_ids so the worker replaces its block table rather than appending to it.
After the output is built but before schedule() returns, _update_after_schedule() advances num_computed_tokens for every scheduled request. Doing it after building the output preserves the original token counts used to compute input indices; doing it before the model has actually run lets the next schedule() call treat this chunk as already computed -- vital for multi-step chunked prefill.
# scheduler.py lines 978-1010 -- _update_after_schedule() def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None: num_scheduled_tokens = scheduler_output.num_scheduled_tokens for req_id, num_scheduled_token in num_scheduled_tokens.items(): request = self.requests[req_id] request.num_computed_tokens += num_scheduled_token # Recompute is_prefill_chunk: are we still in the prefill regime? request.is_prefill_chunk = request.num_computed_tokens < ( request.num_tokens + request.num_output_placeholders ) # Flag is forwarded so the worker can gate structured-output logits work. scheduler_output.has_structured_output_requests |= ( request.use_structured_output and not request.is_prefill_chunk ) if request.has_encoder_inputs: self._free_encoder_inputs(request) # NOTE: Don't .clear() -- the SchedulerOutput still holds the set object. self.finished_req_ids = set()
update_from_output() (lines 1369-1383) subtracts the rejected count from num_computed_tokens so the next step re-schedules them.
The KVCacheCoordinator sits between the scheduler-facing
KVCacheManager and the per-layer-type
SingleTypeKVCacheManager instances. Three concrete
coordinators exist, chosen by the model's attention configuration:
| Coordinator | When Used | Prefix Caching |
|---|---|---|
KVCacheCoordinatorNoPrefixCache |
Prefix caching disabled or unsupported specs | No |
UnitaryKVCacheCoordinator |
Single KV cache group (e.g., all full attention) | Yes |
HybridKVCacheCoordinator |
Multiple cache groups (e.g., full + sliding window) | Yes, with fixed-point algorithm |
# kv_cache_coordinator.py -- Factory function def get_kv_cache_coordinator(...) -> KVCacheCoordinator: if not enable_caching: return KVCacheCoordinatorNoPrefixCache(...) if len(kv_cache_config.kv_cache_groups) == 1: return UnitaryKVCacheCoordinator(...) return HybridKVCacheCoordinator(...)
For models with mixed attention types (e.g., full attention + sliding window), each attention type may support a different cache hit length. The hybrid coordinator uses an iterative algorithm where each type either accepts or reduces the candidate length:
# kv_cache_coordinator.py -- HybridKVCacheCoordinator.find_longest_cache_hit() while True: curr_hit_length = hit_length for spec, group_ids, manager_cls in self.attention_groups: hit_blocks = manager_cls.find_longest_cache_hit( block_hashes=_get_block_hashes(spec), max_length=curr_hit_length, kv_cache_group_ids=group_ids, block_pool=self.block_pool, kv_cache_spec=spec, use_eagle=self.use_eagle, alignment_tokens=self.lcm_block_size, ) curr_hit_length = len(hit_blocks[0]) * spec.block_size if curr_hit_length >= hit_length: break # Converged -- no type reduced the length hit_length = curr_hit_length
This is the most complex method in the KV cache system. It handles prefix cache hits, sliding window cleanup, external (P/D disaggregation) tokens, and speculative decode lookahead -- all in one call. Here is the block layout it operates on:
# kv_cache_manager.py -- Block layout diagram from the source: # ----------------------------------------------------------------- # | <comp> | <new_comp> | <ext_comp> | <new> | <lookahead> | # ----------------------------------------------------------------- # | <to be computed> | # ----------------------------------------------------------------- # | <to be allocated> | # ----------------------------------------------------------------- # comp = request.num_computed_tokens # new_comp = num_new_computed_tokens (prefix cache hits this step) # ext_comp = num_external_computed_tokens (from P/D connector) # new = num_new_tokens (tokens to compute this step) # lookahead = num_lookahead_tokens (for spec decode)
# kv_cache_manager.py lines 257-267 -- full allocate_slots signature def allocate_slots( self, request: Request, num_new_tokens: int, num_new_computed_tokens: int = 0, # local prefix cache hits new_computed_blocks: KVCacheBlocks | None = None, num_lookahead_tokens: int = 0, # spec decode drafts num_external_computed_tokens: int = 0, # P/D connector hits delay_cache_blocks: bool = False, # async KV load path num_encoder_tokens: int = 0, # cross-attn (Whisper etc.) ) -> KVCacheBlocks | None:
The allocation has three phases. The key invariants: (a) no new request state is mutated until we've proven allocation will succeed, and (b) blocks are cached after allocation so the hash table never holds dangling references.
Calculate the total slot count (comp + new_comp + ext_comp + new + lookahead), clipped to max_model_len. First prune blocks outside the sliding window (they can be freed even if we abort later).
# kv_cache_manager.py lines 352-385 -- Phase 1 details num_local_computed_tokens = ( request.num_computed_tokens + num_new_computed_tokens ) total_computed_tokens = min( num_local_computed_tokens + num_external_computed_tokens, self.max_model_len, ) num_tokens_main_model = total_computed_tokens + num_new_tokens num_tokens_need_slot = min( num_tokens_main_model + num_lookahead_tokens, self.max_model_len, ) # Sliding-window cleanup BEFORE checking capacity, which may reduce # num_blocks_to_allocate and save us from a spurious OOM. self.coordinator.remove_skipped_blocks( request.request_id, total_computed_tokens ) num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( request_id=request.request_id, num_tokens=num_tokens_need_slot, new_computed_blocks=new_computed_block_list, num_encoder_tokens=num_encoder_tokens, total_computed_tokens=num_local_computed_tokens + num_external_computed_tokens, num_tokens_main_model=num_tokens_main_model, )
If the required block count exceeds what the pool can serve, return None so the scheduler knows to preempt a victim and retry:
# kv_cache_manager.py lines 387-389 if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): return None # OOM -- caller preempts and retries
First attach any prefix-cache hits (touch raises their ref_cnt and removes them from the free queue). Then pull fresh blocks from the pool. Finally commit hashes for any newly-full blocks:
# kv_cache_manager.py lines 391-425 -- Phase 3 details if (new_computed_block_list is not self.empty_kv_cache_blocks.blocks or num_external_computed_tokens > 0): # Attach cache hits and externally-loaded slots; also touches blocks. self.coordinator.allocate_new_computed_blocks( request_id=request.request_id, new_computed_blocks=new_computed_block_list, num_local_computed_tokens=num_local_computed_tokens, num_external_computed_tokens=num_external_computed_tokens, ) new_blocks = self.coordinator.allocate_new_blocks( request.request_id, num_tokens_need_slot, num_tokens_main_model, num_encoder_tokens, ) # P/D: async KV load defers caching until the transfer finishes. if not self.enable_caching or delay_cache_blocks: return self.create_kv_cache_blocks(new_blocks) # Cap cache commit at request.num_tokens to exclude unverified drafts. num_tokens_to_cache = min( total_computed_tokens + num_new_tokens, request.num_tokens, ) self.coordinator.cache_blocks(request, num_tokens_to_cache) return self.create_kv_cache_blocks(new_blocks)
min(total_computed + num_new, request.num_tokens) cap is the reason speculative decoding doesn't pollute the prefix cache. request.num_tokens is the verified-tokens count, so draft tokens beyond it live in blocks but carry no hash until acceptance.
Block ID 0 is reserved as a "null block" -- a sentinel used to represent positions
where blocks have been freed (e.g., outside a sliding window). Its
ref_cnt is not tracked:
# block_pool.py -- BlockPool.__init__() self.null_block = self.free_block_queue.popleft() # Takes block 0 self.null_block.is_null = True
When get_new_blocks() pops a block from the free queue, that
block might be a cached eviction candidate. The method calls
_maybe_evict_cached_block() to clean up its hash from the lookup table:
# block_pool.py -- _maybe_evict_cached_block() def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: block_hash = block.block_hash if block_hash is None: return False # Not cached, nothing to evict if self.cached_block_hash_to_block.pop(block_hash, block.block_id) is None: return False # Already removed from cache block.reset_hash() # Clear the block's hash metadata return True
Used in RLHF flows after weight updates to invalidate all cached blocks. Only succeeds if no requests are actively using blocks:
# block_pool.py -- reset_prefix_cache() def reset_prefix_cache(self) -> bool: num_used_blocks = self.num_gpu_blocks - self.get_num_free_blocks() if num_used_blocks != 1: # Only the null block should be "used" return False self.cached_block_hash_to_block = BlockHashToBlockMap() for block in self.blocks: block.reset_hash() return True
scheduler.py — Hand-by-Hand Section MapThis section is a companion reading guide. Open one of the two files below on PACE in your editor, then use the tables to jump directly to the line range you want to understand. Every range corresponds to a self-contained functional block; the prose below each table explains what the block does and what invariants it maintains.
/storage/home/hcoda1/1/hlin464/r-rs275-0/vllm-pace/vllm/v1/core/sched/scheduler.py
(2312 lines)/storage/home/hcoda1/1/hlin464/r-rs275-0/vllm-agent-kvcache/vllm/v1/core/sched/scheduler.py
(2538 lines, +226 lines, +11 insertion sites)
The file has one class, Scheduler(SchedulerInterface), starting near the top. Everything else is a method on that class. The table below lists every method in the order it appears, with the line range in both files so you can compare.
| Upstream lines | Fork lines | Block / method | What it does (one sentence) |
|---|---|---|---|
| 1–64 | 1–66 | imports + logger |
Pull in Request, KVCacheManager, EncoderCacheManager, KVConnector factories, and in the fork two extra modules: sched_trace and ToolCallEstimator. |
| 67–296 | 69–322 | __init__ |
Cache all configs, build KVCacheManager / EncoderCacheManager / KVConnector / ECConnector, create waiting / skipped_waiting / running queues. Fork also initializes pinned_requests and tool_call_estimator. |
| — | 328–372 | _pin_request / _unpin_request / _unpin_expired_requests |
(Fork-only) Continuum KV-pinning helpers. Pin holds a request’s blocks in VRAM across the tool-call gap of an agentic turn; unpin releases them; the sweep runs at the top of every schedule(). |
| 298–346 | 376–424 | _mamba_block_aligned_split |
Hybrid-model (Mamba + attention) helper that rounds chunked-prefill sizes so Mamba state is block-aligned. |
| 348–949 | 426–1142 | schedule() — the heart |
Produces one SchedulerOutput per step. Decomposed in Sub-section 3 below. |
| 951–954 | 1144–1147 | _build_kv_connector_meta |
Thin wrapper: delegates to connector.build_connector_meta(scheduler_output). |
| 956–976 | 1149–1169 | _preempt_request |
Frees a running request’s KV + encoder caches, resets num_computed_tokens to 0, increments num_preemptions, prepends back to waiting queue. |
| 978–1010 | 1171–1203 | _update_after_schedule |
Run after every schedule(): advances num_computed_tokens by the amount just scheduled, updates is_prefill_chunk, frees encoder inputs when prompt is done, clears finished_req_ids. |
| 1012–1053 | 1205–1246 | _update_request_as_session |
Streaming-input session handling: takes the next chunk from the client and re-queues the session as WAITING. |
| 1055–1113 | 1248–1306 | _make_cached_request_data |
Build the CachedRequestData for running + resumed reqs — IPC-efficient delta for workers. |
| 1115–1276 | 1308–1469 | _try_schedule_encoder_inputs |
Multi-modal encoder scheduling: pick which MM items to encode this step under encoder_compute_budget. |
| 1278–1300 | 1471–1493 | get_grammar_bitmask |
Collect per-request structured-output grammar bitmasks for the next forward pass. |
| 1302–1557 | 1495–1750 | update_from_output() |
Consume one ModelRunnerOutput from the worker: append sampled tokens, apply spec-decode acceptance/rejection, run stop criteria, publish streaming/pool outputs, trigger _free_request. |
| 1559–1564 | 1752–1757 | _is_blocked_waiting_status |
True if status is WAITING_FOR_REMOTE_KVS / FSM_COMPILATION / STREAMING_REQ — i.e. needs a skip-ahead. |
| 1566–1570 | 1759–1763 | _enqueue_waiting_request |
Push one request into either self.waiting or self.skipped_waiting depending on its status. |
| 1572–1582 | 1765–1775 | _select_waiting_queue_for_scheduling |
Interleave picks between self.waiting (normal) and self.skipped_waiting (previously-deferred) so skipped reqs can catch up. |
| 1584–1600 | 1777–1793 | _handle_stopped_request |
Called from update_from_output when a stop condition hits; emits FINISHED event and kicks off cleanup. |
| 1602–1625 | 1795–1818 | _get_routed_experts |
MoE helper: reads per-token routed-experts indices from the shared buffer. |
| 1627–1643 | 1820–1836 | _update_request_with_output |
Append new sampled token(s) to the request, emit cumulative delta for streaming. |
| 1645–1667 | 1838–1860 | _free_encoder_inputs |
Release encoder cache entries once the prompt past them has been computed. |
| 1669–1689 | 1862–1882 | update_draft_token_ids |
Fold async-drafter’s newly produced speculative tokens into each running request. |
| 1691–1727 | 1884–1920 | update_draft_token_ids_in_output |
Same as above but when spec tokens arrive alongside the main ModelRunnerOutput. |
| 1729–1731 | 1922–1924 | get_request_counts |
Return (num_waiting, num_running) for external observers. |
| 1733–1753 | 1926–1951 | add_request |
Register a new request. Fork additionally tells the ToolCallEstimator about the arrival under CONTINUUM policy. |
| 1755–1816 | 1953–2014 | finish_requests |
External-facing cancel/abort entry; sets status and triggers _free_request. |
| 1818–1834 | 2016–2060 | _free_request |
Release encoder + KV blocks for a finished request. Fork interposes Continuum-pin logic here: if the next agent turn is predicted to arrive soon, keep blocks via _pin_request instead of freeing. |
| 1836–1840 | 2062–2065 | _free_blocks |
Final step of freeing: tell KVCacheManager to drop blocks and remove from self.requests. |
| 1842–1858 | 2067–2084 | pause_state / set_pause_state / get_num_unfinished_requests |
Trivial property accessors for engine pause control and unfinished-count reporting. |
| 1860–1861 | 2086–2087 | has_finished_requests |
Returns True if there is anything to flush to clients. |
| 1863–1907 | 2089–2133 | reset_prefix_cache |
Wipe prefix cache; gated on no running requests (otherwise stale hits). |
| 1909–1921 | 2135–2147 | reset_connector_cache / reset_encoder_cache |
Drop connector / encoder caches (used by eval harnesses and tests). |
| 1931–1967 | 2157–2193 | make_stats |
Build the per-step SchedulerStats that Prometheus loggers consume. |
| 1969–1975 | 2195–2201 | _get_encoder_cache_usage |
Encoder-cache utilization ratio reported via make_stats. |
| 1977–1994 | 2203–2220 | make_spec_decoding_stats |
Aggregate spec-decode accept/reject counts into SpecDecodingStats. |
| 1996–2004 | 2222–2230 | shutdown |
Flush KV-event publisher, close the connector. |
| 2006–2008 | 2232–2233 | get_kv_connector |
Simple accessor used by other subsystems. |
| 2009–2038 | 2235–2264 | _connector_finished |
Ask the KVConnector whether it still needs the blocks (for delayed P/D free). |
| 2040–2076 | 2266–2302 | _update_waiting_for_remote_kv |
Finalize a request that was WAITING_FOR_REMOTE_KVS after connector reports done. |
| 2078–2109 | 2304–2335 | _try_promote_blocked_waiting_request |
If a WAITING_FOR_X request can now be promoted to WAITING, do so. |
| 2111–2138 | 2337–2364 | _update_from_kv_xfer_finished |
Handle the per-step list of finished KV transfers from the connector. |
| 2140–2241 | 2366–2467 | _update_requests_with_invalid_blocks |
P/D connector failure recovery: mark invalid blocks, rewind num_computed_tokens. |
| 2243–end | 2469–end | _handle_invalid_blocks |
Decide whether to recompute or abort each request hit by invalid blocks. |
__init__ AnatomyConstructor is long but linear. Here is what each chunk does:
| Upstream | Fork | What this block sets up |
|---|---|---|
| 78–115 | 80–117 | Cache vllm_config + sub-configs (scheduler/cache/lora/parallel/kv_events/observability). Decide max_num_running_reqs, max_num_scheduled_tokens, max_model_len, and KV-event flag. |
| 117–147 | 119–149 | Build self.connector (KV) + self.ec_connector (encoder) via factories. Record KV-load-failure policy. |
| 149–170 | 151–172 | self.block_size, dcp_world_size, pcp_world_size, init self.requests: dict[str, Request], parse policy = SchedulingPolicy(...), create waiting / skipped_waiting queues, init self.running. Fork inserts line 169: self._sched_step_counter = 0. |
| 171–202 | 173–204 | finished_req_ids set, streaming-input counter, async-KV recv/fail sets; MultiModalBudget computation. |
| 203–223 | 205–225 | max_num_encoder_input_tokens, encoder_cache_manager, speculative-decode flags (use_eagle, num_spec_tokens, num_lookahead_tokens). |
| 224–243 | 226–245 | Build self.kv_cache_manager = KVCacheManager(...), then bind the GPU block pool back to the KV connector so it can issue loads/stores against concrete blocks. |
| 244–296 | 247–299 | Misc flags: use_pp, use_v2_model_runner, scheduler_reserve_full_isl, Mamba-hybrid flags, perf_metrics, optional routed-experts reader, and self._pause_state = PauseState.UNPAUSED. |
| — | 300–322 | (Fork-only) Create self.pinned_requests: list[tuple[Request, float]] = [] and, when policy == SchedulingPolicy.CONTINUUM, instantiate self.tool_call_estimator. A broad try/except falls back to ToolCallEstimator(model_name=None) so the scheduler never fails to start. |
schedule() Anatomy — Nine Functional BlocksThis is the single most important method. Read it top-to-bottom in one of the two files while cross-checking the block map below.
| # | Upstream | Fork | Block name | What happens |
|---|---|---|---|---|
| 3.1 | 348–379 | 426–468 | Preamble & local bookkeeping | Declare per-step locals: scheduled_new_reqs, scheduled_resumed_reqs, scheduled_running_reqs, preempted_reqs, req_to_new_blocks, num_scheduled_tokens, scheduled_encoder_inputs, scheduled_spec_decode_tokens, token_budget, encoder_compute_budget, scheduled_timestamp = time.monotonic(). Call self.kv_cache_manager.new_step_starts(). |
| 3.2 | — | 438–446 | (Fork) Continuum pin sweep | Under policy == CONTINUUM: call self._unpin_expired_requests(), then push the current set of pinned job_ids into the waiting queue via waiting.update_pinned_state({job_ids}) so the ContinuumQueue can prioritise turns whose prefixes are still warm. |
| 3.3 | — | 471–522 | (Fork) sched_trace step_snapshot | If sched_trace.enabled(): probe GPU free blocks, GPU cached blocks, CPU tier stats (if SimpleCPUOffloadConnector present), and compute per-pin block counts. Call sched_trace.step_snapshot(...) with everything, then self._sched_step_counter += 1. This is the "before" half of the per-step trace record. |
| 3.4 | 383–551 | 524–667 | RUNNING queue loop | For each request already RUNNING: compute num_new_tokens, cap by long_prefill_token_threshold, token_budget, and max_model_len - 1 - num_computed_tokens. Handle encoder inputs. Call allocate_slots; if it fails, preempt according to policy (fork adds CONTINUUM branch at 640–658 preferring non-pinned victims). Append to scheduled_running_reqs, record blocks, consume budget. |
| 3.5 | 553–850 | 669–1015 | WAITING queue loop | Only runs if no preemptions happened in 3.4 and _pause_state == UNPAUSED. Interleaves self.waiting and self.skipped_waiting, checks prefix-cache hits (local + external via connector), LoRA limit, mamba-alignment, encoder budget, then calls allocate_slots. Promotes to running / PREEMPTED → scheduled_resumed_reqs / WAITING → scheduled_new_reqs. Fork inserts sched_trace prefix_cache_event at 806–817. |
| 3.6 | 852–863 | 1025–1043 | Sanity asserts | assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens, assert token_budget >= 0, assert len(self.running) <= max_num_running_reqs. Fork downgrades the ‘scheduled ≤ running’ assertion to a logger.warning because under CONTINUUM + A100-40GB KV pressure the counts can transiently mis-account — the hard assert killed experiments. |
| 3.7 | 865–902 | 1045–1082 | Common prefix & request-data construction | Compute num_common_prefix_blocks (used for cascade attention), then build new_reqs_data via NewRequestData.from_request(...) and cached_reqs_data via _make_cached_request_data(...). |
| 3.8 | 904–948 | 1084–1128 | Build SchedulerOutput & connector metadata | Clear and refill prev_step_scheduled_req_ids; compute new_block_ids_to_zero when needed; construct the final SchedulerOutput dataclass; attach KVConnector and ECConnector metadata via _build_kv_connector_meta(...) and ec_connector.build_connector_meta(...); call self._update_after_schedule(scheduler_output). |
| 3.9 | — | 1130–1141 | (Fork) sched_trace step_decision | If sched_trace.enabled(), emit sched_trace.step_decision(step=_sched_step_counter - 1, ...) with the final per-request verdict (new/resumed/running/preempted IDs and token counts). This is the "after" half of the per-step trace record; the −1 reconciles the counter that was incremented in 3.3. |
| 3.10 | 949 | 1142 | return scheduler_output |
One SchedulerOutput leaves the scheduler and is pickled over to the worker. |
_sched_step_counter appears to be incremented before it is used with −1. That is intentional: the snapshot in 3.3 records the state entering the step, the decision in 3.9 records the outcome of the same logical step. Both must share the same step number, hence −1 after the early increment.
update_from_output() AnatomyThis is the mirror of schedule(). schedule() decides what to compute; update_from_output consumes the worker’s answer and threads it back into request state, applying spec-decode acceptance/rejection and stop-criteria along the way. The body has seven blocks.
| # | Fork lines | Block name | What happens |
|---|---|---|---|
| 4.1 | 1495–1525 | Unpack ModelRunnerOutput |
Read req_ids, req_id_to_index, sampled_token_ids, logprobs, prompt_logprobs, pooler_outputs, optional finished_sending/recving from the KV connector; allocate output containers. |
| 4.2 | 1527–1610 | Per-request result loop | For each scheduled request, read its index into the output tensors, pull the newly sampled token, and append it via _update_request_with_output. Also append prompt logprobs if present. |
| 4.3 | 1612–1660 | Spec-decode accept / reject accounting | If the request had scheduled spec tokens, compare sampled vs drafted; adjust num_computed_tokens down by the rejected count so the next step resumes at the correct position. |
| 4.4 | 1662–1690 | Stop-criteria + structured-output hook | check_stop(request, max_model_len) handles EOS / stop_token_ids / max_tokens. Structured-output FSM advances per-token (for grammar-constrained decode). |
| 4.5 | 1692–1720 | Emit EngineCoreOutputs | Build one or more EngineCoreOutputs per client and bucket them into EngineCoreOutputs by client_index. Include finished_req_ids flush. |
| 4.6 | 1722–1740 | Connector xfer-finished handling | If the worker reported finished_sending / finished_recving, call _update_from_kv_xfer_finished to promote WAITING_FOR_REMOTE_KVS reqs or release sender-side blocks. |
| 4.7 | 1742–1750 | Return | Return the dict[int, EngineCoreOutputs]; EngineCore sends each bucket to the matching client via ZMQ. |
When you are inside schedule() and see a call like self._foo(...), use this quick jump table:
| Call you see | Jump to (fork line) | Lives in |
|---|---|---|
self.kv_cache_manager.new_step_starts() | — | vllm/v1/core/kv_cache_manager.py |
self.kv_cache_manager.allocate_slots(...) | — | vllm/v1/core/kv_cache_manager.py |
self.kv_cache_manager.get_computed_blocks(...) | — | vllm/v1/core/kv_cache_manager.py |
self.connector.get_num_new_matched_tokens(...) | — | vllm/distributed/kv_transfer/kv_connector/v1/ |
self._preempt_request(r, ts) | 1149–1169 | Same file |
self._try_schedule_encoder_inputs(...) | 1308–1469 | Same file |
self._mamba_block_aligned_split(...) | 376–424 | Same file |
self._make_cached_request_data(...) | 1248–1306 | Same file |
self._update_after_schedule(...) | 1171–1203 | Same file |
self._build_kv_connector_meta(...) | 1144–1147 | Same file |
self._try_promote_blocked_waiting_request(...) | 2304–2335 | Same file |
self._is_blocked_waiting_status(...) | 1752–1757 | Same file |
self._select_waiting_queue_for_scheduling() | 1765–1775 | Same file |
self._unpin_expired_requests() (fork) | 358–372 | Same file |
self._pin_request(r, pin_ttl) (fork) | 328–346 | Same file |
sched_trace.step_snapshot(...) (fork) | — | vllm/v1/core/sched/trace.py |
sched_trace.prefix_cache_event(...) (fork) | — | vllm/v1/core/sched/trace.py |
sched_trace.step_decision(...) (fork) | — | vllm/v1/core/sched/trace.py |
self.tool_call_estimator.request_arrives(r) (fork) | — | vllm/v1/core/estimate_with_func.py |
self.tool_call_estimator.request_finished(r) (fork) | — | vllm/v1/core/estimate_with_func.py |
self.tool_call_estimator.set_up_pin(r) (fork) | — | vllm/v1/core/estimate_with_func.py |
git pull or edit either scheduler.py, use grep -n 'def schedule' vllm/v1/core/sched/scheduler.py to re-anchor. The block names are stable; only the line numbers change.
The fork at vllm-agent-kvcache keeps the upstream scheduling contract intact and layers on two independent features:
When a turn of a multi-step agent finishes and a tool call is about to run (2–10 s on the wire), don’t free the request’s KV blocks. Instead pin them for a predicted TTL so the next turn reuses the prefix without recomputing. Requires policy == CONTINUUM; a no-op for FCFS/PRIORITY.
Emit a JSONL record for every call to schedule(): queue sizes, free/cached block counts (both GPU and CPU tiers), per-request prefix-cache hit/miss, and the final scheduling decision. Fully off when the VLLM_SCHED_TRACE_PATH env var is unset, so production cost is a single if sched_trace.enabled() check.
Inserted between upstream line 52 and 53:
# Fork scheduler.py lines 53-54 (no upstream equivalent) from vllm.v1.core.sched import trace as sched_trace from vllm.v1.core.estimate_with_func import ToolCallEstimator
Both modules live in the fork’s own tree. sched_trace is a small observability module; ToolCallEstimator is a tokenizer-backed predictor for “how long will the next tool call take?”
self._sched_step_counter: int = 0One line inserted at fork line 169 (no upstream equivalent), right after self.waiting = create_request_queue(...). Used to number trace records; see Change 6/9.
__init__Inserted after upstream line 296 (just before the end of __init__); fork lines 300–322. No upstream equivalent.
# Fork scheduler.py lines 301-322 # Continuum: track requests whose KV blocks are pinned across # tool-call gaps. Each entry is (request, unpin_wall_clock_time). self.pinned_requests: list[tuple[Request, float]] = [] # Continuum: tool-call execution time predictor for dynamic pin TTL. # Only instantiated when policy == CONTINUUM to avoid the tokenizer # load cost on other scheduling modes. self.tool_call_estimator: ToolCallEstimator | None = None if self.policy == SchedulingPolicy.CONTINUUM: mc = self.vllm_config.model_config model_name = ( getattr(mc, "tokenizer", None) or getattr(mc, "model", None) or getattr(mc, "served_model_name", None) ) try: self.tool_call_estimator = ToolCallEstimator(model_name=model_name) except Exception as e: logger.warning( "Continuum: failed to init ToolCallEstimator (%s); falling back " "to no-pin mode.", e, ) self.tool_call_estimator = ToolCallEstimator(model_name=None)
hf-mirror). A no-name estimator still works — it just returns a default TTL. This is more important than it looks: failing to start the scheduler kills the entire engine.
_pin_request, _unpin_request, _unpin_expired_requests)Inserted between upstream line 297 and 298 (between end of __init__ and _mamba_block_aligned_split); fork lines 328–372.
# Fork scheduler.py lines 328-346 def _pin_request(self, request: Request, pin_ttl: float = 10.0) -> None: """Pin request's KV blocks in VRAM for pin_ttl seconds. Evict any stale pin for the same job_id first so each job holds at most one pinned turn at a time. Shared prefix blocks are refcounted, so freeing the previous pin decrements without losing reusable KV. """ if getattr(request, "job_id", None): stale = [(r, t) for r, t in self.pinned_requests if getattr(r, "job_id", None) == request.job_id] for r, t in stale: self._unpin_request(r, t) self.pinned_requests.append((request, time.time() + pin_ttl)) # Fork scheduler.py lines 348-356 def _unpin_request(self, request: Request, end_time: float) -> None: self.pinned_requests.remove((request, end_time)) self.kv_cache_manager.free(request) del self.requests[request.request_id] # Fork scheduler.py lines 358-372 def _unpin_expired_requests(self) -> None: now = time.time() expired = [(req, t) for req, t in self.pinned_requests if now >= t] for req, t in expired: logger.debug("Continuum: unpinning expired request %s (job_id=%s)", req.request_id, req.job_id) self._unpin_request(req, t)
job_id. _pin_request evicts its own stale pin before appending, so the list never contains two entries for the same job. Prefix blocks are refcounted in the KV manager, so freeing a pin only drops the turn-specific suffix blocks — the shared prefix stays warm for the next turn.
schedule()Inserted at fork lines 438–446, immediately after the method’s docstring and before the per-step locals are declared. No upstream equivalent.
# Fork scheduler.py lines 438-446 if self.policy == SchedulingPolicy.CONTINUUM: self._unpin_expired_requests() if hasattr(self.waiting, "update_pinned_state"): self.waiting.update_pinned_state( {req.job_id for req, _ in self.pinned_requests if req.job_id} )
Why here: TTL expiry has to fire before any allocation decisions this step, otherwise an expired pin’s blocks could be treated as occupied and trigger spurious preemption. The update_pinned_state hand-off lets ContinuumQueue reorder waiting requests by whether their prefix is still warm.
sched_trace.step_snapshot at the top of schedule()Inserted at fork lines 471–522. Probes the block pool for free counts, the KV connector for CPU-tier stats, and computes per-pin block counts. All wrapped in if sched_trace.enabled(): — zero cost in production.
# Fork scheduler.py lines 471-522 (condensed) if sched_trace.enabled(): try: free_blocks = self.kv_cache_manager.block_pool.get_num_free_blocks() except AttributeError: free_blocks = None try: gpu_cached_blocks = len(self.kv_cache_manager.block_pool.cached_block_hash_to_block) except Exception: gpu_cached_blocks = None # Probe CPU offload connector (SimpleCPUOffloadConnector)... sched_trace.step_snapshot( step=self._sched_step_counter, running=self.running, waiting=self.waiting, free_blocks=free_blocks, total_blocks=self.cache_config.num_gpu_blocks, gpu_cached_blocks=gpu_cached_blocks, cpu_free_blocks=..., cpu_total_blocks=..., cpu_cached_blocks=..., num_pinned=len(self.pinned_requests), pinned_blocks=pinned_blocks, pinned_job_ids=[req.job_id for req, _ in self.pinned_requests if req.job_id], max_num_running_reqs=self.max_num_running_reqs, max_num_scheduled_tokens=self.max_num_scheduled_tokens, ) self._sched_step_counter += 1
Inserted at fork lines 640–658, between the upstream PRIORITY branch and the fallback FCFS branch. No upstream equivalent.
# Fork scheduler.py lines 640-658 elif self.policy == SchedulingPolicy.CONTINUUM: # Continuum: prefer to preempt requests that do NOT # belong to a pinned job, preserving warm KV blocks. pinned_job_ids = {req.job_id for req, _ in self.pinned_requests if req.job_id} unpinned = [r for r in self.running if getattr(r, "job_id", None) not in pinned_job_ids] if unpinned: preempted_req = unpinned[-1] self.running.remove(preempted_req) else: # No unpinned candidates; fall back to FCFS order. preempted_req = self.running.pop()
Keeps the invariant that warm pinned turns are the last thing to be kicked out of VRAM when pressure hits. If every running request belongs to a pinned job, falls back to upstream FCFS LIFO.
sched_trace.prefix_cache_eventInserted at fork lines 806–817, after the num_computed_tokens assertion in the WAITING-queue loop. Records per-request prefix-cache telemetry.
# Fork scheduler.py lines 806-817 if sched_trace.enabled(): sched_trace.prefix_cache_event( step=self._sched_step_counter - 1, req_id=request.request_id, job_id=getattr(request, "job_id", None), local_hit_tokens=num_new_local_computed_tokens, external_hit_tokens=num_external_computed_tokens, total_prompt_tokens=request.num_prompt_tokens, num_tokens=request.num_tokens, connector_name=type(self.connector).__name__ if self.connector is not None else None, load_kv_async=load_kv_async, )
assert scheduled <= running downgraded to warningUpstream (lines 861–863):
# Upstream scheduler.py lines 861-863 (HARD ASSERT) assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( scheduled_running_reqs ) <= len(self.running)
Fork (lines 1033–1043):
# Fork scheduler.py lines 1033-1043 (SOFT WARNING) # len(self.running). Downgraded to warning from assert: under heavy KV # pressure (A100 40GB + Continuum pin) the scheduler may transiently # mis-account; a hard assert kills the whole run. _sched_count = ( len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs) ) if _sched_count > len(self.running): logger.warning( "Scheduler count mismatch: scheduled=%d > running=%d (trimming to avoid crash).", _sched_count, len(self.running), )
self.running removal). Replace the logger.warning with actual trimming only if you can reproduce a persistent mismatch.
sched_trace.step_decision before returnInserted at fork lines 1130–1141, after _update_after_schedule and right before the final return scheduler_output.
# Fork scheduler.py lines 1130-1141 if sched_trace.enabled(): # _sched_step_counter was incremented after the snapshot; the # decision applies to the SAME logical step, so subtract 1. sched_trace.step_decision( step=self._sched_step_counter - 1, scheduled_new_req_ids=[r.request_id for r in scheduled_new_reqs], scheduled_resumed_req_ids=[r.request_id for r in scheduled_resumed_reqs], scheduled_running_req_ids=[r.request_id for r in scheduled_running_reqs], preempted_req_ids=[r.request_id for r in preempted_reqs], num_scheduled_tokens=num_scheduled_tokens, total_scheduled_tokens=total_num_scheduled_tokens, )
add_requestInserted at fork lines 1943–1947, inside the "non-streaming, non-resumable" branch of add_request.
# Fork scheduler.py lines 1943-1947 if ( self.policy == SchedulingPolicy.CONTINUUM and self.tool_call_estimator is not None ): self.tool_call_estimator.request_arrives(request)
Tells the estimator the wall-clock time at which the next turn of this job_id arrived, so it can learn the actual tool-call latency distribution for that job.
_free_request (the heart of Continuum)Inserted at fork lines 2022–2048, right after _connector_finished and before the normal free path.
# Fork scheduler.py lines 2022-2048 # Continuum: for non-final turns of a multi-step agentic job, consult # the tool-call estimator to decide whether pinning the KV blocks is # worthwhile. set_up_pin() returns 0 when the predicted tool-call # exec time exceeds FIXED_THRESHOLD_CONTINUUM (2.0s), in which case # the next turn won't arrive in time to reuse the pin — so we free # blocks immediately like FCFS would. if ( self.policy == SchedulingPolicy.CONTINUUM and not request.is_last_step and request.job_id and self.tool_call_estimator is not None ): self.tool_call_estimator.request_finished(request) pin_ttl = self.tool_call_estimator.set_up_pin(request) if pin_ttl > 0.0: self.encoder_cache_manager.free(request) request_id = request.request_id self.finished_req_ids.add(request_id) if self.finished_req_ids_dict is not None: self.finished_req_ids_dict[request.client_index].add(request_id) self._pin_request(request, pin_ttl=pin_ttl) logger.debug( "Continuum: pinning request %s (job_id=%s) for %.2fs", request.request_id, request.job_id, pin_ttl, ) return kv_xfer_params
policy == CONTINUUM — otherwise no-op, falls through to normal free.not request.is_last_step — the final answer doesn’t need a pin; pinning here would waste VRAM.request.job_id truthy — no job_id means no notion of “next turn”.pin_ttl > 0.0 — estimator returned zero ⇒ predicted tool-call time > 2 s ⇒ pin would be evicted by TTL anyway, so free now.| # | Upstream anchor | Fork lines | Change | Feature |
|---|---|---|---|---|
| 1 | after 52 | 53–54 | +2 imports | A + B |
| 2 | after 166 | 169 | +_sched_step_counter | B |
| 3 | after 296 | 300–322 | +pinned_requests, +tool_call_estimator | A |
| 4 | after 297 | 328–372 | +3 helper methods | A |
| 5 | inside schedule() | 438–446 | Pin sweep + waiting-queue hint | A |
| 6 | inside schedule() | 471–522 | step_snapshot probe + counter bump | B |
| 7 | inside preempt loop | 640–658 | CONTINUUM preemption branch | A |
| 8 | inside WAITING loop | 806–817 | prefix_cache_event | B |
| 9 | replaces 861–863 | 1033–1043 | Assert → warning | stability |
| 10 | before return | 1130–1141 | step_decision | B |
| 11 | inside add_request | 1943–1947 | estimator.request_arrives | A |
| 12 | inside _free_request | 2022–2048 | Pin-on-finish (the heart) | A |
policy == CONTINUUM or sched_trace.enabled(), so pulling in upstream fixes should not conflict with the fork’s behaviour.
| File | Key Classes / Functions | Responsibility |
|---|---|---|
sched/scheduler.py |
Scheduler.schedule() |
Main scheduling loop: running phase, waiting phase, preemption, output construction |
sched/output.py |
SchedulerOutput, NewRequestData, CachedRequestData |
IPC-optimized scheduler-to-worker communication protocol |
kv_cache_manager.py |
KVCacheManager |
Scheduler-facing facade for block allocation, prefix cache, and freeing |
kv_cache_coordinator.py |
UnitaryKVCacheCoordinator, HybridKVCacheCoordinator |
Coordinate multi-group cache hits, dispatch to per-type managers |
single_type_kv_cache_manager.py |
FullAttentionManager, SlidingWindowManager |
Per-attention-type block management, find_longest_cache_hit(), req_to_blocks |
block_pool.py |
BlockPool, BlockHashToBlockMap |
GPU block pool with free queue, hash-to-block cache, eviction logic |
kv_cache_utils.py |
KVCacheBlock, FreeKVCacheBlockQueue, hash_block_tokens() |
Block metadata, O(1) doubly-linked free list, chained hash computation |
Prefill and decode share one token budget. A single scheduling loop handles all request types uniformly.
Block IDs are never reordered within a request. New blocks are appended. This enables efficient GPU-side attention.
V1 dropped CPU swap. Prefix caching makes recomputation efficient -- most prefix blocks are recovered from cache.
The free queue embeds linked-list pointers in KVCacheBlock itself. No wrapper objects means no GC overhead.