vLLM V1 GPUModelRunner Deep Dive

Attention Backends, Sampling Pipeline, CUDA Graphs, and Memory Management -- traced through actual source code from vllm/v1/worker/

GPUModelRunner FlashAttention FlashInfer Sampling CUDA Graphs
← Back to vLLM Overview

1. Architecture Overview

execute_model() Pipeline -- Four Key Stages
_update_states()
Add/remove reqs
_prepare_inputs()
Build GPU tensors
model.forward()
Transformer + CG
sample_tokens()
Greedy / Top-k,p
Two-phase split: execute_model() returns None, then sample_tokens() runs separately (grammar overlap)

vLLM V1 restructures the worker-side execution around a persistent batch paradigm. Instead of rebuilding input tensors from scratch each iteration, the GPUModelRunner keeps a long-lived InputBatch on the CPU, only applying deltas (adds/removes) between steps. This eliminates the O(batch_size) per-step overhead that dominated V0.

GPUModelRunner.execute_model() Forward Pass Pipeline _update_states() Remove finished reqs Add new/resumed reqs _prepare_inputs() Build positions, token_ids Commit block table to GPU build_attn_metadata() FlashAttn / FlashInfer Cascade / TRTLLM paths model.forward() Transformer layers CUDA Graph replay compute_logits() hidden[logits_indices] x lm_head sample_tokens() Greedy / Temperature / Top-k,p CPU-side (overlapped with GPU model.forward) SamplingMetadata copy Grammar bitmask apply _update_states_after_model_execute() -- bookkeeping ModelRunnerOutput
Key V1 Innovation: execute_model() now returns None and stores ephemeral state in ExecuteModelState. The engine then calls sample_tokens(grammar_output) separately, allowing structured-output grammar bitmasks to be applied between the forward pass and sampling. This two-phase design enables overlap between GPU computation and CPU-side grammar processing.

Source File Map

Core Runner

vllm/v1/worker/gpu_model_runner.py -- 6000+ lines. Contains GPUModelRunner: the main orchestrator for all V1 GPU inference.

Persistent Batch

vllm/v1/worker/gpu_input_batch.py -- InputBatch: maintains CPU-side arrays for token IDs, sampling params, block tables.

Block Table

vllm/v1/worker/block_table.py -- BlockTable and MultiGroupBlockTable: maps virtual blocks to physical KV cache pages.

Worker

vllm/v1/worker/gpu_worker.py -- Worker: initializes device, loads model, profiles memory, delegates to GPUModelRunner.

2. execute_model() -- The Main Entry Point

The execute_model method is the heartbeat of each inference step. It receives a SchedulerOutput from the engine core and orchestrates the entire forward pass. The two-phase split (execute_model + sample_tokens) was introduced to support structured output grammars that need to compute bitmasks in parallel with the forward pass.

# gpu_model_runner.py line 3755
def execute_model(self,
    scheduler_output: "SchedulerOutput",
    intermediate_tensors: IntermediateTensors | None = None,
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors | None:

    with record_function_or_nullcontext("gpu_model_runner: preprocess"):
        # Phase 1: Update persistent batch states
        deferred_state_corrections_fn = self._update_states(scheduler_output)

        # Phase 2: Build GPU input tensors from persistent batch
        logits_indices, spec_decode_metadata = self._prepare_inputs(
            scheduler_output, num_scheduled_tokens_np)

        # Phase 3: Build attention metadata for the chosen backend
        attn_metadata, spec_decode_common_attn_metadata = (
            self._build_attention_metadata(...))

    # Phase 4: Run the actual model forward
    with set_forward_context(attn_metadata, ...):
        model_output = self._model_forward(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds, **model_kwargs)

    # Phase 5: Extract hidden states, compute logits
    sample_hidden_states = hidden_states[logits_indices]
    logits = self.model.compute_logits(sample_hidden_states)

    # Store state for sample_tokens() to pick up
    self.execute_model_state = ExecuteModelState(
        scheduler_output, logits, spec_decode_metadata, ...)
    return None  # Signal: call sample_tokens() next

ExecuteModelState: The Bridge Between Phases

# gpu_model_runner.py line 372
class ExecuteModelState(NamedTuple):
    """Ephemeral cached state transferred between execute_model()
    and sample_tokens(), after execute_model() returns None."""
    scheduler_output: "SchedulerOutput"
    logits: torch.Tensor
    spec_decode_metadata: SpecDecodeMetadata | None
    spec_decode_common_attn_metadata: CommonAttentionMetadata | None
    hidden_states: torch.Tensor
    sample_hidden_states: torch.Tensor
    aux_hidden_states: list[torch.Tensor] | None
    ec_connector_output: ECConnectorOutput | None
    cudagraph_stats: CUDAGraphStat | None
    slot_mappings: dict[str, torch.Tensor] | None

sample_tokens() -- Phase 2

# gpu_model_runner.py line 4107
@torch.inference_mode
def sample_tokens(self, grammar_output: "GrammarOutput | None"):
    # Unpack ephemeral state from execute_model()
    (scheduler_output, logits, spec_decode_metadata, ...) = self.execute_model_state
    self.execute_model_state = None

    # Apply grammar bitmask if structured output is enabled
    if grammar_output is not None:
        apply_grammar_bitmask(scheduler_output, grammar_output,
                              self.input_batch, logits)

    # Actually sample
    sampler_output = self._sample(logits, spec_decode_metadata)

    # Async bookkeeping: update output_token_ids, num_computed_tokens
    self._update_states_after_model_execute(
        sampler_output.sampled_token_ids, scheduler_output)

3. InputBatch -- Persistent CPU-Side State

InputBatch Dual-Array Pattern -- CPU Pinned ↔ GPU Tensor
CPU Side
temperature_cpu_tensor
(torch, pinned memory)
↓ .numpy() (zero-copy)
temperature_cpu
(numpy view, fast writes)
non_blocking
copy_()
GPU Side
temperature
(CUDA tensor)
Read by Sampler kernel
Pattern repeated for: temperature, top_p, top_k, frequency_penalties, presence_penalties, repetition_penalties

InputBatch is one of V1's key optimizations. Rather than reconstructing every tensor from scratch each step, it maintains pre-allocated numpy/torch arrays sized to max_num_reqs x max_model_len. Requests are inserted/removed via index manipulation, and only the relevant slices are copied to GPU.

Why This Matters: In V0, each step rebuilt input tensors for the entire batch. With a batch of 256 requests, that's 256 allocations + copies. In V1, consecutive batches share ~95% of the same requests, so only deltas are applied. The persistent arrays also enable pinned memory for async CPU-to-GPU transfers, overlapping data prep with GPU computation.

Pre-allocated Buffers

# gpu_input_batch.py line 81
class InputBatch:
    def __init__(self, max_num_reqs, max_model_len, ...):
        # Token IDs: [max_num_reqs, max_model_len] -- NOT pinned (too large)
        self.token_ids_cpu_tensor = torch.zeros(
            (max_num_reqs, max_model_len), device="cpu", dtype=torch.int32,
            pin_memory=False)
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()

        # Sequence lengths: [max_num_reqs] -- PINNED for fast GPU copy
        self.num_computed_tokens_cpu_tensor = torch.zeros(
            (max_num_reqs,), device="cpu", dtype=torch.int32,
            pin_memory=pin_memory)

        # Sampling parameters -- each has CPU + GPU mirrors
        self.temperature = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
        self.temperature_cpu_tensor = torch.empty(
            (max_num_reqs,), dtype=torch.float32, device="cpu",
            pin_memory=pin_memory)
        self.temperature_cpu = self.temperature_cpu_tensor.numpy()

        # Top-p, top-k, penalties follow the same pattern:
        # GPU tensor + pinned CPU tensor + numpy view
        self.top_p = torch.empty(...)    # GPU
        self.top_p_cpu = ...              # numpy view of pinned CPU
        self.top_k = torch.empty(...)
        self.frequency_penalties = torch.empty(...)
        self.presence_penalties = torch.empty(...)
        self.repetition_penalties = torch.empty(...)

        # Track which reqs use which sampling type
        self.greedy_reqs: set[str] = set()
        self.random_reqs: set[str] = set()

Adding a Request to the Batch

# gpu_input_batch.py line 316
def add_request(self, request: CachedRequestState) -> int:
    req_index = self._register_add_request(request)

    # Copy prompt token ids into the persistent buffer
    num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
        request.prompt_token_ids, request.prompt_embeds)
    self.num_prompt_tokens[req_index] = num_prompt_tokens

    if request.prompt_token_ids is not None:
        self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids

    # Copy output token ids (for continued requests)
    start_idx = num_prompt_tokens
    end_idx = start_idx + len(request.output_token_ids)
    self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids

    # Set sampling parameters at the request index
    self.temperature_cpu[req_index] = request.sampling_params.temperature
    self.top_p_cpu[req_index] = request.sampling_params.top_p
    self.top_k_cpu[req_index] = request.sampling_params.top_k
    ...

The Dual-Array Pattern

InputBatch Dual-Array Pattern (e.g., temperature) CPU Side temperature_cpu_tensor (pinned) temperature_cpu = .numpy() view Writes via numpy (zero-copy) non_blocking copy_ GPU Side temperature (CUDA tensor) Read by Sampler kernel

This pattern is repeated for every sampling parameter. The numpy view enables zero-copy writes on the CPU, while the pinned CPU tensor enables DMA transfers to the GPU without involving the CPU during the copy.

4. _prepare_inputs() -- Building GPU Tensors

This is the performance-critical bridge between CPU state and GPU computation. It reads from the persistent InputBatch and produces the exact tensors the model needs: input_ids, positions, attention metadata, and slot mappings.

# gpu_model_runner.py line 1784
def _prepare_inputs(self, scheduler_output, num_scheduled_tokens):
    num_reqs = self.input_batch.num_reqs

    # OPTIMIZATION: Start block table GPU copy first (overlaps with CPU work)
    self.input_batch.block_table.commit_block_table(num_reqs)

    # Build request indices: [2,5,3] -> [0,0,1,1,1,1,1,2,2,2]
    req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)

    # Cumulative tokens and per-token query position
    cu_num_tokens = self._get_cumsum_and_arange(
        num_scheduled_tokens, self.query_pos.np)

    # Positions: num_computed_tokens[req] + query_offset
    positions_np = (
        self.input_batch.num_computed_tokens_cpu[req_indices]
        + self.query_pos.np[:cu_num_tokens[-1]])

    # Flatten token_ids from 2D [req, pos] to 1D via index_select
    token_indices = positions_np + req_indices * max_model_len
    torch.index_select(
        self.input_batch.token_ids_cpu_tensor.flatten(), 0,
        torch.from_numpy(token_indices),
        out=self.input_ids.cpu[:total_num_scheduled_tokens])

    # Prepare query_start_loc for attention
    self.query_start_loc.np[0] = 0
    self.query_start_loc.np[1:num_reqs+1] = cu_num_tokens
    self.query_start_loc.copy_to_gpu()

    # Compute slot_mapping via Triton kernel (GPU-side)
    self.input_batch.block_table.compute_slot_mapping(
        num_reqs, self.query_start_loc.gpu[:num_reqs+1],
        self.positions[:total_num_scheduled_tokens])

Block Table and Slot Mapping

The BlockTable maintains a 2D mapping of [request, block_index] to physical block IDs. The slot mapping kernel converts (position) into (block_id * block_size + offset) for each token, which tells the attention kernel exactly where to read/write KV cache entries.

# block_table.py -- Triton kernel for slot mapping
# _compute_slot_mapping_kernel computes:
#   slot = block_table[req][position // block_size] * block_size + position % block_size
# This runs on GPU to avoid CPU-GPU sync for position data.

# block_table.py line 141
def compute_slot_mapping(self, num_reqs, query_start_loc, positions):
    _compute_slot_mapping_kernel[(num_reqs + 1,)](
        num_tokens, max_num_batched_tokens,
        query_start_loc, positions,
        self.block_table.gpu, self.block_table.gpu.stride(0),
        self.block_size, self.slot_mapping.gpu,
        PAD_ID=PAD_SLOT_ID)

Hybrid Block Sizes

# block_table.py line 173 -- map_to_kernel_blocks
# When allocation block size != kernel block size:
# e.g., 32-token allocation blocks with 16-token kernel blocks
#   kv_manager_block_id 0 -> kernel blocks [0, 1]
#   kv_manager_block_id 1 -> kernel blocks [2, 3]
kernel_block_ids = (
    kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block
    + kernel_block_arange).reshape(-1)

5. Attention Backends -- FlashAttention vs FlashInfer

Attention Backend Comparison -- Side by Side
FlashAttention (FA2/FA3)
SM >= 8.0 (Ampere+)
Block size: MultipleOf(16)
Unified varlen API (prefill+decode)
FA3: CUDA Graph ALWAYS
FA3: AOT scheduling support
vs
FlashInfer
SM 7.5 - 12.1 (Turing to Blackwell)
Block size: 16, 32, 64
Separate Prefill + Decode wrappers
TRTLLM decode (SM >= 10.0)
Per-batch-size CG wrappers

vLLM V1 supports multiple attention backends through a clean abstraction layer. Each backend provides three components: an AttentionBackend (static capabilities), an AttentionMetadataBuilder (builds per-step metadata), and an AttentionImpl (the actual kernel dispatch).

Backend Selection Architecture

Attention Backend Hierarchy AttentionBackend (abstract) get_name(), get_builder_cls(), get_impl_cls() FlashAttentionBackend FA2/FA3, SM >= 8.0, block_size: MultipleOf(16) FlashInferBackend SM 7.5-12.1, TRTLLM decode, block_size: 16/32/64 FA MetadataBuilder FlashAttentionImpl FI MetadataBuilder FlashInferImpl TRTLLM Decode Path (SM 10.0+ / Blackwell)

FlashAttention Metadata

# flash_attn.py line 200
@dataclass
class FlashAttentionMetadata:
    num_actual_tokens: int       # Tokens excluding padding
    max_query_len: int          # Max query length in batch
    query_start_loc: torch.Tensor   # Cumulative query lengths [num_reqs+1]
    max_seq_len: int            # Max full sequence length
    seq_lens: torch.Tensor          # Per-request sequence lengths
    block_table: torch.Tensor       # Block table [num_reqs, max_blocks]
    slot_mapping: torch.Tensor      # Where to write K/V entries

    # Cascade attention (shared prefix optimization)
    use_cascade: bool
    common_prefix_len: int
    cu_prefix_query_lens: torch.Tensor | None

    # FA3 AOT scheduling metadata (pre-computed work distribution)
    scheduler_metadata: torch.Tensor | None
    max_num_splits: int = 0      # 0 = use FA3 heuristics

FlashInfer Metadata -- Split Prefill/Decode

# flashinfer.py line 486
@dataclass
class FlashInferMetadata:
    num_actual_tokens: int
    slot_mapping: torch.Tensor
    q_data_type: torch.dtype

    num_decodes: int              # Number of decode requests
    num_decode_tokens: int
    num_prefills: int             # Number of prefill requests
    num_prefill_tokens: int

    # Separate dispatch paths for prefill and decode
    prefill: FIPrefill | TRTLLMPrefill | None
    decode: FIDecode | TRTLLMDecode | None

    # Cascade attention (entire batch, distinct from prefill/decode)
    use_cascade: bool
    cascade_wrapper: MultiLevelCascadeAttentionWrapper | None

Comparison Table

Feature FlashAttention (FA2/FA3) FlashInfer
Compute Capability SM >= 8.0 (Ampere+) SM 7.5 - 12.1 (Turing through Blackwell)
KV Cache Layout (2, num_blocks, block_size, num_kv_heads, head_size) (num_blocks, 2, block_size, num_kv_heads, head_size)
Block Sizes MultipleOf(16) -- any multiple of 16 Fixed: 16, 32, 64
FP8 KV Cache FA3 only (fp8_e4m3) fp8_e4m3, fp8_e5m2
CUDA Graph Support FA3: ALWAYS; FA2: UNIFORM_BATCH UNIFORM_BATCH (with TRTLLM decode) or UNIFORM_SINGLE_TOKEN_DECODE
KV Cache Update Separate (reshape_and_cache_flash) Separate (forward_includes_kv_cache_update = False)
Prefill/Decode Split Unified varlen API handles both Separate BatchPrefillWrapper + BatchDecodeWrapper
AOT Scheduling FA3 only -- pre-computes work distribution No AOT; uses TRTLLM fast_decode_plan
Cascade Attention Via prefix KV + suffix KV split Via MultiLevelCascadeAttentionWrapper
Attention Sinks FA3 on SM >= 9.0 Via TRTLLM on SM >= 10.0 (Blackwell)

FlashInfer's TRTLLM Decode Path

On Blackwell (SM 10.0), FlashInfer can use the TensorRT-LLM attention kernel for decoding, which provides better performance for uniform decode batches. The builder checks this at initialization:

# flashinfer.py line 627
# Prefer TRTLLM attention for decoding in all cases.
# This allows AttentionCGSupport.UNIFORM_BATCH mode.
self.use_trtllm_decode_attention = can_use_trtllm_attention(
    self.num_qo_heads, self.num_kv_heads)

# FlashInfer builder tracks separate wrappers per batch size for CG
if self.enable_cuda_graph:
    self._decode_wrappers_cudagraph: dict[int, BatchDecodeWrapper] = {}
    self._decode_cudagraph_max_bs = (1 + num_spec_tokens) * max_num_reqs

FlashInfer Decode Wrapper Setup

# flashinfer.py line 745
def _get_decode_wrapper(self, batch_size, use_cudagraph=False):
    if use_cudagraph:
        decode_wrapper = self._decode_wrappers_cudagraph.get(batch_size)
    else:
        decode_wrapper = self._decode_wrapper

    if decode_wrapper is None:
        decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
            self._get_workspace_buffer(),
            get_kv_cache_layout(),
            use_cuda_graph=use_cudagraph,
            paged_kv_indptr_buffer=paged_kv_indptr,  # Pre-allocated for CG
            paged_kv_indices_buffer=paged_kv_indices,
            paged_kv_last_page_len_buffer=paged_kv_last_page_len,
            use_tensor_cores=True)  # Always use tensor cores
    return decode_wrapper

6. Sampling Pipeline

Sampling Funnel -- From Logits to Token
Raw Logits (vocab_size floats) — compute_logits(hidden_states)
log_softmax (raw logprobs saved BEFORE modifications)
Logits Processors (allowed_ids mask, bad_words, penalties)
Temperature Scaling (logits / temp, in-place)
Top-k Filter (keep k highest)
Top-p Nucleus (cumprob ≤ p)
Sample Token
Greedy path: argmax (skips temperature/top-k/top-p)
Mixed batch: torch.where(temp < 1e-5, greedy, random)

The Sampler in V1 follows a strict ordering of operations to ensure correctness. Logprobs are computed before any modifications to logits, and the sampling itself uses a clean split between greedy and random paths.

Sampler.forward() -- Step by Step

# sampler.py line 67
def forward(self, logits, sampling_metadata, ...):
    # Step 1: Compute raw logprobs BEFORE any modifications
    if num_logprobs is not None:
        if logprobs_mode == "raw_logprobs":
            raw_logprobs = self.compute_logprobs(logits)  # log_softmax
        elif logprobs_mode == "raw_logits":
            raw_logprobs = logits.to(torch.float32)

    # Step 2: Convert to float32
    logits = logits.to(torch.float32)

    # Step 3-6: Apply processors and penalties
    logits = self.apply_logits_processors(logits, sampling_metadata, ...)

    # Step 7: Sample the next token
    sampled, processed_logprobs = self.sample(logits, sampling_metadata)
    sampled = sampled.long()  # FlashInfer returns int32, need int64 for indexing

    # Step 8: Gather top-k logprobs
    if num_logprobs:
        logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, sampled)

    # Return as GPU tensors (async copy to CPU happens later)
    return SamplerOutput(
        sampled_token_ids=sampled.to(torch.int32).unsqueeze(-1),
        logprobs_tensors=logprobs_tensors)

The sample() Method -- Greedy vs Random

# sampler.py line 147
def sample(self, logits, sampling_metadata):
    # Fast path: all greedy
    if not sampling_metadata.all_random:
        greedy_sampled = logits.argmax(dim=-1).view(-1)
        if sampling_metadata.all_greedy:
            return greedy_sampled, processed_logprobs

    # Apply temperature (in-place division)
    # Avoids division by zero for greedy requests
    temp = torch.where(temp < 1e-5, 1.0, temp)
    logits.div_(temp.unsqueeze(dim=1))

    # Apply argmax-invariant logits processors (e.g., min_p)
    for processor in sampling_metadata.logitsprocs.argmax_invariant:
        logits = processor.apply(logits)

    # Top-k and/or Top-p sampling via dedicated sampler
    random_sampled, processed_logprobs = self.topk_topp_sampler(
        logits, sampling_metadata.generators,
        sampling_metadata.top_k, sampling_metadata.top_p)

    # Merge: use greedy for temp < epsilon, random otherwise
    if greedy_sampled is None:
        return random_sampled, processed_logprobs

    sampled = torch.where(
        sampling_metadata.temperature < 1e-5,
        greedy_sampled, random_sampled, out=greedy_sampled)
    return sampled, processed_logprobs

Logits Processing Order

# sampler.py line 266 -- apply_logits_processors
def apply_logits_processors(self, logits, sampling_metadata, ...):
    # 1. Allowed token IDs whitelist (masked_fill_ with -inf)
    if sampling_metadata.allowed_token_ids_mask is not None:
        logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))

    # 2. Bad words exclusion
    if bad_words_token_ids:
        apply_bad_words(logits, bad_words_token_ids, output_token_ids)

    # 3. Non-argmax-invariant processors (min_tokens, logit_bias)
    for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
        logits = processor.apply(logits)

    # 4. Penalties (repetition, frequency, presence)
    logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
    return logits

Greedy Sampling

Simple argmax(dim=-1). Used when temperature < 1e-5. No distribution sampling needed. V1 computes this regardless (as a fallback for mixed batches), then uses torch.where to select per-request.

Temperature Scaling

In-place logits.div_(temp.unsqueeze(1)). For greedy requests in mixed batches, temperature is set to 1.0 to avoid division by zero (the greedy result is used anyway via torch.where).

Top-k / Top-p

Handled by TopKTopPSampler which uses optimized kernels. Top-k first filters to k candidates, then top-p (nucleus) further filters by cumulative probability. Per-request generators support seeded sampling.

Penalties

Applied via apply_all_penalties(): repetition penalty multiplies logits of repeated tokens, frequency penalty subtracts by count, presence penalty subtracts by boolean presence. All use vectorized operations.

7. CUDA Graph Capture and Replay

CUDA Graph -- Capture Phase vs Replay Phase
Capture Phase (warmup)
1
Run dummy forward pass
2
torch.cuda.CUDAGraph.capture()
3
Record kernel sequence as graph
4
Pre-alloc FA3 scheduler_metadata
Cost: one-time, per batch-size configuration
Replay Phase (inference)
1
Update input tensors in-place
2
graph.replay() — single launch
3
All kernels execute as one unit
4
Zero kernel launch overhead
Benefit: eliminates per-kernel CPU dispatch cost
FULL: entire forward in one graph
PIECEWISE: per-layer via torch.compile
NONE: eager execution (profiling)

CUDA graphs eliminate kernel launch overhead by recording a sequence of GPU operations and replaying them as a single unit. vLLM V1 captures graphs for common batch configurations during warmup, then replays them during inference. The tricky part: attention metadata changes every step (different seq_lens, block_tables), so the system must carefully manage which fields are graph-compatible.

Graph Capture Modes

CUDAGraphMode Description When Used
NONE No CUDA graphs; eager execution Profiling, first step with KV scale calibration
FULL Entire forward pass captured in one graph FA3 (ALWAYS support), uniform decode batches
PIECEWISE Graph captured per-layer (torch.compile) Mixed prefill-decode with FA2

Capture Process

# gpu_model_runner.py line 6077
def _capture_cudagraphs(self, batch_descriptors, cudagraph_runtime_mode):
    for batch_desc in batch_descriptors:
        # Only capture ubatched graphs for FULL + uniform decode
        allow_microbatching = (
            self.parallel_config.use_ubatching
            and cudagraph_runtime_mode == CUDAGraphMode.FULL
            and uniform_decode
            and check_ubatch_thresholds(...))

        self._warmup_and_capture(
            batch_desc,
            cudagraph_runtime_mode=cudagraph_runtime_mode,
            allow_microbatching=allow_microbatching)
        torch.accelerator.synchronize()

AttentionCGSupport Levels

# Each attention backend declares its CUDA graph compatibility

# FlashAttention FA3: full cudagraph for all cases
_cudagraph_support = AttentionCGSupport.ALWAYS

# FlashAttention FA2: only works for uniform decode batches
# (special max_query_len=1 packed-GQA handling breaks mixed prefill-decode)
_cudagraph_support = AttentionCGSupport.UNIFORM_BATCH

# FlashInfer with TRTLLM: UNIFORM_BATCH
# FlashInfer without TRTLLM: UNIFORM_SINGLE_TOKEN_DECODE
if has_trtllm_support:
    return AttentionCGSupport.UNIFORM_BATCH
else:
    return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE

FA3 AOT Scheduling for CUDA Graphs

# flash_attn.py line 326 -- Pre-allocate scheduler_metadata for CUDA graphs
if self.use_full_cuda_graph and self.aot_schedule:
    # FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
    # The +1 is for tile_count_semaphore (synchronization)
    # 4 slots per batch element: prepare_varlen + dynamic_split
    #   + sort_batches + head_swizzle
    max_batch_size = max(max_num_seqs, max_cudagraph_capture_size or 0)
    self.scheduler_metadata = torch.zeros(
        1 + round_up(max_batch_size, 4) * 4,
        dtype=torch.int32, device=self.device)

    # Upper bound splits for pre-allocation of intermediate buffers
    self.max_num_splits = (
        self.attention_config.flash_attn_max_num_splits_for_cuda_graph)
Why max_num_splits matters: Setting max_num_splits > 1 during CUDA graph capture forces FA3 to pre-allocate larger intermediate buffers [num_splits, num_heads, num_tokens, head_size]. This ensures the graph has enough memory for any runtime configuration, but increases peak memory usage. Outside of graph capture, max_num_splits = 0 lets FA3 use its own heuristics.

FlashInfer CUDA Graph -- Per-Batch-Size Wrappers

Unlike FlashAttention, FlashInfer requires a separate decode wrapper per batch size when using CUDA graphs. This is because the wrapper's plan() method pre-allocates fixed buffers sized to the batch:

# flashinfer.py line 568
if self.enable_cuda_graph:
    # One BatchDecodeWithPagedKVCacheWrapper per batch size
    self._decode_wrappers_cudagraph: dict[int, BatchDecodeWrapper] = {}
    self._decode_cudagraph_max_bs = (1 + num_spec_tokens) * max_num_reqs

8. GPU Memory Pre-allocation and Budget Calculation

GPU Memory Layout -- Stacked Allocation
Model Weights
tag="weights"
Activations
(profiled peak)
CG Bufs
(estimated)
KV Cache
(fills remaining budget)
Rsv
Weights
Activations
CUDA Graphs
KV Cache
Reserved
available_kv_cache = (gpu_memory_utilization × total) - weights - peak_activations - cudagraph_estimate

vLLM's memory management follows a strict protocol: profile peak usage, calculate remaining budget, then allocate the KV cache to fill exactly that budget. This happens during initialization in Worker.determine_available_memory().

GPU Memory Layout After Initialization Model Weights (tagged "weights") Peak Activations (profiled) CG Buffers (estimated) KV Cache (fills remaining budget) 0 GiB gpu_memory_utilization * total reserved Budget Calculation (Worker.determine_available_memory) requested_memory = gpu_memory_utilization * total_gpu_memory non_kv_cache_memory = non_torch_increase + torch_peak_increase + weights_memory available_kv_cache = requested_memory - non_kv_cache_memory - cudagraph_estimate

Memory Profiling Code

# gpu_worker.py line 331
def determine_available_memory(self) -> int:
    # Execute a forward pass with dummy inputs to profile peak usage
    with memory_profiling(
        self.init_snapshot,
        weights_memory=int(self.model_runner.model_memory_usage),
    ) as profile_result:
        self.model_runner.profile_run()

        profile_torch_peak = torch.accelerator.memory_stats(self.device).get(
            "allocated_bytes.all.peak", 0)

        # Profile CUDA graph memory if graphs will be captured
        cudagraph_memory_estimate = 0
        if not self.model_config.enforce_eager and not current_platform.is_rocm():
            cudagraph_memory_estimate = self.model_runner.profile_cudagraph_memory()

    # Calculate budget
    self.available_kv_cache_memory_bytes = (
        self.requested_memory
        - profile_result.non_kv_cache_memory
        - cudagraph_memory_estimate)

profile_run() -- What Gets Profiled

# gpu_model_runner.py line 5712
def profile_run(self):
    # 1. Profile multimodal encoder if applicable
    if self.supports_mm_inputs:
        dummy_encoder_outputs = self.model.embed_multimodal(**batched_dummy_mm_inputs)
        # Cache them to simulate peak encoder cache usage
        for i, output in enumerate(dummy_encoder_outputs):
            self.encoder_cache[f"tmp_{i}"] = output

    # 2. Run dummy forward with max_num_tokens (worst-case activation memory)
    hidden_states, last_hidden_states = self._dummy_run(
        self.max_num_tokens, is_profile=True)

    # 3. Run dummy sampler (for its memory overhead)
    if get_pp_group().is_last_rank:
        output = self._dummy_sampler_run(last_hidden_states)

    self._sync_device()
    del hidden_states, output
    self.encoder_cache.clear()
    gc.collect()

9. Executor Hierarchy

Executors sit between the engine core and workers, handling the distributed execution strategy. The hierarchy provides three deployment modes: single-GPU, multi-process (same node), and Ray-based (multi-node).

Executor Hierarchy: Engine Core to Workers EngineCore Executor (abstract) UniProcExecutor Single GPU, in-process MultiprocExecutor Multi-GPU, multiprocessing RayExecutor Multi-node, Ray actors Worker + GPUModelRunner Worker 0 Worker 1 Worker N MessageQueue IPC Ray Worker Actors gRPC + Ray object store

UniProcExecutor -- Single GPU

# uniproc_executor.py line 26
class UniProcExecutor(Executor):
    def _init_executor(self):
        self.driver_worker = WorkerWrapperBase(rpc_rank=0)
        distributed_init_method, rank, local_rank = self._distributed_args()
        kwargs = dict(
            vllm_config=self.vllm_config,
            local_rank=local_rank, rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=True)

        self.driver_worker.init_worker(all_kwargs=[kwargs])
        self.driver_worker.init_device()
        self.driver_worker.load_model()

    def execute_model(self, scheduler_output, non_block=False):
        # Direct method call -- no IPC overhead
        output = self.collective_rpc(
            "execute_model", args=(scheduler_output,),
            non_block=non_block, single_value=True)
        return output

    # Async scheduling support: ThreadPoolExecutor for output processing
    if self.max_concurrent_batches > 1:
        self.async_output_thread = ThreadPoolExecutor(max_workers=1)

MultiprocExecutor -- Multi-GPU Same Node

# multiproc_executor.py line 96
class MultiprocExecutor(Executor):
    supports_pp: bool = True

    def _init_executor(self):
        tp_size, pp_size, pcp_size = self._get_parallel_sizes()
        assert self.world_size == tp_size * pp_size * pcp_size

        # Use loopback address for local communication
        distributed_init_method = get_distributed_init_method(
            get_loopback_ip(), get_open_port())

        # Set up MessageQueue for efficient SchedulerOutput broadcasting
        self.rpc_broadcast_mq = MessageQueue(
            self.world_size, self.local_world_size,
            max_chunk_bytes=max_chunk_bytes)

        # Spawn worker processes (one per GPU)
        # Each worker process runs Worker.init_device() + Worker.load_model()
        # Communication via shared memory MessageQueue

Worker -- The GPU-Side Orchestrator

# gpu_worker.py line 105
class Worker(WorkerBase):
    def __init__(self, vllm_config, local_rank, rank, ...):
        # Configure float32 matmul precision
        precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
        torch.set_float32_matmul_precision(precision)

    def init_device(self):
        # Set CUDA device, init distributed, take memory snapshot
        self.device = torch.device(f"cuda:{self.local_rank}")
        init_worker_distributed_environment(...)

        # Construct the model runner
        self.model_runner = GPUModelRunner(self.vllm_config, self.device)

        # Take baseline memory snapshot for budget calculation
        self.init_snapshot = MemorySnapshot(device=self.device)
        self.requested_memory = request_memory(init_snapshot, self.cache_config)

    def load_model(self):
        with self._maybe_get_memory_pool_context(tag="weights"):
            self.model_runner.load_model()

    def determine_available_memory(self) -> int:
        # Profile model, compute KV cache budget
        ...

Sleep/Wake Protocol

vLLM V1 supports a sleep mode where the worker offloads model weights to free GPU memory (e.g., for resource sharing). The Worker has two levels:

# gpu_worker.py line 157
def sleep(self, level=1):
    # Level 1: offload weights only (tag="weights")
    # Level 2: save all buffers to CPU, offload everything
    if level == 2:
        self._sleep_saved_buffers = {
            name: buffer.cpu().clone()
            for name, buffer in model.named_buffers()}

    allocator = CuMemAllocator.get_instance()
    allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())

def wake_up(self, tags=None):
    allocator.wake_up(tags)
    # Restore saved buffers
    for name, buffer in model.named_buffers():
        if name in self._sleep_saved_buffers:
            buffer.data.copy_(self._sleep_saved_buffers[name].data)

Summary of Key Design Decisions

Persistent Batch

InputBatch maintains pre-allocated arrays across steps. Only deltas are applied -- dramatically reducing per-step CPU overhead compared to V0's per-step reconstruction.

Two-Phase Execution

execute_model() + sample_tokens() split allows grammar bitmask computation to overlap with GPU forward, and enables future optimizations like speculative decoding interleaving.

Backend Abstraction

Clean AttentionBackend / MetadataBuilder / Impl separation means FlashAttention, FlashInfer, and future backends (e.g., Triton-based) can coexist. The system even supports different backends for different layer groups.

Overlap Everywhere

Block table GPU copy starts before CPU input prep finishes. Sampling output copies to CPU via a separate CUDA stream. Slot mapping runs as a Triton kernel on GPU while CPU computes positions.

CUDA Graph Flexibility

Three CG modes (NONE, FULL, PIECEWISE) adapt to backend capabilities. FA3 gets ALWAYS support; FA2 and FlashInfer fall back to UNIFORM_BATCH or single-token decode only.

Memory-First Design

Memory is profiled by running the actual model at max batch size, then subtracting from the user's gpu_memory_utilization budget. CUDA graph memory is separately estimated to avoid double-counting.

File Reference Quick Index

FileKey Classes/FunctionsLines Read
vllm/v1/worker/gpu_model_runner.pyGPUModelRunner, execute_model, _update_states, _prepare_inputs, _sample, _capture_cudagraphs6000+
vllm/v1/worker/gpu_worker.pyWorker, init_device, determine_available_memory, sleep/wake_up~500
vllm/v1/worker/gpu_input_batch.pyInputBatch, CachedRequestState, add_request, SamplingMetadata~350
vllm/v1/worker/block_table.pyBlockTable, MultiGroupBlockTable, compute_slot_mapping, map_to_kernel_blocks~300
vllm/v1/attention/backends/flash_attn.pyFlashAttentionBackend, FlashAttentionMetadata, FlashAttentionMetadataBuilder~600
vllm/v1/attention/backends/flashinfer.pyFlashInferBackend, FlashInferMetadata, FlashInferMetadataBuilder, TRTLLM paths~1050
vllm/v1/sample/sampler.pySampler, forward, sample, apply_temperature, apply_logits_processors~320
vllm/v1/executor/uniproc_executor.pyUniProcExecutor, collective_rpc, execute_model~150
vllm/v1/executor/multiproc_executor.pyMultiprocExecutor, MessageQueue IPC, FutureWrapper~150