Table of Contents
- Backend Selection Logic
- Backend Registry and Dispatch
- RadixAttention Layer
- FlashInfer Backend (Default for Ampere/Ada)
- FlashAttention 3 Backend (Default for Hopper)
- MLA Backends: FlashMLA, CUTLASS MLA, TRT-LLM MLA
- TRT-LLM MHA Backend (Default for Blackwell B200)
- Triton Backend (Portable Fallback)
- Hybrid Attention Backend
- Backend Comparison Table
- Architecture Diagram
1. Backend Selection Logic
SGLang automatically selects the optimal attention backend based on GPU architecture, model type (MHA vs MLA),
and user configuration. This logic lives in _get_default_attn_backend() within server_args.py.
1.1 The Default Selection Algorithm
The function follows a clear priority chain. For MHA models (Llama, Qwen, Gemma, etc.):
# server_args.py: _get_default_attn_backend()
def _get_default_attn_backend(self, use_mla_backend: bool, model_config):
"""
Auto select the fastest attention backend.
1. Models with MHA Architecture (e.g: Llama, QWen)
1.1 FA3 on Hopper unless spec decode with topk > 1 or page_size > 1.
1.2 trtllm_mha for SM100/SM103 (Blackwell B200/GB200/B300)
Note: trtllm_mha does not support SM120
1.3 Otherwise flashinfer if available, else triton.
2. Models with MLA Architecture
2.1 FA3 on Hopper.
2.2 FlashInfer on Blackwell.
2.3 Otherwise triton.
"""
if not use_mla_backend:
# MHA architecture
if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(self):
return "fa3"
elif is_sm100_supported() and is_no_spec_infer_or_topk_one(self):
return "trtllm_mha"
elif is_hip():
return "aiter"
else:
return "flashinfer" if is_flashinfer_available() else "triton"
else:
# MLA architecture (DeepSeek models)
if is_hopper_with_cuda_12_3():
return "fa3"
elif is_sm100_supported():
return "flashinfer"
elif is_hip():
return "aiter"
else:
return "triton"
1.2 Decision Flow Summary
Hopper (SM90) MHA
FlashAttention 3 (fa3) is default. FA3 uses asynchronous TMA copies and warp-specialized execution for maximum throughput on H100/H200.
Hopper (SM90) MLA
FlashAttention 3 (fa3) is also default. The FA3 backend has an MLA-specific path that uses absorbed Q with split nope/rope dimensions.
Blackwell (SM100) MHA
TRT-LLM MHA (trtllm_mha) is default. Uses NVIDIA TensorRT-LLM XQA kernels via flashinfer for SM100/SM103. Supports sliding window and attention sinks natively.
Blackwell (SM100) MLA
FlashInfer (flashinfer) is default for MLA on Blackwell. The MLA-specific FlashInfer backend handles absorbed KV projections.
Ampere/Ada (SM80/89)
FlashInfer (flashinfer) is default when available. Falls back to Triton if not installed. FlashInfer wraps FA2-level kernels with paged KV cache support.
AMD ROCm
AIter (aiter) backend is default for HIP. Falls back to Triton for unsupported head configurations.
1.3 Hybrid Backend Creation
When --prefill-attention-backend and --decode-attention-backend differ, SGLang automatically wraps them in a HybridAttnBackend:
# model_runner.py: _get_attention_backend()
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
from sglang.srt.layers.attention.hybrid_attn_backend import HybridAttnBackend
attn_backend = HybridAttnBackend(
self,
decode_backend=self._get_attention_backend_from_str(
self.decode_attention_backend_str,
),
prefill_backend=self._get_attention_backend_from_str(
self.prefill_attention_backend_str,
),
)
2. Backend Registry and Dispatch
All attention backends are registered via a decorator pattern in attention_registry.py. This registry maps string names to factory functions:
# attention_registry.py
ATTENTION_BACKENDS = {}
def register_attention_backend(name):
def decorator(fn):
ATTENTION_BACKENDS[name] = fn
return fn
return decorator
# Each backend registers itself:
@register_attention_backend("flashinfer")
def create_flashinfer_backend(runner):
if not runner.use_mla_backend:
return FlashInferAttnBackend(runner)
else:
return FlashInferMLAAttnBackend(runner)
@register_attention_backend("fa3")
def create_flashattention_v3_backend(runner):
assert torch.cuda.get_device_capability()[0] >= 8 # SM80+
return FlashAttentionBackend(runner)
@register_attention_backend("trtllm_mha")
def create_trtllm_mha_backend(runner):
if runner.use_mla_backend:
raise ValueError("trtllm_mha only for non-MLA models.")
return TRTLLMHAAttnBackend(runner)
@register_attention_backend("trtllm_mla")
def create_trtllm_mla_backend(runner):
if not runner.use_mla_backend:
raise ValueError("trtllm_mla only for MLA models.")
return TRTLLMMLABackend(runner)
@register_attention_backend("flashmla") # FlashMLA for DeepSeek MLA
@register_attention_backend("cutlass_mla") # CUTLASS MLA kernel
@register_attention_backend("triton") # Portable fallback
@register_attention_backend("fa4") # FlashAttention 4
... # 14+ backends total
Backend instantiation then follows a single dispatch path from ModelRunner:
# model_runner.py: _get_attention_backend_from_str()
def _get_attention_backend_from_str(self, backend_str):
if backend_str not in ATTENTION_BACKENDS:
raise ValueError(f"Invalid attention backend: {backend_str}")
full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
return attn_backend_wrapper(self, full_attention_backend)
The attn_backend_wrapper() function further wraps the backend for hybrid models (e.g., Mamba-2 + attention, GDN + attention) by checking if the model has mixed linear-attention/full-attention layers.
3. RadixAttention Layer
The RadixAttention module in radix_attention.py is the nn.Module that every transformer layer calls.
It does not implement any kernel -- it dispatches to the current backend.
# radix_attention.py
class RadixAttention(nn.Module):
def __init__(self,
num_heads: int, head_dim: int, scaling: float,
num_kv_heads: int, layer_id: int,
logit_cap: float = 0.0,
v_head_dim: int = -1,
sliding_window_size: int = -1,
is_cross_attention: bool = False,
...
):
self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads
self.head_dim = head_dim
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
self.scaling = scaling
self.layer_id = layer_id
self.sliding_window_size = sliding_window_size or -1
3.1 The Forward Dispatch
The forward method has two paths: a torch.compile-compatible path for extend mode
(using unified_attention_with_output) and a direct dispatch path:
def forward(self, q, k, v, forward_batch, save_kv_cache=True, **kwargs):
if k is not None:
if "k_rope" not in kwargs:
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
else:
k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
# Piecewise CUDA graph path (extend mode + torch.compile)
if forward_batch.forward_mode.is_extend() and get_forward_context() is not None:
output = torch.empty_like(q)
unified_attention_with_output(
q, k, v, output, save_kv_cache, self.layer_id, **kwargs
)
return output
else:
# Direct backend dispatch
return forward_batch.attn_backend.forward(
q, k, v, self, forward_batch, save_kv_cache, **kwargs
)
forward_batch.attn_backend, which is set by the scheduler.
The RadixAttention layer itself has zero knowledge of which kernel will execute -- it just
passes itself (with layer_id, head counts, and scaling) to whichever backend is active.
3.2 Unified Attention (torch.compile path)
The unified_attention_with_output function is registered as a custom op for piecewise CUDA graph capture.
It retrieves the attention layer and backend from a thread-local forward context:
@register_custom_op(mutates_args=["output"])
@register_split_op()
def unified_attention_with_output(query, key, value, output,
save_kv_cache, layer_id, **kwargs):
context = get_forward_context()
forward_batch = context.forward_batch
attention_layer = context.attention_layers[layer_id]
ret = forward_batch.attn_backend.forward(
query, key, value, attention_layer, forward_batch, save_kv_cache, **kwargs
)
output.view(ret.shape).copy_(ret)
4. FlashInfer Backend (Default for Ampere/Ada)
flashinfer_backend.py -- the most mature and feature-rich backend. FlashInfer wraps optimized CUDA attention kernels with sophisticated paged KV cache management, CUDA graph support, and sliding window attention.
4.1 Wrapper Architecture
FlashInfer uses three types of wrapper objects that encapsulate kernel state:
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
fast_decode_plan,
)
from flashinfer.cascade import merge_state
4.2 Workspace and Buffer Initialization
A global workspace buffer is shared across all FlashInfer wrappers in the process. Qwen models require a larger workspace (512MB vs default):
# Global workspace -- shared across all wrappers
global_workspace_buffer = None
class FlashInferAttnBackend(AttentionBackend):
def __init__(self, model_runner, skip_prefill=False, ...):
# Qwen2/3 require higher workspace size
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(512 * 1024 * 1024)
# Allocate shared workspace
global global_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.empty(
envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
dtype=torch.uint8, device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer
4.3 Decode/Prefill Wrapper Lifecycle
Multiple wrappers support sliding window (2 wrappers) or cross-attention (2 wrappers). Each wrapper pair handles paged prefill + decode:
# Wrapper dispatch reasoning
if model_runner.sliding_window_size is not None:
self.num_wrappers = 2
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
elif model_runner.model_config.is_encoder_decoder:
self.num_wrappers = 2
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
else:
self.num_wrappers = 1
# Create wrapper instances per wrapper slot
for _ in range(self.num_wrappers):
self.prefill_wrappers_paged.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
)
self.decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD",
use_tensor_cores=self.decode_use_tensor_cores,
)
)
4.4 Page Table Management
FlashInfer uses KV indptr (pointer arrays) and KV indices (page indices) to manage paged KV cache.
The FlashInferIndicesUpdaterDecode and FlashInferIndicesUpdaterPrefill build these arrays
via Triton kernels for GPU-side construction:
# Indices are built per-batch
self.kv_indptr = [
torch.zeros((max_bs + 1,), dtype=torch.int32, device=device)
for _ in range(self.num_wrappers)
]
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=device
)
# init_forward_metadata for decode
def init_forward_metadata(self, forward_batch):
if forward_batch.forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_cpu,
forward_batch.seq_lens_sum,
decode_wrappers=self.decode_wrappers,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
4.5 CUDA Graph Capture
For CUDA graph support, FlashInfer creates dedicated wrappers with pre-allocated index buffers that can be replayed without re-planning:
def init_forward_metadata_capture_cuda_graph(self, bs, ...):
if forward_mode.is_decode_or_idle():
decode_wrappers = []
for i in range(self.num_wrappers):
decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD",
use_cuda_graph=True, # Key flag
use_tensor_cores=self.decode_use_tensor_cores,
paged_kv_indptr_buffer=self.kv_indptr[i][:num_tokens+1],
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buffer=self.kv_last_page_len[:num_tokens],
)
)
# Replace begin_forward with fast_decode_plan
for i in range(self.num_wrappers):
decode_wrappers[i].begin_forward = partial(
fast_decode_plan, decode_wrappers[i]
)
4.6 Tensor Core Decision
def should_use_tensor_core(kv_cache_dtype, num_attention_heads, num_kv_heads):
# Use tensor cores if FP8 KV cache or GQA ratio >= 4
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
return True
if num_attention_heads // num_kv_heads >= 4:
return True
return False
5. FlashAttention 3 Backend (Default for Hopper)
flashattention_backend.py wraps the FA3 and FA4 kernel APIs. It is the default for Hopper GPUs (SM90) for both MHA and MLA models.
5.1 Kernel Selection (FA3 vs FA4)
from sgl_kernel.flash_attn import (
flash_attn_varlen_func as flash_attn_varlen_func_fa3,
flash_attn_with_kvcache as flash_attn_with_kvcache_fa3,
)
from sglang.jit_kernel.flash_attention_v4 import (
flash_attn_varlen_func as flash_attn_varlen_func_fa4,
flash_attn_with_kvcache as flash_attn_with_kvcache_fa4,
)
# The constructor selects between FA3 and FA4 based on fa_impl_ver parameter
class FlashAttentionBackend(AttentionBackend):
def __init__(self, model_runner, fa_impl_ver=3):
# fa_impl_ver=3 for FA3 (default), fa_impl_ver=4 for FA4
5.2 Metadata Preparation
The FlashAttentionMetadata dataclass holds all per-batch state needed by the kernels:
@dataclass
class FlashAttentionMetadata:
cache_seqlens_int32: torch.Tensor = None # Sequence lengths
max_seq_len_q: int = 1 # Max query length
max_seq_len_k: int = 0 # Max key length
cu_seqlens_q: torch.Tensor = None # Cumulative query lengths
cu_seqlens_k: torch.Tensor = None # Cumulative key lengths
window_size: tuple = (-1, -1) # Sliding window
page_table: torch.Tensor = None # KV cache page indices
swa_page_table: torch.Tensor = None # SWA-specific pages
# Nested metadata for local attention (long context)
@dataclass
class LocalAttentionMetadata:
local_query_start_loc: torch.Tensor = None
local_seqused_k: torch.Tensor = None
local_block_table: torch.Tensor = None
local_max_query_len: int = 0
local_max_seq_len: int = 0
5.3 Local Attention for Long Context
For models with attention chunk sizes (e.g., Llama-3.1 with 8K local attention), FA3 breaks sequences into virtual local batches. This simulates a block-diagonal attention mask without actually materializing any mask tensor:
def make_local_attention_virtual_batches(
attn_chunk_size, query_start_loc_np, seq_lens_np, block_table, page_size
):
"""
Break sequences into local attention blocks.
E.g., with attn_chunk_size=4, q_seqlens=[4,10,5], k_seqlens=[6,17,9]:
seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
Each local block becomes an independent "virtual" batch item.
"""
q_tokens_in_first_block = np.minimum(
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
q_seqlens
)
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
... # Build virtual batch arrays
5.4 Decode Path: flash_attn_with_kvcache
The decode path calls flash_attn_with_kvcache which operates on paged KV cache directly,
avoiding any data movement. The page table maps logical token positions to physical cache pages.
6. MLA Backends: FlashMLA, CUTLASS MLA, TRT-LLM MLA
Multi-head Latent Attention (MLA) is DeepSeek's innovation that compresses KV projections into a low-rank space. SGLang provides three specialized decode kernels for MLA models, all sharing FlashInfer's MLA backend for prefill.
6.1 MLA Architecture Overview
kv_lora_rank (typically 512), qk_nope_head_dim (128), qk_rope_head_dim (64),
v_head_dim (128). The KV cache stores a compressed representation of dimension
kv_lora_rank + qk_rope_head_dim = 576 instead of separate K and V.
6.2 FlashMLA Backend
flashmla_backend.py -- Inherits from FlashInferMLAAttnBackend for prefill,
overrides decode with the flash_mla_with_kvcache kernel. Uses PAGE_SIZE = 64.
from sgl_kernel.flash_mla import flash_mla_with_kvcache, get_mla_metadata
PAGE_SIZE = 64 # FlashMLA fixed page size
class FlashMLABackend(FlashInferMLAAttnBackend):
def init_forward_metadata(self, forward_batch):
if forward_batch.forward_mode.is_decode_or_idle():
# Build page-aligned block indices via Triton kernel
max_seqlen_pad = triton.cdiv(
forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
)
block_kv_indices = torch.full(
(bs, max_seqlen_pad), -1, dtype=torch.int32, device=device
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token, forward_batch.req_pool_indices,
forward_batch.seq_lens, None, block_kv_indices,
self.req_to_token.stride(0), max_seqlen_pad,
)
# Get tile scheduler metadata for split-k attention
mla_metadata, num_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32),
self.num_q_heads, 1,
is_fp8_kvcache=self.is_fp8_kvcache,
)
else:
# Fall back to FlashInfer MLA for prefill
super().init_forward_metadata(forward_batch)
FlashMLA Decode Kernel Call
def forward_decode(self, q, k, v, layer, forward_batch, save_kv_cache=True):
# Save KV to cache
if k is not None and save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
o, _ = flash_mla_with_kvcache(
q=reshape_q,
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
block_table=self.forward_metadata.block_kv_indices[:bs],
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
head_dim_v=self.kv_lora_rank, # Output dim is kv_lora_rank, not head_dim
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
num_splits=self.forward_metadata.num_splits,
softmax_scale=layer.scaling,
causal=True,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
6.3 CUTLASS MLA Backend
cutlass_mla_backend.py -- Uses NVIDIA CUTLASS kernels for MLA decode. Key difference: PAGE_SIZE = 128 (fixed by CUTLASS requirements). Separates Q into nope and rope components before the kernel call.
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
PAGE_SIZE = 128 # CUTLASS MLA only supports page_size=128
class CutlassMLABackend(FlashInferMLAAttnBackend):
def init_forward_metadata(self, forward_batch):
if forward_batch.forward_mode.is_decode_or_idle():
# Workspace allocated per-batch
workspace_size = cutlass_mla_get_workspace_size(
max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1
)
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
else:
super().init_forward_metadata(forward_batch) # FlashInfer for prefill
CUTLASS MLA Kernel: Separate Q_nope and Q_rope
def forward_decode(self, q, k, v, layer, forward_batch, save_kv_cache=True,
q_rope=None, k_rope=None):
# CUTLASS MLA takes nope/rope as separate arguments
if q_rope is not None:
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim)
else:
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = reshaped_q[:, :, :layer.v_head_dim]
q_rope = reshaped_q[:, :, layer.v_head_dim:]
o = cutlass_mla_decode(
q_nope=q_nope, q_pe=q_rope,
kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
seq_lens=forward_batch.seq_lens.to(torch.int32),
page_table=self.forward_metadata.block_kv_indices,
workspace=self.forward_metadata.workspace,
sm_scale=layer.scaling, num_kv_splits=1,
)
6.4 TRT-LLM MLA Backend
trtllm_mla_backend.py -- Uses TensorRT-LLM's MLA kernel from flashinfer. Handles complex block constraints for page alignment:
# Block constraint from flashinfer requirements:
# block_num % (128 / block_size) == 0
TRTLLM_BLOCK_CONSTRAINT = 128
class TRTLLMMLABackend(FlashInferMLAAttnBackend):
def _calc_padded_blocks(self, max_seq_len):
"""Pad block count to satisfy both TRT-LLM and Triton constraints."""
blocks = triton.cdiv(max_seq_len, self.page_size)
# LCM of TRT-LLM constraint and Triton page constraint
trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
triton_constraint = get_num_page_per_block_flashmla(self.page_size)
constraint_lcm = math.lcm(trtllm_constraint, triton_constraint)
if blocks % constraint_lcm != 0:
blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
return blocks
The TRT-LLM MLA backend also includes Triton kernels for padding/unpadding draft queries during
speculative decoding (pad_draft_extend_query_kernel and unpad_draft_extend_output_kernel).
6.5 MLA Backend Comparison
| Aspect | FlashMLA | CUTLASS MLA | TRT-LLM MLA |
|---|---|---|---|
| Page Size | 64 (fixed) | 128 (fixed) | Configurable (with constraints) |
| Prefill | FlashInfer MLA (inherited) | FlashInfer MLA (inherited) | FlashInfer MLA (inherited) |
| Q Format | Concatenated [nope|rope] | Separate q_nope, q_pe | Absorbed Q |
| FP8 KV Cache | Yes (quantize Q to match) | No | Yes (with descale) |
| Spec Decode | topk=1 only | Falls back to FlashInfer | Full support with padding |
| CUDA Graph | Full support | Full support | Full support |
| Parent Class | FlashInferMLAAttnBackend | FlashInferMLAAttnBackend | FlashInferMLAAttnBackend |
FlashInferMLAAttnBackend and only
override the decode path (and sometimes target_verify). Prefill always uses FlashInfer's
MLA implementation, making these backends effectively "decode-only" specializations.
7. TRT-LLM MHA Backend (Default for Blackwell B200)
trtllm_mha_backend.py -- This is the default attention backend for SM100 (Blackwell B200/GB200/B300) GPUs on non-MLA models. It extends FlashInfer's backend with TRT-LLM's XQA kernels.
7.1 Why Default for Blackwell?
self.is_xqa_impl = is_sm90_supported() or is_sm120_supported().
The XQA path keeps queries in BF16 even with FP8 KV cache, leveraging hardware mixed-precision
matmul for optimal throughput.
7.2 Initialization and Workspace
class TRTLLMHAAttnBackend(FlashInferAttnBackend):
"""TRTLLM MHA attention kernel from flashinfer.
The kernel supports sm100 only, with sliding window and attention sink."""
def __init__(self, model_runner, skip_prefill=False, ...):
# Preserve user's workspace size before parent __init__ overrides it
workspace_size_bytes = (
env_var.get() if env_var.is_set()
else DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024 # 512 MB
)
super().__init__(model_runner, skip_prefill, ...)
# Zero-initialized workspace (required by TRT-LLM)
global global_zero_init_workspace_buffer
if global_zero_init_workspace_buffer is None:
global_zero_init_workspace_buffer = torch.zeros(
self.workspace_size, dtype=torch.uint8, device=device,
)
self.workspace_buffer = global_zero_init_workspace_buffer
# Detect XQA vs TRTLLM-GEN implementation
self.is_xqa_impl = is_sm90_supported() or is_sm120_supported()
7.3 Sliding Window Attention with SWA KV Pool
A distinguishing feature is native sliding window attention support with separate KV pools. For hybrid SWA models (like Gemma-3), the backend maintains translated page tables for SWA layers that index into a separate, smaller KV pool:
# SWA hybrid model support -- two KV pools with separate indices
allocator = model_runner.token_to_kv_pool_allocator
self.use_sliding_window_kv_pool = isinstance(
allocator, SWATokenToKVPoolAllocator
)
self._swa_kv_pool = (
allocator.get_kvcache() if self.use_sliding_window_kv_pool else None
)
def _get_layer_page_table(self, layer, forward_batch):
"""Return correct page_table for the given layer (SWA or full)."""
swa_pt = self.forward_metadata.swa_page_table
if swa_pt is not None:
_, is_swa = self._swa_kv_pool.layers_mapping[layer.layer_id]
if is_swa:
return swa_pt
return self.forward_metadata.page_table
7.4 Decode Kernel Call
def forward_decode(self, q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs):
# For XQA: q_dtype stays bf16 even with FP8 KV
if self.data_type == torch.float8_e4m3fn and (not self.is_xqa_impl):
q = q.to(torch.float8_e4m3fn)
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
# Permute: [pages, page_size, heads, dim] -> [pages, heads, page_size, dim]
k_cache = k_cache.view(-1, self.page_size, layer.tp_k_head_num, layer.head_dim
).permute(0, 2, 1, 3)
v_cache = v_cache.view(-1, self.page_size, layer.tp_v_head_num, layer.head_dim
).permute(0, 2, 1, 3)
page_table = self._get_layer_page_table(layer, forward_batch)
o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=q,
kv_cache=(k_cache, v_cache),
workspace_buffer=self.workspace_buffer,
block_tables=page_table,
seq_lens=self.forward_metadata.cache_seqlens_int32,
max_seq_len=self.max_context_len,
bmm1_scale=q_scale * k_scale * layer.scaling,
bmm2_scale=1.0,
window_left=layer.sliding_window_size, # Native SWA support
sinks=attention_sink, # Attention sink support
out_dtype=self.q_data_type,
)
7.5 CUDA Graph State Management
TRT-LLM MHA pre-allocates page tables, cumulative sequence length arrays, and optional SWA translation tables for CUDA graph capture across decode, target_verify, and draft_extend modes:
def init_cuda_graph_state(self, max_bs, max_num_tokens):
max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size
self.decode_cuda_graph_metadata = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device="cuda"),
"page_table": torch.zeros(max_bs, max_num_pages, dtype=torch.int32, device="cuda"),
"swa_page_table": self._alloc_swa_page_table(max_bs, max_num_pages),
"strided_indices": torch.arange(0, self.max_context_len, self.page_size, device="cuda"),
}
# Also pre-allocates target_verify_metadata and draft_extend_metadata dicts
8. Triton Backend (Portable Fallback)
triton_backend.py -- A fully custom attention implementation using Triton JIT kernels. This backend works on any CUDA GPU without external dependencies, making it the universal fallback.
8.1 Architecture
class TritonAttnBackend(AttentionBackend):
def __init__(self, model_runner, skip_prefill=False):
# Lazy imports -- avoid CUDA context init at import time
from sglang.srt.layers.attention.triton_ops.decode_attention import (
decode_attention_fwd,
)
from sglang.srt.layers.attention.triton_ops.extend_attention import (
extend_attention_fwd, extend_attention_fwd_unified,
build_unified_kv_indices,
)
# Disable torch.compile for triton kernels
self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
8.2 Split-K Decode with Dynamic KV Splits
The Triton decode kernel uses a split-K strategy where the KV cache is partitioned across multiple splits. The number of splits is computed dynamically to saturate GPU SMs:
def get_num_kv_splits(self, num_kv_splits, seq_lens):
num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0]
num_group = num_token // num_seq
# Deterministic mode: fixed splits per sequence length
if self.split_tile_size is not None and self.enable_deterministic:
expanded_seq_lens = seq_lens.repeat_interleave(num_group)
num_kv_splits[:] = (
expanded_seq_lens + self.split_tile_size - 1
) // self.split_tile_size
return
# Dynamic mode: use GPU SM count to determine splits
get_num_kv_splits_triton[(1,)](
num_kv_splits, seq_lens, num_seq, num_group,
self.num_head, self.num_kv_head,
self.max_kv_splits, self.device_core_count,
)
8.3 Metadata Structure
@dataclass
class ForwardMetadata:
attn_logits: torch.Tensor # (bs, num_heads, max_kv_splits, v_head_dim)
attn_lse: torch.Tensor # (bs, num_heads, max_kv_splits)
max_extend_len: int
num_kv_splits: torch.Tensor # (bs,)
kv_indptr: torch.Tensor # CSR-format pointers into kv_indices
kv_indices: torch.Tensor # Flat array of token indices
qo_indptr: torch.Tensor
custom_mask: torch.Tensor # For speculative decoding verification
mask_indptr: torch.Tensor
# Sliding window variants
window_kv_indptr: torch.Tensor
window_kv_indices: torch.Tensor
window_num_kv_splits: torch.Tensor
window_kv_offsets: torch.Tensor
8.4 Deterministic Inference
When --enable-deterministic-inference is set, the Triton backend uses fixed split tile sizes
instead of dynamic SM-based scheduling, ensuring bitwise-identical outputs across runs:
if self.enable_deterministic:
self.split_tile_size = get_int_env_var(
"SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE", 256
)
self.static_kv_splits = False # Use deterministic logic instead
9. Hybrid Attention Backend
hybrid_attn_backend.py -- Allows mixing different backends for prefill and decode
phases. This is used when --prefill-attention-backend differs from --decode-attention-backend.
9.1 Backend Selection by Forward Mode
class HybridAttnBackend(AttentionBackend):
"""Support different backends for prefill and decode."""
def __init__(self, model_runner, prefill_backend, decode_backend):
self.model_runner = model_runner
self.prefill_backend = prefill_backend
self.decode_backend = decode_backend
def _select_backend(self, forward_mode) -> AttentionBackend:
"""
Select backend based on forward mode:
- decode_or_idle -> decode backend (always)
- target_verify/draft -> decode if speculative_attention_mode=="decode"
else prefill
- prefill -> prefill backend (always)
"""
if forward_mode.is_decode_or_idle():
return self.decode_backend
elif forward_mode.is_target_verify() or forward_mode.is_draft_extend():
return (
self.decode_backend
if self.model_runner.server_args.speculative_attention_mode == "decode"
else self.prefill_backend
)
else:
return self.prefill_backend
9.2 Transparent Delegation
Every method simply delegates to the selected backend. The hybrid backend itself stores no attention-specific state:
def init_forward_metadata(self, forward_batch):
backend = self._select_backend(forward_batch.forward_mode)
backend.init_forward_metadata(forward_batch)
def init_cuda_graph_state(self, max_bs, max_num_tokens):
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
# Also init prefill backend if speculative decoding uses prefill mode
if (self.model_runner.server_args.speculative_algorithm is not None
and self.model_runner.server_args.speculative_attention_mode == "prefill"):
self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens)
def forward(self, q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs):
backend = self._select_backend(forward_batch.forward_mode)
return backend.forward(q, k, v, layer, forward_batch, save_kv_cache, **kwargs)
def forward_decode(self, q, k, v, layer, forward_batch, ...):
# Decode always goes to decode_backend directly
return self.decode_backend.forward_decode(q, k, v, layer, forward_batch, ...)
def forward_extend(self, q, k, v, layer, forward_batch, ...):
# Extend may go to either backend depending on speculative mode
backend = self._select_backend(forward_batch.forward_mode)
return backend.forward_extend(q, k, v, layer, forward_batch, ...)
9.3 Example Use Cases
FlashMLA decode + FlashInfer prefill
Use --decode-attention-backend flashmla --prefill-attention-backend flashinfer
for DeepSeek models where FlashMLA decode kernel is faster but does not handle prefill.
TRT-LLM MHA decode + FA3 prefill
Combine Blackwell-optimized decode with FA3's async TMA prefill.
Use --decode-attention-backend trtllm_mha --prefill-attention-backend fa3.
10. Comprehensive Backend Comparison
| Backend | GPU Arch | Model Type | Prefill | Decode | SWA | CUDA Graph | FP8 KV |
|---|---|---|---|---|---|---|---|
| FlashInfer default | SM80+ (Ampere/Ada) | MHA + MLA | Ragged + Paged | Paged, tensor cores | Yes (2 wrappers) | Full | Yes |
| FA3 default | SM90 (Hopper) | MHA + MLA | varlen + kvcache | flash_attn_with_kvcache | Yes + Local Attn | Full | Yes |
| TRT-LLM MHA default | SM100 (Blackwell) | MHA only | FlashInfer (inherited) | trtllm XQA kernel | Native + SWA pool | Full | Yes (XQA mixed) |
| FlashMLA | SM80+ | MLA only | FlashInfer MLA | flash_mla_with_kvcache | N/A | Full | Yes |
| CUTLASS MLA | SM80+ | MLA only | FlashInfer MLA | cutlass_mla_decode | N/A | Full | No |
| TRT-LLM MLA | SM80+ | MLA only | FlashInfer MLA | trtllm_gen_mla kernel | N/A | Full | Yes |
| Triton | Any CUDA | MHA + MLA | Custom Triton | Split-K Triton | Yes | Full | No |
| Hybrid | Any | Any | Delegated | Delegated | Backend-dependent | Both backends | Backend-dependent |
10.1 Feature Support Matrix
| Feature | FlashInfer | FA3 | TRT-LLM MHA | Triton | FlashMLA | CUTLASS MLA |
|---|---|---|---|---|---|---|
| Cross-Attention | Yes | Yes | Inherited | No | No | No |
| Speculative Decode | Full | Full | topk=1 | Full | topk=1 | Falls back |
| Deterministic Mode | Yes (fixed splits) | N/A | N/A | Yes (fixed tiles) | N/A | N/A |
| Logit Capping | Via kernel param | Via kernel param | N/A | tanh cap | N/A | N/A |
| Multi-Item Scoring | Yes | N/A | N/A | N/A | N/A | N/A |
| Attention Sinks | N/A | N/A | Yes (native) | N/A | N/A | N/A |
| Page Size | Configurable | Configurable | Configurable | Configurable | 64 (fixed) | 128 (fixed) |
11. Architecture Diagram
Figure 1: The complete attention backend dispatch architecture. RadixAttention delegates to the selected backend, which is chosen based on GPU architecture, model type, and configuration. MLA backends all share FlashInfer for prefill.
Figure 2: Attention backend lifecycle per batch. The top row shows the normal path; the middle row shows the CUDA graph capture/replay path that eliminates CPU overhead on decode. The bottom illustrates how token-level indices are grouped into page-level block indices for paged attention kernels.
Source File Reference
| File | Class | Role |
|---|---|---|
| radix_attention.py | RadixAttention |
nn.Module attention layer, dispatches to backend |
| attention_registry.py | ATTENTION_BACKENDS |
Registry mapping string names to factory functions |
| flashinfer_backend.py | FlashInferAttnBackend |
FlashInfer wrappers for paged MHA attention |
| flashattention_backend.py | FlashAttentionBackend |
FA3/FA4 kernels with local attention support |
| flashmla_backend.py | FlashMLABackend |
FlashMLA decode for DeepSeek MLA (page_size=64) |
| cutlass_mla_backend.py | CutlassMLABackend |
CUTLASS MLA decode (page_size=128, separate q_nope/q_pe) |
| trtllm_mla_backend.py | TRTLLMMLABackend |
TRT-LLM MLA decode with LCM block padding |
| trtllm_mha_backend.py | TRTLLMHAAttnBackend |
TRT-LLM XQA decode for Blackwell (SWA + sinks) |
| triton_backend.py | TritonAttnBackend |
Portable Triton JIT kernels with split-K decode |
| hybrid_attn_backend.py | HybridAttnBackend |
Mixes different backends for prefill vs decode |
| server_args.py | _get_default_attn_backend() |
Auto-selection logic based on GPU arch + model type |