Tracing a single HTTP chat-completion request end-to-end through vLLM's multi-process architecture -- from FastAPI handler to GPU forward pass and back, with actual source code.
Source: vllm/v1/ -- live codebase analysis| Step | Process | Component | What Happens |
|---|---|---|---|
| 1 | API | FastAPI Router | HTTP POST /v1/chat/completions received |
| 2 | API | OpenAIServingChat | Parse request, apply chat template, tokenize prompt |
| 3 | API | AsyncLLM.generate() | Create EngineCoreRequest, register with OutputProcessor |
| 4 | API | AsyncMPClient | Serialize via msgpack, send over ZMQ ROUTER socket |
| 5 | Core | EngineCoreProc input thread | Receive from ZMQ DEALER, deserialize, push to input_queue |
| 6 | Core | Scheduler.schedule() | Allocate KV blocks, create token budget, build SchedulerOutput |
| 7 | Core | Executor.execute_model() | Send SchedulerOutput to Worker process(es) |
| 8 | Core | GPUModelRunner.execute_model() | Build attention metadata, run model forward, sample tokens |
| 9 | Core | Scheduler.update_from_output() | Update request state, detect stop conditions, build EngineCoreOutputs |
| 10 | Core | EngineCoreProc output thread | Serialize EngineCoreOutputs via msgpack, send over ZMQ PUSH |
| 11 | API | AsyncLLM output_handler | Receive outputs, run OutputProcessor (detokenize) |
| 12 | API | FastAPI StreamingResponse | Yield SSE chunks back to client |
Every request begins at a FastAPI route. The create_chat_completion handler is decorated with @with_cancellation (for client disconnect handling) and @load_aware_call.
@router.post("/v1/chat/completions", ...)
@with_cancellation
@load_aware_call
async def create_chat_completion(
request: ChatCompletionRequest, raw_request: Request
):
handler = chat(raw_request) # Get OpenAIServingChat from app state
generator = await handler.create_chat_completion(request, raw_request)
if isinstance(generator, ChatCompletionResponse):
return JSONResponse(content=generator.model_dump(), ...)
return StreamingResponse(content=generator, media_type="text/event-stream")
The OpenAIServingChat.create_chat_completion() method applies the chat template, tokenizes the prompt, builds sampling parameters, then calls engine_client.generate().
async def create_chat_completion(self, request, raw_request):
# Apply chat template and tokenize
result = await self.render_chat_request(request)
conversation, engine_inputs = result
request_id = f"chatcmpl-{self._base_request_id(...)}"
sampling_params = request.to_sampling_params(max_tokens, ...)
# engine_client is AsyncLLM -- calls generate() below
generator = self.engine_client.generate(
prompt=engine_input,
sampling_params=sampling_params,
request_id=sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
data_parallel_rank=data_parallel_rank,
)
AsyncLLM is the beating heart of the API process. It holds three critical components:
class AsyncLLM(EngineClient):
def __init__(self, vllm_config, executor_class, ...):
# Convert EngineInput --> EngineCoreRequest (tokenization)
self.input_processor = InputProcessor(self.vllm_config, renderer)
# Converts EngineCoreOutputs --> RequestOutput (detokenization)
self.output_processor = OutputProcessor(
renderer.tokenizer,
log_stats=self.log_stats,
stream_interval=self.vllm_config.scheduler_config.stream_interval,
)
# EngineCore (starts the engine in background process)
self.engine_core = EngineCoreClient.make_async_mp_client(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
)
async def generate(self, prompt, sampling_params, request_id, ...):
"""
Main function called by the API server to kick off a request
* 1) Making an AsyncStream corresponding to the Request.
* 2) Processing the Input.
* 3) Adding the Request to the Detokenizer.
* 4) Adding the Request to the EngineCore (separate process).
"""
q = await self.add_request(request_id, prompt, sampling_params, ...)
# The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
finished = False
while not finished:
out = q.get_nowait() or await q.get()
finished = out.finished
if out is not STREAM_FINISHED:
yield out
async def add_request(self, request_id, prompt, params, ...):
# Convert Input --> EngineCoreRequest (includes tokenization)
request = self.input_processor.process_inputs(
request_id, prompt, params,
supported_tasks=await self.get_supported_tasks(),
arrival_time=arrival_time,
lora_request=lora_request,
)
# Create output collector for this request
queue = RequestOutputCollector(params.output_kind, request.request_id)
await self._add_request(request, prompt_text, None, 0, queue)
return queue
async def _add_request(self, request, prompt, parent_req, index, queue):
# Register with OutputProcessor (this process - for detokenization later)
self.output_processor.add_request(request, prompt, parent_req, index, queue)
# Send to EngineCore (separate process) via ZMQ
await self.engine_core.add_request_async(request)
# vllm/v1/engine/__init__.py
class EngineCoreRequest(
msgspec.Struct,
array_like=True, # compact serialization
omit_defaults=True, # skip default values to save bytes
gc=False, # disable cyclic GC tracking
):
request_id: str
prompt_token_ids: list[int] | None
mm_features: list[MultiModalFeatureSpec] | None
sampling_params: SamplingParams | None
pooling_params: PoolingParams | None
arrival_time: float
lora_request: LoRARequest | None
cache_salt: str | None
data_parallel_rank: int | None
prompt_embeds: torch.Tensor | None = None
client_index: int = 0
priority: int = 0
The AsyncMPClient sends requests to EngineCore over a ZMQ ROUTER socket. The EngineCore process connects back with DEALER sockets. Messages are serialized with msgpack for speed.
class MPClient(EngineCoreClient):
def __init__(self, asyncio_mode, vllm_config, executor_class, ...):
# ZMQ setup
sync_ctx = zmq.Context(io_threads=2)
self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx
# ROUTER socket for sending requests (API -> Core)
self.input_socket = make_zmq_socket(
self.ctx, addresses.inputs[0], zmq.ROUTER, bind=True
)
# PULL socket for receiving outputs (Core -> API)
self.resources.output_socket = make_zmq_socket(
self.ctx, addresses.outputs[0], zmq.PULL
)
# Serialization with optional tensor IPC for multimodal
self.encoder = MsgpackEncoder(oob_tensor_consumer=tensor_ipc_sender)
self.decoder = MsgpackDecoder(EngineCoreOutputs)
In the EngineCore process, a dedicated input thread reads from ZMQ and pushes into a Python queue. This thread releases the GIL during ZMQ I/O, allowing the main thread to run GPU operations concurrently.
def process_input_sockets(self, input_addresses, coord_input_address,
identity, ready_event):
"""Input socket IO thread."""
add_request_decoder = MsgpackDecoder(
EngineCoreRequest, oob_tensor_provider=self.tensor_ipc_receiver
)
# ZMQ DEALER sockets connect to the API ROUTER
input_sockets = [
make_zmq_socket(ctx, addr, zmq.DEALER, identity=identity, bind=False)
for addr in input_addresses
]
while True:
for input_socket, _ in poller.poll():
type_frame, *data_frames = input_socket.recv_multipart(copy=False)
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
if request_type == EngineCoreRequestType.ADD:
req = add_request_decoder.decode(data_frames)
request = self.preprocess_add_request(req)
# Push to input queue for core busy loop
self.input_queue.put_nowait((request_type, request))
EngineCore's main thread runs a tight busy loop that orchestrates three core operations in sequence every iteration.
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
def run_busy_loop(self):
"""Core busy loop of the EngineCore."""
while self._handle_shutdown():
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# 2) Step the engine core and return the outputs.
self._process_engine_step()
raise SystemExit
def _process_input_queue(self):
"""Exits when an engine step needs to be performed."""
while not self.has_work() and self.is_running():
# Block on input queue when idle
block = self.process_input_queue_block
try:
req = self.input_queue.get(block=block)
self._handle_client_request(*req)
except queue.Empty:
break
# Handle any more client requests (non-blocking drain)
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(*req)
This is the single most important function in vLLM V1. It runs every iteration of the busy loop:
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
"""Schedule, execute, and make output."""
# Check for any requests remaining in the scheduler
if not self.scheduler.has_requests():
return {}, False
# 1. SCHEDULE: Decide which requests get tokens this step
scheduler_output = self.scheduler.schedule()
# 2. EXECUTE: Run model forward pass on the GPU
future = self.model_executor.execute_model(scheduler_output, non_block=True)
# 3. GRAMMAR: Get structured output bitmask (if needed)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
# 4. WAIT: Block until GPU forward + sampling complete
model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)
# 5. UPDATE: Process results, detect stop conditions
self._process_aborts_queue()
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
def _process_engine_step(self) -> bool:
"""Called only when there are unfinished local requests."""
# Step the engine core
outputs, model_executed = self.step_fn()
# Put EngineCoreOutputs into the output queue
for output in outputs.items() if outputs else ():
self.output_queue.put_nowait(output)
# Post-step hook (e.g., update draft tokens for spec decode)
self.post_step(model_executed)
return model_executed
The V1 Scheduler uses a unified token budget that covers both prefill and decode tokens. There is no separate "prefill phase" vs "decode phase" -- each request simply has num_computed_tokens that it tries to catch up to num_tokens_with_spec.
class Scheduler(SchedulerInterface):
def __init__(self, vllm_config, kv_cache_config, ...):
# Scheduling constraints
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
self.max_num_scheduled_tokens = (
self.scheduler_config.max_num_scheduled_tokens
if self.scheduler_config.max_num_scheduled_tokens
else self.scheduler_config.max_num_batched_tokens
)
# Priority queues
self.waiting = create_request_queue(self.policy)
self.running: list[Request] = []
# Create the KV cache manager
self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
enable_caching=self.cache_config.enable_prefix_caching,
)
def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has num_computed_tokens and num_tokens_with_spec.
# At each step, the scheduler tries to assign tokens to the requests
# so that each request's num_computed_tokens can catch up its
# num_tokens_with_spec.
token_budget = self.max_num_scheduled_tokens
# First, schedule the RUNNING requests
req_index = 0
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
num_new_tokens = (
request.num_tokens_with_spec
+ request.num_output_placeholders
- request.num_computed_tokens
)
num_new_tokens = min(num_new_tokens, token_budget)
# Allocate KV cache blocks for new tokens
new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens,
num_lookahead_tokens=self.num_lookahead_tokens,
)
if new_blocks is None:
# Out of memory -- preempt lowest-priority request
preempted_req = self.running.pop()
self._preempt_request(preempted_req, ...)
continue
# Schedule the request
num_scheduled_tokens[request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
# Then, schedule WAITING requests with remaining budget
# (similar loop over self.waiting queue)
...
The KVCacheManager is responsible for allocating and freeing blocks of GPU memory for KV caches. It supports prefix caching through block hashing.
class KVCacheManager:
def __init__(self, kv_cache_config, max_model_len, ...):
self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
enable_caching=self.enable_caching,
)
self.block_pool = self.coordinator.block_pool
def get_computed_blocks(self, request) -> tuple[KVCacheBlocks, int]:
"""Find prefix cache hits for a request."""
if not self.enable_caching:
return self.empty_kv_cache_blocks, 0
max_cache_hit_length = request.num_tokens - 1
computed_blocks, num_new_computed_tokens = (
self.coordinator.find_longest_cache_hit(
request.block_hashes, max_cache_hit_length
)
)
return self.create_kv_cache_blocks(computed_blocks), num_new_computed_tokens
# The KVCacheBlocks data structure
@dataclass
class KVCacheBlocks:
blocks: tuple[Sequence[KVCacheBlock], ...]
"""
blocks[i][j] refers to the i-th kv_cache_group
and the j-th block of tokens.
"""
def get_block_ids(self) -> tuple[list[int], ...]:
return tuple([blk.block_id for blk in group] for group in self.blocks)
The Worker manages the GPU device, initializes distributed communication, and hosts the GPUModelRunner. In a single-GPU setup, it runs inside the EngineCore process. With tensor parallelism, separate worker processes handle each GPU.
class Worker(WorkerBase):
def __init__(self, vllm_config, local_rank, rank, ...):
super().__init__(
vllm_config=vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
The GPUModelRunner is the most complex component. It manages the persistent InputBatch, builds attention metadata, runs the model forward pass, and samples output tokens.
class GPUModelRunner(
LoRAModelRunnerMixin, KVConnectorModelRunnerMixin, ECConnectorModelRunnerMixin
):
def __init__(self, vllm_config, device):
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
self.device = device
self.dtype = self.model_config.dtype
...
def execute_model(
self, scheduler_output: SchedulerOutput,
intermediate_tensors: IntermediateTensors | None = None,
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors | None:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# 1. Update persistent batch states from SchedulerOutput
deferred_state_corrections_fn = self._update_states(scheduler_output)
if not num_scheduled_tokens:
return EMPTY_MODEL_RUNNER_OUTPUT
num_reqs = self.input_batch.num_reqs
req_ids = self.input_batch.req_ids
# 2. Prepare inputs (build attention metadata, positions, etc.)
logits_indices, spec_decode_metadata = self._prepare_inputs(
scheduler_output, num_scheduled_tokens_np,
)
# 3. Determine CUDA graph mode and padding
(cudagraph_mode, batch_desc, should_ubatch,
num_tokens_across_dp, cudagraph_stats,
) = self._determine_batch_execution_and_padding(
num_tokens=num_tokens_unpadded,
num_reqs=num_reqs,
num_scheduled_tokens_np=num_scheduled_tokens_np,
max_num_scheduled_tokens=max_num_scheduled_tokens,
...
)
# 4. Build attention metadata per layer
# 5. Run the actual model forward pass
# 6. Sample tokens from logits
# 7. Return ModelRunnerOutput
# vllm/v1/worker/gpu_input_batch.py
class InputBatch:
def __init__(self, max_num_reqs, max_model_len, max_num_batched_tokens, device, ...):
# Pre-allocated CPU tensors (persistent across steps)
self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len), device="cpu", dtype=torch.int32
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.num_computed_tokens_cpu_tensor = torch.zeros(
(max_num_reqs,), device="cpu", dtype=torch.int32
)
# Block table (maps request -> KV cache blocks)
self.block_table = MultiGroupBlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
...
)
# Sampling metadata (pre-allocated)
self.temperatures_cpu_tensor = torch.empty(...)
self.top_p_cpu_tensor = torch.empty(...)
self.top_k_cpu_tensor = torch.empty(...)
After the forward pass and sampling, the GPU produces a ModelRunnerOutput containing the sampled token IDs for every request in the batch.
# vllm/v1/outputs.py
@dataclass
class ModelRunnerOutput:
# [num_reqs]
req_ids: list[str]
req_id_to_index: dict[str, int]
# num_reqs x num_generated_tokens
sampled_token_ids: list[list[int]]
# [num_reqs, max_num_logprobs + 1]
logprobs: LogprobsLists | None = None
# req_id -> (token_ids, logprobs, ranks)
prompt_logprobs_dict: dict[str, LogprobsTensors | None]
# [num_reqs, hidden_size]
pooler_output: list[torch.Tensor | None] | None = None
After model execution, the scheduler updates every request's state with the newly generated tokens, checks stop conditions, and produces EngineCoreOutput objects.
def update_from_output(
self,
scheduler_output: SchedulerOutput,
model_runner_output: ModelRunnerOutput,
) -> dict[int, EngineCoreOutputs]:
sampled_token_ids = model_runner_output.sampled_token_ids
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
request = self.requests.get(req_id)
if request is None or request.is_finished():
continue
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index]
# Update request with new tokens, check stop conditions
new_token_ids, stopped = self._update_request_with_output(
request, generated_token_ids
)
if stopped:
finish_reason = request.get_finished_reason()
finished = self._handle_stopped_request(request)
if finished:
kv_transfer_params = self._free_request(request)
# Build EngineCoreOutput for this request
output = EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
finish_reason=finish_reason,
...
)
outputs[client_index].append(output)
# EngineCoreOutput -- what crosses the ZMQ boundary back to the API process
class EngineCoreOutput(
msgspec.Struct, array_like=True, omit_defaults=True, gc=False
):
request_id: str
new_token_ids: list[int]
new_logprobs: LogprobsLists | None = None
finish_reason: FinishReason | None = None
stop_reason: int | str | None = None
num_cached_tokens: int = 0
# Wrapped in EngineCoreOutputs for batch sending
class EngineCoreOutputs(
msgspec.Struct, array_like=True, omit_defaults=True, gc=False
):
engine_index: int = 0
outputs: list[EngineCoreOutput] = []
scheduler_stats: SchedulerStats | None = None
timestamp: float = 0.0
The EngineCore's output thread serializes EngineCoreOutputs with msgpack and sends them over ZMQ PUSH sockets back to the API process.
def process_output_sockets(self, output_paths, coord_output_path, engine_index):
"""Output socket IO thread."""
encoder = MsgpackEncoder()
reuse_buffers: list[bytearray] = []
pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
sockets = [
make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
for output_path in output_paths
]
while True:
output = self.output_queue.get()
client_index, outputs = output
outputs.engine_index = engine_index
# Reclaim buffers that zmq is finished with
while pending and pending[-1][0].done:
reuse_buffers.append(pending.pop()[2])
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
buffers = encoder.encode_into(outputs, buffer)
tracker = sockets[client_index].send_multipart(
buffers, copy=False, track=True
)
Back in the API process, a background asyncio task continuously pulls outputs from EngineCore and pushes them through the OutputProcessor for detokenization.
async def output_handler():
try:
while True:
# 1) Pull EngineCoreOutputs from the EngineCore
outputs = await engine_core.get_output_async()
num_outputs = len(outputs.outputs)
iteration_stats = IterationStats() if (log_stats and num_outputs) else None
# Split outputs into chunks to avoid blocking the event loop
engine_core_outputs = outputs.outputs
for start in range(0, num_outputs, chunk_size):
end = start + chunk_size
outputs_slice = engine_core_outputs[start:end]
# 2) Process: detokenize + build RequestOutput objects
processed_outputs = output_processor.process_outputs(
outputs_slice, outputs.timestamp, iteration_stats
)
# Allow other asyncio tasks to run between chunks
if end < num_outputs:
await asyncio.sleep(0)
# 3) Abort any reqs that finished due to stop strings
if processed_outputs.reqs_to_abort:
await engine_core.abort_requests_async(
processed_outputs.reqs_to_abort
)
except Exception as e:
output_processor.propagate_error(e)
This is the only function that loops over every request output per iteration. It performs detokenization via the IncrementalDetokenizer and produces RequestOutput objects.
class OutputProcessor:
"""Process EngineCoreOutputs into RequestOutputs."""
def __init__(self, tokenizer, *, log_stats, stream_interval=1, ...):
self.tokenizer = tokenizer
self.request_states: dict[str, RequestState] = {}
def process_outputs(
self, engine_core_outputs: list[EngineCoreOutput],
engine_core_timestamp: float | None = None,
iteration_stats: IterationStats | None = None,
) -> OutputProcessorOutput:
"""
Process the EngineCoreOutputs:
1) Compute stats for logging
2) Detokenize
3) Create and handle RequestOutput objects
"""
for engine_core_output in engine_core_outputs:
req_id = engine_core_output.request_id
req_state = self.request_states.get(req_id)
# 1) Compute stats for this iteration
self._update_stats_from_output(req_state, engine_core_output, ...)
# 2) Detokenize new token IDs into text
# (via IncrementalDetokenizer in RequestState)
# ...
# 3) Build RequestOutput and push to per-request queue
request_output = req_state.make_request_output(
new_token_ids=new_token_ids,
finish_reason=finish_reason,
...
)
if request_output and req_state.queue:
req_state.queue.put(request_output)
# vllm/v1/engine/output_processor.py
class RequestOutputCollector:
def __init__(self, output_kind, request_id):
self.aggregate = output_kind == RequestOutputKind.DELTA
self.request_id = request_id
self.output = None
self.ready = asyncio.Event()
def put(self, output):
"""Non-blocking put operation."""
if self.output is None:
self.output = output
self.ready.set() # Wake up the generate() task
elif isinstance(self.output, RequestOutput):
self.output.add(output, aggregate=self.aggregate) # Merge deltas
async def get(self):
while (output := self.output) is None:
await self.ready.wait()
self.output = None
self.ready.clear()
return output
The generate() coroutine in AsyncLLM yields RequestOutput objects. The OpenAI serving layer wraps these into SSE-formatted JSON chunks that stream back through the FastAPI StreamingResponse.
# Back in AsyncLLM.generate():
while not finished:
# Fast path: drain queue without await if possible
out = q.get_nowait() or await q.get()
finished = out.finished
if out is not STREAM_FINISHED:
yield out # --> back to OpenAIServingChat --> SSE chunk
# The API server converts RequestOutput to SSE events:
# data: {"id":"chatcmpl-...","choices":[{"delta":{"content":"Hello"},...}]}
# data: {"id":"chatcmpl-...","choices":[{"delta":{"content":" world"},...}]}
# data: [DONE]
vLLM V1 splits the API server and EngineCore into separate OS processes. The API process runs the asyncio event loop (HTTP handling, tokenization, detokenization) while the EngineCore process runs the GPU scheduler and model in a tight busy loop. This eliminates GIL contention.
| Aspect | V0 | V1 |
|---|---|---|
| Process model | Single process | Multi-process (API + EngineCore) |
| IPC | In-process queues | ZeroMQ + msgpack |
| GIL impact | Asyncio and GPU compete | Fully overlapped |
| Serialization | N/A (in-process) | msgspec Structs (zero-copy where possible) |
Communication uses ZMQ sockets with msgspec for serialization. Key design choices:
# Request types are single-byte hex for minimal overhead
class EngineCoreRequestType(enum.Enum):
ADD = b"\x00"
ABORT = b"\x01"
START_DP_WAVE = b"\x02"
UTILITY = b"\x03"
EXECUTOR_FAILED = b"\x04"
WAKEUP = b"\x05"
The scheduler makes no distinction between "prefill tokens" and "decode tokens". Both consume from the same token budget per step. This enables:
# From the schedule() docstring:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and
# num_tokens_with_spec. At each step, the scheduler tries to assign
# tokens to the requests so that each request's num_computed_tokens
# can catch up its num_tokens_with_spec.
V1's InputBatch is a pre-allocated set of CPU tensors that persists across scheduling steps. Instead of building a new batch tensor every iteration:
# Key tensors in InputBatch (all pre-allocated):
self.token_ids_cpu_tensor = torch.zeros((max_num_reqs, max_model_len), dtype=torch.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros((max_num_reqs,), dtype=torch.int32)
self.temperatures_cpu_tensor = torch.empty((max_num_reqs,), dtype=torch.float32)
self.top_p_cpu_tensor = torch.empty((max_num_reqs,), dtype=torch.float32)
self.block_table = MultiGroupBlockTable(max_num_reqs=max_num_reqs, ...)
The core scheduling loop is deliberately minimal: schedule() -> execute_model() -> update_from_output(). There is no complex state machine or mode switching. Background threads handle I/O while the main thread focuses exclusively on the schedule-execute-update cycle.
# The entire busy loop is just these lines:
while self._handle_shutdown():
self._process_input_queue() # Drain new requests from ZMQ
self._process_engine_step() # schedule() + execute() + update()
| Component | File | Purpose |
|---|---|---|
| HTTP Entrypoint | vllm/entrypoints/openai/api_server.py | FastAPI app, lifespan, build_async_engine_client |
| Chat Handler | vllm/entrypoints/openai/chat_completion/serving.py | OpenAIServingChat.create_chat_completion() |
| Chat Router | vllm/entrypoints/openai/chat_completion/api_router.py | POST /v1/chat/completions route |
| AsyncLLM | vllm/v1/engine/async_llm.py | API-facing engine: tokenize, queue, output_handler |
| Input Processor | vllm/v1/engine/input_processor.py | Convert EngineInput to EngineCoreRequest |
| Output Processor | vllm/v1/engine/output_processor.py | Detokenize EngineCoreOutput to RequestOutput |
| Core Client | vllm/v1/engine/core_client.py | ZMQ IPC client: AsyncMPClient, SyncMPClient, InprocClient |
| EngineCore | vllm/v1/engine/core.py | EngineCore + EngineCoreProc (busy loop, I/O threads) |
| Data Structures | vllm/v1/engine/__init__.py | EngineCoreRequest, EngineCoreOutput, EngineCoreOutputs |
| Scheduler | vllm/v1/core/sched/scheduler.py | Unified scheduler: schedule(), update_from_output() |
| Scheduler Output | vllm/v1/core/sched/output.py | SchedulerOutput, NewRequestData, CachedRequestData |
| KV Cache Manager | vllm/v1/core/kv_cache_manager.py | Block allocation, prefix caching, eviction |
| GPU Worker | vllm/v1/worker/gpu_worker.py | Device init, weight loading, profiling |
| GPU Model Runner | vllm/v1/worker/gpu_model_runner.py | Forward pass, attention metadata, sampling |
| Input Batch | vllm/v1/worker/gpu_input_batch.py | Persistent batch with pre-allocated tensors |
| Model Outputs | vllm/v1/outputs.py | ModelRunnerOutput, SamplerOutput, LogprobsLists |
| Serialization | vllm/v1/serial_utils.py | MsgpackEncoder / MsgpackDecoder |
create_chat_completion() in api_router.py receives the ChatCompletionRequest.
OpenAIServingChat.render_chat_request() applies the chat template and tokenizes the conversation into prompt_token_ids.
InputProcessor.process_inputs() builds an EngineCoreRequest with the tokenized prompt. OutputProcessor.add_request() creates a RequestState with an IncrementalDetokenizer. A RequestOutputCollector queue is created.
AsyncMPClient.add_request_async() serializes the EngineCoreRequest with MsgpackEncoder and sends it over the ZMQ ROUTER socket with request type byte b"\x00" (ADD).
The input I/O thread receives the multipart ZMQ message, deserializes it with MsgpackDecoder(EngineCoreRequest), calls preprocess_add_request() to create a Request object, and pushes it to input_queue.
_process_input_queue() drains all pending requests from the queue. EngineCore.add_request() passes the Request to Scheduler.add_request() which adds it to the waiting queue.
Iterates over running requests first, then waiting requests. For each: calculates num_new_tokens, calls KVCacheManager.allocate_slots(), deducts from the token budget. Produces a SchedulerOutput with all scheduling decisions.
Updates the persistent InputBatch with the scheduler's decisions via _update_states(). Prepares attention metadata, determines CUDA graph mode, runs the model forward pass, then samples output tokens. Returns ModelRunnerOutput.
For each request: appends generated tokens, checks stop conditions (EOS, max_tokens, stop strings), updates num_computed_tokens. Frees KV cache blocks for finished requests. Builds EngineCoreOutput per request, wrapped in EngineCoreOutputs.
The output I/O thread serializes EngineCoreOutputs with MsgpackEncoder.encode_into() (reusing buffers) and sends via ZMQ PUSH socket. The API process receives on the PULL socket.
The background asyncio task calls OutputProcessor.process_outputs() which: (a) incremental-detokenizes new token IDs into text, (b) builds RequestOutput objects, (c) pushes them into per-request RequestOutputCollector queues.
The generate() coroutine yields the RequestOutput. The OpenAI serving layer formats it as an SSE event (data: {...}\n\n). FastAPI's StreamingResponse sends it over HTTP. Steps 7-12 repeat for each token until the request finishes.
vLLM provides three client variants depending on the execution mode. The factory method selects the appropriate one:
class EngineCoreClient(ABC):
"""
EngineCoreClient: subclasses handle different methods for pushing
and pulling from the EngineCore for asyncio / multiprocessing.
Subclasses:
* InprocClient: In process EngineCore (for V0-style LLMEngine use)
* SyncMPClient: ZMQ + background proc EngineCore (for LLM)
* AsyncMPClient: ZMQ + background proc EngineCore w/ asyncio (for AsyncLLM)
"""
@staticmethod
def make_client(
multiprocess_mode: bool,
asyncio_mode: bool,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
) -> "EngineCoreClient":
if multiprocess_mode and asyncio_mode:
return EngineCoreClient.make_async_mp_client(
vllm_config, executor_class, log_stats
)
if multiprocess_mode and not asyncio_mode:
return SyncMPClient(vllm_config, executor_class, log_stats)
return InprocClient(vllm_config, executor_class, log_stats)
| Client | Mode | Use Case | IPC |
|---|---|---|---|
AsyncMPClient | Multi-process + asyncio | OpenAI API server (AsyncLLM) | ZMQ ROUTER/DEALER + PUSH/PULL |
SyncMPClient | Multi-process + sync | Offline batch inference (LLM class) | ZMQ ROUTER/DEALER + PUSH/PULL |
InprocClient | In-process | Debugging, V0-style LLMEngine | Direct function calls |
DPAsyncMPClient | Data parallel + external LB | Multi-engine with external load balancer | ZMQ per DP rank |
DPLBAsyncMPClient | Data parallel + internal LB | Multi-engine with built-in load balancing | ZMQ + DPCoordinator |
# The InprocClient is the simplest -- direct in-process calls
class InprocClient(EngineCoreClient):
def __init__(self, *args, **kwargs):
self.engine_core = EngineCore(*args, **kwargs)
def get_output(self) -> EngineCoreOutputs:
outputs, model_executed = self.engine_core.step_fn()
self.engine_core.post_step(model_executed=model_executed)
return outputs and outputs.get(0) or EngineCoreOutputs()
def add_request(self, request: EngineCoreRequest) -> None:
req, request_wave = self.engine_core.preprocess_add_request(request)
self.engine_core.add_request(req, request_wave)
The EngineCore process has a carefully designed three-thread architecture to maximize overlap between I/O and computation:
| Thread | Role | GIL Behavior |
|---|---|---|
| Main Thread | Runs run_busy_loop(): drains input_queue, calls step() (schedule + execute + update), pushes to output_queue |
Holds GIL during Python logic; releases during GPU kernel execution (CUDA is async) |
| Input I/O Thread | Runs process_input_sockets(): reads from ZMQ DEALER, deserializes with msgpack, pushes to input_queue |
Releases GIL during zmq.recv_multipart() and poller.poll() |
| Output I/O Thread | Runs process_output_sockets(): pulls from output_queue, serializes with msgpack, sends via ZMQ PUSH |
Releases GIL during zmq.send_multipart() |
# In EngineCoreProc.__init__():
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
input_thread = threading.Thread(
target=self.process_input_sockets,
args=(addresses.inputs, addresses.coordinator_input, identity, ready_event),
daemon=True,
)
input_thread.start()
self.output_thread = threading.Thread(
target=self.process_output_sockets,
args=(addresses.outputs, ...),
daemon=True,
)
self.output_thread.start()
All cross-process data structures use msgspec.Struct with gc=False to disable Python's cyclic garbage collector tracking. Combined with array_like=True for compact serialization and omit_defaults=True to skip default-valued fields, this minimizes both serialization overhead and GC pause times.
After initialization, EngineCore calls freeze_gc_heap() to mark all existing objects as permanent. This prevents the GC from scanning startup allocations (model weights, config objects) during collection cycles, significantly reducing GC pause times.
# At the end of EngineCore.__init__():
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
freeze_gc_heap()
# Enable environment variable cache
enable_envs_cache()
The output I/O thread reuses serialization buffers via encode_into(outputs, buffer) and tracks ZMQ message delivery with zmq.MessageTracker. Buffers are only reclaimed when ZMQ confirms the message has been sent, preventing premature deallocation of backing memory for numpy arrays and tensors.
# Output thread buffer management:
reuse_buffers: list[bytearray] = []
pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
# Reclaim buffers that zmq is finished with
while pending and pending[-1][0].done:
reuse_buffers.append(pending.pop()[2])
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
buffers = encoder.encode_into(outputs, buffer)
tracker = socket.send_multipart(buffers, copy=False, track=True)
GPU-to-CPU copies of sampled tokens are performed on a separate CUDA stream, overlapping the copy with the next forward pass. The AsyncGPUModelRunnerOutput wrapper lazily materializes the CPU result only when needed.
# vllm/v1/worker/gpu_model_runner.py
class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
def __init__(self, model_runner_output, sampled_token_ids, ...):
# Initiate copy on a separate stream (non-blocking)
with torch.cuda.stream(async_output_copy_stream):
async_output_copy_stream.wait_stream(default_stream)
self.sampled_token_ids_cpu = self._sampled_token_ids.to(
"cpu", non_blocking=True
)
self.async_copy_ready_event.record()
def get_output(self) -> ModelRunnerOutput:
"""Blocks until copy is finished, then returns."""
self.async_copy_ready_event.synchronize()
# Convert to Python lists
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
...
The API process splits output processing into chunks of VLLM_V1_OUTPUT_PROC_CHUNK_SIZE to avoid blocking the asyncio event loop when there are many concurrent requests. Between chunks, it yields control with await asyncio.sleep(0).
# In output_handler():
chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
for start in range(0, num_outputs, chunk_size):
end = start + chunk_size
outputs_slice = engine_core_outputs[start:end]
processed_outputs = output_processor.process_outputs(outputs_slice, ...)
# Allow other asyncio tasks to run between chunks
if end < num_outputs:
await asyncio.sleep(0)
Generated from vLLM V1 source code analysis. All code snippets are from the actual repository.
Back to Overview | EngineCore Deep Dive | Scheduler Deep Dive | ModelRunner Deep Dive