Exp: Execution Note

Hands-on verification of the vLLM-Continuum KV cache pinning mechanism
— Functional Test & Metrics-based Verification —

RTX 5090 32GB Llama 3.1 8B 2026-04-02 Pin/TTL Verified

1. Environment

ItemValue
Machinessh hclin@100.107.101.84
GPU1× NVIDIA RTX 5090 (32 GB VRAM, ~30 GB available)
Conda envcontextserve
vLLM source/home/hclin/vllm-continuum/vllm/ (editable install)
Base vLLM versionv0.7.x (fork, ~v0.7.2)
Model/home/hclin/shared_models/Meta-Llama-3.1-8B-Instruct (15 GB, local)
KV cache blocks5 402 blocks × block_size 16 tokens
Scheduling policy--scheduling-policy continuum
Note: The remote repo at /home/hclin/vllm-continuum/ has been extended with ContextServe (three-way dispatch: FAST/SLOW/UNCERTAIN). The Continuum branch is intact as an elif path — it is NOT affected when --scheduling-policy continuum is used.

2. Server Launch

# Activate environment
source ~/anaconda3/etc/profile.d/conda.sh && conda activate contextserve

# Start vLLM with Continuum scheduling
RUN_OUTPUT_DIR=./e2e_output \
vllm serve /home/hclin/shared_models/Meta-Llama-3.1-8B-Instruct \
  --scheduling-policy continuum \
  --port 8199 \
  --max-model-len 4096 \
  --gpu-memory-utilization 0.85

Startup takes ~70 seconds (weight loading + torch.compile). Health check via:

curl http://localhost:8199/health

3. Test Execution

3.1 Test Design

A Python script sends multi-turn agentic requests via the OpenAI client. Each turn provides job_id and is_last_step through the extra_body parameter, which vLLM's protocol layer unpacks into SamplingParams.extra_args.

# Key: how Continuum metadata is passed
resp = client.chat.completions.create(
    model=MODEL,
    messages=[
        {"role": "system", "content": "Respond with ONLY a ```bash``` block."},
        {"role": "user",   "content": "List files with ls."},
    ],
    max_tokens=150,
    extra_body={
        "job_id":       "job_alpha",   # ← Identifies the agent job
        "is_last_step": False,          # ← Tells scheduler to pin after completion
    },
)
Pitfall: Using requests.post() directly with extra_body as a nested dict does NOT work — the server ignores it. You must use the OpenAI Python client, which merges extra_body keys into the top-level payload so vLLM's Pydantic model can parse job_id.

3.2 Prompts & LLM Output

2 sequential jobs × 5 turns each. Each turn carries the full conversation history (system prompt + all previous user/assistant/tool messages), simulating a real agent. Simulated tool outputs are injected between turns.

TurnUser PromptLLM OutputPrompt TokensOutput TokensE2E Latency
1List files in the project.```bash ls ```667319 ms
2Read main.py.```bash cat main.py ```1319115 ms
3Search for TODO comments.```bash grep -r 'TODO' . ```27912154 ms
4Check git status.```bash git status ```3978106 ms
5Run pytest.```bash pytest ```503798 ms
Context grows: Prompt tokens increase with each turn (66 → 131 → 279 → 397 → 503) because the full conversation history accumulates — including simulated tool outputs (e.g. file listings, source code, grep results, git status). This is realistic — a real agent's KV cache grows as the dialogue progresses.
Model matters: Llama 3.1 8B outputs clean bash blocks (10/10 parsed). Qwen3 8B uses <think>...</think> mode, consuming all tokens before producing the bash block — only 2/10 parsed. Continuum's ToolCallParser regex is designed for non-thinking models.

3.3 Server Log — Tool Detection & Pin Chain

Filtered server log showing the complete Continuum pipeline for job_alpha:

03:24:11   Request job id arriving: job_alpha — Turn 1, first time seen
03:24:11   Request job id finishing: job_alpha
03:24:11   Extracted func_call: lsset_up_pin → pin 2.0s
03:24:12   Request job id arriving: job_alpha — last func call: ls
03:24:12   Request job id finishing: job_alpha
03:24:12   Extracted func_call: cat → pin 2.0s
03:24:13   Request job id arriving: job_alpha — last func call: cat
03:24:13   Extracted func_call: grep → pin 2.0s
03:24:14   Request arriving — last func call: grep
03:24:14   Extracted func_call: git → pin 2.0s
03:24:15   Request arriving — last func call: git   (Turn 5, is_last_step=True)
03:24:15   Extracted func_call: pytestis_last_step=True → no pin, free KV cache

All 5 turns correctly parsed. Turns 1–4 triggered set_up_pin → return 2.0. Turn 5 (is_last_step=True) skipped pin and freed blocks.

4. Pin Verification via /metrics

A second test script polls vllm:kv_cache_usage_perc from the Prometheus /metrics endpoint in real-time, before and after each turn, during the simulated tool execution gap.

4.1 Raw Results (5 Turns, Accumulated Context)

# 5 turns with accumulated conversation history + simulated tool outputs
# Turn 1-4: is_last_step=False (should pin)
# Turn 5: is_last_step=True (should free)

[Baseline] KV cache usage: 0.000000

Turn 1/5  last=False  319ms  prompt=66  output=7
  Output: ```bash ls ```
  KV usage: 0.000000 → 0.000876  (delta: +0.000876)
  Simulating tool execution — KV should stay pinned...
    t+0.3s: KV usage = 0.000876
    t+0.6s: KV usage = 0.000876
    t+0.9s: KV usage = 0.000876

Turn 2/5  last=False  115ms  prompt=131 output=9
  Output: ```bash cat main.py ```
  KV usage: 0.000876 → 0.001577  (delta: +0.000701)  ← context grew
  Simulating tool execution — KV should stay pinned...
    t+0.3s: KV usage = 0.001577
    t+0.6s: KV usage = 0.001577
    t+0.9s: KV usage = 0.001577

Turn 3/5  last=False  154ms  prompt=279 output=12
  Output: ```bash grep -r 'TODO' . ```
  KV usage: 0.001577 → 0.003330  (delta: +0.001753)  ← big jump (grep results injected)
  Simulating tool execution — KV should stay pinned...
    t+0.3s: KV usage = 0.003330
    t+0.6s: KV usage = 0.003330
    t+0.9s: KV usage = 0.003330

Turn 4/5  last=False  106ms  prompt=397 output=8
  Output: ```bash git status ```
  KV usage: 0.003330 → 0.004557  (delta: +0.001227)  ← context grew again
  Simulating tool execution — KV should stay pinned...
    t+0.3s: KV usage = 0.004557
    t+0.6s: KV usage = 0.004557
    t+0.9s: KV usage = 0.004557

Turn 5/5  last=True    98ms  prompt=503 output=7
  Output: ```bash pytest ```
  KV usage: 0.004557 → 0.000000  (delta: -0.004557)  ← last step, all freed!

[Final] KV cache usage: 0.000000

4.2 Token Accounting & Block Calculation

Each turn's input = system + all previous messages + chat template tokens. Tool output from Turn N is included in Turn N+1's input.

Turn Content Tokens Template Input Total Output KV Tokens ⌈KV/16⌉ Measured Tool Result
1 33 (sys 25 + task 8) 33 66 7 73 5 5 33 (ls)
2 86 (+out₁ 7 + tool₁ 33 + task 13) 45 131 9 140 9 9 115 (cat main.py)
3 222 (+out₂ 9 + tool₂ 115 + task 12) 57 279 12 291 19 18 82 (grep TODO)
4 328 (+out₃ 12 + tool₃ 82 + task 12) 69 397 8 405 26 25 78 (git status)
5 422 (+out₄ 8 + tool₄ 78 + task 8) 81 503 7 510 32 0 (freed) (last step)
How to read:
  • Content = cumulative raw text tokens (system + tasks + outputs + tool results)
  • Template = Llama chat template overhead (header/footer per message, +12 tok per 3 new messages)
  • Input Total = Content + Template (= vLLM's prompt_tokens)
  • Tool Result = tokens returned by tool execution, included in next turn's Content
  • KV Tokens = Input + Output (what the KV cache stores for this turn)
  • Measured = blocks from /metrics (±1 vs theoretical due to prefix cache dedup)

4.3 KV Cache Usage Over Time

Event Usage Blocks Visual (5402 total)
Baseline0.0000000
Turn 10.0008765
  pin0.0008765
pinned
Turn 20.0015779
+4
  pin0.0015779
pinned
Turn 30.00333018
+9
  pin0.00333018
pinned
Turn 40.00455725
+7
  pin0.00455725
pinned
Turn 5 (last)0.0000000
freed

4.4 Interpretation

KV cache grows with context

Usage increases monotonically: 0.000876 → 0.001577 → 0.003330 → 0.004557 (5 → 9 → 18 → 25 blocks). The jump at Turn 3 (+9 blocks) reflects the large grep output injected into context.

Pin holds during tool execution

After each of Turns 1–4, kv_cache_manager.free() is skipped. Usage stays constant during every tool gap (polled at t+0.3s, t+0.6s, t+0.9s). No blocks are returned to free_block_queue.

Last-step release works

Turn 3 (is_last_step=True) immediately drops usage to 0.000000 — the _free_blocks() code path correctly bypasses pin and calls kv_cache_manager.free().

4.4 System Functions for Verification

MethodLocationWhat it proves
/metricsvllm:kv_cache_usage_percPrometheus endpointUsage stays elevated during pin period
block_pool.get_num_free_blocks()vllm/v1/core/block_pool.py:299Free count does not increase after pin
KVCacheBlock.ref_cntvllm/v1/core/kv_cache_utils.py:164Reference counter stays > 0 while pinned
kv_cache_manager.get_computed_blocks()vllm/v1/core/kv_cache_manager.py:153Returns hit_length > 0 on next turn
scheduler_timestamps JSONvllm/v1/core/estimate_with_func.py:27Records pinned_time, unpinned_time, hit_length per job
scheduler.pinned_requestsvllm/v1/core/sched/scheduler.py:139List of (Request, expiry_time) tuples

5. Bugs Found

KeyError: 'func_call' — vllm/v1/core/estimate_with_func.py:169

When multiple requests share the same job_id (including None), the second request arrives before the first finishes. The code assumes history[-1] is a departure entry with func_call key, but it's an arrival entry.

# Crash site:
request.last_func_call = self.job_to_history[job_id][-1]["func_call"]  # KeyError!
# history[-1] is {"arrival_time": ...}, not {"departure_time": ..., "func_call": ...}

Trigger: Concurrent requests with same job_id, or any request with job_id=None.
Status: Not reported on GitHub. Issue #15 is a different bug (VllmModelConfig port argument).

6. vLLM Block Pool Design — Why Continuum's Pin Works

Understanding Continuum's pin mechanism requires understanding vLLM v1's underlying KV cache architecture. Here we document the key design decisions and community discussions.

6.1 Append-Only Block Table

vLLM v1's block table is append-only — once a block ID is written into a request's block table, it cannot be changed. New blocks can only be appended at the end. To understand why, consider this scenario:

Scenario: Two requests with the same system prompt

Step 1 — Request A arrives first

System prompt = "You are a helpful assistant." → needs 1 block

A: blk 3 allocated fresh, stores K/V for system prompt

Step 2 — Request B arrives with the same system prompt

Prefix cache hit: block hash matches → B also needs block 3's K/V data

(Block hash = hash of token IDs + parent block hash, so same tokens at same position = same hash)

Here's where v0 and v1 diverge:

v0: Mutable block table

B shares A's block directly. Then scheduler rewrites A's block table to point to a canonical copy:

A: blk 3→1 rewritten!
B: blk 1
Problem: GPU's FlashAttention kernel may be reading A's block table at this exact moment. CPU rewrites it mid-read → inconsistent data → wrong attention output.

v1: Append-only

B gets its own block. Both co-exist with the same K/V content:

A: blk 3 untouched
B: blk 9 same content, different block
GPU reads block tables safely — no entry is ever modified. Duplicate blocks waste a little memory but are cleaned up when requests finish.

After Request A finishes

Block 3 returns to free queue, but its hash stays in the prefix cache. Future requests with the same system prompt get a cache hit and reuse block 3 directly — no duplication needed.

Free queue: blk 3 hash cached blk 7 ···
Request C arrives: blk 3 ← prefix cache hit! reuse block 3, no re-computation

Why this matters for Continuum

Continuum pins blocks after a request finishes. Because block tables are append-only, the pinned blocks' IDs are guaranteed stable — no other process can rewrite them. When the next turn arrives, prefix cache finds the pinned blocks by hash and appends their IDs to the new block table.

Source: vLLM v1 Prefix Caching Design Doc

6.2 Null Block (block_id=0)

The BlockPool reserves one block as a permanent placeholder:

# vllm/v1/core/block_pool.py:66-70
# To represent a placeholder block with block_id=0.
# The ref_cnt of null_block is not maintained, needs special care to
# avoid freeing it.
self.null_block = self.free_block_queue.popleft()
self.null_block.is_null = True

Because block tables are append-only and fixed-length, positions that have no real KV data (e.g., prefix cache miss, or tokens outside the sliding window) must be filled with a placeholder rather than left empty:

# Example: sliding window attention, only last 3 blocks active
block_table = [null, null, block_8, block_3, block_12]
               ↑ outside window, filled with null_block

The null block is marked is_null=True so it is never returned to the free queue and never allocated to any request. This is also why get_usage() subtracts 1 from total blocks:

# vllm/v1/core/block_pool.py:314
total_gpu_blocks = self.num_gpu_blocks - 1  # Subtract 1 to account for null block

6.3 Reference Counting & Free Queue

Each KVCacheBlock has a ref_cnt. This is the mechanism Continuum exploits for pinning:

Eventref_cntIn free queue?Allocatable?
Allocated to request+1 → 1NoNo
Normal free (no pin)-1 → 0YesYes
Continuum pin (skip free)stays 1NoNo
Unpin (TTL expires)-1 → 0YesYes

Freed blocks are added to the tail of the free queue in reverse order — the last block of a request hashes the most tokens and is least likely to be reused, so it should be evicted first.

6.4 How Continuum Exploits This Design

Continuum's pin is remarkably simple — it doesn't add any new locking or memory management. It just skips calling kv_cache_manager.free() when a request finishes:

# scheduler.py — _free_blocks()
if policy == CONTINUUM and not request.is_last_step:
    length_of_pin = tool_call_estimator.set_up_pin(request)
    if length_of_pin > 0.01:
        self.pin_request(request, length_of_pin)
        return   # ← kv_cache_manager.free() is SKIPPED
                 # ← ref_cnt stays > 0
                 # ← blocks stay out of free queue
                 # ← GPU memory remains occupied

# Normal path (not pinned):
self.kv_cache_manager.free(request)  # ← ref_cnt decremented → blocks return to free queue

When the same job's next turn arrives, get_computed_blocks() finds the still-allocated blocks via prefix cache hash lookup → hit_length > 0 → skip re-prefill.

6.5 Community Discussions & References

TopicLinkKey Insight
Prefix Caching Design (v1) design/v1/prefix_caching.md Official doc explaining append-only constraint, dedup deferral, and reverse-order eviction
RFC: Automatic Prefix Caching #2614 Original proposal for prefix-aware block management, discusses eviction policy and hash-based lookup
Understanding vLLM KV Cache vLLM Forums Community Q&A on block pool internals, ref counting, and allocation
Hybrid KV Cache Manager #22688 Design doc for GPU↔CPU tiered block pool with async transfer — relevant to Continuum's offload scenario
RFC: CPU Offload in V1 #16144 Discusses hiding transfer latency with pinned memory and MemcpyAsync, round-robin → LRU eviction
Prefix-Caching Aware Scheduling #7883 Discussion on scheduling that considers cache hit rate — directly related to Continuum's job-level FCFS priority
KV Cache Management (DeepWiki) DeepWiki Third-party deep dive into block pool, free queue, and hash-based caching architecture

7. How Pin Works at the Memory Level

7.1 No CUDA Memory API Involved

vLLM's KV cache pin does not call cudaMalloc, cudaFree, cudaMemcpy, or any CUDA memory management API. The entire mechanism is pure Python bookkeeping on top of a pre-allocated GPU buffer.

7.2 One-Time GPU Allocation at Startup

The only CUDA memory allocation happens once at server startup:

# vllm/v1/worker/gpu_model_runner.py:3367 — called once at init
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device="cuda")

This allocates 10.55 GiB as a single contiguous tensor on GPU. This memory is never freed until server shutdown. It is logically partitioned into 5,402 blocks:

GPU Memory (10.55 GiB, allocated once via torch.zeros):

blk 0
(null)
blk 1
blk 2
blk 3
blk 4
···
blk 5401

Each block = 2 MiB = 16 tokens × 32 layers × 8 KV heads × 128 dim × 2 bytes. "Allocate" and "free" only move pointers in Python — zero GPU ops.

7.3 What's Inside a Block?

Each block stores 16 tokens' Key and Value tensors across all 32 layers. The GPU tensor shape per layer (FlashAttention backend):

# vllm/attention/backends/flash_attn.py:77
shape = (2, num_blocks, block_size, num_kv_heads, head_size)
        ↑      ↑          ↑          ↑            ↑
       K/V   5402        16          8           128

One block (block_id = N) in one layer:

K tensor   kv_cache[0, N, :, :, :]

[16, 8, 128] bf16 = 32 KiB

V tensor   kv_cache[1, N, :, :, :]

[16, 8, 128] bf16 = 32 KiB

Per layer: 32 + 32 = 64 KiB. Across 32 layers: 64 KiB × 32 = 2 MiB per block.

FlashAttention kernel uses block_id to index directly into this tensor — no copying, no indirection beyond the block table lookup.

7.4 ref_cnt — Who Owns a Block?

Each KVCacheBlock has a ref_cnt (reference count) that tracks how many requests are using it. This is needed because prefix caching allows multiple requests to share the same block — if two requests have the same system prompt, their first blocks contain identical K/V data, so vLLM lets them share instead of duplicating.

ref_cntStateMeaning
0Free (in free queue)No one uses this block. Can be evicted or re-allocated.
1In use (1 request)Exactly one request owns it. Normal case.
≥ 2Shared (prefix cache)Multiple requests share this block (same prefix). Block is freed only when ALL requests release it.

Continuum's pin exploits this: by keeping ref_cnt = 1 (skipping free()), the block stays out of the free queue — exactly the same as if a request were still using it.

7.5 Why a Doubly-Linked List?

The FreeKVCacheBlockQueue is a doubly-linked list instead of Python's deque. The reason: prefix cache hit requires removing a block from the middle.

When a new request shares a prefix with a cached block, that block (which is free and evictable) must be pulled out of the free queue immediately:

# vllm/v1/core/block_pool.py:236-250 — touch() on prefix cache hit
def touch(self, blocks):
    for block in blocks:
        if block.ref_cnt == 0:                    # block is in free queue
            self.free_block_queue.remove(block)  # ← remove from MIDDLE! O(1)
        block.ref_cnt += 1
Operation Python deque Array / List Doubly-linked list
popleft (allocate)O(1)O(n)O(1)
append (free)O(1)O(1)O(1)
remove from middle (cache hit)O(n)O(n)O(1)

The remove() operation is just 2 pointer assignments:

# vllm/v1/core/kv_cache_utils.py:334-340
block.prev_free_block.next_free_block = block.next_free_block
block.next_free_block.prev_free_block = block.prev_free_block
block.prev_free_block = block.next_free_block = None
Performance trick: vLLM embeds prev_free_block / next_free_block directly on the KVCacheBlock dataclass (not separate Node objects), avoiding Python object allocation overhead. fake_head and fake_tail sentinel nodes eliminate null-check branching.

7.6 block_id → GPU VRAM Address Mapping

Each layer's KV cache is a contiguous torch.zeros tensor on GPU. block_id is simply an index into the second dimension — the GPU address is computed by a single multiplication, with no hash table or pointer chasing:

# Tensor shape per layer (FlashAttention)
kv_cache.shape = (2, 5402, 16, 8, 128)
                  ↑    ↑
                 K/V  block_id ← this dimension IS the block_id

# GPU address for block N:
base = kv_cache.data_ptr()                      # tensor 起始 GPU 位址
block_size_bytes = 16 × 8 × 128 × 2 = 32,768 bytes  # per block per K or V

key_block_N   = base + 0 × total_kv_size + N × 32,768
value_block_N = base + 1 × total_kv_size + N × 32,768

Example: block_table = [1, 2, 3, 4, 5] for one layer's K cache

null
+0
blk 1
+32K
blk 2
+64K
blk 3
+96K
blk 4
+128K
blk 5
+160K
free
+192K
···

Address = base + block_id × 32,768. One multiplication — no lookup table, no indirection.

The CUDA kernel does exactly this pointer arithmetic (csrc/cache_kernels.cu:306):

key_dst = key_cache + block_idx * block_stride + block_offset * page_stride;
//                    ^^^^^^^^^^^^^^^^^^^^^^^^
//                    block_idx × 32,768 = jump directly to that block

7.7 Block Read/Write: The CUDA Kernels

Block allocation/free is pure Python, but the actual K/V data read/write happens in two CUDA kernels:

Write — Store K/V into blocks

vllm/attention/backends/flash_attn.py:677
csrc/cache_kernels.cu:281

Every forward step for new tokens. Converts slot_mapping[token] → block_idx + offset, copies K/V into GPU tensor.

Read — Attention over cached K/V

vllm/attention/backends/flash_attn.py:770,805

FlashAttention receives block_table=[1,2,3,4,5], gathers K/V from scattered blocks → Q×KT/√d → softmax → ×V.

Write kernel: slot → block address

// csrc/cache_kernels.cu:291-309 — reshape_and_cache_flash_kernel

slot_idx     = slot_mapping[token_idx];         // e.g. 35
block_idx    = slot_idx / block_size;           // 35 / 16 = block 2
block_offset = slot_idx % block_size;           // 35 % 16 = position 3

// Pointer arithmetic → destination in GPU tensor
key_dst   = key_cache   + block_idx * block_stride + block_offset * page_stride;
value_dst = value_cache + block_idx * block_stride + block_offset * page_stride;

// Vectorized copy (8 bf16 elements per instruction)
vectorize_copy(key_src   → key_dst);
vectorize_copy(value_src → value_dst);

Read path: block_table → FlashAttention

# vllm/attention/backends/flash_attn.py:770 (prefill) / :805 (decode)

flash_attn_varlen_func(
    q=query,
    k=key_cache,                  # entire GPU tensor
    v=value_cache,                # entire GPU tensor
    block_table=[1, 2, 3, 4, 5],   # which blocks belong to this request
    ...
)
# FlashAttention: for each query → look up block_table
# → index into key_cache/value_cache → compute attention

Why pin preserves data

K/V data is never erased or zeroed — it stays at the same GPU address. Pin simply prevents the block from being re-allocated to another request (by keeping ref_cnt > 0). When the next turn arrives, FlashAttention reads the same block_idssame GPU addresses → old K/V data is intact.

7.8 Block Lifecycle: Python Data Structures Only

OperationAPIWhat it doesCUDA call?
Init pool torch.zeros(..., device="cuda") Allocate 10.55 GiB contiguous GPU buffer Yes (once)
Allocate block FreeKVCacheBlockQueue.popleft_n() Remove from doubly-linked list, ref_cnt += 1 No
Free block FreeKVCacheBlockQueue.append_n() ref_cnt -= 1, append back to linked list No
Pin (Continuum) (nothing) Skip free() → ref_cnt stays > 0 → block stays out of free queue No
Unpin (TTL expires) kv_cache_manager.free() ref_cnt -= 1, block returns to free queue No
Use block (forward pass) FlashAttention kernel Read/write the GPU tensor at block_id × block_size offset Yes (every step)

7.9 The FreeKVCacheBlockQueue

The core data structure is a doubly-linked list of KVCacheBlock objects. All allocate/free operations are O(1) pointer swaps:

# vllm/v1/core/kv_cache_utils.py:158-175 — each block is a node
@dataclass
class KVCacheBlock:
    block_id: int           # index into the GPU tensor
    ref_cnt: int = 0       # > 0 = in use, 0 = free/evictable
    _block_hash: Optional    # for prefix cache lookup
    prev_free_block: Optional["KVCacheBlock"]  # doubly-linked list
    next_free_block: Optional["KVCacheBlock"]  # doubly-linked list
    is_null: bool = False   # placeholder block
# Allocate: pop from front of free list
block = free_block_queue.popleft()   # O(1) pointer swap
block.ref_cnt += 1

# Free: append to tail of free list
block.ref_cnt -= 1
free_block_queue.append(block)       # O(1) pointer swap

# Pin: do nothing (skip the free above)

7.10 Why This Design?

No fragmentation

Fixed-size blocks from a pre-allocated pool. No cudaMalloc/cudaFree churn, no CUDA memory fragmentation.

Zero GPU overhead for scheduling

Allocate/free/pin/unpin are pure CPU operations (Python linked list). The GPU only sees block_id indices in the block table — it never knows blocks were "pinned".

Pin is trivially simple

Continuum doesn't need any special memory API. It just skips one function call (free()). The block's ref_cnt stays above zero, preventing the free queue from reclaiming it.

7.11 Concrete Example: Allocating 5 Blocks

Tracing Turn 1 (66 tokens → 5 blocks) step by step:

① Initial state — FreeKVCacheBlockQueue (5401 free blocks)

fake_head blk 1 blk 2 blk 3 blk 4 blk 5 blk 6 ··· fake_tail

ref_cnt = 0 for all blocks  |  num_free_blocks = 5401

② popleft_n(5) — pop 5 blocks from front

Free queue after pop:

fake_head blk 6 blk 7 ··· fake_tail

Popped (ref_cnt 0 → 1):

blk 1 blk 2 blk 3 blk 4 blk 5

10 pointer swaps, 5 int increments. Zero CUDA calls.  |  num_free_blocks = 5396

③ Build block table → GPU tensor mapping

block_table = [1, 2, 3, 4, 5] ← indices into GPU tensor

GPU tensor (10.55 GiB, allocated once, never freed):

null
blk 1
blk 2
blk 3
blk 4
blk 5
free
free
··· (×5394)

FlashAttention reads/writes at block_id × 2 MiB offset. GPU memory is never malloc'd or free'd.

④ /metrics result

usage = 1 − (5396 / 5401) = 5 / 5401 ≈ 0.000926

Total Python ops: 10 pointer swaps + 5 int increments + 1 dict insert  |  Total CUDA ops: 0

7.12 Relationship to PagedAttention

Everything described above is PagedAttention — the core contribution of the original vLLM paper. The naming maps directly to OS virtual memory concepts:

Page Frame (4 KB) KVCacheBlock (2 MiB = 16 tokens)
Page Table (VA → PA) block_table [1, 2, 3, 4, 5]
Physical Memory (pre-alloc) torch.zeros(10.55 GiB, cuda)
Free Page List FreeKVCacheBlockQueue
ref_count (page sharing) block.ref_cnt (prefix cache)

Just as OS paging lets processes use non-contiguous physical memory via a page table, PagedAttention lets requests use non-contiguous GPU memory via a block table:

Logical token positions (contiguous to attention computation):

tok 0–15
block 1
tok 16–31
block 2
tok 32–47
block 3
tok 48–65
block 4+5

Physical GPU memory (non-contiguous — blocks can be anywhere):

null
blk 1
blk 2
blk 3
blk 4
blk 5
free
free
···

block_table = [1, 2, 3, 4, 5]   ← indirection table (like page table)

Where Continuum Fits

Continuum operates entirely within PagedAttention's framework. It adds one behavior:

Normal vLLM

blk 1blk 2blk 3blk 4blk 5

Request done → free() → ref_cnt → 0 → back to free queue. Next request can overwrite.

Continuum

blk 1blk 2blk 3blk 4blk 5

Request done → skip free() → ref_cnt stays 1 → blocks reserved. Next turn of same job reuses them.

Analogy: Continuum = "don't return the pages to the OS after the process pauses, because it will resume soon and need them back."

8. KV Cache Budget Estimation

How does vLLM decide it can allocate 5402 blocks? This section traces the math.

8.1 The Formula

# vllm/v1/worker/gpu_worker.py:176-177
requested_memory = init_snapshot.total_memory × gpu_memory_utilization

# vllm/v1/worker/gpu_worker.py:279-280 — after profile run
non_kv_cache_memory = weights + peak_activation + non_torch_memory
available_kv = requested_memory − non_kv_cache_memory
num_gpu_blocks = available_kv // cache_block_size
Key: requested_memory uses total_memory (not free), so it's a fixed value regardless of other processes. However, if init_free < requested_memory, vLLM refuses to start. Also, peak_activation is measured by running one real forward pass, not calculated analytically.

8.2 Per-Block Size Calculation

Each block stores KV cache for block_size tokens across all attention layers:

# vllm/worker/cache_engine.py:121-145 — get_cache_block_size()

cache_block_size = dtype_size × num_layers × block_size × (key_entry + value_entry)

# where:
key_entry   = num_kv_heads × head_size
value_entry = num_kv_heads × head_size   # same as key (except MLA)

Llama 3.1 8B concrete values:

ParameterValueSource
num_attention_layers32config.json
num_kv_heads8GQA (32 Q heads / 4 groups)
head_size128hidden_size 4096 / 32 heads
block_size16vLLM default (--block-size)
dtype_size2 bytesbfloat16
key_entry   = 8 × 128 = 1024
value_entry = 8 × 128 = 1024

cache_block_size = 2 × 32 × 16 × (1024 + 1024)
                 = 2 × 32 × 16 × 2048
                 = 2,097,152 bytes
                 = 2 MB per block

8.3 Memory Budget Breakdown (vLLM Measured)

These are exact values reported by vLLM's startup log (vllm/v1/worker/gpu_worker.py:391), not estimates:

# vLLM startup log (vllm/v1/worker/gpu_worker.py:391)
Free memory on device (28.66/31.34 GiB) on startup.
Desired GPU memory utilization is (0.85, 26.64 GiB).
Actual usage is 14.99 GiB for weight,
                1.01 GiB for peak activation,
                0.09 GiB for non-torch memory,
            and -0.41 GiB for CUDAGraph memory.
Available KV cache memory: 10.55 GiB
num_gpu_blocks: 5402

Step 1: How much can vLLM use?

ComponentSizeSource
total_memory31.34 GiBRTX 5090
× gpu_memory_utilization× 0.85--gpu-memory-utilization
= requested_memory26.64 GiBUpper bound for all vLLM usage (total × 0.85)
init_free_memory28.66 GiBMust be ≥ requested_memory, otherwise vLLM refuses to start

Step 2: Subtract non-KV usage → remaining = KV cache

vLLM runs a profile, measures all non-KV memory, and subtracts from requested_memory:

ComponentSizeSource
Model weights14.99 GiBvLLM profile run
Peak activation1.01 GiBvLLM profile run (FlashAttention + torch.compile)
Non-torch memory0.09 GiBCUDA kernels, driver overhead
CUDAGraph memory−0.41 GiBGraph capture freed memory (negative = reclaimed)
Redundancy buffer0.15 GiBHardcoded 150 MiB safety margin
= non-KV total16.09 GiBSum of above + internal overhead
available_kv = requested_memory − non_kv_total
             = 26.6416.09
             = 10.55 GiB

num_gpu_blocks = 10.55 GiB ÷ 2 MiB = 5,402 blocks  ✓
Verification: 10.55 GiB ÷ 2 MiB = 10,803.2 MiB ÷ 2 MiB = 5,401.6 → 5,402 blocks
Also: vLLM reports kv cache memory in use = 11,328,937,984 bytes = 10.55 GiB — matches exactly.

8.4 What Is Activation Memory?

Activations are the intermediate tensors produced during a forward pass. Unlike model weights (permanent) and KV cache (persistent across steps), activations are transient — created during computation and discarded immediately after.

The two largest activations in each Transformer layer:

Attention scores

num_heads × S × S × dtype_size

With FlashAttention, this is never materialized as a full S×S matrix — computed in tiles, so actual memory ≈ 0.

FFN up + gate projection (dominant)

max_batch_tokens × intermediate_size × 2 × dtype_size

Llama 8B: 8192 × 14336 × 2 × 2 = ~448 MB

Only one layer's activation is alive at a time (inference has no backprop), so this is the peak, not cumulative.

8.5 GPU Memory Map

RegionSizeLifetimeConfigurable?
Other processes (e.g. Desktop)2.68 GiBPermanentKill other GPU processes
Model weights14.99 GiBPermanentChoose model / quantization
Peak activation1.01 GiBTransient (per forward pass)--max-num-batched-tokens
Non-torch + CUDAGraph−0.32 GiBPermanent
KV cache blocks10.55 GiBPersistent (across steps)--gpu-memory-utilization, --block-size
Reserved (15% of init_free)4.30 GiB--gpu-memory-utilization
Total GPU31.34 GiB

8.6 Why This Matters for Continuum

Continuum pins blocks to keep them allocated. Every pinned block is one fewer block available for new requests. With 5,402 total blocks:

ScenarioTokens pinnedBlocks pinned% capacity locked
1 short agent job~100~70.1%
10 concurrent jobs, 2K ctx each20,0001,25023%
64 concurrent jobs (SWE-bench)128,000+8,000+>100% → eviction needed

This is why Continuum's pin TTL and eviction policy matter — aggressive pinning without eviction would exhaust the KV cache budget and stall all new requests.

9. Policy Benchmark: FCFS vs Continuum vs LMCache

9.1 Experiment Design

ParameterValue
Arrival patternPoisson process (matching paper methodology)
JPS (Jobs Per Second)2.0 / 8.0
Duration90-120s
Turns per job8
Context per job~2,900 tokens (accumulated, realistic tool outputs)
Tool execution timeVariable: 50-200ms (fast tools) + 2-5s (pytest)
MetricAverage job duration (same as paper's primary metric)

9.2 Low Load: JPS=2.0 (no significant difference)

FCFSContinuumDiff
Jobs completed206231
Peak concurrent~15~22
KV cache pressure~75%
Avg duration6.65s6.97s+4.8%
Median6.63s6.96s+5.0%
Why no improvement: At 75% KV cache utilization, vLLM's prefix caching alone is sufficient — freed blocks stay in the free queue long enough for the same job to reclaim them via hash lookup. Continuum's pin overhead (ToolCallEstimator, ContinuumRequestQueue O(n) scan) adds cost without benefit.

9.3 High Load: JPS=8.0 (significant improvement)

FCFSContinuumDiff
Jobs completed756704
Peak concurrent~156~111
KV cache pressure~95%+
Avg duration14.10s12.47s-11.6%
Median14.33s12.61s-12.0%
P9017.02s14.37s-15.6%
P9517.54s14.63s-16.6%

Per-turn latency comparison

TurnPrompt TokensFCFSContinuumDiff
1921,263 ms1,051 ms-16.8%
2469735 ms519 ms-29.4%
3811574 ms479 ms-16.6%
41,3431,005 ms841 ms-16.3%
51,655749 ms652 ms-12.9%
62,241641 ms539 ms-15.9%
72,6653,732 ms3,212 ms-13.9%
82,9151,236 ms1,034 ms-16.3%
Key finding: Turn 2 shows the largest improvement (-29.4%) — this is the first turn that benefits from KV cache pinning. Under high load, FCFS's freed blocks get immediately reclaimed by other jobs, forcing re-prefill. Continuum's pin keeps blocks out of the free queue entirely.

9.4 Why Load Matters

Low load (JPS=2, ~75%)

Free queue has ~1,350 spare blocks. Freed blocks sit safely — no one grabs them. Prefix cache hit rate ≈ pin hit rate. Continuum's overhead > its benefit.

High load (JPS=8, ~95%)

Free queue nearly empty. Freed blocks get immediately grabbed by new jobs → prefix cache miss → full re-prefill. Continuum's pin prevents this by keeping blocks out of the free queue entirely.

10. Code Architecture: FCFS / Continuum / LMCache

This section traces exactly how each scheduling strategy is implemented in the vllm-continuum codebase. All three share the same scheduler loop (vllm/v1/core/sched/scheduler.py), but differ in request ordering, preemption policy, and post-completion block handling.

10.1 Policy Enum & Queue Factory

The four supported policies and their queue implementations are registered in a single factory:

# vllm/v1/core/sched/request_queue.py:17-22
class SchedulingPolicy(Enum):
    FCFS          = "fcfs"
    PRIORITY      = "priority"
    CONTINUUM     = "continuum"
    CONTEXT_SERVE = "context_serve"

# vllm/v1/core/sched/request_queue.py:342-353
def create_request_queue(policy):
    if policy == SchedulingPolicy.FCFS:
        return FCFSRequestQueue()        # deque[Request], pure FIFO
    elif policy == SchedulingPolicy.CONTINUUM:
        return ContinuumRequestQueue()   # deque + job_id-level FCFS
    elif policy == SchedulingPolicy.CONTEXT_SERVE:
        return ContinuumRequestQueue()   # reuses Continuum queue logic
CLI usage: vllm serve MODEL --scheduling-policy fcfs|continuum|context_serve
(
vllm/engine/arg_utils.py:875, default: fcfs)

10.2 FCFS — First-Come-First-Served

The simplest policy — pure FIFO with no job awareness.

Request Queue

# vllm/v1/core/sched/request_queue.py:85-140
class FCFSRequestQueue(deque[Request], RequestQueue):
    def add_request(self, request):
        self.append(request)           # → right end (newest)

    def pop_request(self):
        return self.popleft()          # ← left end (oldest)

    def peek_request(self):
        return self[0]                # oldest, without removing

    def prepend_request(self, request):
        self.appendleft(request)       # preempted → back to head

Scheduling dispatch

# vllm/v1/core/sched/scheduler.py:470-472 — schedule() waiting loop
if self.policy == SchedulingPolicy.FCFS:
    request = self.waiting.peek_request()   # oldest arrival

Preemption

# vllm/v1/core/sched/scheduler.py:394-396
else:  # FCFS
    preempted_req = self.running.pop()  # evict last-added running request

Post-completion

# vllm/v1/core/sched/scheduler.py:1398-1399
self.kv_cache_manager.free(request)  # ref_cnt → 0, blocks → free queue
del self.requests[request.request_id]

FCFS Summary

Request orderStrict arrival time (deque FIFO)
Preempt victimLast request in running list (pop())
Block handlingAlways free() → blocks return to free queue immediately
Job awarenessNone — treats every request independently

10.3 Continuum — Pin-based KV Cache Reservation

Continuum adds three mechanisms on top of the base scheduler: (1) job-ID level FCFS ordering with pinned-job priority, (2) smart preemption that protects last-step requests, (3) TTL-based KV cache pinning after request completion.

① Request Queue: ContinuumRequestQueue

Two-tier priority: pinned jobs first, then job-ID level FCFS (not request-level).

# vllm/v1/core/sched/request_queue.py:221-274
class ContinuumRequestQueue(deque[Request], RequestQueue):
    def __init__(self):
        self.job_id_first_entry_time: dict[str, float] = {}  # job → first arrival

    def peek_request(self, pinned_requests, kv_cache_manager, connector):
        pinned_job_ids = {req.job_id for req, _ in pinned_requests}

        # Tier 1: find the waiting request whose job has pinned blocks
        #         (among pinned jobs, pick the one with earliest first-entry)
        earliest = None
        for request in self:
            if request.job_id in pinned_job_ids:
                entry = self.job_id_first_entry_time.get(request.job_id)
                if entry < earliest_time:
                    earliest = request

        if earliest:
            return earliest   # ← pinned-job request goes first!

        # Tier 2: job-ID level FCFS — the job that arrived earliest overall
        for request in self:
            entry = self.job_id_first_entry_time.get(request.job_id)
            if entry < earliest_time:
                earliest = request
        return earliest
Key difference from FCFS: FCFS picks the oldest request. Continuum picks the request whose job arrived earliest — so a Turn-5 request of an old job beats a Turn-1 request of a new job. Pinned jobs (those with KV cache still reserved in VRAM) get absolute priority to minimize wasted pin time.

② Preemption: Protect Last-Step Requests

# vllm/v1/core/sched/scheduler.py:197-236
def pop_running_request_based_on_last_step(self, request):
    # Priority 1: evict non-last-step request with latest job entry time
    for req in self.running:
        if not req.is_last_step and job_entry_time > latest:
            latest_request = req

    # Priority 2: if all are last-step, evict the latest entry regardless
    if latest_request is None:
        for req in self.running:
            if job_entry_time > latest:
                latest_request = req

    # Edge case: if running ≤ 1, steal from pinned_requests instead
    if len(self.running) <= 1:
        # Pop the pinned request with the largest end_time
        return latest_pinned, True  # is_unpin=True
Rationale: Last-step requests (is_last_step=True) are about to free their blocks permanently. Evicting them wastes their entire computation. Non-last-step requests will be pinned after completion anyway, so evicting them is cheaper — they'll get priority when they re-enter the waiting queue.

③ Pin Lifecycle: TTL-based Reservation

The complete pin/unpin lifecycle is managed by three methods in the scheduler:

# vllm/v1/core/sched/scheduler.py:239-258

## 1. Pin: skip free(), store (request, end_time) tuple
def pin_request(self, request, length_of_pin):
    self.pinned_requests.append((request, time.time() + length_of_pin))

## 2. Unpin: remove from list, call free() to release blocks
def unpin_request(self, request, end_time):
    self.pinned_requests.remove((request, end_time))
    self.kv_cache_manager.free(request)  # ref_cnt → 0

## 3. Regular cleanup: called at the start of every schedule() step
def unpin_requests_regular(self):
    waiting_job_ids = [req.job_id for req in self.waiting]
    for request, end_time in self.pinned_requests:
        # Only unpin if: (a) TTL expired AND (b) no waiting request for this job
        if request.job_id not in waiting_job_ids and time.time() >= end_time:
            self.unpin_request(request, end_time)

④ TTL Decision: ToolCallEstimator

# vllm/v1/core/estimate_with_func.py:151-160
FIXED_THRESHOLD_CONTINUUM = 2.0   # seconds

def set_up_pin(self, request):
    if request.this_func_call is None:
        return 0                        # can't detect tool → don't pin

    exec_time = self.get_func_call_exec_time(request.this_func_call)

    if exec_time > FIXED_THRESHOLD_CONTINUUM:
        return 0                        # slow tool (>2s) → don't pin

    return FIXED_THRESHOLD_CONTINUUM    # fast tool → pin for 2.0s
Implementation vs Paper: The paper describes a CDF-based TTL estimator that uses the historical execution time distribution to compute optimal pin durations. The implementation uses a fixed 2-second threshold: if avg exec time ≤ 2s → pin for exactly 2s; otherwise → don't pin at all. This is a significant simplification.

⑤ Post-completion: _free_blocks() Dispatch

# vllm/v1/core/sched/scheduler.py:1356-1399
def _free_blocks(self, request):
    # Step 1: unpin any sibling pinned request with same job_id
    for req, end_time in self.pinned_requests:
        if req.job_id == request.job_id:
            self.unpin_request(req, end_time)

    # Step 2: Continuum decision
    if self.policy == SchedulingPolicy.CONTINUUM and not request.is_last_step:
        length_of_pin = self.tool_call_estimator.set_up_pin(request)
        if length_of_pin > 0.01:
            self.pin_request(request, length_of_pin)   # ← SKIP free()
            del self.requests[request.request_id]
            return

    # Default: free blocks (FCFS path, or is_last_step=True)
    self.kv_cache_manager.free(request)
    del self.requests[request.request_id]

Continuum Summary

Request orderPinned-job requests first, then job-ID level FCFS
Preempt victimNewest non-last-step job → newest last-step → steal from pinned
Block handlingis_last_step=False + fast tool → pin (skip free, TTL=2s); otherwise → free()
Job awarenessFull: tracks job_id, is_last_step, this_func_call, per-job history

10.4 ContextServe — Three-Way Dispatch Extension

ContextServe extends Continuum's binary pin/no-pin decision into a three-way dispatch: FAST (strict pin), SLOW (proactive free), UNCERTAIN (fallback to Continuum TTL). It shares the same ContinuumRequestQueue and preemption logic.

# vllm/v1/core/sched/scheduler.py:1364-1388
if self.policy == SchedulingPolicy.CONTEXT_SERVE and not request.is_last_step:
    speed, pin_duration = self.tool_call_estimator.set_up_pin_contextserve(request)

    if speed == ToolSpeed.FAST:           # ls, cat, pwd, git status ...
        self.pin_request(request, pin_duration)   # lock in VRAM

    elif speed == ToolSpeed.SLOW:         # pytest, pip install, docker build ...
        self.kv_cache_manager.free(request)     # proactive free at t=0

    else:                                # ToolSpeed.UNCERTAIN
        if pin_duration > 0.01:             # fallback to Continuum TTL
            self.pin_request(request, pin_duration)

Three-Layer Prediction Stack

The ContextAwarePredictor (vllm/v1/core/context_predictor.py) uses a three-layer classification strategy:

LayerMethodExample
1. Rule-basedHardcoded sets: DEFINITELY_FAST (ls, cat, pwd), DEFINITELY_SLOW_PREFIXES (make, pip install, docker build)ls → FAST 3.0s
2a. BayesianNormal-Inverse-Gamma conjugate prior + Thompson Sampling (vllm/v1/core/bayesian_predictor.py)Posterior quantile → TTL
2b. XGBoost25-feature extraction (pipes, redirects, loops, etc.) → classificationFallback if Bayesian fails
3. FallbackReturns UNCERTAIN + Continuum CDF TTL

10.5 LMCache — CPU DRAM Offload

LMCache is an orthogonal strategy to Continuum. Instead of pinning blocks in GPU VRAM (which reduces capacity for other jobs), LMCache offloads evicted KV cache blocks to CPU DRAM or a remote server, then retrieves them when the same job returns. It integrates via the KVConnectorBase_V1 connector interface.

Connector Architecture

# vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py:23-27
class LMCacheConnectorV1(KVConnectorBase_V1):
    def __init__(self, vllm_config, role):
        self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self)
        # Delegates to lmcache.integration.vllm.vllm_v1_adapter

Two operational roles, split across scheduler and worker processes:

RoleLocationMethods
SCHEDULER vllm/v1/core/sched/scheduler.py:84-93 get_num_new_matched_tokens() — query how many tokens are cached externally
update_state_after_alloc() — track block↔token mapping
build_connector_meta() — encode store/load plan for workers
WORKER vllm/v1/worker/kv_connector_model_runner_mixin.py start_load_kv() — async load from CPU before forward
wait_for_layer_load(layer) — per-layer sync point
save_kv_layer(layer, tensor) — async save to CPU after compute
wait_for_save() — ensure save completes

Data Flow: Store & Retrieve

## Store (after each forward step)
Attention layer → save_kv_layer(layer_name, kv_tensor, attn_metadata)
    → LMCache chunks KV by 256 tokens (configurable)
    → Stores to CPU DRAM (LMCACHE_LOCAL_CPU=True)
       or remote server (LMCACHE_REMOTE_URL=lm://host:port)

## Retrieve (on new request for same prefix)
Scheduler → connector.get_num_new_matched_tokens(request)
    → LMCache matches prompt prefix hash against stored chunks
    → Returns count of externally cached tokens

Worker → connector.start_load_kv(forward_context)
    → Per-layer: wait_for_layer_load(layer_name)
    → Loaded KV appears in GPU paged buffer
    → Attention skips recomputation for loaded tokens

Configuration

# CLI: pass via JSON to --kv-transfer-config
vllm serve MODEL --kv-transfer-config '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_both"}'

# Environment variables (set before launch)
export LMCACHE_USE_EXPERIMENTAL="True"
export LMCACHE_LOCAL_CPU="True"              # enable CPU DRAM backend
export LMCACHE_MAX_LOCAL_CPU_SIZE="5.0"     # 5 GB CPU memory limit
export LMCACHE_CHUNK_SIZE="256"             # tokens per chunk

10.6 Side-by-Side Comparison

FCFS Continuum LMCache
Strategy No KV retention — evicted blocks immediately available Pin in GPU VRAM — blocks reserved for returning job Offload to CPU DRAM — blocks freed on GPU, cached on CPU
GPU capacity impact None (100% available) Reduces available blocks (pinned blocks are locked) None (blocks freed on GPU, stored on CPU)
Resume cost Full re-prefill (recompute all tokens) Zero (KV still in GPU, same addresses) PCIe transfer (CPU→GPU, ~12 GB/s)
Where it shines Low load; prefix caching handles reuse Small GPU, high load, fast tools Large GPU, high load, abundant CPU DRAM
Key files request_queue.py:85 request_queue.py:221
estimate_with_func.py:151
scheduler.py:239
lmcache_connector.py
kv_transfer.py
attention/layer.py:445

10.7 Request Lifecycle per Strategy

How each strategy handles KV cache during tool-call gaps under high memory pressure (KV cache >90%). Toggle between short and long tool execution times:

Short Tool Call (< TTL)
Long Tool Call (> TTL)
1. FCFS recompute on return

KV freed immediately after inference. Under high load, blocks evicted by other requests → full recompute on return.

request finished → _free_blocks()kv_cache_manager.free() → blocks → free queue
  → blocks immediately available for any new request (no pin)
tool started
tool finished
VRAM
LLM inferenceKV allocated
tool executing — KV freed → evictedblocks taken by others
recompute ⚠re-prefill
LLM inferenceresume
Slowest
DRAM
— unused —
2. Continuum VRAM pin with TTL

KV pinned in VRAM (ref_cnt stays > 0). Tool returns before TTL expires → cache hit, zero recompute.

request finished → _free_blocks():
  if not is_last_step + fast tool → pin_request(2.0s) → blocks RESERVED (skip free)
  → next turn → pinned-job gets scheduling priority → cache hit
tool started
tool finished
VRAM
LLM inferenceKV allocated
tool executing — KV pinned ✓blocks in VRAM, not evictable
LLM inferencecache hit
★ Fastest
DRAM
— unused —
3. LMCache DRAM offload + reactive reload

KV offloaded to CPU DRAM asynchronously, GPU freed. On tool return, reload from DRAM (PCIe transfer).

forward → per-layer: save_kv_layer()async copy to CPU DRAMwait_for_save()free() → GPU blocks released
  → next turn → connector.get_num_new_matched_tokens()skip prefill for cached portionstart_load_kv() → per-layer wait_for_layer_load()
tool started
tool finished
VRAM
LLM inferenceKV allocated
offload ↓GPU→CPU
GPU memory freedavailable for other requests
reload ↑CPU→GPU
LLM inferencerestored
Medium
DRAM
KV data held in CPU pinned memory
1. FCFS recompute on return

KV freed immediately. Blocks evicted → full recompute on return.

request finished → _free_blocks()kv_cache_manager.free() → blocks → free queue
  → long tool — blocks overwritten by others — full recompute on return
tool started
tool finished
VRAM
LLM inferenceKV allocated
tool executing — KV freed → evictedblocks taken by others
recompute ⚠re-prefill
LLM inferenceresume
Slowest
DRAM
— unused —
2. Continuum VRAM pin with TTL

KV pinned, but TTL expires before tool returns → blocks unpinned and evicted → must recompute.

pin_request(2.0s) → TTL counting...
  → unpin_requests_regular() → time.time() >= end_time → unpin + free()blocks evicted before tool returns — degrades to FCFS behavior
tool started
TTL expired
tool finished
VRAM
LLM inferenceKV allocated
KV pinned ✓TTL active
unpinned → evictedblocks taken by others
recompute ⚠re-prefill
LLM inferenceresume
Slowest
DRAM
— unused —
3. LMCache DRAM offload + reactive reload

KV safe in DRAM regardless of tool duration. On return, reload from DRAM — no recompute needed.

per-layer: save_kv_layer() → async GPU→CPU copy → free()long tool — KV safe on CPU, GPU fully available for others
  → next turn → start_load_kv()reload from DRAM (skip recompute)
tool started
tool finished
VRAM
LLM inferenceKV allocated
offload ↓GPU→CPU
GPU memory freedavailable for other requests
reload ↑CPU→GPU
LLM inferencerestored
★ Fastest
DRAM
KV data held in CPU pinned memory
LLM inference
KV pinned in VRAM
TTL expired → evicted
KV freed / evicted
DRAM offload / reload
Recompute (costly)
GPU memory available

10.8 Eviction Behavior Under Memory Pressure

When KV cache is nearly full (~95%), a new request arrives that needs blocks. How does each strategy decide what to evict? Below demonstrates with 3 concurrent jobs and only 8 free blocks.

Setup: Job A (turn 3, waiting) needs 12 blocks. Only 8 free blocks remain. Job B (turn 2, running) and Job C (turn 1, running) occupy blocks. Job A has pinned blocks from its previous turn.

FCFS — Evict last-added running request

① Initial state: 8 free blocks, need 12 → must evict 4 blocks

VRAM
A₁
A₂
A₃
A₄
B₁
B₂
B₃
C₁
C₂
C₃
C₄
C₅
·
·
·
·
·
·
·
·

② FCFS evicts Job A's old blocks (freed, no pin), but A needs 12 blocks — still not enough. Evict Job C (last-added to running).

VRAM
A₁
A₂
A₃
A₄
B₁
B₂
B₃
C₁
C₂
C₃
C₄
C₅
·
·
·
·
·
·
·
·

③ Result: Job A gets 12 blocks, but must recompute all tokens (its old blocks were freed). Job C preempted → sent back to waiting queue head, also needs full recompute later.

VRAM
A₁
A₂
A₃
A₄
A₅
A₆
A₇
A₈
A₉
A₁₀
A₁₁
A₁₂
B₁
B₂
B₃
·
·
·
·
·
Cost: Job A: full recompute (12 blocks). Job C: preempted, will need full recompute when re-scheduled. Total wasted work: 2 full prefills.

Continuum — Pin protects, evict newest non-last-step

① Initial state: Job A's blocks are PINNED (ref_cnt > 0). Same 8 free blocks.

VRAM
A₁🔒
A₂🔒
A₃🔒
A₄🔒
B₁
B₂
B₃
C₁
C₂
C₃
C₄
C₅
·
·
·
·
·
·
·
·

② Job A gets pinned-job priority. Unpin old A blocks (they'll be reused as prefix cache hit). Need 4 more → evict Job C (newest non-last-step job).

VRAM
A₁✓
A₂✓
A₃✓
A₄✓
B₁
B₂
B₃
C₁
C₂
C₃
C₄
C₅
·
·
·
·
·
·
·
·

③ Result: Job A reuses 4 pinned blocks as cache hit + allocates 8 new blocks. Only 8 blocks need computation (not 12). Job C preempted.

VRAM
A₁✓
A₂✓
A₃✓
A₄✓
A₅
A₆
A₇
A₈
A₉
A₁₀
A₁₁
A₁₂
B₁
B₂
B₃
·
·
·
·
·
Savings: Job A: 4 blocks cached (33% saved), only 8 blocks recomputed. Total: 1 partial prefill + 1 preemption.

LMCache — Blocks in DRAM, GPU fully available

① Initial state: Job A's KV was offloaded to DRAM — its GPU blocks were already freed. 8 free blocks + A's 4 freed blocks = 12 available.

VRAM
B₁
B₂
B₃
C₁
C₂
C₃
C₄
C₅
·
·
·
·
·
·
·
·
·
·
·
·
DRAM
A₁
A₂
A₃
A₄
(offloaded earlier)

② Job A returns. Reload 4 blocks from DRAM (PCIe transfer), allocate 8 new blocks. No eviction needed — 12 free blocks available!

VRAM
B₁
B₂
B₃
C₁
C₂
C₃
C₄
C₅
A₁↑
A₂↑
A₃↑
A₄↑
A₅
A₆
A₇
A₈
A₉
A₁₀
A₁₁
A₁₂
Savings: Job A: 4 blocks reloaded from DRAM (skip recompute), 8 new blocks computed. No preemptions. All jobs continue running. Cost: PCIe transfer latency (~0.5ms per block).

10.9 Multi-GPU JPS Sweep Results

Three policies tested across three GPUs at varying load levels. Key findings:

GPUVRAM~KV BlocksBest Strategy at High LoadKey Observation
RTX 600024 GB~3,000Continuum71% faster at JPS=1.0; 26% faster at JPS=12.0 — pin prevents eviction on small VRAM
A10040 GB~9,000LMCache39-65% faster than FCFS at JPS=10+. Continuum provides no benefit — sometimes slower
L40S48 GB~13,000LMCache39-60% faster at JPS=10+. Same pattern — large GPU favors DRAM offload over VRAM pin
Takeaway: Continuum's VRAM pin strategy trades capacity for speed — beneficial when VRAM is scarce (<32 GB) and tools are fast. On larger GPUs (>40 GB), the capacity cost of pinning outweighs its benefit; LMCache's CPU DRAM offload provides the same resume-speedup without reducing GPU capacity.

10.10 Full Call Graph: Experiment → vLLM Internals → Response

End-to-end trace from ab_benchmark.py request through every vLLM layer to HTTP response. Shows how job_id / is_last_step propagate and where each strategy diverges.

Phase 1 — Request Path (inbound)

The key challenge is propagating job_id and is_last_step from the client's extra_body through the standard OpenAI-compatible API, past vLLM's multi-layer engine architecture, all the way into the scheduler's Request object — without modifying the OpenAI protocol.

Experiment script ab_benchmark.py:329
client.chat.completions.create(extra_body={"job_id": "job_0042", "is_last_step": False})
The OpenAI Python SDK merges extra_body keys into the top-level JSON payload, so vLLM's Pydantic model sees job_id as a regular field.
HTTP endpoint (FastAPI) vllm/entrypoints/openai/api_server.py:688
create_chat_completion(request: ChatCompletionRequest)
Protocol parsing protocol.py:635-638
ChatCompletionRequest.job_id = "job_0042"   ChatCompletionRequest.is_last_step = False
Because the base model uses extra="allow", unrecognized keys (like job_id) are accepted without breaking vanilla OpenAI requests.
Pack into SamplingParams.extra_args protocol.py:730-737
to_sampling_params() → extra_args = {"job_id": .., "is_last_step": ..}
⑤⑥
Pass-through layers
serving_chat.py:300async_llm.py:289processor.py:424
These 3 layers (serving handler → AsyncLLM → processor) simply forward sampling_params without touching extra_args.
Request object creation vllm/v1/request.py:146-152
Request.from_engine_core_request() → Request.job_id, Request.is_last_step
This is where the dict becomes typed attributes. From this point on, the scheduler can use request.job_id directly.
Scheduler.add_request() scheduler.py:1276-1290
tool_call_estimator.request_arrives(request) ← records per-job history
waiting.add_request(request) ← into FCFS deque or ContinuumRequestQueue
Phase 2 — Scheduling & Execution

The scheduler runs in a loop: pick next request from waiting queue → allocate KV cache blocks → if not enough blocks, preempt someone → send to GPU. Each strategy differs in who gets picked and who gets evicted.

Scheduler.schedule() scheduler.py:271
unpin_requests_regular() ← first check: any pinned request whose TTL expired? unpin and free
├ FCFS: waiting.peek_request() — oldest request, no job awareness
├ Continuum: waiting.peek_request(pinned..) — pinned-job first, then job-ID FCFS
└ LMCache: connector.get_num_new_matched_tokens() — check CPU cache for this request
KVCacheManager.allocate_slots() kv_cache_manager.py:192
Try to allocate GPU blocks. If free pool exhausted → preemption:
├ FCFS: running.pop() — blindly evict last-added
├ Continuum: pop_running_request_based_on_last_step() — protect is_last_step
└ LMCache: — rarely needed (offload freed GPU blocks)
GPU forward pass gpu_model_runner.py:2064
model_output = self.model(input_ids, positions, ...)
The actual LLM computation. K/V tensors written into pre-allocated blocks via FlashAttention.
└ LMCache: start_load_kv() beforesave_kv_layer() per layerwait_for_save() after
  kv_connector_model_runner_mixin.py:44,112,117
  Hooks into forward: load from CPU before, save to CPU after each layer.
Sample tokens gpu_model_runner.py:2118
sampler_output = self._sample(logits)
Phase 3 — Completion & Pin Decision

When the model finishes generating (EOS or max_tokens), the scheduler must decide what to do with the KV cache blocks. This is where the three strategies fundamentally differ — all in one function: _free_blocks().

Scheduler.update_from_output() scheduler.py:1042
if stopped → _free_request() → _free_blocks()
_free_blocks() — THE STRATEGY DIVERGENCE POINT scheduler.py:1356-1399
├ FCFS: kv_cache_manager.free(request)
    Always free immediately. Under high load, blocks grabbed by other jobs within ms.

├ Continuum: tool_call_estimator.request_finished() ← detokenize, regex-parse tool name
    estimate_with_func.py:151 set_up_pin():
      If tool is fast (avg ≤ 2s)pin_request(2.0s) ← skip free! Blocks stay in VRAM
      If tool is slow (avg > 2s)free() ← not worth holding that long

└ ContextServe: context_predictor.py predict()
    Three-layer prediction: rules → Bayesian → XGBoost → fallback
      FAST → pin_request(ttl)  |  SLOW → free()  |  UNCERTAIN → pin(fallback_ttl)
Phase 4 — Response Path (outbound)
Output pipeline
core.py:766 output_queue → ZMQ socket
async_llm.py:427 _run_output_handler() → per-request queue
serving_chat.py → ChatCompletionResponse ← detokenize, format as OpenAI JSON
api_server.py:703 → JSON HTTP 200
Back to experiment script ab_benchmark.py:335
result.turns.append({latency_ms, prompt_tokens, ...})
time.sleep(sample_tool_time(turn)) ← simulate tool execution (50ms–5s), then loop back to ① with accumulated context

Variable Propagation: extra_body → pin decision

StageVariable FormFile
extra_body={"job_id": "job_0042", "is_last_step": False}ab_benchmark.py:329
ChatCompletionRequest.job_id = "job_0042"protocol.py:635
SamplingParams.extra_args = {"job_id": .., "is_last_step": ..}protocol.py:730
Request.job_id = "job_0042", Request.is_last_step = Falserequest.py:146
waiting.add_request(request) + job history recordedscheduler.py:1287
is_last_step=False + fast tool → pin_request(2.0s)scheduler.py:1391

10.11 Experiment Scripts — Source Code Walkthrough

Three key scripts drive all experiments: the policy benchmark, KV pin verifier, and JPS sweep plotter.

① ab_benchmark.py — Poisson-Arrival Policy Benchmark

547 lines. Simulates realistic multi-turn agentic workload with Poisson-distributed job arrivals. Runs one policy per invocation; compare by running multiple times with different --policy.

# CLI usage
python3 ab_benchmark.py --policy continuum --port 8199 --jps 8.0 --duration 90 --turns 8
# Core: Poisson arrival process (line 381)
def poisson_arrival(jps, duration_s, base_url, model_id, n_turns):
    gap = random.expovariate(jps)   # exponential inter-arrival time
    time.sleep(gap)
    # spawn job as thread with job_id = f"job_{count:04d}"

# Core: how Continuum metadata is passed (line 329)
resp = client.chat.completions.create(
    model=model_id,
    messages=messages,       # accumulated context (grows each turn)
    max_tokens=150,
    extra_body={
        "job_id": job_id,
        "is_last_step": (i == n_turns - 1),  # True only on final turn
    },
)

# Tool execution time per turn (line 276-286)
TOOL_TIME_DISTRIBUTIONS = [
    (0.05, 0.15),   # Turn 1: find/ls — fast
    (0.05, 0.10),   # Turn 2: cat — fast
    (0.05, 0.10),   # Turn 3: cat — fast
    (0.08, 0.20),   # Turn 4: grep — fast
    (2.0,  5.0),    # Turn 5: pytest — SLOW (long-tailed)
    (0.05, 0.10),   # Turn 6: cat — fast
    (0.10, 0.30),   # Turn 7: patch — fast
]
# Output JSON schema (line 516-538)
{
  "policy": "continuum",
  "jps": 8.0,
  "avg_duration_s": 12.465,
  "p90_duration_s": 14.371,
  "p95_duration_s": 14.629,
  "per_turn_avg_latency_ms": {"1": 1051, "2": 519, ...},
  "per_turn_avg_prompt_tokens": {"1": 92, "2": 469, ...},
  "job_durations": [6.185, 6.607, ...]
}

② test_pin_verify.py — KV Cache Pin Verification

102 lines. Verifies Continuum's pin mechanism by polling vllm:kv_cache_usage_perc from the Prometheus /metrics endpoint during tool execution gaps.

# Parse Prometheus metric (line 36-41)
def get_kv_usage():
    r = req.get(f"{BASE}/metrics", timeout=5)
    for line in r.text.split("\n"):
        if line.startswith("vllm:kv_cache_usage_perc{"):
            return float(line.split()[-1])

# Verification: poll every 0.3s during 1s tool gap (line 84-95)
if not is_last:
    print("  Simulating tool execution (1s) — KV should stay pinned...")
    for t in [0.3, 0.6, 0.9]:
        time.sleep(0.3)
        u = get_kv_usage()
        print(f"    t+{t:.1f}s: KV usage = {u:.6f}")
        # Expected: u ≈ usage_post (constant = pinned)
        # Failure:  u drops → blocks evicted, pin broken

③ plot_jps_sweep.py — Multi-GPU Comparison Plots

153 lines. Reads JSON results from 3 GPUs × 3 policies, generates publication-quality comparison figures.

# GPU configurations (line 13-17)
GPUS = {
    "rtx_6000": {"name": "RTX 6000 (24 GB)", "jps": [1.0, 2.0, 5.0, 8.0, 12.0]},
    "A100":     {"name": "A100 (40 GB)",      "jps": [2.0, 5.0, 10.0, 15.0, 25.0, 40.0]},
    "L40S":     {"name": "L40S (48 GB)",      "jps": [2.0, 5.0, 10.0, 15.0, 25.0]},
}

# Speedup formula (line 124)
speedup = (fcfs["avg_duration_s"] - other["avg_duration_s"]) / fcfs["avg_duration_s"] * 100

# Generates 3 figure types:
# 1. jps_sweep_{gpu}.png   — per-GPU: avg duration vs JPS for 3 policies
# 2. jps_sweep_combined.png — 3-subplot side-by-side comparison
# 3. jps_sweep_speedup.png  — % improvement over FCFS baseline

10.12 Key Files Reference

ComponentFileLines
Policy enum + queue factoryvllm/v1/core/sched/request_queue.py17-353
FCFSRequestQueuevllm/v1/core/sched/request_queue.py85-140
ContinuumRequestQueuevllm/v1/core/sched/request_queue.py221-340
Scheduler corevllm/v1/core/sched/scheduler.py116-130, 197-264, 370-418, 1356-1399
ToolCallEstimator (TTL)vllm/v1/core/estimate_with_func.py95-228
ContextAwarePredictorvllm/v1/core/context_predictor.py43-268
BayesianPredictorvllm/v1/core/bayesian_predictor.py39-150
LMCache connectorvllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py1-168
Connector factoryvllm/distributed/kv_transfer/kv_connector/factory.py95-98
KV transfer configvllm/config/kv_transfer.py1-112
API protocol (job_id extraction)vllm/entrypoints/openai/protocol.py294-296, 364-384
Request dataclassvllm/v1/request.py23-46, 136-175
CLI argumentvllm/engine/arg_utils.py875-876
Policy benchmark scriptour_experiments/scripts/ab_benchmark.py547
JPS sweep plotterour_experiments/scripts/plot_jps_sweep.py153

11. Observations & Takeaways

Continuum's core mechanism is sound

The tool detection → time estimation → KV pin → priority scheduling pipeline works correctly on the happy path. Metrics confirm blocks are not freed during the pin window.

Tightly coupled to model output format

ToolCallParser relies on ```bash\n...\n``` regex. Thinking models (Qwen3) break this entirely — 2/10 vs 10/10 parse rate. The system is designed for Llama-family output format.

Pin duration is simplistic

The repo uses a fixed 2-second threshold (FIXED_THRESHOLD_CONTINUUM = 2.0), not the CDF-based estimation described in the paper. If historical exec time > 2s → don't pin. Otherwise → pin exactly 2s.

Research-grade code quality

No defensive handling for job_id=None, concurrent same-job requests, or thinking models. The code works on the author's exact setup (Llama 70B + mini-swe-agent + sequential turns) but breaks on deviations.

Not yet tested

A/B comparison (FCFS vs Continuum latency under load), concurrent multi-job pressure test, scheduler_timestamps hit_length analysis, and SWE-bench end-to-end pass rate.