Distributed Execution, Parallelism & KV Cache Transfer

A source-level walkthrough of how vLLM V1 orchestrates multi-GPU and multi-node inference through Tensor, Pipeline, Data, and Expert Parallelism, plus the disaggregated KV Cache Transfer system for Prefill/Decode separation.

Source: vllm-project/vllm (V1 architecture)

Table of Contents

  1. Parallel Group Management & Topology
  2. Tensor Parallelism: ColumnParallelLinear & RowParallelLinear
  3. Pipeline Parallelism: PP Stage Coordination
  4. Data Parallelism: DPCoordinator & Load Balancing
  5. Expert Parallelism: FusedMoE & --enable-expert-parallel
  6. MultiprocExecutor: Subprocess Spawning & Shared-Memory MessageQueue
  7. RayDistributedExecutor: Multi-Node via Ray Compiled DAG
  8. KV Cache Transfer Interface: KVConnectorBase_V1
  9. NIXL Connector: High-Performance GPU-to-GPU KV Transfer
  10. LMCache Connector: Multi-Tier KV Offloading

1 Parallel Group Management & Topology

vLLM's distributed layer is rooted in vllm/distributed/parallel_state.py, adapted from Megatron-LM. At startup, initialize_model_parallel() constructs all process groups from a single tensor of global ranks reshaped into a multi-dimensional grid.

Rank Layout Order: The all_ranks tensor is shaped as [ExternalDP, DP, PP, PCP, TP]. Each parallelism dimension is extracted by transposing the desired axis to the last dimension, then reshaping to 2D.
vllm/distributed/parallel_state.py : initialize_model_parallel()
# the layout order is: ExternalDP x DP x PP x PCP x TP
all_ranks = torch.arange(world_size).reshape(
    -1,
    data_parallel_size,
    pipeline_model_parallel_size,
    prefill_context_model_parallel_size,
    tensor_model_parallel_size,
)

# Build TP groups -- last dimension
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)

# Build PP groups -- transpose PP to last dim
group_ranks = (
    all_ranks.transpose(2, 4)
    .reshape(-1, pipeline_model_parallel_size).unbind(0)
)

# Build DP groups -- transpose DP to last dim
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)

# Build EP groups -- DP * PCP * TP combined
group_ranks = (
    all_ranks.transpose(1, 2)
    .reshape(-1, data_parallel_size * prefill_context_model_parallel_size
              * tensor_model_parallel_size)
    .unbind(0)
)

GroupCoordinator: the Communication Wrapper

Each process group is wrapped in a GroupCoordinator that manages both a NCCL device group (for GPU collectives) and a Gloo CPU group (for metadata exchange). It also optionally creates a DeviceCommunicatorBase (NCCL/CustomAllReduce) and a MessageQueue shared-memory broadcaster.

vllm/distributed/parallel_state.py : GroupCoordinator.__init__
class GroupCoordinator:
    rank: int             # global rank
    ranks: list[int]      # global ranks in the group
    world_size: int       # size of the group
    local_rank: int       # local rank for device assignment
    rank_in_group: int    # rank inside the group
    cpu_group: ProcessGroup     # group for CPU (Gloo)
    device_group: ProcessGroup  # group for GPU (NCCL)
    device_communicator: DeviceCommunicatorBase | None

    def __init__(self, group_ranks, local_rank, backend, ...):
        for ranks in group_ranks:
            device_group = torch.distributed.new_group(ranks, backend=backend)
            cpu_group = torch.distributed.new_group(ranks, backend="gloo")
            if self.rank in ranks:
                self.ranks = ranks
                self.rank_in_group = ranks.index(self.rank)

Collective Operations

vllm/distributed/communication_op.py provides thin wrappers that fetch the TP group and invoke collectives:

vllm/distributed/communication_op.py
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
    """All-reduce the input tensor across model parallel group."""
    return get_tp_group().all_reduce(input_)

def tensor_model_parallel_all_gather(input_: torch.Tensor, dim=-1) -> torch.Tensor:
    return get_tp_group().all_gather(input_, dim)

def tensor_model_parallel_reduce_scatter(input_: torch.Tensor, dim=-1) -> torch.Tensor:
    return get_tp_group().reduce_scatter(input_, dim)

For torch.compile compatibility, these collectives are registered as custom ops with fake implementations:

direct_register_custom_op(
    op_name="all_reduce",
    op_func=all_reduce,
    fake_impl=all_reduce_fake,  # returns torch.empty_like(tensor)
)

2 Tensor Parallelism: Column & Row Parallel Linear

Tensor Parallelism: Column/Row Split Across 4 GPUs
Full Weight W
[H x O]
Col 0
Col 1
Col 2
Col 3
GPU 0
W[:,0:O/4]
Column Shard
GPU 1
W[:,O/4:O/2]
Column Shard
GPU 2
W[:,O/2:3O/4]
Column Shard
GPU 3
W[:,3O/4:O]
Column Shard
ColumnParallel: Split output dim → each GPU computes Y_i = X @ A_i → all-gather
RowParallel: Split input dim → each GPU computes Y_i = X_i @ A_i → all-reduce

Tensor Parallelism (TP) splits individual weight matrices across GPUs so that each GPU holds a shard of each layer. vLLM uses two complementary patterns from Megatron-LM, implemented in vllm/model_executor/layers/linear.py.

Diagram (a): Tensor Parallelism -- Column & Row Sharding
  ColumnParallelLinear:  Y = X @ A       (A split along columns)

         Full A [H, O]            GPU 0: A_0 [H, O/N]     GPU 1: A_1 [H, O/N]
    +-----------------+          +----------+             +----------+
    |                 |          |          |             |          |
    |     A           |   ===>   |   A_0    |             |   A_1    |
    |   [H x O]       |          | [H x O/N]|             | [H x O/N]|
    |                 |          +----------+             +----------+
    +-----------------+            X @ A_0 = Y_0            X @ A_1 = Y_1

    Each GPU computes its local output shard Y_i = X @ A_i
    If gather_output=True: all-gather Y_i across GPUs => full Y
    Otherwise: pass Y_i directly to next layer (RowParallelLinear)

  RowParallelLinear:  Y = X @ A + b      (A split along rows, X split)

         Full A [I, O]            GPU 0: A_0 [I/N, O]     GPU 1: A_1 [I/N, O]
    +-----------------+          +----------+             +----------+
    | A_0             |          |          |             |          |
    +-----------------+   ===>   |   A_0    |             |   A_1    |
    | A_1             |          | [I/N x O]|             | [I/N x O]|
    +-----------------+          +----------+             +----------+
                                   X_0 @ A_0               X_1 @ A_1

    Each GPU: Y_i = X_i @ A_i
    Then: all-reduce Y_i across GPUs => Y = sum(Y_i)
    Bias added only on rank 0 to avoid double-counting

ColumnParallelLinear

Splits the weight output dimension (columns) across TP ranks. Each GPU stores output_size / tp_size columns and computes a partial output. For QKV projections (QKVParallelLinear), Q/K/V heads are individually divided across ranks.

vllm/model_executor/layers/linear.py : ColumnParallelLinear.__init__
class ColumnParallelLinear(LinearBase):
    """Y = XA + b. A is parallelized along its second dimension as A = [A_1, ..., A_p]."""

    def __init__(self, input_size, output_size, bias=True,
                 gather_output=False, ...):
        # Divide the weight matrix along the last dimension
        self.tp_rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.output_size_per_partition = divide(output_size, self.tp_size)
        self.output_partition_sizes = [self.output_size_per_partition]

        # For fused QKV: each of Q, K, V is independently sharded
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
                divide(output_size, self.tp_size) for output_size in self.output_sizes
            ]
vllm/model_executor/layers/linear.py : ColumnParallelLinear.forward
def forward(self, input_):
    bias = self.bias if not self.skip_bias_add else None
    output_parallel = self.quant_method.apply(self, input_, bias)

    if self.gather_output and self.tp_size > 1:
        # All-gather across the partitions.
        output = tensor_model_parallel_all_gather(output_parallel)
    else:
        output = output_parallel  # pass shard directly to RowParallel

    return output

RowParallelLinear

Splits the weight input dimension (rows) across TP ranks. Each GPU receives a pre-sharded input shard from the preceding ColumnParallel layer, computes Y_i = X_i @ A_i, and then an all-reduce sums the partial results.

vllm/model_executor/layers/linear.py : RowParallelLinear
class RowParallelLinear(LinearBase):
    """Y = XA + b. A split along rows, X along columns:
              | A_0 |
          A = | ... |       X = [X_0, ..., X_p]
              | A_p |
    """

    def __init__(self, input_size, output_size, ...):
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size  # full output on each GPU

    def forward(self, input_):
        if self.input_is_parallel:
            input_parallel = input_   # already split from ColumnParallel
        else:
            split_input = split_tensor_along_last_dim(input_, self.tp_size)
            input_parallel = split_input[self.tp_rank]

        # Bias only fused into GEMM on rank 0 (avoids double-count)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

        if self.reduce_results and self.tp_size > 1:
            output = tensor_model_parallel_all_reduce(output_parallel)
        else:
            output = output_parallel
        return output
Key insight: Column and Row parallelism are always paired. A ColumnParallel layer (e.g., gate_up_proj) feeds its sharded output directly into a RowParallel layer (e.g., down_proj). Only one all-reduce per transformer block is needed -- at the end of the RowParallelLinear. This minimizes communication overhead.

QKVParallelLinear: Attention Head Sharding

vllm/model_executor/layers/linear.py : QKVParallelLinear
class QKVParallelLinear(ColumnParallelLinear):
    """Fused Q/K/V projection, sharded by attention head."""
    def __init__(self, hidden_size, head_size, total_num_heads, total_num_kv_heads, ...):
        self.num_heads = divide(total_num_heads, tp_size)
        if tp_size >= total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size, total_num_kv_heads)
        else:
            self.num_kv_heads = divide(total_num_kv_heads, tp_size)

        self.output_sizes = [
            self.num_heads * head_size * tp_size,      # q_proj
            self.num_kv_heads * head_size * tp_size,    # k_proj
            self.num_kv_heads * v_head_size * tp_size,  # v_proj
        ]

3 Pipeline Parallelism: PP Stage Coordination

Pipeline Parallelism: 4 Stages with Micro-Batches Flowing Through
uB-1
uB-2
uB-3
Input
Stage 0
Layers 0-7
GPU 0,1
send_tensor
Stage 1
Layers 8-15
GPU 2,3
send_tensor
Stage 2
Layers 16-23
GPU 4,5
send_tensor
Stage 3
Layers 24-31
GPU 6,7
Output
Last PP, TP-0
Micro-Batch Timeline (filling the pipeline)
t=0
t=1
t=2
t=3
t=4
t=5
t=6
S0:uB1
S0:uB2
S0:uB3
idle
idle
idle
idle
idle
S1:uB1
S1:uB2
S1:uB3
idle
idle
idle
idle
idle
S2:uB1
S2:uB2
S2:uB3
idle
idle
idle
idle
idle
S3:uB1
S3:uB2
S3:uB3
idle

Pipeline Parallelism (PP) splits the model vertically across GPUs, with each GPU holding a contiguous range of transformer layers. vLLM supports PP through both the MultiprocExecutor and RayDistributedExecutor.

Diagram (b): Pipeline Parallelism -- 4 PP Stages
  Input tokens (SchedulerOutput)
        |
        v
  +------------------+     send_tensor_dict()     +------------------+
  | PP Stage 0        | --------------------------> | PP Stage 1        |
  | Layers 0-7       |   IntermediateTensors      | Layers 8-15      |
  | GPU 0,1 (TP=2)   |                            | GPU 2,3 (TP=2)   |
  +------------------+                            +------------------+
                                                         |
                                                         v
  +------------------+     send_tensor_dict()     +------------------+
  | PP Stage 3        | <-------------------------- | PP Stage 2        |
  | Layers 24-31     |   IntermediateTensors      | Layers 16-23     |
  | GPU 6,7 (TP=2)   |                            | GPU 4,5 (TP=2)   |
  +------------------+                            +------------------+
        |
        v
    ModelRunnerOutput (from last PP stage, TP rank 0)

  PP Group Ranks (TP=2, PP=4): [g0, g2, g4, g6], [g1, g3, g5, g7]
  Each PP stage does: recv_tensor_dict() -> forward() -> send_tensor_dict()

PP Group and Navigation

The GroupCoordinator provides PP-stage navigation through properties:

vllm/distributed/parallel_state.py : GroupCoordinator
@property
def next_rank(self):
    """Return the global rank of the process that follows the caller"""
    return self.ranks[(self.rank_in_group + 1) % self.world_size]

@property
def prev_rank(self):
    """Return the global rank of the process that precedes the caller"""
    return self.ranks[(self.rank_in_group - 1) % self.world_size]

@property
def is_first_rank(self):
    return self.rank == self.ranks[0]

@property
def is_last_rank(self):
    return self.rank == self.ranks[-1]

P2P Communication: send_tensor_dict / recv_tensor_dict

PP stages exchange IntermediateTensors using point-to-point send/recv operations. The GroupCoordinator.send_object() serializes metadata via pickle over the Gloo CPU group, while GPU tensors are sent over NCCL:

vllm/distributed/parallel_state.py : send_object / recv_object
def send_object(self, obj, dst: int):
    # Serialize object to tensor
    object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
    size_tensor = torch.tensor([object_tensor.numel()], dtype=torch.long)
    # Send size then object over Gloo CPU group
    torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
    torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)

Output Rank Selection

The executor collects output only from the first TP rank of the last PP stage:

vllm/v1/executor/multiproc_executor.py
def _get_output_rank(self) -> int:
    # TP=8, PP=4 => world_size=32
    # output_rank = 32 - 8 * 1 = 24  (PP rank 3, TP rank 0)
    return (
        self.world_size
        - self.parallel_config.tensor_parallel_size
        * self.parallel_config.prefill_context_parallel_size
    )

4 Data Parallelism: DPCoordinator & Load Balancing

Data Parallelism: Replicas with Load Balancer Distributing Requests
Req A
Req B
Req C
Req D
Req E
Req F
DPCoordinator
ZMQ XPUB/PULL | Load Stats | Wave Coordination
Replica 0
Full Model Copy
GPU 0-3 (TP=4)
A
D
Replica 1
Full Model Copy
GPU 4-7 (TP=4)
B
E
Replica 2
Full Model Copy
GPU 8-11 (TP=4)
C
F

Data Parallelism (DP) replicates the entire model across GPU sets, with each replica handling different requests. vLLM's DPCoordinator (in vllm/v1/engine/coordinator.py) is a dedicated process that intermediates between front-end API servers and DP engine replicas.

DPCoordinator responsibilities:
vllm/v1/engine/coordinator.py : DPCoordinator
class DPCoordinator:
    """Coordinator process for data-parallel deployments (DP>1).
    Intermediates between DP engine rank processes and front-end API servers.
    """

    def __init__(self, parallel_config, enable_wave_coordination=True):
        dp_size = parallel_config.data_parallel_size
        assert dp_size > 1

        # Spawn the coordinator as a separate process
        self.proc = context.Process(
            target=DPCoordinatorProc.run_coordinator,
            name="VLLM_DP_Coordinator",
            kwargs={
                "engine_count": parallel_config.data_parallel_size,
                "front_publish_address": front_publish_address,  # ZMQ XPUB
                "back_output_address": back_output_address,      # ZMQ PULL
                "back_publish_address": back_publish_address,    # ZMQ XPUB
            },
        )

The Event Loop

The coordinator uses a ZMQ poller with three sockets: publish_front (XPUB to API servers), publish_back (XPUB to engines), and output_back (PULL from engines).

vllm/v1/engine/coordinator.py : DPCoordinatorProc.process_input_socket
class EngineState:
    def __init__(self):
        self.request_counts = [0, 0]  # [waiting, running]

# Main event loop
while True:
    events = poller.poll(timeout=max(min_timeout, wait_for - elapsed))
    if not events:
        # Timeout -- publish current stats to front-ends
        to_publish = (engine_req_counts_list, current_wave, engines_running)
        publish_front.send(msgspec.msgpack.encode(to_publish))

    if output_back in events:
        # Engine sent stats update
        outputs = decoder.decode(buffer)
        stats = self.engines[eng_index].request_counts
        # Update local engine state

    if publish_front in events:
        # Wave coordination: wake engines when new request arrives
        if not engines_running:
            engines_running = True
            self._send_start_wave(publish_back, current_wave, engine_to_exclude)
Wave coordination: Engines alternate between running and paused states. The "request wave" counter tracks transitions from running to paused. When a front-end sends a new request while engines are paused, the coordinator broadcasts a START_DP_WAVE to wake all engines.

5 Expert Parallelism: FusedMoE & --enable-expert-parallel

Expert Parallelism: Token Routing to Experts on Different GPUs
Token 0
Token 1
Token 2
Token 3
Token 4
Token 5
Gating Router (Top-K)
softmax(X @ W_gate) → select top-k experts per token
all-to-all dispatch
all-to-all dispatch
GPU 0
E0
E1
T0,T2,T4
GPU 1
E2
E3
T1,T3
GPU 2
E4
E5
T0,T5
GPU 3
E6
E7
T2,T4
all-to-all combine → weighted sum of expert outputs

For Mixture-of-Experts (MoE) models like Mixtral, DeepSeek-V3, etc., Expert Parallelism (EP) distributes experts across GPUs instead of slicing individual weight matrices. The core implementation lives in vllm/model_executor/layers/fused_moe/layer.py.

Expert Distribution: determine_expert_map()

vllm/model_executor/layers/fused_moe/layer.py
def determine_expert_map(
    ep_size, ep_rank, global_num_experts,
    expert_placement_strategy="linear", ...
) -> tuple[int, torch.Tensor | None, ...]:
    """Distribute experts evenly across EP ranks."""
    base_experts = global_num_experts // ep_size
    remainder = global_num_experts % ep_size
    local_num_experts = base_experts + 1 if ep_rank < remainder else base_experts

    # Create global-to-local expert mapping
    expert_map = torch.full((global_num_experts,), -1, dtype=torch.int32)
    if expert_placement_strategy == "linear":
        start_idx = ep_rank * base_experts + min(ep_rank, remainder)
        expert_map[start_idx : start_idx + local_num_experts] = torch.arange(
            0, local_num_experts, dtype=torch.int32
        )
    elif expert_placement_strategy == "round_robin":
        local_log_experts = torch.arange(ep_rank, global_num_experts, ep_size)
        expert_map[local_log_experts] = torch.arange(0, local_num_experts)

    return (local_num_experts, expert_map, expert_mask)

FusedMoE Layer: Bringing It Together

vllm/model_executor/layers/fused_moe/layer.py : FusedMoE
@CustomOp.register("fused_moe")
class FusedMoE(CustomOp):
    """FusedMoE layer containing gate_up_proj (w13) + down_proj (w2).
    Supports both TP-sharded experts and EP-distributed experts."""

    def __init__(self, num_experts, top_k, hidden_size, intermediate_size, ...):
        self.moe_parallel_config = FusedMoEParallelConfig.make(
            tp_size_=tp_size_, dp_size_=dp_size_,
            pcp_size_=pcp_size_, sp_size_=self.sp_size,
            vllm_parallel_config=vllm_config.parallel_config,
        )

        if self.use_ep:
            # Expert Parallelism: each rank gets a subset of experts
            local_num_experts, expert_map, expert_mask = determine_expert_map(
                ep_size=self.ep_size,
                ep_rank=self.ep_rank,
                global_num_experts=self.global_num_experts,
                expert_placement_strategy=self.expert_placement_strategy,
            )
            self.local_num_experts = local_num_experts
            self.register_buffer("_expert_map", expert_map)
        else:
            # TP mode: all experts replicated, weights TP-sharded
            self.local_num_experts = self.global_num_experts
            self._expert_map = None
EP group composition: The EP group spans DP * PCP * TP ranks. When --enable-expert-parallel is set, the EP group size is the product of all three, meaning experts are distributed across all parallel dimensions. Each rank's ep_rank and ep_size come from the EP group coordinator.

Fused MoE Kernel Zoo

The vllm/model_executor/layers/fused_moe/ directory contains numerous kernel backends:

FilePurpose
fused_moe.pyCore Triton fused MoE kernel
deep_gemm_moe.pyDeepGemm-based MoE (DeepSeek)
cutlass_moe.pyCUTLASS-based MoE kernels
fused_marlin_moe.pyMarlin quantized MoE
nixl_ep_prepare_finalize.pyNIXL-based EP all-to-all
all2all_utils.pyAll-to-all dispatch for EP
router/Token-to-expert routing logic
runner/MoE execution runners (default, batched)

6 MultiprocExecutor: Subprocess Spawning & Shared-Memory MessageQueue

Executor Hierarchy: EngineCore → Executor → Workers
EngineCore
Scheduler + KV Manager
Executor
MultiprocExecutor | RayDistributedExecutor
collective_rpc() via SharedMemory MessageQueue
Worker 0
GPU 0
ModelRunner
Worker 1
GPU 1
ModelRunner
Worker 2
GPU 2
ModelRunner
Worker 3
GPU 3
ModelRunner
Broadcast: shm MessageQueue (zero-copy)
Response: per-worker response MQ
Health: OS process sentinels

For single-node multi-GPU deployment, the MultiprocExecutor (in vllm/v1/executor/multiproc_executor.py) spawns one WorkerProc per GPU using Python's multiprocessing. Communication between the scheduler and workers uses shared-memory MessageQueues for zero-copy efficiency.

vllm/v1/executor/multiproc_executor.py : MultiprocExecutor._init_executor
class MultiprocExecutor(Executor):
    supports_pp: bool = True

    def _init_executor(self):
        tp_size, pp_size, pcp_size = self._get_parallel_sizes()
        assert self.world_size == tp_size * pp_size * pcp_size

        # Loopback for single-node communication
        distributed_init_method = get_distributed_init_method(
            get_loopback_ip(), get_open_port()
        )

        # Create shared-memory broadcast MessageQueue
        self.rpc_broadcast_mq = MessageQueue(
            self.world_size, self.local_world_size,
            max_chunk_bytes=max_chunk_bytes,
        )
        scheduler_output_handle = self.rpc_broadcast_mq.export_handle()

        # Spawn one WorkerProc per GPU
        for local_rank in range(self.local_world_size):
            global_rank = global_start_rank + local_rank
            unready_worker = WorkerProc.make_worker_process(
                vllm_config=self.vllm_config,
                local_rank=local_rank,
                rank=global_rank,
                distributed_init_method=distributed_init_method,
                input_shm_handle=scheduler_output_handle,  # shared memory!
                shared_worker_lock=shared_worker_lock,
            )

        # Wait for all workers to be ready
        self.workers = WorkerProc.wait_for_ready(unready_workers)

RPC via MessageQueue

The collective_rpc method broadcasts method calls to all workers through the shared-memory MessageQueue, then collects responses:

vllm/v1/executor/multiproc_executor.py : collective_rpc
def collective_rpc(self, method, timeout=None, args=(), kwargs=None,
                    non_block=False, unique_reply_rank=None, ...):
    # Broadcast the RPC call to all workers
    self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, output_rank))

    # Collect responses from response MessageQueues
    response_mqs = self.response_mqs
    if output_rank is not None:
        response_mqs = (response_mqs[output_rank],)  # only read from one

    def get_response():
        responses = []
        for mq in response_mqs:
            status, result = mq.dequeue(timeout=dequeue_timeout)
            if status != WorkerProc.ResponseStatus.SUCCESS:
                raise RuntimeError(f"Worker failed: {result}")
            responses.append(result)
        return responses[0] if output_rank is not None else responses

Worker Health Monitoring

A daemon thread monitors worker process liveness using OS process sentinels:

def monitor_workers():
    sentinels = [h.proc.sentinel for h in workers]
    died = multiprocessing.connection.wait(sentinels)
    # If any worker dies, shut down executor and invoke failure callback
    _self.is_failed = True
    _self.shutdown()
    callback()

7 RayDistributedExecutor: Multi-Node via Ray Compiled DAG

For multi-node deployment, RayDistributedExecutor (in vllm/v1/executor/ray_executor.py) uses Ray actors as GPU workers and Ray Compiled DAG for PP pipeline communication.

Worker Creation and Rank Assignment

vllm/v1/executor/ray_executor.py : _init_workers_ray
class RayDistributedExecutor(Executor):
    uses_ray: bool = True
    supports_pp: bool = True

    def _init_workers_ray(self, placement_group, ...):
        # pp_tp_workers[pp_rank][tp_rank] = worker actor
        self.pp_tp_workers: list[list[RayWorkerWrapper]] = []

        # Create Ray actors with GPU resources
        for rank, bundle_id in enumerate(bundle_indices):
            worker = ray.remote(
                num_cpus=0, num_gpus=num_gpus,
                scheduling_strategy=PlacementGroupSchedulingStrategy(
                    placement_group=placement_group,
                    placement_group_bundle_index=bundle_id,
                ),
            )(RayWorkerWrapper).remote(rpc_rank=rank)

        # Sort workers: driver node first, then by IP
        sorted_workers = sorted(worker_metadata, key=sort_by_driver_then_worker_ip)

        # Organize into PP x TP grid
        for pp_rank in range(pp_size):
            for tp_rank in range(tp_size):
                rank = pp_rank * tp_size + tp_rank
                self.pp_tp_workers[pp_rank].append(self.workers[rank])

Compiled DAG for PP Communication

The compiled DAG defines the data flow between PP stages. Each TP group within a PP stage executes in SPMD fashion (same program, multiple data):

vllm/v1/executor/ray_executor.py : _compiled_ray_dag
def _compiled_ray_dag(self, enable_asyncio):
    from ray.dag import InputNode, MultiOutputNode

    with InputNode() as input_data:
        # Example DAG: PP=2, TP=4
        # SchedulerOutput -> [0,1,2,3] -> IntermediateTensors -> [4,5,6,7] -> Output

        outputs = [input_data for _ in self.pp_tp_workers[0]]
        for pp_rank, tp_group in enumerate(self.pp_tp_workers):
            # Each PP stage: TP workers execute in parallel (SPMD)
            outputs = [
                worker.execute_model_ray.bind(outputs[i])
                for i, worker in enumerate(tp_group)
            ]

            if pp_rank < last_pp_rank and channel_type != "shm":
                # NCCL transport for intermediate tensors between PP stages
                outputs = [
                    output.with_tensor_transport(transport=channel_type)
                    for output in outputs
                ]

        forward_dag = MultiOutputNode(outputs)

    return forward_dag.experimental_compile(
        enable_asyncio=enable_asyncio,
        _overlap_gpu_communication=envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM,
    )
PP pipelining with Ray: The max_concurrent_batches property returns the PP size, enabling up to PP batches to be in-flight simultaneously. This fills the pipeline and amortizes the bubble overhead.
@property
def max_concurrent_batches(self) -> int:
    pp_size = self.parallel_config.pipeline_parallel_size
    return 2 if pp_size <= 1 and self.scheduler_config.async_scheduling else pp_size

8 KV Cache Transfer Interface: KVConnectorBase_V1

KV Cache Transfer: RDMA Between Prefill & Decode Nodes
PREFILL NODE
Scheduler
request_finished()
→ block_ids, engine_id
Worker (GPU)
KV cache in VRAM
save_kv_layer()
API Proxy
(metadata)
RDMA GPU→GPU
KV blocks transfer
DECODE NODE
Scheduler
get_num_new_matched_tokens()
build_connector_meta()
Worker (GPU)
start_load_kv()
wait_for_layer_load()

The KV Cache Transfer system enables disaggregated Prefill/Decode (P/D) serving, where a prefill instance computes the KV cache and a separate decode instance retrieves it. The abstract interface is defined in vllm/distributed/kv_transfer/kv_connector/v1/base.py.

Two-sided architecture: Each connector runs in two roles:
Diagram (c): KV Cache Transfer in Disaggregated P/D
  PREFILL INSTANCE (P)                       DECODE INSTANCE (D)

  +-----------------------------+            +-----------------------------+
  | Scheduler                    |            | Scheduler                    |
  |  request_finished()         |            |  get_num_new_matched_tokens()|
  |    => {do_remote_decode,    | ---------> |    => num_external_tokens    |
  |        remote_block_ids,    |  (via API  |  update_state_after_alloc()  |
  |        remote_engine_id,    |   proxy)   |    => allocate blocks        |
  |        remote_host:port}    |            |  build_connector_meta()      |
  +-----------------------------+            |    => NixlConnectorMetadata  |
                                             +-----------------------------+
                                                         |
                                                         v
  +-----------------------------+            +-----------------------------+
  | Worker (GPU)                  |            | Worker (GPU)                  |
  |  KV cache in GPU memory     | <========= |  start_load_kv()             |
  |  save_kv_layer() [optional] |   RDMA     |    => initiate NIXL xfer     |
  |  wait_for_save()            |  GPU-GPU   |  wait_for_layer_load()       |
  |  get_finished()             |            |    => sync per-layer         |
  |    => {send_done_req_ids}   |            |  get_finished()              |
  +-----------------------------+            |    => {recv_done_req_ids}    |
                                             +-----------------------------+

Scheduler-Side Methods

vllm/distributed/kv_transfer/kv_connector/v1/base.py : KVConnectorBase_V1
class KVConnectorBase_V1(ABC):

    @abstractmethod
    def get_num_new_matched_tokens(self, request, num_computed_tokens
    ) -> tuple[int | None, bool]:
        """Get number of tokens that can be loaded from external KV cache
        beyond what is already computed. Returns (count, is_async)."""

    @abstractmethod
    def update_state_after_alloc(self, request, blocks, num_external_tokens):
        """Update connector state after block allocation."""

    @abstractmethod
    def build_connector_meta(self, scheduler_output) -> KVConnectorMetadata:
        """Build metadata for this step. Sent to workers."""

    def request_finished(self, request, block_ids) -> tuple[bool, dict | None]:
        """Called when request finishes. Returns (delay_free, kv_transfer_params).
        If delay_free=True, blocks are held for remote decode to fetch."""
        return False, None

Worker-Side Methods

    @abstractmethod
    def start_load_kv(self, forward_context, **kwargs):
        """Start loading KV from connector to vLLM's paged buffer.
        Called BEFORE forward pass to enable async loading."""

    @abstractmethod
    def wait_for_layer_load(self, layer_name: str):
        """Block until layer i load is done.
        Called FROM WITHIN attention layer for layer-by-layer pipelining."""

    @abstractmethod
    def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs):
        """Save a layer of KV cache to the connector (maybe async).
        Called from within attention layer during forward pass."""

    @abstractmethod
    def wait_for_save(self):
        """Block until all saves complete. Called when forward context exits."""

    def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
        """Pre-register KV caches with connector (e.g., for NIXL memory registration)."""

    def get_finished(self, finished_req_ids) -> tuple[set | None, set | None]:
        """Returns (send_done_ids, recv_done_ids) for async transfer tracking."""
Design philosophy: The connector interface separates scheduling decisions (scheduler side) from data movement (worker side). The scheduler uses get_num_new_matched_tokens() to decide whether to skip prefill computation for tokens that can be loaded remotely. The worker side performs actual RDMA/GPU transfers asynchronously, overlapping with model forward pass computation.

9 NIXL Connector: High-Performance GPU-to-GPU KV Transfer

The NIXL (NVIDIA Inference Xfer Library) connector (vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py) provides direct GPU-to-GPU RDMA transfers for KV cache data between prefill and decode instances, potentially across nodes.

Architecture: Scheduler + Worker Split

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py : NixlConnector
class NixlConnector(KVConnectorBase_V1, SupportsHMA):

    def __init__(self, vllm_config, role, kv_cache_config):
        if role == KVConnectorRole.SCHEDULER:
            self.connector_scheduler = NixlConnectorScheduler(
                vllm_config, self.engine_id, kv_cache_config)
            self.connector_worker = None
        elif role == KVConnectorRole.WORKER:
            self.connector_scheduler = None
            self.connector_worker = NixlConnectorWorker(
                vllm_config, self.engine_id, kv_cache_config)

NixlConnectorWorker: RDMA Agent

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py : NixlConnectorWorker
class NixlConnectorWorker:
    def __init__(self, vllm_config, engine_id, kv_cache_config):
        # Initialize NIXL wrapper (RDMA agent)
        self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)

        # Per-engine remote agent tracking
        self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)

        # TP-aware setup
        self.tp_rank = get_tensor_model_parallel_rank()
        self.world_size = get_tensor_model_parallel_world_size()

        # Transfer tracking
        self._recving_transfers = defaultdict[ReqId, list[TransferHandle]](list)
        self._reqs_to_send: dict[ReqId, float] = {}

        # Background handshake thread
        self._handshake_initiation_executor = ThreadPoolExecutor(max_workers=1)

Handshake Protocol

Before any transfer, NIXL agents must exchange metadata via a ZMQ-based side channel. The scheduler-side listener serves NixlHandshakePayload containing a compatibility hash and the serialized NixlAgentMetadata:

@dataclass
class NixlAgentMetadata:
    engine_id: str
    agent_metadata: bytes          # NIXL serialized agent info
    kv_caches_base_addr: list[int]  # Base addresses of KV caches
    device_id: int
    num_blocks: int
    block_lens: list[int]           # Byte sizes per block per layer
    kv_cache_layout: str            # "HND" or "NHD"
    block_size: int
    ssm_sizes: tuple[int, int]     # For Mamba hybrid models

@dataclass
class NixlHandshakePayload(KVConnectorHandshakeMetadata):
    compatibility_hash: str         # SHA-256 of model/version config
    agent_metadata_bytes: bytes     # Encoded NixlAgentMetadata

Scheduler-Side: Prefill Completion Flow

nixl_connector.py : NixlConnectorScheduler.request_finished
def request_finished(self, request, block_ids):
    # On prefill side: request completed, blocks ready for remote decode
    params = request.kv_transfer_params
    if params.get("do_remote_decode"):
        delay_free_blocks = any(len(group) > 0 for group in block_ids)
        if delay_free_blocks:
            # Hold blocks for remote decode to fetch
            self._reqs_need_send[request.request_id] = (
                time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
            )

        return delay_free_blocks, dict(
            do_remote_prefill=True,
            remote_block_ids=block_ids,
            remote_engine_id=self.engine_id,
            remote_request_id=request.request_id,
            remote_host=self.side_channel_host,
            remote_port=self.side_channel_port,
            tp_size=self.vllm_config.parallel_config.tensor_parallel_size,
        )

Decode-Side: Requesting KV from Remote

def get_num_new_matched_tokens(self, request, num_computed_tokens):
    params = request.kv_transfer_params
    if params and params.get("do_remote_prefill"):
        # Remote prefill: claim all prompt tokens from remote
        actual = self._mamba_prefill_token_count(len(token_ids))
        count = actual - num_computed_tokens
        if count > 0:
            return count, True   # (num_tokens, is_async=True)
    return 0, False
Heterogeneous TP support: NIXL supports different TP sizes between prefill and decode instances. The TpKVTopology class maps local TP ranks to the appropriate remote TP rank(s) via kv_topo.get_target_remote_ranks(remote_tp_size).

NIXL Features Summary

FeatureDetail
TransportUCX (RDMA), with fallback to CPU DRAM staging
KV LayoutPrefers HND for optimal transfer performance
HandshakeZMQ side-channel with SHA-256 compatibility hash
Cross-Layer BlocksOptional single tensor for all layers (enable_cross_layers_blocks)
Host BufferSupports CPU-staged transfer for non-CUDA platforms (TPU, XPU)
Hybrid ModelsSupports Mamba SSM state + attention KV co-transfer
HMAImplements SupportsHMA for hybrid memory allocation

10 LMCache Connector: Multi-Tier KV Offloading

The LMCache connector (vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py) provides multi-tier KV cache offloading -- saving and loading KV cache data to/from external caches (local disk, remote storage, etc.) via the LMCache library.

vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
class LMCacheConnectorV1(KVConnectorBase_V1):

    @classmethod
    def requires_piecewise_for_cudagraph(cls, extra_config):
        """LMCache requires PIECEWISE CUDA graph mode when layerwise
        operations are enabled. wait_for_layer_load and save_kv_layer
        perform actual async synchronization that cannot be captured
        in CUDA graphs."""
        return extra_config.get("use_layerwise", False)

    def __init__(self, vllm_config, role, kv_cache_config):
        # Choose between native and dev LMCache implementations
        use_native = vllm_config.kv_transfer_config.get_from_extra_config(
            "use_native", False
        )
        if use_native:
            from vllm.distributed.kv_transfer.kv_connector.v1 \
                import lmcache_integration
            cls = lmcache_integration.vllm_v1_adapter.LMCacheConnectorV1Impl
        else:
            from lmcache.integration.vllm.vllm_v1_adapter import (
                LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl
            )
            cls = LMCacheConnectorLatestImpl

        self._lmcache_engine = cls(vllm_config, role, self)

Layer-by-Layer Async Pipeline

LMCache supports layerwise load/save, where KV for each layer is loaded just before that layer executes (overlapping with previous layers):

# Worker-side: called before forward pass
def start_load_kv(self, forward_context, **kwargs):
    self._lmcache_engine.start_load_kv(forward_context, **kwargs)

# Worker-side: called inside each attention layer
def wait_for_layer_load(self, layer_name):
    self._lmcache_engine.wait_for_layer_load(layer_name)

# Worker-side: called inside each attention layer
def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs):
    self._lmcache_engine.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs)

# Worker-side: called after forward pass completes
def wait_for_save(self):
    self._lmcache_engine.wait_for_save()

Scheduler-Side: Token Matching

def get_num_new_matched_tokens(self, request, num_computed_tokens):
    """Query LMCache for available KV data."""
    return self._lmcache_engine.get_num_new_matched_tokens(
        request, num_computed_tokens
    ), False  # synchronous load

KV Cache Events: Tracking Stored Blocks

LMCache emits KVCacheEvents when blocks are stored, enabling the scheduler to track what is available in the external cache:

class LMCacheKVEvents(KVConnectorKVEvents):
    def __init__(self, num_workers):
        self._aggregator = KVEventAggregator(num_workers)

    def aggregate(self):
        """Aggregate KV events from all workers, retain only common events."""
        common_events = self._aggregator.get_common_events()
        self._aggregator.clear_events()
        self._aggregator.add_events(common_events)
        return self
LMCache vs NIXL:

Quick Reference: Parallelism Strategies

Strategy What is Split Communication Key Files
TP Tensor Weight matrices (columns/rows) all-reduce per block, all-gather for QKV linear.py, parallel_state.py
PP Pipeline Transformer layers (vertical) P2P send/recv of IntermediateTensors parallel_state.py, ray_executor.py
DP Data Requests across replicas DPCoordinator (ZMQ) for load stats coordinator.py
EP Expert MoE experts across GPUs all-to-all for token routing fused_moe/layer.py, all2all_utils.py
KV Transfer KV cache between P/D instances NIXL RDMA or LMCache multi-tier nixl_connector.py, lmcache_connector.py

Key Source Files

FilePurpose
vllm/distributed/parallel_state.pyAll group creation (TP/PP/DP/EP/PCP/DCP), GroupCoordinator
vllm/distributed/communication_op.pyCollective operation wrappers (all_reduce, all_gather, etc.)
vllm/model_executor/layers/linear.pyColumnParallelLinear, RowParallelLinear, QKVParallelLinear
vllm/model_executor/layers/fused_moe/layer.pyFusedMoE with EP support, expert_map computation
vllm/v1/executor/multiproc_executor.pySingle-node multi-GPU via subprocesses + shared-memory MQ
vllm/v1/executor/ray_executor.pyMulti-node via Ray actors + Compiled DAG for PP
vllm/v1/engine/coordinator.pyDPCoordinator for load-balancing across DP replicas
vllm/distributed/kv_transfer/kv_connector/v1/base.pyKVConnectorBase_V1 abstract interface
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.pyNIXL GPU-to-GPU RDMA connector
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.pyLMCache multi-tier offloading connector

Generated from vLLM source code analysis. All code snippets are from the actual vLLM repository at vllm-project/vllm.