vLLM V1 Internals (v0.19.x) - Source Code Deep Dive

The KV Connector & CPU Offload Backend

How vLLM V1 plugs in remote KV stores, disaggregated prefill hosts, and DRAM offload tiers through a single scheduler/worker split interface -- plus a deep dive of the modern OffloadingConnector and its vllm/v1/kv_offload backend: the connector/spec layer split, a pluggable LRU/ARC eviction policy, reuse-frequency gating, deferred offload, and the explicit-DMA KV transfer path.

vllm/distributed/kv_transfer/kv_connector/v1/{base,offloading_connector}.py + offloading/{scheduler,worker}.py
vllm/v1/kv_offload/{abstract,spec,factory,reuse_manager}.py + cpu/{spec,manager,policies}.py + worker/{worker,cpu_gpu}.py
Back to vLLM Overview

1. Why a Connector Exists

vLLM's V1 scheduler owns a single GPU block pool and treats it as the only truth about kv_cache. Real deployments need more: DRAM offload to survive bigger working sets, remote stores that share prefixes across replicas, and prefill/decode disaggregation that hands off KV from one machine to another. Rather than merging every backend into the scheduler, V1 exposes a narrow interface -- the KV Connector -- that lets any external KV source participate in the scheduling loop without the scheduler knowing how it works internally.

Key design choice: the scheduler never calls CPU offload or any specific backend directly. It calls self.connector.get_num_new_matched_tokens() and self.connector.update_state_after_alloc(). Whether that means pinned DRAM, RDMA to another GPU, S3, or Mooncake is entirely the connector's problem.

1.1 The Scheduler / Worker Split

Every connector has two roles. The scheduler-side role is invoked inside the scheduler process, runs synchronously during schedule(), and owns planning: answering "do you have this prefix?" and "here are the GPU blocks I reserved, please plan the transfer". The worker-side role lives in each model worker, executes in the forward-pass path, and owns actual I/O: launching async memcpys, driving network transports, polling completions.

SCHEDULER

Scheduler-side role

Co-located with the Python scheduler process. Pure planning. No tensors, no CUDA calls.

  • get_num_new_matched_tokens()
  • update_state_after_alloc()
  • build_connector_meta()
  • update_connector_output()
  • request_finished()
  • take_events()
WORKER

Worker-side role

Co-located with each model worker / GPU process. Owns the CUDA streams, network handles, pinned CPU tensors, and actual memcpys.

  • register_kv_caches()
  • bind_connector_metadata()
  • start_load_kv() / save_kv_layer()
  • get_finished()
  • handle_preemptions()
  • build_connector_worker_meta()

The factory (KVConnectorFactory.create_connector) is called twice -- once with role=SCHEDULER in the scheduler process, once with role=WORKER in every worker process. Both sides see the same connector class, but the constructor branches and only instantiates the half that is needed.

# vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
spec = OffloadingSpecFactory.create_spec(vllm_config, kv_cache_config)

if role == KVConnectorRole.SCHEDULER:
    self.connector_scheduler = OffloadingConnectorScheduler(spec)
elif role == KVConnectorRole.WORKER:
    self.connector_worker = OffloadingConnectorWorker(spec)

2. The KVConnectorBase_V1 Contract

All connectors subclass KVConnectorBase_V1 in vllm/distributed/kv_transfer/kv_connector/v1/base.py. The contract is small but load-bearing -- the scheduler's entire extensibility story hinges on these eight scheduler-side methods and seven worker-side methods.

2.1 Scheduler-side Methods

get_num_new_matched_tokens(req, n) → (int|None, bool)

Look up how many additional prefix tokens the connector already has for this request, beyond the n already matched locally on GPU. Returns None to request "skip me this step, ask again later". Second element: is_async -- if True, this request will wait for a remote load to finish before running.

update_state_after_alloc(req, blocks, n_ext)

The scheduler just reserved blocks on GPU and n_ext of the tokens were claimed by you. Now the connector records the pairing: "remote block X must land in GPU block Y."

build_connector_meta(scheduler_output) → KVConnectorMetadata

End-of-step hook. The connector bundles everything the worker needs to know this step -- blocks to load, blocks to store, preemption flush flags -- into a single metadata dataclass attached to scheduler_output.

update_connector_output(output)

Worker returned completion events for prior steps. The scheduler-side manager accepts "these stores are done", "these loads arrived", and does the bookkeeping -- promoting CPU blocks into the prefix-cache hash table, releasing reference counts, etc.

request_finished(req, block_ids) → (bool, params?)

Called exactly once when a request terminates. Returning True tells the scheduler "don't free these GPU blocks yet, I'm still using them for an outbound transfer." The blocks will be released later by the worker's get_finished().

take_events() → Iterable[KVCacheEvent]

Optional. Surfaces KV cache events (BlockStored / BlockRemoved) for the observability subsystem, letting external dashboards see what a non-GPU tier is doing.

2.2 Worker-side Methods

register_kv_caches(kv_caches)

Called once at startup with the per-layer GPU KV tensors. The worker inspects stride / layout, allocates a mirrored pinned CPU pool, pins it via cudaHostRegister, and creates low-priority CUDA streams.

bind_connector_metadata(meta)

At the start of every forward pass the worker receives the per-step KVConnectorMetadata that the scheduler produced. Pending load / store event indices are enqueued here.

start_load_kv() / wait_for_layer_load()

Hooks for connectors that want to overlap per-layer KV loads with compute (NIXL, Mooncake). Most DRAM offload backends make these no-ops and defer everything to get_finished().

save_kv_layer(layer_name, kv) / wait_for_save()

Symmetric hooks for store-side overlap. Again, DRAM offload can no-op these -- stores are launched in get_finished() on a dedicated stream so they run in parallel with compute automatically.

get_finished(req_ids) → (sent, recv)

The workhorse. Called after model execution. Returns which requests finished async sends / receives this step. OffloadingConnector uses this moment to poll CUDA events and report completed transfers.

handle_preemptions(meta)

Before preempted blocks can be reused by another request, every in-flight transfer touching them must complete. This method is a barrier.

build_connector_worker_meta() → KVConnectorWorkerMetadata

Reverse direction: the worker reports to the scheduler which event indices completed this step. The scheduler-side manager aggregates across TP/PP ranks to decide "this store is fully done."

SupportsHMA (mixin)

Hybrid Memory Allocator opt-in. Connectors declare they can cope with per-kv-group block lists via request_finished_all_groups(). The factory enforces that HMA-enabled deployments only load compatible connectors.

3. How the Scheduler Drives a Connector

Inside vllm/v1/core/sched/scheduler.py, the scheduler calls the connector at five carefully chosen points. The path below is exactly what your local 0.19.0 checkout at Agentic_KVCache_management shows.

connector built by factory KVConnectorFactory.create_connector
Once per role at startup, the factory builds the connector. OffloadingConnector.__init__ creates the OffloadingSpec and instantiates only its scheduler-side or worker-side half. The worker half then calls register_kv_caches to canonicalize the GPU KV layout.
per schedule() tick —————
get_num_new_matched_tokens sched/scheduler.py:690-715
For every WAITING request, after kv_cache_manager.get_computed_blocks() finds local GPU hits, the scheduler asks the connector for extra external hits. load_kv_async flips the request into WAITING_FOR_REMOTE_KVS if the connector returns True.
GPU blocks reserved via kv_cache_manager.allocate_slots()
update_state_after_alloc sched/scheduler.py:843-857
Passes the freshly-allocated GPU blocks to the connector so it can record the transfer plan. Also ticks the prefix-cache-hit stat counter for observability.
end of schedule() tick
build_connector_meta sched/scheduler.py:1027
Attached to scheduler_output.kv_connector_metadata. This flows through IPC to every worker process.
IPC → worker forward pass
worker.get_finished() offloading/worker.py
Worker submits async CUDA copies on dedicated per-transfer streams, then non-blockingly polls event completion and reports finished sends / receives back to the scheduler.
ModelRunnerOutput.kv_connector_output → scheduler
update_connector_output sched/scheduler.py:2206-2218
Completion arrives back at the scheduler. Finished loads release GPU/CPU refs; finished stores promote CPU blocks into the prefix cache map.
Figure 1: The five scheduler ↔ connector touchpoints per step. Colours: rose = scheduler side, cyan = worker side.
Where does the worker learn the KV layout? At worker startup, register_kv_caches probes each attention backend's get_kv_cache_shape and sorts tensor strides to discover which physical dimension is num_blocks. This is what lets a single offload backend serve FlashAttention, Mamba, and cross-layer layouts without any layout knowledge baked into the scheduler -- the worker hands the spec layer a uniform CanonicalKVCaches view. Details in §12.
# vllm/v1/core/sched/scheduler.py  --  hit check, line 690
if self.connector is not None:
    ext_tokens, load_kv_async = (
        self.connector.get_num_new_matched_tokens(
            request, num_new_local_computed_tokens
        )
    )
    if ext_tokens is None:
        # Connector couldn't decide yet -- requeue and move on.
        request_queue.pop_request()
        step_skipped_waiting.prepend_request(request)
        continue
    request.num_external_computed_tokens = ext_tokens
    num_external_computed_tokens = ext_tokens

# ... allocate_slots() reserves GPU blocks ...

# vllm/v1/core/sched/scheduler.py  --  post-alloc callback, line 843
if self.connector is not None:
    self.connector.update_state_after_alloc(
        request,
        self.kv_cache_manager.get_blocks(request_id),
        num_external_computed_tokens,
    )

4. Connectors Registered in vLLM 0.19.0

The list below reflects the actual register_connector(...) calls at the bottom of vllm/distributed/kv_transfer/kv_connector/factory.py (lines 146-222). "HMA" means the connector implements the SupportsHMA mixin and works with the hybrid memory allocator.

Name Purpose Backend Tags
ExampleConnector
example_connector.py
Reference implementation for writing new connectors. Saves KV to local disk files. Disk demo
ExampleHiddenStatesConnector Transfers hidden states between instances -- useful for prefill-decode split of embeddings, not KV. Network disagg
P2pNcclConnector
p2p/p2p_nccl_connector.py
Direct GPU-to-GPU KV transfer over NCCL for prefill-decode disaggregation on a single node or tightly-coupled cluster. NCCL disaggremote
LMCacheConnectorV1
lmcache_connector.py
Integration with the LMCache KV-caching server. Multi-level (GPU → CPU → disk), cross-engine prefix sharing. CPU + Disk + Remote dramremote
LMCacheMPConnector
lmcache_mp_connector.py
Multi-process variant of LMCache that talks to a sidecar via ZMQ to avoid blocking the scheduler. CPU + Sidecar dramasync
NixlConnector
nixl_connector.py
Uses NVIDIA NIXL for zero-copy RDMA-style KV migration across machines. Supports HMA, cross-layer blocks, Mamba-aware routing. RDMA / NIXL remoteHMAdisagg
MultiConnector
multi_connector.py
Meta-connector. Loads from the first sub-connector that reports a hit, saves to all of them. Used to stack DRAM + remote + P2P tiers. Composite wrapper
MoRIIOConnector
moriio/moriio_connector.py
MoRI I/O transport -- experimental cross-node KV movement, similar scope to NIXL. MoRI I/O remote
OffloadingConnector
offloading_connector.py
   ← Part 2 focus
The modern DRAM offload path. A thin connector layer over a pluggable OffloadingSpec backend (vllm/v1/kv_offload/): LRU/ARC eviction policy, optional reuse-frequency gating, deferred offload. Extensible to new tiers via the spec factory. CPU DRAM dramasync
DecodeBenchConnector
decode_bench_connector.py
Synthetic benchmarking tool. Emulates a prefill-decode split by filling GPU KV with dummy tensors so you can stress-test decode under large ISL without a real prefill cluster. Synthetic bench
MooncakeConnector
mooncake/mooncake_connector.py
Integration with the Mooncake KV transfer engine for large-scale disaggregated deployments. Uses ZMQ + HTTPX for control plane plus Mooncake transports for data. Mooncake remotedisagg
FlexKVConnectorV1
flexkv_connector.py
FlexKV distributed multi-tier KV store -- CPU / SSD / remote memory. Exposes tiered cache hits through a single connector. CPU + SSD + Remote dramremote
SimpleCPUOffloadConnector
simple_cpu_offload_connector.py
Lighter single-file pinned-DRAM tier. Bakes its offload logic into one module and piggy-backs on the GPU BlockPool LRU for its cache map. Still registered, but superseded by OffloadingConnector for extensible work. Pinned DRAM dram
Registry is lazy: KVConnectorFactory.register_connector stores a loader closure; the module is only imported when the factory instantiates the class. This keeps vLLM from pulling in Mooncake, NIXL, FlexKV, or LMCache at import time for deployments that don't use them.

5. The Modern CPU Offload Stack -- Two Layers

OffloadingConnector is the production DRAM-offload path in vLLM V1. Unlike the older single-file SimpleCPUOffloadConnector, it is split into two cleanly separated layers connected by one seam -- the OffloadingSpec. The connector layer knows the KVConnectorBase_V1 contract but nothing about pinned memory; the spec/backend layer knows pinned memory and CUDA but nothing about the scheduler. This is what makes a new tier (disk, remote DRAM, a smarter eviction policy) a drop-in rather than a fork.

CONNECTOR LAYER

kv_connector/v1/offloading_connector.py + offloading/

Implements KVConnectorBase_V1. Pure plumbing into the scheduler/worker loop. Holds no offload policy of its own.

  • OffloadingConnector — facade
  • OffloadingConnectorScheduler
  • OffloadingConnectorWorker
SPEC / BACKEND LAYER

vllm/v1/kv_offload/

The actual offload logic. Selected by OffloadingSpecFactory. Knows pinned DRAM, eviction policy, and CUDA copies.

  • CPUOffloadingSpec — the seam
  • CPUOffloadingManager + CachePolicy
  • CpuGpuOffloadingHandlers
The seam: OffloadingConnector.__init__ calls OffloadingSpecFactory.create_spec(vllm_config, kv_cache_config). The spec name comes from kv_connector_extra_config["spec_name"] and defaults to "CPUOffloadingSpec". The factory also accepts a spec_module_path for out-of-tree specs -- so a brand-new offload tier needs zero changes to the connector or the scheduler.

5.1 File Layout

# Connector layer -- the KVConnectorBase_V1 implementation
vllm/distributed/kv_transfer/kv_connector/v1/
    offloading_connector.py     # OffloadingConnector (thin facade, ~171 lines)
    offloading/
        scheduler.py            # OffloadingConnectorScheduler
        worker.py               # OffloadingConnectorWorker
        common.py               # OffloadingConnectorMetadata
        metrics.py              # Prometheus / stats

# Spec / backend layer -- the actual offload engine
vllm/v1/kv_offload/
    abstract.py                 # OffloadingManager ABC, LoadStoreSpec
    spec.py                     # OffloadingSpec ABC, CanonicalKVCaches
    factory.py                  # OffloadingSpecFactory
    mediums.py                  # GPULoadStoreSpec / CPULoadStoreSpec
    reuse_manager.py            # FilterReusedOffloadingManager (decorator)
    cpu/
        spec.py                 # CPUOffloadingSpec
        manager.py              # CPUOffloadingManager
        policies/
            abstract.py         # CachePolicy ABC, BlockStatus
            lru.py              # LRUCachePolicy
            arc.py              # ARCCachePolicy
    worker/
        worker.py               # OffloadingWorker, OffloadingHandler ABC
        cpu_gpu.py              # CpuGpuOffloadingHandlers + SingleDirectionOffloadingHandler

The rest of Part 2 walks this stack top-down: the facade (§6), the scheduler-side bookkeeping manager (§7), its pluggable eviction policy (§8), the reuse filter (§9), then the per-step load path (§10) and store path (§11), and finally the worker that turns plans into CUDA copies (§12).

6. OffloadingConnector -- the Thin Facade

offloading_connector.py is only ~171 lines and almost every method is a one-line delegation. The connector is built twice (once per role); the constructor instantiates exactly one half and forwards each base-class method to it.

# vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
class OffloadingConnector(KVConnectorBase_V1):
    @property
    def prefer_cross_layer_blocks(self) -> bool:
        return True          # ask vLLM for a single cross-layer KV tensor when possible

    def __init__(self, vllm_config, role, kv_cache_config=None):
        super().__init__(vllm_config, role, kv_cache_config)
        spec = OffloadingSpecFactory.create_spec(vllm_config, kv_cache_config)

        self.connector_scheduler = None
        self.connector_worker = None
        if role == KVConnectorRole.SCHEDULER:
            self.connector_scheduler = OffloadingConnectorScheduler(spec)
        elif role == KVConnectorRole.WORKER:
            self.connector_worker = OffloadingConnectorWorker(spec)

The interesting wiring is which base-class hook maps to which inner method:

Base hookDelegates toNotes
get_num_new_matched_tokensscheduler.get_num_new_matched_tokensDRAM prefix hit lookup.
update_state_after_allocscheduler.update_state_after_allocPlans a CPU→GPU load.
build_connector_metascheduler.build_connector_metaPacks loads + stores + flushes.
start_load_kvworker.start_kv_transfersSubmits load jobs (+ deferred stores).
wait_for_saveworker.prepare_store_kvQueues stores -- but does NOT submit yet (§11).
save_kv_layer / wait_for_layer_loadno-opPer-layer overlap unused; transfers are block-granular.
get_finishedworker.get_finishedPolls CUDA events for done transfers.
prefer_cross_layer_blocks = True tells vLLM's KV cache allocator that this connector would rather receive one big (num_blocks, num_layers×page) tensor than one tensor per layer. That collapses an N-layer offload into a single contiguous DMA region -- see register_cross_layers_kv_cache in §12.

7. CPUOffloadingManager -- the Block Ledger

The scheduler-side brain is CPUOffloadingManager (vllm/v1/kv_offload/cpu/manager.py, ~208 lines). It runs entirely in the scheduler process, touches no tensors, and answers exactly one question: which offloaded blocks exist, and where. It owns the CPU block pool, ref-counting, and event emission -- but delegates the actual "which block to keep / evict" decision to a pluggable CachePolicy (§8).

7.1 BlockStatus -- the per-block record

# vllm/v1/kv_offload/cpu/policies/abstract.py
class BlockStatus(ctypes.Structure):
    _fields_ = [("ref_cnt", ctypes.c_int32), ("block_id", ctypes.c_int64)]

    # ref_cnt == -1  -> block allocated but data not yet written (NOT ready)
    # ref_cnt ==  0  -> data present, idle, eligible for eviction
    # ref_cnt >  0  -> N in-flight loads are reading this block; pinned
    @property
    def is_ready(self) -> bool:
        return self.ref_cnt >= 0

block_id is the index of the physical slot in the pinned-DRAM buffer. ref_cnt does triple duty: a sentinel for "store in progress", an eviction guard, and an in-flight-read counter. Using a ctypes.Structure keeps the record at a fixed 12 bytes -- the manager may hold tens of thousands of these.

7.2 The OffloadingManager primitives

lookup(block_hashes) → int|None

Counts the leading run of block hashes that are present AND is_ready. Stops at the first miss. This is the DRAM prefix-hit length.

prepare_load(block_hashes) → LoadStoreSpec

ref_cnt += 1 on every block (pins them against eviction during the in-flight copy), returns a CPULoadStoreSpec of physical block IDs.

prepare_store(block_hashes) → PrepareStoreOutput|None

Filters out already-stored hashes, evicts via the policy if the pool is full, allocates fresh slots, inserts them as NOT-ready. Returns the slots to write + the list of evicted hashes. None if eviction cannot free enough.

touch(block_hashes)

Marks blocks as recently used -- feeds the policy's recency / frequency state. Called for hit blocks even when no transfer is needed (GPU prefix cache covered them).

complete_load / complete_store

complete_load does ref_cnt -= 1 (unpins). complete_store flips ref_cnt from -1 to 0 (the block becomes loadable) or, on failure, frees the slot.

take_events() → Iterable[OffloadingEvent]

Drains BlockStored / BlockRemoved events for the KV-event observability stream, so external dashboards see the DRAM tier filling and evicting.

The ref_cnt lifecycle: prepare_store inserts a block at -1 → the worker DMAs GPU→CPU → complete_store sets 0 (now is_ready, evictable). A later prepare_load bumps it to 1 so eviction can't pull the block out from under an in-flight copy; complete_load returns it to 0. Eviction only ever considers ref_cnt == 0 blocks.

8. Pluggable Eviction -- LRU vs ARC

When the CPU pool is full and a new block needs a slot, something already stored has to go. Choosing what goes is the eviction policy. CPUOffloadingManager does not make that choice itself -- it delegates to a CachePolicy object, picked at construction from a two-entry registry by the eviction_policy config string.

# vllm/v1/kv_offload/cpu/manager.py
_CACHE_POLICIES = {"lru": LRUCachePolicy, "arc": ARCCachePolicy}

# vllm/v1/kv_offload/cpu/policies/abstract.py -- the interface both implement
class CachePolicy(ABC):
    def get(self, block_hash)      -> BlockStatus | None   # find a stored block
    def insert(self, block_hash, block)                    # register a new block
    def remove(self, block_hash)                           # drop a block
    def touch(self, block_hashes)                          # mark as recently used
    def evict(self, n, protected) -> list | None      # choose n victims

The two policies differ in both the data structure they keep and the eviction decision they make. The next two subsections explain each one as a caching algorithm first, then show its actual vLLM source.

8.1 LRU -- evict the least recently used

LRU (Least Recently Used) is the simplest useful policy. The rule: when you must drop something, drop whatever has gone untouched the longest. The bet behind it is that recently used data is likely to be used again soon. It needs only one list ordered by last-use time -- use a block and it moves to the "most recent" end; evict, and you take from the "oldest" end.

LRU's well-known weakness is that it sees only recency. If one request streams a long run of blocks that will each be used exactly once -- a long, never-repeated context -- every one of those blocks is pushed to the recent end, shoving the genuinely hot, reusable blocks off the old end. After the scan, the cache is full of things that will never be touched again. This is cache pollution, and it is exactly what ARC sets out to fix.

# vllm/v1/kv_offload/cpu/policies/lru.py  (~46 lines, core shown)
class LRUCachePolicy(CachePolicy):
    def __init__(self, cache_capacity):
        self.blocks: OrderedDict[BlockHash, BlockStatus] = OrderedDict()

    def touch(self, block_hashes):
        for block_hash in reversed(list(block_hashes)):
            if block_hash in self.blocks:
                self.blocks.move_to_end(block_hash)   # -> most-recent end

    def evict(self, n, protected):
        candidates = []
        for block_hash, block in self.blocks.items():   # walk oldest -> newest
            if block.ref_cnt == 0 and block_hash not in protected:
                candidates.append((block_hash, block))
                if len(candidates) == n:
                    break
        if len(candidates) < n:
            return None                              # cannot satisfy -> no change
        for block_hash, _ in candidates:
            del self.blocks[block_hash]
        return candidates

vLLM's LRUCachePolicy is one OrderedDict -- an insertion-ordered map that doubles as the recency list. touch is just move_to_end; evict walks from the oldest end and collects the first n droppable blocks. There is no recompute path here -- this is the textbook LRU, unchanged.

8.2 ARC -- adaptive recency + frequency

ARC (Adaptive Replacement Cache) exists to fix LRU's blind spot. It rests on two ideas.

Idea 1 -- track recency and frequency separately. ARC splits blocks across two real lists. T1 holds blocks seen only once (recent but unproven); T2 holds blocks seen twice or more (proven to recur). A one-off scan only ever lands blocks in T1 -- it can never push into T2 -- so the genuinely hot blocks sitting in T2 are shielded from scan pollution.

Idea 2 -- learn from your own mistakes via ghost lists. ARC also keeps two ghost lists, B1 and B2. They store only the hashes of recently evicted blocks, not their data, so they cost almost nothing. B1 remembers what was just evicted from T1, B2 from T2. They act as an after-the-fact audit: if a block you just evicted from T1 (now sitting in B1) is requested again, that is a signal -- "I dropped a recency block too soon; I gave recency too little room." ARC responds by nudging an internal target, target_t1_size, larger. A hit in B2 nudges it the other way. That target governs how much capacity T1 and T2 each get, and it drifts with the workload on its own -- this is the "Adaptive" in the name.

# vllm/v1/kv_offload/cpu/policies/arc.py  (~156 lines, core shown / condensed)
class ARCCachePolicy(CachePolicy):
    def __init__(self, cache_capacity):
        self.target_t1_size = 0.0
        self.t1 = OrderedDict()   # seen once  -- recency
        self.t2 = OrderedDict()   # seen 2+    -- frequency
        self.b1 = OrderedDict()   # ghost: hashes recently evicted from T1
        self.b2 = OrderedDict()   # ghost: hashes recently evicted from T2

    def touch(self, block_hashes):
        for block_hash in reversed(list(block_hashes)):
            if block_hash in self.t1:        # proven recurring: promote T1 -> T2
                self.t2[block_hash] = self.t1.pop(block_hash)
            elif block_hash in self.t2:      # already frequent: refresh
                self.t2.move_to_end(block_hash)
            elif block_hash in self.b1:      # ghost hit: evicted from T1 too soon
                self.target_t1_size += max(1, len(self.b2) / len(self.b1))
            elif block_hash in self.b2:      # ghost hit: evicted from T2 too soon
                self.target_t1_size -= max(1, len(self.b1) / len(self.b2))

    def evict(self, n, protected):
        # For each victim: if T1 is over its adaptive target, take from T1,
        # otherwise from T2. Gather all n first; mutate only if all n exist.
        ...
        if from_t1: del self.t1[h]; self.b1[h] = None   # evicted hash -> ghost B1
        else:       del self.t2[h]; self.b2[h] = None   # evicted hash -> ghost B2

So touch does double duty: for a live block it promotes T1→T2 or refreshes T2; for a ghost-list hit it does no caching at all -- it only adjusts target_t1_size. And evict consults that target to decide which list to drain, then files each evicted hash into the matching ghost list so the next ghost hit can keep tuning. Recency and frequency are balanced by a number that the workload itself moves.

8.3 Two rules vLLM adds on top

Both policies' evict obey two constraints that come from vLLM, not from the classic algorithms. First, a block is skipped if its ref_cnt != 0 (an in-flight transfer is reading it) or if it is in the protected set (the current call is about to use it) -- you may never evict a block someone still needs. Second, evict(n) is atomic: it gathers all n candidates before touching anything, and returns None -- changing nothing -- if it cannot find n. That prevents a half-done eviction from corrupting the lists when the pool genuinely cannot free enough.

Research hook: That upstream ships both LRU and ARC is itself an admission that LRU is not optimal for KV reuse. A new policy -- SLO-aware, request-priority-aware, or workload-adaptive -- is a single class implementing the 5-method CachePolicy ABC plus one line in _CACHE_POLICIES. No connector, scheduler, or worker change.

9. FilterReusedOffloadingManager -- Reuse Gating

Offloading a block costs a GPU→CPU PCIe copy and a CPU slot. If that block is never read back, the cost was pure waste. FilterReusedOffloadingManager (reuse_manager.py, ~120 lines) is an optional decorator around any OffloadingManager that refuses to store a block until it has been seen enough times.

Intercept lookup() -- count sightings

Every block hash passed to lookup is recorded in an internal LRU OrderedDict[BlockHash, int] counter, capped at max_tracker_size (default 64,000). Then the call is delegated to the backing manager unchanged.

Intercept prepare_store() -- filter below threshold

Before delegating, it keeps only hashes whose count has reached store_threshold. Blocks below threshold are dropped before the backing manager runs, so they never consume CPU offload capacity.

It is wired in by CPUOffloadingSpec.get_manager(): when extra_config["store_threshold"] >= 2, the plain CPUOffloadingManager is wrapped in the filter; otherwise the filter is skipped entirely (threshold 0 or 1 = no gating). All other primitives -- prepare_load, touch, complete_*, take_events -- pass straight through.

10. The Load Path -- CPU → GPU

A load is planned across two scheduler-side calls in offloading/scheduler.py. First a fact about block sizes: the offloaded block can be larger than a GPU block. offloaded_block_size = gpu_block_size × block_size_factor, so one CPU block mirrors block_size_factor contiguous GPU blocks. All hashing arithmetic in this section steps in units of the larger offloaded block.

10.1 get_num_new_matched_tokens -- the DRAM hit lookup

# offloading/scheduler.py  --  get_num_new_matched_tokens (condensed)
block_hashes = self._get_block_hashes(request)
self.manager.touch(block_hashes)                  # feed recency/frequency even on a miss

start_block_idx = num_computed_tokens // self.offloaded_block_size
hits = self.manager.lookup(
    self._get_block_hashes(request, start_idx=start_block_idx))

if hits is None:        return None, False    # "ask me again later"
if hits == 0:           return 0, False
# delay if any hit block is already being loaded by another request
if any(h in self._blocks_being_loaded for h in block_hashes):
    return None, False
return num_hit_tokens, True            # True == load happens asynchronously

The second return value is always True on a real hit -- the load is asynchronous, so the request enters WAITING_FOR_REMOTE_KVS and only runs once the copy lands. The _blocks_being_loaded set (populated only when GPU prefix caching is on) prevents two requests from racing to load the same block twice.

10.2 update_state_after_alloc -- pairing CPU blocks to GPU slots

Once the scheduler reserves GPU blocks, this call builds the transfer plan. manager.prepare_load(block_hashes) pins the CPU source blocks and returns a CPULoadStoreSpec (physical CPU block IDs). The GPU destination is a GPULoadStoreSpec built from the freshly-allocated block IDs. The (src, dst) pair -- a TransferSpec -- is parked in _reqs_to_load until build_connector_meta ships it.

Because an offloaded block is block_size_factor GPU blocks wide, GPULoadStoreSpec also carries group_sizes and block_indices -- the first GPU block of a load may be misaligned to the offloaded boundary (some sub-blocks were already on GPU), and block_indices tells the worker how many leading sub-blocks to skip.

11. The Store Path & Deferred Offload

Stores are planned by _get_reqs_to_store, called once per step inside build_connector_meta. It walks both new and cached requests via yield_req_data(scheduler_output) and offloads only blocks that have newly become full.

11.1 The per-request cursor

# offloading/scheduler.py  --  _get_reqs_to_store (condensed)
num_blocks       = total_tokens // self.offloaded_block_size
start_block_idx  = self._next_stored_block_idx.get(req_id, 0)   # cursor
num_new_blocks   = num_blocks - start_block_idx
if num_new_blocks <= 0:
    continue                                # nothing new to offload

new_block_hashes = self._get_block_hashes(req, start_block_idx, num_blocks)
store_output     = self.manager.prepare_store(new_block_hashes)
if store_output is None:                     # CPU pool full, eviction failed
    continue
self._next_stored_block_idx[req_id] = num_blocks   # advance cursor

_next_stored_block_idx is a monotonic cursor per request. Each step offloads only the blocks between the cursor and the current full-block count, then advances. A block is never offloaded twice, and the scan is O(new blocks) not O(request length).

11.2 Why stores are deferred one step

The worker does NOT submit a store the moment it is planned. prepare_store_kv (called from wait_for_save) only appends the job to _unsubmitted_store_jobs. The actual worker.transfer_async call happens at the start of the next step, inside start_kv_transfers / handle_preemptions. The source comment is explicit:

"defer the store to the beginning of the next engine step, so that offloading starts AFTER transfers related to token sampling, thereby avoiding delays to token generation due to offloading." -- A store moves data the user has already finished computing; a sampling transfer is on the critical path of the next token. Deferring the store one step keeps the offload DMA from contending with latency-sensitive work.

11.3 Completion accounting

When a transfer's CUDA event fires, worker.get_finished returns the request to (finished_sending, finished_recving). Back on the scheduler, update_connector_output calls manager.complete_store (block becomes loadable) or manager.complete_load (unpin). A request that finishes generation while a store is still in flight is held in _finished_reqs_waiting_for_store -- its GPU blocks are not freed until the outbound copy lands, exactly the contract request_finished() promises in §2.

12. The Worker -- Canonical Tensors, Streams, DMA

OffloadingConnectorWorker turns the declarative TransferSpecs into CUDA work. It owns an OffloadingWorker that holds a registry of OffloadingHandlers keyed by (src_medium, dst_medium) -- ("GPU","CPU") and ("CPU","GPU") each map to one direction handler.

12.1 register_kv_caches -- canonicalizing layouts

The hardest worker code is layout-wrangling. Different attention backends store KV with different physical layouts -- FlashAttention uses (2, num_blocks, ...), others put num_blocks first, Mamba is different again. register_kv_caches probes each backend's get_kv_cache_shape, sorts the strides to find which physical dimension is num_blocks, and rebuilds every layer as a canonical (num_blocks, page_size_bytes) int8 tensor. FlashAttention's leading-2 layout is unbind'd into two tensors (K and V). The result is a CanonicalKVCaches -- a uniform block-major view regardless of backend.

12.2 transfer_async -- streams and event chaining

Each transfer gets its own CUDA stream from a pool and is wrapped in start/end torch.Events. To preserve submission order without serializing onto the compute stream, each new transfer's stream wait_events on the previous transfer's end event. GPU→CPU transfers also wait_stream on the current compute stream so offload only starts after the model has written the KV.

# worker/cpu_gpu.py  --  the actual copy, per KV tensor
with torch.cuda.stream(stream):
    start_event.record(stream)
    for src_tensor, dst_tensor, block_size_in_bytes in ...:
        ops.swap_blocks(src_tensor, dst_tensor,
                        block_size_in_bytes, src_to_dst_tensor)
    end_event.record(stream)

# csrc/cache_kernels.cu  --  swap_blocks lowers to a per-block memcpy loop
for (size_t i = 0; i < num_blocks; i++)
    cudaMemcpyAsync(dst_ptr + dst_off, src_ptr + src_off,
                    block_size_in_bytes, memcpy_type, stream);
Does the KV cache pass through VRAM? Yes. The destination of a CPU→GPU load is the live KV cache HBM allocation -- swap_blocks issues per-block cudaMemcpyAsync(HostToDevice), so the bytes physically land in VRAM and the attention kernels read them from HBM. The CPU buffer is pinned (for true async DMA) but NOT memory-mapped -- this is explicit copy, not UVA zero-copy. vLLM's UVAOffloader (for model weights) does map host memory and lets the SM read it over PCIe; the KV path deliberately does not. The reason is reuse: a KV block is re-read by attention many times per decode step, so one PCIe copy + many ~3 TB/s HBM reads beats N ~64 GB/s PCIe reads. Weights are streamed once per layer, so zero-copy wins there. vLLM picks per data-reuse profile.

12.3 get_finished -- polling completion

get_finished drains worker.get_finished(), which non-blockingly queries each pending transfer's end event. A finished load returns its request in finished_recving; a store decrements that request's pending-job count and only reports finished_sending once the count hits zero. Transfer size and elapsed time are recorded into OffloadingConnectorStats for Prometheus.

13. Configuration & File Map

13.1 Configuration surface

Everything is driven through kv_transfer_config.kv_connector_extra_config:

KeyDefaultEffect
spec_name"CPUOffloadingSpec"Which OffloadingSpec the factory builds.
spec_module_pathImport path for an out-of-tree spec.
cpu_bytes_to_userequiredDRAM budget; divided by block bytes to size the CPU pool.
block_size= GPU blockOffloaded block size; sets block_size_factor (must be a multiple).
eviction_policy"lru"lru or arc -- selects the CachePolicy.
store_threshold0>=2 enables the reuse filter (§9); 0/1 = no gating.
max_tracker_size64000LRU counter cap for the reuse filter.

13.2 Complete file map

FileLayerKey classResponsibility
offloading_connector.pyconnOffloadingConnectorFacade; delegates to scheduler/worker halves.
offloading/scheduler.pyconnOffloadingConnectorSchedulerLoad/store planning, cursors, completion accounting.
offloading/worker.pyconnOffloadingConnectorWorkerCanonicalizes KV layouts, drives the OffloadingWorker.
kv_offload/factory.pyspecOffloadingSpecFactoryLazy registry of OffloadingSpec implementations.
kv_offload/cpu/spec.pyspecCPUOffloadingSpecBuilds the manager + handlers; the layer seam.
kv_offload/cpu/manager.pyspecCPUOffloadingManagerCPU block pool, ref-counting, events.
kv_offload/cpu/policies/specLRUCachePolicy / ARCCachePolicyPluggable eviction; the research extension point.
kv_offload/reuse_manager.pyspecFilterReusedOffloadingManagerOptional reuse-frequency admission filter.
kv_offload/worker/cpu_gpu.pyspecCpuGpuOffloadingHandlersPer-direction handlers; streams, events, swap_blocks.

13.3 OffloadingConnector vs the older SimpleCPUOffloadConnector

SimpleCPUOffloadConnector is still registered in the factory, but it is the lighter single-file design: it piggy-backs on the GPU BlockPool LRU, bakes in its own offload logic, and is not extensible. OffloadingConnector is the modern path -- a connector/spec split, a pluggable eviction policy, an optional reuse filter, and a factory seam for entirely new tiers. For new work, build on OffloadingConnector and vllm/v1/kv_offload/.

Bottom line: The modern offload stack is two layers around one seam. The connector layer is pure plumbing into KVConnectorBase_V1; the vllm/v1/kv_offload/ backend owns the real decisions -- and exposes three independent extension points: a new OffloadingSpec (a whole new tier), a new CachePolicy (smarter eviction), and the reuse filter. KV data still transits VRAM by explicit DMA, because KV blocks are reused too heavily for zero-copy to pay off.