Life of a Request in vLLM V1

Tracing a single HTTP chat-completion request end-to-end through vLLM's multi-process architecture -- from FastAPI handler to GPU forward pass and back, with actual source code.

Source: vllm/v1/ -- live codebase analysis
Back to AI Infra Overview

0. TL;DR -- The Journey at a Glance

HTTP POST
Tokenize
ZMQ IPC
Schedule
GPU Forward
Update
ZMQ Back
Detokenize
SSE Stream
API Process (P0)
ZMQ IPC
EngineCore (P1)
GPU
Key insight: vLLM V1 splits work across two OS processes. The API process handles HTTP, tokenization, and detokenization. The EngineCore process runs the GPU scheduler + model. They communicate over ZeroMQ + msgpack, eliminating Python GIL contention.
StepProcessComponentWhat Happens
1APIFastAPI RouterHTTP POST /v1/chat/completions received
2APIOpenAIServingChatParse request, apply chat template, tokenize prompt
3APIAsyncLLM.generate()Create EngineCoreRequest, register with OutputProcessor
4APIAsyncMPClientSerialize via msgpack, send over ZMQ ROUTER socket
5CoreEngineCoreProc input threadReceive from ZMQ DEALER, deserialize, push to input_queue
6CoreScheduler.schedule()Allocate KV blocks, create token budget, build SchedulerOutput
7CoreExecutor.execute_model()Send SchedulerOutput to Worker process(es)
8CoreGPUModelRunner.execute_model()Build attention metadata, run model forward, sample tokens
9CoreScheduler.update_from_output()Update request state, detect stop conditions, build EngineCoreOutputs
10CoreEngineCoreProc output threadSerialize EngineCoreOutputs via msgpack, send over ZMQ PUSH
11APIAsyncLLM output_handlerReceive outputs, run OutputProcessor (detokenize)
12APIFastAPI StreamingResponseYield SSE chunks back to client

1. Multi-Process Architecture Diagram

API Process (P0)
HTTP + Tokenize + Detokenize
asyncio event loop
ZMQ+msgpack ROUTER/DEALER
EngineCore Process (P1)
Schedule + GPU Forward + Update
tight busy loop + IO threads
External Client
HTTP Client
POST /v1/chat/completions
{"messages": [...]}
API Process (P0) -- asyncio event loop, handles HTTP, tokenization, detokenization
FastAPI + Uvicorn
api_server.py
chat_completion/api_router.py
OpenAIServingChat
chat_completion/serving.py
render + tokenize
AsyncLLM
v1/engine/async_llm.py
InputProcessor + OutputProcessor
generate() + output_handler()
↓ ↑
ZeroMQ + msgpack IPC
ZMQ ROUTER
Requests ↓
(API → Core)
ZMQ PUSH/PULL
Outputs ↑
(Core → API)
MsgpackEncoder
serial_utils.py
zero-copy tensors
↓ ↑
EngineCore Process (P1) -- runs the tight schedule-execute-update loop
Input Thread
ZMQ DEALER →
input_queue
EngineCoreProc.run_busy_loop()
v1/engine/core.py
schedule() → execute_model() → update_from_output()
Output Thread
output_queue →
ZMQ PUSH
Scheduler
v1/core/sched/scheduler.py
unified token budget
KVCacheManager
v1/core/kv_cache_manager.py
block alloc + prefix cache
Worker + GPUModelRunner
v1/worker/gpu_worker.py
v1/worker/gpu_model_runner.py
InputBatch + forward + sample
GPU (CUDA)
Model forward pass
Attention kernels
Sampling

2. Step 1 -- HTTP Entry: FastAPI Router

Client
POST /v1/chat/completions
FastAPI Router
@with_cancellation
OpenAIServingChat
chat template + tokenize
engine_client.generate()
AsyncLLM

File: vllm/entrypoints/openai/chat_completion/api_router.py

Every request begins at a FastAPI route. The create_chat_completion handler is decorated with @with_cancellation (for client disconnect handling) and @load_aware_call.

@router.post("/v1/chat/completions", ...)
@with_cancellation
@load_aware_call
async def create_chat_completion(
    request: ChatCompletionRequest, raw_request: Request
):
    handler = chat(raw_request)  # Get OpenAIServingChat from app state
    generator = await handler.create_chat_completion(request, raw_request)

    if isinstance(generator, ChatCompletionResponse):
        return JSONResponse(content=generator.model_dump(), ...)
    return StreamingResponse(content=generator, media_type="text/event-stream")

File: vllm/entrypoints/openai/chat_completion/serving.py

The OpenAIServingChat.create_chat_completion() method applies the chat template, tokenizes the prompt, builds sampling parameters, then calls engine_client.generate().

async def create_chat_completion(self, request, raw_request):
    # Apply chat template and tokenize
    result = await self.render_chat_request(request)
    conversation, engine_inputs = result

    request_id = f"chatcmpl-{self._base_request_id(...)}"
    sampling_params = request.to_sampling_params(max_tokens, ...)

    # engine_client is AsyncLLM -- calls generate() below
    generator = self.engine_client.generate(
        prompt=engine_input,
        sampling_params=sampling_params,
        request_id=sub_request_id,
        lora_request=lora_request,
        trace_headers=trace_headers,
        data_parallel_rank=data_parallel_rank,
    )

3. Step 2 -- AsyncLLM: Tokenize + Queue

Input
EngineInput (raw prompt)
Processing (AsyncLLM)
InputProcessor → tokenize
OutputProcessor → register request
RequestOutputCollector → create queue
Output
EngineCoreRequest
(prompt_token_ids + params)

File: vllm/v1/engine/async_llm.py -- The API-Facing Engine

AsyncLLM is the beating heart of the API process. It holds three critical components:

class AsyncLLM(EngineClient):
    def __init__(self, vllm_config, executor_class, ...):
        # Convert EngineInput --> EngineCoreRequest (tokenization)
        self.input_processor = InputProcessor(self.vllm_config, renderer)

        # Converts EngineCoreOutputs --> RequestOutput (detokenization)
        self.output_processor = OutputProcessor(
            renderer.tokenizer,
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
        )

        # EngineCore (starts the engine in background process)
        self.engine_core = EngineCoreClient.make_async_mp_client(
            vllm_config=vllm_config,
            executor_class=executor_class,
            log_stats=self.log_stats,
        )

generate() -- The Main Entry Point for Every Request

async def generate(self, prompt, sampling_params, request_id, ...):
    """
    Main function called by the API server to kick off a request
      * 1) Making an AsyncStream corresponding to the Request.
      * 2) Processing the Input.
      * 3) Adding the Request to the Detokenizer.
      * 4) Adding the Request to the EngineCore (separate process).
    """
    q = await self.add_request(request_id, prompt, sampling_params, ...)

    # The output_handler task pushes items into the queue.
    # This task pulls from the queue and yields to caller.
    finished = False
    while not finished:
        out = q.get_nowait() or await q.get()
        finished = out.finished
        if out is not STREAM_FINISHED:
            yield out

add_request() -- Tokenize and Dispatch

async def add_request(self, request_id, prompt, params, ...):
    # Convert Input --> EngineCoreRequest (includes tokenization)
    request = self.input_processor.process_inputs(
        request_id, prompt, params,
        supported_tasks=await self.get_supported_tasks(),
        arrival_time=arrival_time,
        lora_request=lora_request,
    )

    # Create output collector for this request
    queue = RequestOutputCollector(params.output_kind, request.request_id)

    await self._add_request(request, prompt_text, None, 0, queue)
    return queue

async def _add_request(self, request, prompt, parent_req, index, queue):
    # Register with OutputProcessor (this process - for detokenization later)
    self.output_processor.add_request(request, prompt, parent_req, index, queue)

    # Send to EngineCore (separate process) via ZMQ
    await self.engine_core.add_request_async(request)
Key data structure -- EngineCoreRequest: This msgspec Struct is what crosses the process boundary. It contains the tokenized prompt_token_ids, sampling_params, and metadata -- but NOT the raw text (that stays in the API process for detokenization).
# vllm/v1/engine/__init__.py
class EngineCoreRequest(
    msgspec.Struct,
    array_like=True,   # compact serialization
    omit_defaults=True, # skip default values to save bytes
    gc=False,            # disable cyclic GC tracking
):
    request_id: str
    prompt_token_ids: list[int] | None
    mm_features: list[MultiModalFeatureSpec] | None
    sampling_params: SamplingParams | None
    pooling_params: PoolingParams | None
    arrival_time: float
    lora_request: LoRARequest | None
    cache_salt: str | None
    data_parallel_rank: int | None
    prompt_embeds: torch.Tensor | None = None
    client_index: int = 0
    priority: int = 0

4. Step 3 -- ZeroMQ IPC: Crossing the Process Boundary

API Process
AsyncMPClient
ROUTER DEALER
EngineCoreRequest (msgpack)
PULL PUSH
EngineCoreOutputs (msgpack)
EngineCore Process
Input/Output Threads

File: vllm/v1/engine/core_client.py -- MPClient (ZMQ Client Side)

The AsyncMPClient sends requests to EngineCore over a ZMQ ROUTER socket. The EngineCore process connects back with DEALER sockets. Messages are serialized with msgpack for speed.

class MPClient(EngineCoreClient):
    def __init__(self, asyncio_mode, vllm_config, executor_class, ...):
        # ZMQ setup
        sync_ctx = zmq.Context(io_threads=2)
        self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx

        # ROUTER socket for sending requests (API -> Core)
        self.input_socket = make_zmq_socket(
            self.ctx, addresses.inputs[0], zmq.ROUTER, bind=True
        )
        # PULL socket for receiving outputs (Core -> API)
        self.resources.output_socket = make_zmq_socket(
            self.ctx, addresses.outputs[0], zmq.PULL
        )

        # Serialization with optional tensor IPC for multimodal
        self.encoder = MsgpackEncoder(oob_tensor_consumer=tensor_ipc_sender)
        self.decoder = MsgpackDecoder(EngineCoreOutputs)

File: vllm/v1/engine/core.py -- EngineCoreProc Input Thread (Core Side)

In the EngineCore process, a dedicated input thread reads from ZMQ and pushes into a Python queue. This thread releases the GIL during ZMQ I/O, allowing the main thread to run GPU operations concurrently.

def process_input_sockets(self, input_addresses, coord_input_address,
                          identity, ready_event):
    """Input socket IO thread."""

    add_request_decoder = MsgpackDecoder(
        EngineCoreRequest, oob_tensor_provider=self.tensor_ipc_receiver
    )

    # ZMQ DEALER sockets connect to the API ROUTER
    input_sockets = [
        make_zmq_socket(ctx, addr, zmq.DEALER, identity=identity, bind=False)
        for addr in input_addresses
    ]

    while True:
        for input_socket, _ in poller.poll():
            type_frame, *data_frames = input_socket.recv_multipart(copy=False)
            request_type = EngineCoreRequestType(bytes(type_frame.buffer))

            if request_type == EngineCoreRequestType.ADD:
                req = add_request_decoder.decode(data_frames)
                request = self.preprocess_add_request(req)

            # Push to input queue for core busy loop
            self.input_queue.put_nowait((request_type, request))
Why ZMQ + msgpack? Python's GIL means a single process cannot truly run asyncio I/O and GPU computation in parallel. By splitting into two processes with ZMQ IPC (which releases the GIL during socket operations), vLLM V1 achieves true overlap: the API process can tokenize/detokenize while the EngineCore process runs the GPU.

5. Step 4 -- EngineCore: The Tight Busy Loop

drain queue
schedule()
execute_model()
update_from_output()
output_queue
Repeats every iteration -- tight loop, no mode switching

File: vllm/v1/engine/core.py -- The Heart of vLLM

EngineCore's main thread runs a tight busy loop that orchestrates three core operations in sequence every iteration.

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""
        while self._handle_shutdown():
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

        raise SystemExit

_process_input_queue() -- Drain New Requests

def _process_input_queue(self):
    """Exits when an engine step needs to be performed."""

    while not self.has_work() and self.is_running():
        # Block on input queue when idle
        block = self.process_input_queue_block
        try:
            req = self.input_queue.get(block=block)
            self._handle_client_request(*req)
        except queue.Empty:
            break

    # Handle any more client requests (non-blocking drain)
    while not self.input_queue.empty():
        req = self.input_queue.get_nowait()
        self._handle_client_request(*req)

step() -- Schedule, Execute, Update

This is the single most important function in vLLM V1. It runs every iteration of the busy loop:

def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
    """Schedule, execute, and make output."""

    # Check for any requests remaining in the scheduler
    if not self.scheduler.has_requests():
        return {}, False

    # 1. SCHEDULE: Decide which requests get tokens this step
    scheduler_output = self.scheduler.schedule()

    # 2. EXECUTE: Run model forward pass on the GPU
    future = self.model_executor.execute_model(scheduler_output, non_block=True)

    # 3. GRAMMAR: Get structured output bitmask (if needed)
    grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)

    # 4. WAIT: Block until GPU forward + sampling complete
    model_output = future.result()
    if model_output is None:
        model_output = self.model_executor.sample_tokens(grammar_output)

    # 5. UPDATE: Process results, detect stop conditions
    self._process_aborts_queue()
    engine_core_outputs = self.scheduler.update_from_output(
        scheduler_output, model_output
    )

    return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0

_process_engine_step() -- Send Results to Output Queue

def _process_engine_step(self) -> bool:
    """Called only when there are unfinished local requests."""

    # Step the engine core
    outputs, model_executed = self.step_fn()

    # Put EngineCoreOutputs into the output queue
    for output in outputs.items() if outputs else ():
        self.output_queue.put_nowait(output)

    # Post-step hook (e.g., update draft tokens for spec decode)
    self.post_step(model_executed)
    return model_executed

6. Step 5 -- Scheduler: Unified Token Budget

Unified Token Budget (e.g., 1024 tokens per step)
Prefill Req A
512 tokens
Decode B
1 tok
Decode C
1 tok
Decode D
1 tok
... 509 more decode slots
Prefill (chunked)
Decode (1 token each)

File: vllm/v1/core/sched/scheduler.py

The V1 Scheduler uses a unified token budget that covers both prefill and decode tokens. There is no separate "prefill phase" vs "decode phase" -- each request simply has num_computed_tokens that it tries to catch up to num_tokens_with_spec.

class Scheduler(SchedulerInterface):
    def __init__(self, vllm_config, kv_cache_config, ...):
        # Scheduling constraints
        self.max_num_running_reqs = self.scheduler_config.max_num_seqs
        self.max_num_scheduled_tokens = (
            self.scheduler_config.max_num_scheduled_tokens
            if self.scheduler_config.max_num_scheduled_tokens
            else self.scheduler_config.max_num_batched_tokens
        )

        # Priority queues
        self.waiting = create_request_queue(self.policy)
        self.running: list[Request] = []

        # Create the KV cache manager
        self.kv_cache_manager = KVCacheManager(
            kv_cache_config=kv_cache_config,
            max_model_len=self.max_model_len,
            enable_caching=self.cache_config.enable_prefix_caching,
        )

schedule() -- The Scheduling Algorithm

def schedule(self) -> SchedulerOutput:
    # NOTE(woosuk) on the scheduling algorithm:
    # There's no "decoding phase" nor "prefill phase" in the scheduler.
    # Each request just has num_computed_tokens and num_tokens_with_spec.
    # At each step, the scheduler tries to assign tokens to the requests
    # so that each request's num_computed_tokens can catch up its
    # num_tokens_with_spec.

    token_budget = self.max_num_scheduled_tokens

    # First, schedule the RUNNING requests
    req_index = 0
    while req_index < len(self.running) and token_budget > 0:
        request = self.running[req_index]

        num_new_tokens = (
            request.num_tokens_with_spec
            + request.num_output_placeholders
            - request.num_computed_tokens
        )
        num_new_tokens = min(num_new_tokens, token_budget)

        # Allocate KV cache blocks for new tokens
        new_blocks = self.kv_cache_manager.allocate_slots(
            request, num_new_tokens,
            num_lookahead_tokens=self.num_lookahead_tokens,
        )

        if new_blocks is None:
            # Out of memory -- preempt lowest-priority request
            preempted_req = self.running.pop()
            self._preempt_request(preempted_req, ...)
            continue

        # Schedule the request
        num_scheduled_tokens[request_id] = num_new_tokens
        token_budget -= num_new_tokens
        req_index += 1

    # Then, schedule WAITING requests with remaining budget
    # (similar loop over self.waiting queue)
    ...
V1 Innovation -- Unified Token Budget: Instead of splitting iterations into prefill-only or decode-only batches, V1 packs prefill tokens and decode tokens into the same budget. A prefill request consuming 512 tokens and 488 decode requests each consuming 1 token can share the same 1000-token step. This is the foundation of "chunked prefill" which dramatically reduces time-to-first-token for concurrent requests.

7. Step 6 -- KV Cache Management

GPU KV Cache Block Pool
Req A
Blk 0
Req A
Blk 1
Req B
Blk 0
Cached
(prefix)
Cached
(prefix)
Req C
Blk 0
Free
Free
Free
Active blocks
Prefix cached (hash-matched)
Free pool

File: vllm/v1/core/kv_cache_manager.py

The KVCacheManager is responsible for allocating and freeing blocks of GPU memory for KV caches. It supports prefix caching through block hashing.

class KVCacheManager:
    def __init__(self, kv_cache_config, max_model_len, ...):
        self.coordinator = get_kv_cache_coordinator(
            kv_cache_config=kv_cache_config,
            max_model_len=self.max_model_len,
            enable_caching=self.enable_caching,
        )
        self.block_pool = self.coordinator.block_pool

    def get_computed_blocks(self, request) -> tuple[KVCacheBlocks, int]:
        """Find prefix cache hits for a request."""
        if not self.enable_caching:
            return self.empty_kv_cache_blocks, 0

        max_cache_hit_length = request.num_tokens - 1
        computed_blocks, num_new_computed_tokens = (
            self.coordinator.find_longest_cache_hit(
                request.block_hashes, max_cache_hit_length
            )
        )
        return self.create_kv_cache_blocks(computed_blocks), num_new_computed_tokens
# The KVCacheBlocks data structure
@dataclass
class KVCacheBlocks:
    blocks: tuple[Sequence[KVCacheBlock], ...]
    """
    blocks[i][j] refers to the i-th kv_cache_group
    and the j-th block of tokens.
    """

    def get_block_ids(self) -> tuple[list[int], ...]:
        return tuple([blk.block_id for blk in group] for group in self.blocks)

8. Step 7 -- Worker + GPUModelRunner: The Forward Pass

GPUModelRunner.execute_model() Pipeline
1. _update_states
Patch InputBatch
2. _prepare_inputs
Attention metadata
3. CUDA Graph
Pad + select mode
4. Forward Pass
model(input_ids,...)
5. Sample
logits → token_ids

File: vllm/v1/worker/gpu_worker.py -- Worker

The Worker manages the GPU device, initializes distributed communication, and hosts the GPUModelRunner. In a single-GPU setup, it runs inside the EngineCore process. With tensor parallelism, separate worker processes handle each GPU.

class Worker(WorkerBase):
    def __init__(self, vllm_config, local_rank, rank, ...):
        super().__init__(
            vllm_config=vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )

File: vllm/v1/worker/gpu_model_runner.py -- GPUModelRunner

The GPUModelRunner is the most complex component. It manages the persistent InputBatch, builds attention metadata, runs the model forward pass, and samples output tokens.

class GPUModelRunner(
    LoRAModelRunnerMixin, KVConnectorModelRunnerMixin, ECConnectorModelRunnerMixin
):
    def __init__(self, vllm_config, device):
        self.max_num_tokens = scheduler_config.max_num_batched_tokens
        self.max_num_reqs = scheduler_config.max_num_seqs
        self.device = device
        self.dtype = self.model_config.dtype
        ...

execute_model() -- The GPU Forward Pass

def execute_model(
    self, scheduler_output: SchedulerOutput,
    intermediate_tensors: IntermediateTensors | None = None,
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors | None:

    num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens

    # 1. Update persistent batch states from SchedulerOutput
    deferred_state_corrections_fn = self._update_states(scheduler_output)

    if not num_scheduled_tokens:
        return EMPTY_MODEL_RUNNER_OUTPUT

    num_reqs = self.input_batch.num_reqs
    req_ids = self.input_batch.req_ids

    # 2. Prepare inputs (build attention metadata, positions, etc.)
    logits_indices, spec_decode_metadata = self._prepare_inputs(
        scheduler_output, num_scheduled_tokens_np,
    )

    # 3. Determine CUDA graph mode and padding
    (cudagraph_mode, batch_desc, should_ubatch,
     num_tokens_across_dp, cudagraph_stats,
    ) = self._determine_batch_execution_and_padding(
        num_tokens=num_tokens_unpadded,
        num_reqs=num_reqs,
        num_scheduled_tokens_np=num_scheduled_tokens_np,
        max_num_scheduled_tokens=max_num_scheduled_tokens,
        ...
    )

    # 4. Build attention metadata per layer
    # 5. Run the actual model forward pass
    # 6. Sample tokens from logits
    # 7. Return ModelRunnerOutput
V1 Innovation -- Persistent InputBatch: Unlike V0 which recreated input tensors every step, V1 maintains a InputBatch that persists across iterations. New requests are added to it, finished requests are removed, and only the deltas are applied. This avoids expensive tensor allocation/deallocation.
# vllm/v1/worker/gpu_input_batch.py
class InputBatch:
    def __init__(self, max_num_reqs, max_model_len, max_num_batched_tokens, device, ...):
        # Pre-allocated CPU tensors (persistent across steps)
        self.token_ids_cpu_tensor = torch.zeros(
            (max_num_reqs, max_model_len), device="cpu", dtype=torch.int32
        )
        self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()

        self.num_computed_tokens_cpu_tensor = torch.zeros(
            (max_num_reqs,), device="cpu", dtype=torch.int32
        )

        # Block table (maps request -> KV cache blocks)
        self.block_table = MultiGroupBlockTable(
            max_num_reqs=max_num_reqs,
            max_model_len=max_model_len,
            ...
        )

        # Sampling metadata (pre-allocated)
        self.temperatures_cpu_tensor = torch.empty(...)
        self.top_p_cpu_tensor = torch.empty(...)
        self.top_k_cpu_tensor = torch.empty(...)

9. Step 8 -- ModelRunnerOutput: Tokens Come Back

ModelRunnerOutput Structure
req_ids[]
per-request identifiers
sampled_token_ids[][]
num_reqs x num_generated
logprobs
(optional)
pooler_output
(embedding models)

Data Structure: ModelRunnerOutput

After the forward pass and sampling, the GPU produces a ModelRunnerOutput containing the sampled token IDs for every request in the batch.

# vllm/v1/outputs.py
@dataclass
class ModelRunnerOutput:
    # [num_reqs]
    req_ids: list[str]
    req_id_to_index: dict[str, int]

    # num_reqs x num_generated_tokens
    sampled_token_ids: list[list[int]]

    # [num_reqs, max_num_logprobs + 1]
    logprobs: LogprobsLists | None = None

    # req_id -> (token_ids, logprobs, ranks)
    prompt_logprobs_dict: dict[str, LogprobsTensors | None]

    # [num_reqs, hidden_size]
    pooler_output: list[torch.Tensor | None] | None = None

10. Step 9 -- Scheduler.update_from_output(): State Update

ModelRunnerOutput
sampled tokens
Append Tokens
update request state
Check Stop
EOS / max_tokens / stop_str
Finished? Free KV blocks
Running? Continue next step
EngineCoreOutputs
to output_queue

File: vllm/v1/core/sched/scheduler.py

After model execution, the scheduler updates every request's state with the newly generated tokens, checks stop conditions, and produces EngineCoreOutput objects.

def update_from_output(
    self,
    scheduler_output: SchedulerOutput,
    model_runner_output: ModelRunnerOutput,
) -> dict[int, EngineCoreOutputs]:
    sampled_token_ids = model_runner_output.sampled_token_ids
    num_scheduled_tokens = scheduler_output.num_scheduled_tokens

    outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)

    for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
        request = self.requests.get(req_id)
        if request is None or request.is_finished():
            continue

        req_index = model_runner_output.req_id_to_index[req_id]
        generated_token_ids = sampled_token_ids[req_index]

        # Update request with new tokens, check stop conditions
        new_token_ids, stopped = self._update_request_with_output(
            request, generated_token_ids
        )

        if stopped:
            finish_reason = request.get_finished_reason()
            finished = self._handle_stopped_request(request)
            if finished:
                kv_transfer_params = self._free_request(request)

        # Build EngineCoreOutput for this request
        output = EngineCoreOutput(
            request_id=req_id,
            new_token_ids=new_token_ids,
            finish_reason=finish_reason,
            ...
        )
        outputs[client_index].append(output)
# EngineCoreOutput -- what crosses the ZMQ boundary back to the API process
class EngineCoreOutput(
    msgspec.Struct, array_like=True, omit_defaults=True, gc=False
):
    request_id: str
    new_token_ids: list[int]
    new_logprobs: LogprobsLists | None = None
    finish_reason: FinishReason | None = None
    stop_reason: int | str | None = None
    num_cached_tokens: int = 0

# Wrapped in EngineCoreOutputs for batch sending
class EngineCoreOutputs(
    msgspec.Struct, array_like=True, omit_defaults=True, gc=False
):
    engine_index: int = 0
    outputs: list[EngineCoreOutput] = []
    scheduler_stats: SchedulerStats | None = None
    timestamp: float = 0.0

11. Step 10 -- ZMQ IPC: Results Flow Back

output_queue
EngineCoreOutputs
Output Thread
encode_into() + buffer reuse
ZMQ PUSH
zero-copy send
~~~→
ZMQ PULL
API process
output_handler()
asyncio task

File: vllm/v1/engine/core.py -- Output Thread

The EngineCore's output thread serializes EngineCoreOutputs with msgpack and sends them over ZMQ PUSH sockets back to the API process.

def process_output_sockets(self, output_paths, coord_output_path, engine_index):
    """Output socket IO thread."""

    encoder = MsgpackEncoder()
    reuse_buffers: list[bytearray] = []
    pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()

    sockets = [
        make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
        for output_path in output_paths
    ]

    while True:
        output = self.output_queue.get()
        client_index, outputs = output
        outputs.engine_index = engine_index

        # Reclaim buffers that zmq is finished with
        while pending and pending[-1][0].done:
            reuse_buffers.append(pending.pop()[2])

        buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
        buffers = encoder.encode_into(outputs, buffer)
        tracker = sockets[client_index].send_multipart(
            buffers, copy=False, track=True
        )

12. Step 11 -- OutputProcessor: Detokenize

EngineCoreOutputs
token IDs per request
IncrementalDetokenizer
token IDs → text
Stop String Check
abort if found
RequestOutput
text + finish_reason
queue.put()
wake generate()

File: vllm/v1/engine/async_llm.py -- output_handler()

Back in the API process, a background asyncio task continuously pulls outputs from EngineCore and pushes them through the OutputProcessor for detokenization.

async def output_handler():
    try:
        while True:
            # 1) Pull EngineCoreOutputs from the EngineCore
            outputs = await engine_core.get_output_async()
            num_outputs = len(outputs.outputs)

            iteration_stats = IterationStats() if (log_stats and num_outputs) else None

            # Split outputs into chunks to avoid blocking the event loop
            engine_core_outputs = outputs.outputs
            for start in range(0, num_outputs, chunk_size):
                end = start + chunk_size
                outputs_slice = engine_core_outputs[start:end]

                # 2) Process: detokenize + build RequestOutput objects
                processed_outputs = output_processor.process_outputs(
                    outputs_slice, outputs.timestamp, iteration_stats
                )

                # Allow other asyncio tasks to run between chunks
                if end < num_outputs:
                    await asyncio.sleep(0)

                # 3) Abort any reqs that finished due to stop strings
                if processed_outputs.reqs_to_abort:
                    await engine_core.abort_requests_async(
                        processed_outputs.reqs_to_abort
                    )
    except Exception as e:
        output_processor.propagate_error(e)

File: vllm/v1/engine/output_processor.py -- OutputProcessor.process_outputs()

This is the only function that loops over every request output per iteration. It performs detokenization via the IncrementalDetokenizer and produces RequestOutput objects.

class OutputProcessor:
    """Process EngineCoreOutputs into RequestOutputs."""

    def __init__(self, tokenizer, *, log_stats, stream_interval=1, ...):
        self.tokenizer = tokenizer
        self.request_states: dict[str, RequestState] = {}

    def process_outputs(
        self, engine_core_outputs: list[EngineCoreOutput],
        engine_core_timestamp: float | None = None,
        iteration_stats: IterationStats | None = None,
    ) -> OutputProcessorOutput:
        """
        Process the EngineCoreOutputs:
        1) Compute stats for logging
        2) Detokenize
        3) Create and handle RequestOutput objects
        """
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)

            # 1) Compute stats for this iteration
            self._update_stats_from_output(req_state, engine_core_output, ...)

            # 2) Detokenize new token IDs into text
            # (via IncrementalDetokenizer in RequestState)
            # ...

            # 3) Build RequestOutput and push to per-request queue
            request_output = req_state.make_request_output(
                new_token_ids=new_token_ids,
                finish_reason=finish_reason,
                ...
            )
            if request_output and req_state.queue:
                req_state.queue.put(request_output)
Note: The RequestOutputCollector uses a lightweight asyncio Event to signal the generate() coroutine. When outputs are ready, the generate() task wakes up and yields them to the FastAPI StreamingResponse handler.
# vllm/v1/engine/output_processor.py
class RequestOutputCollector:
    def __init__(self, output_kind, request_id):
        self.aggregate = output_kind == RequestOutputKind.DELTA
        self.request_id = request_id
        self.output = None
        self.ready = asyncio.Event()

    def put(self, output):
        """Non-blocking put operation."""
        if self.output is None:
            self.output = output
            self.ready.set()  # Wake up the generate() task
        elif isinstance(self.output, RequestOutput):
            self.output.add(output, aggregate=self.aggregate)  # Merge deltas

    async def get(self):
        while (output := self.output) is None:
            await self.ready.wait()
        self.output = None
        self.ready.clear()
        return output

13. Step 12 -- HTTP Response: Streaming Back to Client

generate() yields
RequestOutput
OpenAI Format
data: {"choices":[...]}
StreamingResponse
text/event-stream
HTTP Client
data: [DONE]

The generate() coroutine in AsyncLLM yields RequestOutput objects. The OpenAI serving layer wraps these into SSE-formatted JSON chunks that stream back through the FastAPI StreamingResponse.

# Back in AsyncLLM.generate():
while not finished:
    # Fast path: drain queue without await if possible
    out = q.get_nowait() or await q.get()
    finished = out.finished
    if out is not STREAM_FINISHED:
        yield out   # --> back to OpenAIServingChat --> SSE chunk

# The API server converts RequestOutput to SSE events:
# data: {"id":"chatcmpl-...","choices":[{"delta":{"content":"Hello"},...}]}
# data: {"id":"chatcmpl-...","choices":[{"delta":{"content":" world"},...}]}
# data: [DONE]

14. Key V1 Innovations Summary

Multi-Process
API + EngineCore split
ZMQ + msgpack
GIL-free IPC
Unified Budget
Chunked prefill
Persistent Batch
Delta-only updates
Tight Busy Loop
Minimal overhead

14.1 Multi-Process Architecture

vLLM V1 splits the API server and EngineCore into separate OS processes. The API process runs the asyncio event loop (HTTP handling, tokenization, detokenization) while the EngineCore process runs the GPU scheduler and model in a tight busy loop. This eliminates GIL contention.

AspectV0V1
Process modelSingle processMulti-process (API + EngineCore)
IPCIn-process queuesZeroMQ + msgpack
GIL impactAsyncio and GPU competeFully overlapped
SerializationN/A (in-process)msgspec Structs (zero-copy where possible)

14.2 ZeroMQ + msgpack IPC

Communication uses ZMQ sockets with msgspec for serialization. Key design choices:

# Request types are single-byte hex for minimal overhead
class EngineCoreRequestType(enum.Enum):
    ADD = b"\x00"
    ABORT = b"\x01"
    START_DP_WAVE = b"\x02"
    UTILITY = b"\x03"
    EXECUTOR_FAILED = b"\x04"
    WAKEUP = b"\x05"

14.3 Unified Token Budget (Chunked Prefill)

The scheduler makes no distinction between "prefill tokens" and "decode tokens". Both consume from the same token budget per step. This enables:

# From the schedule() docstring:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and
# num_tokens_with_spec. At each step, the scheduler tries to assign
# tokens to the requests so that each request's num_computed_tokens
# can catch up its num_tokens_with_spec.

14.4 Persistent InputBatch

V1's InputBatch is a pre-allocated set of CPU tensors that persists across scheduling steps. Instead of building a new batch tensor every iteration:

# Key tensors in InputBatch (all pre-allocated):
self.token_ids_cpu_tensor = torch.zeros((max_num_reqs, max_model_len), dtype=torch.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros((max_num_reqs,), dtype=torch.int32)
self.temperatures_cpu_tensor = torch.empty((max_num_reqs,), dtype=torch.float32)
self.top_p_cpu_tensor = torch.empty((max_num_reqs,), dtype=torch.float32)
self.block_table = MultiGroupBlockTable(max_num_reqs=max_num_reqs, ...)

14.5 The Tight Busy Loop

The core scheduling loop is deliberately minimal: schedule() -> execute_model() -> update_from_output(). There is no complex state machine or mode switching. Background threads handle I/O while the main thread focuses exclusively on the schedule-execute-update cycle.

# The entire busy loop is just these lines:
while self._handle_shutdown():
    self._process_input_queue()      # Drain new requests from ZMQ
    self._process_engine_step()      # schedule() + execute() + update()

15. Key Files Quick Reference

ComponentFilePurpose
HTTP Entrypointvllm/entrypoints/openai/api_server.pyFastAPI app, lifespan, build_async_engine_client
Chat Handlervllm/entrypoints/openai/chat_completion/serving.pyOpenAIServingChat.create_chat_completion()
Chat Routervllm/entrypoints/openai/chat_completion/api_router.pyPOST /v1/chat/completions route
AsyncLLMvllm/v1/engine/async_llm.pyAPI-facing engine: tokenize, queue, output_handler
Input Processorvllm/v1/engine/input_processor.pyConvert EngineInput to EngineCoreRequest
Output Processorvllm/v1/engine/output_processor.pyDetokenize EngineCoreOutput to RequestOutput
Core Clientvllm/v1/engine/core_client.pyZMQ IPC client: AsyncMPClient, SyncMPClient, InprocClient
EngineCorevllm/v1/engine/core.pyEngineCore + EngineCoreProc (busy loop, I/O threads)
Data Structuresvllm/v1/engine/__init__.pyEngineCoreRequest, EngineCoreOutput, EngineCoreOutputs
Schedulervllm/v1/core/sched/scheduler.pyUnified scheduler: schedule(), update_from_output()
Scheduler Outputvllm/v1/core/sched/output.pySchedulerOutput, NewRequestData, CachedRequestData
KV Cache Managervllm/v1/core/kv_cache_manager.pyBlock allocation, prefix caching, eviction
GPU Workervllm/v1/worker/gpu_worker.pyDevice init, weight loading, profiling
GPU Model Runnervllm/v1/worker/gpu_model_runner.pyForward pass, attention metadata, sampling
Input Batchvllm/v1/worker/gpu_input_batch.pyPersistent batch with pre-allocated tensors
Model Outputsvllm/v1/outputs.pyModelRunnerOutput, SamplerOutput, LogprobsLists
Serializationvllm/v1/serial_utils.pyMsgpackEncoder / MsgpackDecoder

16. Complete Request Timeline

1

HTTP POST arrives at FastAPI

create_chat_completion() in api_router.py receives the ChatCompletionRequest.

2

Chat Template + Tokenization (API Process)

OpenAIServingChat.render_chat_request() applies the chat template and tokenizes the conversation into prompt_token_ids.

3

AsyncLLM.generate() begins (API Process)

InputProcessor.process_inputs() builds an EngineCoreRequest with the tokenized prompt. OutputProcessor.add_request() creates a RequestState with an IncrementalDetokenizer. A RequestOutputCollector queue is created.

4

ZMQ Send (API --> EngineCore)

AsyncMPClient.add_request_async() serializes the EngineCoreRequest with MsgpackEncoder and sends it over the ZMQ ROUTER socket with request type byte b"\x00" (ADD).

5

Input Thread Receives (EngineCore Process)

The input I/O thread receives the multipart ZMQ message, deserializes it with MsgpackDecoder(EngineCoreRequest), calls preprocess_add_request() to create a Request object, and pushes it to input_queue.

6

Busy Loop Drains Queue (EngineCore Process)

_process_input_queue() drains all pending requests from the queue. EngineCore.add_request() passes the Request to Scheduler.add_request() which adds it to the waiting queue.

7

Scheduler.schedule() (EngineCore Process)

Iterates over running requests first, then waiting requests. For each: calculates num_new_tokens, calls KVCacheManager.allocate_slots(), deducts from the token budget. Produces a SchedulerOutput with all scheduling decisions.

8

GPUModelRunner.execute_model() (GPU)

Updates the persistent InputBatch with the scheduler's decisions via _update_states(). Prepares attention metadata, determines CUDA graph mode, runs the model forward pass, then samples output tokens. Returns ModelRunnerOutput.

9

Scheduler.update_from_output() (EngineCore Process)

For each request: appends generated tokens, checks stop conditions (EOS, max_tokens, stop strings), updates num_computed_tokens. Frees KV cache blocks for finished requests. Builds EngineCoreOutput per request, wrapped in EngineCoreOutputs.

10

ZMQ Send (EngineCore --> API)

The output I/O thread serializes EngineCoreOutputs with MsgpackEncoder.encode_into() (reusing buffers) and sends via ZMQ PUSH socket. The API process receives on the PULL socket.

11

output_handler() Detokenizes (API Process)

The background asyncio task calls OutputProcessor.process_outputs() which: (a) incremental-detokenizes new token IDs into text, (b) builds RequestOutput objects, (c) pushes them into per-request RequestOutputCollector queues.

12

SSE Response Streams to Client (API Process)

The generate() coroutine yields the RequestOutput. The OpenAI serving layer formats it as an SSE event (data: {...}\n\n). FastAPI's StreamingResponse sends it over HTTP. Steps 7-12 repeat for each token until the request finishes.

Iteration loop: Steps 6-12 happen repeatedly. Each iteration of the EngineCore busy loop processes ALL active requests in a single batch, generating one (or more, with speculative decoding) token per request. The API process continuously receives outputs and streams them. A request typically goes through hundreds of iterations before finishing.

17. EngineCoreClient Variants

File: vllm/v1/engine/core_client.py

vLLM provides three client variants depending on the execution mode. The factory method selects the appropriate one:

class EngineCoreClient(ABC):
    """
    EngineCoreClient: subclasses handle different methods for pushing
        and pulling from the EngineCore for asyncio / multiprocessing.

    Subclasses:
    * InprocClient: In process EngineCore (for V0-style LLMEngine use)
    * SyncMPClient: ZMQ + background proc EngineCore (for LLM)
    * AsyncMPClient: ZMQ + background proc EngineCore w/ asyncio (for AsyncLLM)
    """

    @staticmethod
    def make_client(
        multiprocess_mode: bool,
        asyncio_mode: bool,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
    ) -> "EngineCoreClient":
        if multiprocess_mode and asyncio_mode:
            return EngineCoreClient.make_async_mp_client(
                vllm_config, executor_class, log_stats
            )
        if multiprocess_mode and not asyncio_mode:
            return SyncMPClient(vllm_config, executor_class, log_stats)
        return InprocClient(vllm_config, executor_class, log_stats)
ClientModeUse CaseIPC
AsyncMPClientMulti-process + asyncioOpenAI API server (AsyncLLM)ZMQ ROUTER/DEALER + PUSH/PULL
SyncMPClientMulti-process + syncOffline batch inference (LLM class)ZMQ ROUTER/DEALER + PUSH/PULL
InprocClientIn-processDebugging, V0-style LLMEngineDirect function calls
DPAsyncMPClientData parallel + external LBMulti-engine with external load balancerZMQ per DP rank
DPLBAsyncMPClientData parallel + internal LBMulti-engine with built-in load balancingZMQ + DPCoordinator
# The InprocClient is the simplest -- direct in-process calls
class InprocClient(EngineCoreClient):
    def __init__(self, *args, **kwargs):
        self.engine_core = EngineCore(*args, **kwargs)

    def get_output(self) -> EngineCoreOutputs:
        outputs, model_executed = self.engine_core.step_fn()
        self.engine_core.post_step(model_executed=model_executed)
        return outputs and outputs.get(0) or EngineCoreOutputs()

    def add_request(self, request: EngineCoreRequest) -> None:
        req, request_wave = self.engine_core.preprocess_add_request(request)
        self.engine_core.add_request(req, request_wave)

18. Threading Model in the EngineCore Process

EngineCore Process -- 3 Threads with GIL Overlap
Input IO Thread
ZMQ recv (releases GIL)
msgpack decode
preprocess → input_queue
Main Thread (Busy Loop)
drain input_queue
schedule() + execute_model()
update_from_output()
Releases GIL during CUDA kernels
Output IO Thread
output_queue → encode
buffer reuse pool
ZMQ send (releases GIL)

The EngineCore process has a carefully designed three-thread architecture to maximize overlap between I/O and computation:

ThreadRoleGIL Behavior
Main Thread Runs run_busy_loop(): drains input_queue, calls step() (schedule + execute + update), pushes to output_queue Holds GIL during Python logic; releases during GPU kernel execution (CUDA is async)
Input I/O Thread Runs process_input_sockets(): reads from ZMQ DEALER, deserializes with msgpack, pushes to input_queue Releases GIL during zmq.recv_multipart() and poller.poll()
Output I/O Thread Runs process_output_sockets(): pulls from output_queue, serializes with msgpack, sends via ZMQ PUSH Releases GIL during zmq.send_multipart()
# In EngineCoreProc.__init__():
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
input_thread = threading.Thread(
    target=self.process_input_sockets,
    args=(addresses.inputs, addresses.coordinator_input, identity, ready_event),
    daemon=True,
)
input_thread.start()

self.output_thread = threading.Thread(
    target=self.process_output_sockets,
    args=(addresses.outputs, ...),
    daemon=True,
)
self.output_thread.start()
Why threading works here: Even though Python has the GIL, ZMQ socket operations release it. This means the input thread can receive the next request while the main thread is running GPU kernels (which also release the GIL). The output thread can serialize and send results while the next scheduling step begins. This three-thread design is simpler than the two-process split between API and EngineCore, but follows the same principle: overlap I/O with computation by using operations that release the GIL.

19. Performance Optimization Details

Concurrent Operations Timeline (Single Step)
Input Thread
ZMQ recv (GIL free)
decode + preprocess
ZMQ recv next...
Main Thread
schedule()
execute_model() -- GPU forward (GIL free)
update()
Output Thread
encode prev
ZMQ send (GIL free)
GPU
Attention kernels + MLP + Sampling
async copy

msgspec Structs with gc=False

All cross-process data structures use msgspec.Struct with gc=False to disable Python's cyclic garbage collector tracking. Combined with array_like=True for compact serialization and omit_defaults=True to skip default-valued fields, this minimizes both serialization overhead and GC pause times.

GC Heap Freezing

After initialization, EngineCore calls freeze_gc_heap() to mark all existing objects as permanent. This prevents the GC from scanning startup allocations (model weights, config objects) during collection cycles, significantly reducing GC pause times.

# At the end of EngineCore.__init__():

# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
freeze_gc_heap()
# Enable environment variable cache
enable_envs_cache()

Zero-Copy Buffer Reuse in Output Thread

The output I/O thread reuses serialization buffers via encode_into(outputs, buffer) and tracks ZMQ message delivery with zmq.MessageTracker. Buffers are only reclaimed when ZMQ confirms the message has been sent, preventing premature deallocation of backing memory for numpy arrays and tensors.

# Output thread buffer management:
reuse_buffers: list[bytearray] = []
pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()

# Reclaim buffers that zmq is finished with
while pending and pending[-1][0].done:
    reuse_buffers.append(pending.pop()[2])

buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
buffers = encoder.encode_into(outputs, buffer)
tracker = socket.send_multipart(buffers, copy=False, track=True)

Async Output Copy Stream

GPU-to-CPU copies of sampled tokens are performed on a separate CUDA stream, overlapping the copy with the next forward pass. The AsyncGPUModelRunnerOutput wrapper lazily materializes the CPU result only when needed.

# vllm/v1/worker/gpu_model_runner.py
class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
    def __init__(self, model_runner_output, sampled_token_ids, ...):
        # Initiate copy on a separate stream (non-blocking)
        with torch.cuda.stream(async_output_copy_stream):
            async_output_copy_stream.wait_stream(default_stream)
            self.sampled_token_ids_cpu = self._sampled_token_ids.to(
                "cpu", non_blocking=True
            )
            self.async_copy_ready_event.record()

    def get_output(self) -> ModelRunnerOutput:
        """Blocks until copy is finished, then returns."""
        self.async_copy_ready_event.synchronize()
        # Convert to Python lists
        valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
        ...

Output Processing Chunking

The API process splits output processing into chunks of VLLM_V1_OUTPUT_PROC_CHUNK_SIZE to avoid blocking the asyncio event loop when there are many concurrent requests. Between chunks, it yields control with await asyncio.sleep(0).

# In output_handler():
chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE

for start in range(0, num_outputs, chunk_size):
    end = start + chunk_size
    outputs_slice = engine_core_outputs[start:end]
    processed_outputs = output_processor.process_outputs(outputs_slice, ...)
    # Allow other asyncio tasks to run between chunks
    if end < num_outputs:
        await asyncio.sleep(0)

Generated from vLLM V1 source code analysis. All code snippets are from the actual repository.

Back to Overview | EngineCore Deep Dive | Scheduler Deep Dive | ModelRunner Deep Dive