A comprehensive source-code walkthrough of SGLang's GPU execution engine -- from model loading and CUDA graph capture, through FlashInfer/FlashAttention kernel dispatch, to EAGLE speculative decoding and fused Mixture-of-Experts Triton kernels.
python/sglang/srt/model_executor/model_runner.py
ModelRunner is the central class that owns the model weights, KV cache memory pools,
attention backend, CUDA graph runner, and sampler. It inherits from ModelRunnerKVCacheMixin
which provides KV cache management methods.
class ModelRunner(ModelRunnerKVCacheMixin):
"""ModelRunner runs the forward passes of the models."""
def __init__(
self,
model_config: ModelConfig,
mem_fraction_static: float,
gpu_id: int,
tp_rank: int,
tp_size: int,
moe_ep_rank: int,
moe_ep_size: int,
pp_rank: int,
pp_size: int,
nccl_port: int,
server_args: ServerArgs,
dp_rank: Optional[int] = None,
is_draft_worker: bool = False,
req_to_token_pool: Optional[ReqToTokenPool] = None,
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
...
):
Key fields parsed during construction:
self.spec_algorithm -- parsed from server_args.speculative_algorithm (EAGLE, EAGLE3, NGRAM, etc.)self.use_mla_backend -- set when model_config.attention_arch == AttentionArch.MLA (DeepSeek-V2/V3 MLA)self.page_size -- paged KV cache page sizeself.is_hybrid_swa -- sliding-window attention hybrid modelsinitialize() PipelineThe constructor calls self.initialize(pre_model_load_memory) which executes a strict sequence:
def initialize(self, pre_model_load_memory):
# 1. Create sampler
self.sampler = create_sampler()
# 2. Load model weights (VLLM loader, safetensors, etc.)
self.load_model()
# 3. Configure KV cache dtype (fp8_e4m3, bf16, etc.)
self.configure_kv_cache_dtype()
# 4. Allocate KV cache memory pools
self.init_memory_pool(pre_model_load_memory)
# 5. Init attention backend (FlashInfer, FA3, Triton, etc.)
self.init_attention_backend()
# 6. Warmup kernels (FlashInfer autotune)
self.kernel_warmup()
# 7. Capture CUDA graphs for decode batch sizes
self.init_device_graphs()
# 8. Initialize piecewise CUDA graphs for extend
self.init_piecewise_cuda_graphs()
def load_model(self):
# Uses the configurable model loader system
self.loader = get_model_loader(
load_config=self.load_config,
model_config=self.model_config,
)
self.model = self.loader.load_model(
model_config=self.model_config,
device_config=DeviceConfig(self.device, self.gpu_id),
)
# After loading: parse sliding_window_size, dtype, etc.
self.sliding_window_size = None
if hasattr(self.model, 'get_attention_sliding_window_size'):
self.sliding_window_size = self.model.get_attention_sliding_window_size()
The backend is chosen via a registry pattern defined in attention_registry.py:
# attention_registry.py
ATTENTION_BACKENDS = {}
def register_attention_backend(name):
def decorator(fn):
ATTENTION_BACKENDS[name] = fn
return fn
return decorator
@register_attention_backend("flashinfer")
def create_flashinfer_backend(runner):
if not runner.use_mla_backend:
return FlashInferAttnBackend(runner)
else:
return FlashInferMLAAttnBackend(runner)
@register_attention_backend("fa3")
def create_flashattention_v3_backend(runner):
return FlashAttentionBackend(runner)
@register_attention_backend("fa4")
def create_flashattention_v4_backend(runner):
return FlashAttentionBackend(runner, fa_impl_ver=4)
@register_attention_backend("triton")
def create_triton_backend(runner): ...
@register_attention_backend("flashmla")
def create_flashmla_backend(runner): ...
@register_attention_backend("cutlass_mla")
def create_cutlass_mla_backend(runner): ...
The model runner dispatches to the selected backend:
def _get_attention_backend_from_str(self, backend_str, init_new_workspace=False):
if backend_str not in ATTENTION_BACKENDS:
raise ValueError(f"Invalid attention backend: {backend_str}")
full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
return attn_backend_wrapper(self, full_attention_backend)
HybridAttnBackend,
allowing e.g. FlashMLA for decode + FlashInfer for prefill.
The supported MLA backends list: ["aiter", "flashinfer", "fa3", "fa4", "triton", "flashmla", "cutlass_mla", "trtllm_mla", "ascend", "nsa"].
python/sglang/srt/model_executor/model_runner.py : lines 2507-2588
_forward_raw() Dispatcherdef _forward_raw(self, forward_batch, skip_attn_backend_init, pp_proxy_tensors, ...):
# Check if we can use CUDA graph (decode + graph runner available)
can_run_graph = bool(
forward_batch.forward_mode.is_cuda_graph()
and self.graph_runner
and self.graph_runner.can_run(forward_batch)
)
if can_run_graph:
ret = self.graph_runner.replay(forward_batch, ...)
return ModelRunnerOutput(logits_output=ret, can_run_graph=can_run_graph)
# Fallback to eager execution
if forward_batch.forward_mode.is_decode():
ret = self.forward_decode(forward_batch, ...)
elif forward_batch.forward_mode.is_extend(include_draft_extend_v2=True):
ret, can_run_graph = self.forward_extend(forward_batch, ...)
elif forward_batch.forward_mode.is_idle():
ret = self.forward_idle(forward_batch, ...)
elif forward_batch.forward_mode.is_split_prefill():
ret = self.forward_split_prefill(forward_batch, ...)
Decode processes one new token per request. It is the hot path that benefits most from CUDA graphs.
def forward_decode(self, forward_batch, skip_attn_backend_init=False, pp_proxy_tensors=None):
if not skip_attn_backend_init:
self.attn_backend.init_forward_metadata(forward_batch)
return self.model.forward(
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
**kwargs,
)
Extend (prefill) processes variable-length token sequences. It also checks if piecewise CUDA graphs can be used:
def forward_extend(self, forward_batch, skip_attn_backend_init=False, pp_proxy_tensors=None):
can_run_graph = (
self.piecewise_cuda_graph_runner is not None
and self.piecewise_cuda_graph_runner.can_run(forward_batch)
)
if can_run_graph:
return (
self.piecewise_cuda_graph_runner.replay(forward_batch, **kwargs),
can_run_graph,
)
if not skip_attn_backend_init:
self.attn_backend.init_forward_metadata(forward_batch)
return (
self.model.forward(forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs),
can_run_graph,
)
In data-parallel attention, some workers may have no sequences. IDLE batches are padded for MLP sync:
def forward_idle(self, forward_batch, pp_proxy_tensors=None):
# Reinit metadata for padded IDLE batches
if forward_batch.batch_size > 0:
self.attn_backend.init_forward_metadata(forward_batch)
return self.model.forward(forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs)
python/sglang/srt/model_executor/cuda_graph_runner.py
class CudaGraphRunner:
"""Runs the forward pass with cuda graph and torch.compile."""
def __init__(self, model_runner: ModelRunner):
self.graphs = {} # bs -> captured CUDAGraph
self.output_buffers = {} # bs -> LogitsProcessorOutput
# Determine capture forward mode
self.capture_forward_mode = ForwardMode.DECODE
self.num_tokens_per_bs = 1
if model_runner.spec_algorithm.is_eagle():
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
self.num_tokens_per_bs = server_args.speculative_num_draft_tokens
# Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
# e.g., [1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, ...]
def get_batch_sizes_to_capture(model_runner, num_tokens_per_bs=1):
capture_bs = server_args.cuda_graph_bs
num_max_requests = model_runner.req_to_token_pool.size
# Filter: bs * tokens_per_bs must be multiple of attn_tp_size
mul_base = 1
if server_args.enable_two_batch_overlap:
mul_base *= 2
if require_gathered_buffer(server_args):
mul_base *= get_attention_tp_size()
capture_bs = [bs for bs in capture_bs if bs * num_tokens_per_bs % mul_base == 0]
capture_bs = [bs for bs in capture_bs if bs <= num_max_requests]
return sorted(set(capture_bs)), compile_bs
capture_one_batch_size()For each batch size, a full forward pass is captured into a CUDA graph:
def capture_one_batch_size(self, bs, forward, stream_idx=None):
buffers = self.buffers # DecodeInputBuffers -- preallocated GPU tensors
graph = self._create_device_graph()
num_tokens = bs * self.num_tokens_per_bs
# Slice pre-allocated buffers to this batch size
input_ids = buffers.input_ids[:num_tokens]
req_pool_indices = buffers.req_pool_indices[:bs]
seq_lens = buffers.seq_lens[:bs]
out_cache_loc = buffers.out_cache_loc[:num_tokens]
positions = buffers.positions[:num_tokens]
# Create ForwardBatch for capture
forward_batch = ForwardBatch(
forward_mode=self.capture_forward_mode,
batch_size=bs,
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
out_cache_loc=out_cache_loc,
...
)
# Init attention backend for capture
attn_backend.init_forward_metadata_capture_cuda_graph(
bs, num_tokens, req_pool_indices, seq_lens, encoder_lens,
forward_batch.forward_mode, forward_batch.spec_info,
)
# Warm up (2 iterations) then capture
for _ in range(2):
self.device_module.synchronize()
run_once()
# Actual capture
out = self._capture_graph(graph, get_global_graph_memory_pool(), stream, run_once)
return graph, out
replay()def replay(self, forward_batch, skip_attn_backend_init=False, pp_proxy_tensors=None):
if not skip_attn_backend_init:
self.replay_prepare(forward_batch, pp_proxy_tensors)
# Find the right graph by batch size (bisect to nearest captured bs)
graph_key = self.bs
self.graphs[graph_key].replay()
output = self.output_buffers[graph_key]
# Slice output to actual batch size (remove padding)
return LogitsProcessorOutput(
next_token_logits=output.next_token_logits[:self.raw_num_token],
hidden_states=output.hidden_states[:self.raw_num_token] if ... else None,
)
def replay_prepare(self, forward_batch, pp_proxy_tensors=None):
raw_bs = forward_batch.batch_size
# Round up to nearest captured batch size
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
# Copy forward_batch tensors into pre-allocated buffers
buffers.populate_from_forward_batch(
forward_batch=forward_batch,
raw_bs=raw_bs,
bs=bs,
seq_len_fill_value=self.seq_len_fill_value,
...
)
# Update attention backend metadata for replay
attn_backend.init_forward_metadata_replay_cuda_graph(
bs, buffers.req_pool_indices[:bs], buffers.seq_lens[:bs], ...
)
global_graph_memory_pool
to avoid allocating duplicate GPU memory. The pool handle is created on first capture via
torch.cuda.graph_pool_handle().
python/sglang/srt/model_executor/forward_batch_info.py
The data flow is: ScheduleBatch (CPU scheduler) → ModelWorkerBatch (TP worker) → ForwardBatch (GPU model runner).
@dataclass
class ForwardBatch(ForwardBatchDeepSeekMHAMixin):
"""Store all inputs of a forward pass."""
# Core fields
forward_mode: ForwardMode # EXTEND, DECODE, MIXED, IDLE, TARGET_VERIFY, ...
batch_size: int
input_ids: torch.Tensor # [num_tokens] -- flattened token IDs
req_pool_indices: torch.Tensor # [batch_size] -- index into req_to_token_pool
seq_lens: torch.Tensor # [batch_size] -- current sequence lengths
out_cache_loc: torch.Tensor # [num_tokens] -- where to write KV cache
seq_lens_sum: int # sum of all seq_lens
# Position encoding
positions: torch.Tensor # [num_tokens] -- absolute positions
# Extend-specific (prefill)
extend_num_tokens: Optional[int] # total tokens being extended
extend_seq_lens: Optional[torch.Tensor] # [batch_size] -- new tokens per request
extend_prefix_lens: Optional[torch.Tensor] # [batch_size] -- cached prefix per request
extend_start_loc: Optional[torch.Tensor] # [batch_size] -- start index in flattened
# Attention backend reference
attn_backend: Optional[AttentionBackend]
req_to_token_pool: Optional[ReqToTokenPool]
token_to_kv_pool: Optional[KVCache]
# Speculative decoding
spec_info: Optional[SpecInput]
spec_algorithm: Optional[SpeculativeAlgorithm]
# Sampling metadata
sampling_info: Optional[SamplingBatchInfo]
temperature: torch.Tensor
top_p: torch.Tensor
For CUDA graph capture, all input tensors must be pre-allocated. DecodeInputBuffers allocates
maximally-sized tensors once:
@dataclass
class DecodeInputBuffers(ForwardInputBuffers):
input_ids: torch.Tensor # [max_num_token]
req_pool_indices: torch.Tensor # [max_bs]
seq_lens: torch.Tensor # [max_bs]
out_cache_loc: torch.Tensor # [max_num_token]
positions: torch.Tensor # [max_num_token]
next_token_logits_buffer: torch.Tensor # [max_num_token, vocab_size]
@classmethod
def create(cls, *, device, max_bs, max_num_token, hidden_size, vocab_size, dtype, ...):
with torch.device(device):
input_ids = torch.zeros((max_num_token,), dtype=torch.int64)
req_pool_indices = torch.zeros((max_bs,), dtype=torch.int64)
seq_lens = torch.full((max_bs,), seq_len_fill_value, dtype=torch.int32)
...
During CUDA graph replay, data is copied from ForwardBatch into these buffers using
a batched torch._foreach_copy_ grouped by dtype pairs:
def _grouped_foreach_copy_(dsts, srcs):
"""Call torch._foreach_copy_ grouped by (dst_dtype, src_dtype) pairs."""
groups = {}
for dst, src in zip(dsts, srcs):
key = (dst.dtype, src.dtype)
groups.setdefault(key, ([], []))
groups[key][0].append(dst)
groups[key][1].append(src)
for group_dsts, group_srcs in groups.values():
torch._foreach_copy_(group_dsts, group_srcs)
python/sglang/srt/layers/attention/flashinfer_backend.py
FlashInfer uses paged KV cache with plan/execute separation. It wraps three types of wrappers:
BatchPrefillWithRaggedKVCacheWrapper -- for extend when all prefix lengths are zero (new sequences)BatchPrefillWithPagedKVCacheWrapper -- for extend with cached prefixBatchDecodeWithPagedKVCacheWrapper -- for decode (one token per request)class FlashInferAttnBackend(AttentionBackend):
def __init__(self, model_runner, skip_prefill=False, ...):
self.decode_use_tensor_cores = should_use_tensor_core(
kv_cache_dtype=model_runner.kv_cache_dtype,
num_attention_heads=...,
num_kv_heads=...,
)
self.max_context_len = model_runner.model_config.context_len
# Allocate shared workspace buffer (default 256MB, up to 2GB for deterministic)
global global_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.empty(
envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
dtype=torch.uint8, device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer
# Allocate kv_indptr for tracking page boundaries
max_bs = model_runner.req_to_token_pool.size
self.kv_indptr = [
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
for _ in range(self.num_wrappers)
]
# Create wrappers
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD", backend=fmha_backend
)
for _ in range(self.num_wrappers):
self.prefill_wrappers_paged.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
)
self.decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD",
use_tensor_cores=self.decode_use_tensor_cores,
)
)
This method dispatches based on the forward mode, updating wrapper indices:
def init_forward_metadata(self, forward_batch):
if forward_batch.forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_cpu,
forward_batch.seq_lens_sum,
decode_wrappers=self.decode_wrappers,
...
)
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
else: # extend
use_ragged = not self.enable_deterministic and not is_in_piecewise_cuda_graph()
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_cpu,
forward_batch.seq_lens_sum,
prefix_lens,
prefill_wrappers=self.prefill_wrappers_paged,
use_ragged=use_ragged,
...
)
self.forward_metadata = PrefillMetadata(
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
)
FlashInfer supports dual wrappers for hybrid attention models:
if model_runner.sliding_window_size is not None:
self.num_wrappers = 2
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
elif model_runner.model_config.is_encoder_decoder:
self.num_wrappers = 2
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
else:
self.num_wrappers = 1
def init_cuda_graph_state(self, max_bs, max_num_tokens):
# Pre-allocate KV indices buffer for CUDA graph
cuda_graph_kv_indices = torch.zeros(
(max_num_tokens * self.max_context_len,),
dtype=torch.int32, device="cuda",
)
self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
]
python/sglang/srt/layers/attention/flashattention_backend.py
sgl_kernel.flash_attn (FA3) or flash_attention_v4 (FA4)block_table tensorsFlashAttentionMetadata dataclassclass FlashAttentionBackend(AttentionBackend):
def __init__(self, model_runner, skip_prefill=False,
speculative_step_id=0, topk=0,
speculative_num_steps=0, fa_impl_ver=3):
self.forward_metadata: FlashAttentionMetadata = None
self.max_context_len = model_runner.model_config.context_len
self.decode_cuda_graph_metadata = {}
# fa_impl_ver selects between FA3 and FA4 kernel implementations
@dataclass
class FlashAttentionMetadata:
cache_seqlens_int32: torch.Tensor # [batch_size] -- KV cache seq lens
max_seq_len_q: int # max query length
max_seq_len_k: int # max key length
cu_seqlens_q: torch.Tensor # cumulative query lengths
cu_seqlens_k: torch.Tensor # cumulative key lengths
window_size: tuple # sliding window (-1, -1) for full
page_table: torch.Tensor # [batch_size, max_pages] -- page indices
# Local attention metadata (for chunked models like Qwen-1M)
@dataclass
class LocalAttentionMetadata:
local_query_start_loc: torch.Tensor
local_seqused_k: torch.Tensor
local_block_table: torch.Tensor
local_max_query_len: int
local_max_seq_len: int
FlashAttention implements local attention by breaking sequences into "virtual batches" aligned to attn_chunk_size:
def make_local_attention_virtual_batches(
attn_chunk_size, query_start_loc_np, seq_lens_np, block_table, page_size
):
# Example: chunk_size=4, q_seqlens=[4,10,5], k_seqlens=[6,17,9]
# Result: seqlens_q_local=[2,2,1,4,4,1,4,1] (8 virtual batches)
q_tokens_in_first_block = np.minimum(
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
q_seqlens
)
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
...
fa_impl_ver parameter.
FA3 uses sgl_kernel.flash_attn.flash_attn_varlen_func, while FA4 uses
sglang.jit_kernel.flash_attention_v4.flash_attn_varlen_func. Both share the same
FlashAttentionBackend class and metadata structure.
python/sglang/srt/layers/sampler.py
class Sampler(nn.Module):
def forward(self, logits_output, sampling_info, return_logprob, top_logprobs_nums, ...):
logits = logits_output.next_token_logits
# 1. Preprocess: custom logit processors + NaN detection
logits = self._preprocess_logits(logits, sampling_info)
# 2. Greedy fast path
if sampling_info.is_all_greedy:
batch_next_token_ids = torch.argmax(logits, -1)
# 3. Standard sampling path
else:
# Apply temperature scaling (in-place division)
logits.div_(sampling_info.temperatures)
# Softmax to get probabilities (in-place for memory savings)
logits[:] = torch.softmax(logits, dim=-1)
probs = logits
# Sample from probabilities
batch_next_token_ids = self._sample_from_probs(
probs, sampling_info, positions, simple_sampling_case
)
def _sample_from_probs(self, probs, sampling_info, positions, simple_sampling_case):
if simple_sampling_case:
# Direct multinomial (no top-k/top-p/min-p)
batch_next_token_ids = sampling_from_probs_torch(
probs, sampling_seed=sampling_info.sampling_seed, positions=positions
)
else:
backend = get_global_server_args().sampling_backend
if backend == "flashinfer":
if sampling_info.need_min_p_sampling:
# Top-k renorm -> Top-p renorm -> Min-p sampling
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids = min_p_sampling_from_probs(probs, sampling_info.min_ps)
else:
# Joint top-k + top-p sampling
batch_next_token_ids = top_k_top_p_sampling_from_probs(
probs.contiguous(),
sampling_info.top_ks,
sampling_info.top_ps,
filter_apply_order="joint",
)
elif backend == "pytorch":
# Fallback with torch native operations
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps, ...
)
# For RL training: sample from log-softmax using Gumbel trick
def _sample_from_logprobs(self, logprobs, sampling_info, positions):
"""Deterministic sampling with seeded randomness."""
assert sampling_info.sampling_seed is not None
sampled_index = multinomial_with_seed(logprobs, sampling_info.sampling_seed, positions)
return sampled_index.view(-1).to(torch.int32)
def _sync_token_ids_across_tp(self, batch_next_token_ids, sampling_info):
# By default, SGLang skips this sync for performance.
# Enable via SYNC_TOKEN_IDS_ACROSS_TP env var or when grammars are used.
if SYNC_TOKEN_IDS_ACROSS_TP or sampling_info.grammars:
torch.distributed.all_reduce(
batch_next_token_ids,
op=dist.ReduceOp.MIN,
group=self.tp_sync_group,
)
python/sglang/srt/speculative/eagle_worker.py
EAGLEWorker extends TpModelWorker and holds both a draft model
(lightweight EAGLE head) and a reference to the target model:
class EAGLEWorker(TpModelWorker):
def __init__(self, server_args, gpu_id, tp_rank, dp_rank, ..., target_worker):
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
self.target_worker = target_worker
# Share req_to_token_pool and token_to_kv_pool_allocator with target
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
target_worker.get_memory_pool()
)
# Init draft model as a TpModelWorker (is_draft_worker=True)
super().__init__(
server_args=server_args, ..., is_draft_worker=True,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
# Share embeddings and lm_head from target model
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
self.draft_model_runner.model.set_embed_and_head(embed, head)
def forward_batch_generation(self, batch):
if batch.forward_mode.is_extend():
# PREFILL: target extend -> draft extend (fill draft KV cache)
logits_output, next_token_ids, seq_lens_cpu, can_run_graph = (
self.forward_target_extend(batch)
)
self.forward_draft_extend(batch, logits_output.hidden_states, next_token_ids, ...)
return GenerationBatchResult(logits_output=logits_output, ...)
else:
# DECODE: draft multi-step -> target verify -> accept/reject
spec_info = self.draft(batch) # draft multi-step forward
logits_output, verify_output, _, can_run_graph = (
self.verify(batch, spec_info) # target verify
)
self.forward_draft_extend_after_decode(batch) # update draft state
return GenerationBatchResult(
next_token_ids=verify_output.verified_id,
num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu),
)
def draft_forward(self, forward_batch):
spec_info = forward_batch.spec_info
out_cache_loc = forward_batch.out_cache_loc
topk_p, topk_index, hidden_states = spec_info.topk_p, spec_info.topk_index, spec_info.hidden_states
score_list, token_list, parents_list = [], [], []
# Forward multiple speculative steps
for i in range(self.speculative_num_steps):
# Select top-k tokens from current step
input_ids, hidden_states, scores, tree_info = select_top_k_tokens(
i, topk_p, topk_index, hidden_states, scores, self.topk
)
score_list.append(tree_info[0])
token_list.append(tree_info[1])
parents_list.append(tree_info[2])
if i == self.speculative_num_steps - 1:
break
# Run draft model forward
forward_batch.input_ids = input_ids
forward_batch.out_cache_loc = out_cache_loc[i]
forward_batch.positions.add_(1)
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
logits_output = self.draft_model_runner.forward(forward_batch, skip_attn_backend_init=True)
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
hidden_states = logits_output.hidden_states
# Organize results into tree structure
parent_list, top_scores_index, draft_tokens = organize_draft_results(
score_list, token_list, parents_list, self.speculative_num_draft_tokens
)
return parent_list, top_scores_index, draft_tokens
# After draft, build the verification tree
(tree_mask, position, retrive_index, retrive_next_token,
retrive_next_sibling, draft_tokens) = build_tree_kernel_efficient(
spec_info.verified_id, parent_list, top_scores_index, draft_tokens,
batch.seq_lens, batch.seq_lens_sum,
self.topk, self.speculative_num_steps, self.speculative_num_draft_tokens,
)
# Verify: run target model on the tree
def verify(self, batch, spec_info):
spec_info.prepare_for_verify(batch, self.page_size)
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = spec_info
# Forward with target model
batch_result = self.target_worker.forward_batch_generation(model_worker_batch)
logits_output = batch_result.logits_output
# Verify: compare draft predictions with target logits
res = spec_info.verify(batch, logits_output, self.token_to_kv_pool_allocator, ...)
# Pick accepted indices
logits_output.next_token_logits = logits_output.next_token_logits[res.accepted_indices]
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
# Update batch for next draft iteration
batch.forward_mode = ForwardMode.DECODE
batch.spec_info = res.draft_input
return logits_output, res, model_worker_batch, can_run_cuda_graph
hot_token_id mapping built into the draft model config.
They may also use use_aux_hidden_state to capture auxiliary hidden states from specific target model layers
(configured via eagle_aux_hidden_state_layer_ids).
python/sglang/srt/layers/moe/fused_moe_triton/
fused_moe_triton/
__init__.py
configs/ # Kernel config tuning files per model
fused_moe.py # Main fused MoE dispatch logic
fused_moe_triton_config.py # Config dtype selection and optimal config lookup
fused_moe_triton_kernels.py # Triton kernel definitions (@triton.jit)
layer.py # FusedMoE nn.Module
moe_align_block_size.py # Token-to-expert alignment for blocked execution
fused_marlin_moe.py # Marlin-quantized MoE dispatch
triton_kernels_moe.py # Additional Triton MoE kernels
# Register as custom op for torch.compile compatibility
@register_custom_op(mutates_args=["hidden_states"])
def inplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor, # gate/up projection weights [num_experts, 2*intermediate, hidden]
w2: torch.Tensor, # down projection weights [num_experts, hidden, intermediate]
topk_weights: torch.Tensor, # [num_tokens, topk] -- routing weights
topk_ids: torch.Tensor, # [num_tokens, topk] -- selected expert IDs
activation: str = "silu",
is_gated: bool = True,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
...
):
# fused_moe_triton_kernels.py
# SM90 optimization: swap_ab for better performance with small M, large N
@functools.lru_cache(maxsize=8)
def should_enable_swap_ab(BLOCK_SIZE_M, BLOCK_SIZE_N):
return is_sm90_supported() and BLOCK_SIZE_M < 64 and BLOCK_SIZE_N >= 64
# Activation fused kernels from sgl_kernel:
# - silu_and_mul (SiLU gating)
# - gelu_and_mul (GELU gating)
# - moe_sum_reduce (reduce expert outputs)
# Quantization support:
# - per_token_group_quant_fp8 (FP8 token-group quantization)
# - per_token_group_quant_int8 (INT8 token-group quantization)
# - scaled_fp8_quant (scaled FP8 quantization)
SGLang supports multiple expert parallelism strategies:
--ep-size.EPLBManager. Enabled via --enable-eplb.ElasticEPStateManager.DeepEPBuffer.--moe-dp-size.# From model_runner.py -- EPLB rebalance on rank faults
elastic_ep_state = ElasticEPStateManager.instance()
if elastic_ep_state is not None and not elastic_ep_state.is_active_equal_last():
elastic_ep_state.snapshot_active_to_last()
elastic_ep_state.sync_active_to_cpu()
gen = self.eplb_manager.rebalance()
while True:
try: next(gen)
except StopIteration: break
model_executor/model_runner.py -- ModelRunner class, forward dispatchmodel_executor/forward_batch_info.py -- ForwardBatch, ForwardMode enummodel_executor/cuda_graph_runner.py -- CudaGraphRunner, DecodeInputBufferslayers/attention/base_attn_backend.py -- AttentionBackend ABClayers/attention/attention_registry.py -- Backend registrylayers/attention/flashinfer_backend.py -- FlashInferlayers/attention/flashattention_backend.py -- FA3/FA4layers/sampler.py -- Sampler class, temp/top-p/top-kspeculative/eagle_worker.py -- EAGLEWorker, draft/verifyspeculative/eagle_utils.py -- Tree buildinglayers/moe/fused_moe_triton/fused_moe.py -- Dispatchlayers/moe/fused_moe_triton/fused_moe_triton_kernels.py -- Triton JITlayers/moe/fused_moe_triton/layer.py -- FusedMoE Moduleclass ForwardMode(IntEnum):
EXTEND = auto() # Prefill / extend with cached prefix
DECODE = auto() # Decode one token per request
MIXED = auto() # Chunked prefill: both extend and decode
IDLE = auto() # No sequences (DP attention padding)
TARGET_VERIFY = auto() # Speculative: verify in target model
DRAFT_EXTEND = auto() # Speculative: extend in draft model
DRAFT_EXTEND_V2 = auto() # Eagle V2 fixed-shape logits
PREBUILT = auto() # Disagg decode: KV cache ready
SPLIT_PREFILL = auto() # PD multiplexing split prefill
DLLM_EXTEND = auto() # Draft LLM extend
def is_cuda_graph(self):
# Only DECODE, TARGET_VERIFY, IDLE, and DLLM_EXTEND use CUDA graphs
return self in (DECODE, TARGET_VERIFY, IDLE, DLLM_EXTEND)