A source-code deep dive into vLLM's multi-process engine design, ZeroMQ messaging,
msgpack serialization, and data-parallel coordination
vLLM V1 separates the engine into two process-level concerns: a frontend process (the API server or LLMEngine) that handles tokenization, detokenization, and client-facing I/O, and a backend process (the EngineCore) that runs the GPU-bound scheduling and model execution loop. These two processes communicate via ZeroMQ sockets with msgpack serialization.
The key insight of V1's architecture is process isolation: the GPU-bound EngineCore runs in its own process with dedicated I/O threads, while tokenization, detokenization, and client management live in the frontend. This separation is bridged by ZeroMQ for message passing and msgpack for efficient serialization.
vllm/v1/engine/core.py (EngineCore + EngineCoreProc), vllm/v1/engine/core_client.py (all client variants), vllm/v1/engine/async_llm.py (AsyncLLM frontend), vllm/v1/engine/output_processor.py (detokenization), vllm/v1/engine/coordinator.py (DP coordinator), vllm/v1/serial_utils.py (msgpack).
The EngineCore class in core.py implements the fundamental execution loop. It is the "inner loop of vLLM's Engine", as the source code states.
The EngineCoreProc subclass wraps EngineCore with ZMQ I/O and runs the busy loop in a background process:
def run_busy_loop(self):
"""Core busy loop of the EngineCore."""
while self._handle_shutdown():
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# 2) Step the engine core and return the outputs.
self._process_engine_step()
raise SystemExit
This loop has only two phases: drain the input queue (which receives deserialized requests from the I/O thread) and then step the engine. The loop continues until a shutdown signal is received.
The core computation happens in step() -- vLLM's per-iteration heartbeat:
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
"""Schedule, execute, and make output."""
# Check for any requests remaining in the scheduler
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with (
self.log_error_detail(scheduler_output),
self.log_iteration_details(scheduler_output),
):
model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)
# Before processing the model output, process any aborts that happened
# during the model execution.
self._process_aborts_queue()
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
scheduler.schedule() -- selects which requests to batch togethermodel_executor.execute_model() -- dispatches GPU computation (non-blocking)scheduler.update_from_output() -- processes results, advances request state
When pipeline parallelism is enabled (max_concurrent_batches > 1), vLLM uses a batch queue to overlap scheduling with execution, eliminating pipeline bubbles:
self.batch_queue_size = self.model_executor.max_concurrent_batches
self.batch_queue: deque[...] | None = None
if self.batch_queue_size > 1:
logger.debug("Batch queue is enabled with size %d", self.batch_queue_size)
self.batch_queue = deque(maxlen=self.batch_queue_size)
self.step_fn = (
self.step if self.batch_queue is None else self.step_with_batch_queue
)
def _process_engine_step(self) -> bool:
"""Called only when there are unfinished local requests."""
# Step the engine core.
outputs, model_executed = self.step_fn()
# Put EngineCoreOutputs into the output queue.
for output in outputs.items() if outputs else ():
self.output_queue.put_nowait(output)
# Post-step hook.
self.post_step(model_executed)
# If no model execution happened but there are waiting requests
# yield the GIL briefly to allow background threads to make progress.
if not model_executed and self.scheduler.has_unfinished_requests():
time.sleep(0.001)
return model_executed
After each step, outputs are placed into self.output_queue (a queue.Queue), which is drained by the dedicated output I/O thread and sent over ZMQ to the frontend.
def _process_input_queue(self):
"""Exits when an engine step needs to be performed."""
waited = False
while not self.has_work() and self.is_running():
# Notify callbacks waiting for engine to become idle.
self._notify_idle_state_callbacks()
if self.input_queue.empty():
# Drain aborts queue; clear all pending aborts.
with self.aborts_queue.mutex:
self.aborts_queue.queue.clear()
block = self.process_input_queue_block
try:
req = self.input_queue.get(block=block)
self._handle_client_request(*req)
except queue.Empty:
break
# Handle any more client requests.
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(*req)
When the engine has no work, it blocks on the input queue (sleeping the thread). Once there is work, it drains all remaining messages non-blocking before proceeding to the engine step.
EngineCoreProc.run_engine_core() is the entry point for the background process. It is invoked by the process management infrastructure:
@staticmethod
def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
"""Launch EngineCore busy loop in background process."""
maybe_register_config_serialize_by_value()
engine_core: EngineCoreProc | None = None
try:
vllm_config: VllmConfig = kwargs["vllm_config"]
parallel_config = vllm_config.parallel_config
data_parallel = parallel_config.data_parallel_size > 1 or dp_rank > 0
if data_parallel:
parallel_config.data_parallel_rank_local = local_dp_rank
process_title = f"EngineCore_DP{dp_rank}"
else:
process_title = "EngineCore"
set_process_title(process_title)
# Choose DPEngineCoreProc for MoE models with DP,
# regular EngineCoreProc otherwise.
if data_parallel and vllm_config.model_config.is_moe:
parallel_config.data_parallel_rank = dp_rank
engine_core = DPEngineCoreProc(*args, **kwargs)
else:
parallel_config.data_parallel_size = 1
engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
engine_core.run_busy_loop()
except Exception as e:
if engine_core is not None:
engine_core._send_engine_dead()
raise e
finally:
if engine_core is not None:
engine_core.shutdown()
EngineCoreProc.__init__() spawns two daemon threads that bridge ZMQ sockets with the main loop's queue.Queue objects:
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
input_thread = threading.Thread(
target=self.process_input_sockets,
args=(addresses.inputs, addresses.coordinator_input, identity, ready_event),
daemon=True,
)
input_thread.start()
self.output_thread = threading.Thread(
target=self.process_output_sockets,
args=(addresses.outputs, addresses.coordinator_output, self.engine_index),
daemon=True,
)
Before the busy loop starts, EngineCoreProc performs a handshake with the frontend using a dedicated ZMQ DEALER socket. The handshake exchanges configuration, ZMQ addresses, and DP coordination info:
@staticmethod
def startup_handshake(handshake_socket, local_client, headless, parallel_config=None):
# Send registration message.
handshake_socket.send(
msgspec.msgpack.encode({"status": "HELLO", "local": local_client, "headless": headless})
)
# Receive initialization message.
init_bytes = handshake_socket.recv()
init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(init_bytes, type=EngineHandshakeMetadata)
if parallel_config is not None:
for key, value in init_message.parallel_config.items():
setattr(parallel_config, key, value)
return init_message.addresses
EngineCore encounters a fatal error, it sends the sentinel value ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD" through the output socket, which the client detects and raises EngineDeadError.
The MPClient base class sets up two ZMQ sockets for bidirectional communication:
vllm/v1/engine/core_client.py : MPClient.__init__()# ZMQ setup.
sync_ctx = zmq.Context(io_threads=2)
self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx
# Input socket: ROUTER pattern -- routes to specific engine by identity
self.input_socket = make_zmq_socket(
self.ctx, addresses.inputs[0], zmq.ROUTER, bind=True,
router_handover=enable_input_socket_handover,
)
# Output socket: PULL pattern -- pulls results from engine
self.resources.output_socket = make_zmq_socket(
self.ctx, addresses.outputs[0], zmq.PULL
)
def _send_input(self, request_type, request, engine=None):
if engine is None:
engine = self.core_engine
message = (request_type.value, *self.encoder.encode(request))
return self._send_input_message(message, engine, request)
def _send_input_message(self, message, engine, objects):
self.ensure_alive()
self.free_pending_messages()
msg = (engine,) + message # (Identity, RequestType, SerializedData...)
if not objects or len(msg) <= 3:
# No auxiliary buffers => no tensor backing buffers in request.
return self.input_socket.send_multipart(msg, copy=False)
# Track message for zero-copy buffer lifetime management.
future = self.input_socket.send_multipart(msg, copy=False, track=True)
...
The message format over ZMQ is: (engine_identity, request_type_bytes, msgpack_payload, [aux_buffer_1, aux_buffer_2, ...]). The ROUTER socket uses the identity byte prefix to route to the correct engine process.
async def process_outputs_socket():
try:
while True:
frames = await output_socket.recv_multipart(copy=False)
resources.validate_alive(frames)
outputs: EngineCoreOutputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output, utility_results)
continue
if outputs.outputs or outputs.scheduler_stats:
outputs_queue.put_nowait(outputs)
except Exception as e:
outputs_queue.put_nowait(e)
This runs as an asyncio.Task that continuously reads from the PULL socket. Decoded outputs go to the outputs_queue, while utility results (RPC responses) are routed to their pending futures.
# Msgpack serialization decoding with optional tensor IPC receiver.
add_request_decoder = MsgpackDecoder(
EngineCoreRequest, oob_tensor_provider=self.tensor_ipc_receiver
)
generic_decoder = MsgpackDecoder(oob_tensor_provider=self.tensor_ipc_receiver)
while True:
for input_socket, _ in poller.poll():
type_frame, *data_frames = input_socket.recv_multipart(copy=False)
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
if request_type == EngineCoreRequestType.ADD:
req = add_request_decoder.decode(data_frames)
request = self.preprocess_add_request(req)
else:
request = generic_decoder.decode(data_frames)
if request_type == EngineCoreRequestType.ABORT:
# Aborts go to BOTH queues for eager processing
self.aborts_queue.put_nowait(request)
self.input_queue.put_nowait((request_type, request))
vLLM uses msgspec (a high-performance msgpack library) with custom extensions for tensors, numpy arrays, and multimodal data. The implementation is in vllm/v1/serial_utils.py.
class MsgpackEncoder:
"""Encoder with custom torch tensor and numpy array serialization.
By default, arrays below 256B are serialized inline. Larger will get sent
via dedicated messages (zero-copy)."""
def encode(self, obj) -> Sequence[bytestr]:
if self.oob_tensor_consumer is not None:
self.oob_tensor_consumer.new_message()
self.aux_buffers = bufs = [b""]
bufs[0] = self.encoder.encode(obj)
# bufs collects direct pointers to backing buffers of
# tensors and np arrays, returned alongside the main buffer
return bufs
def _encode_tensor(self, obj: torch.Tensor):
if obj.nbytes < self.size_threshold and obj.is_cpu:
# Small tensors: inline as raw bytes
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, tensor_data(obj))
elif oob_consumer is not None and (data := oob_consumer(obj)) is not None:
# Out-of-band tensor IPC (shared memory)
assert isinstance(data, dict)
else:
# Large tensors: store index to aux buffer (zero-copy send)
data = len(self.aux_buffers)
self.aux_buffers.append(tensor_data(obj))
dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data
| Tier | Condition | Method | Performance |
|---|---|---|---|
| Inline | nbytes < threshold (default 256B) |
Embed as msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, ...) |
Single buffer, minimal overhead |
| Zero-copy | nbytes ≥ threshold, no OOB consumer |
Store buffer index in main message, send buffer as separate ZMQ frame | No copy -- ZMQ sends directly from tensor memory |
| OOB IPC | OOB consumer accepts tensor (e.g., shared memory) | Pass tensor via IPC mechanism, include metadata dict in message | Best for large multimodal tensors (images, video) |
class MsgpackDecoder:
def decode(self, bufs):
if isinstance(bufs, bytestr):
return self.decoder.decode(bufs)
self.aux_buffers = bufs # Multi-frame message
try:
return self.decoder.decode(bufs[0]) # Main frame
finally:
self.aux_buffers = ()
def _decode_tensor(self, arr):
dtype, shape, data = arr
if isinstance(data, dict):
# OOB tensor: retrieve from shared memory
return self.oob_tensor_provider(dtype, shape, data)
is_aux = isinstance(data, int)
buffer = self.aux_buffers[data] if is_aux else data
arr = torch.frombuffer(buffer, dtype=torch.uint8)
# Pin larger tensors for efficient CPU->GPU transfer
if not is_aux:
arr = arr.clone()
elif not self.share_mem:
arr = arr.pin_memory() if self.pin_tensors else arr.clone()
return arr.view(torch_dtype).view(shape)
The output thread in EngineCoreProc reuses bytearray buffers to minimize allocation overhead:
encoder = MsgpackEncoder()
reuse_buffers: list[bytearray] = []
pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
while True:
output = self.output_queue.get()
if output == EngineCoreProc.ENGINE_CORE_DEAD:
for socket in sockets:
socket.send(output)
break
client_index, outputs = output
outputs.engine_index = engine_index
# Reclaim buffers that zmq is finished with.
while pending and pending[-1][0].done:
reuse_buffers.append(pending.pop()[2])
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
buffers = encoder.encode_into(outputs, buffer)
tracker = sockets[client_index].send_multipart(buffers, copy=False, track=True)
@staticmethod
def make_client(multiprocess_mode, asyncio_mode, vllm_config, executor_class, log_stats):
if multiprocess_mode and asyncio_mode:
return EngineCoreClient.make_async_mp_client(vllm_config, executor_class, log_stats)
if multiprocess_mode and not asyncio_mode:
return SyncMPClient(vllm_config, executor_class, log_stats)
return InprocClient(vllm_config, executor_class, log_stats)
@staticmethod
def make_async_mp_client(vllm_config, executor_class, log_stats, ...):
parallel_config = vllm_config.parallel_config
if parallel_config.data_parallel_size > 1:
if parallel_config.data_parallel_external_lb:
return DPAsyncMPClient(...) # External load balancer
return DPLBAsyncMPClient(...) # Internal load balancer
return AsyncMPClient(...)
| Client | Process Model | Async | DP Support | Use Case |
|---|---|---|---|---|
InprocClient |
Single process | No | No | V0-compatible LLMEngine |
SyncMPClient |
Multi-process | No | Basic | Sync LLM with background EngineCore |
AsyncMPClient |
Multi-process | Yes | No (DP=1) | AsyncLLM / API server |
DPAsyncMPClient |
Multi-process | Yes | External LB | DP with external load balancer |
DPLBAsyncMPClient |
Multi-process | Yes | Internal LB | DP with built-in load balancing |
The simplest variant -- no IPC, no threads. The EngineCore runs in the same process:
vllm/v1/engine/core_client.py : InprocClientclass InprocClient(EngineCoreClient):
"""InprocClient: client for in-process EngineCore.
Intended for use in LLMEngine for V0-style add_request() and step()
EngineCore setup in this process (no busy loop)."""
def __init__(self, *args, **kwargs):
self.engine_core = EngineCore(*args, **kwargs)
def get_output(self) -> EngineCoreOutputs:
outputs, model_executed = self.engine_core.step_fn()
self.engine_core.post_step(model_executed=model_executed)
return outputs and outputs.get(0) or EngineCoreOutputs()
def add_request(self, request):
req, request_wave = self.engine_core.preprocess_add_request(request)
self.engine_core.add_request(req, request_wave)
The most sophisticated variant implements load-balanced request routing across multiple engine processes:
vllm/v1/engine/core_client.py : DPLBAsyncMPClient.get_core_engine_for_request()def get_core_engine_for_request(self, request):
# Check for explicit DP rank assignment or late-interaction routing.
if (eng_index := request.data_parallel_rank) is None and (
eng_index := get_late_interaction_engine_index(...)
) is None:
current_counts = self.lb_engines
min_score = sys.maxsize
eng_index = 0
for i in range(num_engines):
idx = (self.eng_start_index + i) % num_engines
waiting, running = current_counts[idx]
score = waiting * 4 + running # Weighted score: prioritize empty queues
if score < min_score:
min_score = score
eng_index = idx
# Increment local waiting count for better balancing
# between stats updates from the coordinator (every ~100ms).
current_counts[eng_index][0] += self.client_count
chosen_engine = self.core_engines[eng_index]
self.reqs_in_flight[request.request_id] = chosen_engine
return chosen_engine
score = waiting * 4 + running. The 4x weight on waiting requests means requests are steered towards engines with shorter queues. Between coordinator updates (~100ms), the client also locally increments the chosen engine's count to prevent burst routing.
The OutputProcessor lives in the frontend process and converts raw EngineCoreOutput (token IDs) into user-facing RequestOutput objects with decoded text.
def process_outputs(self, engine_core_outputs, engine_core_timestamp=None, iteration_stats=None):
"""Process the EngineCoreOutputs:
1) Compute stats for logging
2) Detokenize
3) Create and handle RequestOutput objects:
- If queue exists (AsyncLLM): put into queue for generate()
- If no queue (LLMEngine): return list of RequestOutputs
NOTE: vLLM V1 minimizes the number of python loops over the full
batch. This is the ONLY function that loops over EngineCoreOutputs."""
request_outputs = []
reqs_to_abort = []
for engine_core_output in engine_core_outputs:
req_state = self.request_states.get(engine_core_output.request_id)
if req_state is None:
continue # Ignore output for already-aborted request
# 1) Compute stats for this iteration.
self._update_stats_from_output(req_state, engine_core_output, ...)
# 2) Detokenize the token ids into text and perform stop checks.
if pooling_output is None:
stop_string = req_state.detokenizer.update(
new_token_ids, finish_reason == FinishReason.STOP
)
if stop_string:
finish_reason = FinishReason.STOP
stop_reason = stop_string
# 3) Compute sample and prompt logprobs if required.
req_state.logprobs_processor.update_from_output(engine_core_output)
# 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output(...):
if req_state.queue is not None:
req_state.queue.put(request_output) # AsyncLLM path
else:
request_outputs.append(request_output) # LLMEngine path
# 5) Free completed requests.
if finish_reason is not None:
if not engine_core_output.finished:
# Detokenizer detected stop string -- abort in EngineCore
reqs_to_abort.append(req_id)
EngineCoreOutputs. This minimizes Python-level overhead on the hot path -- all per-request operations (detokenization, logprobs, stats, output creation) are done in a single pass.
For AsyncLLM, each request has a RequestOutputCollector that uses asyncio.Event for backpressure:
class RequestOutputCollector:
"""Collects streamed RequestOutputs per request.
When streaming deltas, RequestOutputs are merged if the
producer gets ahead of the consumer."""
def __init__(self, output_kind, request_id):
self.aggregate = output_kind == RequestOutputKind.DELTA
self.output = None
self.ready = asyncio.Event()
def put(self, output):
"""Non-blocking put operation."""
if self.output is None or isinstance(output, Exception):
self.output = output
self.ready.set()
elif isinstance(self.output, RequestOutput):
# Merge outputs if producer ahead of consumer
self.output.add(output, aggregate=self.aggregate)
async def get(self):
"""Get operation blocks on put event."""
while (self.output) is None:
await self.ready.wait()
output = self.output
self.output = None
self.ready.clear()
if isinstance(output, Exception):
raise output
return output
The AsyncLLM frontend initializes both an InputProcessor and OutputProcessor in the same process, and delegates actual GPU computation to the EngineCore via the AsyncMPClient:
# Convert EngineInput --> EngineCoreRequest.
self.input_processor = InputProcessor(self.vllm_config, renderer)
# Converts EngineCoreOutputs --> RequestOutput.
self.output_processor = OutputProcessor(
renderer.tokenizer,
log_stats=self.log_stats,
stream_interval=self.vllm_config.scheduler_config.stream_interval,
)
# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_async_mp_client(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
)
def step(self):
# 1) Get EngineCoreOutput from the EngineCore.
outputs = self.engine_core.get_output()
# 2) Process EngineCoreOutputs.
iteration_stats = IterationStats() if self.log_stats else None
processed_outputs = self.output_processor.process_outputs(
outputs.outputs,
engine_core_timestamp=outputs.timestamp,
iteration_stats=iteration_stats,
)
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
# 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
return processed_outputs.request_outputs
For data-parallel deployments (DP > 1), vLLM runs a DPCoordinator in its own process. It aggregates queue statistics from all engine processes and manages the "request wave" protocol that coordinates their running/paused state.
class DPCoordinator:
"""Coordinator process used for data-parallel deployments (DP>1).
* Collects stats from each DP engine (waiting and running queue lengths)
* Publishes stats to all front-ends for load-balancing decisions
* Keeps track of the current DP "request wave" number
* Broadcasts START_DP_WAVE messages to engines"""
def __init__(self, parallel_config, enable_wave_coordination=True):
dp_size = parallel_config.data_parallel_size
context = get_mp_context()
parent_zmq_addr_pipe, child_zmq_addr_pipe = context.Pipe(duplex=False)
self.proc = context.Process(
target=DPCoordinatorProc.run_coordinator,
name="VLLM_DP_Coordinator",
kwargs={
"engine_count": parallel_config.data_parallel_size,
"front_publish_address": front_publish_address,
"back_output_address": back_output_address,
"back_publish_address": back_publish_address,
},
daemon=True,
)
self.proc.start()
The coordinator process uses three ZMQ sockets:
Publishes aggregated stats and wave state to all frontend API server processes. Frontends subscribe via XSUB pattern.
PULL: Receives stats and wave notifications from engine processes. XPUB: Broadcasts START_DP_WAVE messages to all engines.
The "request wave" is a coordination mechanism for MoE models with data parallelism, where all engines must synchronize:
vllm/v1/engine/coordinator.py : DPCoordinatorProc.process_input_socket()# Wave coordination: handle wave completion and start notifications
if (wave := outputs.wave_complete) is not None:
# Notification from rank 0 engine that we've moved into
# the global paused state (engines_running==False).
if current_wave <= wave:
new_wave = wave + 1
logger.debug("Moving DP wave from %d to %d.", current_wave, new_wave)
current_wave = new_wave
engines_running = False
wave_state_changed = True
elif (wave := outputs.start_wave) is not None:
# Engine received request for a non-current wave --
# ensure other engines progress to next wave.
engines_running = True
wave_state_changed = True
self._send_start_wave(publish_back, wave, eng_index)
# Engines send their stats to the coordinator.
if scheduler_stats:
stats = self.engines[eng_index].request_counts
stats[0] = scheduler_stats.num_waiting_reqs
stats[1] = scheduler_stats.num_running_reqs
stats_changed = True
# Publish to front-ends at interval (default 100ms).
to_publish = (engine_req_counts_list, current_wave, engines_running)
publish_front.send(msgspec.msgpack.encode(to_publish))
The coordinator also supports dynamic scaling of engine processes. When a SCALE_ELASTIC_EP message is received, it adds or removes EngineState entries:
if decoded[0] == "SCALE_ELASTIC_EP":
new_engine_count = decoded[1]
current_count = len(self.engines)
if new_engine_count > current_count:
for _ in range(new_engine_count - current_count):
self.engines.append(EngineState())
logger.info("DPCoordinator scaled up from %s to %s engines",
current_count, new_engine_count)
else:
self.engines = self.engines[:new_engine_count]
logger.info("DPCoordinator scaled down from %s to %s engines",
current_count, new_engine_count)
vllm/v1/engine/core.py -- EngineCore, EngineCoreProc, DPEngineCoreProc (busy loop, I/O threads, handshake)vllm/v1/engine/core_client.py -- InprocClient, SyncMPClient, AsyncMPClient, DPAsyncMPClient, DPLBAsyncMPClientvllm/v1/engine/async_llm.py -- AsyncLLM frontendvllm/v1/engine/llm_engine.py -- LLMEngine sync frontendvllm/v1/engine/output_processor.py -- OutputProcessor, RequestState, RequestOutputCollectorvllm/v1/engine/coordinator.py -- DPCoordinator, DPCoordinatorProcvllm/v1/serial_utils.py -- MsgpackEncoder, MsgpackDecoder