1. Architectural Overview
SGLang's RadixCache is the signature innovation of the SGLang inference engine. It uses a radix tree (compressed trie) to store and reuse KV cache entries across requests that share common token prefixes. This enables massive speedups for multi-turn conversations, few-shot prompting, and shared system prompts.
SGLang organizes GPU memory for KV cache into three conceptual layers:
ReqToTokenPool-- maps each request slot to its token positionsTokenToKVPoolAllocator-- manages free/allocated indices into the KV buffersKVCache(e.g.,MHATokenToKVPool) -- holds the actual K/V tensor buffers on GPU
On top of this, the RadixCache radix tree tracks which token sequences have their KV cache already computed and stored at specific indices. When a new request arrives with a prefix that matches an existing cache entry, SGLang skips the prefill computation for that prefix entirely.
2. TreeNode -- The Radix Tree Node
radix_cache.py : TreeNode (line 121)Every node in the radix tree is a TreeNode instance. The tree is a compressed trie where each edge stores a sequence of token IDs (not just a single token). This compression means a path like [1,2,3] is stored as a single node with key = RadixKey([1,2,3]) rather than three separate nodes.
class TreeNode:
counter = 0
def __init__(self, id=None, priority=0):
self.children = defaultdict(TreeNode) # child_key -> TreeNode
self.parent = None # parent TreeNode
self.key = None # RadixKey (list of token_ids)
self.value = None # torch.Tensor of KV cache indices
self.lock_ref = 0 # active request count (prevents eviction)
self.pin_expiry = 0.0 # absolute expiry time for pinning
self.pin_ttl = 0 # TTL in seconds for refresh-on-hit
self.last_access_time = time.monotonic() # for LRU eviction
self.creation_time = time.monotonic() # for FIFO eviction
self.hit_count = 0 # for LFU eviction
self.host_ref_counter = 0 # protects host KV from eviction
self.host_value = None # CPU-side KV indices (HiRadixCache)
self.hash_value = None # SHA256 per-page hashes
self.priority = priority # for priority-aware eviction
@property
def evicted(self):
return self.value is None # GPU value gone, but node may still have host_value
@property
def backuped(self):
return self.host_value is not None # has CPU backup (for HiRadixCache)
RadixKey -- The Namespace-Aware Key
Keys are wrapped in a RadixKey object which carries both the token_ids list and an optional extra_key string. The extra_key allows isolation of KV cache for different LoRA adapters or cache salts -- entries that share identical token IDs but have different extra_key values never share prefix nodes.
class RadixKey:
def __init__(self, token_ids: List[int], extra_key: Optional[str] = None,
is_bigram: bool = False):
self.token_ids = token_ids
self.extra_key = extra_key # e.g. lora_id, cache_salt
self.is_bigram = is_bigram # for EAGLE speculative decoding
Child Key Construction
When looking up children, the get_child_key() function constructs the dictionary key. For page_size=1, it is the first token ID. For page_size>1, it is a tuple of the first page_size tokens. If extra_key is set, it becomes a tuple (extra_key, plain_key) to ensure namespace isolation:
def get_child_key(key: RadixKey, page_size: int = 1):
if page_size == 1:
plain_key = key.token_ids[0]
else:
plain_key = tuple(key.token_ids[:page_size])
if key.extra_key is None:
return plain_key
else:
return (key.extra_key, plain_key)
3. RadixCache Initialization
radix_cache.py : RadixCache.__init__ (line 289)The RadixCache class is initialized with a CacheInitParams object. It sets up the eviction strategy, the key matching function (optimized for page_size=1 vs paged), and initializes the root node:
class RadixCache(BasePrefixCache):
def __init__(self, params: CacheInitParams):
self.disable = params.disable
self.req_to_token_pool = params.req_to_token_pool
self.token_to_kv_pool_allocator = params.token_to_kv_pool_allocator
self.page_size = params.page_size
# Select key matching: token-by-token for page_size=1, paged otherwise
if self.page_size == 1:
self.key_match_fn = _key_match_page_size1
else:
self.key_match_fn = partial(_key_match_paged, page_size=self.page_size)
# Select eviction strategy (7 options!)
if self.eviction_policy == "lru":
self.eviction_strategy = LRUStrategy()
elif self.eviction_policy == "lfu":
self.eviction_strategy = LFUStrategy()
# ... also: fifo, mru, filo, priority, slru
self.evictable_leaves = set()
self.reset()
def reset(self):
self.root_node = TreeNode(priority=-sys.maxsize)
self.root_node.key = RadixKey(token_ids=[], extra_key=None)
self.root_node.value = []
self.root_node.lock_ref = 1 # root is always locked
self.evictable_size_ = 0
self.protected_size_ = 0
The root node always has lock_ref = 1, making it permanently immune to eviction. Every other node's lock_ref is incremented when an active request traverses through it (via inc_lock_ref) and decremented when the request finishes (via dec_lock_ref).
4. match_prefix() -- Finding the Longest Cached Prefix
radix_cache.py : match_prefix (line 378) / _match_prefix_helper (line 671)When a new request arrives, SGLang calls match_prefix() to find the longest prefix of the request's token IDs that already has KV cache stored in the tree. This is the core operation that enables prefix reuse.
The Public API
def match_prefix(self, params: MatchPrefixParams) -> MatchResult:
key = params.key
key, _ = self.maybe_bigram_convert(key) # EAGLE bigram conversion if needed
if self.disable or len(key) == 0:
return empty_match_result()
# For paged mode, truncate to page-aligned length
if self.page_size != 1:
page_aligned_len = len(key) // self.page_size * self.page_size
key = key[:page_aligned_len]
value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.cat(value) # concatenate KV indices from matched nodes
return MatchResult(device_indices=value, last_device_node=last_node, ...)
The Internal Tree Walk
The helper walks down the tree following matching children. At each node, it compares the node's key against the remaining query key using key_match_fn. If a match ends inside a node's key (partial match), it splits the node to create a precise boundary:
def _match_prefix_helper(self, node, key):
access_time = time.monotonic()
node.last_access_time = access_time # update for LRU
child_key = self.get_child_key_fn(key)
value = []
while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key]
child.last_access_time = access_time
prefix_len = self.key_match_fn(child.key, key) # how many tokens match?
if prefix_len < len(child.key):
# PARTIAL MATCH: split the node
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value)
node = new_node
break
else:
# FULL MATCH: consume this node and continue
value.append(child.value)
node = child
key = key[prefix_len:]
if len(key):
child_key = self.get_child_key_fn(key)
return value, node # list of KV index tensors, last matched node
Node Splitting (_split_node)
When a match ends at position split_len inside a node, the node is split into two: a new parent holding the matched prefix, and the existing node retaining the unmatched suffix:
def _split_node(self, key, child, split_len):
# Create new_node -> child
new_node = TreeNode(priority=child.priority)
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
new_node.parent = child.parent
new_node.lock_ref = child.lock_ref # inherit lock count
new_node.key = child.key[:split_len]
new_node.value = child.value[:split_len].clone()
child.parent = new_node
child.key = child.key[split_len:]
child.value = child.value[split_len:].clone()
new_node.parent.children[self.get_child_key_fn(key)] = new_node
return new_node
5. insert() -- Adding New KV Cache Entries
radix_cache.py : insert (line 450) / _insert_helper (line 727)After a request completes its prefill computation, the new KV cache entries are inserted into the tree via insert(). The insert operation walks the tree much like match_prefix, but it also creates new nodes for the unmatched suffix and deduplicates against existing entries.
def _insert_helper(self, node, key, value, priority=0, chunked=False):
access_time = time.monotonic()
node.last_access_time = access_time
node.priority = max(node.priority, priority) # propagate higher priority
if len(key) == 0:
return 0
child_key = self.get_child_key_fn(key)
total_prefix_length = 0
# Walk down existing matching nodes
while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key]
node.last_access_time = access_time
prefix_len = self.key_match_fn(node.key, key)
total_prefix_length += prefix_len
key = key[prefix_len:] # consume matched portion
value = value[prefix_len:] # skip already-cached KV indices
if prefix_len < len(node.key):
# Partial match: split and stop
new_node = self._split_node(node.key, node, prefix_len)
self._inc_hit_count(new_node, chunked)
node = new_node
else:
self._inc_hit_count(node, chunked)
if len(key):
child_key = self.get_child_key_fn(key)
# Create new node for unmatched suffix
if len(key):
new_node = TreeNode(priority=priority)
new_node.parent = node
new_node.key = key
new_node.value = value.clone()
node.children[child_key] = new_node
self.evictable_size_ += len(key)
self._update_leaf_status(node)
self._update_leaf_status(new_node)
return total_prefix_length # how much was already in the tree
cache_finished_req -- The Full Lifecycle
When a request finishes, cache_finished_req() orchestrates the insertion and deduplication:
def cache_finished_req(self, req, is_insert=True):
token_ids = (req.origin_input_ids + req.output_ids)[:kv_committed_len]
kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :len(token_ids)]
if is_insert:
result = self.insert(InsertParams(key=radix_key, value=values, ...))
new_prefix_len = result.prefix_len
# Free DUPLICATES: indices that were already in the tree
self.token_to_kv_pool_allocator.free(
kv_indices[req.cache_protected_len : new_prefix_len]
)
# Free unaligned tail (page_size > 1 can leave a partial page)
self.token_to_kv_pool_allocator.free(kv_indices[len(keys):])
# Release lock on the tree path
self.dec_lock_ref(req.last_node)
The return value total_prefix_length from _insert_helper tells the caller how many KV indices at the beginning of the inserted sequence were already in the tree. Those duplicates are freed back to the allocator, avoiding memory waste.
6. Eviction Policies
evict_policy.py (all strategies) | radix_cache.py : evict (line 586)SGLang supports seven different eviction strategies, all implementing the EvictionStrategy abstract class. Each strategy defines a get_priority() method; lower priority values are evicted first (using a min-heap).
| Strategy | Priority Key | Description |
|---|---|---|
| LRU | node.last_access_time | Least Recently Used -- oldest access evicted first |
| LFU | (node.hit_count, node.last_access_time) | Least Frequently Used -- fewest hits first, then LRU as tiebreaker |
| FIFO | node.creation_time | First In First Out -- oldest creation evicted first |
| MRU | -node.last_access_time | Most Recently Used -- newest access evicted first |
| FILO | -node.creation_time | First In Last Out -- newest creation evicted first |
| Priority | (node.priority, node.last_access_time) | Priority-aware -- lower priority evicted first, then LRU |
| SLRU | (segment, node.last_access_time) | Segmented LRU -- probationary items (hit_count < threshold) evicted before protected |
The Eviction Process
def evict(self, params: EvictParams) -> EvictResult:
num_tokens = params.num_tokens
leaves = list(self.evictable_leaves) # pre-tracked set of evictable leaves
# Build min-heap using strategy priorities
eviction_heap = [
(self.eviction_strategy.get_priority(node), node)
for node in leaves
]
heapq.heapify(eviction_heap)
num_evicted = 0
while num_evicted < num_tokens and len(eviction_heap):
_priority, x = heapq.heappop(eviction_heap)
# Free the KV cache indices back to allocator
self.token_to_kv_pool_allocator.free(x.value)
num_evicted += len(x.value)
# Remove leaf from tree
self._delete_leaf(x)
# If parent becomes a childless unlocked node, add to heap
if len(x.parent.children) == 0 and x.parent.lock_ref == 0:
new_priority = self.eviction_strategy.get_priority(x.parent)
heapq.heappush(eviction_heap, (new_priority, x.parent))
return EvictResult(num_tokens_evicted=num_evicted)
Evictable Leaf Tracking
Rather than scanning the entire tree, SGLang maintains a evictable_leaves set that is updated incrementally. A node is an evictable leaf when:
- It is not evicted (
value is not None) - It has
lock_ref == 0(no active requests using it) - All its children are evicted (it is effectively a leaf in the non-evicted subtree)
def _update_leaf_status(self, node):
if node.evicted or node.lock_ref > 0:
# Remove from evictable set if present
if node in self.evictable_leaves:
self.evictable_leaves.remove(node)
return
# Check if any child is NOT evicted
for child in node.children.values():
if not child.evicted:
if node in self.evictable_leaves:
self.evictable_leaves.remove(node)
return
# All children evicted (or no children) -> this is an evictable leaf
if node not in self.evictable_leaves:
self.evictable_leaves.add(node)
7. Reference Counting -- Protecting Active Requests
radix_cache.py : inc_lock_ref (line 615) / dec_lock_ref (line 630)The reference counting system prevents eviction of cache entries that are being actively used by in-flight requests. When a request matches a prefix and starts processing, inc_lock_ref() is called on the terminal node, and it propagates upward to the root:
def inc_lock_ref(self, node) -> IncLockRefResult:
delta = 0
while node != self.root_node:
if node.lock_ref == 0:
# Transitioning from evictable to protected
self.evictable_size_ -= len(node.key)
self.protected_size_ += len(node.key)
delta -= len(node.key)
node.lock_ref += 1
self._update_leaf_status(node)
node = node.parent
return IncLockRefResult(delta=delta)
def dec_lock_ref(self, node, ...) -> DecLockRefResult:
delta = 0
while node != self.root_node:
if node.lock_ref == 1:
# Transitioning from protected to evictable
self.evictable_size_ += len(node.key)
self.protected_size_ -= len(node.key)
delta += len(node.key)
node.lock_ref -= 1
self._update_leaf_status(node)
node = node.parent
return DecLockRefResult(delta=delta)
Locking propagates from the terminal matched node all the way up to the root because eviction works bottom-up (leaf nodes first). If a child is locked, the parent must also be protected -- otherwise evicting the parent would leave the child's KV cache in an inconsistent state. The sizes evictable_size_ and protected_size_ are tracked incrementally so the scheduler can make memory-aware decisions without scanning the tree.
8. Memory Pools: ReqToTokenPool & MHATokenToKVPool
memory_pool.py : ReqToTokenPool (line 123) / MHATokenToKVPool (line 733)ReqToTokenPool -- Request to Token Mapping
This pool maps each request slot to the token positions it occupies in the KV cache. It is a 2D tensor of shape [size, max_context_len] stored as torch.int32 on the GPU device:
class ReqToTokenPool:
def __init__(self, size, max_context_len, device, enable_memory_saver):
self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device
)
self.free_slots = list(range(size)) # Python list of available slot indices
def alloc(self, reqs: list[Req]) -> Optional[List[int]]:
# Reuse existing slots for chunked prefill (continuing across chunks)
reusing = [i for i, r in enumerate(reqs) if r.req_pool_idx is not None]
need_size = len(reqs) - len(reusing)
if need_size > len(self.free_slots):
return None # not enough request slots
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
# Assign req_pool_idx to each request
offset = 0
for r in reqs:
if r.req_pool_idx is None:
r.req_pool_idx = select_index[offset]
offset += 1
return [r.req_pool_idx for r in reqs]
def free(self, req):
self.free_slots.append(req.req_pool_idx)
req.req_pool_idx = None
MHATokenToKVPool -- The Actual KV Tensor Storage
This is where the GPU memory physically lives. For Multi-Head Attention (MHA), it stores separate K and V buffers per layer:
class MHATokenToKVPool(KVCache):
def _create_buffers(self):
# [size + page_size, head_num, head_dim] for each layer
# The padded slot 0 is used for dummy outputs from padded tokens
self.k_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype, device=self.device
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.v_head_dim),
dtype=self.store_dtype, device=self.device
)
for _ in range(self.layer_num)
]
def set_kv_buffer(self, layer, loc, cache_k, cache_v, ...):
# Called during forward pass to store computed K/V into cache
# Uses optimized store_cache JIT kernel when possible
_set_kv_buffer_impl(
cache_k, cache_v,
self.k_buffer[layer_id - self.start_layer],
self.v_buffer[layer_id - self.start_layer],
loc, ...
)
9. TokenToKVPoolAllocator -- Allocation & Deallocation
allocator.py : TokenToKVPoolAllocator (line 117) / PagedTokenToKVPoolAllocator (line 356)SGLang provides two allocator variants. The base TokenToKVPoolAllocator operates at token granularity (page_size=1), while PagedTokenToKVPoolAllocator operates at page granularity for higher efficiency with paged attention.
TokenToKVPoolAllocator (page_size=1)
class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
def __init__(self, size, dtype, device, kvcache, need_sort):
super().__init__(size, 1, dtype, device, kvcache, need_sort)
self.clear()
def clear(self):
# Slot 0 is reserved as a padding dummy slot
self.free_pages = torch.arange(1, self.size + 1,
dtype=torch.int64, device=self.device)
self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
def alloc(self, need_size):
# If sorted mode and insufficient free pages, merge with release_pages
if self.need_sort and need_size > len(self.free_pages):
self.merge_and_sort_free()
if need_size > len(self.free_pages):
return None # OOM
select_index = self.free_pages[:need_size]
self.free_pages = self.free_pages[need_size:]
return select_index
def free(self, free_index):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
if self.need_sort:
# Deferred: add to release_pages (will be merged later)
self.release_pages = torch.cat((self.release_pages, free_index))
else:
# Immediate: prepend to free_pages
self.free_pages = torch.cat((free_index, self.free_pages))
else:
# Batched free: accumulate in group, apply at free_group_end()
self.free_group.append(free_index)
PagedTokenToKVPoolAllocator -- Triton-Accelerated Page Allocation
For paged operation, the allocator uses Triton kernels for batch allocation during extend and decode phases. The alloc_extend_kernel handles three parts: filling the old partial page, allocating full new pages, and filling a new partial page at the end.
class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
def alloc_extend(self, prefix_lens, ..., seq_lens, ..., last_loc, extend_num_tokens):
out_indices = torch.empty((extend_num_tokens,), dtype=torch.int64, device=...)
# Triton kernel computes per-request page allocation in parallel on GPU
alloc_extend_kernel[(bs,)](
prefix_lens, seq_lens, last_loc,
self.free_pages, out_indices,
next_power_of_2(bs), self.page_size
)
num_new_pages = get_num_new_pages(seq_lens=..., page_size=..., prefix_lens=...)
self.free_pages = self.free_pages[num_new_pages:]
return out_indices
def free(self, free_index):
# Convert token indices back to page indices
free_page_indices = torch.unique(free_index // self.page_size)
self.free_pages = torch.cat((free_page_indices, self.free_pages))
The allocator supports batched free via free_group_begin()/free_group_end(). Between these calls, all free() operations accumulate in a list and are applied as a single torch.cat at the end, reducing the number of tensor concatenation operations.
The need_sort flag controls whether freed indices go to release_pages (deferred, sorted at merge time) or directly to free_pages. Sorted allocation is important for some attention backends that benefit from locality of KV cache indices in GPU memory.
10. HiRadixCache -- Three-Tier Hierarchical Caching
hiradix_cache.py : HiRadixCache (line 56)The HiRadixCache extends RadixCache with a three-tier memory hierarchy: GPU (device) -> CPU (host) -> Storage (disk/remote). When GPU memory is exhausted, KV cache entries are demoted to CPU memory; when CPU memory is full, entries can be further spilled to persistent storage.
Initialization -- Setting Up the Hierarchy
class HiRadixCache(RadixCache):
def __init__(self, params, server_args):
self.kv_cache = params.token_to_kv_pool_allocator.get_kvcache()
# Create CPU-side KV pool matching the GPU KV cache type
if isinstance(self.kv_cache, MHATokenToKVPool):
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
self.kv_cache, server_args.hicache_ratio,
server_args.hicache_size, self.page_size, ...
)
elif isinstance(self.kv_cache, MLATokenToKVPool):
self.token_to_kv_pool_host = MLATokenToKVPoolHost(...)
# Cache controller manages async GPU<->CPU<->Storage transfers
self.cache_controller = HiCacheController(
params.token_to_kv_pool_allocator,
self.token_to_kv_pool_host,
self.page_size, self.tp_group,
write_policy=server_args.hicache_write_policy, # write_back | write_through
storage_backend=server_args.hicache_storage_backend,
...
)
# Write-through: backup after hit_count reaches threshold
self.write_through_threshold = (
1 if server_args.hicache_write_policy == "write_through" else 2
)
self.load_back_threshold = 10 # minimum tokens to justify a load-back
self.evictable_host_leaves = set() # host-side evictable tracking
Eviction with Host Backup
The eviction process in HiRadixCache is significantly more complex than in the base RadixCache. Instead of simply freeing GPU memory, it can demote entries to CPU:
def evict(self, params) -> EvictResult:
# Build eviction heap from evictable leaves (same as base)
while num_evicted < num_tokens and len(eviction_heap):
_priority, x = heapq.heappop(eviction_heap)
if x.lock_ref > 0:
continue
if self._is_pinned(x):
# Pinned node: demote to host, keep in tree
if x.backuped:
num_evicted += self._evict_backuped(x) # just free GPU indices
else:
written = self.write_backup(x, write_back=True)
...
elif not x.backuped:
if self.cache_controller.write_policy == "write_back":
# Write-back: backup to host before freeing GPU
num_evicted += self.write_backup(x, write_back=True)
else:
# Write-through: already backed up, just free GPU
num_evicted += self._evict_regular(x)
else:
num_evicted += self._evict_backuped(x)
def _evict_backuped(self, node):
# GPU -> CPU demotion: free device indices but keep node in tree
num_evicted = self.cache_controller.evict_device(node.value)
node.value = None # mark as evicted from GPU
# node.host_value still has the CPU indices
self._update_leaf_status(node)
self._update_host_leaf_status(node)
return num_evicted
Load-Back: CPU -> GPU Restoration
When a match_prefix() hits a node whose GPU data was evicted but still exists on CPU, load_back() restores it:
def load_back(self, node, mem_quota=None):
# Walk up to find the nearest non-evicted ancestor
last_hit_node = node
nodes_to_load = []
while node.evicted:
assert node.backuped # must have host backup
nodes_to_load.insert(0, node)
node = node.parent
# Protect ancestor from eviction during load
self.inc_lock_ref(ancester_node)
# Concatenate all host indices for batch transfer
host_indices = torch.cat([n.host_value for n in nodes_to_load])
# Skip if too small (< 10 tokens) or exceeds memory quota
if len(host_indices) < self.load_back_threshold:
return None
# Allocate GPU space and initiate async DMA transfer
device_indices = self.cache_controller.load(host_indices=host_indices, ...)
if device_indices is None:
# Not enough GPU memory, try evicting first
self.evict(EvictParams(num_tokens=len(host_indices)))
device_indices = self.cache_controller.load(...)
# Restore node values from device_indices
offset = 0
for node in nodes_to_load:
node.value = device_indices[offset : offset + len(node.host_value)].clone()
offset += len(node.host_value)
self.evictable_size_ += len(device_indices)
Write-Through: Hit Count Triggered Backup
With the write-through policy, nodes are proactively backed up to CPU after reaching a hit count threshold:
def _inc_hit_count(self, node, chunked=False):
if self.cache_controller.write_policy == "write_back" or chunked:
return
node.hit_count += 1
if not node.backuped:
if node.hit_count >= self.write_through_threshold:
self.write_backup(node) # async GPU->CPU copy
Host Eviction
When even CPU memory is exhausted, evict_host() removes entries from CPU, fully deleting the tree node:
def evict_host(self, num_tokens):
# Uses same eviction strategy but on evictable_host_leaves
leaves = list(self.evictable_host_leaves)
eviction_heap = [(self.eviction_strategy.get_priority(n), n) for n in leaves]
heapq.heapify(eviction_heap)
while num_evicted < num_tokens and eviction_heap:
_priority, x = heapq.heappop(eviction_heap)
if not x.evicted: # only evict nodes already gone from GPU
continue
if x.host_ref_counter > 0: # protected by ongoing prefetch/backup/pin
continue
num_evicted += self.cache_controller.evict_host(x.host_value)
# Fully remove node from tree
x.parent.children.pop(self.get_child_key_fn(x.key))
11. SessionAwareCache -- Streaming Session KV Management
session_aware_cache.py : SessionAwareCache (line 112)The SessionAwareCache is a decorator (wrapper) around any BasePrefixCache implementation that adds support for streaming sessions. In streaming mode, KV cache is held across multiple request turns within the same session, avoiding re-computation between turns of a multi-turn conversation.
class SessionAwareCache(BasePrefixCache):
def __init__(self, inner: BasePrefixCache):
self.inner = inner # wrapped cache (RadixCache, HiRadixCache, etc.)
self.slots: Dict[str, SessionSlot] = {} # session_id -> slot
SessionSlot -- Per-Session KV State
@dataclass
class SessionSlot:
virtual_node: _VirtualNode = field(default_factory=_VirtualNode)
req_pool_idx: Optional[int] = None # allocated request pool slot
kv_committed_len: int = 0 # how many tokens have KV cached
kv_allocated_len: int = 0 # how many tokens are allocated
last_node: Any = None # radix tree node (for dec_lock_ref on close)
cache_protected_len: int = 0 # overlap with radix tree
The key insight is that for streaming sessions, match_prefix() directly restores KV state from the SessionSlot instead of walking the radix tree. The _VirtualNode sentinel allows inc_lock_ref/dec_lock_ref to be no-ops for session-managed KV, since the session itself holds the KV alive:
def match_prefix(self, params) -> MatchResult:
req = params.req
if not _is_streaming(req):
return self.inner.match_prefix(params) # pass through for non-streaming
slot = self.slots.get(req.session.session_id)
if slot is None or slot.req_pool_idx is None:
return self.inner.match_prefix(params) # first turn: use radix tree
# Subsequent turns: restore from session slot
slot.restore_to_req(req)
prefix_len = min(req.kv_committed_len, max(len(params.key.token_ids) - 1, 0))
device_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :prefix_len]
return MatchResult(device_indices=device_indices, last_device_node=slot.virtual_node, ...)
12. RadixAttention Layer -- Integrating Cache in Forward Pass
layers/radix_attention.py : RadixAttention (line 47)The RadixAttention module is the attention layer used in SGLang model implementations. Despite its name, it does not directly reference the radix cache tree -- instead, it delegates to the attention backend, which reads from the KV cache buffers at the indices provided by the cache system.
class RadixAttention(nn.Module):
def __init__(self,
num_heads, head_dim, scaling, num_kv_heads, layer_id,
logit_cap=0.0, v_head_dim=-1, sliding_window_size=-1,
is_cross_attention=False, pos_encoding_mode="NONE", ...
):
self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads
self.head_dim = head_dim
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
self.layer_id = layer_id
self.sliding_window_size = sliding_window_size or -1
def forward(self, q, k, v, forward_batch: ForwardBatch,
save_kv_cache=True, **kwargs):
if k is not None:
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
# For extend (prefill) mode with torch.compile context:
if forward_batch.forward_mode.is_extend() and get_forward_context():
output = torch.empty_like(q) if ... else ...
unified_attention_with_output(q, k, v, output, save_kv_cache, self.layer_id)
return output
else:
# Delegate to attention backend (FlashAttention, Triton, etc.)
return forward_batch.attn_backend.forward(
q, k, v, self, forward_batch, save_kv_cache, **kwargs
)
How Cache Integration Works in Practice
The integration flow between the cache and the attention layer is orchestrated by the scheduler and attention backend:
- Scheduling: The scheduler calls
match_prefix()on the radix cache to find cached KV indices for each request. These indices are written intoreq_to_token_pool. - Forward Pass: During the forward pass, the attention backend reads K/V from
MHATokenToKVPool.get_kv_buffer(layer_id)at the indices stored inreq_to_token_pool. - KV Storage: After computing attention for new tokens,
set_kv_buffer()writes the new K/V into the pool. Thesave_kv_cacheflag controls whether this happens (set to False for the last layer in speculative decoding verification). - Cache Update: When the request finishes,
cache_finished_req()inserts the full token sequence into the radix tree, deduplicating against existing entries and freeing unused KV indices.
# unified_attention_with_output -- the torch.compile-compatible entry point
@register_custom_op(mutates_args=["output"])
@register_split_op()
def unified_attention_with_output(query, key, value, output, save_kv_cache, layer_id, ...):
context = get_forward_context()
forward_batch = context.forward_batch
attention_layer = context.attention_layers[layer_id]
# The attn_backend.forward() handles:
# 1. Writing K/V to cache (set_kv_buffer) for new tokens
# 2. Reading K/V from cache for cached prefix tokens
# 3. Computing attention output
ret = forward_batch.attn_backend.forward(
query, key, value, attention_layer, forward_batch, save_kv_cache, **kwargs
)
output.view(ret.shape).copy_(ret)
The RadixAttention layer itself is cache-unaware -- it simply passes data to the attention backend. The magic happens in the attention backend (e.g., FlashInfer, Triton), which knows about the KV cache layout because the ForwardBatch carries all the index information (from req_to_token_pool) and the MHATokenToKVPool provides the buffer handles. This clean separation means the radix cache can be swapped, extended, or disabled without touching the attention computation code.
Summary: Source File Map
| Component | File | Key Classes/Functions | Role |
|---|---|---|---|
| Radix Tree | radix_cache.py | TreeNode, RadixKey, RadixCache |
Token->KV index mapping via compressed trie |
| Memory Pool | memory_pool.py | ReqToTokenPool, MHATokenToKVPool |
Request slot management, GPU KV tensor storage |
| Allocator | allocator.py | TokenToKVPoolAllocator, PagedTokenToKVPoolAllocator |
Free-list management of KV cache indices |
| Eviction | evict_policy.py | LRUStrategy, LFUStrategy, SLRUStrategy, ... |
7 pluggable eviction policies |
| Hierarchical | hiradix_cache.py | HiRadixCache |
GPU->CPU->Storage three-tier caching |
| Sessions | session_aware_cache.py | SessionAwareCache, SessionSlot |
Multi-turn streaming session KV persistence |
| Chunk Cache | chunk_cache.py | ChunkCache |
Fallback when RadixCache is disabled |
| Attention | radix_attention.py | RadixAttention |
nn.Module layer delegating to attn backend |