LMCache KV Transfer Pipeline — Internal Architecture

Technical deep dive: slot_mapping.cuda(), CUDA stream synchronization, nsys profiling, and wait_for_save bottleneck analysis.

Exp C — nsys profiling wait_for_save: 115.7 ms mean / prefill step cudaStreamSynchronize: 96.6% vLLM v1 + LMCache v0.3.7
Note

This page covers the technical internals of LMCache's KV transfer mechanism. For the latest Exp G results, see bottleneck-analysis.html.

1. Root Cause: Why LMCache Is Slower

Section §2 established that LMCache adds +2.73 s per job relative to FCFS under the benchmark conditions (H200, JPS=1, Exp C). This section traces that overhead to a single, quantifiable mechanism: excessive CUDA stream synchronization inside wait_for_save, which blocks the vLLM scheduler main loop on every prefill step.

1.1 The Cost: wait_for_save Blocks the Scheduler

vLLM's KV Connector API requires the scheduler to call wait_for_save before dispatching each step. NVTX profiling of 51 completed jobs (H200, JPS=1, Exp C) captured 1,769 invocations. The distribution is strongly bimodal: a small minority of prefill-step calls account for almost all wallclock time.

Call Type Count % of Calls Mean Median p95 Syncs/call Memcpy/call Source
Prefill-step 159 9% 115.71 ms 73.00 ms 253 ms ~24 ~19 H200, JPS=1, Exp C
Decode-only 1,610 91% 0.01 ms 0.01 ms 0.02 ms 0 0 H200, JPS=1, Exp C
Total 1,769 100% 3,744 (159×24) 2,996 (159×19)
19 memcpy/call = one cudaMemcpyAsync per KV chunk (new tokens since last turn ÷ chunk_size). 24 syncs/call > 19 memcpy/call because ~5 additional cudaStreamSynchronize per step originate from metadata operations (token hash, buffer allocation, store bookkeeping) that do not involve DMA.
Key insight: Decode-only calls complete in ~0.01 ms — essentially zero cost. All measurable overhead is concentrated in the 9% of calls that coincide with a prefill step. The question is why prefill-step calls take ~115 ms.

1.2 Inside wait_for_save: cudaStreamSynchronize

cudaStreamSynchronize(stream) is a CUDA Runtime API call that blocks the calling CPU thread until every operation previously queued on the given stream has completed on the GPU. Internally the call performs a user-to-kernel transition (via ioctl to the NVIDIA kernel driver), checks stream completion status, sleeps if the GPU is still working (using futex or driver poll), and resumes on GPU interrupt. Even when the GPU finishes its DMA in 0.67 ms (measured via NVTX from_gpu range, see §1.3), the round-trip through the kernel — transition + interrupt latency + wakeup — adds ~3–4 ms of fixed overhead per call on PACE nodes.

The Syncs/call and Memcpy/call columns above link directly to the CUDA Runtime totals in the table below (3,744 = 159 × ~24; 2,996 = 159 × ~19).

CUDA Runtime API Call Count Total Time Per-Call Mean % of Wallclock Source
cudaStreamSynchronize 3,744 17,764 ms 4.74 ms 96.6% H200, JPS=1, Exp C
cudaMemcpyAsync (launch) 2,996 24.5 ms 8.2 μs 0.1% H200, JPS=1, Exp C
cudaLaunchKernel 2,961 22.2 ms 7.5 μs 0.1% H200, JPS=1, Exp C
All other CUDA Runtime 3.2% H200, JPS=1, Exp C

The derivation table below shows that the measured 115.71 ms per prefill step is almost entirely explained by 24 sync calls × 4.74 ms each, leaving less than 2% unexplained residual.

Quantity Value Derivation Source
Total sync calls in 159 prefill windows 3,744 measured H200, JPS=1, Exp C
Syncs per prefill step ~23.5 ≈ 24 3,744 ÷ 159 derived
Mean cost per sync call 4.74 ms 17,764 ms ÷ 3,744 H200, JPS=1, Exp C
4.74 ms is a misleading weighted average of two populations: per-chunk syncs (2,962 calls × 0.59 ms = 1,742 ms, 9.5% of wfs) and per-request syncs (391 calls × 40.97 ms = 16,021 ms, 87% of wfs). The per-request syncs — one per request, outside all NVTX ranges, likely from batched_from_gpu() or StorageManager.batched_put() — dominate the total cost.
Predicted wait_for_save (prefill) 113.8 ms 24 × 4.74 ms derived
Measured wait_for_save (prefill mean) 115.71 ms NVTX direct measurement H200, JPS=1, Exp C
Residual (unexplained) +1.91 ms (1.6%) 115.71 − 113.8 derived

Root Cause (corrected)

CUDA Stream Definitions
Stream Created at Role GPU Activity (Exp C)
Default torch.cuda.current_stream() vLLM forward pass 249,655 kernels (FlashAttn, GEMM, Triton)
Store gpu_connector.py:165 LMCache KV offload 2,961 lmc::load_and_reshape + 2,604 D2H memcpy
Load gpu_connector.py:166 LMCache KV onload 50 kernels
Overhead Decomposition (115.71 ms per prefill step)
Key insight: ~100 ms is unavoidable forward-pass time, not extra overhead

wait_for_save() is called at the start of the next prefill step, while the previous step's forward pass is still running. The 100.75 ms GC sync simply waits for GPU compute that would run anyway. It is NOT a GPU idle gap — the GPU is fully busy with forward-pass kernels throughout. The real overhead is the ~15 ms of actual offload work (lmc::load_and_reshape + D2H DMA + store_stream syncs + Python bookkeeping) that runs inside wait_for_save() after the forward pass completes.

Component Code Location ~Count/step Avg ms/step Nature
Wait for forward pass (GC sync) vllm_v1_adapter.py:1055 1 100.75 ms 100.75 Unavoidable — GPU busy with forward pass
lmc::load_and_reshape kernel gpu_connector.py:279 ~19 0.054 ms ~1.0 Real offload work (~15 ms total)
D2H DMA (MEMCPY D2H) gpu_connector.py:279 ~19 0.626 ms ~11.9
store_stream.synchronize() gpu_connector.py:309 ~19 0.59 ms ~11.2
Python bookkeeping (hash, alloc, metadata) cache_engine.py, storage_manager.py ~4.0
Measured wait_for_save mean (Exp C, H200 JPS=1) 115.71 ms
Real overhead = ~15 ms/step × 159 steps ≈ 2,385 ms ≈ measured 2.73 s/job overhead. The 100.75 ms forward-pass wait is not overhead — it would have been paid regardless of LMCache.
Upstream fix: vLLM #37160

Upstream PR #37160 makes wait_for_save() a no-op by deferring all offload work to get_finished(), which runs after model execution. This hides the ~15 ms real offload work behind GPU compute, eliminating the serial overhead entirely. See §1.4 for experimental confirmation that sync-pattern patches alone cannot fix this.

1.3 The Hardware Floor: DMA Is Not the Bottleneck

The GPU-side DMA transfer (MEMCPY D2H) was measured separately via Nsight Systems GPU-activity records across 4 GPU families and 3 JPS values (12 measurements, Exp C). These measurements establish the minimum achievable time for the actual data movement, independent of CPU-side synchronization overhead.

GPU PCIe Gen JPS=1 JPS=6 JPS=10 Mean ± CV Source
H200 Gen 5 0.670 ms 0.673 ms 0.675 ms 0.674 ms ± 0.3% Exp C
H100 Gen 5 0.662 ms 0.683 ms 0.679 ms 0.675 ms ± 0.3% Exp C
A100 Gen 4 1.420 ms 1.425 ms 1.415 ms 1.420 ms ± 3.6% Exp C
L40S Gen 4 1.358 ms 1.408 ms 1.343 ms 1.370 ms ± 3.5% Exp C
Gen5 : Gen4 ratio 2.07× (matches PCIe Gen5 : Gen4 spec bandwidth ratio) 4 GPUs, Exp C

On the H200, GPU-side DMA completes in 0.674 ms. The per-call cudaStreamSynchronize cost is 4.74 ms. Therefore the DMA itself accounts for only 0.674 ÷ 4.74 = 14.2% of each sync call; the remaining 85.8% is fixed kernel-mode overhead. The table below itemises this split.

H200 H100 Interpretation
GPU DMA duration (hardware) 0.674 ms 0.674 ms PCIe bandwidth floor
cudaStreamSynchronize total cost 4.74 ms 4.74 ms measured NVTX (Exp C)
DMA as fraction of sync cost 14.2% 14.2% hardware is NOT the bottleneck
Kernel-mode overhead (fixed) 85.8% 85.8% ioctl + interrupt + wakeup latency
Cross-validation: The DMA floor is stable to within ±3.6% across JPS=1/2/4 on all four GPUs, confirming that the bottleneck is not memory bandwidth saturation or contention. Upgrading PCIe bandwidth (Gen4 → Gen5) halves the DMA time but provides no relief against the per-call kernel overhead, which remains ~4.74 ms regardless.

1.4 Optimization Path: Why Sync Patches Failed

Two patches targeting synchronization patterns (Exp D and Exp F) were tested to validate the original (incorrect) root cause hypothesis. Both produced negative results — not because the patches were poorly implemented, but because the ~15 ms real offload work is actual GPU computation, not a sync overhead artifact that can be removed by rearranging sync calls.

Exp D — Coalesce store_stream Sync (Negative Result)

Removed per-chunk store_stream.synchronize() in gpu_connector.py:from_gpu(), retaining only the single sync at the end of batched_from_gpu(). Tested on H200, H100, A100 (same workload as Exp A, JPS=1 and JPS=6, control vs patched sequentially).

GPU Config JPS=1 avg JPS=1 jobs JPS=6 avg JPS=6 jobs Source
H200 control 9.03 s 48 236.0 s 263 Exp D
coalesced 12.40 s (+37%) 66 262.2 s (+11%) 283
H100 control 8.92 s 47 276.0 s 257 Exp D
coalesced 10.48 s (+17%) 55 283.7 s (+3%) 272
A100 control 69.36 s 46 1294.7 s 270 Exp D
coalesced 90.28 s (+30%) 48 1249.0 s (−4%) 264
Negative result: removing per-chunk sync in gpu_connector.from_gpu() alone did NOT reduce per-job latency. At JPS=1, latency increased +17–37% across all GPUs.

Exp F — stream.query() Guard (Negative Result)

Added a store_stream.query() guard before each store_stream.synchronize() in gpu_connector.py, so the sync is skipped when the stream is already idle. Results were also negative: JPS=1 latency increased on both GPUs tested.

GPU Config JPS=1 avg Change Source
H200 control 6.86 s Exp F
stream.query() guard 8.29 s +21%
A100 control 27.88 s Exp F
stream.query() guard 34.85 s +25%
Negative result: skipping store_stream syncs when idle did NOT reduce per-job latency.
Why both patches (Exp D and Exp F) failed
Experiment What was patched Target Result Why it failed
Exp D Coalesce per-chunk store_stream syncs → 1/step store_stream syncs (~11 ms/step) +17–37% JPS=1 latency The ~15 ms offload work is real GPU work (lmc::load_and_reshape + D2H DMA). Sync-pattern changes cannot remove real computation; they only shift when CPU blocks relative to GPU work completion.
Exp F Skip store_stream sync when stream already idle store_stream syncs (~11 ms/step) +21–25% JPS=1 latency

The correct architectural fix is upstream #37160: make wait_for_save() a no-op, deferring all offload work to get_finished() which runs after model execution. This hides the ~15 ms real offload work behind GPU compute rather than serializing it before the next forward pass.

2. Ablation Experimental Design

2.1 Instrumentation

Scheduler Step Instrumentation scheduler.py

Added step_detail trace events capturing per-step metrics:

  • batch_size / total_tokens
  • num_prefills / num_decodes
  • schedule_ms (wall-clock)
  • num_waiting / num_running
KV Save Pipeline (GPU→CPU) lmcache_connector.py

Added kv_load_timing events to the LMCache connector:

  • start_load_kv — CPU→GPU initiation
  • wait_for_layer_load — per-layer blocking
  • wait_for_save — GPU→CPU offload sync
  • forward_kv_summary — total KV overhead

2.2 Sync Ablation (4 Modes)

The initial hypothesis (§2) was that torch.cuda.synchronize() in LMCache's KV-onload path blocks the CPU scheduler thread and starves the GPU. To falsify or confirm this, we need to replace the synchronization primitive itself with progressively looser variants and see whether wall-clock latency improves. If it does, the sync primitive was the bottleneck; if it doesn't, the bottleneck is somewhere else. The tool is our_experiments/scripts/patch_sync.py, which runtime-patches gpu_connector.py inside the LMCache installation in the conda env (so vLLM picks up the change on server restart).

The target line in original LMCache (gpu_connector.py) is at the end of batched_to_gpu, after firing the async CPU→GPU KV load on a dedicated load_stream:

# NOTE: wait until both compute and load streams are done
torch.cuda.synchronize()   # ← the line under test

Each mode replaces exactly this call with a different primitive (the rest of LMCache code is untouched). All four modes are tested head-to-head against the same FCFS/Continuum baselines under identical JPS and workload. Here is what each mode does and what it is supposed to test:

D default

What it does: No change — runs LMCache's original torch.cuda.synchronize(). This is a device-wide block: the CPU thread waits until every kernel on every CUDA stream on the current device has completed, not just the load stream.

What it tests: baseline — the "worst-case" sync strategy that blocks on everything.

S stream_only

What it does: replaces torch.cuda.synchronize() with explicit per-stream syncs on only the streams this operation cares about:

self.load_stream.synchronize()
self.compute_stream.synchronize()

What it tests: whether the "device-wide" scope of default sync is over-conservative. If stream-scoped syncs are meaningfully faster, then default was wasting time waiting for unrelated kernels on other streams.

E event_async

What it does: replaces the device sync with a CUDA event that records a specific point on the current stream, then waits for just that point:

self._sync_event = torch.cuda.Event()
self._sync_event.record()
self._sync_event.synchronize()

What it tests: events are the cheapest fine-grained sync primitive CUDA offers — only one condition variable to check, no scan over all device work. If the bottleneck is CPU overhead inside the sync primitive itself, event_async should show it.

Caveat: record() without a stream argument records on the current stream (typically the default compute stream), so this event may not actually capture load_stream activity. If that's true, event_async is effectively incorrect — which is part of what the experiment reveals.

N no_sync ⚠ UNSAFE

What it does: removes all sync calls in the onload path — both the torch.cuda.synchronize() and the subsequent self.load_stream.synchronize():

pass  # sync removed for ceiling measurement

What it tests: the ceiling — how fast could LMCache possibly run if every synchronization primitive in the onload path were free? This is not a correct implementation (the next kernel could read paged KV blocks before the DMA finishes), but as long as it doesn't crash and the outputs are approximately right, it gives an upper bound on "best case if sync cost were zero".

Logical progression of the four modes

The four modes are ordered from most conservative to least: default blocks on the whole device, stream_only blocks only on the two involved streams, event_async blocks only on a single event, and no_sync blocks on nothing. If the sync-hypothesis is correct, each step should give a progressively larger speedup; if it's wrong, the curve will be flat. The no_sync ceiling is especially load-bearing — if even the free-sync case doesn't beat default by more than a few percent, the bottleneck is demonstrably not in any sync primitive. Spoiler: Finding 1 in §5 shows exactly this flat curve.

D
default
device-wide sync
S
stream_only
2 stream syncs
E
event_async
1 event wait
N
no_sync
0 syncs (unsafe)

2.3 Transfer Ablation

nocache — isolating framework overhead

LMCache with LMCACHE_MAX_LOCAL_CPU_SIZE=1 (1GB — too small for meaningful KV). Isolates whether overhead comes from actual KV transfers or from LMCache's framework/metadata management.

2.4 Parameters

Parameter Value
ModelLlama-3.1-8B-Instruct (bf16, ~16GB)
Tool output tokens/turnRandomized: [1640, 1510, 2455, 1335, 2730, 1930, 775]
Completion tokensFixed 20/turn
Tool execution time0.5s
Turns per job8
Context at turn 8~13,000 tokens
JPS (Jobs per second)1, 3, 6, 10
Duration per run45s

2.5 GPUs Tested

H200

141 GB

28/28 complete

H100

80 GB

§2 + §5 complete

A100

80 GB

§2 + §5 complete

L40S

48 GB

§2 + §5 complete

3. Ablation Results

3.1 Wall-Clock Performance (H200)

Log-scale comparison of all 7 policies. Continuum (green) is the only variant whose curve diverges downward at JPS≥6. The five LMCache variants cluster together above FCFS, confirming that no sync strategy rescues LMCache — see Finding 1 in §5.

H200 bottleneck ablation — Avg JCT vs JPS (7 policies) 1s 10s 100s 1000s Avg JCT (log scale) JPS=1 JPS=3 JPS=6 JPS=10 FCFS LMCache default LMCache stream LMCache event LMCache no_sync LMCache nocache Continuum
Policy JPS=1 JPS=3 JPS=6 JPS=10
FCFS (baseline) 7.3s24.5s226.6s453.1s
LMCache default 9.4s (+29%) 25.5s (+4%) 226.0s (−0%) 516.4s (+14%)
LMCache stream_only 8.5s (+16%) 30.2s (+23%) 261.9s (+16%) 597.9s (+32%)
LMCache event_async 7.9s (+8%) 29.5s (+20%) 233.5s (+3%) 504.0s (+11%)
LMCache no_sync 7.9s (+9%) 37.0s (+51%) 229.6s (+1%) 460.5s (+2%)
LMCache nocache (1GB) 8.4s (+15%) 39.6s (+62%) 249.4s (+10%) 570.4s (+26%)
Continuum 7.7s (+5%) 23.2s (−5%) 75.2s (−67%) 125.8s (−72%)

JPS=10 Mean JCT Comparison

FCFS
453.1s
baseline
LMCache default
516.4s
+14%
LMCache stream
597.9s
+32%
LMCache no_sync
460.5s
+2%
LMCache nocache
570.4s
+26%
Continuum
125.8s
−72%

3.2 Scheduler Step Analysis (H200)

What is a scheduler "step"?

A step in vLLM is one full iteration of the engine loop: the scheduler picks a batch of requests, the GPU runs one model.forward() on that batch (which produces exactly one output token per request in the batch), the sampler picks the token, and the loop returns to the scheduler for the next iteration. Every metric in the table below counts per-step averages across the 28-run benchmark (4 JPS values × 7 policies).

Policy Steps (total) Batch (mean) Tokens/Step (mean) Schedule (ms, mean) Waiting Q (mean)
FCFS8,94516.64,2231.76109.9
LMCache default8,16817.64,3782.85118.2
LMCache stream8,68918.04,7322.70147.5
LMCache event8,50016.94,2532.48112.7
LMCache nosync8,33916.64,1272.5499.1
LMCache nocache8,09018.84,8312.56145.2
Continuum6,45722.02,1742.1427.8

3.3 L40S Results (48GB VRAM) — where the mechanism differs

L40S shows a much more severe LMCache penalty than H200 — average batch size drops from 26.7 (FCFS) to 8.0 (LMCache default), a 3.3× reduction. On H200 the same comparison is 16.6 vs 17.6, basically unchanged. The obvious-sounding explanation — "L40S has less VRAM so LMCache's buffers eat into KV budget" — is wrong. The actual mechanism is a subtler interaction between vLLM's auto-tuned max_num_batched_tokens and LMCache's wait_for_save stall.

L40S (48 GB) H200 (141 GB) Source
max_num_batched_tokens 2,048 8,192 vLLM auto-tuner
Prefill chunks / 13K request 7 (13,193 ÷ 2,048) 2 (13,193 ÷ 8,192) arithmetic
wait_for_save per prefill step 187 ms 115.7 ms L40S: Exp B; H200: Exp C §1.1
Total wfs stall / prefill job 1,309 ms (7 × 187) 231 ms (2 × 115.7) derived
Batch shrink (FCFS ÷ LMCache) 3.3× (26.7 ÷ 8.0) 0.94× (no shrink) derived
Testable prediction — FALSIFIED (Exp E)

Prediction: setting --max-num-batched-tokens=8192 on L40S should shrink the LMCache batch-size penalty. Result: prediction falsified. Both policies became uniformly slower at mbt=8192. Root cause: L40S smaller VRAM → auto-tuner sets max_num_batched_tokens=2048 → 13K prompt needs 7 chunked prefill steps → 7× wait_for_save stalls → step rate collapses. H200's 8,192 token budget = only 2 chunks = tolerable stall.

4. Deep Dive: KV Transfer Pipeline

4.1 vLLM Step Lifecycle

A vLLM step = one schedule() + one model.forward() + sampling. Each step produces one output token per request in the batch.

schedule()

Decide which requests to batch

build_connector_meta()

Build KV work list — per-request ReqMeta with load_spec + save_spec. Pure decode requests are filtered out entirely (skip_save=True + load_spec=None).

start_load_kv() — 0.2ms avg, 4.1% of steps

CPU→GPU onload (only for new requests with cache hit)

model.forward() — 32 attention layers

  • wait_for_layer_load() — wait for this layer's onload
  • attention + MLP (actual GPU compute)
  • save_kv_layer() — start async GPU→CPU offload

wait_for_save() — see §1 for per-call attribution

Block until ALL async GPU→CPU copies complete. One prefill in a batch of 23 requests blocks all 22 decode requests.

sample → 1 token per request

4.2 Per-JPS Breakdown: Why High Concurrency Hurts More

JPS Steps Batch Tokens/Step Prefill % wfs all-steps mean wfs prefill-only mean Source
11,4975.3431 10% 13.1 ms 115.7 ms Exp B / Exp C §1.1
31,60313.1882 16% 29.9 ms Exp B (no Exp C at JPS=3)
61,89321.65,789 72% 175.2 ms 226.0 ms Exp B / Exp C
103,17023.37,174 88% 215.1 ms 227.1 ms Exp B / Exp C
Scaling Law

Higher concurrency = more prefill requests per step = more offload transfers = more blocking. At JPS=1, only 10% of steps trigger real saves (13ms). At JPS=10, 88% trigger saves (215ms). This is why LMCache's penalty scales with load.

4.3 Full Call Chain

wait_for_save() vllm_v1_adapter.py:1019

lmcache_engine.store() cache_engine.py:176

token_database.process_tokens()

Compute chunk boundaries, hash keys

storage_manager.allocate()

Allocate CPU memory for KV

gpu_connector.batched_from_gpu() dominant cost

  • from_gpu()lmc_ops.multi_layer_kv_transfer() (CUDA kernel)
  • store_stream.synchronize() — wait for DMA completion

storage_manager.batched_put()

Insert into hash table

4.4 Why the blocking wait exists — upstream rationale

"Block until all the save operations is done. This is called as the forward context exits to ensure that the async saving from save_kv_layer is complete before finishing the forward. This prevents overwrites of paged KV buffer before saving done."

Source: vllm/distributed/kv_transfer/kv_connector/v1/base.pyKVConnectorBase_V1.wait_for_save docstring

That sentence is the entire upstream design discussion. There is no GitHub issue debating the trade-off, no PR proposing a non-blocking alternative, no blog post elaborating on the constraint, and no design doc. The wait-for-save barrier is treated as a self-evident safety requirement rather than an explicit trade-off.

The paged block reuse hazard

Step N — request A

Prefill writes KV for request A into block #47. At step N's forward-context exit, save_kv_layer() fires an async GPU→CPU DMA reading from block #47, queued on store_stream.

(no wait_for_save)

vLLM's engine loop proceeds immediately to step N+1. The store_stream DMA is still running in the background, reading block #47 byte by byte.

Step N+1 — request B hijacks block #47

Scheduler's block allocator sees #47 as free and allocates it to request B's prefill. B's attention kernel writes new KV to block #47. At the same moment the DMA from step N is still reading — result: CPU-side cache receives a mix of A's and B's bytes. The cache is corrupted silently.

The blocking wait_for_save forbids this by freezing the engine loop until every in-flight save has finished reading every dirty block. It is a correctness barrier for the block allocator, not a performance choice — there is no way to skip it without changing the semantics of the free pool.

Potential mitigations (none implemented upstream as of 2026-04)

Mitigation What it changes Cost / drawback Implementation effort
Double buffering Keep a shadow copy of every dirty block while its DMA is in flight. ~2× VRAM cost for dirty blocks. Medium
Reference counting on in-flight save Block allocator tracks a "DMA pending" refcount per block. A block with pending saves is NOT returned to the free pool. Almost zero cost — a few bytes per block. Smallest change (~50 LOC in vLLM core)
Separated offload process Move the entire transfer subsystem into a sidecar process that shares VRAM via CUDA IPC. Largest refactor. But most robust. Very high — RFC vllm#22605

5. Key Findings

Finding 1: The sync hypothesis is FALSIFIED

Removing torch.cuda.synchronize() entirely (no_sync mode) produced minimal improvement: JPS=10 only 2% better, JPS=3 actually 51% worse. All sync variants (stream, event, nosync) performed equal or worse than default.

Finding 2: wait_for_save is the dominant overhead

The distribution of wait_for_save() is bimodal (§1.1): 91% of calls (decode-only) return in ~0.01 ms; the remaining 9% (prefill-step) average 115.7 ms (median 73 ms, p95 253 ms). Per §1.2, of that 115.7 ms: ~100 ms is waiting for the forward pass (unavoidable GPU compute, not LMCache overhead); ~15 ms is actual offload work (lmc::load_and_reshape + D2H DMA + store_stream syncs + Python bookkeeping — the real overhead). 159 steps × 15 ms ≈ 2,385 ms ≈ measured 2.73 s/job overhead. GPU-side DMA hardware floor: 53.6 GB/s (§1.3). The architectural fix (§1.4, upstream #37160) defers offload work to get_finished(), hiding it behind GPU compute.

Finding 3: Overhead is structural, not transfer-related

nocache (1GB CPU cache, almost no actual KV stored) is worse than default LMCache: JPS=3 +62% vs FCFS, JPS=10 +26%. Even when LMCache has nothing to transfer, its connector metadata management and per-step save/load hooks impose a significant tax.

Finding 4: stream_only makes things worse

Replacing torch.cuda.synchronize() with stream.synchronize() was uniformly worse across all JPS (+16% to +32%). The device-level sync may provide unintended ordering guarantees that prevent contention.

Finding 5: Batch size impact is VRAM-dependent
GPUFCFS BatchLMCache BatchRatio
H200 (141GB)16.617.61.06×
L40S (48GB)26.78.00.30×
Finding 6: Offload overhead scales with concurrency
JPSReal save frequencywait_for_save
110%13ms
316%30ms
672%175ms
1088%215ms

Higher concurrency = more prefill requests per step = more offload transfers. At JPS=10, one prefill in a batch of 23 requests blocks all 22 decode requests for 215ms.

6. Future Directions

6.1 Optimize wait_for_save (highest impact)

The save path runs synchronously after every forward pass. Potential optimizations:

  • Lazy saveonly save KV when the request is about to be evicted, not after every step
  • Async save pipelinemove save to a background thread/stream, don't block the main loop
  • Selective saveskip saving for short-lived requests unlikely to be reused
6.2 Reduce LMCache Framework Overhead

The nocache experiment shows structural overhead even without transfers. Investigate:

  • Eliminate per-step build_connector_meta cost when no loads/saves are needed
  • Skip save_kv_layer calls entirely when CPU cache is full/disabled
  • Profile Python-level overhead (dict lookups, tensor metadata, etc.)
6.3 VRAM-aware Tuning for Smaller GPUs

On L40S (48GB), LMCache's buffer allocations reduce batch capacity 3.3×. Potential fixes:

  • Dynamic buffer sizing based on available VRAM
  • Reduce gpu-memory-utilization for LMCache to leave headroom for KV cache slots
6.4 Clean L40S wait_for_save Baseline

The L40S nsys-bn run shared a node with the ctrl-sweep fill job (6428519), contaminating the wait_for_save latency measurements (though from_gpu DMA floor remained clean). A dedicated rerun on an isolated L40S node would provide a clean Gen4 wait_for_save baseline for comparison with the clean A100 Gen4 data.

7. Code & Data Files

Code Changes

File Change
vllm/v1/core/sched/trace.py Added step_detail and kv_load_timing trace events
vllm/v1/core/sched/scheduler.py Added per-step timing and batch composition logging
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py Added timing to start_load_kv, wait_for_layer_load, wait_for_save

Experiment Scripts

Script Purpose
our_experiments/scripts/patch_sync.py Sync ablation tool (default/stream/event/nosync)
our_experiments/scripts/analyze_bottleneck.py Analysis: parses traces, produces comparison tables
pace/21_bottleneck_experiments.sbatch Full experiment: 7 variants × JPS sweep
our_experiments/scripts/vllm_with_nvtx.py NVTX monkey-patch wrapper for nsys profiling (§1)
pace/22_nsys_bottleneck.sbatch nsys profile capture: NVTX + CUDA runtime × JPS sweep per GPU

Results Directory

our_experiments/results/20260407_controlled-sweep/
H200/ — 15 JSON (COMPLETE) — §3 source
H100/ — 15 JSON (COMPLETE) — §3 source
A100/ — 15 JSON (COMPLETE, filled 2026-04-12) — §3 source
L40S/ — 15 JSON (COMPLETE, filled 2026-04-11) — §3 source
our_experiments/results/20260407_bottleneck/
H200/ — 28 JSON + trace files (COMPLETE) — §1–§4 source
H100/ — 10/28 (partial; §1 nsys data covers H100 analysis)
L40S/ — 14/28 (partial; §1 nsys data covers L40S analysis)
A100/ — 0/28 (not run; §1 nsys data covers A100 analysis)
our_experiments/results/20260411_nsys-{h200,h100,a100,l40s}/
All 4 GPUs — 6 nsys profiles + sqlite each (COMPLETE) — §1 source