Attention Backends Deep Dive

How SGLang selects, configures, and dispatches across 8+ attention kernel backends -- from FlashInfer on Ampere to TRT-LLM MHA on Blackwell, covering MHA, MLA, sliding window, CUDA graph capture, and hybrid prefill/decode strategies.

Source: python/sglang/srt/layers/attention/ Backends: FlashInfer | FA3 | FlashMLA | CUTLASS MLA | TRT-LLM MLA | TRT-LLM MHA | Triton | Hybrid

Table of Contents

  1. Backend Selection Logic
  2. Backend Registry and Dispatch
  3. RadixAttention Layer
  4. FlashInfer Backend (Default for Ampere/Ada)
  5. FlashAttention 3 Backend (Default for Hopper)
  6. MLA Backends: FlashMLA, CUTLASS MLA, TRT-LLM MLA
  7. TRT-LLM MHA Backend (Default for Blackwell B200)
  8. Triton Backend (Portable Fallback)
  9. Hybrid Attention Backend
  10. Backend Comparison Table
  11. Architecture Diagram

1. Backend Selection Logic

SGLang automatically selects the optimal attention backend based on GPU architecture, model type (MHA vs MLA), and user configuration. This logic lives in _get_default_attn_backend() within server_args.py.

1.1 The Default Selection Algorithm

The function follows a clear priority chain. For MHA models (Llama, Qwen, Gemma, etc.):

# server_args.py: _get_default_attn_backend()
def _get_default_attn_backend(self, use_mla_backend: bool, model_config):
    """
    Auto select the fastest attention backend.

    1. Models with MHA Architecture (e.g: Llama, QWen)
        1.1 FA3 on Hopper unless spec decode with topk > 1 or page_size > 1.
        1.2 trtllm_mha for SM100/SM103 (Blackwell B200/GB200/B300)
            Note: trtllm_mha does not support SM120
        1.3 Otherwise flashinfer if available, else triton.
    2. Models with MLA Architecture
        2.1 FA3 on Hopper.
        2.2 FlashInfer on Blackwell.
        2.3 Otherwise triton.
    """
    if not use_mla_backend:
        # MHA architecture
        if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(self):
            return "fa3"
        elif is_sm100_supported() and is_no_spec_infer_or_topk_one(self):
            return "trtllm_mha"
        elif is_hip():
            return "aiter"
        else:
            return "flashinfer" if is_flashinfer_available() else "triton"
    else:
        # MLA architecture (DeepSeek models)
        if is_hopper_with_cuda_12_3():
            return "fa3"
        elif is_sm100_supported():
            return "flashinfer"
        elif is_hip():
            return "aiter"
        else:
            return "triton"

1.2 Decision Flow Summary

Hopper (SM90) MHA

FlashAttention 3 (fa3) is default. FA3 uses asynchronous TMA copies and warp-specialized execution for maximum throughput on H100/H200.

Hopper (SM90) MLA

FlashAttention 3 (fa3) is also default. The FA3 backend has an MLA-specific path that uses absorbed Q with split nope/rope dimensions.

Blackwell (SM100) MHA

TRT-LLM MHA (trtllm_mha) is default. Uses NVIDIA TensorRT-LLM XQA kernels via flashinfer for SM100/SM103. Supports sliding window and attention sinks natively.

Blackwell (SM100) MLA

FlashInfer (flashinfer) is default for MLA on Blackwell. The MLA-specific FlashInfer backend handles absorbed KV projections.

Ampere/Ada (SM80/89)

FlashInfer (flashinfer) is default when available. Falls back to Triton if not installed. FlashInfer wraps FA2-level kernels with paged KV cache support.

AMD ROCm

AIter (aiter) backend is default for HIP. Falls back to Triton for unsupported head configurations.

1.3 Hybrid Backend Creation

When --prefill-attention-backend and --decode-attention-backend differ, SGLang automatically wraps them in a HybridAttnBackend:

# model_runner.py: _get_attention_backend()
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
    from sglang.srt.layers.attention.hybrid_attn_backend import HybridAttnBackend

    attn_backend = HybridAttnBackend(
        self,
        decode_backend=self._get_attention_backend_from_str(
            self.decode_attention_backend_str,
        ),
        prefill_backend=self._get_attention_backend_from_str(
            self.prefill_attention_backend_str,
        ),
    )

2. Backend Registry and Dispatch

All attention backends are registered via a decorator pattern in attention_registry.py. This registry maps string names to factory functions:

# attention_registry.py
ATTENTION_BACKENDS = {}

def register_attention_backend(name):
    def decorator(fn):
        ATTENTION_BACKENDS[name] = fn
        return fn
    return decorator

# Each backend registers itself:
@register_attention_backend("flashinfer")
def create_flashinfer_backend(runner):
    if not runner.use_mla_backend:
        return FlashInferAttnBackend(runner)
    else:
        return FlashInferMLAAttnBackend(runner)

@register_attention_backend("fa3")
def create_flashattention_v3_backend(runner):
    assert torch.cuda.get_device_capability()[0] >= 8  # SM80+
    return FlashAttentionBackend(runner)

@register_attention_backend("trtllm_mha")
def create_trtllm_mha_backend(runner):
    if runner.use_mla_backend:
        raise ValueError("trtllm_mha only for non-MLA models.")
    return TRTLLMHAAttnBackend(runner)

@register_attention_backend("trtllm_mla")
def create_trtllm_mla_backend(runner):
    if not runner.use_mla_backend:
        raise ValueError("trtllm_mla only for MLA models.")
    return TRTLLMMLABackend(runner)

@register_attention_backend("flashmla")   # FlashMLA for DeepSeek MLA
@register_attention_backend("cutlass_mla") # CUTLASS MLA kernel
@register_attention_backend("triton")      # Portable fallback
@register_attention_backend("fa4")         # FlashAttention 4
... # 14+ backends total

Backend instantiation then follows a single dispatch path from ModelRunner:

# model_runner.py: _get_attention_backend_from_str()
def _get_attention_backend_from_str(self, backend_str):
    if backend_str not in ATTENTION_BACKENDS:
        raise ValueError(f"Invalid attention backend: {backend_str}")
    full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
    return attn_backend_wrapper(self, full_attention_backend)

The attn_backend_wrapper() function further wraps the backend for hybrid models (e.g., Mamba-2 + attention, GDN + attention) by checking if the model has mixed linear-attention/full-attention layers.

3. RadixAttention Layer

The RadixAttention module in radix_attention.py is the nn.Module that every transformer layer calls. It does not implement any kernel -- it dispatches to the current backend.

# radix_attention.py
class RadixAttention(nn.Module):
    def __init__(self,
        num_heads: int, head_dim: int, scaling: float,
        num_kv_heads: int, layer_id: int,
        logit_cap: float = 0.0,
        v_head_dim: int = -1,
        sliding_window_size: int = -1,
        is_cross_attention: bool = False,
        ...
    ):
        self.tp_q_head_num = num_heads
        self.tp_k_head_num = num_kv_heads
        self.head_dim = head_dim
        self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
        self.scaling = scaling
        self.layer_id = layer_id
        self.sliding_window_size = sliding_window_size or -1

3.1 The Forward Dispatch

The forward method has two paths: a torch.compile-compatible path for extend mode (using unified_attention_with_output) and a direct dispatch path:

def forward(self, q, k, v, forward_batch, save_kv_cache=True, **kwargs):
    if k is not None:
        if "k_rope" not in kwargs:
            k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
            v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
        else:
            k = k.view(-1, self.tp_k_head_num, self.v_head_dim)

    # Piecewise CUDA graph path (extend mode + torch.compile)
    if forward_batch.forward_mode.is_extend() and get_forward_context() is not None:
        output = torch.empty_like(q)
        unified_attention_with_output(
            q, k, v, output, save_kv_cache, self.layer_id, **kwargs
        )
        return output
    else:
        # Direct backend dispatch
        return forward_batch.attn_backend.forward(
            q, k, v, self, forward_batch, save_kv_cache, **kwargs
        )
Key Insight: The backend is accessed via forward_batch.attn_backend, which is set by the scheduler. The RadixAttention layer itself has zero knowledge of which kernel will execute -- it just passes itself (with layer_id, head counts, and scaling) to whichever backend is active.

3.2 Unified Attention (torch.compile path)

The unified_attention_with_output function is registered as a custom op for piecewise CUDA graph capture. It retrieves the attention layer and backend from a thread-local forward context:

@register_custom_op(mutates_args=["output"])
@register_split_op()
def unified_attention_with_output(query, key, value, output,
                                  save_kv_cache, layer_id, **kwargs):
    context = get_forward_context()
    forward_batch = context.forward_batch
    attention_layer = context.attention_layers[layer_id]
    ret = forward_batch.attn_backend.forward(
        query, key, value, attention_layer, forward_batch, save_kv_cache, **kwargs
    )
    output.view(ret.shape).copy_(ret)

4. FlashInfer Backend (Default for Ampere/Ada)

flashinfer_backend.py -- the most mature and feature-rich backend. FlashInfer wraps optimized CUDA attention kernels with sophisticated paged KV cache management, CUDA graph support, and sliding window attention.

4.1 Wrapper Architecture

FlashInfer uses three types of wrapper objects that encapsulate kernel state:

from flashinfer import (
    BatchDecodeWithPagedKVCacheWrapper,
    BatchPrefillWithPagedKVCacheWrapper,
    BatchPrefillWithRaggedKVCacheWrapper,
    fast_decode_plan,
)
from flashinfer.cascade import merge_state

4.2 Workspace and Buffer Initialization

A global workspace buffer is shared across all FlashInfer wrappers in the process. Qwen models require a larger workspace (512MB vs default):

# Global workspace -- shared across all wrappers
global_workspace_buffer = None

class FlashInferAttnBackend(AttentionBackend):
    def __init__(self, model_runner, skip_prefill=False, ...):
        # Qwen2/3 require higher workspace size
        if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
            envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(512 * 1024 * 1024)

        # Allocate shared workspace
        global global_workspace_buffer
        if global_workspace_buffer is None:
            global_workspace_buffer = torch.empty(
                envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
                dtype=torch.uint8, device=model_runner.device,
            )
        self.workspace_buffer = global_workspace_buffer

4.3 Decode/Prefill Wrapper Lifecycle

Multiple wrappers support sliding window (2 wrappers) or cross-attention (2 wrappers). Each wrapper pair handles paged prefill + decode:

# Wrapper dispatch reasoning
if model_runner.sliding_window_size is not None:
    self.num_wrappers = 2
    self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
elif model_runner.model_config.is_encoder_decoder:
    self.num_wrappers = 2
    self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
else:
    self.num_wrappers = 1

# Create wrapper instances per wrapper slot
for _ in range(self.num_wrappers):
    self.prefill_wrappers_paged.append(
        BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
    )
    self.decode_wrappers.append(
        BatchDecodeWithPagedKVCacheWrapper(
            self.workspace_buffer, "NHD",
            use_tensor_cores=self.decode_use_tensor_cores,
        )
    )

4.4 Page Table Management

FlashInfer uses KV indptr (pointer arrays) and KV indices (page indices) to manage paged KV cache. The FlashInferIndicesUpdaterDecode and FlashInferIndicesUpdaterPrefill build these arrays via Triton kernels for GPU-side construction:

# Indices are built per-batch
self.kv_indptr = [
    torch.zeros((max_bs + 1,), dtype=torch.int32, device=device)
    for _ in range(self.num_wrappers)
]
self.kv_last_page_len = torch.ones(
    (max_bs,), dtype=torch.int32, device=device
)

# init_forward_metadata for decode
def init_forward_metadata(self, forward_batch):
    if forward_batch.forward_mode.is_decode_or_idle():
        self.indices_updater_decode.update(
            forward_batch.req_pool_indices,
            forward_batch.seq_lens,
            forward_batch.seq_lens_cpu,
            forward_batch.seq_lens_sum,
            decode_wrappers=self.decode_wrappers,
            encoder_lens=forward_batch.encoder_lens,
            spec_info=forward_batch.spec_info,
        )
        self.forward_metadata = DecodeMetadata(self.decode_wrappers)

4.5 CUDA Graph Capture

For CUDA graph support, FlashInfer creates dedicated wrappers with pre-allocated index buffers that can be replayed without re-planning:

def init_forward_metadata_capture_cuda_graph(self, bs, ...):
    if forward_mode.is_decode_or_idle():
        decode_wrappers = []
        for i in range(self.num_wrappers):
            decode_wrappers.append(
                BatchDecodeWithPagedKVCacheWrapper(
                    self.workspace_buffer, "NHD",
                    use_cuda_graph=True,  # Key flag
                    use_tensor_cores=self.decode_use_tensor_cores,
                    paged_kv_indptr_buffer=self.kv_indptr[i][:num_tokens+1],
                    paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
                    paged_kv_last_page_len_buffer=self.kv_last_page_len[:num_tokens],
                )
            )
        # Replace begin_forward with fast_decode_plan
        for i in range(self.num_wrappers):
            decode_wrappers[i].begin_forward = partial(
                fast_decode_plan, decode_wrappers[i]
            )

4.6 Tensor Core Decision

def should_use_tensor_core(kv_cache_dtype, num_attention_heads, num_kv_heads):
    # Use tensor cores if FP8 KV cache or GQA ratio >= 4
    if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
        return True
    if num_attention_heads // num_kv_heads >= 4:
        return True
    return False

5. FlashAttention 3 Backend (Default for Hopper)

flashattention_backend.py wraps the FA3 and FA4 kernel APIs. It is the default for Hopper GPUs (SM90) for both MHA and MLA models.

5.1 Kernel Selection (FA3 vs FA4)

from sgl_kernel.flash_attn import (
    flash_attn_varlen_func as flash_attn_varlen_func_fa3,
    flash_attn_with_kvcache as flash_attn_with_kvcache_fa3,
)
from sglang.jit_kernel.flash_attention_v4 import (
    flash_attn_varlen_func as flash_attn_varlen_func_fa4,
    flash_attn_with_kvcache as flash_attn_with_kvcache_fa4,
)

# The constructor selects between FA3 and FA4 based on fa_impl_ver parameter
class FlashAttentionBackend(AttentionBackend):
    def __init__(self, model_runner, fa_impl_ver=3):
        # fa_impl_ver=3 for FA3 (default), fa_impl_ver=4 for FA4

5.2 Metadata Preparation

The FlashAttentionMetadata dataclass holds all per-batch state needed by the kernels:

@dataclass
class FlashAttentionMetadata:
    cache_seqlens_int32: torch.Tensor = None   # Sequence lengths
    max_seq_len_q: int = 1                    # Max query length
    max_seq_len_k: int = 0                    # Max key length
    cu_seqlens_q: torch.Tensor = None         # Cumulative query lengths
    cu_seqlens_k: torch.Tensor = None         # Cumulative key lengths
    window_size: tuple = (-1, -1)             # Sliding window
    page_table: torch.Tensor = None           # KV cache page indices
    swa_page_table: torch.Tensor = None       # SWA-specific pages

    # Nested metadata for local attention (long context)
    @dataclass
    class LocalAttentionMetadata:
        local_query_start_loc: torch.Tensor = None
        local_seqused_k: torch.Tensor = None
        local_block_table: torch.Tensor = None
        local_max_query_len: int = 0
        local_max_seq_len: int = 0

5.3 Local Attention for Long Context

For models with attention chunk sizes (e.g., Llama-3.1 with 8K local attention), FA3 breaks sequences into virtual local batches. This simulates a block-diagonal attention mask without actually materializing any mask tensor:

def make_local_attention_virtual_batches(
    attn_chunk_size, query_start_loc_np, seq_lens_np, block_table, page_size
):
    """
    Break sequences into local attention blocks.
    E.g., with attn_chunk_size=4, q_seqlens=[4,10,5], k_seqlens=[6,17,9]:
      seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
      seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
    Each local block becomes an independent "virtual" batch item.
    """
    q_tokens_in_first_block = np.minimum(
        attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
        q_seqlens
    )
    local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
    ... # Build virtual batch arrays
Performance Note: Local attention is critical for serving 128K+ context models. By splitting into virtual batches, each "batch item" only attends to its chunk, reducing the O(n^2) attention to O(n * chunk_size) while maintaining the same kernel efficiency.

5.4 Decode Path: flash_attn_with_kvcache

The decode path calls flash_attn_with_kvcache which operates on paged KV cache directly, avoiding any data movement. The page table maps logical token positions to physical cache pages.

6. MLA Backends: FlashMLA, CUTLASS MLA, TRT-LLM MLA

Multi-head Latent Attention (MLA) is DeepSeek's innovation that compresses KV projections into a low-rank space. SGLang provides three specialized decode kernels for MLA models, all sharing FlashInfer's MLA backend for prefill.

6.1 MLA Architecture Overview

MLA Key Dimensions: All MLA backends share these configuration values from the model config: kv_lora_rank (typically 512), qk_nope_head_dim (128), qk_rope_head_dim (64), v_head_dim (128). The KV cache stores a compressed representation of dimension kv_lora_rank + qk_rope_head_dim = 576 instead of separate K and V.

6.2 FlashMLA Backend

flashmla_backend.py -- Inherits from FlashInferMLAAttnBackend for prefill, overrides decode with the flash_mla_with_kvcache kernel. Uses PAGE_SIZE = 64.

from sgl_kernel.flash_mla import flash_mla_with_kvcache, get_mla_metadata

PAGE_SIZE = 64  # FlashMLA fixed page size

class FlashMLABackend(FlashInferMLAAttnBackend):
    def init_forward_metadata(self, forward_batch):
        if forward_batch.forward_mode.is_decode_or_idle():
            # Build page-aligned block indices via Triton kernel
            max_seqlen_pad = triton.cdiv(
                forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
            )
            block_kv_indices = torch.full(
                (bs, max_seqlen_pad), -1, dtype=torch.int32, device=device
            )
            create_flashmla_kv_indices_triton[(bs,)](
                self.req_to_token, forward_batch.req_pool_indices,
                forward_batch.seq_lens, None, block_kv_indices,
                self.req_to_token.stride(0), max_seqlen_pad,
            )
            # Get tile scheduler metadata for split-k attention
            mla_metadata, num_splits = get_mla_metadata(
                forward_batch.seq_lens.to(torch.int32),
                self.num_q_heads, 1,
                is_fp8_kvcache=self.is_fp8_kvcache,
            )
        else:
            # Fall back to FlashInfer MLA for prefill
            super().init_forward_metadata(forward_batch)

FlashMLA Decode Kernel Call

def forward_decode(self, q, k, v, layer, forward_batch, save_kv_cache=True):
    # Save KV to cache
    if k is not None and save_kv_cache:
        forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)

    k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
    reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)

    o, _ = flash_mla_with_kvcache(
        q=reshape_q,
        k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
        block_table=self.forward_metadata.block_kv_indices[:bs],
        cache_seqlens=forward_batch.seq_lens.to(torch.int32),
        head_dim_v=self.kv_lora_rank,  # Output dim is kv_lora_rank, not head_dim
        tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
        num_splits=self.forward_metadata.num_splits,
        softmax_scale=layer.scaling,
        causal=True,
    )
    return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)

6.3 CUTLASS MLA Backend

cutlass_mla_backend.py -- Uses NVIDIA CUTLASS kernels for MLA decode. Key difference: PAGE_SIZE = 128 (fixed by CUTLASS requirements). Separates Q into nope and rope components before the kernel call.

from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size

PAGE_SIZE = 128  # CUTLASS MLA only supports page_size=128

class CutlassMLABackend(FlashInferMLAAttnBackend):
    def init_forward_metadata(self, forward_batch):
        if forward_batch.forward_mode.is_decode_or_idle():
            # Workspace allocated per-batch
            workspace_size = cutlass_mla_get_workspace_size(
                max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1
            )
            workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
        else:
            super().init_forward_metadata(forward_batch)  # FlashInfer for prefill

CUTLASS MLA Kernel: Separate Q_nope and Q_rope

def forward_decode(self, q, k, v, layer, forward_batch, save_kv_cache=True,
                   q_rope=None, k_rope=None):
    # CUTLASS MLA takes nope/rope as separate arguments
    if q_rope is not None:
        q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
        q_rope = q_rope.view(-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim)
    else:
        reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
        q_nope = reshaped_q[:, :, :layer.v_head_dim]
        q_rope = reshaped_q[:, :, layer.v_head_dim:]

    o = cutlass_mla_decode(
        q_nope=q_nope, q_pe=q_rope,
        kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
        seq_lens=forward_batch.seq_lens.to(torch.int32),
        page_table=self.forward_metadata.block_kv_indices,
        workspace=self.forward_metadata.workspace,
        sm_scale=layer.scaling, num_kv_splits=1,
    )

6.4 TRT-LLM MLA Backend

trtllm_mla_backend.py -- Uses TensorRT-LLM's MLA kernel from flashinfer. Handles complex block constraints for page alignment:

# Block constraint from flashinfer requirements:
# block_num % (128 / block_size) == 0
TRTLLM_BLOCK_CONSTRAINT = 128

class TRTLLMMLABackend(FlashInferMLAAttnBackend):
    def _calc_padded_blocks(self, max_seq_len):
        """Pad block count to satisfy both TRT-LLM and Triton constraints."""
        blocks = triton.cdiv(max_seq_len, self.page_size)
        # LCM of TRT-LLM constraint and Triton page constraint
        trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
        triton_constraint = get_num_page_per_block_flashmla(self.page_size)
        constraint_lcm = math.lcm(trtllm_constraint, triton_constraint)
        if blocks % constraint_lcm != 0:
            blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
        return blocks

The TRT-LLM MLA backend also includes Triton kernels for padding/unpadding draft queries during speculative decoding (pad_draft_extend_query_kernel and unpad_draft_extend_output_kernel).

6.5 MLA Backend Comparison

Aspect FlashMLA CUTLASS MLA TRT-LLM MLA
Page Size 64 (fixed) 128 (fixed) Configurable (with constraints)
Prefill FlashInfer MLA (inherited) FlashInfer MLA (inherited) FlashInfer MLA (inherited)
Q Format Concatenated [nope|rope] Separate q_nope, q_pe Absorbed Q
FP8 KV Cache Yes (quantize Q to match) No Yes (with descale)
Spec Decode topk=1 only Falls back to FlashInfer Full support with padding
CUDA Graph Full support Full support Full support
Parent Class FlashInferMLAAttnBackend FlashInferMLAAttnBackend FlashInferMLAAttnBackend
Key Pattern: All three MLA backends inherit from FlashInferMLAAttnBackend and only override the decode path (and sometimes target_verify). Prefill always uses FlashInfer's MLA implementation, making these backends effectively "decode-only" specializations.

7. TRT-LLM MHA Backend (Default for Blackwell B200)

trtllm_mha_backend.py -- This is the default attention backend for SM100 (Blackwell B200/GB200/B300) GPUs on non-MLA models. It extends FlashInfer's backend with TRT-LLM's XQA kernels.

7.1 Why Default for Blackwell?

SM100 Optimization: The TRT-LLM XQA kernel is specifically tuned for Blackwell's 5th-gen Tensor Cores and TMA (Tensor Memory Accelerator). The backend detects whether to use the XQA implementation based on compute capability: self.is_xqa_impl = is_sm90_supported() or is_sm120_supported(). The XQA path keeps queries in BF16 even with FP8 KV cache, leveraging hardware mixed-precision matmul for optimal throughput.

7.2 Initialization and Workspace

class TRTLLMHAAttnBackend(FlashInferAttnBackend):
    """TRTLLM MHA attention kernel from flashinfer.
    The kernel supports sm100 only, with sliding window and attention sink."""

    def __init__(self, model_runner, skip_prefill=False, ...):
        # Preserve user's workspace size before parent __init__ overrides it
        workspace_size_bytes = (
            env_var.get() if env_var.is_set()
            else DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024  # 512 MB
        )
        super().__init__(model_runner, skip_prefill, ...)

        # Zero-initialized workspace (required by TRT-LLM)
        global global_zero_init_workspace_buffer
        if global_zero_init_workspace_buffer is None:
            global_zero_init_workspace_buffer = torch.zeros(
                self.workspace_size, dtype=torch.uint8, device=device,
            )
        self.workspace_buffer = global_zero_init_workspace_buffer

        # Detect XQA vs TRTLLM-GEN implementation
        self.is_xqa_impl = is_sm90_supported() or is_sm120_supported()

7.3 Sliding Window Attention with SWA KV Pool

A distinguishing feature is native sliding window attention support with separate KV pools. For hybrid SWA models (like Gemma-3), the backend maintains translated page tables for SWA layers that index into a separate, smaller KV pool:

# SWA hybrid model support -- two KV pools with separate indices
allocator = model_runner.token_to_kv_pool_allocator
self.use_sliding_window_kv_pool = isinstance(
    allocator, SWATokenToKVPoolAllocator
)
self._swa_kv_pool = (
    allocator.get_kvcache() if self.use_sliding_window_kv_pool else None
)

def _get_layer_page_table(self, layer, forward_batch):
    """Return correct page_table for the given layer (SWA or full)."""
    swa_pt = self.forward_metadata.swa_page_table
    if swa_pt is not None:
        _, is_swa = self._swa_kv_pool.layers_mapping[layer.layer_id]
        if is_swa:
            return swa_pt
    return self.forward_metadata.page_table

7.4 Decode Kernel Call

def forward_decode(self, q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs):
    # For XQA: q_dtype stays bf16 even with FP8 KV
    if self.data_type == torch.float8_e4m3fn and (not self.is_xqa_impl):
        q = q.to(torch.float8_e4m3fn)

    q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
    k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
    # Permute: [pages, page_size, heads, dim] -> [pages, heads, page_size, dim]
    k_cache = k_cache.view(-1, self.page_size, layer.tp_k_head_num, layer.head_dim
        ).permute(0, 2, 1, 3)
    v_cache = v_cache.view(-1, self.page_size, layer.tp_v_head_num, layer.head_dim
        ).permute(0, 2, 1, 3)

    page_table = self._get_layer_page_table(layer, forward_batch)

    o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
        query=q,
        kv_cache=(k_cache, v_cache),
        workspace_buffer=self.workspace_buffer,
        block_tables=page_table,
        seq_lens=self.forward_metadata.cache_seqlens_int32,
        max_seq_len=self.max_context_len,
        bmm1_scale=q_scale * k_scale * layer.scaling,
        bmm2_scale=1.0,
        window_left=layer.sliding_window_size,  # Native SWA support
        sinks=attention_sink,                    # Attention sink support
        out_dtype=self.q_data_type,
    )

7.5 CUDA Graph State Management

TRT-LLM MHA pre-allocates page tables, cumulative sequence length arrays, and optional SWA translation tables for CUDA graph capture across decode, target_verify, and draft_extend modes:

def init_cuda_graph_state(self, max_bs, max_num_tokens):
    max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size
    self.decode_cuda_graph_metadata = {
        "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device="cuda"),
        "page_table": torch.zeros(max_bs, max_num_pages, dtype=torch.int32, device="cuda"),
        "swa_page_table": self._alloc_swa_page_table(max_bs, max_num_pages),
        "strided_indices": torch.arange(0, self.max_context_len, self.page_size, device="cuda"),
    }
    # Also pre-allocates target_verify_metadata and draft_extend_metadata dicts

8. Triton Backend (Portable Fallback)

triton_backend.py -- A fully custom attention implementation using Triton JIT kernels. This backend works on any CUDA GPU without external dependencies, making it the universal fallback.

8.1 Architecture

class TritonAttnBackend(AttentionBackend):
    def __init__(self, model_runner, skip_prefill=False):
        # Lazy imports -- avoid CUDA context init at import time
        from sglang.srt.layers.attention.triton_ops.decode_attention import (
            decode_attention_fwd,
        )
        from sglang.srt.layers.attention.triton_ops.extend_attention import (
            extend_attention_fwd, extend_attention_fwd_unified,
            build_unified_kv_indices,
        )
        # Disable torch.compile for triton kernels
        self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
        self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)

8.2 Split-K Decode with Dynamic KV Splits

The Triton decode kernel uses a split-K strategy where the KV cache is partitioned across multiple splits. The number of splits is computed dynamically to saturate GPU SMs:

def get_num_kv_splits(self, num_kv_splits, seq_lens):
    num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0]
    num_group = num_token // num_seq

    # Deterministic mode: fixed splits per sequence length
    if self.split_tile_size is not None and self.enable_deterministic:
        expanded_seq_lens = seq_lens.repeat_interleave(num_group)
        num_kv_splits[:] = (
            expanded_seq_lens + self.split_tile_size - 1
        ) // self.split_tile_size
        return

    # Dynamic mode: use GPU SM count to determine splits
    get_num_kv_splits_triton[(1,)](
        num_kv_splits, seq_lens, num_seq, num_group,
        self.num_head, self.num_kv_head,
        self.max_kv_splits, self.device_core_count,
    )

8.3 Metadata Structure

@dataclass
class ForwardMetadata:
    attn_logits: torch.Tensor    # (bs, num_heads, max_kv_splits, v_head_dim)
    attn_lse: torch.Tensor       # (bs, num_heads, max_kv_splits)
    max_extend_len: int
    num_kv_splits: torch.Tensor  # (bs,)
    kv_indptr: torch.Tensor      # CSR-format pointers into kv_indices
    kv_indices: torch.Tensor     # Flat array of token indices
    qo_indptr: torch.Tensor
    custom_mask: torch.Tensor    # For speculative decoding verification
    mask_indptr: torch.Tensor
    # Sliding window variants
    window_kv_indptr: torch.Tensor
    window_kv_indices: torch.Tensor
    window_num_kv_splits: torch.Tensor
    window_kv_offsets: torch.Tensor

8.4 Deterministic Inference

When --enable-deterministic-inference is set, the Triton backend uses fixed split tile sizes instead of dynamic SM-based scheduling, ensuring bitwise-identical outputs across runs:

if self.enable_deterministic:
    self.split_tile_size = get_int_env_var(
        "SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE", 256
    )
    self.static_kv_splits = False  # Use deterministic logic instead

9. Hybrid Attention Backend

hybrid_attn_backend.py -- Allows mixing different backends for prefill and decode phases. This is used when --prefill-attention-backend differs from --decode-attention-backend.

9.1 Backend Selection by Forward Mode

class HybridAttnBackend(AttentionBackend):
    """Support different backends for prefill and decode."""

    def __init__(self, model_runner, prefill_backend, decode_backend):
        self.model_runner = model_runner
        self.prefill_backend = prefill_backend
        self.decode_backend = decode_backend

    def _select_backend(self, forward_mode) -> AttentionBackend:
        """
        Select backend based on forward mode:
          - decode_or_idle       -> decode backend (always)
          - target_verify/draft  -> decode if speculative_attention_mode=="decode"
                                    else prefill
          - prefill              -> prefill backend (always)
        """
        if forward_mode.is_decode_or_idle():
            return self.decode_backend
        elif forward_mode.is_target_verify() or forward_mode.is_draft_extend():
            return (
                self.decode_backend
                if self.model_runner.server_args.speculative_attention_mode == "decode"
                else self.prefill_backend
            )
        else:
            return self.prefill_backend

9.2 Transparent Delegation

Every method simply delegates to the selected backend. The hybrid backend itself stores no attention-specific state:

def init_forward_metadata(self, forward_batch):
    backend = self._select_backend(forward_batch.forward_mode)
    backend.init_forward_metadata(forward_batch)

def init_cuda_graph_state(self, max_bs, max_num_tokens):
    self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
    # Also init prefill backend if speculative decoding uses prefill mode
    if (self.model_runner.server_args.speculative_algorithm is not None
        and self.model_runner.server_args.speculative_attention_mode == "prefill"):
        self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens)

def forward(self, q, k, v, layer, forward_batch, save_kv_cache=True, **kwargs):
    backend = self._select_backend(forward_batch.forward_mode)
    return backend.forward(q, k, v, layer, forward_batch, save_kv_cache, **kwargs)

def forward_decode(self, q, k, v, layer, forward_batch, ...):
    # Decode always goes to decode_backend directly
    return self.decode_backend.forward_decode(q, k, v, layer, forward_batch, ...)

def forward_extend(self, q, k, v, layer, forward_batch, ...):
    # Extend may go to either backend depending on speculative mode
    backend = self._select_backend(forward_batch.forward_mode)
    return backend.forward_extend(q, k, v, layer, forward_batch, ...)

9.3 Example Use Cases

FlashMLA decode + FlashInfer prefill

Use --decode-attention-backend flashmla --prefill-attention-backend flashinfer for DeepSeek models where FlashMLA decode kernel is faster but does not handle prefill.

TRT-LLM MHA decode + FA3 prefill

Combine Blackwell-optimized decode with FA3's async TMA prefill. Use --decode-attention-backend trtllm_mha --prefill-attention-backend fa3.

10. Comprehensive Backend Comparison

Backend GPU Arch Model Type Prefill Decode SWA CUDA Graph FP8 KV
FlashInfer default SM80+ (Ampere/Ada) MHA + MLA Ragged + Paged Paged, tensor cores Yes (2 wrappers) Full Yes
FA3 default SM90 (Hopper) MHA + MLA varlen + kvcache flash_attn_with_kvcache Yes + Local Attn Full Yes
TRT-LLM MHA default SM100 (Blackwell) MHA only FlashInfer (inherited) trtllm XQA kernel Native + SWA pool Full Yes (XQA mixed)
FlashMLA SM80+ MLA only FlashInfer MLA flash_mla_with_kvcache N/A Full Yes
CUTLASS MLA SM80+ MLA only FlashInfer MLA cutlass_mla_decode N/A Full No
TRT-LLM MLA SM80+ MLA only FlashInfer MLA trtllm_gen_mla kernel N/A Full Yes
Triton Any CUDA MHA + MLA Custom Triton Split-K Triton Yes Full No
Hybrid Any Any Delegated Delegated Backend-dependent Both backends Backend-dependent

10.1 Feature Support Matrix

Feature FlashInfer FA3 TRT-LLM MHA Triton FlashMLA CUTLASS MLA
Cross-Attention Yes Yes Inherited No No No
Speculative Decode Full Full topk=1 Full topk=1 Falls back
Deterministic Mode Yes (fixed splits) N/A N/A Yes (fixed tiles) N/A N/A
Logit Capping Via kernel param Via kernel param N/A tanh cap N/A N/A
Multi-Item Scoring Yes N/A N/A N/A N/A N/A
Attention Sinks N/A N/A Yes (native) N/A N/A N/A
Page Size Configurable Configurable Configurable Configurable 64 (fixed) 128 (fixed)

11. Architecture Diagram

SGLang Attention Backend Architecture RadixAttention Layer forward_batch.attn_backend.forward() _get_default_attn_backend() GPU arch + MHA/MLA + user config FlashInfer Ampere/Ada default Paged wrappers Ragged prefill SWA + Cross-attn FA3 / FA4 Hopper default Async TMA Local attention MHA + MLA TRT-LLM MHA Blackwell default XQA kernels Native SWA + sinks SWA KV pool Triton Universal fallback Split-K decode Custom extend Deterministic Hybrid Mix prefill/decode Delegates forward Mode-based select Spec decode aware MLA Decode Backends (DeepSeek) All inherit from FlashInferMLAAttnBackend for prefill FlashMLA flash_mla_with_kvcache() PAGE_SIZE = 64 FP8 KV: Yes (quantize Q) Spec: topk=1 only CUTLASS MLA cutlass_mla_decode() PAGE_SIZE = 128 Split Q: q_nope + q_pe Workspace per batch TRT-LLM MLA trtllm_gen_mla_decode() Configurable page size LCM block padding Full spec decode support Common Infrastructure req_to_token_pool Token-to-page mapping token_to_kv_pool Paged KV cache storage ForwardBatch Batch metadata container CUDA Graph Capture/Replay state Forward Modes (ForwardMode enum) EXTEND DECODE TARGET_VERIFY DRAFT_EXTEND IDLE

Figure 1: The complete attention backend dispatch architecture. RadixAttention delegates to the selected backend, which is chosen based on GPU architecture, model type, and configuration. MLA backends all share FlashInfer for prefill.

Per-Batch Lifecycle: init_forward_metadata -> forward Scheduler Creates ForwardBatch with attn_backend ref init_forward_metadata() Build KV indices, page tables Plan wrappers (FlashInfer) Model Forward Pass Each layer calls RadixAttention.forward() Kernel Execution forward_decode() or forward_extend() CUDA Graph Path (Decode) init_cuda_graph_state() Allocate index buffers Create CG wrappers capture_cuda_graph() Build metadata for bs=N Record graph with real ops replay_cuda_graph() Update indices in-place Replay pre-recorded graph Zero overhead No CPU-GPU sync No re-planning Page Table: req_to_token -> block_kv_indices req_to_token: 0 1 2 64 65 ... token indices (flat) block_kv_indices: page 0 page 1 page 2 -1 page indices (per-request, padded to max_blocks)

Figure 2: Attention backend lifecycle per batch. The top row shows the normal path; the middle row shows the CUDA graph capture/replay path that eliminates CPU overhead on decode. The bottom illustrates how token-level indices are grouped into page-level block indices for paged attention kernels.

Source File Reference

File Class Role
radix_attention.py RadixAttention nn.Module attention layer, dispatches to backend
attention_registry.py ATTENTION_BACKENDS Registry mapping string names to factory functions
flashinfer_backend.py FlashInferAttnBackend FlashInfer wrappers for paged MHA attention
flashattention_backend.py FlashAttentionBackend FA3/FA4 kernels with local attention support
flashmla_backend.py FlashMLABackend FlashMLA decode for DeepSeek MLA (page_size=64)
cutlass_mla_backend.py CutlassMLABackend CUTLASS MLA decode (page_size=128, separate q_nope/q_pe)
trtllm_mla_backend.py TRTLLMMLABackend TRT-LLM MLA decode with LCM block padding
trtllm_mha_backend.py TRTLLMHAAttnBackend TRT-LLM XQA decode for Blackwell (SWA + sinks)
triton_backend.py TritonAttnBackend Portable Triton JIT kernels with split-K decode
hybrid_attn_backend.py HybridAttnBackend Mixes different backends for prefill vs decode
server_args.py _get_default_attn_backend() Auto-selection logic based on GPU arch + model type