Life of a Request in SGLang

An end-to-end trace through the SGLang Runtime (SRT), from HTTP ingestion to streamed token delivery, with deep dives into every process, data structure, and innovation along the way.

Multi-Process Architecture Overlap Scheduling RadixCache Prefix Sharing Extend Mode

Multi-Process Architecture

SGLang Runtime organizes itself into distinct processes communicating via ZeroMQ IPC sockets. The Engine class in entrypoints/engine.py orchestrates the launch of all subprocesses. From the docstring of the Engine class itself:

# From engine.py - Engine class docstring
class Engine(EngineBase):
    """
    The entry point to the inference engine.

    - The engine consists of three components:
        1. TokenizerManager: Tokenizes the requests and sends them
           to the scheduler.
        2. Scheduler (subprocess): Receives requests from the Tokenizer
           Manager, schedules batches, forwards them, and sends the
           output tokens to the Detokenizer Manager.
        3. DetokenizerManager (subprocess): Detokenizes the output
           tokens and sends the result back to the Tokenizer Manager.

    Note:
    1. The HTTP server, Engine, and TokenizerManager all run in the
       main process.
    2. Inter-process communication is done through IPC (each process
       uses a different port) via the ZMQ library.
    """

When Data Parallel (DP) mode is enabled, a Data Parallel Controller process is launched to distribute requests across multiple Scheduler instances, each with its own GPU worker. This gives us up to six distinct process roles:

Process Source File Runs In Responsibility
HTTP Server entrypoints/http_server.py Main process FastAPI endpoints, OpenAI-compatible API, request validation
TokenizerManager managers/tokenizer_manager.py Main process Tokenization, chat template, multimodal processing, request state tracking
DP Controller managers/data_parallel_controller.py Subprocess (optional) Routes requests to DP replicas based on load
Scheduler managers/scheduler.py Subprocess (per GPU) Batching, RadixCache, memory management, KV cache allocation
TpModelWorker managers/tp_worker.py Same as Scheduler Tensor parallel model forward, sampling
DetokenizerManager managers/detokenizer_manager.py Subprocess Token ID to text conversion, incremental decoding
MAIN PROCESS SUBPROCESSES (per DP rank) HTTP Server (FastAPI) /v1/chat/completions entrypoints/http_server.py TokenizerManager Tokenize + Chat Template Multimodal Processing managers/tokenizer_manager.py ReqState: rid_to_state[rid] asyncio.Event for await DP Controller (optional) Load-aware routing to DP ranks Scheduler Batching + RadixCache KV Memory Allocation event_loop_overlap managers/scheduler.py RadixCache Prefix Sharing Tree TpModelWorker forward_batch_generation() Sampling + Token Selection managers/tp_worker.py ModelRunner forward() + CUDA Graph GPU Attention MLP + Logits DetokenizerManager Token IDs -> Text Incremental Decoding managers/detokenizer_manager.py request ZMQ batch ZMQ ZMQ response

ZMQ IPC Channel Setup

Each process creates its ZMQ sockets during initialization. The TokenizerManager.init_ipc_channels() method shows how the tokenizer connects to both the scheduler and the detokenizer:

# From tokenizer_manager.py - TokenizerManager.init_ipc_channels()
def init_ipc_channels(self, port_args: PortArgs):
    context = zmq.asyncio.Context(2)
    self.recv_from_detokenizer = get_zmq_socket(
        context, zmq.PULL, port_args.tokenizer_ipc_name, True
    )
    self.send_to_scheduler = get_zmq_socket(
        context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
    )

The Scheduler similarly connects to both tokenizer and detokenizer:

# From scheduler.py - Scheduler.init_ipc_channels()
def init_ipc_channels(self, port_args: PortArgs):
    context = zmq.Context(2)
    self.recv_from_tokenizer = get_zmq_socket(
        context, zmq.PULL, port_args.scheduler_input_ipc_name, False
    )
    send_to_detokenizer = get_zmq_socket(
        context, zmq.PUSH, port_args.detokenizer_ipc_name, False
    )
    self.send_to_detokenizer = SenderWrapper(send_to_detokenizer)

Core Data Structures

A single request transforms through several data structure representations as it traverses the pipeline. The schedule_batch.py docstring describes this explicitly:

# From schedule_batch.py - Module docstring
"""
The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
  It is a subset of ScheduleBatch that only contains data related
  to the model forward on GPU.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of
  GPU tensors.
"""
Data Structure Owner Location Key Fields
GenerateReqInput HTTP Server io_struct.py text, input_ids, sampling_params, image_data, stream
TokenizedGenerateReqInput TokenizerManager io_struct.py input_text, input_ids, mm_inputs, sampling_params
Req Scheduler schedule_batch.py rid, origin_input_ids, output_ids, fill_ids, kv_committed_len
ScheduleBatch Scheduler schedule_batch.py reqs[], forward_mode, input_ids, seq_lens, out_cache_loc
ModelWorkerBatch TpModelWorker schedule_batch.py forward_mode, input_ids, req_pool_indices, seq_lens, out_cache_loc
ForwardBatch ModelRunner forward_batch_info.py GPU tensors for attention computation
BatchTokenIDOutput Scheduler io_struct.py rids[], output_ids[], finished_reasons[]
BatchStrOutput DetokenizerManager io_struct.py rids[], output_strs[], finished_reasons[]
ReqState TokenizerManager tokenizer_manager.py out_list, finished, event (asyncio.Event), text, output_ids

End-to-End Request Flow

1
HTTP Server Receives Request

entrypoints/http_server.py

A user sends a POST to /v1/chat/completions. FastAPI deserializes the JSON body into a ChatCompletionRequest, then the serving layer converts it to a GenerateReqInput -- the internal lingua franca for all generation requests.

# http_server.py
@app.post("/v1/chat/completions",
         dependencies=[Depends(validate_json_request)])
async def openai_v1_chat_completions(
    request: ChatCompletionRequest, raw_request: Request
):
    """OpenAI-compatible chat completion endpoint."""
    return await raw_request.app.state.openai_serving_chat.handle_request(
        request, raw_request
    )

The OpenAIServingChat.handle_request() applies the chat template, converts messages to a prompt string, builds GenerateReqInput, and calls tokenizer_manager.generate_request().

2
TokenizerManager: Tokenize and Dispatch

managers/tokenizer_manager.py

The generate_request() method is an async generator that orchestrates the entire request lifecycle from the main process perspective:

# tokenizer_manager.py - TokenizerManager.generate_request()
async def generate_request(self, obj, request=None):
    self.auto_create_handle_loop()

    # Normalize the request
    obj.normalize_batch_and_arguments()
    self._set_default_priority(obj)
    self._validate_rid(obj)

    # ... log, validate lora ...

    async with self.model_update_lock.reader_lock:
        await self._validate_and_resolve_lora(obj)

        # Tokenize the request and send it to the scheduler
        if obj.is_single:
            tokenized_obj = await self._tokenize_one_request(obj)
            state = self.rid_to_state[obj.rid]
            self._send_one_request(tokenized_obj)
            async for response in self._wait_one_response(obj, state, request):
                yield response

Tokenization Detail

The _tokenize_one_request() method handles three input modalities: text, pre-tokenized IDs, and input embeddings. For text, it calls the HuggingFace tokenizer. For multimodal inputs, it uses AsyncMMDataProcessor to parallelize image/video/audio preprocessing:

# tokenizer_manager.py - _tokenize_one_request()
async def _tokenize_one_request(self, obj):
    if obj.input_ids is not None:
        input_ids = obj.input_ids
    else:
        input_ids, token_type_ids = await self._tokenize_texts(
            input_text, is_cross_encoder_request
        )

    if self.mm_processor and obj.contains_mm_input():
        mm_inputs = await self.mm_data_processor.process(
            image_data=obj.image_data,
            audio_data=obj.audio_data,
            input_text_or_ids=(input_text or input_ids),
            request_obj=obj,
        )

ZMQ Dispatch

The tokenized request is sent to the scheduler via ZMQ PUSH socket as a TokenizedGenerateReqInput. The tokenizer manager then creates a ReqState with an asyncio.Event and awaits responses:

# tokenizer_manager.py - _send_one_request()
def _send_one_request(self, tokenized_obj):
    tokenized_obj.time_stats.set_api_server_dispatch_time()
    tokenized_obj = wrap_shm_features(tokenized_obj)
    self.send_to_scheduler.send_pyobj(tokenized_obj)  # ZMQ IPC
3
Scheduler: Batch Formation and Memory Allocation

managers/scheduler.py

The Scheduler runs an event loop (either event_loop_normal or event_loop_overlap) that continuously: (a) receives incoming requests, (b) builds the next batch, (c) runs the forward pass, and (d) processes results.

Request Reception

When a TokenizedGenerateReqInput arrives via ZMQ, the scheduler's dispatcher routes it to handle_generate_request(), which creates a Req object:

# scheduler.py - handle_generate_request()
def handle_generate_request(self, recv_req: TokenizedGenerateReqInput):
    req = Req(
        recv_req.rid,
        recv_req.input_text,
        recv_req.input_ids,
        recv_req.sampling_params,
        return_logprob=recv_req.return_logprob,
        top_logprobs_num=recv_req.top_logprobs_num,
        stream=recv_req.stream,
        lora_id=recv_req.lora_id,
        eos_token_ids=self.model_config.hf_eos_token_id,
        # ... more fields ...
    )

RadixCache Prefix Matching

Before adding a new request to the waiting queue, the scheduler queries the RadixCache for an existing prefix match. This is the heart of SGLang's automatic prefix sharing innovation:

# radix_cache.py - RadixCache.match_prefix()
def match_prefix(self, params: MatchPrefixParams) -> MatchResult:
    """Find the longest cached prefix of key in the radix tree.

    The logical namespace for prefix matching is determined by both
    the token id sequence and the optional extra_key carried by
    RadixKey. Entries that share identical leading token ids but
    have *different* extra_key values are intentionally kept disjoint.
    """
    key = params.key
    key, _ = self.maybe_bigram_convert(key)

    if self.disable or len(key) == 0:
        return empty_match_result()

    if self.page_size != 1:
        page_aligned_len = len(key) // self.page_size * self.page_size
        key = key[:page_aligned_len]
SGLang Innovation: RadixCache

The RadixCache stores a radix tree where each node holds a segment of token IDs and their associated KV cache pool indices. When a new request arrives, the tree is walked from the root, matching tokens page-by-page. The matched prefix KV cache is reused directly -- the request only needs to compute KV for the unmatched suffix. This is what makes SGLang's "Extend" mode possible: instead of computing the full prompt from scratch (traditional prefill), only the new tokens after the cached prefix need processing.

The RadixCache supports multiple eviction policies via a strategy pattern: LRU, LFU, FIFO, MRU, FILO, Priority, and SLRU.

Batch Construction: get_next_batch_to_run()

The scheduler merges new requests from the waiting queue with running requests that need another decode step. The PrefillAdder determines how many new requests to admit based on available memory and the configured max_prefill_tokens budget:

# scheduler.py - get_next_batch_to_run() (simplified)
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
    # Merge the prefill batch into the running batch
    if self.last_batch and self.last_batch.forward_mode.is_extend():
        self.last_batch.filter_batch(...)
        if not self.last_batch.is_empty():
            if self.running_batch.is_empty():
                self.running_batch = self.last_batch
            else:
                self.running_batch.merge_batch(self.last_batch)

    # Get next batch from policy
    batch = self.get_new_batch_prefill()  # or decode if no new prefills
    return batch

KV Cache Memory Allocation

For each request in a batch, the scheduler allocates KV cache memory from the token pool. Extend requests only allocate for the uncached suffix -- the prefix is already in the RadixCache:

# schedule_batch.py / common.py - alloc_for_extend()
# Allocates out_cache_loc for only the new (extend) tokens:
# extend_num_tokens = total_seq_len - prefix_len (from radix cache match)
# This is what makes "extend" different from full "prefill"
4
TpModelWorker and ModelRunner: GPU Forward Pass

managers/tp_worker.py and model_executor/model_runner.py

The scheduler calls run_batch(), which prepares a ModelWorkerBatch from the ScheduleBatch and dispatches it to the TpModelWorker:

# scheduler.py - run_batch()
def run_batch(self, batch: ScheduleBatch) -> GenerationBatchResult:
    self.forward_ct += 1
    if self.is_generation:
        worker_batch_or_batch = batch.get_model_worker_batch()
        result = self.tp_worker.forward_batch_generation(worker_batch_or_batch)
    return result

Inside TpModelWorker.forward_batch_generation(), the ModelWorkerBatch is converted to a ForwardBatch (GPU tensors), then the model runner executes the forward pass:

# tp_worker.py - TpModelWorker.forward_batch_generation()
def forward_batch_generation(self, model_worker_batch, ...):
    # Convert to ForwardBatch (GPU tensors)
    forward_batch = ForwardBatch.init_new(
        model_worker_batch, self.model_runner
    )

    # Run model forward
    out = self.model_runner.forward(forward_batch)
    logits_output, can_run_cuda_graph = out.logits_output, out.can_run_graph

    batch_result = GenerationBatchResult(
        logits_output=logits_output,
        can_run_cuda_graph=can_run_cuda_graph,
    )

    # Sample next tokens
    if not self.enable_overlap:
        batch_result.sample_output = self.model_runner.sample(...)
    return batch_result

ModelRunner.forward() -- CUDA Graph or Eager

The model runner decides whether to use a captured CUDA graph (for decode) or run the model eagerly (for extend/prefill):

# model_runner.py - ModelRunner._forward_raw()
def _forward_raw(self, forward_batch, ...):
    can_run_graph = bool(
        forward_batch.forward_mode.is_cuda_graph()
        and self.graph_runner
        and self.graph_runner.can_run(forward_batch)
    )

    if can_run_graph:
        ret = self.graph_runner.replay(forward_batch)
        return ModelRunnerOutput(logits_output=ret, can_run_graph=True)

    # Eager forward path (used for extend/prefill)
    # Runs full transformer: embedding -> attention -> MLP -> logits
    ...
SGLang Innovation: Extend Mode vs Standard Prefill

Traditional LLM serving engines process each new request by running a full "prefill" pass over all input tokens. SGLang introduces Extend mode -- when the RadixCache finds a matching prefix, only the unmatched suffix tokens go through the forward pass. The KV cache for the matched prefix is already in GPU memory and is reused directly.

This is implemented through the ForwardMode.EXTEND flag. Both new requests (with partial prefix matches) and continuing decode requests can coexist in the same batch using ragged tensors -- each request contributes a different number of tokens to the batch. The attention kernel handles variable-length sequences within the batch via extend_num_tokens and per-request seq_lens metadata.

5
Scheduler: Process Results and Send Output

managers/scheduler.py

After the GPU forward pass returns, the scheduler processes the results: it appends the sampled token to each request's output_ids, checks stop conditions (EOS token, max length, stop strings), and for finished or streaming requests, constructs a BatchTokenIDOutput that is sent to the DetokenizerManager via ZMQ:

# scheduler.py - process_batch_result() sends to detokenizer
# The batch result contains sampled token IDs for each request.
# For streaming or finished requests, a BatchTokenIDOutput is
# constructed and sent via ZMQ to the detokenizer:
#   self.send_to_detokenizer.send_pyobj(output)
6
DetokenizerManager: Tokens to Text

managers/detokenizer_manager.py

The detokenizer runs a simple synchronous event loop, receiving BatchTokenIDOutput objects from the scheduler and converting token IDs to text using the HuggingFace tokenizer:

# detokenizer_manager.py - DetokenizerManager.event_loop()
def event_loop(self):
    """The event loop that handles requests"""
    while True:
        with self.soft_watchdog.disable():
            recv_obj = self.recv_from_scheduler.recv_pyobj()
        output = self._request_dispatcher(recv_obj)
        if output is not None:
            self.send_to_tokenizer.send_pyobj(output)
        self.soft_watchdog.feed()

The handle_batch_token_id_out() method performs incremental detokenization, maintaining a DecodeStatus per request to handle partial UTF-8 sequences and special tokens correctly:

# detokenizer_manager.py - handle_batch_token_id_out()
def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
    output_strs = self._decode_batch_token_id_output(recv_obj)
    return BatchStrOutput(
        rids=recv_obj.rids,
        finished_reasons=recv_obj.finished_reasons,
        output_strs=output_strs,
        output_ids=recv_obj.output_ids,
        prompt_tokens=recv_obj.prompt_tokens,
        completion_tokens=recv_obj.completion_tokens,
        cached_tokens=recv_obj.cached_tokens,
        # ... logprobs, metrics, etc. ...
    )

The DecodeStatus tracks the state of incremental decoding for each request:

# detokenizer_manager.py
@dataclasses.dataclass
class DecodeStatus:
    """Store the status of incremental decoding."""
    decoded_text: str
    decode_ids: List[int]
    surr_offset: int
    read_offset: int
    sent_offset: int = 0
7
TokenizerManager: Deliver Response

managers/tokenizer_manager.py

The BatchStrOutput arrives back at the TokenizerManager via ZMQ. A background loop matches each response to its ReqState (keyed by rid), updates the state, and signals the asyncio.Event. The _wait_one_response() coroutine that has been awaiting this event wakes up and yields the response to the FastAPI handler, which sends it to the client:

# tokenizer_manager.py - ReqState dataclass
@dataclasses.dataclass
class ReqState:
    """Store the state a request."""
    out_list: List[Dict]
    finished: bool
    event: asyncio.Event      # Signaled when output arrives
    obj: Union[GenerateReqInput, EmbeddingReqInput]
    time_stats: APIServerReqTimeStats
    text: str = ""
    output_ids: List[int] = field(default_factory=list)
    # ... logprobs fields ...

For streaming requests, the tokenizer manager yields partial responses as SSE (Server-Sent Events) each time new tokens arrive. For non-streaming, it waits for the finished flag and returns the complete response as a single JSON object.

Overlap Scheduling

SGLang Innovation: CPU/GPU Overlap

SGLang's overlap scheduling is enabled by default (disable_overlap_schedule=False). It pipelines GPU execution with CPU-side batch preparation, so the CPU work for step N+1 happens while the GPU is computing step N. This hides scheduling, memory management, and sampling overhead behind GPU compute.

Normal Event Loop (Baseline)

The non-overlapped loop is sequential: recv -> batch -> forward -> process results -> repeat.

# scheduler.py - event_loop_normal()
def event_loop_normal(self):
    """A normal scheduler loop."""
    while True:
        # Receive requests
        recv_reqs = self.recv_requests()
        self.process_input_requests(recv_reqs)

        # Get the next batch to run
        batch = self.get_next_batch_to_run()
        self.cur_batch = batch

        # Launch the current batch
        if batch:
            result = self.run_batch(batch)
            self.process_batch_result(batch, result)
        else:
            self.self_check_during_idle()

        self.last_batch = batch

Overlapped Event Loop

The overlapped loop uses a result_queue to defer result processing. While the GPU computes the current batch, the CPU processes the previous batch's results (including sampling, stop checking, and ZMQ dispatch to the detokenizer):

# scheduler.py - event_loop_overlap()
def event_loop_overlap(self):
    """A scheduler loop that overlaps CPU processing and GPU computation."""
    self.result_queue: Deque = deque()

    def pop_and_process():
        tmp_batch, tmp_result = self.result_queue.popleft()
        self.process_batch_result(tmp_batch, tmp_result)

    while True:
        recv_reqs = self.recv_requests()
        self.process_input_requests(recv_reqs)

        # Get the next batch
        batch = self.get_next_batch_to_run()
        disable_overlap_for_batch = self.is_disable_overlap_for_batch(batch)

        # If we must NOT overlap, process last batch NOW
        if disable_overlap_for_batch:
            pop_and_process()

        # Launch GPU forward for current batch
        if batch:
            batch_result = self.run_batch(batch)
            self.result_queue.append((batch.copy(), batch_result))

        # Process LAST batch results while GPU runs CURRENT batch
        if self.last_batch:
            if not disable_overlap_for_batch:
                pop_and_process()

        # Sample (depends on grammar from last batch)
        if self.is_generation:
            self.launch_batch_sample_if_needed(batch_result)

        self.last_batch = batch

When Overlap is Disabled

The scheduler disables overlap for specific scenarios to avoid correctness issues or reduce TTFT:

# scheduler.py - is_disable_overlap_for_batch()
def is_disable_overlap_for_batch(self, batch):
    # For two consecutive prefill batches, disable overlap
    # to improve the TTFT of the first batch.
    batch_is_extend = batch and batch.forward_mode.is_extend()
    last_batch_is_extend = self.last_batch and self.last_batch.forward_mode.is_extend()

    disable = (
        envs.SGLANG_DISABLE_CONSECUTIVE_PREFILL_OVERLAP.get()
        and batch_is_extend
        and last_batch_is_extend
    )

Overlap Timing Diagram

Normal (Sequential) CPU: Batch N GPU: Forward N CPU: Process N CPU: Batch N+1 GPU: Forward N+1 Overlap (Pipelined) CPU: Batch N CPU: Process N-1 + Batch N+1 CPU: Process N + Batch N+2 GPU: Forward N GPU: Forward N+1 GPU: Forward N+2 CPU hidden behind GPU

RadixCache: Automatic Prefix Sharing

The RadixCache is a radix tree (compressed trie) where each node stores a token sequence segment and a tensor of KV cache pool indices. This data structure enables O(n) prefix matching where n is the length of the common prefix.

TreeNode Structure

# radix_cache.py - TreeNode
class TreeNode:
    def __init__(self, id=None, priority=0):
        self.children = defaultdict(TreeNode)
        self.parent: TreeNode = None
        self.key: RadixKey = None       # Token IDs for this segment
        self.value: torch.Tensor = None # KV cache pool indices
        self.lock_ref = 0               # Reference count (in-use)
        self.last_access_time = time.monotonic()
        self.hit_count = 0
        self.host_value: torch.Tensor = None  # HiCache offload
        self.hash_value: List[str] = None      # SHA256 page hashes
        self.priority = priority

Key Matching

Page-aligned key matching compares tokens in chunks of page_size for memory-efficient allocation:

# radix_cache.py
def _key_match_paged(key0: RadixKey, key1: RadixKey, page_size: int):
    _check_extra_key(key0, key1)
    min_len = min(len(key0), len(key1))
    i = 0
    while i < min_len:
        if key0.token_ids[i : i + page_size] != key1.token_ids[i : i + page_size]:
            break
        i += page_size
    return i

Eviction Strategies

# radix_cache.py - RadixCache.__init__()
if self.eviction_policy == "lru":
    self.eviction_strategy = LRUStrategy()
elif self.eviction_policy == "lfu":
    self.eviction_strategy = LFUStrategy()
elif self.eviction_policy == "fifo":
    self.eviction_strategy = FIFOStrategy()
elif self.eviction_policy == "priority":
    self.eviction_strategy = PriorityStrategy()
elif self.eviction_policy == "slru":
    self.eviction_strategy = SLRUStrategy()

Extra Key Isolation

The RadixKey carries an optional extra_key field that isolates cache entries for different LoRA adapters, cache salts, or other logical namespaces:

# radix_cache.py - RadixKey
class RadixKey:
    def __init__(self, token_ids, extra_key=None, is_bigram=False):
        self.token_ids = token_ids      # Token IDs sequence
        self.extra_key = extra_key      # e.g. lora_id, cache_salt
        self.is_bigram = is_bigram      # For EAGLE speculative decoding

The Req Class: Per-Request State Machine

Every in-flight request in the scheduler is represented by a Req object. It tracks the full lifecycle from initial token IDs through decode iterations to finish conditions:

# schedule_batch.py - Req class (key fields)
class Req(ReqDllmMixin):
    """The input and output status of a request."""

    def __init__(self, rid, origin_input_text, origin_input_ids,
                 sampling_params, ...):
        # Input and output info
        self.rid = rid
        self.origin_input_text = origin_input_text
        self.origin_input_ids = origin_input_ids
        self.output_ids = []
        self.fill_ids = []  # origin_input_ids + output_ids

        # KV cache memory management
        self.kv_committed_len = 0
        self.kv_allocated_len = 0

        # Prefix info (set during scheduling)
        self.prefix_indices = None  # From RadixCache match
        self.extend_input_len = 0   # Tokens to process (past prefix)

        # Finish conditions
        self.finished_reason = None
        self.sampling_params = sampling_params
Field Type Purpose
ridstrUnique request ID (UUID hex)
origin_input_idsList[int]Full tokenized prompt (after padding for multimodal)
output_idsList[int]Accumulated generated token IDs
fill_idsList[int]origin_input_ids + output_ids (for chunked prefill tracking)
kv_committed_lenintHow many tokens have committed KV cache entries
kv_allocated_lenintHow many KV cache slots are allocated for this request
prefix_indicesTensorKV cache indices from RadixCache prefix match
extend_input_lenintNumber of new tokens to process (total - prefix)
finished_reasonBaseFinishReasonNone while running; set on EOS/length/stop/abort

Finish Conditions

# schedule_batch.py - Finish reason classes
class FINISH_MATCHED_TOKEN(BaseFinishReason):  # EOS or stop token matched
class FINISH_MATCHED_STR(BaseFinishReason):    # Stop string matched
class FINISH_LENGTH(BaseFinishReason):         # max_new_tokens reached
class FINISH_ABORT(BaseFinishReason):          # Error or cancellation

ScheduleBatch and Forward Modes

The ScheduleBatch is the master batch container that the scheduler manages. It holds lists of requests and all the tensor metadata needed to drive a forward pass:

# schedule_batch.py - ScheduleBatch (key fields)
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
    """Store all information of a batch on the scheduler."""

    reqs: List[Req]
    req_to_token_pool: ReqToTokenPool = None
    token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
    tree_cache: BasePrefixCache = None

    # Batch configs
    forward_mode: ForwardMode = None
    enable_overlap: bool = False
    batch_is_full: bool = False

    # Batched arguments to model runner
    input_ids: torch.Tensor = None     # shape: [total_tokens], int64
    req_pool_indices: torch.Tensor = None  # shape: [batch_size], int64
    seq_lens: torch.Tensor = None     # shape: [batch_size], int64
    out_cache_loc: torch.Tensor = None # shape: [total_new_tokens], int64

ForwardMode Enum

Mode When Used Description
EXTEND New request (partial or full prompt) Process multiple tokens per request. Ragged batch where each request may contribute a different number of tokens. This is SGLang's "prefill" -- but only for uncached tokens.
DECODE Continuing generation Process exactly 1 token per request. All requests produce the next token. CUDA graph eligible for maximum throughput.
IDLE No work to do Scheduler loop iteration with no active requests.

Process Launch Sequence

The Engine._launch_subprocesses() classmethod orchestrates all process creation. The launch order ensures dependencies are satisfied:

# engine.py - Engine._launch_subprocesses() (simplified)
@classmethod
def _launch_subprocesses(cls, server_args, ...):
    """Launch TokenizerManager in main process,
    Scheduler in a subprocess, DetokenizerManager in another."""

    # 1. Allocate IPC ports
    port_args = PortArgs.init_new(server_args)

    # 2. Launch scheduler subprocess(es)
    #    - If dp_size > 1, launches DP controller + N schedulers
    #    - Each scheduler creates its TpModelWorker internally
    scheduler_init_result = cls._launch_scheduler_processes(
        server_args, port_args, run_scheduler_process_func
    )

    # 3. Launch detokenizer subprocess
    detoken_proc = mp.Process(
        target=run_detokenizer_process_func,
        args=(server_args, port_args)
    )

    # 4. Initialize TokenizerManager in main process
    tokenizer_manager, template_manager = init_tokenizer_manager_func(
        server_args, port_args
    )

    return tokenizer_manager, template_manager, port_args, scheduler_init_result
Port Allocation: PortArgs.init_new() allocates unique IPC paths for scheduler-input, tokenizer, detokenizer, and NCCL ports.
Scheduler Subprocess: run_scheduler_process() creates the Scheduler, which internally initializes TpModelWorker, ModelRunner, loads model weights, profiles memory, sets up CUDA graphs, and creates the RadixCache.
Detokenizer Subprocess: run_detokenizer_process() creates DetokenizerManager with its own tokenizer instance and ZMQ sockets.
TokenizerManager (Main Process): init_tokenizer_manager() initializes tokenizer, processor, and ZMQ sockets in the main process alongside the HTTP server.
HTTP Server: uvicorn starts the FastAPI app in the same main process as the TokenizerManager.

Scheduler Initialization

The Scheduler constructor performs a complex multi-step initialization that sets up the entire GPU-side infrastructure. Here is the initialization sequence from the actual code:

# scheduler.py - Scheduler.__init__() initialization steps
class Scheduler(...):
    def __init__(self, server_args, port_args, gpu_id, tp_rank, ...):
        self.init_model_config()
        self.init_metrics(tp_rank, pp_rank, dp_rank)
        self.init_ipc_channels(port_args)
        self.init_tokenizer()
        self.init_moe_gemm_config()
        self.init_mamba_backend()

        # This launches TpModelWorker + ModelRunner
        self.init_model_worker()

        # Memory pool + RadixCache creation
        self.init_cache_with_memory_pool()

        self.init_running_status()
        self.init_chunked_prefill()
        self.init_schedule_policy()
        self.init_disaggregation()
        self.init_overlap()
        self.init_request_dispatcher()

        # Constrained generation (grammar-guided decoding)
        self.grammar_manager = GrammarManager(self)

Cache Creation

The scheduler selects the appropriate cache implementation based on configuration. The default is RadixCache, with specialized variants for chunked prefill, hierarchical caching, hybrid SWA, and more:

# scheduler.py - init_cache_with_memory_pool() (selection logic)
if chunked_prefill and disable_radix_cache:
    self.tree_cache = ChunkCache(params)
elif enable_hierarchical_cache:
    self.tree_cache = HiRadixCache(params, server_args)
elif is_hybrid_swa:
    self.tree_cache = SWARadixCache(params)
elif enable_lmcache:
    self.tree_cache = LMCRadixCache(params, ...)
else:
    self.tree_cache = RadixCache(params)  # Default

TpModelWorker and ModelRunner

The TpModelWorker is instantiated inside the Scheduler and owns the ModelRunner. It bridges the scheduler's batch abstraction with the GPU execution layer:

# tp_worker.py - TpModelWorker.__init__()
class TpModelWorker(BaseTpWorker):
    def __init__(self, server_args, gpu_id, tp_rank, ...):
        self._init_model_config()
        self._init_model_runner()  # Creates ModelRunner

        # Profile available tokens
        self.max_total_num_tokens = self.model_runner.max_total_num_tokens
        self.max_running_requests = self.model_runner.max_running_requests
        self.max_req_len = min(
            self.model_config.context_len - 1,
            self.model_runner.max_token_pool_size - 1,
        )

    def _init_model_runner(self):
        self._model_runner = ModelRunner(
            model_config=self.model_config,
            mem_fraction_static=self.server_args.mem_fraction_static,
            gpu_id=self.gpu_id,
            tp_rank=self.tp_rank,
            tp_size=self.tp_size,
            # ... more distributed info ...
            server_args=self.server_args,
        )

ModelRunner.forward()

The ModelRunner's forward() method is the final step before GPU kernels execute. It handles CUDA graph replay for decode (high throughput) and eager execution for extend (variable-length prefill):

# model_runner.py - ModelRunner.forward()
def forward(self, forward_batch, ...) -> ModelRunnerOutput:
    self.forward_pass_id += 1
    output = self._forward_raw(forward_batch, ...)
    return output

def _forward_raw(self, forward_batch, ...):
    can_run_graph = bool(
        forward_batch.forward_mode.is_cuda_graph()
        and self.graph_runner
        and self.graph_runner.can_run(forward_batch)
    )

    if can_run_graph:
        # CUDA Graph replay: fastest path for decode batches
        ret = self.graph_runner.replay(forward_batch)
        return ModelRunnerOutput(logits_output=ret, can_run_graph=True)

    # Eager forward: used for extend/prefill with variable lengths
    # Runs: Embedding -> N x (Attention + MLP) -> LogitsProcessor
    ...

Request Lifecycle Summary

Stage Process Source File Key Method Data In Data Out
1. HTTP Main http_server.py openai_v1_chat_completions() JSON body GenerateReqInput
2. Tokenize Main tokenizer_manager.py generate_request() GenerateReqInput TokenizedGenerateReqInput
3. Dispatch Main tokenizer_manager.py _send_one_request() TokenizedGenerateReqInput ZMQ PUSH
4. Schedule Scheduler scheduler.py handle_generate_request() TokenizedGenerateReqInput Req + waiting_queue
5. Batch Scheduler scheduler.py get_next_batch_to_run() waiting_queue + running_batch ScheduleBatch
6. Forward Scheduler tp_worker.py forward_batch_generation() ModelWorkerBatch GenerationBatchResult
7. GPU Scheduler model_runner.py forward() / _forward_raw() ForwardBatch LogitsProcessorOutput
8. Sample Scheduler model_runner.py sampler.forward() Logits Token IDs
9. Detokenize Detokenizer detokenizer_manager.py handle_batch_token_id_out() BatchTokenIDOutput BatchStrOutput
10. Respond Main tokenizer_manager.py _wait_one_response() BatchStrOutput HTTP JSON / SSE

Key SGLang Innovations

1. Extend Mode (Ragged Tensor Prefill)

Instead of the traditional "prefill entire prompt, then decode" two-phase approach, SGLang uses ForwardMode.EXTEND which processes only the uncached suffix of a prompt. The RadixCache provides the prefix KV cache indices, and the extend forward pass computes only the new tokens. This is especially powerful for chat applications where the system prompt and conversation history are frequently reused.

The extend batch is a ragged tensor -- each request in the batch contributes a different number of tokens. This is handled by the attention backend which uses per-request seq_lens and extend_start_loc metadata.

2. Overlap Scheduling (CPU/GPU Pipelining)

The event_loop_overlap() schedules GPU computation for batch N while the CPU processes results from batch N-1. This hides the latency of sampling, stop-condition checking, RadixCache updates, ZMQ dispatch, and new request reception behind GPU compute time. For workloads with small decode batches where GPU compute is fast, this can significantly improve throughput.

3. RadixCache Automatic Prefix Sharing

The RadixCache is a compressed trie (radix tree) that automatically identifies and shares KV cache prefixes across requests. When a new request matches a cached prefix, its KV cache is reused with zero recomputation. The tree supports page-aligned matching (for efficient memory allocation), multiple eviction policies (LRU, LFU, FIFO, Priority, SLRU), and namespace isolation via extra_key for LoRA adapters or cache salts.

4. Multi-Process Architecture with ZMQ IPC

SGLang separates concerns across processes connected by ZMQ PUSH/PULL sockets: tokenization (I/O bound) runs in the main process with async Python, GPU scheduling and forward passes run in a dedicated subprocess (one per GPU), and detokenization runs in its own subprocess. This isolation prevents CPU-bound tokenization or detokenization from blocking GPU scheduling, and vice versa. For multi-GPU setups, a Data Parallel Controller distributes requests across replicas.

Source File Reference

File Path (relative to python/sglang/srt/) Lines Role
entrypoints/engine.py ~700 Engine class, subprocess launch orchestration
entrypoints/http_server.py ~1300 FastAPI endpoints, OpenAI/Anthropic/Ollama compatibility
managers/tokenizer_manager.py ~1500 Tokenization, multimodal preprocessing, request state
managers/scheduler.py ~2800 Core scheduler loop, batching, memory management
managers/schedule_batch.py ~2600 Req, ScheduleBatch, ModelWorkerBatch data classes
managers/tp_worker.py ~600 TpModelWorker: bridges scheduler to ModelRunner
model_executor/model_runner.py ~3000 ModelRunner: forward pass, CUDA graphs, weight loading
managers/detokenizer_manager.py ~450 Token ID to text, incremental decoding
managers/io_struct.py ~900 All inter-process data structures (GenerateReqInput, etc.)
mem_cache/radix_cache.py ~800 RadixCache: prefix tree for KV cache sharing