Attention Backends

How vLLM selects, configures, and dispatches to FlashAttention, FlashInfer, and MLA kernels -- traced through actual source code.

vllm/v1/attention/backends/ -- Source-Level Analysis
← Back to vLLM Overview

Table of Contents

  1. Backend Selection & Priority System
  2. FlashAttention Backend (FA2 / FA3)
  3. FlashInfer Backend
  4. Multi-Latent Attention (MLA) Overview
  5. MLA Backend Implementations
  6. Attention Metadata Structures
  7. Block Table & KV Cache Layout
  8. Performance: Hardware-to-Backend Mapping

1. Backend Selection & Priority System

Every attention layer in vLLM calls get_attn_backend() to determine which kernel to use. The entry point is in vllm/v1/attention/selector.py:

# vllm/v1/attention/selector.py

def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: str | None,
    use_mla: bool = False,
    has_sink: bool = False,
    use_sparse: bool = False,
    ...
) -> type[AttentionBackend]:
    """Selects which attention backend to use and lazily imports it."""
    vllm_config = get_current_vllm_config()
    # If user set --block-size explicitly, pass it to validation
    block_size = cache_config.block_size if cache_config.user_specified_block_size else None

    attn_selector_config = AttentionSelectorConfig(
        head_size=head_size, dtype=dtype, kv_cache_dtype=kv_cache_dtype,
        block_size=block_size, use_mla=use_mla, has_sink=has_sink, ...
    )
    return _cached_get_attn_backend(
        backend=vllm_config.attention_config.backend,  # user override or None
        attn_selector_config=attn_selector_config,
    )

Priority-Based Auto-Selection

When no backend is explicitly requested via --attention-backend, the CUDA platform iterates through a priority-ordered list and picks the first one that passes validation. This logic lives in vllm/platforms/cuda.py:

# vllm/platforms/cuda.py -- _get_backend_priorities()

def _get_backend_priorities(
    use_mla: bool,
    device_capability: DeviceCapability,
    num_heads: int | None = None,
    kv_cache_dtype: CacheDType | None = None,
) -> list[AttentionBackendEnum]:
    if use_mla:
        if device_capability.major == 10:  # Blackwell (SM100)
            return [
                FLASHINFER_MLA,   # Highest priority on Blackwell
                CUTLASS_MLA,
                FLASH_ATTN_MLA,
                FLASHMLA,
                TRITON_MLA,
                *sparse_backends,  # FLASHMLA_SPARSE / FLASHINFER_MLA_SPARSE
            ]
        else:  # Hopper (SM90) and older
            return [FLASH_ATTN_MLA, FLASHMLA, FLASHINFER_MLA, TRITON_MLA, ...]
    else:  # Standard MHA
        if device_capability.major == 10:  # Blackwell
            return [FLASHINFER, FLASH_ATTN, TRITON_ATTN, FLEX_ATTENTION]
        else:  # Ampere / Hopper
            return [FLASH_ATTN, FLASHINFER, TRITON_ATTN, FLEX_ATTENTION]
Key insight: The priority list depends on two dimensions: (1) whether the model uses MLA (e.g., DeepSeek-V2/V3), and (2) the GPU generation. On Blackwell, FlashInfer takes top priority for both MHA and MLA; on Hopper, FlashAttention leads for MHA while FlashMLA leads for MLA.

Validation Gate

Each backend class exposes a validate_configuration() method (inherited from AttentionBackend). The platform checks every backend in priority order:

# vllm/platforms/cuda.py -- get_valid_backends()

for priority, backend in enumerate(backend_priorities):
    backend_class = backend.get_class()
    invalid_reasons = backend_class.validate_configuration(
        device_capability=device_capability,
        **attn_selector_config._asdict(),
    )
    if invalid_reasons:
        invalid_reasons_map[backend] = (priority, invalid_reasons)
    else:
        valid_backends_priorities.append((backend, priority))

# Select backend with lowest priority index (highest priority)
sorted_indices = sorted(range(len(valid)), key=lambda i: valid[i][1])
selected_backend = valid[sorted_indices[0]][0]
Backend Selection Flow get_attn_backend() --backend? Yes Validate selected backend directly No _get_backend_priorities() MLA? GPU gen? for priority, backend in priorities: validate_configuration() First valid backend = lowest priority index Priority (index 0 = highest) Blackwell MHA: FI > FA > Triton Hopper MHA: FA > FI > Triton Blackwell MLA: FI > CUTLASS > FA Hopper MLA: FA > FlashMLA > FI Fallback (any): TRITON_MLA
Backend selection: user override takes precedence; otherwise auto-select by GPU generation + model type.

The Backend Registry

All backends are registered in an enum at vllm/v1/attention/backends/registry.py. Each entry maps a name to a fully qualified class path:

# vllm/v1/attention/backends/registry.py

class AttentionBackendEnum(Enum):
    FLASH_ATTN     = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
    FLASHINFER     = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
    TRITON_MLA     = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
    CUTLASS_MLA    = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
    FLASHMLA       = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
    FLASHINFER_MLA = "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
    FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
    TRITON_ATTN    = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
    # ... plus ROCm, CPU, sparse variants, etc.

2. FlashAttention Backend (FA2 / FA3)

The FlashAttention backend (vllm/v1/attention/backends/flash_attn.py) is the workhorse for standard multi-head attention on Ampere and Hopper GPUs. It auto-detects the FA version and adjusts its behavior accordingly.

Backend Class: Capabilities

# vllm/v1/attention/backends/flash_attn.py

class FlashAttentionBackend(AttentionBackend):
    accept_output_buffer: bool = True
    forward_includes_kv_cache_update: bool = False  # KV update separate from attn

    supported_dtypes = [torch.float16, torch.bfloat16]
    supported_kv_cache_dtypes = ["auto", "float16", "bfloat16"]

    @staticmethod
    def get_supported_kernel_block_sizes():
        # For hybrid models with FP32 Mamba state, restrict block sizes
        if model_config.is_hybrid:
            return [16, 32, 64]  # NaN propagation issue
        return [MultipleOf(16)]       # Any multiple of 16

    @classmethod
    def supports_compute_capability(cls, capability):
        return capability >= DeviceCapability(8, 0)  # Ampere+

    @classmethod
    def supports_per_head_quant_scales(cls):
        return get_flash_attn_version() >= 3  # FA3 only

KV Cache Shape

FlashAttention stores the KV cache with the key and value in a unified tensor. The first dimension selects K vs V:

# KV cache shape for FlashAttention:
# (2, num_blocks, block_size, num_kv_heads, head_size)
#  ^--- 0=key, 1=value

@staticmethod
def get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size, ...):
    if block_size % 16 != 0:
        raise ValueError("Block size must be a multiple of 16.")
    return (2, num_blocks, block_size, num_kv_heads, head_size)

Metadata Builder: AOT Scheduling

The FlashAttentionMetadataBuilder pre-computes scheduling decisions ahead of time (AOT) when running with FlashAttention 3. This is critical for CUDA graph compatibility:

# FlashAttentionMetadataBuilder.__init__()

self.aot_schedule = get_flash_attn_version() == 3  # Only FA3

if self.use_full_cuda_graph and self.aot_schedule:
    # FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
    # The +1 is for tile_count_semaphore (synchronization)
    # 4 slots per batch: prepare_varlen + dynamic_split + sort + head_swizzle
    max_batch_size = max(max_num_seqs, max_cudagraph_capture_size or 0)
    self.scheduler_metadata = torch.zeros(
        1 + round_up(max_batch_size, 4) * 4,
        dtype=torch.int32, device=self.device,
    )
    # Upper bound for intermediate buffer splits
    self.max_num_splits = attention_config.flash_attn_max_num_splits_for_cuda_graph
FA2 vs FA3 CUDA Graph support: FA3 supports AttentionCGSupport.ALWAYS (full cudagraph for all cases). FA2 only supports UNIFORM_BATCH because its special max_query_len=1 packed-GQA handling breaks when mixed prefill-decode batches are replayed from a graph captured with decode-only shapes.

The build() Method: Metadata Preparation

# FlashAttentionMetadataBuilder.build()

def build(self, common_prefix_len, common_attn_metadata, fast_build=False):
    # fast_build disables AOT scheduling -- used for spec-decode
    aot_schedule = self.aot_schedule and not fast_build

    # Inner function that invokes FA3's get_scheduler_metadata
    def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal):
        if aot_schedule:
            return get_scheduler_metadata(
                batch_size=batch_size, max_seqlen_q=max_query_len,
                max_seqlen_k=max_seq_len, num_heads_q=self.num_heads_q,
                num_heads_kv=self.num_heads_kv, headdim=self.headdim,
                cache_seqlens=seqlens, qkv_dtype=qkv_dtype,
                cu_seqlens_q=cu_query_lens, page_size=self.block_size,
                causal=causal, window_size=self.aot_sliding_window,
                num_splits=max_num_splits,
            )
        return None

    # Three paths: DCP, cascade, or standard
    if self.dcp_world_size > 1:
        scheduler_metadata = schedule(..., causal=False)
    elif use_cascade:
        prefix_scheduler = schedule(batch_size=1, ..., causal=False)
        suffix_scheduler = schedule(batch_size=num_reqs, ..., causal=True)
    else:
        scheduler_metadata = schedule(..., causal=causal)

    # For FA3 + CUDA graph: copy into persistent buffer, zero rest
    if self.use_full_cuda_graph and scheduler_metadata is not None:
        n = scheduler_metadata.shape[0]
        self.scheduler_metadata[:n] = scheduler_metadata
        self.scheduler_metadata[n:] = 0  # Zero rest for correctness

The forward() Method: Kernel Dispatch

# FlashAttentionImpl.forward()

def forward(self, layer, query, key, value, kv_cache, attn_metadata, output, ...):
    num_actual_tokens = attn_metadata.num_actual_tokens
    key_cache, value_cache = kv_cache.unbind(0)   # Split K/V

    if self.kv_cache_dtype.startswith("fp8"):
        key_cache = key_cache.view(dtype)          # Reinterpret as FP8
        value_cache = value_cache.view(dtype)

    if not attn_metadata.use_cascade:
        # Standard path: single flash_attn_varlen_func call
        flash_attn_varlen_func(
            q=query[:num_actual_tokens],
            k=key_cache, v=value_cache,
            out=output[:num_actual_tokens],
            cu_seqlens_q=attn_metadata.query_start_loc,
            max_seqlen_q=attn_metadata.max_query_len,
            seqused_k=attn_metadata.seq_lens,
            max_seqlen_k=attn_metadata.max_seq_len,
            softmax_scale=self.scale, causal=attn_metadata.causal,
            block_table=attn_metadata.block_table,      # Paged attention!
            scheduler_metadata=attn_metadata.scheduler_metadata,  # AOT
            fa_version=self.vllm_flash_attn_version,
            q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,
            num_splits=attn_metadata.max_num_splits,
            s_aux=self.sinks,                            # Attention sinks
        )
    else:
        # Cascade attention: prefix + suffix kernels, merged via LSE
        cascade_attention(output, query, key_cache, value_cache, ...)

Cascade Attention

When multiple requests share a common KV prefix (e.g., system prompt), vLLM computes attention over the shared prefix once and merges with per-request suffix attention using log-sum-exp (LSE) values. The FlashAttention backend dispatches two separate flash_attn_varlen_func calls: one for the prefix (non-causal, batch_size=1) and one for the per-request suffixes (causal).

KV Cache Update

# FlashAttentionImpl.do_kv_cache_update() -- called BEFORE forward()

def do_kv_cache_update(self, layer, key, value, kv_cache, slot_mapping):
    key_cache, value_cache = kv_cache.unbind(0)
    # reshape_and_cache_flash writes K/V tokens into paged blocks
    # using slot_mapping to determine the block + offset for each token
    reshape_and_cache_flash(
        key[:num_actual_tokens], value[:num_actual_tokens],
        key_cache, value_cache, slot_mapping[:num_actual_tokens],
        self.kv_cache_dtype, layer._k_scale, layer._v_scale,
    )

3. FlashInfer Backend

FlashInfer (vllm/v1/attention/backends/flashinfer.py) takes top priority on Blackwell GPUs and offers two internal dispatch paths: the native FlashInfer wrappers and TRT-LLM attention kernels.

Backend Class: Blackwell-First Design

# vllm/v1/attention/backends/flashinfer.py

class FlashInferBackend(AttentionBackend):
    accept_output_buffer: bool = True
    supported_kv_cache_dtypes = ["auto", "float16", "bfloat16",
                                 "fp8", "fp8_e4m3", "fp8_e5m2"]

    @staticmethod
    def get_supported_kernel_block_sizes():
        return [16, 32, 64]  # Fixed sizes, not MultipleOf

    @classmethod
    def supports_compute_capability(cls, capability):
        return capability >= DeviceCapability(7, 5) \
           and capability <= DeviceCapability(12, 1)

    @classmethod
    def get_required_kv_cache_layout(cls):
        capability = current_platform.get_device_capability()
        if capability.major == 10:  # Blackwell requires HND layout
            return "HND"
        return None  # Use default (NHD)

The Plan/Execute Pattern

FlashInfer uses a plan-then-execute pattern. During build(), wrappers call .plan() to precompute indices and scheduling. During forward(), they call .run() to execute the actual attention.

Workspace Management

# Workspace buffer: allocated once, reused across calls
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024  # 2 GB

def _get_workspace_buffer(self):
    if self._workspace_buffer is None:
        buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE
        if envs.VLLM_BATCH_INVARIANT:
            buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
        self._workspace_buffer = torch.zeros(
            buffer_size, dtype=torch.uint8, device=self.device
        )
    return self._workspace_buffer

Wrapper Initialization

# Separate wrappers for prefill vs decode

def _get_prefill_wrapper(self):
    if self.use_dcp:
        return BatchDCPPrefillWrapper(workspace_buffer=..., dcp_a2a=...)
    else:
        return BatchPrefillWithPagedKVCacheWrapper(
            self._get_workspace_buffer(), get_kv_cache_layout()
        )

def _get_decode_wrapper(self, batch_size, use_cudagraph=False):
    # For CUDA graphs: one wrapper PER batch size, with pinned buffers
    if use_cudagraph:
        decode_wrapper = self._decode_wrappers_cudagraph.get(batch_size)
    if decode_wrapper is None:
        decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
            self._get_workspace_buffer(), get_kv_cache_layout(),
            use_cuda_graph=use_cudagraph,
            paged_kv_indptr_buffer=self.paged_kv_indptr.gpu[:batch_size+1],
            paged_kv_indices_buffer=self.paged_kv_indices.gpu,
            paged_kv_last_page_len_buffer=self.paged_kv_last_page_len.gpu[:batch_size],
            use_tensor_cores=True,  # Always on for modern GPUs
        )
        if use_cudagraph:
            self._decode_wrappers_cudagraph[batch_size] = decode_wrapper

The build() Method: Dual-Path Dispatch

FlashInfer's build() decides per-phase whether to use native FlashInfer or TRT-LLM kernels:

# FlashInferMetadataBuilder.build()

# Step 1: Decide dispatch mode for each phase
prefill_use_trtllm = use_trtllm_attention(
    self.num_qo_heads, self.num_kv_heads, num_prefill_tokens,
    max_seq_len, self.dcp_world_size, self.cache_dtype, ...
)
decode_use_trtllm = self.use_trtllm_decode_attention and self.dcp_world_size <= 1

# Step 2: Build metadata skeleton
attn_metadata = FlashInferMetadata(
    num_actual_tokens=num_actual_tokens,
    slot_mapping=slot_mapping,
    q_data_type=self.q_data_type,
    num_decodes=num_decodes, num_prefills=num_prefills,
    prefill=None, decode=None,  # populated below
    use_cascade=use_cascade, cascade_wrapper=None,
)

# Step 3a: PREFILL -- either FI native or TRTLLM
if prefill_use_trtllm:
    attn_metadata.prefill = TRTLLMPrefill(
        block_tables=block_table_tensor[prefill_start:],
        seq_lens=seq_lens[prefill_start:],
        cum_seq_lens_q=qo_indptr_prefill_gpu,
        max_q_len=max_q_len_prefill, max_seq_len=max_seq_len,
    )
else:
    prefill_wrapper = self._get_prefill_wrapper()
    prefill_wrapper.plan(...)   # Plan phase -- precompute indices
    attn_metadata.prefill = FIPrefill(wrapper=prefill_wrapper)

# Step 3b: DECODE -- either FI native or TRTLLM
if decode_use_trtllm:
    attn_metadata.decode = TRTLLMDecode(
        block_tables=block_table_tensor[:num_decodes],
        seq_lens=seq_lens[:num_decodes], max_seq_len=max_seq_len,
    )
else:
    decode_wrapper = self._get_decode_wrapper(num_decode_tokens, use_cudagraph)
    fast_plan_decode(decode_wrapper, ...)  # Plan phase
    attn_metadata.decode = FIDecode(wrapper=decode_wrapper)

Metadata Types: FI Native vs TRT-LLM

FI Native Metadata

@dataclass
class FIPrefill:
    wrapper: BatchPrefillWith...Wrapper
    # Wrapper has plan() already called
    # Contains paged_kv_indptr,
    # paged_kv_indices, last_page_len

@dataclass
class FIDecode:
    wrapper: BatchDecodeWith...Wrapper
    # Same pattern: plan() pre-called

TRT-LLM Metadata

@dataclass
class TRTLLMPrefill:
    block_tables: torch.Tensor  # [num_prefills, max_blocks]
    seq_lens: torch.Tensor      # [num_prefills]
    cum_seq_lens_q: torch.Tensor
    cum_seq_lens_kv: torch.Tensor
    max_q_len: int
    max_seq_len: int

@dataclass
class TRTLLMDecode:
    block_tables: torch.Tensor  # [num_decodes, max_blocks]
    seq_lens: torch.Tensor      # [num_decodes]
    max_seq_len: int

CUDA Graph Support

# FlashInfer CUDA graph strategy depends on TRT-LLM availability

@classmethod
def get_cudagraph_support(cls, vllm_config, kv_cache_spec):
    # Check if TRT-LLM attention works for ALL spec layers
    has_trtllm_support = True
    for spec in kv_specs:
        if not can_use_trtllm_attention(num_qo_heads, spec.num_kv_heads):
            has_trtllm_support = False
            break
    if has_trtllm_support:
        return AttentionCGSupport.UNIFORM_BATCH       # Full CG with mixed batches
    else:
        return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE  # Decode-only CG

4. Multi-Latent Attention (MLA) Overview

MLA, introduced in DeepSeek-V2, compresses the KV cache into a single latent vector per token. This dramatically reduces memory bandwidth during decode. The design is documented in extensive comments in vllm/model_executor/layers/attention/mla_attention.py:

MLA Key Dimensions (DeepSeek-V3)

SymbolMeaningValue
NNumber of attention heads128
LkvLatent KV dimension512
PNo-RoPE head dimension128
RRoPE head dimension64
VValue head dimension128

KV cache stores only [Lkv + R] = 576 values per token per layer, vs 2 * N * P = 32768 for standard MHA. A 57x compression.

Two Computation Paths

The source code defines two paths explicitly (from the docstring in mla_attention.py):

MLA: Two Computation Paths Compute-Friendly (forward_mha) -- Prefill 1. Decompress KV latent into full heads: k_nope = (kv_c @ W_UK).view(Skv, N, P) v = (kv_c @ W_UV).view(Skv, N, V) 2. Standard Multi-Head Attention: Q = cat([q_nope, q_pe], dim=-1) # [Sq, N, P+R] K = cat([k_nope, k_pe], dim=-1) # [Skv, N, P+R] V = v # [Skv, N, V] sdpa(Q, K, V) # Standard FA call Memory cost: O(Skv * N * P) Compute: Standard MHA -- good Sq/Skv ratio FlashAttention FlashInfer Memory-Friendly (forward_mqa) -- Decode 1. Absorb W_UK into query (no decompression!): ql_nope = einsum("snh,lnh->snl", q, W_UK) 2. Multi-Query Attention on latent vectors: Q = cat([ql_nope, q_pe], dim=-1) # [Sq, N, Lkv+R] K = cat([kv_c, k_pe], dim=-1) # [Skv, 1, Lkv+R] V = kv_c # [Skv, 1, Lkv] MQA(Q, K, V) # 1 KV head! Memory cost: O(Skv * (Lkv + R)) -- no per-head expand Compute: Lkv > P, more FLOPs per token FlashMLA CUTLASS MLA Triton MLA
Prefill uses standard MHA with decompressed heads; decode uses MQA on compressed latent vectors to save bandwidth.
# From vllm/model_executor/layers/attention/mla_attention.py docstring:

## Compute Friendly Approach (forward_mha) -- used for prefill:
# Decompress KV latent into per-head K and V
k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
v      = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
# Standard MHA with QK headdim = P + R, V headdim = V
spda_o = scaled_dot_product_attention(
    cat([q_nope, q_pe], dim=-1),   # [Sq, N, P+R]
    cat([k_nope, k_pe_expanded], dim=-1),  # [Skv, N, P+R]
    v                                       # [Skv, N, V]
)

## Data-Movement Friendly Approach (forward_mqa) -- used for decode:
# Absorb W_UK into query -- avoid decompressing KV
ql_nope = einsum("snh,lnh->snl", q, W_UK)
# MQA with QK headdim = Lkv + R, V headdim = Lkv
spda_o = scaled_dot_product_attention(
    cat([ql_nope, q_pe], dim=-1),  # [Sq, N, Lkv+R] = [Sq, N, 576]
    cat([kv_c, k_pe], dim=-1),      # [Skv, 1, 576] -- MQA, single KV head
    kv_c                            # [Skv, 1, 512]
)
o = einsum("snl,lnv->snv", spda_o, W_UV)  # Project back to V dimension

5. MLA Backend Implementations

vLLM provides multiple MLA decode kernels under vllm/v1/attention/backends/mla/. All inherit from MLACommonBackend / MLACommonImpl and implement forward_mqa() for the decode path.

5a. FlashMLA -- Hopper & Blackwell

# vllm/v1/attention/backends/mla/flashmla.py

class FlashMLABackend(MLACommonBackend):
    @staticmethod
    def get_supported_kernel_block_sizes():
        return [64]  # Fixed block size of 64

    @staticmethod
    def get_name():
        return "FLASHMLA"

    @classmethod
    def supports_compute_capability(cls, capability):
        return capability.major in [9, 10]  # Hopper + Blackwell only

FlashMLA Decode Kernel

# FlashMLAImpl.forward_mqa()

def forward_mqa(self, q, kv_c_and_k_pe_cache, attn_metadata, layer):
    q = reshape_query_for_spec_decode(q, attn_metadata.num_decodes)

    scheduler_metadata = attn_metadata.decode.scheduler_metadata

    if self.kv_cache_dtype.startswith("fp8"):
        o, lse = flash_mla_with_kvcache_fp8(
            q=q,
            k_cache=kv_c_and_k_pe_cache.unsqueeze(-2),  # Add head dim of 1
            block_table=attn_metadata.decode.block_table,
            cache_seqlens=attn_metadata.decode.seq_lens,
            head_dim_v=self.kv_lora_rank,        # 512
            tile_scheduler_metadata=scheduler_metadata.tile_scheduler_metadata,
            num_splits=scheduler_metadata.num_splits,
            softmax_scale=self.scale, causal=True,
            descale_q=layer._q_scale, descale_k=layer._k_scale,
        )
    else:
        o, lse = flash_mla_with_kvcache(
            q=q,
            k_cache=kv_c_and_k_pe_cache.unsqueeze(-2),  # MQA: 1 KV head
            block_table=attn_metadata.decode.block_table,
            cache_seqlens=attn_metadata.decode.seq_lens,
            head_dim_v=self.kv_lora_rank,
            tile_scheduler_metadata=scheduler_metadata,
            softmax_scale=self.scale, causal=True,
        )
    return reshape_attn_output_for_spec_decode(o), lse

FlashMLA Metadata Builder

# FlashMLAMetadataBuilder._build_decode()

def _build_decode(self, block_table, seq_lens, max_seq_len, ...):
    # Compute tile scheduler metadata for the FlashMLA kernel
    scheduler_metadata, _ = get_mla_metadata(
        seq_lens, num_q_tokens_per_head_k, 1,  # MQA: 1 KV head
        is_fp8_kvcache=self.is_fp8_kvcache,
    )
    if self.is_fp8_kvcache:
        # FP8 needs separate tile_scheduler_metadata and num_splits
        tile_scheduler_metadata, num_splits = get_mla_metadata_dense_fp8(...)
        # Copy into CUDA graph persistent buffers
        if self.cudagraph_mode.has_full_cudagraphs():
            self.cg_buf_tile_scheduler_metadata[:n].copy_(tile_scheduler_metadata)
            self.cg_buf_num_splits[:n].copy_(num_splits)

5b. CUTLASS MLA -- Blackwell Only

# vllm/v1/attention/backends/mla/cutlass_mla.py

class CutlassMLABackend(MLACommonBackend):
    @staticmethod
    def get_supported_kernel_block_sizes():
        return [128]   # Forced to 128

    @classmethod
    def supports_compute_capability(cls, capability):
        return capability.major == 10  # Blackwell (SM100) ONLY

SM100 Workspace

# Manages GPU workspace buffer for CUTLASS MLA

class SM100Workspace:
    def __init__(self, initial_workspace_size):
        self._workspace_buf = torch.empty(
            initial_workspace_size, device="cuda", dtype=torch.uint8
        )
        self._block_size = 128  # Forced to 128
        self._sm_count = num_compute_units(0)

    def ensure_size(self, attn_metadata, num_kv_splits):
        workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
            max_seq_len * self._block_size, batch_size,
            self._sm_count, num_kv_splits=num_kv_splits,
        )
        if self._workspace_buf.shape[0] < workspace_size:
            self._workspace_buf.resize_(workspace_size)

g_sm100_workspace = SM100Workspace(128 * 1024 * 1024)  # 128 MB initial
MAX_HEADS = 128

CUTLASS MLA Decode Kernel

# CutlassMLAImpl.forward_mqa()

def forward_mqa(self, q, kv_c_and_k_pe_cache, attn_metadata, layer):
    if type(q) is tuple:
        q_nope, q_pe = q               # Already split
    else:
        q_nope, q_pe = torch.split(q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)

    self._workspace.ensure_size(attn_metadata, self._num_kv_splits)

    # q_nope: [B, H, 512], q_pe: [B, H, 64]
    # kv_c_and_k_pe_cache: [num_blocks, 128, 576]
    ops.sm100_cutlass_mla_decode(
        out, lse, q_nope, q_pe,
        kv_c_and_k_pe_cache,
        seq_lens, page_table, workspace,
        sm_scale=self.scale,
        num_kv_splits=self._num_kv_splits,  # -1 = auto-detect
    )
    return out[:, :H], lse[:, :H]

5c. Triton MLA -- Universal Fallback

# vllm/v1/attention/backends/mla/triton_mla.py

class TritonMLABackend(MLACommonBackend):
    @staticmethod
    def get_supported_kernel_block_sizes():
        return [MultipleOf(16)]  # Any multiple of 16

    @classmethod
    def supports_compute_capability(cls, capability):
        return True  # Works on ANY GPU

Triton MLA Decode: Split-K with Explicit Workspace

# TritonMLAImpl.forward_mqa()

def forward_mqa(self, q, kv_c_and_k_pe_cache, attn_metadata, layer):
    B, q_num_heads = q.shape[0], q.shape[1]
    o = torch.zeros(B, q_num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device)
    lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)

    num_kv_splits = 1 if envs.VLLM_BATCH_INVARIANT else 4

    # Workspace for split-K: stores partial outputs + LSE per split
    attn_logits = torch.empty(
        (B, q_num_heads, num_kv_splits, self.kv_lora_rank + 1),
        #                                   ^ +1 for LSE per split
        dtype=torch.float32, device=q.device,
    )

    kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)  # Add head dim of 1
    decode_attention_fwd(
        q, kv_c_and_k_pe_cache,
        kv_c_and_k_pe_cache[..., :self.kv_lora_rank],  # V = kv_c portion
        o, lse, block_table, seq_lens, attn_logits,
        num_kv_splits, self.scale, PAGE_SIZE,
        k_scale=layer._k_scale, v_scale=layer._k_scale,
    )
    return o, lse
Key difference: Triton MLA uses decode_attention_fwd with an explicit split-K workspace (attn_logits) that stores per-split partial outputs and LSE values. The +1 in the last dimension holds the LogSumExp for each split, which the stage-2 kernel uses to merge results.

6. Attention Metadata Structures

Each backend defines its own metadata dataclass. These are populated by the builder's build() method and consumed by the implementation's forward().

FlashAttentionMetadata

# vllm/v1/attention/backends/flash_attn.py

@dataclass
class FlashAttentionMetadata:
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------|
    # |- tokenA -|..............|-- newTokens --|
    # |---------- context_len --|
    # |-------------- seq_len ------------------|
    #                          |-- query_len ---|

    num_actual_tokens: int          # Tokens excluding padding
    max_query_len: int
    query_start_loc: torch.Tensor   # [num_reqs + 1] cumsum of query lens
    max_seq_len: int
    seq_lens: torch.Tensor          # [num_reqs] total sequence lengths
    block_table: torch.Tensor       # [num_reqs, max_blocks_per_seq]
    slot_mapping: torch.Tensor      # [num_actual_tokens] -> cache slot idx

    # Cascade attention fields
    use_cascade: bool
    common_prefix_len: int
    cu_prefix_query_lens: torch.Tensor | None
    prefix_kv_lens: torch.Tensor | None
    suffix_kv_lens: torch.Tensor | None

    # AOT scheduling (FA3 only)
    scheduler_metadata: torch.Tensor | None = None
    prefix_scheduler_metadata: torch.Tensor | None = None
    max_num_splits: int = 0           # 0 = use FA3's heuristics
    causal: bool = True

FlashInferMetadata

# vllm/v1/attention/backends/flashinfer.py

@dataclass
class FlashInferMetadata:
    num_actual_tokens: int
    slot_mapping: torch.Tensor    # [num_actual_tokens]
    q_data_type: torch.dtype      # May be FP8 if TRT-LLM path

    num_decodes: int
    num_decode_tokens: int
    num_prefills: int
    num_prefill_tokens: int

    # Polymorphic: FI native OR TRT-LLM, depending on dispatch
    prefill: FIPrefill | TRTLLMPrefill | None
    decode:  FIDecode  | TRTLLMDecode  | None

    # Cascade attention (uses FlashInfer's MultiLevelCascadeAttentionWrapper)
    use_cascade: bool
    cascade_wrapper: MultiLevelCascadeAttentionWrapper | None

MLA Common Metadata

# From vllm/model_executor/layers/attention/mla_attention.py

@dataclass
class MLACommonDecodeMetadata:
    block_table: torch.Tensor     # [num_decodes, max_blocks]
    seq_lens: torch.Tensor        # [num_decodes]
    dcp_tot_seq_lens: torch.Tensor | None

# FlashMLA extends this with scheduler metadata:
@dataclass
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
    scheduler_metadata: FlashMLASchedMeta
    # Contains tile_scheduler_metadata and num_splits
    # for the FlashMLA kernel's work distribution

7. Block Table & KV Cache Layout

Block Table Format

The block table is a 2D tensor of shape [num_requests, max_blocks_per_seq]. Each entry is a physical block index into the KV cache pool. The attention kernel uses it to map logical token positions to physical memory locations:

Block Table Maps Logical Sequence to Physical KV Cache Blocks block_table[req_id][block_idx] Req 0 (seq_len=50): 7 14 3 - Req 1 (seq_len=96): 1 9 22 5 Physical KV Cache Pool blk 0 blk 1 blk 2 blk 3 ... blk 7 blk 9 blk 14 Block Internal Layout (block_size=16 example) FlashAttn: (2, num_blocks, block_size, num_kv_heads, head_size) FlashInfer: (num_blocks, 2, block_size, num_kv_heads, head_size) MLA: (num_blocks, block_size, kv_lora_rank + rope_dim) = (num_blocks, 64, 576) for DeepSeek-V3 slot_mapping[token_i] = block_idx * block_size + offset_in_block
Each entry in the block table points to a physical block in the KV cache pool. slot_mapping maps individual tokens to cache slots.

NHD vs HND Layout

vLLM supports two memory layouts for KV cache blocks, selectable via VLLM_KV_CACHE_LAYOUT:

# vllm/v1/attention/backends/utils.py

KVCacheLayoutType = Literal["NHD", "HND"]

@functools.lru_cache
def get_kv_cache_layout():
    # Priority: code override > env var > connector default
    if _KV_CACHE_LAYOUT_OVERRIDE is not None:
        return _KV_CACHE_LAYOUT_OVERRIDE
    cache_layout = envs.VLLM_KV_CACHE_LAYOUT
    if cache_layout is None:
        cache_layout = get_kv_connector_cache_layout()
    return cache_layout

NHD (Default)

(num_blocks, block_size, num_kv_heads, head_size)

Tokens within a block are contiguous, heads interleaved. Good for sequential token access patterns. Used by FlashAttention on Ampere/Hopper.

HND (Blackwell)

(num_blocks, num_kv_heads, block_size, head_size)

Heads within a block are contiguous, tokens interleaved. Required by FlashInfer on Blackwell (SM100) for optimal memory access patterns with the TRTLLM kernels.

FlashInfer's Paged KV Metadata

FlashInfer requires three auxiliary tensors to describe the paged KV structure:

# FlashInferMetadataBuilder._compute_flashinfer_kv_metadata()

def _compute_flashinfer_kv_metadata(self, num_blocks_np, seq_lens_np,
                                      block_table_tensor, num_reqs, page_size):
    # 1. paged_kv_indptr: cumulative sum of blocks per request
    np.cumsum(num_blocks_np, dtype=np.int32,
              out=self.paged_kv_indptr.np[1:num_reqs+1])

    # 2. paged_kv_indices: flattened list of physical block IDs
    _copy_page_indices_kernel[(num_reqs,)](
        paged_kv_indices, block_table_tensor,
        block_table_tensor.stride(0), paged_kv_indptr,
    )

    # 3. paged_kv_last_page_len: how many valid tokens in last block
    paged_kv_last_page_len_np = seq_lens_np % page_size
    self.paged_kv_last_page_len.np[:num_reqs] = np.where(
        (paged_kv_last_page_len_np == 0) & (seq_lens_np != 0),
        page_size, paged_kv_last_page_len_np,
    )

8. Performance: Hardware-to-Backend Mapping

Standard MHA: Backend Selection Matrix

GPU Generation SM Default Backend Fallback Key Feature
Blackwell (B200, B100) SM100 FlashInfer + TRT-LLM kernels FlashAttention HND layout, FP8+FP4 fused output quant
Hopper (H100, H200) SM90 FlashAttention (FA3) FlashInfer AOT scheduling, full CUDA graph
Ampere (A100, A10) SM80 FlashAttention (FA2) FlashInfer Packed GQA, uniform-batch CG only
Any (fallback) Any Triton Attention FlexAttention Pure Triton kernel, portable

MLA: Backend Selection Matrix

GPU Generation SM Prefill Backend Decode Backend (Dense) Block Size FP8 KV Cache
Blackwell SM100 FlashInfer MLA / FlashAttn MLA FlashInfer MLA > CUTLASS MLA > FlashMLA 64 / 128 Yes (FlashInfer preferred for FP8)
Hopper SM90 FlashAttn MLA FlashMLA > FlashAttn MLA 64 Yes (FlashMLA FP8 + dense path)
Ampere SM80 FlashAttn MLA Triton MLA (fallback) any*16 Yes (dequant in kernel)
Blackwell MLA priority nuance: When using FP8 KV cache, FlashInfer MLA Sparse gets highest priority. For BF16 KV cache with fewer than 16 attention heads, FlashInfer MLA Sparse also leads. Otherwise, FlashMLA Sparse takes the top sparse slot. The code handles this with head-count-dependent reordering in _get_backend_priorities().

CUDA Graph Compatibility

Backend CG Support Level Meaning
FlashAttention (FA3) ALWAYS Full CG for all batch types (decode, mixed prefill-decode)
FlashAttention (FA2) UNIFORM_BATCH CG works for uniform-length batches only (not mixed)
FlashInfer (with TRT-LLM) UNIFORM_BATCH CG works for decode and uniform prefill
FlashInfer (without TRT-LLM) UNIFORM_SINGLE_TOKEN_DECODE CG only for pure single-token decode
FlashMLA UNIFORM_BATCH Persistent CUDA graph buffers for tile scheduler metadata
CUTLASS MLA UNIFORM_SINGLE_TOKEN_DECODE Decode-only CUDA graph with 128 MB workspace
Triton MLA UNIFORM_SINGLE_TOKEN_DECODE Default (inherited from MLACommonMetadataBuilder)

Key Configuration Environment Variables

Variable Effect
VLLM_KV_CACHE_LAYOUT Force NHD or HND layout (auto-detected per backend otherwise)
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE FlashInfer workspace size in bytes (default varies by platform)
VLLM_BATCH_INVARIANT Force deterministic output (limits splits to 1, enlarges workspace to 2 GB)
--attention-backend Override auto-selection with a specific backend name
FORCE_NUM_KV_SPLITS Override CUTLASS MLA split-K count (default: auto-detect)

Architecture Overview

vLLM V1 Attention Architecture Attention Layer (model_executor/layers/attention/) get_attn_backend() -- selector.py + cuda.py AttentionBackend (get_impl_cls, get_builder_cls, validate...) FlashAttention FA2 (SM80+) / FA3 (SM90+) AOT sched, cascade FlashInfer FI native + TRT-LLM Plan/Execute, HND MLA Backends (DeepSeek) FlashMLA | CUTLASS MLA | FlashInfer MLA Triton MLA (fallback) | +Sparse variants Fallback Triton Attn, Flex Attn ROCm, CPU, ... Per-Step Execution MetadataBuilder.build() -> FlashAttentionMetadata MetadataBuilder.build() -> FlashInferMetadata MLACommonBuilder.build() -> MLACommonMetadata + decode Backend.build() -> Backend-specific metadata
From model layer to kernel dispatch: selector picks a backend, builder prepares metadata, impl runs the kernel.

Source File Reference

File Purpose Key Classes
v1/attention/selector.py Entry point for backend selection get_attn_backend()
v1/attention/backends/registry.py Enum of all backend class paths AttentionBackendEnum
platforms/cuda.py GPU-specific priority ordering _get_backend_priorities()
v1/attention/backends/flash_attn.py FlashAttention 2/3 implementation FlashAttentionBackend, FlashAttentionImpl
v1/attention/backends/flashinfer.py FlashInfer + TRT-LLM implementation FlashInferBackend, FlashInferImpl
v1/attention/backends/utils.py KV cache layout, per-layer params, local attention get_kv_cache_layout(), PerLayerParameters
v1/attention/backends/mla/flashmla.py FlashMLA for DeepSeek (SM90/100) FlashMLABackend, FlashMLAImpl
v1/attention/backends/mla/cutlass_mla.py CUTLASS MLA for Blackwell only CutlassMLABackend, CutlassMLAImpl
v1/attention/backends/mla/triton_mla.py Triton MLA fallback (any GPU) TritonMLABackend, TritonMLAImpl
model_executor/layers/attention/mla_attention.py MLA common base classes + algorithm docs MLACommonBackend, MLACommonImpl