SGLang Deep Dive: ModelRunner, Attention Backends & Speculative Decoding

A comprehensive source-code walkthrough of SGLang's GPU execution engine -- from model loading and CUDA graph capture, through FlashInfer/FlashAttention kernel dispatch, to EAGLE speculative decoding and fused Mixture-of-Experts Triton kernels.

Source: sglang/python/sglang/srt/ Updated: 2026-03-26 Based on SGLang main branch
← Back to AI Infra Overview

Table of Contents

  1. ModelRunner Initialization -- Model Loading, Memory, Backend Selection
  2. forward_extend() vs forward_decode() -- The Two Primary Forward Modes
  3. CUDA Graph Capture and Replay
  4. ForwardBatch -- Packing Tensors for the GPU
  5. FlashInfer Backend -- Memory Pool Integration
  6. FlashAttention 3/4 Backend -- Comparison with FlashInfer
  7. Sampling -- Temperature, Top-p, Top-k, Min-p
  8. EAGLE Speculative Decoding -- Draft/Verify Loop
  9. MoE Kernel Structure and Expert Parallelism
  10. Architecture Diagrams

1. ModelRunner Initialization

python/sglang/srt/model_executor/model_runner.py

ModelRunner is the central class that owns the model weights, KV cache memory pools, attention backend, CUDA graph runner, and sampler. It inherits from ModelRunnerKVCacheMixin which provides KV cache management methods.

1.1 Constructor Signature

class ModelRunner(ModelRunnerKVCacheMixin):
    """ModelRunner runs the forward passes of the models."""

    def __init__(
        self,
        model_config: ModelConfig,
        mem_fraction_static: float,
        gpu_id: int,
        tp_rank: int,
        tp_size: int,
        moe_ep_rank: int,
        moe_ep_size: int,
        pp_rank: int,
        pp_size: int,
        nccl_port: int,
        server_args: ServerArgs,
        dp_rank: Optional[int] = None,
        is_draft_worker: bool = False,
        req_to_token_pool: Optional[ReqToTokenPool] = None,
        token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
        ...
    ):

Key fields parsed during construction:

1.2 The initialize() Pipeline

The constructor calls self.initialize(pre_model_load_memory) which executes a strict sequence:

def initialize(self, pre_model_load_memory):
    # 1. Create sampler
    self.sampler = create_sampler()

    # 2. Load model weights (VLLM loader, safetensors, etc.)
    self.load_model()

    # 3. Configure KV cache dtype (fp8_e4m3, bf16, etc.)
    self.configure_kv_cache_dtype()

    # 4. Allocate KV cache memory pools
    self.init_memory_pool(pre_model_load_memory)

    # 5. Init attention backend (FlashInfer, FA3, Triton, etc.)
    self.init_attention_backend()

    # 6. Warmup kernels (FlashInfer autotune)
    self.kernel_warmup()

    # 7. Capture CUDA graphs for decode batch sizes
    self.init_device_graphs()

    # 8. Initialize piecewise CUDA graphs for extend
    self.init_piecewise_cuda_graphs()

1.3 Model Loading

def load_model(self):
    # Uses the configurable model loader system
    self.loader = get_model_loader(
        load_config=self.load_config,
        model_config=self.model_config,
    )
    self.model = self.loader.load_model(
        model_config=self.model_config,
        device_config=DeviceConfig(self.device, self.gpu_id),
    )
    # After loading: parse sliding_window_size, dtype, etc.
    self.sliding_window_size = None
    if hasattr(self.model, 'get_attention_sliding_window_size'):
        self.sliding_window_size = self.model.get_attention_sliding_window_size()

1.4 Attention Backend Selection

The backend is chosen via a registry pattern defined in attention_registry.py:

# attention_registry.py
ATTENTION_BACKENDS = {}

def register_attention_backend(name):
    def decorator(fn):
        ATTENTION_BACKENDS[name] = fn
        return fn
    return decorator

@register_attention_backend("flashinfer")
def create_flashinfer_backend(runner):
    if not runner.use_mla_backend:
        return FlashInferAttnBackend(runner)
    else:
        return FlashInferMLAAttnBackend(runner)

@register_attention_backend("fa3")
def create_flashattention_v3_backend(runner):
    return FlashAttentionBackend(runner)

@register_attention_backend("fa4")
def create_flashattention_v4_backend(runner):
    return FlashAttentionBackend(runner, fa_impl_ver=4)

@register_attention_backend("triton")
def create_triton_backend(runner): ...

@register_attention_backend("flashmla")
def create_flashmla_backend(runner): ...

@register_attention_backend("cutlass_mla")
def create_cutlass_mla_backend(runner): ...

The model runner dispatches to the selected backend:

def _get_attention_backend_from_str(self, backend_str, init_new_workspace=False):
    if backend_str not in ATTENTION_BACKENDS:
        raise ValueError(f"Invalid attention backend: {backend_str}")
    full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
    return attn_backend_wrapper(self, full_attention_backend)
Key insight: SGLang supports hybrid decode/prefill backends via HybridAttnBackend, allowing e.g. FlashMLA for decode + FlashInfer for prefill. The supported MLA backends list: ["aiter", "flashinfer", "fa3", "fa4", "triton", "flashmla", "cutlass_mla", "trtllm_mla", "ascend", "nsa"].

2. forward_extend() vs forward_decode()

python/sglang/srt/model_executor/model_runner.py : lines 2507-2588

2.1 The Central _forward_raw() Dispatcher

def _forward_raw(self, forward_batch, skip_attn_backend_init, pp_proxy_tensors, ...):
    # Check if we can use CUDA graph (decode + graph runner available)
    can_run_graph = bool(
        forward_batch.forward_mode.is_cuda_graph()
        and self.graph_runner
        and self.graph_runner.can_run(forward_batch)
    )

    if can_run_graph:
        ret = self.graph_runner.replay(forward_batch, ...)
        return ModelRunnerOutput(logits_output=ret, can_run_graph=can_run_graph)

    # Fallback to eager execution
    if forward_batch.forward_mode.is_decode():
        ret = self.forward_decode(forward_batch, ...)
    elif forward_batch.forward_mode.is_extend(include_draft_extend_v2=True):
        ret, can_run_graph = self.forward_extend(forward_batch, ...)
    elif forward_batch.forward_mode.is_idle():
        ret = self.forward_idle(forward_batch, ...)
    elif forward_batch.forward_mode.is_split_prefill():
        ret = self.forward_split_prefill(forward_batch, ...)

2.2 forward_decode()

Decode processes one new token per request. It is the hot path that benefits most from CUDA graphs.

def forward_decode(self, forward_batch, skip_attn_backend_init=False, pp_proxy_tensors=None):
    if not skip_attn_backend_init:
        self.attn_backend.init_forward_metadata(forward_batch)
    return self.model.forward(
        forward_batch.input_ids,
        forward_batch.positions,
        forward_batch,
        **kwargs,
    )

2.3 forward_extend()

Extend (prefill) processes variable-length token sequences. It also checks if piecewise CUDA graphs can be used:

def forward_extend(self, forward_batch, skip_attn_backend_init=False, pp_proxy_tensors=None):
    can_run_graph = (
        self.piecewise_cuda_graph_runner is not None
        and self.piecewise_cuda_graph_runner.can_run(forward_batch)
    )

    if can_run_graph:
        return (
            self.piecewise_cuda_graph_runner.replay(forward_batch, **kwargs),
            can_run_graph,
        )

    if not skip_attn_backend_init:
        self.attn_backend.init_forward_metadata(forward_batch)

    return (
        self.model.forward(forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs),
        can_run_graph,
    )

2.4 forward_idle()

In data-parallel attention, some workers may have no sequences. IDLE batches are padded for MLP sync:

def forward_idle(self, forward_batch, pp_proxy_tensors=None):
    # Reinit metadata for padded IDLE batches
    if forward_batch.batch_size > 0:
        self.attn_backend.init_forward_metadata(forward_batch)
    return self.model.forward(forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs)
ForwardMode enum: EXTEND (prefill), DECODE (autoregressive), MIXED (chunked prefill), IDLE (DP no-op), TARGET_VERIFY (spec decode verify), DRAFT_EXTEND (spec decode draft), SPLIT_PREFILL (PD multiplexing), DLLM_EXTEND (draft LLM).

3. CUDA Graph Capture and Replay

python/sglang/srt/model_executor/cuda_graph_runner.py

3.1 CudaGraphRunner Constructor

class CudaGraphRunner:
    """Runs the forward pass with cuda graph and torch.compile."""

    def __init__(self, model_runner: ModelRunner):
        self.graphs = {}            # bs -> captured CUDAGraph
        self.output_buffers = {}    # bs -> LogitsProcessorOutput

        # Determine capture forward mode
        self.capture_forward_mode = ForwardMode.DECODE
        self.num_tokens_per_bs = 1
        if model_runner.spec_algorithm.is_eagle():
            self.capture_forward_mode = ForwardMode.TARGET_VERIFY
            self.num_tokens_per_bs = server_args.speculative_num_draft_tokens

        # Batch sizes to capture
        self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
        # e.g., [1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, ...]

3.2 Batch Size Selection

def get_batch_sizes_to_capture(model_runner, num_tokens_per_bs=1):
    capture_bs = server_args.cuda_graph_bs
    num_max_requests = model_runner.req_to_token_pool.size

    # Filter: bs * tokens_per_bs must be multiple of attn_tp_size
    mul_base = 1
    if server_args.enable_two_batch_overlap:
        mul_base *= 2
    if require_gathered_buffer(server_args):
        mul_base *= get_attention_tp_size()

    capture_bs = [bs for bs in capture_bs if bs * num_tokens_per_bs % mul_base == 0]
    capture_bs = [bs for bs in capture_bs if bs <= num_max_requests]
    return sorted(set(capture_bs)), compile_bs

3.3 Capture: capture_one_batch_size()

For each batch size, a full forward pass is captured into a CUDA graph:

def capture_one_batch_size(self, bs, forward, stream_idx=None):
    buffers = self.buffers  # DecodeInputBuffers -- preallocated GPU tensors
    graph = self._create_device_graph()
    num_tokens = bs * self.num_tokens_per_bs

    # Slice pre-allocated buffers to this batch size
    input_ids = buffers.input_ids[:num_tokens]
    req_pool_indices = buffers.req_pool_indices[:bs]
    seq_lens = buffers.seq_lens[:bs]
    out_cache_loc = buffers.out_cache_loc[:num_tokens]
    positions = buffers.positions[:num_tokens]

    # Create ForwardBatch for capture
    forward_batch = ForwardBatch(
        forward_mode=self.capture_forward_mode,
        batch_size=bs,
        input_ids=input_ids,
        req_pool_indices=req_pool_indices,
        seq_lens=seq_lens,
        out_cache_loc=out_cache_loc,
        ...
    )

    # Init attention backend for capture
    attn_backend.init_forward_metadata_capture_cuda_graph(
        bs, num_tokens, req_pool_indices, seq_lens, encoder_lens,
        forward_batch.forward_mode, forward_batch.spec_info,
    )

    # Warm up (2 iterations) then capture
    for _ in range(2):
        self.device_module.synchronize()
        run_once()

    # Actual capture
    out = self._capture_graph(graph, get_global_graph_memory_pool(), stream, run_once)
    return graph, out

3.4 Replay: replay()

def replay(self, forward_batch, skip_attn_backend_init=False, pp_proxy_tensors=None):
    if not skip_attn_backend_init:
        self.replay_prepare(forward_batch, pp_proxy_tensors)

    # Find the right graph by batch size (bisect to nearest captured bs)
    graph_key = self.bs
    self.graphs[graph_key].replay()
    output = self.output_buffers[graph_key]

    # Slice output to actual batch size (remove padding)
    return LogitsProcessorOutput(
        next_token_logits=output.next_token_logits[:self.raw_num_token],
        hidden_states=output.hidden_states[:self.raw_num_token] if ... else None,
    )

3.5 replay_prepare() -- Padding and Buffer Copy

def replay_prepare(self, forward_batch, pp_proxy_tensors=None):
    raw_bs = forward_batch.batch_size
    # Round up to nearest captured batch size
    index = bisect.bisect_left(self.capture_bs, raw_bs)
    bs = self.capture_bs[index]

    # Copy forward_batch tensors into pre-allocated buffers
    buffers.populate_from_forward_batch(
        forward_batch=forward_batch,
        raw_bs=raw_bs,
        bs=bs,
        seq_len_fill_value=self.seq_len_fill_value,
        ...
    )

    # Update attention backend metadata for replay
    attn_backend.init_forward_metadata_replay_cuda_graph(
        bs, buffers.req_pool_indices[:bs], buffers.seq_lens[:bs], ...
    )
Memory pool: All CUDA graph runners share a single global_graph_memory_pool to avoid allocating duplicate GPU memory. The pool handle is created on first capture via torch.cuda.graph_pool_handle().

4. ForwardBatch -- Packing Tensors for the GPU

python/sglang/srt/model_executor/forward_batch_info.py

The data flow is: ScheduleBatch (CPU scheduler) → ModelWorkerBatch (TP worker) → ForwardBatch (GPU model runner).

4.1 ForwardBatch Fields

@dataclass
class ForwardBatch(ForwardBatchDeepSeekMHAMixin):
    """Store all inputs of a forward pass."""

    # Core fields
    forward_mode: ForwardMode        # EXTEND, DECODE, MIXED, IDLE, TARGET_VERIFY, ...
    batch_size: int
    input_ids: torch.Tensor         # [num_tokens] -- flattened token IDs
    req_pool_indices: torch.Tensor  # [batch_size] -- index into req_to_token_pool
    seq_lens: torch.Tensor          # [batch_size] -- current sequence lengths
    out_cache_loc: torch.Tensor     # [num_tokens] -- where to write KV cache
    seq_lens_sum: int               # sum of all seq_lens

    # Position encoding
    positions: torch.Tensor         # [num_tokens] -- absolute positions

    # Extend-specific (prefill)
    extend_num_tokens: Optional[int]           # total tokens being extended
    extend_seq_lens: Optional[torch.Tensor]    # [batch_size] -- new tokens per request
    extend_prefix_lens: Optional[torch.Tensor] # [batch_size] -- cached prefix per request
    extend_start_loc: Optional[torch.Tensor]   # [batch_size] -- start index in flattened

    # Attention backend reference
    attn_backend: Optional[AttentionBackend]
    req_to_token_pool: Optional[ReqToTokenPool]
    token_to_kv_pool: Optional[KVCache]

    # Speculative decoding
    spec_info: Optional[SpecInput]
    spec_algorithm: Optional[SpeculativeAlgorithm]

    # Sampling metadata
    sampling_info: Optional[SamplingBatchInfo]
    temperature: torch.Tensor
    top_p: torch.Tensor

4.2 DecodeInputBuffers -- Pre-allocated GPU Buffers

For CUDA graph capture, all input tensors must be pre-allocated. DecodeInputBuffers allocates maximally-sized tensors once:

@dataclass
class DecodeInputBuffers(ForwardInputBuffers):
    input_ids: torch.Tensor          # [max_num_token]
    req_pool_indices: torch.Tensor  # [max_bs]
    seq_lens: torch.Tensor          # [max_bs]
    out_cache_loc: torch.Tensor     # [max_num_token]
    positions: torch.Tensor         # [max_num_token]
    next_token_logits_buffer: torch.Tensor  # [max_num_token, vocab_size]

    @classmethod
    def create(cls, *, device, max_bs, max_num_token, hidden_size, vocab_size, dtype, ...):
        with torch.device(device):
            input_ids = torch.zeros((max_num_token,), dtype=torch.int64)
            req_pool_indices = torch.zeros((max_bs,), dtype=torch.int64)
            seq_lens = torch.full((max_bs,), seq_len_fill_value, dtype=torch.int32)
            ...

4.3 populate_from_forward_batch() -- Efficient Bulk Copy

During CUDA graph replay, data is copied from ForwardBatch into these buffers using a batched torch._foreach_copy_ grouped by dtype pairs:

def _grouped_foreach_copy_(dsts, srcs):
    """Call torch._foreach_copy_ grouped by (dst_dtype, src_dtype) pairs."""
    groups = {}
    for dst, src in zip(dsts, srcs):
        key = (dst.dtype, src.dtype)
        groups.setdefault(key, ([], []))
        groups[key][0].append(dst)
        groups[key][1].append(src)
    for group_dsts, group_srcs in groups.values():
        torch._foreach_copy_(group_dsts, group_srcs)

5. FlashInfer Backend

python/sglang/srt/layers/attention/flashinfer_backend.py

5.1 Architecture

FlashInfer uses paged KV cache with plan/execute separation. It wraps three types of wrappers:

5.2 Constructor

class FlashInferAttnBackend(AttentionBackend):
    def __init__(self, model_runner, skip_prefill=False, ...):
        self.decode_use_tensor_cores = should_use_tensor_core(
            kv_cache_dtype=model_runner.kv_cache_dtype,
            num_attention_heads=...,
            num_kv_heads=...,
        )
        self.max_context_len = model_runner.model_config.context_len

        # Allocate shared workspace buffer (default 256MB, up to 2GB for deterministic)
        global global_workspace_buffer
        if global_workspace_buffer is None:
            global_workspace_buffer = torch.empty(
                envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
                dtype=torch.uint8, device=model_runner.device,
            )
        self.workspace_buffer = global_workspace_buffer

        # Allocate kv_indptr for tracking page boundaries
        max_bs = model_runner.req_to_token_pool.size
        self.kv_indptr = [
            torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
            for _ in range(self.num_wrappers)
        ]

        # Create wrappers
        self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
            self.workspace_buffer, "NHD", backend=fmha_backend
        )
        for _ in range(self.num_wrappers):
            self.prefill_wrappers_paged.append(
                BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
            )
            self.decode_wrappers.append(
                BatchDecodeWithPagedKVCacheWrapper(
                    self.workspace_buffer, "NHD",
                    use_tensor_cores=self.decode_use_tensor_cores,
                )
            )

5.3 init_forward_metadata()

This method dispatches based on the forward mode, updating wrapper indices:

def init_forward_metadata(self, forward_batch):
    if forward_batch.forward_mode.is_decode_or_idle():
        self.indices_updater_decode.update(
            forward_batch.req_pool_indices,
            forward_batch.seq_lens,
            forward_batch.seq_lens_cpu,
            forward_batch.seq_lens_sum,
            decode_wrappers=self.decode_wrappers,
            ...
        )
        self.forward_metadata = DecodeMetadata(self.decode_wrappers)
    else:  # extend
        use_ragged = not self.enable_deterministic and not is_in_piecewise_cuda_graph()
        extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)

        self.indices_updater_prefill.update(
            forward_batch.req_pool_indices,
            forward_batch.seq_lens,
            forward_batch.seq_lens_cpu,
            forward_batch.seq_lens_sum,
            prefix_lens,
            prefill_wrappers=self.prefill_wrappers_paged,
            use_ragged=use_ragged,
            ...
        )
        self.forward_metadata = PrefillMetadata(
            self.prefill_wrappers_paged, use_ragged, extend_no_prefix
        )

5.4 Wrapper Dispatch (Sliding Window / Cross Attention)

FlashInfer supports dual wrappers for hybrid attention models:

if model_runner.sliding_window_size is not None:
    self.num_wrappers = 2
    self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
elif model_runner.model_config.is_encoder_decoder:
    self.num_wrappers = 2
    self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
else:
    self.num_wrappers = 1

5.5 CUDA Graph State for FlashInfer

def init_cuda_graph_state(self, max_bs, max_num_tokens):
    # Pre-allocate KV indices buffer for CUDA graph
    cuda_graph_kv_indices = torch.zeros(
        (max_num_tokens * self.max_context_len,),
        dtype=torch.int32, device="cuda",
    )
    self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [
        cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
    ]
Performance note: For Qwen2/Qwen3 models, FlashInfer workspace is automatically increased to 512MB. For deterministic inference, it goes up to 2GB with tensor-core decode forced on.

6. FlashAttention 3/4 Backend

python/sglang/srt/layers/attention/flashattention_backend.py

6.1 Key Differences from FlashInfer

FlashInfer Plan/Execute

  • Uses FlashInfer library wrappers
  • Plan phase: builds page tables, kv_indptr on GPU
  • Supports ragged KV for zero-prefix extends
  • Native paged KV cache integration
  • Workspace buffer shared across all wrappers

FlashAttention 3/4 Direct Kernel Calls

  • Uses sgl_kernel.flash_attn (FA3) or flash_attention_v4 (FA4)
  • Page table based -- builds block_table tensors
  • Native local attention via virtual batch splitting
  • SM80+ required for FA3; SM100+ for FA4
  • Metadata stored in FlashAttentionMetadata dataclass

6.2 FlashAttentionBackend Constructor

class FlashAttentionBackend(AttentionBackend):
    def __init__(self, model_runner, skip_prefill=False,
                 speculative_step_id=0, topk=0,
                 speculative_num_steps=0, fa_impl_ver=3):
        self.forward_metadata: FlashAttentionMetadata = None
        self.max_context_len = model_runner.model_config.context_len
        self.decode_cuda_graph_metadata = {}
        # fa_impl_ver selects between FA3 and FA4 kernel implementations

6.3 FlashAttentionMetadata

@dataclass
class FlashAttentionMetadata:
    cache_seqlens_int32: torch.Tensor   # [batch_size] -- KV cache seq lens
    max_seq_len_q: int                  # max query length
    max_seq_len_k: int                  # max key length
    cu_seqlens_q: torch.Tensor         # cumulative query lengths
    cu_seqlens_k: torch.Tensor         # cumulative key lengths
    window_size: tuple                  # sliding window (-1, -1) for full
    page_table: torch.Tensor           # [batch_size, max_pages] -- page indices

    # Local attention metadata (for chunked models like Qwen-1M)
    @dataclass
    class LocalAttentionMetadata:
        local_query_start_loc: torch.Tensor
        local_seqused_k: torch.Tensor
        local_block_table: torch.Tensor
        local_max_query_len: int
        local_max_seq_len: int

6.4 Local Attention (Virtual Batch Splitting)

FlashAttention implements local attention by breaking sequences into "virtual batches" aligned to attn_chunk_size:

def make_local_attention_virtual_batches(
    attn_chunk_size, query_start_loc_np, seq_lens_np, block_table, page_size
):
    # Example: chunk_size=4, q_seqlens=[4,10,5], k_seqlens=[6,17,9]
    # Result:  seqlens_q_local=[2,2,1,4,4,1,4,1]  (8 virtual batches)
    q_tokens_in_first_block = np.minimum(
        attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
        q_seqlens
    )
    local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
    ...
FA3 vs FA4: The backend selects the kernel implementation via fa_impl_ver parameter. FA3 uses sgl_kernel.flash_attn.flash_attn_varlen_func, while FA4 uses sglang.jit_kernel.flash_attention_v4.flash_attn_varlen_func. Both share the same FlashAttentionBackend class and metadata structure.

7. Sampling -- Temperature, Top-p, Top-k, Min-p

python/sglang/srt/layers/sampler.py

7.1 Sampler Architecture

class Sampler(nn.Module):
    def forward(self, logits_output, sampling_info, return_logprob, top_logprobs_nums, ...):
        logits = logits_output.next_token_logits

        # 1. Preprocess: custom logit processors + NaN detection
        logits = self._preprocess_logits(logits, sampling_info)

        # 2. Greedy fast path
        if sampling_info.is_all_greedy:
            batch_next_token_ids = torch.argmax(logits, -1)

        # 3. Standard sampling path
        else:
            # Apply temperature scaling (in-place division)
            logits.div_(sampling_info.temperatures)

            # Softmax to get probabilities (in-place for memory savings)
            logits[:] = torch.softmax(logits, dim=-1)
            probs = logits

            # Sample from probabilities
            batch_next_token_ids = self._sample_from_probs(
                probs, sampling_info, positions, simple_sampling_case
            )

7.2 FlashInfer Sampling Backend

def _sample_from_probs(self, probs, sampling_info, positions, simple_sampling_case):
    if simple_sampling_case:
        # Direct multinomial (no top-k/top-p/min-p)
        batch_next_token_ids = sampling_from_probs_torch(
            probs, sampling_seed=sampling_info.sampling_seed, positions=positions
        )
    else:
        backend = get_global_server_args().sampling_backend
        if backend == "flashinfer":
            if sampling_info.need_min_p_sampling:
                # Top-k renorm -> Top-p renorm -> Min-p sampling
                probs = top_k_renorm_prob(probs, sampling_info.top_ks)
                probs = top_p_renorm_prob(probs, sampling_info.top_ps)
                batch_next_token_ids = min_p_sampling_from_probs(probs, sampling_info.min_ps)
            else:
                # Joint top-k + top-p sampling
                batch_next_token_ids = top_k_top_p_sampling_from_probs(
                    probs.contiguous(),
                    sampling_info.top_ks,
                    sampling_info.top_ps,
                    filter_apply_order="joint",
                )
        elif backend == "pytorch":
            # Fallback with torch native operations
            batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
                probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps, ...
            )

7.3 Deterministic RL On-Policy Sampling

# For RL training: sample from log-softmax using Gumbel trick
def _sample_from_logprobs(self, logprobs, sampling_info, positions):
    """Deterministic sampling with seeded randomness."""
    assert sampling_info.sampling_seed is not None
    sampled_index = multinomial_with_seed(logprobs, sampling_info.sampling_seed, positions)
    return sampled_index.view(-1).to(torch.int32)

7.4 TP Synchronization

def _sync_token_ids_across_tp(self, batch_next_token_ids, sampling_info):
    # By default, SGLang skips this sync for performance.
    # Enable via SYNC_TOKEN_IDS_ACROSS_TP env var or when grammars are used.
    if SYNC_TOKEN_IDS_ACROSS_TP or sampling_info.grammars:
        torch.distributed.all_reduce(
            batch_next_token_ids,
            op=dist.ReduceOp.MIN,
            group=self.tp_sync_group,
        )

8. EAGLE Speculative Decoding

python/sglang/srt/speculative/eagle_worker.py

8.1 EAGLEWorker Structure

EAGLEWorker extends TpModelWorker and holds both a draft model (lightweight EAGLE head) and a reference to the target model:

class EAGLEWorker(TpModelWorker):
    def __init__(self, server_args, gpu_id, tp_rank, dp_rank, ..., target_worker):
        self.topk = server_args.speculative_eagle_topk
        self.speculative_num_steps = server_args.speculative_num_steps
        self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
        self.target_worker = target_worker

        # Share req_to_token_pool and token_to_kv_pool_allocator with target
        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            target_worker.get_memory_pool()
        )

        # Init draft model as a TpModelWorker (is_draft_worker=True)
        super().__init__(
            server_args=server_args, ..., is_draft_worker=True,
            req_to_token_pool=self.req_to_token_pool,
            token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
        )

        # Share embeddings and lm_head from target model
        embed, head = self.target_worker.model_runner.model.get_embed_and_head()
        self.draft_model_runner.model.set_embed_and_head(embed, head)

8.2 The forward_batch_generation() Entry Point

def forward_batch_generation(self, batch):
    if batch.forward_mode.is_extend():
        # PREFILL: target extend -> draft extend (fill draft KV cache)
        logits_output, next_token_ids, seq_lens_cpu, can_run_graph = (
            self.forward_target_extend(batch)
        )
        self.forward_draft_extend(batch, logits_output.hidden_states, next_token_ids, ...)
        return GenerationBatchResult(logits_output=logits_output, ...)

    else:
        # DECODE: draft multi-step -> target verify -> accept/reject
        spec_info = self.draft(batch)                  # draft multi-step forward
        logits_output, verify_output, _, can_run_graph = (
            self.verify(batch, spec_info)               # target verify
        )
        self.forward_draft_extend_after_decode(batch)  # update draft state
        return GenerationBatchResult(
            next_token_ids=verify_output.verified_id,
            num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu),
        )

8.3 Draft Multi-Step Forward

def draft_forward(self, forward_batch):
    spec_info = forward_batch.spec_info
    out_cache_loc = forward_batch.out_cache_loc
    topk_p, topk_index, hidden_states = spec_info.topk_p, spec_info.topk_index, spec_info.hidden_states

    score_list, token_list, parents_list = [], [], []

    # Forward multiple speculative steps
    for i in range(self.speculative_num_steps):
        # Select top-k tokens from current step
        input_ids, hidden_states, scores, tree_info = select_top_k_tokens(
            i, topk_p, topk_index, hidden_states, scores, self.topk
        )
        score_list.append(tree_info[0])
        token_list.append(tree_info[1])
        parents_list.append(tree_info[2])

        if i == self.speculative_num_steps - 1:
            break

        # Run draft model forward
        forward_batch.input_ids = input_ids
        forward_batch.out_cache_loc = out_cache_loc[i]
        forward_batch.positions.add_(1)
        forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]

        logits_output = self.draft_model_runner.forward(forward_batch, skip_attn_backend_init=True)
        probs = torch.softmax(logits_output.next_token_logits, dim=-1)
        topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
        hidden_states = logits_output.hidden_states

    # Organize results into tree structure
    parent_list, top_scores_index, draft_tokens = organize_draft_results(
        score_list, token_list, parents_list, self.speculative_num_draft_tokens
    )
    return parent_list, top_scores_index, draft_tokens

8.4 Tree Construction and Verification

# After draft, build the verification tree
(tree_mask, position, retrive_index, retrive_next_token,
 retrive_next_sibling, draft_tokens) = build_tree_kernel_efficient(
    spec_info.verified_id, parent_list, top_scores_index, draft_tokens,
    batch.seq_lens, batch.seq_lens_sum,
    self.topk, self.speculative_num_steps, self.speculative_num_draft_tokens,
)

# Verify: run target model on the tree
def verify(self, batch, spec_info):
    spec_info.prepare_for_verify(batch, self.page_size)
    batch.forward_mode = ForwardMode.TARGET_VERIFY
    batch.spec_info = spec_info

    # Forward with target model
    batch_result = self.target_worker.forward_batch_generation(model_worker_batch)
    logits_output = batch_result.logits_output

    # Verify: compare draft predictions with target logits
    res = spec_info.verify(batch, logits_output, self.token_to_kv_pool_allocator, ...)

    # Pick accepted indices
    logits_output.next_token_logits = logits_output.next_token_logits[res.accepted_indices]
    logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]

    # Update batch for next draft iteration
    batch.forward_mode = ForwardMode.DECODE
    batch.spec_info = res.draft_input
    return logits_output, res, model_worker_batch, can_run_cuda_graph
EAGLE3 vs EAGLE: EAGLE3 models have their own hot_token_id mapping built into the draft model config. They may also use use_aux_hidden_state to capture auxiliary hidden states from specific target model layers (configured via eagle_aux_hidden_state_layer_ids).

9. MoE Kernel Structure and Expert Parallelism

python/sglang/srt/layers/moe/fused_moe_triton/

9.1 Directory Structure

fused_moe_triton/
  __init__.py
  configs/                       # Kernel config tuning files per model
  fused_moe.py                   # Main fused MoE dispatch logic
  fused_moe_triton_config.py     # Config dtype selection and optimal config lookup
  fused_moe_triton_kernels.py    # Triton kernel definitions (@triton.jit)
  layer.py                       # FusedMoE nn.Module
  moe_align_block_size.py        # Token-to-expert alignment for blocked execution
  fused_marlin_moe.py            # Marlin-quantized MoE dispatch
  triton_kernels_moe.py          # Additional Triton MoE kernels

9.2 Core fused_moe.py Entry Point

# Register as custom op for torch.compile compatibility
@register_custom_op(mutates_args=["hidden_states"])
def inplace_fused_experts(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,            # gate/up projection weights [num_experts, 2*intermediate, hidden]
    w2: torch.Tensor,            # down projection weights [num_experts, hidden, intermediate]
    topk_weights: torch.Tensor,  # [num_tokens, topk] -- routing weights
    topk_ids: torch.Tensor,      # [num_tokens, topk] -- selected expert IDs
    activation: str = "silu",
    is_gated: bool = True,
    use_fp8_w8a8: bool = False,
    use_int8_w8a8: bool = False,
    w1_scale: Optional[torch.Tensor] = None,
    w2_scale: Optional[torch.Tensor] = None,
    ...
):

9.3 Triton Kernel Highlights

# fused_moe_triton_kernels.py
# SM90 optimization: swap_ab for better performance with small M, large N
@functools.lru_cache(maxsize=8)
def should_enable_swap_ab(BLOCK_SIZE_M, BLOCK_SIZE_N):
    return is_sm90_supported() and BLOCK_SIZE_M < 64 and BLOCK_SIZE_N >= 64

# Activation fused kernels from sgl_kernel:
# - silu_and_mul    (SiLU gating)
# - gelu_and_mul    (GELU gating)
# - moe_sum_reduce  (reduce expert outputs)

# Quantization support:
# - per_token_group_quant_fp8   (FP8 token-group quantization)
# - per_token_group_quant_int8  (INT8 token-group quantization)
# - scaled_fp8_quant            (scaled FP8 quantization)

9.4 Expert Parallelism (EP)

SGLang supports multiple expert parallelism strategies:

# From model_runner.py -- EPLB rebalance on rank faults
elastic_ep_state = ElasticEPStateManager.instance()
if elastic_ep_state is not None and not elastic_ep_state.is_active_equal_last():
    elastic_ep_state.snapshot_active_to_last()
    elastic_ep_state.sync_active_to_cpu()
    gen = self.eplb_manager.rebalance()
    while True:
        try: next(gen)
        except StopIteration: break

10. Architecture Diagrams

Diagram (a): Forward Pass Pipeline

SGLang Forward Pass Pipeline ScheduleBatch (CPU Scheduler) ModelWorkerBatch (TP Worker) ForwardBatch (GPU Tensors) _forward_raw() dispatcher CUDA Graph Replay graph_runner.replay() forward_decode() 1 token/req, hot path forward_extend() prefill, var-length forward_idle() DP no-op padding model.forward(input_ids, positions, batch) Attention + MLP layers LogitsProcessor + Sampler ModelRunnerOutput is_cuda_graph() = True Eager decode/extend/idle
Figure 1: Forward pass pipeline from ScheduleBatch through model execution to output

Diagram (b): CUDA Graph Capture and Replay

CUDA Graph Lifecycle CAPTURE PHASE (Server Startup) init_device_graphs() for bs in [1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, ...]: 1. Allocate DecodeInputBuffers (max_bs pre-allocated) 2. attn_backend.init_forward_metadata_capture_cuda_graph() 3. Warm up: run_once() x 2 iterations 4. torch.cuda.CUDAGraph.capture(run_once, pool=global_pool) 5. Store: self.graphs[bs] = graph, self.output_buffers[bs] = out REPLAY PHASE (Every Decode Step) 1. bisect_left(capture_bs, raw_bs) -- find padded bs 2. populate_from_forward_batch() -- _foreach_copy_ to buffers 3. attn_backend.init_forward_metadata_replay_cuda_graph() 4. self.graphs[bs].replay() -- GPU kernel replay! 5. Slice output[:raw_num_token] -- remove padding Memory Pool: global_graph_memory_pool Shared across all CudaGraphRunners Created via torch.cuda.graph_pool_handle()
Figure 2: CUDA graph capture (startup) vs replay (every decode step)

Diagram (c): Attention Backend Selection Tree

Attention Backend Selection (attention_registry.py) --attention-backend <name> "flashinfer" FlashInferAttn !use_mla_backend FlashInferMLA use_mla_backend "fa3" SM80+/SM90 "fa4" SM100+ (B200) FlashAttentionBackend "triton" TritonAttn DoubleSparse "flashmla" Additional Backends: "cutlass_mla" "trtllm_mla" "aiter" "wave" "nsa" "torch_native" "flex_attention" HybridAttnBackend Different backends for decode vs prefill (e.g., flashmla decode + flashinfer prefill) AttentionBackend (ABC) init_forward_metadata() | init_cuda_graph_state() | forward_extend() | forward_decode() MLA_ATTENTION_BACKENDS flashinfer, fa3, fa4, triton, flashmla, cutlass_mla, trtllm_mla, aiter, ascend, nsa SM Requirements FlashInfer: SM75+ FA3: SM80+ (H100/H200) FA4: SM100+ (B200)
Figure 3: Attention backend selection tree based on --attention-backend flag and model architecture

Quick Reference: Key Files

Core Execution

  • model_executor/model_runner.py -- ModelRunner class, forward dispatch
  • model_executor/forward_batch_info.py -- ForwardBatch, ForwardMode enum
  • model_executor/cuda_graph_runner.py -- CudaGraphRunner, DecodeInputBuffers

Attention Backends

  • layers/attention/base_attn_backend.py -- AttentionBackend ABC
  • layers/attention/attention_registry.py -- Backend registry
  • layers/attention/flashinfer_backend.py -- FlashInfer
  • layers/attention/flashattention_backend.py -- FA3/FA4

Sampling & Speculative

  • layers/sampler.py -- Sampler class, temp/top-p/top-k
  • speculative/eagle_worker.py -- EAGLEWorker, draft/verify
  • speculative/eagle_utils.py -- Tree building

MoE Kernels

  • layers/moe/fused_moe_triton/fused_moe.py -- Dispatch
  • layers/moe/fused_moe_triton/fused_moe_triton_kernels.py -- Triton JIT
  • layers/moe/fused_moe_triton/layer.py -- FusedMoE Module

Appendix: ForwardMode Complete Enum

class ForwardMode(IntEnum):
    EXTEND = auto()           # Prefill / extend with cached prefix
    DECODE = auto()           # Decode one token per request
    MIXED = auto()            # Chunked prefill: both extend and decode
    IDLE = auto()             # No sequences (DP attention padding)
    TARGET_VERIFY = auto()    # Speculative: verify in target model
    DRAFT_EXTEND = auto()     # Speculative: extend in draft model
    DRAFT_EXTEND_V2 = auto()  # Eagle V2 fixed-shape logits
    PREBUILT = auto()         # Disagg decode: KV cache ready
    SPLIT_PREFILL = auto()    # PD multiplexing split prefill
    DLLM_EXTEND = auto()      # Draft LLM extend

    def is_cuda_graph(self):
        # Only DECODE, TARGET_VERIFY, IDLE, and DLLM_EXTEND use CUDA graphs
        return self in (DECODE, TARGET_VERIFY, IDLE, DLLM_EXTEND)