vLLM V1 Internals (v0.19.x) — Source Code Deep Dive

KV Cache Management

Data structures, APIs, important files, and the role of the KV Cache Manager in vLLM’s scheduler.

BlockPool FreeKVCacheBlockQueue KVCacheCoordinator CPU Offload
vllm/v1/core/{block_pool, kv_cache_manager, kv_cache_coordinator, kv_cache_utils}.py · vllm/v1/simple_kv_offload/manager.py
Back to vLLM Overview

1. Why KV Cache Management Is Hard

LLM serving lives and dies by KV cache management.

Core design principles:
  • Blocks, not tokens.
  • Merkle-chained hashes.
  • Single free queue, mixed contents.
  • O(1) doubly-linked list.
  • Reusable across devices.

Big picture: the four layers

KVCacheManager
Scheduler-facing API. allocate_slots, free, get_computed_blocks, get_block_ids
KVCacheCoordinator
Coordinates multiple KV cache groups
SingleTypeKVCacheManager
Manages a single group: translates request tokens to block lists
BlockPool
The actual pool: free queue + hash table

Sections 2–4 dissect each layer bottom up.

2. Core Data Structures

Four classes form the foundation.

2.1 KVCacheBlock

The atom.

@dataclass(slots=True)
class KVCacheBlock:
    block_id: int                                  # 0 .. num_gpu_blocks-1 (null_block = 0)
    ref_cnt: int = 0                              # >0 means "someone is using me, don't evict"
    _block_hash: BlockHashWithGroupId | None = None  # hash key when block is full and cached

    # Doubly-linked-list pointers (manipulated only by FreeKVCacheBlockQueue)
    prev_free_block: "KVCacheBlock | None" = None
    next_free_block: "KVCacheBlock | None" = None

    is_null: bool = False                         # placeholder block; never cached, never freed

Lifecycle of a block:

  1. Born at pool init
  2. Allocated via get_new_blocks()
  3. Cached once full
  4. Released
  5. Re-hit or evicted

2.2 FreeKVCacheBlockQueue

A custom doubly-linked list.

Queue ordering

Head = LRU (evicted first). Tail = most-recently-released.

popleft() / popleft_n()

Allocation: pop oldest candidate.

remove(block)

O(1) mid-queue removal.

num_free_blocks

Plain counter. ⚠ See §8.

2.3 BlockHashToBlockMap

The prefix cache's hash index.

class BlockHashToBlockMap:
    # Stored on BlockPool as .cached_block_hash_to_block
    _cache: dict[BlockHashWithGroupId, KVCacheBlock | dict[int, KVCacheBlock]]

    def get_one_block(self, key) -> KVCacheBlock | None: ...  # lookup
    def insert(self, key, block): ...                       # cache_full_blocks
    def pop(self, key, block_id) -> KVCacheBlock | None: ... # _maybe_evict_cached_block
    def __len__(self) -> int: return len(self._cache)         # ← true "how much content does CPU have?"
Key observation: A block sits in both the free queue and the hash map when it has ref_cnt=0 but its hash is still set.

2.4 BlockPool (the container)

Defined in block_pool.py:129.

class BlockPool:
    def __init__(self, num_gpu_blocks, enable_caching, hash_block_size, ...):
        self.num_gpu_blocks = num_gpu_blocks
        self.blocks: list[KVCacheBlock] = [KVCacheBlock(i) for i in range(num_gpu_blocks)]
        self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)
        self.cached_block_hash_to_block = BlockHashToBlockMap()

        # null_block (block_id=0) is a sentinel; never cached, never freed
        self.null_block = self.free_block_queue.popleft()
        self.null_block.is_null = True
        self.enable_caching = enable_caching
        self.kv_event_queue: list[KVCacheEvent] = []  # for distributed KV event broadcast

2.5 Visualizing the three states of a block

A 32-block pool at a moment in time
N A A A A A A A C C C C C C C C C C C C C C C C C E E E E E E E
N null (1) — sentinel, block_id=0    A active (7) — ref_cnt>0, in use    C cached (17) — ref_cnt==0 AND has hash. Sits in BOTH free queue and hash index.    E empty (7) — ref_cnt==0, no hash. Truly unused.
get_num_free_blocks() returns 24. Only 'cached' blocks can satisfy a prefix hit.

3. BlockPool’s Public API

Eight methods cover the entire scheduler surface.

Method Purpose State change
get_cached_block(hash, group_ids) Prefix-cache lookup Read-only
get_new_blocks(n) Allocate n blocks popleft_n; may evict cached hashes
touch(blocks) Increment ref_cnt on prefix hit Remove from free queue if ref=0; ref+=1
cache_full_blocks(...) Mark full block cache-eligible Insert into hash map
free_blocks(ordered_blocks) Release on finish / preempt ref -=1; append to free queue; hash stays
_maybe_evict_cached_block(block) Evict hash on reuse Pop from hash map; reset hash
evict_blocks(block_ids) Explicit eviction Same as above, per-block
get_num_free_blocks() Statistics Returns counter; ⚠ see §8

3.1 Annotated: get_new_blocks()

The primary allocation path.

def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
    if num_blocks > self.get_num_free_blocks():
        raise ValueError(f"Cannot get {num_blocks} free blocks")

    ret = self.free_block_queue.popleft_n(num_blocks)  # head of LRU → oldest

    if self.enable_caching:
        for block in ret:
            self._maybe_evict_cached_block(block)  # drop hash if it had one
            assert block.ref_cnt == 0
            block.ref_cnt += 1                        # take ownership
    return ret

Eviction happens inline at allocation time.

3.2 Annotated: touch()

def touch(self, blocks: Sequence[KVCacheBlock]) -> None:
    for block in blocks:
        # ref_cnt=0 means sitting in free queue as eviction candidate
        if block.ref_cnt == 0 and not block.is_null:
            self.free_block_queue.remove(block)  # O(1) mid-queue removal
        block.ref_cnt += 1

This is where the doubly-linked list pays off.

4. The Three Layers Above BlockPool

Each layer adds one concern.

4.1 SingleTypeKVCacheManager

Manages KV blocks for one cache group.

4.2 KVCacheCoordinator

Coordinates multiple groups (full + sliding + mamba).

class KVCacheCoordinator:
    managers: list[SingleTypeKVCacheManager]   # one per group
    block_pool: BlockPool                        # shortcut to group 0 (or shared)

    def allocate_slots(self, request, num_new_tokens) -> list[list[KVCacheBlock]] | None:
        # Fans out across all groups; returns None if any group runs out

    def find_longest_cache_hit(self, hashes, max_len) -> tuple[list[list[KVCacheBlock]], int]:
        # Used by both GPU-side lookup and CPU-offload connector

4.3 KVCacheManager

The Scheduler's only entry point.

get_computed_blocks(req)

Prefix lookup before prefill.

allocate_slots(req, num_tokens)

Allocate KV; None → OOM → preempt.

free(req)

Release block chain on finish.

get_block_ids(req_id)

Per-group block-id list.

new_step_starts()

Per-step bookkeeping reset.

block_pool

Exposes underlying BlockPool.

5. Important Files

All paths relative to vllm/.

File ~LOC What it defines Called by
v1/core/kv_cache_utils.py 1700+ KVCacheBlock, FreeKVCacheBlockQueue, hash helpers, BlockHash types block_pool.py, single_type manager
v1/core/block_pool.py 510 BlockPool, BlockHashToBlockMap Coordinator, SingleType, CPU offload
v1/core/single_type_kv_cache_manager.py 800+ SingleTypeKVCacheManager (per-group block allocation, hash assignment) Coordinator
v1/core/kv_cache_coordinator.py 400+ KVCacheCoordinator, factory get_kv_cache_coordinator(...) KVCacheManager, CPU offload scheduler
v1/core/kv_cache_manager.py 600+ KVCacheManager (scheduler-facing API) Scheduler
v1/core/kv_cache_metrics.py ~150 KVCacheMetricsCollector (residency histograms, allocation events) BlockPool (optional injection)
v1/kv_cache_interface.py ~300 KVCacheConfig, KVCacheGroupSpec, group specs (full / sliding / mamba) Engine init, factories
v1/simple_kv_offload/manager.py 700+ SimpleCPUOffloadScheduler (wires a 2nd BlockPool against pinned host memory) SimpleCPUOffloadConnector
v1/simple_kv_offload/worker.py ~600 SimpleCPUOffloadWorker (CUDA streams, GPU↔CPU DMA) SimpleCPUOffloadConnector
distributed/kv_transfer/kv_connector/v1/simple_cpu_offload_connector.py ~250 SimpleCPUOffloadConnector (thin wrapper over scheduler/worker halves) Scheduler.connector
Read order for newcomers: Read in this order.

6. GPU & CPU: The Same BlockPool, Twice

vLLM instantiates the same BlockPool class twice.

6.1 Where each pool is constructed

GPU pool (always)

vllm/v1/core/kv_cache_manager.py
  • Built in KVCacheManager.__init__
  • Backed by torch.zeros(...) on CUDA
  • Num blocks from --gpu-memory-utilization
  • Used for active prefill / decode

CPU pool (optional)

vllm/v1/simple_kv_offload/manager.py:110
  • Built in SimpleCPUOffloadScheduler init
  • Backed by pinned host memory
  • Num blocks from cpu_bytes_to_use_per_rank
  • Holds evicted/deferred GPU blocks

6.2 The code that creates the CPU pool

# vllm/v1/simple_kv_offload/manager.py:110
class SimpleCPUOffloadScheduler:
    def __init__(self, vllm_config, kv_cache_config, cpu_capacity_bytes, lazy_offload=False):
        self.cpu_kv_cache_config = self._derive_cpu_config(kv_cache_config, cpu_capacity_bytes)
        self.num_cpu_blocks = self.cpu_kv_cache_config.num_blocks

        # ← Same factory, different config
        self.cpu_coordinator: KVCacheCoordinator = get_kv_cache_coordinator(
            kv_cache_config=self.cpu_kv_cache_config,
            max_model_len=vllm_config.model_config.max_model_len,
            enable_caching=True,                        # prefix hash stays on
            enable_kv_cache_events=self.enable_kv_cache_events,
            hash_block_size=self.block_size,
        )
        self.cpu_block_pool: BlockPool = self.cpu_coordinator.block_pool

        # GPU side reference, bound by the connector after Scheduler builds kv_cache_manager
        self._gpu_block_pool: BlockPool | None = None

6.3 What differs, what doesn’t

  GPU BlockPool CPU BlockPool
Python class Same — vllm.v1.core.block_pool.BlockPool
Free queue + hash index Same
ref_cnt, touch, free, get_new_blocks Same API, same semantics
kv_cache_config GPU probe _derive_cpu_config(...) (derived)
Where tensors live CUDA device memory Pinned host memory
Who allocates GPU worker init_worker_distributed_environment SimpleCPUOffloadWorker via cudaHostAlloc
Who builds it KVCacheManager on Scheduler start SimpleCPUOffloadScheduler on connector init
Practical implication: CPU pool participates in prefix lookup on its own.

7. The Scheduler’s Role, Step by Step

Every Scheduler.schedule() call follows a predictable KV-touch sequence.

7.1 Interaction points per step

1
self.kv_cache_manager.new_step_starts()

Reset per-step bookkeeping. Top of schedule().

2
self.kv_cache_manager.block_pool.get_num_free_blocks()

Read for tracing. Scheduler itself doesn't gate on it.

3
self.kv_cache_manager.get_computed_blocks(request)

Prefix lookup on admission (waiting → running).

4
self.connector.get_num_new_matched_tokens(request, num_computed)

Connector-layer prefix lookup on CPU pool.

5
self.kv_cache_manager.allocate_slots(request, num_new_tokens, ...)

Allocate KV; None → preempt.

6
self.kv_cache_manager.free(preempted_req)

Return victim's blocks to free queue. Hashes NOT evicted.

7
self.kv_cache_manager.get_block_ids(request_id)

Produce block IDs for worker + offload connector.

8
self.kv_cache_manager.free(finished_req)

Release finished request's chain (may be deferred by pin TTL).

7.2 Example: a prefill step with a CPU offload hit

# Inside Scheduler.schedule() — simplified

def schedule(self):
    self.kv_cache_manager.new_step_starts()                          # 1

    if sched_trace.enabled():                                        # 2 (tracing only)
        free_blocks = self.kv_cache_manager.block_pool.get_num_free_blocks()

    for request in self.waiting:
        computed_blocks = self.kv_cache_manager.get_computed_blocks(request)  # 3 local

        if self.connector is not None:
            n_ext, is_async = self.connector.get_num_new_matched_tokens(       # 4 external (CPU)
                request, num_computed=len(computed_blocks) * block_size
            )

        new_blocks = self.kv_cache_manager.allocate_slots(request, ...)   # 5

        if new_blocks is None:
            victim = self._pick_preempt_victim()
            self.kv_cache_manager.free(victim)                        # 6
            new_blocks = self.kv_cache_manager.allocate_slots(request, ...) # retry

        block_ids = self.kv_cache_manager.get_block_ids(request.request_id)  # 7
        scheduler_output.add(request, block_ids)

    return scheduler_output

7.3 Tracing hooks

The agent-kvcache fork emits a step_snapshot event per schedule() call.

# vllm/v1/core/sched/scheduler.py:484 (simplified)
sched_trace.step_snapshot(
    step=self._sched_step_counter,
    running=self.running, waiting=self.waiting,
    free_blocks=self.kv_cache_manager.block_pool.get_num_free_blocks(),  # ⚠ see §8
    total_blocks=self.cache_config.num_gpu_blocks,
    cpu_free_blocks=cpu_block_pool.get_num_free_blocks() if cbp else None,
    cpu_total_blocks=getattr(sm, "num_cpu_blocks", None),
    num_pinned=len(self.pinned_requests),
    pinned_blocks=pinned_blocks,
    ...
)

8. Pitfall: the Semantics of get_num_free_blocks()

It's a plain counter of blocks in the free queue.

Concrete misreading (true story): The 19 is only ref_cnt>0 blocks. The actual cached content is ref=0 and counted as "free".

8.1 The right way to count CPU DRAM usage

# The actual "blocks with usable cached content" count:
cpu_cached_blocks = len(cpu_block_pool.cached_block_hash_to_block)

# Three distinct numbers the viewer should show:
empty_blocks  = free_count - cached_count         # truly empty
cached_blocks = len(cached_block_hash_to_block) # has content, evictable
ref_blocks    = total - free_count                # ref_cnt > 0 (in-flight, active)

The trace schema currently emits only free/total. Adding cpu_cached_blocks fixes this.

8.2 Why the API is designed this way

The counter is correct for allocation but wrong for reporting. The word "free" is overloaded.

9. One-Page Recap

Data structures

  • KVCacheBlockone record per physical block
  • FreeKVCacheBlockQueueO(1) doubly-linked LRU list
  • BlockHashToBlockMaphash → block index
  • BlockPoolowns all three; exposes full API

APIs that matter

  • get_cached_block(hash, group_ids) — prefix lookup
  • get_new_blocks(n) — allocate, may evict
  • touch(blocks) — increment ref on prefix hit
  • free_blocks(blocks) — decrement ref
  • get_num_free_blocks()⚠ free-queue length

Files

  • v1/core/kv_cache_utils.py — atom types
  • v1/core/block_pool.py — the pool
  • v1/core/kv_cache_coordinator.py — group coordinator
  • v1/core/kv_cache_manager.py — scheduler-facing API
  • v1/simple_kv_offload/manager.py — CPU offload pool builder

Role in scheduler

  1. new_step_starts() — per-step reset
  2. get_computed_blocks() — GPU prefix lookup
  3. Connector — CPU prefix lookup (optional)
  4. allocate_slots() — allocate for new tokens
  5. free() — release on preempt / finish
  6. get_block_ids() — produce scheduler output

GPU vs CPU pool

Same class, different config. Both participate in prefix lookup.

Gotcha

get_num_free_blocks() counts cached-but-evictable as 'free'. Can under-count DRAM used by 400×.