A code-level deep dive into SGLang's scheduler architecture: the two event loops, continuous batching, overlap scheduling innovation, request lifecycle management, chunked prefill, and data-parallel dispatch.
SGLang's serving engine is organized around a Scheduler that manages one or more tensor-parallel GPU workers. The scheduler runs a tight event loop on the CPU, receiving requests, constructing batches, dispatching forward passes, and processing results. The key data flow documented in the source is:
"""
The following is the flow of data structures for a batch:
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
It is a subset of `ScheduleBatch` that only contains data related
to the model forward on GPU. It will be transformed from CPU scheduler
to GPU model runner.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
TokenizedGenerateReqInput ScheduleBatch ModelWorkerBatch ForwardBatch
======================== ===================== ======================= ==================
| rid, input_ids, | | reqs: List[Req] | | forward_mode | | GPU tensors: |
| sampling_params, | -> | forward_mode | ->| input_ids (tensor) | ->| input_ids |
| mm_inputs, ... | | seq_lens, out_cache | | seq_lens (tensor) | | positions |
======================== | sampling_info | | sampling_info | | attn metadata |
(from tokenizer) ===================== ======================= ==================
(CPU-side) (CPU->GPU bridge) (GPU-side)
recv_requests() get_next_batch_to_run() get_model_worker_batch() ForwardBatch.init_new()
The Scheduler class itself is built via mixin composition:
class Scheduler(
SchedulerOutputProcessorMixin,
SchedulerUpdateWeightsMixin,
SchedulerProfilerMixin,
SchedulerMetricsMixin,
SchedulerDisaggregationDecodeMixin,
SchedulerDisaggregationPrefillMixin,
SchedulerMultiplexMixin,
SchedulerRuntimeCheckerMixin,
SchedulerPPMixin,
SchedulerDPAttnMixin,
SchedulerDllmMixin,
):
"""A scheduler that manages a tensor parallel GPU worker."""
SGLang provides two event loop implementations. The dispatcher selects between them
at startup based on server_args.disable_overlap_schedule:
def dispatch_event_loop(scheduler: Scheduler):
server_args = scheduler.server_args
disaggregation_mode = scheduler.disaggregation_mode
if disaggregation_mode == DisaggregationMode.NULL:
if scheduler.enable_pdmux:
scheduler.event_loop_pdmux()
elif server_args.pp_size > 1:
scheduler.event_loop_pp()
elif scheduler.enable_overlap:
scheduler.event_loop_overlap() # <-- overlap scheduling (DEFAULT)
else:
scheduler.event_loop_normal() # <-- simple sequential loop
The normal loop is synchronous -- it waits for the GPU forward to finish before doing any CPU processing for the next batch:
@DynamicGradMode()
def event_loop_normal(self):
"""A normal scheduler loop."""
while True:
# Receive requests
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
if self._engine_paused:
self.cancel_bubble_timer()
continue
# Get the next batch to run
batch = self.get_next_batch_to_run()
self.cur_batch = batch
# Launch the current batch
if batch:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
else:
# When the server is idle, do self-check and re-init some states.
self.self_check_during_idle()
# Update last_batch
self.last_batch = batch
The overlap loop is the default and recommended mode. It pipelines CPU processing of the previous batch's results with the GPU forward of the current batch. This is the key performance innovation in SGLang:
@DynamicGradMode()
def event_loop_overlap(self):
"""A scheduler loop that overlaps CPU processing and GPU computation."""
self.result_queue: Deque[
Tuple[ScheduleBatch, Union[GenerationBatchResult, EmbeddingBatchResult]]
] = deque()
def pop_and_process():
# Process the results of the last batch
tmp_batch, tmp_result = self.result_queue.popleft()
self.process_batch_result(tmp_batch, tmp_result)
while True:
# Receive requests
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
if self._engine_paused:
continue
# Get the next batch to run
batch = self.get_next_batch_to_run()
self.cur_batch = batch
disable_overlap_for_batch = self.is_disable_overlap_for_batch(batch)
# If we do not need to overlap the current batch with the last batch,
# we can process the last batch immediately.
if disable_overlap_for_batch:
pop_and_process()
# Launch the current batch
if batch:
batch_result = self.run_batch(batch)
self.result_queue.append((batch.copy(), batch_result))
else:
batch_result = None
self.cancel_bubble_timer()
# Process the last batch (while current batch runs on GPU)
if self.last_batch:
if not disable_overlap_for_batch:
pop_and_process()
elif batch is None:
self.self_check_during_idle()
# Run sample of the current batch (depends on last batch's grammar)
if self.is_generation:
self.launch_batch_sample_if_needed(batch_result)
self.last_batch = batch
process_batch_result() for batch N
happens after batch N+1 has already been launched on the GPU via run_batch().
The CPU does useful work (result processing, token sampling, streaming output)
while the GPU is computing, hiding the CPU overhead almost entirely.
The overlap scheduling mechanism uses CUDA streams and a FutureMap to decouple the CPU scheduler from GPU execution. The core idea: the GPU can start working on the next batch while the CPU processes the previous batch's output tokens.
NORMAL EVENT LOOP (Sequential):
================================================================
CPU : [recv] [get_batch] [ wait ] [process_result] [recv] [get_batch] ...
GPU : [====== forward N ======] [=== forward N+1 ===]
Time : ---t0----------t1--t2-----------------------t3--------------t4-----t5-----------t6------------------->
^^^^^ ^^^^^^^^^
CPU idle CPU idle
(GPU done) (GPU done)
OVERLAP EVENT LOOP (Pipelined):
================================================================
CPU : [recv] [get_batch_N+1] [process_result_N] [sample_N+1] [recv] [get_batch_N+2] [process_result_N+1]
GPU : [======= forward N ========][======== forward N+1 =========][======= forward N+2 ==========]
Time : ---t0----t1---------t2------t3-----------t4---------t5------t6---------t7------t8-----------t9----->
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
CPU work overlaps with GPU work!
No CPU idle time between batches.
Legend: [recv] = recv_requests() + process_input_requests()
[get_batch] = get_next_batch_to_run()
[process_result] = process_batch_result() -- token checks, streaming output
[forward] = GPU model forward pass
[sample] = launch_batch_sample_if_needed() -- delayed sampling
The overlap mode uses separate CUDA streams. The scheduler constructs the batch on the
schedule_stream, then launches the forward on the forward_stream. Result tokens
are stored as "futures" in a FutureMap -- negative indices that get resolved later
when the GPU forward completes:
def init_overlap(self):
self.device_module = torch.get_device_module(self.device)
self.forward_stream_ctx = self.device_module.stream(self.forward_stream)
self.copy_stream = self.device_module.Stream()
self.copy_stream_ctx = self.device_module.stream(self.copy_stream)
if not self.enable_overlap:
self.future_map = None
return
self.future_map = FutureMap(
self.max_running_requests,
self.chunked_prefill_size,
self.model_config.context_len,
self.device,
self.spec_algorithm,
)
self.batch_record_buf = [None] * 2
self.batch_record_ct = 0
class FutureMap:
def __init__(self, max_running_requests, chunked_prefill_size,
context_len, device, spec_algo=None):
self.future_ct = 0
# Circular buffer layout (wraps in this order):
# Running decode batch -> Prefill chunk 1 -> ... -> Prefill chunk N
max_num_chunks = (
(context_len + chunked_prefill_size - 1) // chunked_prefill_size
if chunked_prefill_size else 0
)
self.future_limit = max_running_requests * (3 + max_num_chunks)
self.future_buffer_len = self.future_limit + 2 * max_running_requests
self.token_ids_buf = torch.empty(
(self.future_buffer_len,), dtype=torch.int64, device=device
)
When overlap is enabled, run_batch() launches the forward on the forward_stream
and immediately returns, storing the batch result copy as a non-blocking operation:
if self.enable_overlap:
model_worker_batch = worker_batch_or_batch
self.record_batch_in_overlap(model_worker_batch)
# Sampling info will be modified during forward, so we store a copy.
model_worker_batch.sampling_info = (
model_worker_batch.sampling_info.copy_for_forward()
)
bs = len(model_worker_batch.seq_lens)
future_indices = self.future_map.alloc_future_indices(bs)
with self.forward_stream_ctx, self.record_bubble_metrics(batch):
self.forward_stream.wait_stream(self.schedule_stream)
self.future_map.resolve_future(model_worker_batch)
with self.record_forward_metrics(batch):
batch_result = self.model_worker.forward_batch_generation(
model_worker_batch
)
batch_result.copy_done = self.device_module.Event()
if batch_result.delay_sample_func is None:
self.future_map.store_to_map(future_indices, batch_result)
batch_result.copy_to_cpu(return_logprob=batch.return_logprob)
else:
batch_result.future_indices = future_indices
# Use negative indices as future references
future_indices_or_next_token_ids = -future_indices.indices
Overlap is disabled for consecutive prefill batches (to improve TTFT) and for speculative decoding with grammar constraints:
def is_disable_overlap_for_batch(self, batch):
# For two consecutive prefill batches, disable overlap
# to improve the TTFT of the first batch.
batch_is_extend = is_extend(batch)
last_batch_is_extend = is_extend(self.last_batch)
disable_overlap_for_batch = (
envs.SGLANG_DISABLE_CONSECUTIVE_PREFILL_OVERLAP.get()
and batch_is_extend
and last_batch_is_extend
)
# We do not support overlap + spec + grammar yet
need_grammar_sync = (
batch and batch.is_spec_v2 and batch.has_grammar
and batch.forward_mode.is_decode()
and len(self.result_queue) > 0
)
return disable_overlap_for_batch or need_grammar_sync
This is the core scheduling function. It decides whether to run a prefill (extend) batch for new requests or a decode batch for ongoing generation. Prefill is always prioritized over decode when new requests are waiting:
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
self._abort_on_waiting_timeout()
self._abort_on_running_timeout()
# Merge the prefill batch into the running batch
if self.last_batch and self.last_batch.forward_mode.is_extend():
# Filter completed requests out of the last prefill batch
self.last_batch.filter_batch(
chunked_req_to_exclude=list(chunked_req_to_exclude)
)
# Merge into continuous running batch
if not self.last_batch.is_empty():
if self.running_batch.is_empty():
self.running_batch = self.last_batch
else:
self.running_batch.merge_batch(self.last_batch)
# Try to build a new prefill batch
new_batch = self.get_new_batch_prefill()
if new_batch is not None:
# Run prefill first if possible
ret = new_batch
else:
# Run decode (skip for prefill-only batches)
if (not self.running_batch.is_empty()
and not self.running_batch.is_prefill_only):
self.running_batch = self.update_running_batch(self.running_batch)
ret = self.running_batch if not self.running_batch.is_empty() else None
else:
ret = None
return ret
get_next_batch_to_run()
|
v
+--[Merge last prefill batch into running_batch]--+
| |
v |
get_new_batch_prefill() |
| |
+---> waiting_queue empty |
| OR batch_is_full? ---- YES ------------>+---> update_running_batch()
| | |
NO | +---> filter_batch() -- remove finished
| | +---> check_decode_mem() -- OOM check
v | +---> retract_decode() if OOM
policy.calc_priority(waiting_queue) | +---> prepare_for_decode()
| | |
v | v
PrefillAdder.add_one_req() for each req | Return decode ScheduleBatch
| |
+---> Check token budget |
+---> Match prefix (cache hit) |
+---> Allocate KV cache |
+---> Maybe chunk if too long |
| |
v |
ScheduleBatch.init_new(can_run_list) |
| |
v |
prepare_for_extend() |
| |
v |
Return prefill ScheduleBatch <------(priority)----+
The PrefillAdder is the core component that decides which waiting requests can fit
into a new prefill batch, respecting token budgets, memory limits, and chunked prefill:
class PrefillAdder:
def __init__(self, page_size, tree_cache, token_to_kv_pool_allocator,
running_batch, new_token_ratio, rem_input_tokens,
rem_chunk_tokens, mixed_with_decode_tokens=0, ...):
self.page_size = page_size
self.tree_cache = tree_cache
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.new_token_ratio = new_token_ratio
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
self.rem_chunk_tokens = rem_chunk_tokens
self.can_run_list = [] # Requests that will be in the batch
self.preempt_list = [] # Requests preempted by priority
self.new_chunked_req = None # Request being chunked
self.log_hit_tokens = 0 # Prefix cache hits logged
self.log_input_tokens = 0 # Input tokens logged
def add_one_req(self, req, has_chunked_req, truncation_align_size):
# Calculate total token cost: input + estimated output
total_tokens = req.extend_input_len + min(
max(req.sampling_params.max_new_tokens - len(req.output_ids), 0),
CLIP_MAX_NEW_TOKENS, # Default 4096 -- prevents over-conservatism
)
# Check global token budget
if total_tokens >= self.rem_total_tokens:
return AddReqResult.NO_TOKEN
# Check input token budget
if real_input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
return AddReqResult.OTHER
# Lock the cache node and try adding
with self._lock_node(req.last_node):
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
# Non-chunked prefill -- add entire request
self.can_run_list.append(req)
else:
# Chunked prefill -- truncate to remaining budget
trunc_len = self.rem_chunk_tokens // self.page_size * self.page_size
req.set_extend_input_len(trunc_len)
self.can_run_list.append(req)
self.new_chunked_req = req # Mark for continuation
return self.budget_state()
After a forward pass completes, process_batch_result() dispatches to mode-specific
handlers. In overlap mode, this runs on the CPU while the next batch is already
executing on the GPU:
def process_batch_result(self, batch, result):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
elif batch.forward_mode.is_extend():
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.process_batch_result_disagg_prefill(batch, result)
else:
self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_prebuilt():
self.process_batch_result_prebuilt(batch)
elif batch.forward_mode.is_idle():
self.process_batch_result_idle(batch, result)
self.log_batch_result_stats(batch, result)
self._maybe_clear_mm_inputs(batch)
def process_batch_result_decode(self, batch, result):
if result.copy_done is not None:
result.copy_done.synchronize() # Wait for GPU->CPU copy
logits_output, next_token_ids, can_run_cuda_graph = (
result.logits_output, result.next_token_ids, result.can_run_cuda_graph,
)
next_token_ids = next_token_ids.tolist()
# Process each request in the batch
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
req.output_ids.append(next_token_id) # Append new token
req.check_finished() # Check stop conditions
if req.finished():
release_kv_cache(req, self.tree_cache) # Free KV cache
if req.grammar is not None:
req.grammar.accept_token(next_token_id) # Update grammar state
self.stream_output(batch.reqs, batch.return_logprob)
def process_batch_result_prefill(self, batch, result):
if result.copy_done is not None:
result.copy_done.synchronize()
next_token_ids = result.next_token_ids.tolist()
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if req.is_chunked <= 0:
# Prefill complete -- append first output token
req.output_ids.append(next_token_id)
req.check_finished()
if req.finished():
release_kv_cache(req, self.tree_cache)
else:
self.tree_cache.cache_unfinished_req(req)
else:
# Still chunking -- decrement counter, don't stream yet
req.is_chunked -= 1
skip_stream_req = req
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
ForwardMode is an IntEnum that controls how the model processes each batch.
The two primary modes are EXTEND (prefill) and DECODE:
class ForwardMode(IntEnum):
# Extend a sequence. The KV cache of the beginning part is already
# computed (e.g., system prompt). Also called "prefill".
EXTEND = auto()
# Decode one token.
DECODE = auto()
# Contains both EXTEND and DECODE when doing chunked prefill.
MIXED = auto()
# No sequence to forward. For DP attention idle workers.
IDLE = auto()
# Used in speculative decoding: verify a batch in the target model.
TARGET_VERIFY = auto()
# Used in speculative decoding: extend a batch in the draft model.
DRAFT_EXTEND = auto()
DRAFT_EXTEND_V2 = auto()
# Used in disaggregated decode worker -- KV cache is pre-built
PREBUILT = auto()
# Split Prefill for PD multiplexing
SPLIT_PREFILL = auto()
# Used in diffusion LLM
DLLM_EXTEND = auto()
def is_extend(self, include_draft_extend_v2=False):
return (self == ForwardMode.EXTEND
or self == ForwardMode.MIXED
or self == ForwardMode.DRAFT_EXTEND
or self == ForwardMode.TARGET_VERIFY
or self == ForwardMode.SPLIT_PREFILL
or self == ForwardMode.DLLM_EXTEND)
| Mode | When Used | Input Shape | KV Cache Behavior |
|---|---|---|---|
EXTEND |
New request's prefill (or continuation after retraction) | Variable-length: all new tokens | Fills KV cache for all input tokens; reuses prefix cache hits |
DECODE |
Autoregressive generation -- one token per request | Exactly 1 token per request | Appends 1 KV entry per request |
MIXED |
Chunked prefill mixed with decode requests | Variable -- both extend and decode in one batch | Both extend and append operations |
IDLE |
DP attention: idle workers with no sequences | Empty | No-op |
def prepare_for_extend(self):
self.forward_mode = ForwardMode.EXTEND
reqs = self.reqs
# Input = all tokens after the cached prefix
input_ids = [r.fill_ids[len(r.prefix_indices):] for r in reqs]
extend_num_tokens = sum(len(ids) for ids in input_ids)
seq_lens = [len(r.fill_ids) for r in reqs]
prefix_lens = [len(r.prefix_indices) for r in reqs]
extend_lens = [r.extend_input_len for r in reqs]
def prepare_for_decode(self):
self.forward_mode = ForwardMode.DECODE
bs = len(self.reqs)
# Decode embeds the last output token via embed_tokens;
# clear stale prefill-time tensor
self.input_embeds = None
Every request (Req) follows a well-defined lifecycle through the scheduler.
The key states are: Waiting, Running (Extend), Running (Decode),
Chunked, Retracted, and Finished.
+-----------+
handle_generate_request() | | abort / error
TokenizedReq ------>------->--------| WAITING |---------->----+
| (queue) | |
+-----+-----+ |
| |
get_new_batch_prefill()| |
PrefillAdder.add_one() | |
v |
+--------+--------+ |
| EXTENDING | |
| (prefill batch) | |
+---+---------+---+ |
| | |
is_chunked > 0| | is_chunked == 0
| | |
v v |
+-------+--+ +--+--------+ |
| CHUNKED | | check_ | |
| (partial | | finished()| |
| prefill)| +--+--------+ |
+----+-----+ | | |
| finished| |not |
next round | v |finished|
prefill | +------+-----+ | |
+------------+ | FINISHED | | |
| (release KV)| | |
+------+-----+ | |
^ v |
| +----+-----+ |
retract_decode() | | DECODING | |
+-------<-------+------+--+ (running | |
| | decode | batch) | |
v | token +----+-----+ |
+-------+------+ | appended | |
| RETRACTED | | | |
| (back to | +------<--------+ |
| waiting) | |
+-------+------+ |
| |
+---------->---------->----->---------+ |
(re-queue) | |
v v
+-----+--+---+
| STREAM |
| OUTPUT |
| (to client) |
+-------------+
class Req(ReqDllmMixin):
"""The input and output status of a request."""
def __init__(self, rid, origin_input_text, origin_input_ids,
sampling_params, ...):
# Identity
self.rid = rid
self.origin_input_ids = origin_input_ids
# Output tracking
self.output_ids = [] # Generated token IDs
self.fill_ids = [] # origin_input_ids + output_ids
# KV cache management
self.kv_committed_len = 0 # Committed KV length
self.kv_allocated_len = 0 # Allocated KV length
self.req_pool_idx = None # Index in req_to_token_pool
# Prefix cache matching
self.prefix_indices = torch.empty((0,), dtype=torch.int64)
self.extend_input_len = 0 # Tokens to run prefill
self.last_node = None # RadixCache node
self.cached_tokens = 0 # Cache hit count
# Finish state
self.finished_reason = None # None = still running
self.to_finish = None # Deferred finish (safe mid-loop)
# Chunked prefill tracking
self.is_chunked = 0 # >0 means still chunking
self.is_retracted = False # Currently retracted
def check_finished(self, new_accepted_len=1):
if self.finished():
return
if self.to_finish:
self.finished_reason = self.to_finish
self.to_finish = None
return
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
self.finished_reason = FINISH_LENGTH(len(self.output_ids))
# ... also checks stop tokens, stop strings, regex matches
def finished(self) -> bool:
return self.finished_reason is not None
| Finish Reason | Trigger |
|---|---|
FINISH_LENGTH | len(output_ids) >= max_new_tokens |
FINISH_MATCHED_TOKEN | Output matches a stop token ID |
FINISH_MATCHED_STR | Decoded text matches a stop string |
FINISHED_MATCHED_REGEX | Output matches regex pattern |
FINISH_ABORT | Error, timeout, or manual abort |
The prefix cache (RadixCache) is deeply integrated with the scheduler. When a new request
arrives, init_next_round_input() queries the cache for matching prefixes to avoid
redundant computation:
def init_next_round_input(self, tree_cache=None, cow_mamba=None):
self.fill_ids = self.origin_input_ids + self.output_ids
input_len = len(self.fill_ids)
# Max prefix = input_len - 1 (need at least 1 new token for logprob)
max_prefix_len = input_len - 1
if self.return_logprob and self.logprob_start_len >= 0:
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
token_ids = self.fill_ids[:max_prefix_len]
if tree_cache is not None:
match_result = tree_cache.match_prefix(
MatchPrefixParams(key=RadixKey(token_ids=token_ids, ...))
)
self.prefix_indices = match_result.device_indices
self.last_node = match_result.last_device_node
# extend_input_len = total - cached prefix
self.extend_input_len = input_len - len(self.prefix_indices)
The SchedulePolicy class supports both cache-aware and cache-agnostic policies
that affect how the waiting queue is prioritized:
class CacheAwarePolicy(Enum):
"""Scheduling policies that are aware of the tree cache."""
LPM = "lpm" # Longest prefix match
DFS_WEIGHT = "dfs-weight" # Depth-first search weighting
class CacheAgnosticPolicy(Enum):
"""Scheduling policies that are not aware of the tree cache."""
FCFS = "fcfs" # First come first serve
LOF = "lof" # Longest output first
RANDOM = "random"
ROUTING_KEY = "routing-key" # Prioritize by routing key frequency
if self.policy == CacheAwarePolicy.LPM and len(waiting_queue) > 128: return CacheAgnosticPolicy.FCFS
The scheduler also performs in-batch prefix caching -- if multiple waiting requests share the same prefix but have small existing cache matches, only one is scheduled first so that its KV cache can benefit subsequent requests:
# If a request has a small match from existing cache,
# check for in-batch prefix sharing
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
match_result = self.waiting_queue_radix_tree.match_prefix(...)
in_batch_matching_prefixes = match_result.device_indices
if len(in_batch_matching_prefixes) >= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD:
temporary_deprioritized.add(r.rid) # Skip this request for now
else:
# Insert into waiting-queue radix tree for future matching
self.waiting_queue_radix_tree.insert(...)
Long prompts that exceed chunked_prefill_size are split across multiple forward passes.
This prevents a single large prefill from blocking all decode requests, enabling
low-latency decoding even during long prefills:
def init_chunked_prefill(self):
self.chunked_prefill_size = self.server_args.chunked_prefill_size
if self.chunked_prefill_size <= 0: # -1 means disable
self.chunked_prefill_size = None
self.chunked_req = None
self.is_mixed_chunk = (
self.chunked_prefill_size is not None
and self.server_args.enable_mixed_chunk
)
When a request's input exceeds the remaining chunk budget:
else:
# Chunked prefill -- truncate to remaining budget
trunc_len = self.rem_chunk_tokens // self.page_size * self.page_size
if trunc_len <= 0:
return AddReqResult.OTHER
# When deterministic inference is enabled, align truncation
if truncation_align_size is not None:
trunc_len = truncation_align_size * (trunc_len // truncation_align_size)
# Truncate the request's fill_ids
req.set_extend_input_len(trunc_len)
req.fill_ids = req.fill_ids[:len(req.prefix_indices) + trunc_len]
self.can_run_list.append(req)
self.new_chunked_req = req # Track for next iteration
# In get_new_batch_prefill: resume the chunked request first
if self.chunked_req is not None:
self.chunked_req.init_next_round_input() # Recompute fill_ids/prefix
self.chunked_req = adder.add_chunked_req(self.chunked_req)
When enable_mixed_chunk is true, the scheduler can combine a chunked prefill with
ongoing decode requests in a single batch:
# Mixed-style chunked prefill
if (self.is_mixed_chunk
and not self.running_batch.is_empty()
and not (new_batch.return_logprob or self.running_batch.return_logprob)
and new_batch.input_embeds is None):
self.running_batch.filter_batch(v1_spec_info_filtered=True)
if not self.running_batch.is_empty():
self.running_batch.prepare_for_decode()
new_batch.mix_with_running(self.running_batch)
new_batch.decoding_reqs = self.running_batch.reqs
ForwardMode.MIXED and the attention
kernel handles both extend and decode positions in a single call.
When running with dp_size > 1, a DataParallelController dispatches requests
across multiple scheduler instances. It supports several load-balancing strategies:
class LoadBalanceMethod(Enum):
"""Load balance method."""
ROUND_ROBIN = auto()
FOLLOW_BOOTSTRAP_ROOM = auto()
TOTAL_REQUESTS = auto() # Dispatch to worker with fewest requests
TOTAL_TOKENS = auto() # Dispatch to worker with fewest tokens
class DPBudget:
def __init__(self, dp_size):
self.dp_size = dp_size
self.total_requests = [0] * dp_size
self.total_tokens = [0] * dp_size
def dispatch(self, method):
if method == LoadBalanceMethod.TOTAL_REQUESTS:
target_rank = self.total_requests.index(min(self.total_requests))
elif method == LoadBalanceMethod.TOTAL_TOKENS:
target_rank = min(
range(self.dp_size),
key=lambda i: (self.total_tokens[i], self.total_requests[i]),
)
self.total_requests[target_rank] += 1 # Heuristic increment
return target_rank
class DataParallelController:
"""A controller that dispatches requests to multiple DP workers."""
def __init__(self, server_args, port_args, run_scheduler_process_func):
self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method
)
# ZMQ sockets to each DP worker
self.workers = [None] * server_args.dp_size
self.status = [True] * server_args.dp_size
# Launch scheduler processes for each DP rank
if server_args.enable_dp_attention:
self.launch_dp_attention_schedulers(server_args, port_args)
else:
self.launch_dp_schedulers(server_args, port_args)
enable_dp_attention is set, all DP workers share the
same tensor-parallel group but split the attention computation across workers. This requires
careful synchronization -- the scheduler broadcasts work requests via the attention TP group
and uses maybe_prepare_mlp_sync_batch() to coordinate MLP operations.
| Field | Type | Description |
|---|---|---|
rid | str | Unique request identifier |
origin_input_ids | List[int] | Original tokenized input (with MM padding) |
output_ids | List[int] | Generated token IDs so far |
fill_ids | List[int] | origin_input_ids + output_ids (updated each round) |
prefix_indices | Tensor | KV cache indices from prefix cache match |
extend_input_len | int | Number of new tokens to process in this extend |
last_node | TreeNode | Last matched node in RadixCache |
req_pool_idx | int | Index into req_to_token_pool |
kv_committed_len | int | Length of committed KV cache |
is_chunked | int | Chunk counter (>0 means still chunking) |
finished_reason | BaseFinishReason | None = still running; set when done |
sampling_params | SamplingParams | Temperature, top_p, max_new_tokens, etc. |
grammar | BaseGrammarObject | Grammar constraint state (for structured output) |
| Field | Type | Description |
|---|---|---|
reqs | List[Req] | Requests in this batch |
forward_mode | ForwardMode | EXTEND, DECODE, MIXED, IDLE, etc. |
input_ids | Tensor [N] | Flattened input token IDs |
req_pool_indices | Tensor [B] | Indices into req_to_token_pool |
seq_lens | Tensor [B] | Sequence lengths (GPU) |
seq_lens_cpu | Tensor [B] | Sequence lengths (CPU mirror) |
out_cache_loc | Tensor [N] | Output locations in KV pool |
sampling_info | SamplingBatchInfo | Batched sampling parameters |
batch_is_full | bool | Skip prefill check if True |
prefix_lens | List[int] | Cached prefix lengths (extend mode) |
extend_lens | List[int] | New tokens per request (extend mode) |
| Field | Type | Description |
|---|---|---|
forward_mode | ForwardMode | Forward pass type |
input_ids | Tensor | Token IDs for this batch |
req_pool_indices | Tensor | Request pool indices |
seq_lens | Tensor | Sequence lengths (GPU) |
out_cache_loc | Tensor | KV cache output locations |
extend_num_tokens | int | Total extend tokens in batch |
extend_seq_lens | List[int] | Per-request extend sequence lengths |
extend_prefix_lens | List[int] | Per-request cached prefix lengths |
sampling_info | SamplingBatchInfo | Sampling parameters for GPU |
multimodal_inputs | List | Multimodal data (images, video, audio) |
lora_ids | List[str] | LoRA adapter IDs per request |
class TpModelWorker(BaseTpWorker):
"""A tensor parallel model worker."""
def __init__(self, server_args, gpu_id, tp_rank, ...):
self.tp_rank = tp_rank
self.gpu_id = gpu_id
self._init_model_runner() # Creates ModelRunner
self.enable_overlap = not server_args.disable_overlap_schedule
def forward_batch_generation(self, model_worker_batch, ...):
# Convert ModelWorkerBatch -> ForwardBatch -> GPU forward
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
out = self.model_runner.forward(forward_batch, ...)
# Sample next tokens
next_token_ids = self.model_runner.sample(out.logits_output, forward_batch)
return GenerationBatchResult(
logits_output=out.logits_output,
next_token_ids=next_token_ids,
)
def init_running_status(self):
self.waiting_queue: List[Req] = []
# The running decoding batch for continuous batching
self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
# The current forward batch
self.cur_batch: Optional[ScheduleBatch] = None
# The last forward batch
self.last_batch: Optional[ScheduleBatch] = None
self.forward_ct = 0
FutureMap stores token IDs as negative-index futures that are resolved when the
GPU forward completes, enabling the scheduler to prepare the next batch immediately.get_next_batch_to_run() always prioritizes prefill over decode, using a
PrefillAdder to fit requests within token budgets and chunk long prompts.SchedulePolicy supports cache-aware policies (LPM, DFS-weight) that prioritize
requests with longer prefix cache hits, plus in-batch prefix caching to improve hit rates.DataParallelController distributes requests across DP workers using pluggable
load-balancing strategies (round-robin, min-requests, min-tokens).