A source-code deep dive into vLLM's multi-process engine design, ZeroMQ messaging,
msgpack serialization, and data-parallel coordination
vLLM V1 separates the engine into two process-level concerns: a frontend process (the API server or LLMEngine) that handles tokenization, detokenization, and client-facing I/O, and a backend process (the EngineCore) that runs the GPU-bound scheduling and model execution loop. These two processes communicate via ZeroMQ sockets with msgpack serialization.
The key insight of V1's architecture is process isolation: the GPU-bound EngineCore runs in its own process with dedicated I/O threads, while tokenization, detokenization, and client management live in the frontend. This separation is bridged by ZeroMQ for message passing and msgpack for efficient serialization.
vllm/v1/engine/core.py (EngineCore + EngineCoreProc), vllm/v1/engine/core_client.py (all client variants), vllm/v1/engine/async_llm.py (AsyncLLM frontend), vllm/v1/engine/output_processor.py (detokenization), vllm/v1/engine/coordinator.py (DP coordinator), vllm/v1/serial_utils.py (msgpack).
The EngineCore class in core.py implements the fundamental execution loop. It is the "inner loop of vLLM's Engine", as the source code states.
The EngineCoreProc subclass wraps EngineCore with ZMQ I/O and runs the busy loop in a background process. The loop is driven by _handle_shutdown() (line 1202): it returns True while shutdown_state == RUNNING, so normally the loop just alternates between draining the input queue and stepping the engine. When SIGINT / SIGTERM flips shutdown_state to REQUESTED, the helper drains or aborts in-flight requests and eventually returns False, letting the loop fall through to raise SystemExit:
def run_busy_loop(self):
"""Core busy loop of the EngineCore."""
# _handle_shutdown() returns True while shutdown_state == RUNNING.
# On SIGTERM/SIGINT the handler flips it to REQUESTED -- the helper
# then drains/aborts in-flight reqs and eventually returns False.
while self._handle_shutdown():
# 1) Poll the input queue until there is work to do.
# Blocks while has_work() is False and is_running() is True;
# releases the GIL so the Input I/O thread can run.
self._process_input_queue()
# 2) Step the engine core and return the outputs.
# step_fn is either step() or step_with_batch_queue()
# depending on pipeline-parallel config (see §2.3).
self._process_engine_step()
# shutdown_state == COMPLETED reached -- exit cleanly so the
# try/finally in run_engine_core() can call self.shutdown().
raise SystemExit
This loop has only two phases: drain the input queue (blocking on a condition — see §2.5), then step the engine once. The try / finally in run_engine_core() catches the SystemExit and calls engine_core.shutdown() to tear down the executor and scheduler cleanly.
The core computation happens in step() — vLLM's per-iteration heartbeat. In 0.19.0 the path dispatches the forward pass asynchronously (non_block=True), builds the grammar bitmask while the model is still computing, then waits on future.result() under logging context managers. An aborts queue is drained between the forward pass and the scheduler update so late aborts don't leak into the next iteration:
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
"""Schedule, execute, and make output.
Returns tuple of outputs and a flag indicating whether the model
was executed.
"""
# Fast-exit: nothing to do if scheduler has no live requests
# (neither running nor waiting nor finished-pending-cleanup).
if not self.scheduler.has_requests():
return {}, False
# 1) Scheduler picks the next batch: which running requests to
# decode, which waiting requests to prefill, which blocks to
# allocate / evict. Returns a SchedulerOutput.
scheduler_output = self.scheduler.schedule()
# 2) Dispatch the forward pass non-blocking (future-based).
# execute_model returns immediately; the worker is now running.
future = self.model_executor.execute_model(scheduler_output,
non_block=True)
# 3) While the GPU is busy, compute the grammar bitmask (structured
# output / JSON schema constraint) on the CPU -- pure overlap.
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
# 4) Join on the forward pass and emit iteration logging / error
# context. log_error_detail dumps the SchedulerOutput on failure;
# log_iteration_details records timing / shapes.
with (
self.log_error_detail(scheduler_output),
self.log_iteration_details(scheduler_output),
):
model_output = future.result()
# If the worker returned None it means sampling was deferred;
# drive sampling here with the grammar bitmask.
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)
# 5) Drain any aborts that arrived DURING model execution. Doing
# this before update_from_output prevents freed requests from
# receiving token appends they can no longer consume.
self._process_aborts_queue()
# 6) Feed the ModelRunnerOutput back into the scheduler so it
# appends sampled tokens, detects finished requests, frees
# blocks, and emits per-client EngineCoreOutput records.
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
When pipeline parallelism is enabled (max_concurrent_batches > 1), vLLM swaps in step_with_batch_queue(), which keeps a deque of in-flight (future, scheduler_output) pairs. Scheduling can run ahead of execution by up to batch_queue_size batches, so the GPU never waits for the next schedule() to finish:
# --- in __init__: pick the step function based on PP config ---
self.batch_queue_size = self.model_executor.max_concurrent_batches
self.batch_queue: Optional[deque[...]] = None
if self.batch_queue_size > 1:
logger.debug("Batch queue is enabled with size %d",
self.batch_queue_size)
self.batch_queue = deque(maxlen=self.batch_queue_size)
# step_fn is bound once at construction time -- the busy loop just
# calls self.step_fn() and never branches per iteration.
self.step_fn = (
self.step if self.batch_queue is None
else self.step_with_batch_queue
)
# --- step_with_batch_queue body (abridged) ---
batch_queue = self.batch_queue
if self.scheduler.has_requests():
# Schedule a NEW batch and dispatch it without waiting.
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output)
batch_queue.appendleft((future, scheduler_output))
# Don't block on the OLDEST batch unless the queue is full
# (i.e. we already have batch_queue_size outstanding) or
# the oldest one's future is already done.
if len(batch_queue) < self.batch_queue_size \
and not batch_queue[-1][0].done():
return None, True
# Block on the OLDEST in-flight batch and feed its output back.
future, scheduler_output = batch_queue.pop()
model_output = self.execute_model_with_error_logging(
lambda _: future.result(), scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output)
return engine_core_outputs, model_executed
def _process_engine_step(self) -> bool:
"""Called only when there are unfinished local requests."""
# Run one step. step_fn is either self.step (single-batch)
# or self.step_with_batch_queue (pipeline-parallel, see §2.3).
# outputs: dict[client_index, EngineCoreOutputs] | None
outputs, model_executed = self.step_fn()
# Hand each per-client EngineCoreOutputs tuple off to the
# Output I/O thread via an unbounded queue.Queue. The thread
# serializes with msgpack and sends on a ZMQ PUSH socket.
# put_nowait() never blocks the main thread.
for output in outputs.items() if outputs else ():
self.output_queue.put_nowait(output)
# Post-step hook: in EngineCore base class this picks up
# draft token IDs for speculative decoding and hands them to
# the scheduler. DP subclass also publishes request counts here.
self.post_step(model_executed)
# If no model execution happened but there are waiting requests
# (e.g., WAITING_FOR_REMOTE_KVS during disaggregated prefill),
# yield the GIL briefly so background threads -- NIXL handshake,
# connector polling, etc. -- can make progress. Without this
# sleep, the tight poll loop can starve those threads.
if not model_executed and self.scheduler.has_unfinished_requests():
time.sleep(0.001)
return model_executed
After each step, outputs are placed into self.output_queue (a queue.Queue). The dedicated Output I/O thread pops them, msgpack-encodes them, and sends them over a ZMQ PUSH socket to the frontend. The time.sleep(0.001) guard only fires when the scheduler has unfinished requests but no forward pass ran — typical with KV-transfer stalls (WAITING_FOR_REMOTE_KVS) where background threads need CPU time to make progress.
def _process_input_queue(self):
"""Exits when an engine step needs to be performed."""
waited = False
# Loop while has_work() returns False and shutdown hasn't been
# requested. has_work() is True iff any of:
# engines_running (DP wave active) or
# scheduler.has_requests() (work to step on) or
# batch_queue (in-flight PP batches).
while not self.has_work() and self.is_running():
# Fire any pending "engine became idle" callbacks -- used by
# DP coordinator handshake and KV-transfer waiters.
self._notify_idle_state_callbacks()
if self.input_queue.empty():
# Clear the aborts_queue while we have nothing else to do.
# Aborts also travel through input_queue -- this mirror is
# only used by step() for mid-forward-pass draining -- so
# clearing it here is safe and prevents unbounded growth.
with self.aborts_queue.mutex:
self.aborts_queue.queue.clear()
if logger.isEnabledFor(DEBUG):
logger.debug("EngineCore waiting for work.")
waited = True
# process_input_queue_block: True for the normal busy loop,
# False for tight polling paths where we don't want to sleep.
block = self.process_input_queue_block
try:
# Blocking get releases the GIL so Input I/O thread runs.
req = self.input_queue.get(block=block)
self._handle_client_request(*req)
except queue.Empty:
# Only reached when block=False -- exit and let caller retry.
break
if not block:
break
if waited:
logger.debug("EngineCore loop active.")
# Drain any messages that arrived while we were already running.
# Coalescing multiple ADDs into a single schedule() call is what
# lets late requests join the current batch.
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(*req)
When the engine has nothing to do (has_work() is False and is_running() is True), it parks on input_queue.get() — the main thread sleeps, the GIL is released, and the Input I/O thread keeps reading from ZMQ in the background. As soon as any message is pushed, the loop resumes. The aborts_queue is a lightweight mirror cleared here on idle; the real abort dispatch goes through input_queue (see §4.4). After work arrives, the final non-blocking drain coalesces multiple ADDs into the same schedule() call so they can be batched together.
EngineCoreProc.run_engine_core() is the static entry point the child process runs after spawn. It sets the process title, optionally switches to the MoE-aware DPEngineCoreProc, installs signal handlers that wake the engine via input_queue, and finally enters the busy loop:
@staticmethod
def run_engine_core(*args, dp_rank: int = 0,
local_dp_rank: int = 0, **kwargs):
"""Launch EngineCore busy loop in background process."""
# Ensure custom transformer configs survive pickling across spawn.
maybe_register_config_serialize_by_value()
engine_core: EngineCoreProc | None = None
signal_callback: SignalCallback | None = None
try:
vllm_config: VllmConfig = kwargs["vllm_config"]
parallel_config: ParallelConfig = vllm_config.parallel_config
data_parallel = parallel_config.data_parallel_size > 1 or dp_rank > 0
if data_parallel:
parallel_config.data_parallel_rank_local = local_dp_rank
process_title = f"EngineCore_DP{dp_rank}"
else:
process_title = "EngineCore"
set_process_title(process_title)
maybe_init_worker_tracer("vllm.engine_core", "engine_core", process_title)
decorate_logs()
# Per-DP-rank kv_transfer_config.engine_id uniqueness, so remote
# KV peers can route correctly in disaggregated setups.
if data_parallel and vllm_config.kv_transfer_config is not None:
vllm_config.kv_transfer_config.engine_id = (
f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
)
parallel_config.data_parallel_index = dp_rank
# MoE + DP takes the DPEngineCoreProc path so wave protocol and
# expert-parallel all-reduce can synchronize across ranks.
if data_parallel and vllm_config.model_config.is_moe:
parallel_config.data_parallel_rank = dp_rank
engine_core = DPEngineCoreProc(*args, **kwargs)
else:
# Non-MoE DP ranks are independent -- drop to DP=1 semantics
# but keep data_parallel_index for logging / observability.
parallel_config.data_parallel_size = 1
parallel_config.data_parallel_size_local = 1
parallel_config.data_parallel_rank = 0
engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
assert engine_core is not None
def wakeup_engine():
# Signal-safe wake path: pushing a WAKEUP message so the
# blocking input_queue.get() in _process_input_queue returns.
# Doing this from a signal handler directly would be unsafe --
# the main thread may be holding input_queue.mutex mid-put.
engine_core.input_queue.put_nowait((EngineCoreRequestType.WAKEUP, None))
signal_callback = SignalCallback(wakeup_engine)
def signal_handler(signum, frame):
# Set shutdown state; _handle_shutdown() will notice it and
# start draining/aborting in-flight requests next iteration.
engine_core.shutdown_state = EngineShutdownState.REQUESTED
signal_callback.trigger()
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
engine_core.run_busy_loop() # blocks until _handle_shutdown returns False
except SystemExit:
logger.debug("EngineCore exiting.")
raise
except Exception as e:
# Abnormal path -- send ENGINE_CORE_DEAD so the client
# fails fast instead of hanging on ZMQ recv.
if engine_core is None:
logger.exception("EngineCore failed to start.")
else:
logger.exception("EngineCore encountered a fatal error.")
engine_core._send_engine_dead()
raise e
finally:
# Restore default signal dispositions before shutting down so
# a second SIGINT during cleanup kills the process quickly.
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGINT, signal.SIG_DFL)
if signal_callback is not None:
signal_callback.stop()
if engine_core is not None:
engine_core.shutdown()
Two details worth calling out: (a) the DPEngineCoreProc path is gated on both data_parallel > 1 and model_config.is_moe — non-MoE DP ranks are treated as independent DP=1 processes; (b) signal handlers don't raise inside the handler — they flip shutdown_state to REQUESTED and push a WAKEUP message onto input_queue, so the blocking get() returns and _handle_shutdown() sees the state change on the next loop iteration.
EngineCoreProc.__init__() spawns two daemon threads that bridge ZMQ sockets with the main loop's queue.Queue objects. Both threads do all their CPU-heavy work (msgpack encode/decode, request preprocessing) while holding the GIL only briefly — they release it on every ZMQ syscall, letting the main thread run model forward passes in parallel.
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
ready_event = threading.Event()
# Input thread: ZMQ DEALER recv_multipart -> MsgpackDecoder
# -> preprocess_add_request() -> input_queue.put_nowait().
# Preprocess includes MM-cache lookup, block hashing,
# grammar compilation -- all of which run in parallel with GPU.
input_thread = threading.Thread(
target=self.process_input_sockets,
args=(addresses.inputs, addresses.coordinator_input,
identity, ready_event),
daemon=True)
input_thread.start()
# Output thread: output_queue.get() -> MsgpackEncoder
# -> ZMQ PUSH send_multipart (zero-copy when possible).
# Owns a bytearray reuse pool to avoid allocation in the hot path.
self.output_thread = threading.Thread(
target=self.process_output_sockets,
args=(addresses.outputs, addresses.coordinator_output,
self.engine_index),
daemon=True)
self.output_thread.start()
# Wait for the Input thread to finish its initial handshake
# with the DP coordinator (if one is configured) before
# returning from __init__.
while not ready_event.wait(timeout=10):
if not input_thread.is_alive():
raise RuntimeError("Input socket thread died during startup")
logger.info("Waiting for READY message from DP Coordinator...")
Before the busy loop starts, EngineCoreProc performs a two-phase handshake with the frontend on a dedicated ZMQ DEALER socket. Phase 1 (HELLO) announces the engine and receives the ZMQ addresses it should bind its input/output/coordinator sockets to. Phase 2 (READY) reports back the KV-cache block count so the frontend knows the engine finished profiling and is ready for traffic:
@staticmethod
def startup_handshake(
handshake_socket: zmq.Socket,
local_client: bool,
headless: bool,
parallel_config: ParallelConfig | None = None,
) -> EngineZmqAddresses:
# Phase 1: HELLO. Advertise this engine's identity + deployment
# flags so the frontend can route to us and pick DP mode.
handshake_socket.send(
msgspec.msgpack.encode({
"status": "HELLO",
"local": local_client, # colocated with front-end?
"headless": headless, # no API server in this proc?
})
)
# Receive initialization message.
logger.debug("Waiting for init message from front-end.")
# Hard timeout protects us from a silently-dead parent.
if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
raise RuntimeError(
"Did not receive response from front-end "
f"process within {HANDSHAKE_TIMEOUT_MINS} "
f"minutes"
)
init_bytes = handshake_socket.recv()
init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
init_bytes, type=EngineHandshakeMetadata
)
logger.debug("Received init message: %s", init_message)
# Frontend can push parallel-config overrides (e.g. the actual
# DP mesh size after launcher resolves auto-discovery).
if parallel_config is not None:
for key, value in init_message.parallel_config.items():
setattr(parallel_config, key, value)
# Returns the EngineZmqAddresses struct: inputs, outputs,
# coordinator_input, coordinator_output, frontend_stats_publish.
return init_message.addresses
# Phase 2 (inside _perform_handshake around line ~970, after profiling):
# Report KV cache block count back to the frontend so the client
# knows the engine finished warmup and is ready for traffic.
handshake_socket.send(msgspec.msgpack.encode({
"status": "READY",
"local": local_client,
"headless": headless,
"num_gpu_blocks": vllm_config.cache_config.num_gpu_blocks,
"dp_stats_address": dp_stats_address,
}))
EngineCore hits a fatal error after startup, run_engine_core() calls _send_engine_dead(), which enqueues the sentinel ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD" into the output queue. The Output I/O thread sends the sentinel on every output socket (PUSH, linger=4s) and exits. The client's output-reader detects this exact sentinel in validate_alive() and surfaces it as EngineDeadError so the API server can fail fast instead of hanging.
The MPClient base class sets up two ZMQ sockets for bidirectional communication. The client owns the ROUTER (binds) and PULL (connects); each engine connects to them via DEALER and PUSH respectively:
# ZMQ context: 2 I/O threads is enough because most work
# happens in the per-client asyncio loop, not in ZMQ threads.
sync_ctx = zmq.Context(io_threads=2)
# For AsyncLLM we wrap the sync context so sockets return
# awaitables (via zmq.asyncio); SyncMPClient uses sync_ctx directly.
self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx
# Input socket: ROUTER pattern, BIND.
# - Client binds; each engine's DEALER socket connects here.
# - First frame of every multipart message is the 2-byte
# engine identity (dp_rank in little-endian), which ROUTER
# uses to route replies back to the correct engine.
(input_address, ) = addresses.inputs
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, input_address, zmq.ROUTER, bind=True)
# Output socket: PULL pattern, CONNECT.
# - Client connects; engine's PUSH socket binds.
# - ZMQ fair-queues across all connected PUSH peers, so a
# single PULL socket is enough even with DP>1.
(output_address, ) = addresses.outputs
self.resources.output_socket = make_zmq_socket(
self.ctx, output_address, zmq.PULL)
# Engine identities this client talks to. For DP>1 with
# internal LB, this is [0, 1, ..., dp_size-1] in little-endian 2-byte form.
self.core_engines: list[EngineIdentity] = [
rank.to_bytes(2, "little")
for rank in self.engine_ranks_managed
]
def _send_input(self, request_type, request, engine=None):
if engine is None:
engine = self.core_engine
message = (request_type.value, *self.encoder.encode(request))
return self._send_input_message(message, engine, request)
def _send_input_message(self, message, engine, objects):
self.ensure_alive()
self.free_pending_messages()
msg = (engine,) + message # (Identity, RequestType, SerializedData...)
if not objects or len(msg) <= 3:
# No auxiliary buffers => no tensor backing buffers in request.
return self.input_socket.send_multipart(msg, copy=False)
# Track message for zero-copy buffer lifetime management.
future = self.input_socket.send_multipart(msg, copy=False, track=True)
...
The on-wire message shape is (engine_identity, request_type_bytes, msgpack_payload, [aux_buffer_1, aux_buffer_2, ...]). The ROUTER uses the 2-byte little-endian engine identity to steer messages to the right DEALER; aux buffers after the payload carry large tensor memory without copying.
async def process_outputs_socket():
try:
while True:
frames = await output_socket.recv_multipart(copy=False)
resources.validate_alive(frames)
outputs: EngineCoreOutputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output, utility_results)
continue
if outputs.outputs or outputs.scheduler_stats:
outputs_queue.put_nowait(outputs)
except Exception as e:
outputs_queue.put_nowait(e)
This runs as an asyncio.Task that continuously reads from the PULL socket. Decoded EngineCoreOutputs go to outputs_queue for get_output_async() to consume; utility replies are routed to their pending futures in utility_results; any exception from the socket (including the ENGINE_CORE_DEAD sentinel) is pushed into outputs_queue so every waiter fails fast.
# Two decoders. The strongly-typed one populates EngineCoreRequest
# in C; both wire up oob_tensor_provider so multimodal tensors that
# travel via shared memory can be reconstructed (see §5).
add_request_decoder = MsgpackDecoder(
EngineCoreRequest, oob_tensor_provider=self.tensor_ipc_receiver
)
generic_decoder = MsgpackDecoder(oob_tensor_provider=self.tensor_ipc_receiver)
while True:
for input_socket, _ in poller.poll():
# Frame 0: request type byte string (ADD / ABORT / UTILITY / ...).
# Frames 1+: msgpack payload + optional zero-copy tensor buffers.
type_frame, *data_frames = input_socket.recv_multipart(copy=False)
# Coordinator may send b"READY" on the XSUB channel to wake new
# engines; ignore it -- nothing to enqueue.
if type_frame.buffer == b"READY":
assert input_socket == coord_socket
continue
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Deserialize the request data.
request: Any
if request_type == EngineCoreRequestType.ADD:
req: EngineCoreRequest = add_request_decoder.decode(data_frames)
try:
# preprocess_add_request() does MM cache lookup, KV block
# hash computation, and grammar init -- all CPU work that
# overlaps with the main thread's forward pass.
request = self.preprocess_add_request(req)
except Exception:
# Bad request: surface the error to the client immediately
# and skip enqueue so the engine never sees it.
self._handle_request_preproc_error(req)
continue
else:
request = generic_decoder.decode(data_frames)
if request_type == EngineCoreRequestType.ABORT:
# Aborts go to BOTH queues: the eager aborts_queue lets
# step() drain in-flight aborts mid-forward-pass; the
# input_queue copy preserves ordering for FIFO dispatch.
# Idempotent on the scheduler side, so dual-delivery is safe.
self.aborts_queue.put_nowait(request)
# Push to input queue for the main loop to pick up.
self.input_queue.put_nowait((request_type, request))
EngineCoreRequestType.ABORT messages are written to both aborts_queue (eager mid-step drain) and input_queue (FIFO ordering); the scheduler-side abort is idempotent so dual delivery is safe.
VLLM_MSGPACK_ZERO_COPY_THRESHOLD; OOB tier active only when oob_tensor_consumer is wired up (multimodal fast-path).
vLLM uses msgspec (a high-performance msgpack library) with custom extension types for tensors, numpy arrays, slices, and multimodal kwargs. The implementation is in vllm/v1/serial_utils.py. The oob_tensor_consumer / oob_tensor_provider pair is what enables true OOB IPC for CUDA / shared-memory tensors — see lines 152 and 327 of the source.
class MsgpackEncoder:
"""Encoder with custom torch tensor and numpy array serialization.
By default, arrays below 256B are serialized inline. Larger will get
sent via dedicated messages. Note that this is a per-tensor limit.
When a ``oob_tensor_consumer`` is provided, tensors (CUDA and CPU) will
be offered to it for out-of-band handling.
"""
def __init__(
self,
size_threshold: int | None = None,
oob_tensor_consumer: OOBTensorConsumer | None = None,
):
if size_threshold is None:
# VLLM_MSGPACK_ZERO_COPY_THRESHOLD, default 256.
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
# Local stash so enc_hook (no custom-data param) can find the
# aux_buffers list during a single encode() call.
self.aux_buffers: list[bytestr] | None = None
self.size_threshold = size_threshold
self.oob_tensor_consumer = oob_tensor_consumer
def encode(self, obj: Any) -> Sequence[bytestr]:
try:
# OOB consumer needs per-message state (e.g. shared-mem cursor).
if self.oob_tensor_consumer is not None:
self.oob_tensor_consumer.new_message()
# bufs[0] reserved for main msgpack frame; bufs[1:] grow as
# enc_hook appends tensor backing buffers (aux frames).
self.aux_buffers = bufs = [b""]
bufs[0] = self.encoder.encode(obj)
return bufs
finally:
self.aux_buffers = None
def _encode_tensor(self, obj: torch.Tensor):
oob_consumer = self.oob_tensor_consumer
if obj.nbytes < self.size_threshold and obj.is_cpu:
# Tier 1: inline. Ext type carries the bytes inside the main
# msgpack frame; decoder returns a zero-copy memoryview.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, tensor_data(obj))
elif oob_consumer is not None and (data := oob_consumer(obj)) is not None:
# Tier 2: out-of-band IPC. Consumer returns a dict describing
# where the tensor lives (e.g. shared-mem segment id, offset).
assert isinstance(data, dict)
else:
# Tier 3: zero-copy aux frame. Encode integer index;
# ZMQ send_multipart will transmit tensor_data(obj) as a
# separate frame without copying.
assert self.aux_buffers is not None
data = len(self.aux_buffers)
self.aux_buffers.append(tensor_data(obj))
dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data
| Tier | Condition | Method | Performance |
|---|---|---|---|
| Inline | obj.is_cpu and nbytes < size_threshold (default 256 B) |
Embed as msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, tensor_data(obj)) inside the main frame |
Single ZMQ frame; decoder returns a zero-copy memoryview |
| OOB IPC | oob_tensor_consumer wired up and consumer accepts the tensor (returns a dict) |
Tensor stays in shared memory / CUDA IPC; main frame carries only the metadata dict | Best for large multimodal tensors (vision encoder embeddings, video frames) |
| Zero-copy aux frame | Fall-through: large CPU tensors, or any tensor the OOB consumer rejects | Main frame stores integer index; tensor bytes go out as a separate ZMQ frame (send_multipart(copy=False)) |
No copy on send; receiver pins the message buffer in memory while the tensor is alive |
class MsgpackDecoder:
def __init__(
self,
t: Any | None = None,
oob_tensor_provider: OOBTensorProvider | None = None,
):
# Provider is the symmetric peer of the encoder's consumer:
# it knows how to look a tensor up in shared memory / CUDA IPC.
self.oob_tensor_provider = oob_tensor_provider
...
def decode(self, bufs: bytestr | Sequence[bytestr]) -> Any:
if isinstance(bufs, bytestr):
# Single-frame message (no aux tensor frames).
return self.decoder.decode(bufs)
# Multi-frame: aux_buffers stashes the extra frames so the
# dec_hook can pull them by index when decoding tensors.
self.aux_buffers = bufs
try:
return self.decoder.decode(bufs[0])
finally:
self.aux_buffers = ()
def _decode_tensor(self, arr: Any) -> torch.Tensor:
dtype, shape, data = arr
if isinstance(data, dict):
# Tier 2 path: tensor lives out-of-band; ask provider to fetch it.
assert self.oob_tensor_provider, (
"Received OOB tensor but tensor provider is not set")
return self.oob_tensor_provider(dtype, shape, data)
is_aux = isinstance(data, int)
buffer = self.aux_buffers[data] if is_aux else data
buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer)
torch_dtype = getattr(torch, dtype)
if not buffer.nbytes:
# torch.frombuffer refuses empty buffers; produce empty tensor.
assert 0 in shape
return torch.empty(shape, dtype=torch_dtype)
# Zero-copy view of the bytes as uint8.
arr = torch.frombuffer(buffer, dtype=torch.uint8)
# Inline tensors: clone() so the returned tensor owns its memory
# (decoupled from the message buffer that may be recycled).
# Aux tensors: pin_memory() if requested for fast CPU->GPU H2D.
if not is_aux:
arr = arr.clone()
elif not self.share_mem:
arr = arr.pin_memory() if self.pin_tensors else arr.clone()
return arr.view(torch_dtype).view(shape)
The Output I/O thread reuses bytearray buffers to keep msgpack encoding off the allocator hot path. It only reclaims a buffer back into the pool once ZMQ's MessageTracker reports that the socket has finished with it, so the zero-copy send is safe:
encoder = MsgpackEncoder()
# Pool of bytearrays we can hand back to MsgpackEncoder.encode_into()
# instead of allocating a fresh one for every message.
reuse_buffers: list[bytearray] = []
# Outstanding sends: (tracker, retained_object, buffer). We hold
# onto the retained object + buffer until tracker.done says ZMQ
# is finished reading from them (zero-copy sends read lazily).
pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
max_reuse_bufs = len(sockets) + 1
while True:
output = self.output_queue.get()
# Sentinel path: engine crashed or is shutting down.
if output == EngineCoreProc.ENGINE_CORE_DEAD:
for socket in sockets:
socket.send(output)
break
client_index, outputs = output
# Tag with our engine index so the client can route aborts
# and stats back to the right engine (useful with DP).
outputs.engine_index = engine_index
# Walk the pending deque from oldest side and reclaim any
# bytearrays ZMQ has finished with.
while pending and pending[-1][0].done:
reuse_buffers.append(pending.pop()[2])
# encode_into() writes into `buffer` in place and returns the
# list of frames: [main_buffer, aux_tensor_bufs...].
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
buffers = encoder.encode_into(outputs, buffer)
# copy=False + track=True: zero-copy send with a tracker so we
# know when ZMQ is done, without blocking this loop on the send.
tracker = sockets[client_index].send_multipart(buffers,
copy=False,
track=True)
if not tracker.done:
# Keep buffer + (optional) outputs ref alive until send completes.
ref = outputs if len(buffers) > 1 else None
pending.appendleft((tracker, ref, buffer))
elif len(reuse_buffers) < max_reuse_bufs:
reuse_buffers.append(buffer)
@staticmethod
def make_client(
multiprocess_mode: bool,
asyncio_mode: bool,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
) -> "EngineCoreClient":
# asyncio_mode without multiproc is not supported -- there is
# no InprocClient variant that drives the engine from an
# asyncio event loop in the same process.
if asyncio_mode and not multiprocess_mode:
raise NotImplementedError("Running EngineCore in asyncio without "
"multiprocessing is not currently supported.")
# Main async path -- used by AsyncLLM / API server.
if multiprocess_mode and asyncio_mode:
return EngineCoreClient.make_async_mp_client(
vllm_config, executor_class, log_stats)
# Sync multi-proc -- used by the offline LLM() API when the user
# doesn't want an asyncio loop.
if multiprocess_mode and not asyncio_mode:
return SyncMPClient(vllm_config, executor_class, log_stats)
# Everything in one process -- used for LLMEngine V0-compat and tests.
return InprocClient(vllm_config, executor_class, log_stats)
@staticmethod
def make_async_mp_client(
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
client_addresses: Optional[dict[str, str]] = None,
client_count: int = 1,
client_index: int = 0,
) -> "MPClient":
parallel_config = vllm_config.parallel_config
client_args = (vllm_config, executor_class, log_stats,
client_addresses, client_count, client_index)
if parallel_config.data_parallel_size > 1:
if parallel_config.data_parallel_external_lb:
# One client per DP rank; an external LB (e.g. K8s Service,
# Nginx) fans requests out across API server replicas.
return DPAsyncMPClient(*client_args)
# One client talks to all DP ranks and load-balances internally
# using the score = waiting*4 + running heuristic (see §6.4).
return DPLBAsyncMPClient(*client_args)
# DP=1: plain async multi-proc client, no LB logic needed.
return AsyncMPClient(*client_args)
| Client | Process Model | Async | DP Support | Use Case |
|---|---|---|---|---|
InprocClient |
Single process | No | No | V0-compatible LLMEngine |
SyncMPClient |
Multi-process | No | Basic | Sync LLM with background EngineCore |
AsyncMPClient |
Multi-process | Yes | No (DP=1) | AsyncLLM / API server |
DPAsyncMPClient |
Multi-process | Yes | External LB | DP with external load balancer |
DPLBAsyncMPClient |
Multi-process | Yes | Internal LB | DP with built-in load balancing |
The simplest variant — no IPC, no threads, no msgpack. The EngineCore instance runs in the same Python process as LLMEngine; get_output() drives it by calling step() directly. Used by LLMEngine.step() in offline scripts and tests.
class InprocClient(EngineCoreClient):
"""
InprocClient: client for in-process EngineCore. Intended
for use in LLMEngine for V0-style add_request() and step()
EngineCore setup in this process (no busy loop).
* pushes EngineCoreRequest directly into the EngineCore
* pulls EngineCoreOutputs by stepping the EngineCore
"""
def __init__(self, *args, **kwargs):
# Construct EngineCore in-process: model load, KV profiling,
# scheduler setup — all synchronous.
self.engine_core = EngineCore(*args, **kwargs)
def get_output(self) -> EngineCoreOutputs:
# step_fn is bound to either step() or step_with_batch_queue()
# in EngineCore.__init__, depending on pipeline-parallel config.
outputs, model_executed = self.engine_core.step_fn()
# post_step() drives speculative-decoding draft-token handoff;
# also where DP rank publishes request counts to coordinator.
self.engine_core.post_step(model_executed=model_executed)
# InprocClient has only client_index 0. If outputs is None or
# has no key 0, hand back an empty struct so the caller's
# downstream loop stays uniform.
return outputs and outputs.get(0) or EngineCoreOutputs()
def add_request(self, request: EngineCoreRequest) -> None:
# Preprocess (MM cache lookup, KV block hashing, grammar init)
# and push straight into the scheduler -- no IPC boundary.
req, request_wave = self.engine_core.preprocess_add_request(request)
self.engine_core.add_request(req, request_wave)
def abort_requests(self, request_ids: list[str]) -> None:
if len(request_ids) > 0:
self.engine_core.abort_requests(request_ids)
The most sophisticated client variant routes each request to one of the DP engine processes, using recent queue-length stats published by the DPCoordinator. Callers can override the decision by setting request.data_parallel_rank; otherwise the client runs a linear-scan argmin on a per-engine score:
def get_core_engine_for_request(
self, request: EngineCoreRequest) -> EngineIdentity:
# Engines are in rank order.
if (eng_index := request.data_parallel_rank) is None and (
# Late-interaction models (e.g. ColBERT-style pooling) pin
# requests to a specific DP rank based on the pooling params,
# so embeddings for one document always land on the same engine.
eng_index := get_late_interaction_engine_index(
request.pooling_params, len(self.core_engines)
)
) is None:
# Neither explicit pin nor late-interaction routing -- fall
# through to load-based selection.
current_counts = self.lb_engines
# TODO use P2C alg for larger DP sizes
num_engines = len(current_counts)
min_score = sys.maxsize
eng_index = 0
for i in range(num_engines):
# Rotate the scan start across clients so N simultaneous
# requests with empty queues don't all pick engine 0.
idx = (self.eng_start_index + i) % num_engines
waiting, running = current_counts[idx]
# 4x weight on waiting means "prefer any engine currently
# running over one that's queueing". Tuned to keep TPOT flat.
score = waiting * 4 + running
if score < min_score:
min_score = score
eng_index = idx
# Optimistically bump our mirror so concurrent pick calls
# between coordinator stats publishes (~100ms) don't all
# pile onto the same "empty" engine.
current_counts[eng_index][0] += self.client_count
chosen_engine = self.core_engines[eng_index]
# Record choice so abort_requests() knows where to route the ABORT.
self.reqs_in_flight[request.request_id] = chosen_engine
return chosen_engine
score = waiting * 4 + running. The 4x weight on waiting requests means requests are steered towards engines with shorter queues. Between coordinator updates (~100ms), the client also locally increments the chosen engine's count to prevent burst routing.
The OutputProcessor lives in the frontend process and converts raw EngineCoreOutput (token IDs) into user-facing RequestOutput objects with decoded text.
def process_outputs(self, engine_core_outputs, engine_core_timestamp=None, iteration_stats=None):
"""Process the EngineCoreOutputs:
1) Compute stats for logging
2) Detokenize
3) Create and handle RequestOutput objects:
- If queue exists (AsyncLLM): put into queue for generate()
- If no queue (LLMEngine): return list of RequestOutputs
NOTE: vLLM V1 minimizes the number of python loops over the full
batch. This is the ONLY function that loops over EngineCoreOutputs."""
request_outputs = []
reqs_to_abort = []
for engine_core_output in engine_core_outputs:
req_state = self.request_states.get(engine_core_output.request_id)
if req_state is None:
continue # Ignore output for already-aborted request
# 1) Compute stats for this iteration.
self._update_stats_from_output(req_state, engine_core_output, ...)
# 2) Detokenize the token ids into text and perform stop checks.
if pooling_output is None:
stop_string = req_state.detokenizer.update(
new_token_ids, finish_reason == FinishReason.STOP
)
if stop_string:
finish_reason = FinishReason.STOP
stop_reason = stop_string
# 3) Compute sample and prompt logprobs if required.
req_state.logprobs_processor.update_from_output(engine_core_output)
# 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output(...):
if req_state.queue is not None:
req_state.queue.put(request_output) # AsyncLLM path
else:
request_outputs.append(request_output) # LLMEngine path
# 5) Free completed requests.
if finish_reason is not None:
if not engine_core_output.finished:
# Detokenizer detected stop string -- abort in EngineCore
reqs_to_abort.append(req_id)
EngineCoreOutputs. This minimizes Python-level overhead on the hot path -- all per-request operations (detokenization, logprobs, stats, output creation) are done in a single pass.
For AsyncLLM, each request has a RequestOutputCollector that uses asyncio.Event for backpressure. If the OutputProcessor produces faster than the user's generate() consumer, deltas are merged in place instead of queuing (no unbounded growth):
class RequestOutputCollector:
"""
Collects streamed RequestOutputs per individual request,
for hand-off to the consuming asyncio generate task.
When streaming deltas, RequestOutputs are merged if the
producer gets ahead of the consumer.
"""
def __init__(self, output_kind: RequestOutputKind, request_id: str):
# DELTA mode streams per-step diffs that can be merged;
# FINAL_ONLY and CUMULATIVE always overwrite.
self.aggregate = output_kind == RequestOutputKind.DELTA
# request_id kept so debug logs / cancel paths can identify
# the owner without walking up to the Processor.
self.request_id = request_id
self.output: RequestOutput | PoolingRequestOutput | Exception | None = None
self.ready = asyncio.Event()
# Reserved for input-stream driven flows (e.g. audio/video
# prompt streaming). Populated by AsyncLLM when applicable.
self._input_stream_task: asyncio.Task | None = None
def put(self, output) -> None:
"""Non-blocking put operation."""
if self.output is None or isinstance(output, Exception):
# First delta after a get(), or an error that trumps any
# pending data. Wake up the waiting consumer immediately.
self.output = output
self.ready.set()
elif isinstance(self.output, RequestOutput) and isinstance(
output, RequestOutput
):
# Consumer hasn't drained yet -- merge into pending output so
# memory stays bounded and the consumer sees a single coalesced
# chunk on its next await. add() handles per-request-index
# merging when n>1 so parallel samples don't overwrite.
self.output.add(output, aggregate=self.aggregate)
elif isinstance(self.output, PoolingRequestOutput) and isinstance(
output, PoolingRequestOutput
):
# Pooling outputs are point-in-time embeddings -- latest wins.
self.output = output
async def get(self) -> RequestOutput | PoolingRequestOutput:
"""Get operation blocks on put event."""
# Sleep efficiently until put() fires the event; no busy wait.
while (output := self.output) is None:
await self.ready.wait()
self.output = None
self.ready.clear()
if isinstance(output, Exception):
raise output
return output
AsyncLLM lives in the frontend process and owns three collaborators: a Processor (tokenization + MM input handling), an OutputProcessor (detokenization + streaming), and an engine_core client (ZMQ bridge to the background EngineCore process):
# Tokenizer: may be None for text-free pooling models.
self.tokenizer = init_tokenizer_from_configs(...)
# Processor converts raw user inputs (text, chat messages,
# multimodal payloads) into EngineCoreRequest objects, doing
# tokenization + MM preprocessing synchronously here.
self.processor = Processor(
vllm_config=vllm_config,
tokenizer=self.tokenizer,
mm_registry=mm_registry,
)
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
# Receives token-id outputs over ZMQ and streams them to users.
self.output_processor = OutputProcessor(self.tokenizer,
log_stats=self.log_stats)
# EngineCore client: spawns the background engine process (via
# launch_core_engines), sets up ROUTER/PULL sockets, and returns
# an async client whose methods await on ZMQ round-trips.
self.engine_core = EngineCoreClient.make_async_mp_client(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
client_addresses=client_addresses,
client_count=client_count,
client_index=client_index,
)
def step(self) -> list[RequestOutput | PoolingRequestOutput]:
# Dummy-batch bypass: runs an empty forward pass to keep CUDA
# graphs warm and to satisfy DP sync barriers when this rank
# has nothing scheduled.
if self.should_execute_dummy_batch:
self.should_execute_dummy_batch = False
self.engine_core.execute_dummy_batch()
return []
# 1) Pull EngineCoreOutputs from the EngineCore. Each branch
# (Inproc direct step vs MP blocking recv) is wrapped with
# record_function_or_nullcontext for PyTorch Profiler spans.
with record_function_or_nullcontext("llm_engine step: get_output"):
outputs = self.engine_core.get_output()
# 2) Detokenize + stop-string check + build RequestOutput.
# Also updates scheduler-stats mirrors used for logging.
with record_function_or_nullcontext("llm_engine step: process_outputs"):
iteration_stats = IterationStats() if self.log_stats else None
processed_outputs = self.output_processor.process_outputs(
outputs.outputs,
engine_core_timestamp=outputs.timestamp,
iteration_stats=iteration_stats,
)
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
# 3) Stop-string detection happens locally; engine needs to
# hear about it so it can free KV blocks on the next iter.
with record_function_or_nullcontext("llm_engine step: abort_requests"):
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
# 4) Emit Prometheus / stdout metrics. Skip when outputs.outputs
# is empty — no point recording a zero-token iteration.
with record_function_or_nullcontext("llm_engine step: record_stats"):
if (self.logger_manager is not None
and outputs.scheduler_stats is not None
and len(outputs.outputs) > 0):
self.logger_manager.record(
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
mm_cache_stats=self.renderer.stat_mm_cache(),
)
self.do_log_stats_with_interval()
return processed_outputs.request_outputs
For data-parallel deployments (DP > 1), vLLM runs a DPCoordinator in its own process. It aggregates queue statistics from all engine processes and manages the "request wave" protocol that coordinates their running/paused state.
class DPCoordinator:
"""Coordinator process used for data-parallel deployments (DP>1).
Intermediates between multiple DP engine rank processes and one or
more front-end API server processes.
* Collects stats from each DP engine (waiting / running queue lengths)
and publishes them to all front-ends for LB decisions.
* Tracks the current DP request-wave number. This transition is
synchronized via the all-reduce in
DPEngineCoreProc._has_global_unfinished_reqs().
* Broadcasts START_DP_WAVE to engines when any one receives a request
during the paused state.
"""
def __init__(
self, parallel_config: ParallelConfig,
enable_wave_coordination: bool = True,
):
dp_size = parallel_config.data_parallel_size
assert dp_size > 1, "Coordinator only used for data parallel"
host = parallel_config.data_parallel_master_ip
# Colocation signal from parallel_config (internal LB means the
# coordinator shares a host with front-ends, so IPC is fine).
local_only = not parallel_config.local_engines_only
local_only_eng = dp_size == parallel_config.data_parallel_size_local
# Elastic-EP needs to scale from intra-node to inter-node, so force TCP.
if parallel_config.enable_elastic_ep:
local_only_eng = False
def bind_address(local_only: bool) -> str:
# IPC when colocated, TCP port-0 (kernel-assigned) otherwise.
return (get_engine_client_zmq_addr(local_only=True, host=host)
if local_only else get_tcp_uri(host, 0))
front_publish_address = bind_address(local_only)
back_publish_address = bind_address(local_only_eng)
back_output_address = bind_address(local_only_eng)
# Duplex=False one-way pipe: child sends back the actual bound
# addresses (after kernel port-0 assignment) so the parent knows
# where to tell engines and front-ends to connect.
context = get_mp_context()
parent_zmq_addr_pipe, child_zmq_addr_pipe = context.Pipe(duplex=False)
self.proc: multiprocessing.Process = context.Process(
target=DPCoordinatorProc.run_coordinator,
name="VLLM_DP_Coordinator",
kwargs={
"engine_count": parallel_config.data_parallel_size,
"front_publish_address": front_publish_address,
"back_output_address": back_output_address,
"back_publish_address": back_publish_address,
"zmq_addr_pipe": child_zmq_addr_pipe,
"enable_wave_coordination": enable_wave_coordination,
},
daemon=True,
)
self.proc.start()
child_zmq_addr_pipe.close()
# Block (with 30s timeout) until the child reports back the
# bound addresses; also trips if the child dies during startup.
(front_publish_address, back_output_address,
back_publish_address) = self._wait_for_zmq_addrs(parent_zmq_addr_pipe)
self.stats_publish_address = front_publish_address
self.coord_in_address = back_publish_address
self.coord_out_address = back_output_address
self._finalizer = weakref.finalize(self, shutdown, [self.proc])
The coordinator process owns three ZMQ sockets — one facing the API servers, two facing the engines — plus uses XPUB everywhere publish is needed so subscriptions can be tracked:
Publishes aggregated stats and wave state to all frontend API server processes. Frontends subscribe via XSUB and can also send SCALE_ELASTIC_EP on this channel.
PULL output_back: receives scheduler_stats, wave_complete, and start_wave from engine processes. XPUB publish_back: broadcasts START_DP_WAVE notifications back to all engines.
The "request wave" is the DP barrier mechanism: all engines alternate between a global running state and a global paused state, and the count of running→paused transitions is the wave number. Rank 0 drives wave_complete via an all-reduce in DPEngineCoreProc._has_global_unfinished_reqs(); the coordinator uses start_wave to wake engines that received a request during a paused wave.
# Wave coordination: handle wave completion and start notifications
if (wave := outputs.wave_complete) is not None:
# Notification from rank 0 engine that we've moved into
# the global paused state (engines_running==False).
if current_wave <= wave:
new_wave = wave + 1
logger.debug("Moving DP wave from %d to %d.", current_wave, new_wave)
current_wave = new_wave
engines_running = False
wave_state_changed = True
elif (wave := outputs.start_wave) is not None:
# Engine received request for a non-current wave --
# ensure other engines progress to next wave.
engines_running = True
wave_state_changed = True
self._send_start_wave(publish_back, wave, eng_index)
Engines PUSH their per-step SchedulerStats to output_back. The coordinator tracks the freshest (wave, step) pair per engine and only publishes to front-ends when either (a) the wave state changed or (b) the rate-limited publish timer fires. This keeps the fan-out cheap — a DP=8 deployment does not produce 8 messages per engine step.
# (inside the output_back (PULL) branch, after decoding EngineCoreOutputs)
eng_index = outputs.engine_index
scheduler_stats = outputs.scheduler_stats
if scheduler_stats:
# Reject stale updates: accept only if this message has a
# newer (wave, step_counter) than we've already seen from
# ANY engine -- protects against out-of-order ZMQ delivery.
stats = self.engines[eng_index].request_counts
stats_step = scheduler_stats.step_counter
stats_wave = scheduler_stats.current_wave
if (stats_wave > last_stats_wave
or stats_wave == last_stats_wave
and stats_step > last_stats_step):
if stats_changed:
last_step_counts = self._get_engine_counts(do_copy=True)
last_stats_step = stats_step
last_stats_wave = stats_wave
# Record the fresh counts in our mirror.
stats[0] = scheduler_stats.num_waiting_reqs
stats[1] = scheduler_stats.num_running_reqs
stats_changed = True
# ---- Publish ---- (only when wave state toggles, keeps messages sparse)
if wave_state_changed:
# Tuple format: (engine_counts_or_None, current_wave, engines_running).
# When only wave/running changed (not counts), first element is None.
message = (None, current_wave, engines_running)
publish_front.send(msgspec.msgpack.encode(message))
The coordinator also supports dynamic scaling of the engine mesh. When an API server sends a SCALE_ELASTIC_EP message on publish_front, the coordinator resizes its per-engine EngineState list and resets the global running flag so newly-added engines (which start at wave 0) will participate in the next START_DP_WAVE broadcast:
# Received on the publish_front XPUB socket from an API server.
# Payload is a 2-tuple: ("SCALE_ELASTIC_EP", new_engine_count).
if isinstance(decoded, (list, tuple)) and len(decoded) == 2 \
and decoded[0] == "SCALE_ELASTIC_EP":
new_engine_count = decoded[1]
current_count = len(self.engines)
if new_engine_count > current_count:
# Scale up: create fresh EngineState slots. Their
# request_counts start at [0, 0] so the LB will route
# requests to them on the next publish.
for _ in range(new_engine_count - current_count):
self.engines.append(EngineState())
# Force engines_running=False so the new engines (wave=0)
# are included in the next START_DP_WAVE handshake; without
# this, wave-mismatch could strand them.
engines_running = False
logger.info("DPCoordinator scaled up from %s to %s engines",
current_count, new_engine_count)
else:
# Scale down: drop trailing slots. Safe because the ray
# backend is the only caller that supports this, and it
# removes engines before notifying the coordinator.
self.engines = self.engines[:new_engine_count]
logger.info("DPCoordinator scaled down from %s to %s engines",
current_count, new_engine_count)
continue # Skip the normal engine-notification path
OutputProcessor.process_outputs() loops over the whole batch exactly once — stats, detokenize, logprobs, stop-check, build RequestOutput.bytearray buffers for encode_into(), reclaiming each one only after ZMQ's MessageTracker reports delivery complete.max_concurrent_batches > 1, step_with_batch_queue() schedules the next PP batch before the previous one's future resolves, eliminating pipeline bubbles.vllm/v1/engine/core.py -- EngineCore, EngineCoreProc, DPEngineCoreProc (busy loop, I/O threads, handshake)vllm/v1/engine/core_client.py -- InprocClient, SyncMPClient, AsyncMPClient, DPAsyncMPClient, DPLBAsyncMPClientvllm/v1/engine/async_llm.py -- AsyncLLM frontendvllm/v1/engine/llm_engine.py -- LLMEngine sync frontendvllm/v1/engine/output_processor.py -- OutputProcessor, RequestState, RequestOutputCollectorvllm/v1/engine/coordinator.py -- DPCoordinator, DPCoordinatorProcvllm/v1/serial_utils.py -- MsgpackEncoder, MsgpackDecoder