How vLLM selects, configures, and dispatches to FlashAttention, FlashInfer, and MLA kernels -- traced through actual source code.
vllm/v1/attention/backends/ -- Source-Level AnalysisEvery attention layer in vLLM calls get_attn_backend() to determine which kernel to use. The entry point is in vllm/v1/attention/selector.py:
# vllm/v1/attention/selector.py
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
...
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
vllm_config = get_current_vllm_config()
# If user set --block-size explicitly, pass it to validation
block_size = cache_config.block_size if cache_config.user_specified_block_size else None
attn_selector_config = AttentionSelectorConfig(
head_size=head_size, dtype=dtype, kv_cache_dtype=kv_cache_dtype,
block_size=block_size, use_mla=use_mla, has_sink=has_sink, ...
)
return _cached_get_attn_backend(
backend=vllm_config.attention_config.backend, # user override or None
attn_selector_config=attn_selector_config,
)
When no backend is explicitly requested via --attention-backend, the CUDA platform iterates through a priority-ordered list and picks the first one that passes validation. This logic lives in vllm/platforms/cuda.py:
# vllm/platforms/cuda.py -- _get_backend_priorities()
def _get_backend_priorities(
use_mla: bool,
device_capability: DeviceCapability,
num_heads: int | None = None,
kv_cache_dtype: CacheDType | None = None,
) -> list[AttentionBackendEnum]:
if use_mla:
if device_capability.major == 10: # Blackwell (SM100)
return [
FLASHINFER_MLA, # Highest priority on Blackwell
CUTLASS_MLA,
FLASH_ATTN_MLA,
FLASHMLA,
TRITON_MLA,
*sparse_backends, # FLASHMLA_SPARSE / FLASHINFER_MLA_SPARSE
]
else: # Hopper (SM90) and older
return [FLASH_ATTN_MLA, FLASHMLA, FLASHINFER_MLA, TRITON_MLA, ...]
else: # Standard MHA
if device_capability.major == 10: # Blackwell
return [FLASHINFER, FLASH_ATTN, TRITON_ATTN, FLEX_ATTENTION]
else: # Ampere / Hopper
return [FLASH_ATTN, FLASHINFER, TRITON_ATTN, FLEX_ATTENTION]
Each backend class exposes a validate_configuration() method (inherited from AttentionBackend). The platform checks every backend in priority order:
# vllm/platforms/cuda.py -- get_valid_backends()
for priority, backend in enumerate(backend_priorities):
backend_class = backend.get_class()
invalid_reasons = backend_class.validate_configuration(
device_capability=device_capability,
**attn_selector_config._asdict(),
)
if invalid_reasons:
invalid_reasons_map[backend] = (priority, invalid_reasons)
else:
valid_backends_priorities.append((backend, priority))
# Select backend with lowest priority index (highest priority)
sorted_indices = sorted(range(len(valid)), key=lambda i: valid[i][1])
selected_backend = valid[sorted_indices[0]][0]
All backends are registered in an enum at vllm/v1/attention/backends/registry.py. Each entry maps a name to a fully qualified class path:
# vllm/v1/attention/backends/registry.py
class AttentionBackendEnum(Enum):
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
FLASHINFER_MLA = "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
# ... plus ROCm, CPU, sparse variants, etc.
The FlashAttention backend (vllm/v1/attention/backends/flash_attn.py) is the workhorse for standard multi-head attention on Ampere and Hopper GPUs. It auto-detects the FA version and adjusts its behavior accordingly.
# vllm/v1/attention/backends/flash_attn.py
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
forward_includes_kv_cache_update: bool = False # KV update separate from attn
supported_dtypes = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes = ["auto", "float16", "bfloat16"]
@staticmethod
def get_supported_kernel_block_sizes():
# For hybrid models with FP32 Mamba state, restrict block sizes
if model_config.is_hybrid:
return [16, 32, 64] # NaN propagation issue
return [MultipleOf(16)] # Any multiple of 16
@classmethod
def supports_compute_capability(cls, capability):
return capability >= DeviceCapability(8, 0) # Ampere+
@classmethod
def supports_per_head_quant_scales(cls):
return get_flash_attn_version() >= 3 # FA3 only
FlashAttention stores the KV cache with the key and value in a unified tensor. The first dimension selects K vs V:
# KV cache shape for FlashAttention:
# (2, num_blocks, block_size, num_kv_heads, head_size)
# ^--- 0=key, 1=value
@staticmethod
def get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size, ...):
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
The FlashAttentionMetadataBuilder pre-computes scheduling decisions ahead of time (AOT) when running with FlashAttention 3. This is critical for CUDA graph compatibility:
# FlashAttentionMetadataBuilder.__init__()
self.aot_schedule = get_flash_attn_version() == 3 # Only FA3
if self.use_full_cuda_graph and self.aot_schedule:
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
# The +1 is for tile_count_semaphore (synchronization)
# 4 slots per batch: prepare_varlen + dynamic_split + sort + head_swizzle
max_batch_size = max(max_num_seqs, max_cudagraph_capture_size or 0)
self.scheduler_metadata = torch.zeros(
1 + round_up(max_batch_size, 4) * 4,
dtype=torch.int32, device=self.device,
)
# Upper bound for intermediate buffer splits
self.max_num_splits = attention_config.flash_attn_max_num_splits_for_cuda_graph
AttentionCGSupport.ALWAYS (full cudagraph for all cases). FA2 only supports UNIFORM_BATCH because its special max_query_len=1 packed-GQA handling breaks when mixed prefill-decode batches are replayed from a graph captured with decode-only shapes.
# FlashAttentionMetadataBuilder.build()
def build(self, common_prefix_len, common_attn_metadata, fast_build=False):
# fast_build disables AOT scheduling -- used for spec-decode
aot_schedule = self.aot_schedule and not fast_build
# Inner function that invokes FA3's get_scheduler_metadata
def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal):
if aot_schedule:
return get_scheduler_metadata(
batch_size=batch_size, max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len, num_heads_q=self.num_heads_q,
num_heads_kv=self.num_heads_kv, headdim=self.headdim,
cache_seqlens=seqlens, qkv_dtype=qkv_dtype,
cu_seqlens_q=cu_query_lens, page_size=self.block_size,
causal=causal, window_size=self.aot_sliding_window,
num_splits=max_num_splits,
)
return None
# Three paths: DCP, cascade, or standard
if self.dcp_world_size > 1:
scheduler_metadata = schedule(..., causal=False)
elif use_cascade:
prefix_scheduler = schedule(batch_size=1, ..., causal=False)
suffix_scheduler = schedule(batch_size=num_reqs, ..., causal=True)
else:
scheduler_metadata = schedule(..., causal=causal)
# For FA3 + CUDA graph: copy into persistent buffer, zero rest
if self.use_full_cuda_graph and scheduler_metadata is not None:
n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n] = scheduler_metadata
self.scheduler_metadata[n:] = 0 # Zero rest for correctness
# FlashAttentionImpl.forward()
def forward(self, layer, query, key, value, kv_cache, attn_metadata, output, ...):
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0) # Split K/V
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(dtype) # Reinterpret as FP8
value_cache = value_cache.view(dtype)
if not attn_metadata.use_cascade:
# Standard path: single flash_attn_varlen_func call
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache, v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
seqused_k=attn_metadata.seq_lens,
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scale, causal=attn_metadata.causal,
block_table=attn_metadata.block_table, # Paged attention!
scheduler_metadata=attn_metadata.scheduler_metadata, # AOT
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,
num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks, # Attention sinks
)
else:
# Cascade attention: prefix + suffix kernels, merged via LSE
cascade_attention(output, query, key_cache, value_cache, ...)
When multiple requests share a common KV prefix (e.g., system prompt), vLLM computes attention over the shared prefix once and merges with per-request suffix attention using log-sum-exp (LSE) values. The FlashAttention backend dispatches two separate flash_attn_varlen_func calls: one for the prefix (non-causal, batch_size=1) and one for the per-request suffixes (causal).
# FlashAttentionImpl.do_kv_cache_update() -- called BEFORE forward()
def do_kv_cache_update(self, layer, key, value, kv_cache, slot_mapping):
key_cache, value_cache = kv_cache.unbind(0)
# reshape_and_cache_flash writes K/V tokens into paged blocks
# using slot_mapping to determine the block + offset for each token
reshape_and_cache_flash(
key[:num_actual_tokens], value[:num_actual_tokens],
key_cache, value_cache, slot_mapping[:num_actual_tokens],
self.kv_cache_dtype, layer._k_scale, layer._v_scale,
)
FlashInfer (vllm/v1/attention/backends/flashinfer.py) takes top priority on Blackwell GPUs and offers two internal dispatch paths: the native FlashInfer wrappers and TRT-LLM attention kernels.
# vllm/v1/attention/backends/flashinfer.py
class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_kv_cache_dtypes = ["auto", "float16", "bfloat16",
"fp8", "fp8_e4m3", "fp8_e5m2"]
@staticmethod
def get_supported_kernel_block_sizes():
return [16, 32, 64] # Fixed sizes, not MultipleOf
@classmethod
def supports_compute_capability(cls, capability):
return capability >= DeviceCapability(7, 5) \
and capability <= DeviceCapability(12, 1)
@classmethod
def get_required_kv_cache_layout(cls):
capability = current_platform.get_device_capability()
if capability.major == 10: # Blackwell requires HND layout
return "HND"
return None # Use default (NHD)
FlashInfer uses a plan-then-execute pattern. During build(), wrappers call .plan() to precompute indices and scheduling. During forward(), they call .run() to execute the actual attention.
# Workspace buffer: allocated once, reused across calls
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024 # 2 GB
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE
if envs.VLLM_BATCH_INVARIANT:
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
self._workspace_buffer = torch.zeros(
buffer_size, dtype=torch.uint8, device=self.device
)
return self._workspace_buffer
# Separate wrappers for prefill vs decode
def _get_prefill_wrapper(self):
if self.use_dcp:
return BatchDCPPrefillWrapper(workspace_buffer=..., dcp_a2a=...)
else:
return BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), get_kv_cache_layout()
)
def _get_decode_wrapper(self, batch_size, use_cudagraph=False):
# For CUDA graphs: one wrapper PER batch size, with pinned buffers
if use_cudagraph:
decode_wrapper = self._decode_wrappers_cudagraph.get(batch_size)
if decode_wrapper is None:
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(), get_kv_cache_layout(),
use_cuda_graph=use_cudagraph,
paged_kv_indptr_buffer=self.paged_kv_indptr.gpu[:batch_size+1],
paged_kv_indices_buffer=self.paged_kv_indices.gpu,
paged_kv_last_page_len_buffer=self.paged_kv_last_page_len.gpu[:batch_size],
use_tensor_cores=True, # Always on for modern GPUs
)
if use_cudagraph:
self._decode_wrappers_cudagraph[batch_size] = decode_wrapper
FlashInfer's build() decides per-phase whether to use native FlashInfer or TRT-LLM kernels:
# FlashInferMetadataBuilder.build()
# Step 1: Decide dispatch mode for each phase
prefill_use_trtllm = use_trtllm_attention(
self.num_qo_heads, self.num_kv_heads, num_prefill_tokens,
max_seq_len, self.dcp_world_size, self.cache_dtype, ...
)
decode_use_trtllm = self.use_trtllm_decode_attention and self.dcp_world_size <= 1
# Step 2: Build metadata skeleton
attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
slot_mapping=slot_mapping,
q_data_type=self.q_data_type,
num_decodes=num_decodes, num_prefills=num_prefills,
prefill=None, decode=None, # populated below
use_cascade=use_cascade, cascade_wrapper=None,
)
# Step 3a: PREFILL -- either FI native or TRTLLM
if prefill_use_trtllm:
attn_metadata.prefill = TRTLLMPrefill(
block_tables=block_table_tensor[prefill_start:],
seq_lens=seq_lens[prefill_start:],
cum_seq_lens_q=qo_indptr_prefill_gpu,
max_q_len=max_q_len_prefill, max_seq_len=max_seq_len,
)
else:
prefill_wrapper = self._get_prefill_wrapper()
prefill_wrapper.plan(...) # Plan phase -- precompute indices
attn_metadata.prefill = FIPrefill(wrapper=prefill_wrapper)
# Step 3b: DECODE -- either FI native or TRTLLM
if decode_use_trtllm:
attn_metadata.decode = TRTLLMDecode(
block_tables=block_table_tensor[:num_decodes],
seq_lens=seq_lens[:num_decodes], max_seq_len=max_seq_len,
)
else:
decode_wrapper = self._get_decode_wrapper(num_decode_tokens, use_cudagraph)
fast_plan_decode(decode_wrapper, ...) # Plan phase
attn_metadata.decode = FIDecode(wrapper=decode_wrapper)
@dataclass
class FIPrefill:
wrapper: BatchPrefillWith...Wrapper
# Wrapper has plan() already called
# Contains paged_kv_indptr,
# paged_kv_indices, last_page_len
@dataclass
class FIDecode:
wrapper: BatchDecodeWith...Wrapper
# Same pattern: plan() pre-called
@dataclass
class TRTLLMPrefill:
block_tables: torch.Tensor # [num_prefills, max_blocks]
seq_lens: torch.Tensor # [num_prefills]
cum_seq_lens_q: torch.Tensor
cum_seq_lens_kv: torch.Tensor
max_q_len: int
max_seq_len: int
@dataclass
class TRTLLMDecode:
block_tables: torch.Tensor # [num_decodes, max_blocks]
seq_lens: torch.Tensor # [num_decodes]
max_seq_len: int
# FlashInfer CUDA graph strategy depends on TRT-LLM availability
@classmethod
def get_cudagraph_support(cls, vllm_config, kv_cache_spec):
# Check if TRT-LLM attention works for ALL spec layers
has_trtllm_support = True
for spec in kv_specs:
if not can_use_trtllm_attention(num_qo_heads, spec.num_kv_heads):
has_trtllm_support = False
break
if has_trtllm_support:
return AttentionCGSupport.UNIFORM_BATCH # Full CG with mixed batches
else:
return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE # Decode-only CG
MLA, introduced in DeepSeek-V2, compresses the KV cache into a single latent vector per token. This dramatically reduces memory bandwidth during decode. The design is documented in extensive comments in vllm/model_executor/layers/attention/mla_attention.py:
| Symbol | Meaning | Value |
|---|---|---|
N | Number of attention heads | 128 |
Lkv | Latent KV dimension | 512 |
P | No-RoPE head dimension | 128 |
R | RoPE head dimension | 64 |
V | Value head dimension | 128 |
KV cache stores only [Lkv + R] = 576 values per token per layer, vs 2 * N * P = 32768 for standard MHA. A 57x compression.
The source code defines two paths explicitly (from the docstring in mla_attention.py):
# From vllm/model_executor/layers/attention/mla_attention.py docstring:
## Compute Friendly Approach (forward_mha) -- used for prefill:
# Decompress KV latent into per-head K and V
k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
# Standard MHA with QK headdim = P + R, V headdim = V
spda_o = scaled_dot_product_attention(
cat([q_nope, q_pe], dim=-1), # [Sq, N, P+R]
cat([k_nope, k_pe_expanded], dim=-1), # [Skv, N, P+R]
v # [Skv, N, V]
)
## Data-Movement Friendly Approach (forward_mqa) -- used for decode:
# Absorb W_UK into query -- avoid decompressing KV
ql_nope = einsum("snh,lnh->snl", q, W_UK)
# MQA with QK headdim = Lkv + R, V headdim = Lkv
spda_o = scaled_dot_product_attention(
cat([ql_nope, q_pe], dim=-1), # [Sq, N, Lkv+R] = [Sq, N, 576]
cat([kv_c, k_pe], dim=-1), # [Skv, 1, 576] -- MQA, single KV head
kv_c # [Skv, 1, 512]
)
o = einsum("snl,lnv->snv", spda_o, W_UV) # Project back to V dimension
vLLM provides multiple MLA decode kernels under vllm/v1/attention/backends/mla/. All inherit from MLACommonBackend / MLACommonImpl and implement forward_mqa() for the decode path.
# vllm/v1/attention/backends/mla/flashmla.py
class FlashMLABackend(MLACommonBackend):
@staticmethod
def get_supported_kernel_block_sizes():
return [64] # Fixed block size of 64
@staticmethod
def get_name():
return "FLASHMLA"
@classmethod
def supports_compute_capability(cls, capability):
return capability.major in [9, 10] # Hopper + Blackwell only
# FlashMLAImpl.forward_mqa()
def forward_mqa(self, q, kv_c_and_k_pe_cache, attn_metadata, layer):
q = reshape_query_for_spec_decode(q, attn_metadata.num_decodes)
scheduler_metadata = attn_metadata.decode.scheduler_metadata
if self.kv_cache_dtype.startswith("fp8"):
o, lse = flash_mla_with_kvcache_fp8(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank, # 512
tile_scheduler_metadata=scheduler_metadata.tile_scheduler_metadata,
num_splits=scheduler_metadata.num_splits,
softmax_scale=self.scale, causal=True,
descale_q=layer._q_scale, descale_k=layer._k_scale,
)
else:
o, lse = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # MQA: 1 KV head
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=scheduler_metadata,
softmax_scale=self.scale, causal=True,
)
return reshape_attn_output_for_spec_decode(o), lse
# FlashMLAMetadataBuilder._build_decode()
def _build_decode(self, block_table, seq_lens, max_seq_len, ...):
# Compute tile scheduler metadata for the FlashMLA kernel
scheduler_metadata, _ = get_mla_metadata(
seq_lens, num_q_tokens_per_head_k, 1, # MQA: 1 KV head
is_fp8_kvcache=self.is_fp8_kvcache,
)
if self.is_fp8_kvcache:
# FP8 needs separate tile_scheduler_metadata and num_splits
tile_scheduler_metadata, num_splits = get_mla_metadata_dense_fp8(...)
# Copy into CUDA graph persistent buffers
if self.cudagraph_mode.has_full_cudagraphs():
self.cg_buf_tile_scheduler_metadata[:n].copy_(tile_scheduler_metadata)
self.cg_buf_num_splits[:n].copy_(num_splits)
# vllm/v1/attention/backends/mla/cutlass_mla.py
class CutlassMLABackend(MLACommonBackend):
@staticmethod
def get_supported_kernel_block_sizes():
return [128] # Forced to 128
@classmethod
def supports_compute_capability(cls, capability):
return capability.major == 10 # Blackwell (SM100) ONLY
# Manages GPU workspace buffer for CUTLASS MLA
class SM100Workspace:
def __init__(self, initial_workspace_size):
self._workspace_buf = torch.empty(
initial_workspace_size, device="cuda", dtype=torch.uint8
)
self._block_size = 128 # Forced to 128
self._sm_count = num_compute_units(0)
def ensure_size(self, attn_metadata, num_kv_splits):
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
max_seq_len * self._block_size, batch_size,
self._sm_count, num_kv_splits=num_kv_splits,
)
if self._workspace_buf.shape[0] < workspace_size:
self._workspace_buf.resize_(workspace_size)
g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128 MB initial
MAX_HEADS = 128
# CutlassMLAImpl.forward_mqa()
def forward_mqa(self, q, kv_c_and_k_pe_cache, attn_metadata, layer):
if type(q) is tuple:
q_nope, q_pe = q # Already split
else:
q_nope, q_pe = torch.split(q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)
# q_nope: [B, H, 512], q_pe: [B, H, 64]
# kv_c_and_k_pe_cache: [num_blocks, 128, 576]
ops.sm100_cutlass_mla_decode(
out, lse, q_nope, q_pe,
kv_c_and_k_pe_cache,
seq_lens, page_table, workspace,
sm_scale=self.scale,
num_kv_splits=self._num_kv_splits, # -1 = auto-detect
)
return out[:, :H], lse[:, :H]
# vllm/v1/attention/backends/mla/triton_mla.py
class TritonMLABackend(MLACommonBackend):
@staticmethod
def get_supported_kernel_block_sizes():
return [MultipleOf(16)] # Any multiple of 16
@classmethod
def supports_compute_capability(cls, capability):
return True # Works on ANY GPU
# TritonMLAImpl.forward_mqa()
def forward_mqa(self, q, kv_c_and_k_pe_cache, attn_metadata, layer):
B, q_num_heads = q.shape[0], q.shape[1]
o = torch.zeros(B, q_num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device)
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
num_kv_splits = 1 if envs.VLLM_BATCH_INVARIANT else 4
# Workspace for split-K: stores partial outputs + LSE per split
attn_logits = torch.empty(
(B, q_num_heads, num_kv_splits, self.kv_lora_rank + 1),
# ^ +1 for LSE per split
dtype=torch.float32, device=q.device,
)
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) # Add head dim of 1
decode_attention_fwd(
q, kv_c_and_k_pe_cache,
kv_c_and_k_pe_cache[..., :self.kv_lora_rank], # V = kv_c portion
o, lse, block_table, seq_lens, attn_logits,
num_kv_splits, self.scale, PAGE_SIZE,
k_scale=layer._k_scale, v_scale=layer._k_scale,
)
return o, lse
decode_attention_fwd with an explicit split-K workspace (attn_logits) that stores per-split partial outputs and LSE values. The +1 in the last dimension holds the LogSumExp for each split, which the stage-2 kernel uses to merge results.
Each backend defines its own metadata dataclass. These are populated by the builder's build() method and consumed by the implementation's forward().
# vllm/v1/attention/backends/flash_attn.py
@dataclass
class FlashAttentionMetadata:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------|
# |- tokenA -|..............|-- newTokens --|
# |---------- context_len --|
# |-------------- seq_len ------------------|
# |-- query_len ---|
num_actual_tokens: int # Tokens excluding padding
max_query_len: int
query_start_loc: torch.Tensor # [num_reqs + 1] cumsum of query lens
max_seq_len: int
seq_lens: torch.Tensor # [num_reqs] total sequence lengths
block_table: torch.Tensor # [num_reqs, max_blocks_per_seq]
slot_mapping: torch.Tensor # [num_actual_tokens] -> cache slot idx
# Cascade attention fields
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: torch.Tensor | None
prefix_kv_lens: torch.Tensor | None
suffix_kv_lens: torch.Tensor | None
# AOT scheduling (FA3 only)
scheduler_metadata: torch.Tensor | None = None
prefix_scheduler_metadata: torch.Tensor | None = None
max_num_splits: int = 0 # 0 = use FA3's heuristics
causal: bool = True
# vllm/v1/attention/backends/flashinfer.py
@dataclass
class FlashInferMetadata:
num_actual_tokens: int
slot_mapping: torch.Tensor # [num_actual_tokens]
q_data_type: torch.dtype # May be FP8 if TRT-LLM path
num_decodes: int
num_decode_tokens: int
num_prefills: int
num_prefill_tokens: int
# Polymorphic: FI native OR TRT-LLM, depending on dispatch
prefill: FIPrefill | TRTLLMPrefill | None
decode: FIDecode | TRTLLMDecode | None
# Cascade attention (uses FlashInfer's MultiLevelCascadeAttentionWrapper)
use_cascade: bool
cascade_wrapper: MultiLevelCascadeAttentionWrapper | None
# From vllm/model_executor/layers/attention/mla_attention.py
@dataclass
class MLACommonDecodeMetadata:
block_table: torch.Tensor # [num_decodes, max_blocks]
seq_lens: torch.Tensor # [num_decodes]
dcp_tot_seq_lens: torch.Tensor | None
# FlashMLA extends this with scheduler metadata:
@dataclass
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
scheduler_metadata: FlashMLASchedMeta
# Contains tile_scheduler_metadata and num_splits
# for the FlashMLA kernel's work distribution
The block table is a 2D tensor of shape [num_requests, max_blocks_per_seq]. Each entry is a physical block index into the KV cache pool. The attention kernel uses it to map logical token positions to physical memory locations:
vLLM supports two memory layouts for KV cache blocks, selectable via VLLM_KV_CACHE_LAYOUT:
# vllm/v1/attention/backends/utils.py
KVCacheLayoutType = Literal["NHD", "HND"]
@functools.lru_cache
def get_kv_cache_layout():
# Priority: code override > env var > connector default
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
return _KV_CACHE_LAYOUT_OVERRIDE
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
if cache_layout is None:
cache_layout = get_kv_connector_cache_layout()
return cache_layout
(num_blocks, block_size, num_kv_heads, head_size)
Tokens within a block are contiguous, heads interleaved. Good for sequential token access patterns. Used by FlashAttention on Ampere/Hopper.
(num_blocks, num_kv_heads, block_size, head_size)
Heads within a block are contiguous, tokens interleaved. Required by FlashInfer on Blackwell (SM100) for optimal memory access patterns with the TRTLLM kernels.
FlashInfer requires three auxiliary tensors to describe the paged KV structure:
# FlashInferMetadataBuilder._compute_flashinfer_kv_metadata()
def _compute_flashinfer_kv_metadata(self, num_blocks_np, seq_lens_np,
block_table_tensor, num_reqs, page_size):
# 1. paged_kv_indptr: cumulative sum of blocks per request
np.cumsum(num_blocks_np, dtype=np.int32,
out=self.paged_kv_indptr.np[1:num_reqs+1])
# 2. paged_kv_indices: flattened list of physical block IDs
_copy_page_indices_kernel[(num_reqs,)](
paged_kv_indices, block_table_tensor,
block_table_tensor.stride(0), paged_kv_indptr,
)
# 3. paged_kv_last_page_len: how many valid tokens in last block
paged_kv_last_page_len_np = seq_lens_np % page_size
self.paged_kv_last_page_len.np[:num_reqs] = np.where(
(paged_kv_last_page_len_np == 0) & (seq_lens_np != 0),
page_size, paged_kv_last_page_len_np,
)
| GPU Generation | SM | Default Backend | Fallback | Key Feature |
|---|---|---|---|---|
| Blackwell (B200, B100) | SM100 | FlashInfer + TRT-LLM kernels | FlashAttention | HND layout, FP8+FP4 fused output quant |
| Hopper (H100, H200) | SM90 | FlashAttention (FA3) | FlashInfer | AOT scheduling, full CUDA graph |
| Ampere (A100, A10) | SM80 | FlashAttention (FA2) | FlashInfer | Packed GQA, uniform-batch CG only |
| Any (fallback) | Any | Triton Attention | FlexAttention | Pure Triton kernel, portable |
| GPU Generation | SM | Prefill Backend | Decode Backend (Dense) | Block Size | FP8 KV Cache |
|---|---|---|---|---|---|
| Blackwell | SM100 | FlashInfer MLA / FlashAttn MLA | FlashInfer MLA > CUTLASS MLA > FlashMLA | 64 / 128 | Yes (FlashInfer preferred for FP8) |
| Hopper | SM90 | FlashAttn MLA | FlashMLA > FlashAttn MLA | 64 | Yes (FlashMLA FP8 + dense path) |
| Ampere | SM80 | FlashAttn MLA | Triton MLA (fallback) | any*16 | Yes (dequant in kernel) |
_get_backend_priorities().
| Backend | CG Support Level | Meaning |
|---|---|---|
| FlashAttention (FA3) | ALWAYS |
Full CG for all batch types (decode, mixed prefill-decode) |
| FlashAttention (FA2) | UNIFORM_BATCH |
CG works for uniform-length batches only (not mixed) |
| FlashInfer (with TRT-LLM) | UNIFORM_BATCH |
CG works for decode and uniform prefill |
| FlashInfer (without TRT-LLM) | UNIFORM_SINGLE_TOKEN_DECODE |
CG only for pure single-token decode |
| FlashMLA | UNIFORM_BATCH |
Persistent CUDA graph buffers for tile scheduler metadata |
| CUTLASS MLA | UNIFORM_SINGLE_TOKEN_DECODE |
Decode-only CUDA graph with 128 MB workspace |
| Triton MLA | UNIFORM_SINGLE_TOKEN_DECODE |
Default (inherited from MLACommonMetadataBuilder) |
| Variable | Effect |
|---|---|
VLLM_KV_CACHE_LAYOUT |
Force NHD or HND layout (auto-detected per backend otherwise) |
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE |
FlashInfer workspace size in bytes (default varies by platform) |
VLLM_BATCH_INVARIANT |
Force deterministic output (limits splits to 1, enlarges workspace to 2 GB) |
--attention-backend |
Override auto-selection with a specific backend name |
FORCE_NUM_KV_SPLITS |
Override CUTLASS MLA split-K count (default: auto-detect) |
| File | Purpose | Key Classes |
|---|---|---|
v1/attention/selector.py |
Entry point for backend selection | get_attn_backend() |
v1/attention/backends/registry.py |
Enum of all backend class paths | AttentionBackendEnum |
platforms/cuda.py |
GPU-specific priority ordering | _get_backend_priorities() |
v1/attention/backends/flash_attn.py |
FlashAttention 2/3 implementation | FlashAttentionBackend, FlashAttentionImpl |
v1/attention/backends/flashinfer.py |
FlashInfer + TRT-LLM implementation | FlashInferBackend, FlashInferImpl |
v1/attention/backends/utils.py |
KV cache layout, per-layer params, local attention | get_kv_cache_layout(), PerLayerParameters |
v1/attention/backends/mla/flashmla.py |
FlashMLA for DeepSeek (SM90/100) | FlashMLABackend, FlashMLAImpl |
v1/attention/backends/mla/cutlass_mla.py |
CUTLASS MLA for Blackwell only | CutlassMLABackend, CutlassMLAImpl |
v1/attention/backends/mla/triton_mla.py |
Triton MLA fallback (any GPU) | TritonMLABackend, TritonMLAImpl |
model_executor/layers/attention/mla_attention.py |
MLA common base classes + algorithm docs | MLACommonBackend, MLACommonImpl |