Hands-on verification of the vLLM-Continuum KV cache pinning mechanism
— Functional Test & Metrics-based Verification —
| Item | Value |
|---|---|
| Machine | ssh hclin@100.107.101.84 |
| GPU | 1× NVIDIA RTX 5090 (32 GB VRAM, ~30 GB available) |
| Conda env | contextserve |
| vLLM source | /home/hclin/vllm-continuum/vllm/ (editable install) |
| Base vLLM version | v0.7.x (fork, ~v0.7.2) |
| Model | /home/hclin/shared_models/Meta-Llama-3.1-8B-Instruct (15 GB, local) |
| KV cache blocks | 5 402 blocks × block_size 16 tokens |
| Scheduling policy | --scheduling-policy continuum |
/home/hclin/vllm-continuum/ has been extended with ContextServe (three-way dispatch: FAST/SLOW/UNCERTAIN). The Continuum branch is intact as an elif path — it is NOT affected when --scheduling-policy continuum is used.# Activate environment
source ~/anaconda3/etc/profile.d/conda.sh && conda activate contextserve
# Start vLLM with Continuum scheduling
RUN_OUTPUT_DIR=./e2e_output \
vllm serve /home/hclin/shared_models/Meta-Llama-3.1-8B-Instruct \
--scheduling-policy continuum \
--port 8199 \
--max-model-len 4096 \
--gpu-memory-utilization 0.85
Startup takes ~70 seconds (weight loading + torch.compile). Health check via:
curl http://localhost:8199/health
A Python script sends multi-turn agentic requests via the OpenAI client. Each turn provides job_id and is_last_step through the extra_body parameter, which vLLM's protocol layer unpacks into SamplingParams.extra_args.
# Key: how Continuum metadata is passed
resp = client.chat.completions.create(
model=MODEL,
messages=[
{"role": "system", "content": "Respond with ONLY a ```bash``` block."},
{"role": "user", "content": "List files with ls."},
],
max_tokens=150,
extra_body={
"job_id": "job_alpha", # ← Identifies the agent job
"is_last_step": False, # ← Tells scheduler to pin after completion
},
)
requests.post() directly with extra_body as a nested dict does NOT work — the server ignores it. You must use the OpenAI Python client, which merges extra_body keys into the top-level payload so vLLM's Pydantic model can parse job_id.2 sequential jobs × 5 turns each. Each turn carries the full conversation history (system prompt + all previous user/assistant/tool messages), simulating a real agent. Simulated tool outputs are injected between turns.
| Turn | User Prompt | LLM Output | Prompt Tokens | Output Tokens | E2E Latency |
|---|---|---|---|---|---|
| 1 | List files in the project. | ```bash ls ``` | 66 | 7 | 319 ms |
| 2 | Read main.py. | ```bash cat main.py ``` | 131 | 9 | 115 ms |
| 3 | Search for TODO comments. | ```bash grep -r 'TODO' . ``` | 279 | 12 | 154 ms |
| 4 | Check git status. | ```bash git status ``` | 397 | 8 | 106 ms |
| 5 | Run pytest. | ```bash pytest ``` | 503 | 7 | 98 ms |
<think>...</think> mode, consuming all tokens before producing the bash block — only 2/10 parsed. Continuum's ToolCallParser regex is designed for non-thinking models.Filtered server log showing the complete Continuum pipeline for job_alpha:
set_up_pin → pin 2.0slast func call: lslast func call: catlast func call: greplast func call: git (Turn 5, is_last_step=True)All 5 turns correctly parsed. Turns 1–4 triggered set_up_pin → return 2.0. Turn 5 (is_last_step=True) skipped pin and freed blocks.
A second test script polls vllm:kv_cache_usage_perc from the Prometheus /metrics endpoint in real-time, before and after each turn, during the simulated tool execution gap.
# 5 turns with accumulated conversation history + simulated tool outputs
# Turn 1-4: is_last_step=False (should pin)
# Turn 5: is_last_step=True (should free)
[Baseline] KV cache usage: 0.000000
Turn 1/5 last=False 319ms prompt=66 output=7
Output: ```bash ls ```
KV usage: 0.000000 → 0.000876 (delta: +0.000876)
Simulating tool execution — KV should stay pinned...
t+0.3s: KV usage = 0.000876
t+0.6s: KV usage = 0.000876
t+0.9s: KV usage = 0.000876
Turn 2/5 last=False 115ms prompt=131 output=9
Output: ```bash cat main.py ```
KV usage: 0.000876 → 0.001577 (delta: +0.000701) ← context grew
Simulating tool execution — KV should stay pinned...
t+0.3s: KV usage = 0.001577
t+0.6s: KV usage = 0.001577
t+0.9s: KV usage = 0.001577
Turn 3/5 last=False 154ms prompt=279 output=12
Output: ```bash grep -r 'TODO' . ```
KV usage: 0.001577 → 0.003330 (delta: +0.001753) ← big jump (grep results injected)
Simulating tool execution — KV should stay pinned...
t+0.3s: KV usage = 0.003330
t+0.6s: KV usage = 0.003330
t+0.9s: KV usage = 0.003330
Turn 4/5 last=False 106ms prompt=397 output=8
Output: ```bash git status ```
KV usage: 0.003330 → 0.004557 (delta: +0.001227) ← context grew again
Simulating tool execution — KV should stay pinned...
t+0.3s: KV usage = 0.004557
t+0.6s: KV usage = 0.004557
t+0.9s: KV usage = 0.004557
Turn 5/5 last=True 98ms prompt=503 output=7
Output: ```bash pytest ```
KV usage: 0.004557 → 0.000000 (delta: -0.004557) ← last step, all freed!
[Final] KV cache usage: 0.000000
Each turn's input = system + all previous messages + chat template tokens. Tool output from Turn N is included in Turn N+1's input.
| Turn | Content Tokens | Template | Input Total | Output | KV Tokens | ⌈KV/16⌉ | Measured | Tool Result |
|---|---|---|---|---|---|---|---|---|
| 1 | 33 (sys 25 + task 8) | 33 | 66 | 7 | 73 | 5 | 5 | 33 (ls) |
| 2 | 86 (+out₁ 7 + tool₁ 33 + task 13) | 45 | 131 | 9 | 140 | 9 | 9 | 115 (cat main.py) |
| 3 | 222 (+out₂ 9 + tool₂ 115 + task 12) | 57 | 279 | 12 | 291 | 19 | 18 | 82 (grep TODO) |
| 4 | 328 (+out₃ 12 + tool₃ 82 + task 12) | 69 | 397 | 8 | 405 | 26 | 25 | 78 (git status) |
| 5 | 422 (+out₄ 8 + tool₄ 78 + task 8) | 81 | 503 | 7 | 510 | 32 | 0 (freed) | — (last step) |
prompt_tokens)/metrics (±1 vs theoretical due to prefix cache dedup)| Event | Usage | Blocks | Visual (5402 total) |
|---|---|---|---|
| Baseline | 0.000000 | 0 | |
| Turn 1 | 0.000876 | 5 | |
| pin | 0.000876 | 5 | pinned |
| Turn 2 | 0.001577 | 9 | +4 |
| pin | 0.001577 | 9 | pinned |
| Turn 3 | 0.003330 | 18 | +9 |
| pin | 0.003330 | 18 | pinned |
| Turn 4 | 0.004557 | 25 | +7 |
| pin | 0.004557 | 25 | pinned |
| Turn 5 (last) | 0.000000 | 0 | freed |
Usage increases monotonically: 0.000876 → 0.001577 → 0.003330 → 0.004557 (5 → 9 → 18 → 25 blocks). The jump at Turn 3 (+9 blocks) reflects the large grep output injected into context.
After each of Turns 1–4, kv_cache_manager.free() is skipped. Usage stays constant during every tool gap (polled at t+0.3s, t+0.6s, t+0.9s). No blocks are returned to free_block_queue.
Turn 3 (is_last_step=True) immediately drops usage to 0.000000 — the _free_blocks() code path correctly bypasses pin and calls kv_cache_manager.free().
| Method | Location | What it proves |
|---|---|---|
/metrics → vllm:kv_cache_usage_perc | Prometheus endpoint | Usage stays elevated during pin period |
block_pool.get_num_free_blocks() | vllm/v1/core/block_pool.py:299 | Free count does not increase after pin |
KVCacheBlock.ref_cnt | vllm/v1/core/kv_cache_utils.py:164 | Reference counter stays > 0 while pinned |
kv_cache_manager.get_computed_blocks() | vllm/v1/core/kv_cache_manager.py:153 | Returns hit_length > 0 on next turn |
scheduler_timestamps JSON | vllm/v1/core/estimate_with_func.py:27 | Records pinned_time, unpinned_time, hit_length per job |
scheduler.pinned_requests | vllm/v1/core/sched/scheduler.py:139 | List of (Request, expiry_time) tuples |
When multiple requests share the same job_id (including None), the second request arrives before the first finishes. The code assumes history[-1] is a departure entry with func_call key, but it's an arrival entry.
# Crash site:
request.last_func_call = self.job_to_history[job_id][-1]["func_call"] # KeyError!
# history[-1] is {"arrival_time": ...}, not {"departure_time": ..., "func_call": ...}
Trigger: Concurrent requests with same job_id, or any request with job_id=None.
Status: Not reported on GitHub. Issue #15 is a different bug (VllmModelConfig port argument).
Understanding Continuum's pin mechanism requires understanding vLLM v1's underlying KV cache architecture. Here we document the key design decisions and community discussions.
vLLM v1's block table is append-only — once a block ID is written into a request's block table, it cannot be changed. New blocks can only be appended at the end. To understand why, consider this scenario:
Step 1 — Request A arrives first
System prompt = "You are a helpful assistant." → needs 1 block
Step 2 — Request B arrives with the same system prompt
Prefix cache hit: block hash matches → B also needs block 3's K/V data
(Block hash = hash of token IDs + parent block hash, so same tokens at same position = same hash)
Here's where v0 and v1 diverge:
B shares A's block directly. Then scheduler rewrites A's block table to point to a canonical copy:
B gets its own block. Both co-exist with the same K/V content:
Block 3 returns to free queue, but its hash stays in the prefix cache. Future requests with the same system prompt get a cache hit and reuse block 3 directly — no duplication needed.
Continuum pins blocks after a request finishes. Because block tables are append-only, the pinned blocks' IDs are guaranteed stable — no other process can rewrite them. When the next turn arrives, prefix cache finds the pinned blocks by hash and appends their IDs to the new block table.
Source: vLLM v1 Prefix Caching Design Doc
The BlockPool reserves one block as a permanent placeholder:
# vllm/v1/core/block_pool.py:66-70
# To represent a placeholder block with block_id=0.
# The ref_cnt of null_block is not maintained, needs special care to
# avoid freeing it.
self.null_block = self.free_block_queue.popleft()
self.null_block.is_null = True
Because block tables are append-only and fixed-length, positions that have no real KV data (e.g., prefix cache miss, or tokens outside the sliding window) must be filled with a placeholder rather than left empty:
# Example: sliding window attention, only last 3 blocks active
block_table = [null, null, block_8, block_3, block_12]
↑ outside window, filled with null_block
The null block is marked is_null=True so it is never returned to the free queue and never allocated to any request. This is also why get_usage() subtracts 1 from total blocks:
# vllm/v1/core/block_pool.py:314
total_gpu_blocks = self.num_gpu_blocks - 1 # Subtract 1 to account for null block
Each KVCacheBlock has a ref_cnt. This is the mechanism Continuum exploits for pinning:
| Event | ref_cnt | In free queue? | Allocatable? |
|---|---|---|---|
| Allocated to request | +1 → 1 | No | No |
| Normal free (no pin) | -1 → 0 | Yes | Yes |
| Continuum pin (skip free) | stays 1 | No | No |
| Unpin (TTL expires) | -1 → 0 | Yes | Yes |
Freed blocks are added to the tail of the free queue in reverse order — the last block of a request hashes the most tokens and is least likely to be reused, so it should be evicted first.
Continuum's pin is remarkably simple — it doesn't add any new locking or memory management. It just skips calling kv_cache_manager.free() when a request finishes:
# scheduler.py — _free_blocks()
if policy == CONTINUUM and not request.is_last_step:
length_of_pin = tool_call_estimator.set_up_pin(request)
if length_of_pin > 0.01:
self.pin_request(request, length_of_pin)
return # ← kv_cache_manager.free() is SKIPPED
# ← ref_cnt stays > 0
# ← blocks stay out of free queue
# ← GPU memory remains occupied
# Normal path (not pinned):
self.kv_cache_manager.free(request) # ← ref_cnt decremented → blocks return to free queue
When the same job's next turn arrives, get_computed_blocks() finds the still-allocated blocks via prefix cache hash lookup → hit_length > 0 → skip re-prefill.
| Topic | Link | Key Insight |
|---|---|---|
| Prefix Caching Design (v1) | design/v1/prefix_caching.md | Official doc explaining append-only constraint, dedup deferral, and reverse-order eviction |
| RFC: Automatic Prefix Caching | #2614 | Original proposal for prefix-aware block management, discusses eviction policy and hash-based lookup |
| Understanding vLLM KV Cache | vLLM Forums | Community Q&A on block pool internals, ref counting, and allocation |
| Hybrid KV Cache Manager | #22688 | Design doc for GPU↔CPU tiered block pool with async transfer — relevant to Continuum's offload scenario |
| RFC: CPU Offload in V1 | #16144 | Discusses hiding transfer latency with pinned memory and MemcpyAsync, round-robin → LRU eviction |
| Prefix-Caching Aware Scheduling | #7883 | Discussion on scheduling that considers cache hit rate — directly related to Continuum's job-level FCFS priority |
| KV Cache Management (DeepWiki) | DeepWiki | Third-party deep dive into block pool, free queue, and hash-based caching architecture |
vLLM's KV cache pin does not call cudaMalloc, cudaFree, cudaMemcpy, or any CUDA memory management API. The entire mechanism is pure Python bookkeeping on top of a pre-allocated GPU buffer.
The only CUDA memory allocation happens once at server startup:
# vllm/v1/worker/gpu_model_runner.py:3367 — called once at init
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device="cuda")
This allocates 10.55 GiB as a single contiguous tensor on GPU. This memory is never freed until server shutdown. It is logically partitioned into 5,402 blocks:
GPU Memory (10.55 GiB, allocated once via torch.zeros):
Each block = 2 MiB = 16 tokens × 32 layers × 8 KV heads × 128 dim × 2 bytes. "Allocate" and "free" only move pointers in Python — zero GPU ops.
Each block stores 16 tokens' Key and Value tensors across all 32 layers. The GPU tensor shape per layer (FlashAttention backend):
# vllm/attention/backends/flash_attn.py:77
shape = (2, num_blocks, block_size, num_kv_heads, head_size)
↑ ↑ ↑ ↑ ↑
K/V 5402 16 8 128
One block (block_id = N) in one layer:
K tensor kv_cache[0, N, :, :, :]
V tensor kv_cache[1, N, :, :, :]
Per layer: 32 + 32 = 64 KiB. Across 32 layers: 64 KiB × 32 = 2 MiB per block.
FlashAttention kernel uses block_id to index directly into this tensor — no copying, no indirection beyond the block table lookup.
Each KVCacheBlock has a ref_cnt (reference count) that tracks how many requests are using it. This is needed because prefix caching allows multiple requests to share the same block — if two requests have the same system prompt, their first blocks contain identical K/V data, so vLLM lets them share instead of duplicating.
| ref_cnt | State | Meaning |
|---|---|---|
0 | Free (in free queue) | No one uses this block. Can be evicted or re-allocated. |
1 | In use (1 request) | Exactly one request owns it. Normal case. |
≥ 2 | Shared (prefix cache) | Multiple requests share this block (same prefix). Block is freed only when ALL requests release it. |
Continuum's pin exploits this: by keeping ref_cnt = 1 (skipping free()), the block stays out of the free queue — exactly the same as if a request were still using it.
The FreeKVCacheBlockQueue is a doubly-linked list instead of Python's deque. The reason: prefix cache hit requires removing a block from the middle.
When a new request shares a prefix with a cached block, that block (which is free and evictable) must be pulled out of the free queue immediately:
# vllm/v1/core/block_pool.py:236-250 — touch() on prefix cache hit
def touch(self, blocks):
for block in blocks:
if block.ref_cnt == 0: # block is in free queue
self.free_block_queue.remove(block) # ← remove from MIDDLE! O(1)
block.ref_cnt += 1
| Operation | Python deque |
Array / List | Doubly-linked list |
|---|---|---|---|
| popleft (allocate) | O(1) | O(n) | O(1) |
| append (free) | O(1) | O(1) | O(1) |
| remove from middle (cache hit) | O(n) | O(n) | O(1) |
The remove() operation is just 2 pointer assignments:
# vllm/v1/core/kv_cache_utils.py:334-340
block.prev_free_block.next_free_block = block.next_free_block
block.next_free_block.prev_free_block = block.prev_free_block
block.prev_free_block = block.next_free_block = None
prev_free_block / next_free_block directly on the KVCacheBlock dataclass (not separate Node objects), avoiding Python object allocation overhead. fake_head and fake_tail sentinel nodes eliminate null-check branching.Each layer's KV cache is a contiguous torch.zeros tensor on GPU. block_id is simply an index into the second dimension — the GPU address is computed by a single multiplication, with no hash table or pointer chasing:
# Tensor shape per layer (FlashAttention)
kv_cache.shape = (2, 5402, 16, 8, 128)
↑ ↑
K/V block_id ← this dimension IS the block_id
# GPU address for block N:
base = kv_cache.data_ptr() # tensor 起始 GPU 位址
block_size_bytes = 16 × 8 × 128 × 2 = 32,768 bytes # per block per K or V
key_block_N = base + 0 × total_kv_size + N × 32,768
value_block_N = base + 1 × total_kv_size + N × 32,768
Example: block_table = [1, 2, 3, 4, 5] for one layer's K cache
Address = base + block_id × 32,768. One multiplication — no lookup table, no indirection.
The CUDA kernel does exactly this pointer arithmetic (csrc/cache_kernels.cu:306):
key_dst = key_cache + block_idx * block_stride + block_offset * page_stride;
// ^^^^^^^^^^^^^^^^^^^^^^^^
// block_idx × 32,768 = jump directly to that block
Block allocation/free is pure Python, but the actual K/V data read/write happens in two CUDA kernels:
vllm/attention/backends/flash_attn.py:677csrc/cache_kernels.cu:281
Every forward step for new tokens. Converts slot_mapping[token] → block_idx + offset, copies K/V into GPU tensor.
vllm/attention/backends/flash_attn.py:770,805
FlashAttention receives block_table=[1,2,3,4,5], gathers K/V from scattered blocks → Q×KT/√d → softmax → ×V.
// csrc/cache_kernels.cu:291-309 — reshape_and_cache_flash_kernel
slot_idx = slot_mapping[token_idx]; // e.g. 35
block_idx = slot_idx / block_size; // 35 / 16 = block 2
block_offset = slot_idx % block_size; // 35 % 16 = position 3
// Pointer arithmetic → destination in GPU tensor
key_dst = key_cache + block_idx * block_stride + block_offset * page_stride;
value_dst = value_cache + block_idx * block_stride + block_offset * page_stride;
// Vectorized copy (8 bf16 elements per instruction)
vectorize_copy(key_src → key_dst);
vectorize_copy(value_src → value_dst);
# vllm/attention/backends/flash_attn.py:770 (prefill) / :805 (decode)
flash_attn_varlen_func(
q=query,
k=key_cache, # entire GPU tensor
v=value_cache, # entire GPU tensor
block_table=[1, 2, 3, 4, 5], # which blocks belong to this request
...
)
# FlashAttention: for each query → look up block_table
# → index into key_cache/value_cache → compute attention
ref_cnt > 0). When the next turn arrives, FlashAttention reads the same block_ids → same GPU addresses → old K/V data is intact.| Operation | API | What it does | CUDA call? |
|---|---|---|---|
| Init pool | torch.zeros(..., device="cuda") |
Allocate 10.55 GiB contiguous GPU buffer | Yes (once) |
| Allocate block | FreeKVCacheBlockQueue.popleft_n() |
Remove from doubly-linked list, ref_cnt += 1 | No |
| Free block | FreeKVCacheBlockQueue.append_n() |
ref_cnt -= 1, append back to linked list | No |
| Pin (Continuum) | (nothing) | Skip free() → ref_cnt stays > 0 → block stays out of free queue | No |
| Unpin (TTL expires) | kv_cache_manager.free() |
ref_cnt -= 1, block returns to free queue | No |
| Use block (forward pass) | FlashAttention kernel | Read/write the GPU tensor at block_id × block_size offset | Yes (every step) |
The core data structure is a doubly-linked list of KVCacheBlock objects. All allocate/free operations are O(1) pointer swaps:
# vllm/v1/core/kv_cache_utils.py:158-175 — each block is a node
@dataclass
class KVCacheBlock:
block_id: int # index into the GPU tensor
ref_cnt: int = 0 # > 0 = in use, 0 = free/evictable
_block_hash: Optional # for prefix cache lookup
prev_free_block: Optional["KVCacheBlock"] # doubly-linked list
next_free_block: Optional["KVCacheBlock"] # doubly-linked list
is_null: bool = False # placeholder block
# Allocate: pop from front of free list
block = free_block_queue.popleft() # O(1) pointer swap
block.ref_cnt += 1
# Free: append to tail of free list
block.ref_cnt -= 1
free_block_queue.append(block) # O(1) pointer swap
# Pin: do nothing (skip the free above)
Fixed-size blocks from a pre-allocated pool. No cudaMalloc/cudaFree churn, no CUDA memory fragmentation.
Allocate/free/pin/unpin are pure CPU operations (Python linked list). The GPU only sees block_id indices in the block table — it never knows blocks were "pinned".
Continuum doesn't need any special memory API. It just skips one function call (free()). The block's ref_cnt stays above zero, preventing the free queue from reclaiming it.
Tracing Turn 1 (66 tokens → 5 blocks) step by step:
① Initial state — FreeKVCacheBlockQueue (5401 free blocks)
ref_cnt = 0 for all blocks | num_free_blocks = 5401
② popleft_n(5) — pop 5 blocks from front
Free queue after pop:
Popped (ref_cnt 0 → 1):
10 pointer swaps, 5 int increments. Zero CUDA calls. | num_free_blocks = 5396
③ Build block table → GPU tensor mapping
block_table = [1, 2, 3, 4, 5] ← indices into GPU tensor
GPU tensor (10.55 GiB, allocated once, never freed):
FlashAttention reads/writes at block_id × 2 MiB offset. GPU memory is never malloc'd or free'd.
④ /metrics result
usage = 1 − (5396 / 5401) = 5 / 5401 ≈ 0.000926
Total Python ops: 10 pointer swaps + 5 int increments + 1 dict insert | Total CUDA ops: 0
Everything described above is PagedAttention — the core contribution of the original vLLM paper. The naming maps directly to OS virtual memory concepts:
Just as OS paging lets processes use non-contiguous physical memory via a page table, PagedAttention lets requests use non-contiguous GPU memory via a block table:
Logical token positions (contiguous to attention computation):
Physical GPU memory (non-contiguous — blocks can be anywhere):
block_table = [1, 2, 3, 4, 5] ← indirection table (like page table)
Continuum operates entirely within PagedAttention's framework. It adds one behavior:
Request done → free() → ref_cnt → 0 → back to free queue. Next request can overwrite.
Request done → skip free() → ref_cnt stays 1 → blocks reserved. Next turn of same job reuses them.
How does vLLM decide it can allocate 5402 blocks? This section traces the math.
# vllm/v1/worker/gpu_worker.py:176-177
requested_memory = init_snapshot.total_memory × gpu_memory_utilization
# vllm/v1/worker/gpu_worker.py:279-280 — after profile run
non_kv_cache_memory = weights + peak_activation + non_torch_memory
available_kv = requested_memory − non_kv_cache_memory
num_gpu_blocks = available_kv // cache_block_size
requested_memory uses total_memory (not free), so it's a fixed value regardless of other processes. However, if init_free < requested_memory, vLLM refuses to start. Also, peak_activation is measured by running one real forward pass, not calculated analytically.Each block stores KV cache for block_size tokens across all attention layers:
# vllm/worker/cache_engine.py:121-145 — get_cache_block_size()
cache_block_size = dtype_size × num_layers × block_size × (key_entry + value_entry)
# where:
key_entry = num_kv_heads × head_size
value_entry = num_kv_heads × head_size # same as key (except MLA)
| Parameter | Value | Source |
|---|---|---|
num_attention_layers | 32 | config.json |
num_kv_heads | 8 | GQA (32 Q heads / 4 groups) |
head_size | 128 | hidden_size 4096 / 32 heads |
block_size | 16 | vLLM default (--block-size) |
dtype_size | 2 bytes | bfloat16 |
key_entry = 8 × 128 = 1024
value_entry = 8 × 128 = 1024
cache_block_size = 2 × 32 × 16 × (1024 + 1024)
= 2 × 32 × 16 × 2048
= 2,097,152 bytes
= 2 MB per block
These are exact values reported by vLLM's startup log (vllm/v1/worker/gpu_worker.py:391), not estimates:
# vLLM startup log (vllm/v1/worker/gpu_worker.py:391)
Free memory on device (28.66/31.34 GiB) on startup.
Desired GPU memory utilization is (0.85, 26.64 GiB).
Actual usage is 14.99 GiB for weight,
1.01 GiB for peak activation,
0.09 GiB for non-torch memory,
and -0.41 GiB for CUDAGraph memory.
Available KV cache memory: 10.55 GiB
num_gpu_blocks: 5402
| Component | Size | Source |
|---|---|---|
total_memory | 31.34 GiB | RTX 5090 |
× gpu_memory_utilization | × 0.85 | --gpu-memory-utilization |
| = requested_memory | 26.64 GiB | Upper bound for all vLLM usage (total × 0.85) |
init_free_memory | 28.66 GiB | Must be ≥ requested_memory, otherwise vLLM refuses to start |
vLLM runs a profile, measures all non-KV memory, and subtracts from requested_memory:
| Component | Size | Source |
|---|---|---|
| Model weights | 14.99 GiB | vLLM profile run |
| Peak activation | 1.01 GiB | vLLM profile run (FlashAttention + torch.compile) |
| Non-torch memory | 0.09 GiB | CUDA kernels, driver overhead |
| CUDAGraph memory | −0.41 GiB | Graph capture freed memory (negative = reclaimed) |
| Redundancy buffer | 0.15 GiB | Hardcoded 150 MiB safety margin |
| = non-KV total | 16.09 GiB | Sum of above + internal overhead |
available_kv = requested_memory − non_kv_total
= 26.64 − 16.09
= 10.55 GiB
num_gpu_blocks = 10.55 GiB ÷ 2 MiB = 5,402 blocks ✓
10.55 GiB ÷ 2 MiB = 10,803.2 MiB ÷ 2 MiB = 5,401.6 → 5,402 blocks ✓
kv cache memory in use = 11,328,937,984 bytes = 10.55 GiB — matches exactly.Activations are the intermediate tensors produced during a forward pass. Unlike model weights (permanent) and KV cache (persistent across steps), activations are transient — created during computation and discarded immediately after.
The two largest activations in each Transformer layer:
num_heads × S × S × dtype_size
With FlashAttention, this is never materialized as a full S×S matrix — computed in tiles, so actual memory ≈ 0.
max_batch_tokens × intermediate_size × 2 × dtype_size
Llama 8B: 8192 × 14336 × 2 × 2 = ~448 MB
Only one layer's activation is alive at a time (inference has no backprop), so this is the peak, not cumulative.
| Region | Size | Lifetime | Configurable? |
|---|---|---|---|
| Other processes (e.g. Desktop) | 2.68 GiB | Permanent | Kill other GPU processes |
| Model weights | 14.99 GiB | Permanent | Choose model / quantization |
| Peak activation | 1.01 GiB | Transient (per forward pass) | --max-num-batched-tokens |
| Non-torch + CUDAGraph | −0.32 GiB | Permanent | — |
| KV cache blocks | 10.55 GiB | Persistent (across steps) | --gpu-memory-utilization, --block-size |
| Reserved (15% of init_free) | 4.30 GiB | — | --gpu-memory-utilization |
| Total GPU | 31.34 GiB |
Continuum pins blocks to keep them allocated. Every pinned block is one fewer block available for new requests. With 5,402 total blocks:
| Scenario | Tokens pinned | Blocks pinned | % capacity locked |
|---|---|---|---|
| 1 short agent job | ~100 | ~7 | 0.1% |
| 10 concurrent jobs, 2K ctx each | 20,000 | 1,250 | 23% |
| 64 concurrent jobs (SWE-bench) | 128,000+ | 8,000+ | >100% → eviction needed |
This is why Continuum's pin TTL and eviction policy matter — aggressive pinning without eviction would exhaust the KV cache budget and stall all new requests.
| Parameter | Value |
|---|---|
| Arrival pattern | Poisson process (matching paper methodology) |
| JPS (Jobs Per Second) | 2.0 / 8.0 |
| Duration | 90-120s |
| Turns per job | 8 |
| Context per job | ~2,900 tokens (accumulated, realistic tool outputs) |
| Tool execution time | Variable: 50-200ms (fast tools) + 2-5s (pytest) |
| Metric | Average job duration (same as paper's primary metric) |
| FCFS | Continuum | Diff | |
|---|---|---|---|
| Jobs completed | 206 | 231 | |
| Peak concurrent | ~15 | ~22 | |
| KV cache pressure | ~75% | ||
| Avg duration | 6.65s | 6.97s | +4.8% |
| Median | 6.63s | 6.96s | +5.0% |
| FCFS | Continuum | Diff | |
|---|---|---|---|
| Jobs completed | 756 | 704 | |
| Peak concurrent | ~156 | ~111 | |
| KV cache pressure | ~95%+ | ||
| Avg duration | 14.10s | 12.47s | -11.6% |
| Median | 14.33s | 12.61s | -12.0% |
| P90 | 17.02s | 14.37s | -15.6% |
| P95 | 17.54s | 14.63s | -16.6% |
| Turn | Prompt Tokens | FCFS | Continuum | Diff |
|---|---|---|---|---|
| 1 | 92 | 1,263 ms | 1,051 ms | -16.8% |
| 2 | 469 | 735 ms | 519 ms | -29.4% |
| 3 | 811 | 574 ms | 479 ms | -16.6% |
| 4 | 1,343 | 1,005 ms | 841 ms | -16.3% |
| 5 | 1,655 | 749 ms | 652 ms | -12.9% |
| 6 | 2,241 | 641 ms | 539 ms | -15.9% |
| 7 | 2,665 | 3,732 ms | 3,212 ms | -13.9% |
| 8 | 2,915 | 1,236 ms | 1,034 ms | -16.3% |
Free queue has ~1,350 spare blocks. Freed blocks sit safely — no one grabs them. Prefix cache hit rate ≈ pin hit rate. Continuum's overhead > its benefit.
Free queue nearly empty. Freed blocks get immediately grabbed by new jobs → prefix cache miss → full re-prefill. Continuum's pin prevents this by keeping blocks out of the free queue entirely.
This section traces exactly how each scheduling strategy is implemented in the vllm-continuum codebase. All three share the same scheduler loop (vllm/v1/core/sched/scheduler.py), but differ in request ordering, preemption policy, and post-completion block handling.
The four supported policies and their queue implementations are registered in a single factory:
# vllm/v1/core/sched/request_queue.py:17-22
class SchedulingPolicy(Enum):
FCFS = "fcfs"
PRIORITY = "priority"
CONTINUUM = "continuum"
CONTEXT_SERVE = "context_serve"
# vllm/v1/core/sched/request_queue.py:342-353
def create_request_queue(policy):
if policy == SchedulingPolicy.FCFS:
return FCFSRequestQueue() # deque[Request], pure FIFO
elif policy == SchedulingPolicy.CONTINUUM:
return ContinuumRequestQueue() # deque + job_id-level FCFS
elif policy == SchedulingPolicy.CONTEXT_SERVE:
return ContinuumRequestQueue() # reuses Continuum queue logic
vllm serve MODEL --scheduling-policy fcfs|continuum|context_serve vllm/engine/arg_utils.py:875, default: fcfs)The simplest policy — pure FIFO with no job awareness.
# vllm/v1/core/sched/request_queue.py:85-140
class FCFSRequestQueue(deque[Request], RequestQueue):
def add_request(self, request):
self.append(request) # → right end (newest)
def pop_request(self):
return self.popleft() # ← left end (oldest)
def peek_request(self):
return self[0] # oldest, without removing
def prepend_request(self, request):
self.appendleft(request) # preempted → back to head
# vllm/v1/core/sched/scheduler.py:470-472 — schedule() waiting loop
if self.policy == SchedulingPolicy.FCFS:
request = self.waiting.peek_request() # oldest arrival
# vllm/v1/core/sched/scheduler.py:394-396
else: # FCFS
preempted_req = self.running.pop() # evict last-added running request
# vllm/v1/core/sched/scheduler.py:1398-1399
self.kv_cache_manager.free(request) # ref_cnt → 0, blocks → free queue
del self.requests[request.request_id]
| Request order | Strict arrival time (deque FIFO) |
| Preempt victim | Last request in running list (pop()) |
| Block handling | Always free() → blocks return to free queue immediately |
| Job awareness | None — treats every request independently |
Continuum adds three mechanisms on top of the base scheduler: (1) job-ID level FCFS ordering with pinned-job priority, (2) smart preemption that protects last-step requests, (3) TTL-based KV cache pinning after request completion.
Two-tier priority: pinned jobs first, then job-ID level FCFS (not request-level).
# vllm/v1/core/sched/request_queue.py:221-274
class ContinuumRequestQueue(deque[Request], RequestQueue):
def __init__(self):
self.job_id_first_entry_time: dict[str, float] = {} # job → first arrival
def peek_request(self, pinned_requests, kv_cache_manager, connector):
pinned_job_ids = {req.job_id for req, _ in pinned_requests}
# Tier 1: find the waiting request whose job has pinned blocks
# (among pinned jobs, pick the one with earliest first-entry)
earliest = None
for request in self:
if request.job_id in pinned_job_ids:
entry = self.job_id_first_entry_time.get(request.job_id)
if entry < earliest_time:
earliest = request
if earliest:
return earliest # ← pinned-job request goes first!
# Tier 2: job-ID level FCFS — the job that arrived earliest overall
for request in self:
entry = self.job_id_first_entry_time.get(request.job_id)
if entry < earliest_time:
earliest = request
return earliest
# vllm/v1/core/sched/scheduler.py:197-236
def pop_running_request_based_on_last_step(self, request):
# Priority 1: evict non-last-step request with latest job entry time
for req in self.running:
if not req.is_last_step and job_entry_time > latest:
latest_request = req
# Priority 2: if all are last-step, evict the latest entry regardless
if latest_request is None:
for req in self.running:
if job_entry_time > latest:
latest_request = req
# Edge case: if running ≤ 1, steal from pinned_requests instead
if len(self.running) <= 1:
# Pop the pinned request with the largest end_time
return latest_pinned, True # is_unpin=True
is_last_step=True) are about to free their blocks permanently. Evicting them wastes their entire computation. Non-last-step requests will be pinned after completion anyway, so evicting them is cheaper — they'll get priority when they re-enter the waiting queue.The complete pin/unpin lifecycle is managed by three methods in the scheduler:
# vllm/v1/core/sched/scheduler.py:239-258
## 1. Pin: skip free(), store (request, end_time) tuple
def pin_request(self, request, length_of_pin):
self.pinned_requests.append((request, time.time() + length_of_pin))
## 2. Unpin: remove from list, call free() to release blocks
def unpin_request(self, request, end_time):
self.pinned_requests.remove((request, end_time))
self.kv_cache_manager.free(request) # ref_cnt → 0
## 3. Regular cleanup: called at the start of every schedule() step
def unpin_requests_regular(self):
waiting_job_ids = [req.job_id for req in self.waiting]
for request, end_time in self.pinned_requests:
# Only unpin if: (a) TTL expired AND (b) no waiting request for this job
if request.job_id not in waiting_job_ids and time.time() >= end_time:
self.unpin_request(request, end_time)
# vllm/v1/core/estimate_with_func.py:151-160
FIXED_THRESHOLD_CONTINUUM = 2.0 # seconds
def set_up_pin(self, request):
if request.this_func_call is None:
return 0 # can't detect tool → don't pin
exec_time = self.get_func_call_exec_time(request.this_func_call)
if exec_time > FIXED_THRESHOLD_CONTINUUM:
return 0 # slow tool (>2s) → don't pin
return FIXED_THRESHOLD_CONTINUUM # fast tool → pin for 2.0s
# vllm/v1/core/sched/scheduler.py:1356-1399
def _free_blocks(self, request):
# Step 1: unpin any sibling pinned request with same job_id
for req, end_time in self.pinned_requests:
if req.job_id == request.job_id:
self.unpin_request(req, end_time)
# Step 2: Continuum decision
if self.policy == SchedulingPolicy.CONTINUUM and not request.is_last_step:
length_of_pin = self.tool_call_estimator.set_up_pin(request)
if length_of_pin > 0.01:
self.pin_request(request, length_of_pin) # ← SKIP free()
del self.requests[request.request_id]
return
# Default: free blocks (FCFS path, or is_last_step=True)
self.kv_cache_manager.free(request)
del self.requests[request.request_id]
| Request order | Pinned-job requests first, then job-ID level FCFS |
| Preempt victim | Newest non-last-step job → newest last-step → steal from pinned |
| Block handling | is_last_step=False + fast tool → pin (skip free, TTL=2s); otherwise → free() |
| Job awareness | Full: tracks job_id, is_last_step, this_func_call, per-job history |
ContextServe extends Continuum's binary pin/no-pin decision into a three-way dispatch: FAST (strict pin), SLOW (proactive free), UNCERTAIN (fallback to Continuum TTL). It shares the same ContinuumRequestQueue and preemption logic.
# vllm/v1/core/sched/scheduler.py:1364-1388
if self.policy == SchedulingPolicy.CONTEXT_SERVE and not request.is_last_step:
speed, pin_duration = self.tool_call_estimator.set_up_pin_contextserve(request)
if speed == ToolSpeed.FAST: # ls, cat, pwd, git status ...
self.pin_request(request, pin_duration) # lock in VRAM
elif speed == ToolSpeed.SLOW: # pytest, pip install, docker build ...
self.kv_cache_manager.free(request) # proactive free at t=0
else: # ToolSpeed.UNCERTAIN
if pin_duration > 0.01: # fallback to Continuum TTL
self.pin_request(request, pin_duration)
The ContextAwarePredictor (vllm/v1/core/context_predictor.py) uses a three-layer classification strategy:
| Layer | Method | Example |
|---|---|---|
| 1. Rule-based | Hardcoded sets: DEFINITELY_FAST (ls, cat, pwd), DEFINITELY_SLOW_PREFIXES (make, pip install, docker build) | ls → FAST 3.0s |
| 2a. Bayesian | Normal-Inverse-Gamma conjugate prior + Thompson Sampling (vllm/v1/core/bayesian_predictor.py) | Posterior quantile → TTL |
| 2b. XGBoost | 25-feature extraction (pipes, redirects, loops, etc.) → classification | Fallback if Bayesian fails |
| 3. Fallback | Returns UNCERTAIN + Continuum CDF TTL | — |
LMCache is an orthogonal strategy to Continuum. Instead of pinning blocks in GPU VRAM (which reduces capacity for other jobs), LMCache offloads evicted KV cache blocks to CPU DRAM or a remote server, then retrieves them when the same job returns. It integrates via the KVConnectorBase_V1 connector interface.
# vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py:23-27
class LMCacheConnectorV1(KVConnectorBase_V1):
def __init__(self, vllm_config, role):
self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self)
# Delegates to lmcache.integration.vllm.vllm_v1_adapter
Two operational roles, split across scheduler and worker processes:
| Role | Location | Methods |
|---|---|---|
SCHEDULER |
vllm/v1/core/sched/scheduler.py:84-93 |
get_num_new_matched_tokens() — query how many tokens are cached externallyupdate_state_after_alloc() — track block↔token mappingbuild_connector_meta() — encode store/load plan for workers |
WORKER |
vllm/v1/worker/kv_connector_model_runner_mixin.py |
start_load_kv() — async load from CPU before forwardwait_for_layer_load(layer) — per-layer sync pointsave_kv_layer(layer, tensor) — async save to CPU after computewait_for_save() — ensure save completes |
## Store (after each forward step)
Attention layer → save_kv_layer(layer_name, kv_tensor, attn_metadata)
→ LMCache chunks KV by 256 tokens (configurable)
→ Stores to CPU DRAM (LMCACHE_LOCAL_CPU=True)
or remote server (LMCACHE_REMOTE_URL=lm://host:port)
## Retrieve (on new request for same prefix)
Scheduler → connector.get_num_new_matched_tokens(request)
→ LMCache matches prompt prefix hash against stored chunks
→ Returns count of externally cached tokens
Worker → connector.start_load_kv(forward_context)
→ Per-layer: wait_for_layer_load(layer_name)
→ Loaded KV appears in GPU paged buffer
→ Attention skips recomputation for loaded tokens
# CLI: pass via JSON to --kv-transfer-config
vllm serve MODEL --kv-transfer-config '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_both"}'
# Environment variables (set before launch)
export LMCACHE_USE_EXPERIMENTAL="True"
export LMCACHE_LOCAL_CPU="True" # enable CPU DRAM backend
export LMCACHE_MAX_LOCAL_CPU_SIZE="5.0" # 5 GB CPU memory limit
export LMCACHE_CHUNK_SIZE="256" # tokens per chunk
| FCFS | Continuum | LMCache | |
|---|---|---|---|
| Strategy | No KV retention — evicted blocks immediately available | Pin in GPU VRAM — blocks reserved for returning job | Offload to CPU DRAM — blocks freed on GPU, cached on CPU |
| GPU capacity impact | None (100% available) | Reduces available blocks (pinned blocks are locked) | None (blocks freed on GPU, stored on CPU) |
| Resume cost | Full re-prefill (recompute all tokens) | Zero (KV still in GPU, same addresses) | PCIe transfer (CPU→GPU, ~12 GB/s) |
| Where it shines | Low load; prefix caching handles reuse | Small GPU, high load, fast tools | Large GPU, high load, abundant CPU DRAM |
| Key files | request_queue.py:85 |
request_queue.py:221estimate_with_func.py:151scheduler.py:239 |
lmcache_connector.pykv_transfer.pyattention/layer.py:445 |
How each strategy handles KV cache during tool-call gaps under high memory pressure (KV cache >90%). Toggle between short and long tool execution times:
KV freed immediately after inference. Under high load, blocks evicted by other requests → full recompute on return.
request finished → _free_blocks() → kv_cache_manager.free() → blocks → free queue
→ blocks immediately available for any new request (no pin)
KV pinned in VRAM (ref_cnt stays > 0). Tool returns before TTL expires → cache hit, zero recompute.
request finished → _free_blocks():
if not is_last_step + fast tool → pin_request(2.0s) → blocks RESERVED (skip free)
→ next turn → pinned-job gets scheduling priority → cache hit
KV offloaded to CPU DRAM asynchronously, GPU freed. On tool return, reload from DRAM (PCIe transfer).
forward → per-layer: save_kv_layer() → async copy to CPU DRAM
→ wait_for_save() → free() → GPU blocks released
→ next turn → connector.get_num_new_matched_tokens() → skip prefill for cached portion
→ start_load_kv() → per-layer wait_for_layer_load()
KV freed immediately. Blocks evicted → full recompute on return.
request finished → _free_blocks() → kv_cache_manager.free() → blocks → free queue
→ long tool — blocks overwritten by others — full recompute on return
KV pinned, but TTL expires before tool returns → blocks unpinned and evicted → must recompute.
pin_request(2.0s) → TTL counting...
→ unpin_requests_regular() → time.time() >= end_time → unpin + free()
→ blocks evicted before tool returns — degrades to FCFS behavior
KV safe in DRAM regardless of tool duration. On return, reload from DRAM — no recompute needed.
per-layer: save_kv_layer() → async GPU→CPU copy → free()
→ long tool — KV safe on CPU, GPU fully available for others
→ next turn → start_load_kv() → reload from DRAM (skip recompute)
When KV cache is nearly full (~95%), a new request arrives that needs blocks. How does each strategy decide what to evict? Below demonstrates with 3 concurrent jobs and only 8 free blocks.
FCFS — Evict last-added running request
① Initial state: 8 free blocks, need 12 → must evict 4 blocks
② FCFS evicts Job A's old blocks (freed, no pin), but A needs 12 blocks — still not enough. Evict Job C (last-added to running).
③ Result: Job A gets 12 blocks, but must recompute all tokens (its old blocks were freed). Job C preempted → sent back to waiting queue head, also needs full recompute later.
Continuum — Pin protects, evict newest non-last-step
① Initial state: Job A's blocks are PINNED (ref_cnt > 0). Same 8 free blocks.
② Job A gets pinned-job priority. Unpin old A blocks (they'll be reused as prefix cache hit). Need 4 more → evict Job C (newest non-last-step job).
③ Result: Job A reuses 4 pinned blocks as cache hit + allocates 8 new blocks. Only 8 blocks need computation (not 12). Job C preempted.
LMCache — Blocks in DRAM, GPU fully available
① Initial state: Job A's KV was offloaded to DRAM — its GPU blocks were already freed. 8 free blocks + A's 4 freed blocks = 12 available.
② Job A returns. Reload 4 blocks from DRAM (PCIe transfer), allocate 8 new blocks. No eviction needed — 12 free blocks available!
Three policies tested across three GPUs at varying load levels. Key findings:
| GPU | VRAM | ~KV Blocks | Best Strategy at High Load | Key Observation |
|---|---|---|---|---|
| RTX 6000 | 24 GB | ~3,000 | Continuum | 71% faster at JPS=1.0; 26% faster at JPS=12.0 — pin prevents eviction on small VRAM |
| A100 | 40 GB | ~9,000 | LMCache | 39-65% faster than FCFS at JPS=10+. Continuum provides no benefit — sometimes slower |
| L40S | 48 GB | ~13,000 | LMCache | 39-60% faster at JPS=10+. Same pattern — large GPU favors DRAM offload over VRAM pin |
End-to-end trace from ab_benchmark.py request through every vLLM layer to HTTP response. Shows how job_id / is_last_step propagate and where each strategy diverges.
The key challenge is propagating job_id and is_last_step from the client's extra_body through the standard OpenAI-compatible API, past vLLM's multi-layer engine architecture, all the way into the scheduler's Request object — without modifying the OpenAI protocol.
ab_benchmark.py:329vllm/entrypoints/openai/api_server.py:688protocol.py:635-638protocol.py:730-737serving_chat.py:300 → async_llm.py:289 → processor.py:424vllm/v1/request.py:146-152scheduler.py:1276-1290The scheduler runs in a loop: pick next request from waiting queue → allocate KV cache blocks → if not enough blocks, preempt someone → send to GPU. Each strategy differs in who gets picked and who gets evicted.
scheduler.py:271kv_cache_manager.py:192gpu_model_runner.py:2064kv_connector_model_runner_mixin.py:44,112,117gpu_model_runner.py:2118When the model finishes generating (EOS or max_tokens), the scheduler must decide what to do with the KV cache blocks. This is where the three strategies fundamentally differ — all in one function: _free_blocks().
scheduler.py:1042scheduler.py:1356-1399estimate_with_func.py:151 set_up_pin():context_predictor.py predict()core.py:766 output_queue → ZMQ socketasync_llm.py:427 _run_output_handler() → per-request queueserving_chat.py → ChatCompletionResponse ← detokenize, format as OpenAI JSONapi_server.py:703 → JSON HTTP 200ab_benchmark.py:335| Stage | Variable Form | File |
|---|---|---|
| ① | extra_body={"job_id": "job_0042", "is_last_step": False} | ab_benchmark.py:329 |
| ③ | ChatCompletionRequest.job_id = "job_0042" | protocol.py:635 |
| ④ | SamplingParams.extra_args = {"job_id": .., "is_last_step": ..} | protocol.py:730 |
| ⑦ | Request.job_id = "job_0042", Request.is_last_step = False | request.py:146 |
| ⑧ | waiting.add_request(request) + job history recorded | scheduler.py:1287 |
| ⑭ | is_last_step=False + fast tool → pin_request(2.0s) | scheduler.py:1391 |
Three key scripts drive all experiments: the policy benchmark, KV pin verifier, and JPS sweep plotter.
547 lines. Simulates realistic multi-turn agentic workload with Poisson-distributed job arrivals. Runs one policy per invocation; compare by running multiple times with different --policy.
# CLI usage
python3 ab_benchmark.py --policy continuum --port 8199 --jps 8.0 --duration 90 --turns 8
# Core: Poisson arrival process (line 381)
def poisson_arrival(jps, duration_s, base_url, model_id, n_turns):
gap = random.expovariate(jps) # exponential inter-arrival time
time.sleep(gap)
# spawn job as thread with job_id = f"job_{count:04d}"
# Core: how Continuum metadata is passed (line 329)
resp = client.chat.completions.create(
model=model_id,
messages=messages, # accumulated context (grows each turn)
max_tokens=150,
extra_body={
"job_id": job_id,
"is_last_step": (i == n_turns - 1), # True only on final turn
},
)
# Tool execution time per turn (line 276-286)
TOOL_TIME_DISTRIBUTIONS = [
(0.05, 0.15), # Turn 1: find/ls — fast
(0.05, 0.10), # Turn 2: cat — fast
(0.05, 0.10), # Turn 3: cat — fast
(0.08, 0.20), # Turn 4: grep — fast
(2.0, 5.0), # Turn 5: pytest — SLOW (long-tailed)
(0.05, 0.10), # Turn 6: cat — fast
(0.10, 0.30), # Turn 7: patch — fast
]
# Output JSON schema (line 516-538)
{
"policy": "continuum",
"jps": 8.0,
"avg_duration_s": 12.465,
"p90_duration_s": 14.371,
"p95_duration_s": 14.629,
"per_turn_avg_latency_ms": {"1": 1051, "2": 519, ...},
"per_turn_avg_prompt_tokens": {"1": 92, "2": 469, ...},
"job_durations": [6.185, 6.607, ...]
}
102 lines. Verifies Continuum's pin mechanism by polling vllm:kv_cache_usage_perc from the Prometheus /metrics endpoint during tool execution gaps.
# Parse Prometheus metric (line 36-41)
def get_kv_usage():
r = req.get(f"{BASE}/metrics", timeout=5)
for line in r.text.split("\n"):
if line.startswith("vllm:kv_cache_usage_perc{"):
return float(line.split()[-1])
# Verification: poll every 0.3s during 1s tool gap (line 84-95)
if not is_last:
print(" Simulating tool execution (1s) — KV should stay pinned...")
for t in [0.3, 0.6, 0.9]:
time.sleep(0.3)
u = get_kv_usage()
print(f" t+{t:.1f}s: KV usage = {u:.6f}")
# Expected: u ≈ usage_post (constant = pinned)
# Failure: u drops → blocks evicted, pin broken
153 lines. Reads JSON results from 3 GPUs × 3 policies, generates publication-quality comparison figures.
# GPU configurations (line 13-17)
GPUS = {
"rtx_6000": {"name": "RTX 6000 (24 GB)", "jps": [1.0, 2.0, 5.0, 8.0, 12.0]},
"A100": {"name": "A100 (40 GB)", "jps": [2.0, 5.0, 10.0, 15.0, 25.0, 40.0]},
"L40S": {"name": "L40S (48 GB)", "jps": [2.0, 5.0, 10.0, 15.0, 25.0]},
}
# Speedup formula (line 124)
speedup = (fcfs["avg_duration_s"] - other["avg_duration_s"]) / fcfs["avg_duration_s"] * 100
# Generates 3 figure types:
# 1. jps_sweep_{gpu}.png — per-GPU: avg duration vs JPS for 3 policies
# 2. jps_sweep_combined.png — 3-subplot side-by-side comparison
# 3. jps_sweep_speedup.png — % improvement over FCFS baseline
| Component | File | Lines |
|---|---|---|
| Policy enum + queue factory | vllm/v1/core/sched/request_queue.py | 17-353 |
| FCFSRequestQueue | vllm/v1/core/sched/request_queue.py | 85-140 |
| ContinuumRequestQueue | vllm/v1/core/sched/request_queue.py | 221-340 |
| Scheduler core | vllm/v1/core/sched/scheduler.py | 116-130, 197-264, 370-418, 1356-1399 |
| ToolCallEstimator (TTL) | vllm/v1/core/estimate_with_func.py | 95-228 |
| ContextAwarePredictor | vllm/v1/core/context_predictor.py | 43-268 |
| BayesianPredictor | vllm/v1/core/bayesian_predictor.py | 39-150 |
| LMCache connector | vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py | 1-168 |
| Connector factory | vllm/distributed/kv_transfer/kv_connector/factory.py | 95-98 |
| KV transfer config | vllm/config/kv_transfer.py | 1-112 |
| API protocol (job_id extraction) | vllm/entrypoints/openai/protocol.py | 294-296, 364-384 |
| Request dataclass | vllm/v1/request.py | 23-46, 136-175 |
| CLI argument | vllm/engine/arg_utils.py | 875-876 |
| Policy benchmark script | our_experiments/scripts/ab_benchmark.py | 547 |
| JPS sweep plotter | our_experiments/scripts/plot_jps_sweep.py | 153 |
The tool detection → time estimation → KV pin → priority scheduling pipeline works correctly on the happy path. Metrics confirm blocks are not freed during the pin window.
ToolCallParser relies on ```bash\n...\n``` regex. Thinking models (Qwen3) break this entirely — 2/10 vs 10/10 parse rate. The system is designed for Llama-family output format.
The repo uses a fixed 2-second threshold (FIXED_THRESHOLD_CONTINUUM = 2.0), not the CDF-based estimation described in the paper. If historical exec time > 2s → don't pin. Otherwise → pin exactly 2s.
No defensive handling for job_id=None, concurrent same-job requests, or thinking models. The code works on the author's exact setup (Llama 70B + mini-swe-agent + sequential turns) but breaks on deviations.
A/B comparison (FCFS vs Continuum latency under load), concurrent multi-job pressure test, scheduler_timestamps hit_length analysis, and SWE-bench end-to-end pass rate.