Scheduler & Overlap Scheduling

A code-level deep dive into SGLang's scheduler architecture: the two event loops, continuous batching, overlap scheduling innovation, request lifecycle management, chunked prefill, and data-parallel dispatch.

← Back to AI Infra Overview

Table of Contents

  1. Architecture Overview & Data Flow
  2. event_loop_normal vs event_loop_overlap
  3. Overlap Scheduling Innovation
  4. get_next_batch_to_run() -- Batch Construction
  5. process_batch_result() -- Result Handling
  6. ForwardMode: Extend vs Decode
  7. Request Lifecycle & State Machine
  8. Prefix Cache Hits & Scheduling
  9. Chunked Prefill & Long Prompts
  10. Data-Parallel Scheduling
  11. Key Data Structures Reference

1. Architecture Overview & Data Flow

SGLang's serving engine is organized around a Scheduler that manages one or more tensor-parallel GPU workers. The scheduler runs a tight event loop on the CPU, receiving requests, constructing batches, dispatching forward passes, and processing results. The key data flow documented in the source is:

"""
The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
  It is a subset of `ScheduleBatch` that only contains data related
  to the model forward on GPU. It will be transformed from CPU scheduler
  to GPU model runner.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Data Flow: Request to GPU Forward

  TokenizedGenerateReqInput          ScheduleBatch            ModelWorkerBatch          ForwardBatch
  ========================     =====================    =======================    ==================
  | rid, input_ids,      |    | reqs: List[Req]    |   | forward_mode         |   | GPU tensors:   |
  | sampling_params,     | -> | forward_mode       | ->| input_ids (tensor)   | ->| input_ids      |
  | mm_inputs, ...       |    | seq_lens, out_cache |   | seq_lens (tensor)    |   | positions      |
  ========================    | sampling_info       |   | sampling_info        |   | attn metadata  |
       (from tokenizer)       =====================    =======================    ==================
                                   (CPU-side)              (CPU->GPU bridge)          (GPU-side)

  recv_requests()         get_next_batch_to_run()    get_model_worker_batch()     ForwardBatch.init_new()

The Scheduler class itself is built via mixin composition:

class Scheduler(
    SchedulerOutputProcessorMixin,
    SchedulerUpdateWeightsMixin,
    SchedulerProfilerMixin,
    SchedulerMetricsMixin,
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
    SchedulerMultiplexMixin,
    SchedulerRuntimeCheckerMixin,
    SchedulerPPMixin,
    SchedulerDPAttnMixin,
    SchedulerDllmMixin,
):
    """A scheduler that manages a tensor parallel GPU worker."""

2. event_loop_normal vs event_loop_overlap

SGLang provides two event loop implementations. The dispatcher selects between them at startup based on server_args.disable_overlap_schedule:

def dispatch_event_loop(scheduler: Scheduler):
    server_args = scheduler.server_args
    disaggregation_mode = scheduler.disaggregation_mode
    if disaggregation_mode == DisaggregationMode.NULL:
        if scheduler.enable_pdmux:
            scheduler.event_loop_pdmux()
        elif server_args.pp_size > 1:
            scheduler.event_loop_pp()
        elif scheduler.enable_overlap:
            scheduler.event_loop_overlap()   # <-- overlap scheduling (DEFAULT)
        else:
            scheduler.event_loop_normal()    # <-- simple sequential loop

2.1 The Normal Event Loop

The normal loop is synchronous -- it waits for the GPU forward to finish before doing any CPU processing for the next batch:

@DynamicGradMode()
def event_loop_normal(self):
    """A normal scheduler loop."""
    while True:
        # Receive requests
        recv_reqs = self.recv_requests()
        self.process_input_requests(recv_reqs)
        if self._engine_paused:
            self.cancel_bubble_timer()
            continue

        # Get the next batch to run
        batch = self.get_next_batch_to_run()
        self.cur_batch = batch

        # Launch the current batch
        if batch:
            result = self.run_batch(batch)
            self.process_batch_result(batch, result)
        else:
            # When the server is idle, do self-check and re-init some states.
            self.self_check_during_idle()

        # Update last_batch
        self.last_batch = batch

2.2 The Overlap Event Loop

The overlap loop is the default and recommended mode. It pipelines CPU processing of the previous batch's results with the GPU forward of the current batch. This is the key performance innovation in SGLang:

@DynamicGradMode()
def event_loop_overlap(self):
    """A scheduler loop that overlaps CPU processing and GPU computation."""
    self.result_queue: Deque[
        Tuple[ScheduleBatch, Union[GenerationBatchResult, EmbeddingBatchResult]]
    ] = deque()

    def pop_and_process():
        # Process the results of the last batch
        tmp_batch, tmp_result = self.result_queue.popleft()
        self.process_batch_result(tmp_batch, tmp_result)

    while True:
        # Receive requests
        recv_reqs = self.recv_requests()
        self.process_input_requests(recv_reqs)
        if self._engine_paused:
            continue

        # Get the next batch to run
        batch = self.get_next_batch_to_run()
        self.cur_batch = batch
        disable_overlap_for_batch = self.is_disable_overlap_for_batch(batch)

        # If we do not need to overlap the current batch with the last batch,
        # we can process the last batch immediately.
        if disable_overlap_for_batch:
            pop_and_process()

        # Launch the current batch
        if batch:
            batch_result = self.run_batch(batch)
            self.result_queue.append((batch.copy(), batch_result))
        else:
            batch_result = None
            self.cancel_bubble_timer()

        # Process the last batch (while current batch runs on GPU)
        if self.last_batch:
            if not disable_overlap_for_batch:
                pop_and_process()
        elif batch is None:
            self.self_check_during_idle()

        # Run sample of the current batch (depends on last batch's grammar)
        if self.is_generation:
            self.launch_batch_sample_if_needed(batch_result)

        self.last_batch = batch
Key Insight: In the overlap loop, process_batch_result() for batch N happens after batch N+1 has already been launched on the GPU via run_batch(). The CPU does useful work (result processing, token sampling, streaming output) while the GPU is computing, hiding the CPU overhead almost entirely.

3. Overlap Scheduling Innovation

The overlap scheduling mechanism uses CUDA streams and a FutureMap to decouple the CPU scheduler from GPU execution. The core idea: the GPU can start working on the next batch while the CPU processes the previous batch's output tokens.

(a) Normal vs Overlap Event Loop Timeline

  NORMAL EVENT LOOP (Sequential):
  ================================================================
  CPU  : [recv] [get_batch] [         wait         ] [process_result] [recv] [get_batch] ...
  GPU  :                    [====== forward N ======]                                    [=== forward N+1 ===]
  Time : ---t0----------t1--t2-----------------------t3--------------t4-----t5-----------t6------------------->
                                                                      ^^^^^               ^^^^^^^^^
                                                                     CPU idle            CPU idle
                                                                     (GPU done)          (GPU done)

  OVERLAP EVENT LOOP (Pipelined):
  ================================================================
  CPU  : [recv] [get_batch_N+1] [process_result_N] [sample_N+1]  [recv] [get_batch_N+2] [process_result_N+1]
  GPU  : [======= forward N ========][======== forward N+1 =========][======= forward N+2 ==========]
  Time : ---t0----t1---------t2------t3-----------t4---------t5------t6---------t7------t8-----------t9----->
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                              CPU work overlaps with GPU work!
                              No CPU idle time between batches.

  Legend:  [recv]           = recv_requests() + process_input_requests()
           [get_batch]      = get_next_batch_to_run()
           [process_result] = process_batch_result() -- token checks, streaming output
           [forward]        = GPU model forward pass
           [sample]         = launch_batch_sample_if_needed() -- delayed sampling

3.1 CUDA Streams and Future Tokens

The overlap mode uses separate CUDA streams. The scheduler constructs the batch on the schedule_stream, then launches the forward on the forward_stream. Result tokens are stored as "futures" in a FutureMap -- negative indices that get resolved later when the GPU forward completes:

def init_overlap(self):
    self.device_module = torch.get_device_module(self.device)
    self.forward_stream_ctx = self.device_module.stream(self.forward_stream)
    self.copy_stream = self.device_module.Stream()
    self.copy_stream_ctx = self.device_module.stream(self.copy_stream)

    if not self.enable_overlap:
        self.future_map = None
        return

    self.future_map = FutureMap(
        self.max_running_requests,
        self.chunked_prefill_size,
        self.model_config.context_len,
        self.device,
        self.spec_algorithm,
    )
    self.batch_record_buf = [None] * 2
    self.batch_record_ct = 0
class FutureMap:
    def __init__(self, max_running_requests, chunked_prefill_size,
                 context_len, device, spec_algo=None):
        self.future_ct = 0
        # Circular buffer layout (wraps in this order):
        # Running decode batch -> Prefill chunk 1 -> ... -> Prefill chunk N
        max_num_chunks = (
            (context_len + chunked_prefill_size - 1) // chunked_prefill_size
            if chunked_prefill_size else 0
        )
        self.future_limit = max_running_requests * (3 + max_num_chunks)
        self.future_buffer_len = self.future_limit + 2 * max_running_requests
        self.token_ids_buf = torch.empty(
            (self.future_buffer_len,), dtype=torch.int64, device=device
        )

3.2 The run_batch Overlap Path

When overlap is enabled, run_batch() launches the forward on the forward_stream and immediately returns, storing the batch result copy as a non-blocking operation:

if self.enable_overlap:
    model_worker_batch = worker_batch_or_batch
    self.record_batch_in_overlap(model_worker_batch)

    # Sampling info will be modified during forward, so we store a copy.
    model_worker_batch.sampling_info = (
        model_worker_batch.sampling_info.copy_for_forward()
    )

    bs = len(model_worker_batch.seq_lens)
    future_indices = self.future_map.alloc_future_indices(bs)

    with self.forward_stream_ctx, self.record_bubble_metrics(batch):
        self.forward_stream.wait_stream(self.schedule_stream)
        self.future_map.resolve_future(model_worker_batch)
        with self.record_forward_metrics(batch):
            batch_result = self.model_worker.forward_batch_generation(
                model_worker_batch
            )
        batch_result.copy_done = self.device_module.Event()
        if batch_result.delay_sample_func is None:
            self.future_map.store_to_map(future_indices, batch_result)
            batch_result.copy_to_cpu(return_logprob=batch.return_logprob)
        else:
            batch_result.future_indices = future_indices

    # Use negative indices as future references
    future_indices_or_next_token_ids = -future_indices.indices

3.3 Disabling Overlap for Specific Cases

Overlap is disabled for consecutive prefill batches (to improve TTFT) and for speculative decoding with grammar constraints:

def is_disable_overlap_for_batch(self, batch):
    # For two consecutive prefill batches, disable overlap
    # to improve the TTFT of the first batch.
    batch_is_extend = is_extend(batch)
    last_batch_is_extend = is_extend(self.last_batch)

    disable_overlap_for_batch = (
        envs.SGLANG_DISABLE_CONSECUTIVE_PREFILL_OVERLAP.get()
        and batch_is_extend
        and last_batch_is_extend
    )

    # We do not support overlap + spec + grammar yet
    need_grammar_sync = (
        batch and batch.is_spec_v2 and batch.has_grammar
        and batch.forward_mode.is_decode()
        and len(self.result_queue) > 0
    )
    return disable_overlap_for_batch or need_grammar_sync

4. get_next_batch_to_run() -- Batch Construction

This is the core scheduling function. It decides whether to run a prefill (extend) batch for new requests or a decode batch for ongoing generation. Prefill is always prioritized over decode when new requests are waiting:

def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
    self._abort_on_waiting_timeout()
    self._abort_on_running_timeout()

    # Merge the prefill batch into the running batch
    if self.last_batch and self.last_batch.forward_mode.is_extend():
        # Filter completed requests out of the last prefill batch
        self.last_batch.filter_batch(
            chunked_req_to_exclude=list(chunked_req_to_exclude)
        )
        # Merge into continuous running batch
        if not self.last_batch.is_empty():
            if self.running_batch.is_empty():
                self.running_batch = self.last_batch
            else:
                self.running_batch.merge_batch(self.last_batch)

    # Try to build a new prefill batch
    new_batch = self.get_new_batch_prefill()

    if new_batch is not None:
        # Run prefill first if possible
        ret = new_batch
    else:
        # Run decode (skip for prefill-only batches)
        if (not self.running_batch.is_empty()
            and not self.running_batch.is_prefill_only):
            self.running_batch = self.update_running_batch(self.running_batch)
            ret = self.running_batch if not self.running_batch.is_empty() else None
        else:
            ret = None

    return ret
(c) Batch Construction Flow

  get_next_batch_to_run()
      |
      v
  +--[Merge last prefill batch into running_batch]--+
  |                                                  |
  v                                                  |
  get_new_batch_prefill()                            |
      |                                              |
      +---> waiting_queue empty                      |
      |     OR batch_is_full?  ---- YES ------------>+---> update_running_batch()
      |                                              |          |
      NO                                             |          +---> filter_batch() -- remove finished
      |                                              |          +---> check_decode_mem() -- OOM check
      v                                              |          +---> retract_decode() if OOM
  policy.calc_priority(waiting_queue)                |          +---> prepare_for_decode()
      |                                              |          |
      v                                              |          v
  PrefillAdder.add_one_req() for each req            |     Return decode ScheduleBatch
      |                                              |
      +---> Check token budget                       |
      +---> Match prefix (cache hit)                 |
      +---> Allocate KV cache                        |
      +---> Maybe chunk if too long                  |
      |                                              |
      v                                              |
  ScheduleBatch.init_new(can_run_list)               |
      |                                              |
      v                                              |
  prepare_for_extend()                               |
      |                                              |
      v                                              |
  Return prefill ScheduleBatch  <------(priority)----+

4.1 The PrefillAdder -- Token Budget Manager

The PrefillAdder is the core component that decides which waiting requests can fit into a new prefill batch, respecting token budgets, memory limits, and chunked prefill:

class PrefillAdder:
    def __init__(self, page_size, tree_cache, token_to_kv_pool_allocator,
                 running_batch, new_token_ratio, rem_input_tokens,
                 rem_chunk_tokens, mixed_with_decode_tokens=0, ...):
        self.page_size = page_size
        self.tree_cache = tree_cache
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
        self.new_token_ratio = new_token_ratio
        self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
        self.rem_chunk_tokens = rem_chunk_tokens

        self.can_run_list = []       # Requests that will be in the batch
        self.preempt_list = []       # Requests preempted by priority
        self.new_chunked_req = None  # Request being chunked
        self.log_hit_tokens = 0      # Prefix cache hits logged
        self.log_input_tokens = 0    # Input tokens logged
def add_one_req(self, req, has_chunked_req, truncation_align_size):
    # Calculate total token cost: input + estimated output
    total_tokens = req.extend_input_len + min(
        max(req.sampling_params.max_new_tokens - len(req.output_ids), 0),
        CLIP_MAX_NEW_TOKENS,  # Default 4096 -- prevents over-conservatism
    )

    # Check global token budget
    if total_tokens >= self.rem_total_tokens:
        return AddReqResult.NO_TOKEN

    # Check input token budget
    if real_input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
        return AddReqResult.OTHER

    # Lock the cache node and try adding
    with self._lock_node(req.last_node):
        if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
            # Non-chunked prefill -- add entire request
            self.can_run_list.append(req)
        else:
            # Chunked prefill -- truncate to remaining budget
            trunc_len = self.rem_chunk_tokens // self.page_size * self.page_size
            req.set_extend_input_len(trunc_len)
            self.can_run_list.append(req)
            self.new_chunked_req = req   # Mark for continuation

    return self.budget_state()

5. process_batch_result() -- Result Handling

After a forward pass completes, process_batch_result() dispatches to mode-specific handlers. In overlap mode, this runs on the CPU while the next batch is already executing on the GPU:

def process_batch_result(self, batch, result):
    if batch.forward_mode.is_decode():
        self.process_batch_result_decode(batch, result)
    elif batch.forward_mode.is_extend():
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            self.process_batch_result_disagg_prefill(batch, result)
        else:
            self.process_batch_result_prefill(batch, result)
    elif batch.forward_mode.is_prebuilt():
        self.process_batch_result_prebuilt(batch)
    elif batch.forward_mode.is_idle():
        self.process_batch_result_idle(batch, result)

    self.log_batch_result_stats(batch, result)
    self._maybe_clear_mm_inputs(batch)

5.1 Decode Result Processing

def process_batch_result_decode(self, batch, result):
    if result.copy_done is not None:
        result.copy_done.synchronize()  # Wait for GPU->CPU copy

    logits_output, next_token_ids, can_run_cuda_graph = (
        result.logits_output, result.next_token_ids, result.can_run_cuda_graph,
    )
    next_token_ids = next_token_ids.tolist()

    # Process each request in the batch
    for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
        req.output_ids.append(next_token_id)          # Append new token
        req.check_finished()                            # Check stop conditions
        if req.finished():
            release_kv_cache(req, self.tree_cache)  # Free KV cache
        if req.grammar is not None:
            req.grammar.accept_token(next_token_id)    # Update grammar state

    self.stream_output(batch.reqs, batch.return_logprob)

5.2 Prefill Result Processing

def process_batch_result_prefill(self, batch, result):
    if result.copy_done is not None:
        result.copy_done.synchronize()

    next_token_ids = result.next_token_ids.tolist()

    for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
        if req.is_chunked <= 0:
            # Prefill complete -- append first output token
            req.output_ids.append(next_token_id)
            req.check_finished()
            if req.finished():
                release_kv_cache(req, self.tree_cache)
            else:
                self.tree_cache.cache_unfinished_req(req)
        else:
            # Still chunking -- decrement counter, don't stream yet
            req.is_chunked -= 1
            skip_stream_req = req

    self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)

6. ForwardMode: Extend vs Decode

ForwardMode is an IntEnum that controls how the model processes each batch. The two primary modes are EXTEND (prefill) and DECODE:

class ForwardMode(IntEnum):
    # Extend a sequence. The KV cache of the beginning part is already
    # computed (e.g., system prompt). Also called "prefill".
    EXTEND = auto()

    # Decode one token.
    DECODE = auto()

    # Contains both EXTEND and DECODE when doing chunked prefill.
    MIXED = auto()

    # No sequence to forward. For DP attention idle workers.
    IDLE = auto()

    # Used in speculative decoding: verify a batch in the target model.
    TARGET_VERIFY = auto()

    # Used in speculative decoding: extend a batch in the draft model.
    DRAFT_EXTEND = auto()
    DRAFT_EXTEND_V2 = auto()

    # Used in disaggregated decode worker -- KV cache is pre-built
    PREBUILT = auto()

    # Split Prefill for PD multiplexing
    SPLIT_PREFILL = auto()

    # Used in diffusion LLM
    DLLM_EXTEND = auto()

    def is_extend(self, include_draft_extend_v2=False):
        return (self == ForwardMode.EXTEND
            or self == ForwardMode.MIXED
            or self == ForwardMode.DRAFT_EXTEND
            or self == ForwardMode.TARGET_VERIFY
            or self == ForwardMode.SPLIT_PREFILL
            or self == ForwardMode.DLLM_EXTEND)
ModeWhen UsedInput ShapeKV Cache Behavior
EXTEND New request's prefill (or continuation after retraction) Variable-length: all new tokens Fills KV cache for all input tokens; reuses prefix cache hits
DECODE Autoregressive generation -- one token per request Exactly 1 token per request Appends 1 KV entry per request
MIXED Chunked prefill mixed with decode requests Variable -- both extend and decode in one batch Both extend and append operations
IDLE DP attention: idle workers with no sequences Empty No-op

Extend vs Decode in Batch Preparation

def prepare_for_extend(self):
    self.forward_mode = ForwardMode.EXTEND
    reqs = self.reqs
    # Input = all tokens after the cached prefix
    input_ids = [r.fill_ids[len(r.prefix_indices):] for r in reqs]
    extend_num_tokens = sum(len(ids) for ids in input_ids)
    seq_lens = [len(r.fill_ids) for r in reqs]
    prefix_lens = [len(r.prefix_indices) for r in reqs]
    extend_lens = [r.extend_input_len for r in reqs]
def prepare_for_decode(self):
    self.forward_mode = ForwardMode.DECODE
    bs = len(self.reqs)
    # Decode embeds the last output token via embed_tokens;
    # clear stale prefill-time tensor
    self.input_embeds = None

7. Request Lifecycle & State Machine

Every request (Req) follows a well-defined lifecycle through the scheduler. The key states are: Waiting, Running (Extend), Running (Decode), Chunked, Retracted, and Finished.

(b) Request State Machine

                                      +-----------+
            handle_generate_request() |           |  abort / error
  TokenizedReq ------>------->--------| WAITING   |---------->----+
                                      | (queue)   |               |
                                      +-----+-----+               |
                                            |                     |
                      get_new_batch_prefill()|                     |
                      PrefillAdder.add_one() |                     |
                                            v                     |
                                   +--------+--------+            |
                                   |    EXTENDING     |            |
                                   | (prefill batch)  |            |
                                   +---+---------+---+            |
                                       |         |                |
                          is_chunked > 0|         | is_chunked == 0
                                       |         |                |
                                       v         v                |
                               +-------+--+  +--+--------+       |
                               | CHUNKED  |  | check_    |       |
                               | (partial |  | finished()|       |
                               |  prefill)|  +--+--------+       |
                               +----+-----+     |       |        |
                                    |    finished|       |not     |
                       next round   |            v       |finished|
                       prefill      |     +------+-----+ |        |
                       +------------+     |  FINISHED   | |        |
                                          | (release KV)| |        |
                                          +------+-----+ |        |
                                                 ^       v        |
                                                 |  +----+-----+  |
                          retract_decode()       |  | DECODING  |  |
                          +-------<-------+------+--+ (running  |  |
                          |               |  decode  |  batch)   |  |
                          v               |  token   +----+-----+  |
                  +-------+------+        |  appended     |        |
                  | RETRACTED    |        |               |        |
                  | (back to     |        +------<--------+        |
                  |  waiting)    |                                  |
                  +-------+------+                                 |
                          |                                        |
                          +---------->---------->----->---------+  |
                                      (re-queue)               |  |
                                                               v  v
                                                         +-----+--+---+
                                                         | STREAM      |
                                                         | OUTPUT      |
                                                         | (to client) |
                                                         +-------------+

7.1 Key Req Fields

class Req(ReqDllmMixin):
    """The input and output status of a request."""

    def __init__(self, rid, origin_input_text, origin_input_ids,
                 sampling_params, ...):
        # Identity
        self.rid = rid
        self.origin_input_ids = origin_input_ids

        # Output tracking
        self.output_ids = []               # Generated token IDs
        self.fill_ids = []                 # origin_input_ids + output_ids

        # KV cache management
        self.kv_committed_len = 0          # Committed KV length
        self.kv_allocated_len = 0          # Allocated KV length
        self.req_pool_idx = None           # Index in req_to_token_pool

        # Prefix cache matching
        self.prefix_indices = torch.empty((0,), dtype=torch.int64)
        self.extend_input_len = 0         # Tokens to run prefill
        self.last_node = None              # RadixCache node
        self.cached_tokens = 0            # Cache hit count

        # Finish state
        self.finished_reason = None        # None = still running
        self.to_finish = None              # Deferred finish (safe mid-loop)

        # Chunked prefill tracking
        self.is_chunked = 0               # >0 means still chunking
        self.is_retracted = False         # Currently retracted

7.2 Finish Conditions

def check_finished(self, new_accepted_len=1):
    if self.finished():
        return

    if self.to_finish:
        self.finished_reason = self.to_finish
        self.to_finish = None
        return

    if len(self.output_ids) >= self.sampling_params.max_new_tokens:
        self.finished_reason = FINISH_LENGTH(len(self.output_ids))
    # ... also checks stop tokens, stop strings, regex matches

def finished(self) -> bool:
    return self.finished_reason is not None
Finish ReasonTrigger
FINISH_LENGTHlen(output_ids) >= max_new_tokens
FINISH_MATCHED_TOKENOutput matches a stop token ID
FINISH_MATCHED_STRDecoded text matches a stop string
FINISHED_MATCHED_REGEXOutput matches regex pattern
FINISH_ABORTError, timeout, or manual abort

8. Prefix Cache Hits & Scheduling Decisions

The prefix cache (RadixCache) is deeply integrated with the scheduler. When a new request arrives, init_next_round_input() queries the cache for matching prefixes to avoid redundant computation:

def init_next_round_input(self, tree_cache=None, cow_mamba=None):
    self.fill_ids = self.origin_input_ids + self.output_ids
    input_len = len(self.fill_ids)

    # Max prefix = input_len - 1 (need at least 1 new token for logprob)
    max_prefix_len = input_len - 1
    if self.return_logprob and self.logprob_start_len >= 0:
        max_prefix_len = min(max_prefix_len, self.logprob_start_len)

    token_ids = self.fill_ids[:max_prefix_len]
    if tree_cache is not None:
        match_result = tree_cache.match_prefix(
            MatchPrefixParams(key=RadixKey(token_ids=token_ids, ...))
        )
        self.prefix_indices = match_result.device_indices
        self.last_node = match_result.last_device_node
    # extend_input_len = total - cached prefix
    self.extend_input_len = input_len - len(self.prefix_indices)

8.1 Scheduling Policies

The SchedulePolicy class supports both cache-aware and cache-agnostic policies that affect how the waiting queue is prioritized:

class CacheAwarePolicy(Enum):
    """Scheduling policies that are aware of the tree cache."""
    LPM = "lpm"          # Longest prefix match
    DFS_WEIGHT = "dfs-weight"  # Depth-first search weighting

class CacheAgnosticPolicy(Enum):
    """Scheduling policies that are not aware of the tree cache."""
    FCFS = "fcfs"         # First come first serve
    LOF = "lof"           # Longest output first
    RANDOM = "random"
    ROUTING_KEY = "routing-key"  # Prioritize by routing key frequency
Performance Optimization: The LPM policy automatically falls back to FCFS when the waiting queue exceeds 128 requests, because the O(N*M) prefix matching becomes too expensive:
if self.policy == CacheAwarePolicy.LPM and len(waiting_queue) > 128: return CacheAgnosticPolicy.FCFS

8.2 In-Batch Prefix Caching

The scheduler also performs in-batch prefix caching -- if multiple waiting requests share the same prefix but have small existing cache matches, only one is scheduled first so that its KV cache can benefit subsequent requests:

# If a request has a small match from existing cache,
# check for in-batch prefix sharing
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
    match_result = self.waiting_queue_radix_tree.match_prefix(...)
    in_batch_matching_prefixes = match_result.device_indices
    if len(in_batch_matching_prefixes) >= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD:
        temporary_deprioritized.add(r.rid)  # Skip this request for now
    else:
        # Insert into waiting-queue radix tree for future matching
        self.waiting_queue_radix_tree.insert(...)

9. Chunked Prefill & Long Prompt Handling

Long prompts that exceed chunked_prefill_size are split across multiple forward passes. This prevents a single large prefill from blocking all decode requests, enabling low-latency decoding even during long prefills:

def init_chunked_prefill(self):
    self.chunked_prefill_size = self.server_args.chunked_prefill_size
    if self.chunked_prefill_size <= 0:  # -1 means disable
        self.chunked_prefill_size = None
    self.chunked_req = None
    self.is_mixed_chunk = (
        self.chunked_prefill_size is not None
        and self.server_args.enable_mixed_chunk
    )

9.1 How Chunking Works in PrefillAdder

When a request's input exceeds the remaining chunk budget:

else:
    # Chunked prefill -- truncate to remaining budget
    trunc_len = self.rem_chunk_tokens // self.page_size * self.page_size

    if trunc_len <= 0:
        return AddReqResult.OTHER

    # When deterministic inference is enabled, align truncation
    if truncation_align_size is not None:
        trunc_len = truncation_align_size * (trunc_len // truncation_align_size)

    # Truncate the request's fill_ids
    req.set_extend_input_len(trunc_len)
    req.fill_ids = req.fill_ids[:len(req.prefix_indices) + trunc_len]

    self.can_run_list.append(req)
    self.new_chunked_req = req  # Track for next iteration

9.2 Chunked Request Flow Through Iterations

# In get_new_batch_prefill: resume the chunked request first
if self.chunked_req is not None:
    self.chunked_req.init_next_round_input()  # Recompute fill_ids/prefix
    self.chunked_req = adder.add_chunked_req(self.chunked_req)

9.3 Mixed Chunked Prefill

When enable_mixed_chunk is true, the scheduler can combine a chunked prefill with ongoing decode requests in a single batch:

# Mixed-style chunked prefill
if (self.is_mixed_chunk
    and not self.running_batch.is_empty()
    and not (new_batch.return_logprob or self.running_batch.return_logprob)
    and new_batch.input_embeds is None):
    self.running_batch.filter_batch(v1_spec_info_filtered=True)
    if not self.running_batch.is_empty():
        self.running_batch.prepare_for_decode()
        new_batch.mix_with_running(self.running_batch)
        new_batch.decoding_reqs = self.running_batch.reqs
Mixed Chunked Prefill: This allows decode tokens to be processed in the same forward pass as prefill tokens. The batch uses ForwardMode.MIXED and the attention kernel handles both extend and decode positions in a single call.

10. Data-Parallel Scheduling

When running with dp_size > 1, a DataParallelController dispatches requests across multiple scheduler instances. It supports several load-balancing strategies:

class LoadBalanceMethod(Enum):
    """Load balance method."""
    ROUND_ROBIN = auto()
    FOLLOW_BOOTSTRAP_ROOM = auto()
    TOTAL_REQUESTS = auto()     # Dispatch to worker with fewest requests
    TOTAL_TOKENS = auto()       # Dispatch to worker with fewest tokens
class DPBudget:
    def __init__(self, dp_size):
        self.dp_size = dp_size
        self.total_requests = [0] * dp_size
        self.total_tokens = [0] * dp_size

    def dispatch(self, method):
        if method == LoadBalanceMethod.TOTAL_REQUESTS:
            target_rank = self.total_requests.index(min(self.total_requests))
        elif method == LoadBalanceMethod.TOTAL_TOKENS:
            target_rank = min(
                range(self.dp_size),
                key=lambda i: (self.total_tokens[i], self.total_requests[i]),
            )
        self.total_requests[target_rank] += 1  # Heuristic increment
        return target_rank
class DataParallelController:
    """A controller that dispatches requests to multiple DP workers."""

    def __init__(self, server_args, port_args, run_scheduler_process_func):
        self.load_balance_method = LoadBalanceMethod.from_str(
            server_args.load_balance_method
        )
        # ZMQ sockets to each DP worker
        self.workers = [None] * server_args.dp_size
        self.status = [True] * server_args.dp_size

        # Launch scheduler processes for each DP rank
        if server_args.enable_dp_attention:
            self.launch_dp_attention_schedulers(server_args, port_args)
        else:
            self.launch_dp_schedulers(server_args, port_args)
DP Attention Mode: When enable_dp_attention is set, all DP workers share the same tensor-parallel group but split the attention computation across workers. This requires careful synchronization -- the scheduler broadcasts work requests via the attention TP group and uses maybe_prepare_mlp_sync_batch() to coordinate MLP operations.

11. Key Data Structures Reference

11.1 Req -- Request Object

FieldTypeDescription
ridstrUnique request identifier
origin_input_idsList[int]Original tokenized input (with MM padding)
output_idsList[int]Generated token IDs so far
fill_idsList[int]origin_input_ids + output_ids (updated each round)
prefix_indicesTensorKV cache indices from prefix cache match
extend_input_lenintNumber of new tokens to process in this extend
last_nodeTreeNodeLast matched node in RadixCache
req_pool_idxintIndex into req_to_token_pool
kv_committed_lenintLength of committed KV cache
is_chunkedintChunk counter (>0 means still chunking)
finished_reasonBaseFinishReasonNone = still running; set when done
sampling_paramsSamplingParamsTemperature, top_p, max_new_tokens, etc.
grammarBaseGrammarObjectGrammar constraint state (for structured output)

11.2 ScheduleBatch -- CPU-Side Batch

FieldTypeDescription
reqsList[Req]Requests in this batch
forward_modeForwardModeEXTEND, DECODE, MIXED, IDLE, etc.
input_idsTensor [N]Flattened input token IDs
req_pool_indicesTensor [B]Indices into req_to_token_pool
seq_lensTensor [B]Sequence lengths (GPU)
seq_lens_cpuTensor [B]Sequence lengths (CPU mirror)
out_cache_locTensor [N]Output locations in KV pool
sampling_infoSamplingBatchInfoBatched sampling parameters
batch_is_fullboolSkip prefill check if True
prefix_lensList[int]Cached prefix lengths (extend mode)
extend_lensList[int]New tokens per request (extend mode)

11.3 ModelWorkerBatch -- CPU-to-GPU Bridge

FieldTypeDescription
forward_modeForwardModeForward pass type
input_idsTensorToken IDs for this batch
req_pool_indicesTensorRequest pool indices
seq_lensTensorSequence lengths (GPU)
out_cache_locTensorKV cache output locations
extend_num_tokensintTotal extend tokens in batch
extend_seq_lensList[int]Per-request extend sequence lengths
extend_prefix_lensList[int]Per-request cached prefix lengths
sampling_infoSamplingBatchInfoSampling parameters for GPU
multimodal_inputsListMultimodal data (images, video, audio)
lora_idsList[str]LoRA adapter IDs per request

11.4 TpModelWorker -- GPU Worker

class TpModelWorker(BaseTpWorker):
    """A tensor parallel model worker."""

    def __init__(self, server_args, gpu_id, tp_rank, ...):
        self.tp_rank = tp_rank
        self.gpu_id = gpu_id
        self._init_model_runner()         # Creates ModelRunner
        self.enable_overlap = not server_args.disable_overlap_schedule

    def forward_batch_generation(self, model_worker_batch, ...):
        # Convert ModelWorkerBatch -> ForwardBatch -> GPU forward
        forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
        out = self.model_runner.forward(forward_batch, ...)
        # Sample next tokens
        next_token_ids = self.model_runner.sample(out.logits_output, forward_batch)
        return GenerationBatchResult(
            logits_output=out.logits_output,
            next_token_ids=next_token_ids,
        )

11.5 Running Status Summary

def init_running_status(self):
    self.waiting_queue: List[Req] = []
    # The running decoding batch for continuous batching
    self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
    # The current forward batch
    self.cur_batch: Optional[ScheduleBatch] = None
    # The last forward batch
    self.last_batch: Optional[ScheduleBatch] = None
    self.forward_ct = 0

Summary

Key Takeaways:
  • SGLang's overlap scheduling (enabled by default) pipelines CPU result processing with GPU forward passes, eliminating idle CPU bubbles between batches.
  • The FutureMap stores token IDs as negative-index futures that are resolved when the GPU forward completes, enabling the scheduler to prepare the next batch immediately.
  • get_next_batch_to_run() always prioritizes prefill over decode, using a PrefillAdder to fit requests within token budgets and chunk long prompts.
  • The SchedulePolicy supports cache-aware policies (LPM, DFS-weight) that prioritize requests with longer prefix cache hits, plus in-batch prefix caching to improve hit rates.
  • Chunked prefill splits long prompts across iterations, and mixed-chunk mode allows decode tokens to ride along with prefill, keeping decode latency low.
  • The DataParallelController distributes requests across DP workers using pluggable load-balancing strategies (round-robin, min-requests, min-tokens).