SGLang RadixCache & Memory Management

Deep Dive from Actual Source Code
Source: sglang/python/sglang/srt/mem_cache/ | Analysis of radix_cache.py, memory_pool.py, allocator.py, hiradix_cache.py, session_aware_cache.py, radix_attention.py

1. Architectural Overview

SGLang's RadixCache is the signature innovation of the SGLang inference engine. It uses a radix tree (compressed trie) to store and reuse KV cache entries across requests that share common token prefixes. This enables massive speedups for multi-turn conversations, few-shot prompting, and shared system prompts.

Three-Layer Memory Architecture

SGLang organizes GPU memory for KV cache into three conceptual layers:

  • ReqToTokenPool -- maps each request slot to its token positions
  • TokenToKVPoolAllocator -- manages free/allocated indices into the KV buffers
  • KVCache (e.g., MHATokenToKVPool) -- holds the actual K/V tensor buffers on GPU

On top of this, the RadixCache radix tree tracks which token sequences have their KV cache already computed and stored at specific indices. When a new request arrives with a prefix that matches an existing cache entry, SGLang skips the prefill computation for that prefix entirely.

Diagram A: SGLang Cache Architecture Overview
Incoming Request [1,2,3,4,5,6,7,8] RadixCache (Radix Tree) root [1,2,3] [8,9,10] match_prefix() returns KV indices + last_node ReqToTokenPool req_to_token[size, max_ctx_len] Maps req_pool_idx -> token locs torch.int32 on device TokenToKVPoolAllocator free_pages: torch.int64 tensor alloc(need_size) -> indices free(indices) -> return to pool MHATokenToKVPool k_buffer[layer][size,H,D] v_buffer[layer][size,H,D] Actual GPU tensor storage GPU HBM K Buffer Layer 0 V Buffer Layer 0 K Buffer Layer 1 V Buffer Layer 1 ... Layer N Each buffer: [num_tokens + page_size, num_kv_heads, head_dim] indexed by allocator indices

2. TreeNode -- The Radix Tree Node

radix_cache.py : TreeNode (line 121)

Every node in the radix tree is a TreeNode instance. The tree is a compressed trie where each edge stores a sequence of token IDs (not just a single token). This compression means a path like [1,2,3] is stored as a single node with key = RadixKey([1,2,3]) rather than three separate nodes.

class TreeNode:
    counter = 0

    def __init__(self, id=None, priority=0):
        self.children    = defaultdict(TreeNode)  # child_key -> TreeNode
        self.parent      = None                  # parent TreeNode
        self.key         = None                  # RadixKey (list of token_ids)
        self.value       = None                  # torch.Tensor of KV cache indices
        self.lock_ref    = 0                     # active request count (prevents eviction)
        self.pin_expiry  = 0.0                   # absolute expiry time for pinning
        self.pin_ttl     = 0                     # TTL in seconds for refresh-on-hit
        self.last_access_time = time.monotonic()  # for LRU eviction
        self.creation_time    = time.monotonic()  # for FIFO eviction
        self.hit_count   = 0                     # for LFU eviction
        self.host_ref_counter = 0               # protects host KV from eviction
        self.host_value  = None                  # CPU-side KV indices (HiRadixCache)
        self.hash_value  = None                  # SHA256 per-page hashes
        self.priority    = priority              # for priority-aware eviction

    @property
    def evicted(self):
        return self.value is None  # GPU value gone, but node may still have host_value

    @property
    def backuped(self):
        return self.host_value is not None  # has CPU backup (for HiRadixCache)

RadixKey -- The Namespace-Aware Key

Keys are wrapped in a RadixKey object which carries both the token_ids list and an optional extra_key string. The extra_key allows isolation of KV cache for different LoRA adapters or cache salts -- entries that share identical token IDs but have different extra_key values never share prefix nodes.

class RadixKey:
    def __init__(self, token_ids: List[int], extra_key: Optional[str] = None,
                 is_bigram: bool = False):
        self.token_ids = token_ids
        self.extra_key = extra_key   # e.g. lora_id, cache_salt
        self.is_bigram = is_bigram   # for EAGLE speculative decoding

Child Key Construction

When looking up children, the get_child_key() function constructs the dictionary key. For page_size=1, it is the first token ID. For page_size>1, it is a tuple of the first page_size tokens. If extra_key is set, it becomes a tuple (extra_key, plain_key) to ensure namespace isolation:

def get_child_key(key: RadixKey, page_size: int = 1):
    if page_size == 1:
        plain_key = key.token_ids[0]
    else:
        plain_key = tuple(key.token_ids[:page_size])
    if key.extra_key is None:
        return plain_key
    else:
        return (key.extra_key, plain_key)
Diagram B: Radix Tree Structure with Example Branches
ROOT key=[] lock_ref=1 key=[1, 2] value=[idx5, idx6] ref=0 key=[3] value=[idx7] ref=1 LOCKED (active req) key=[4, 5] value=[idx8,idx9] ref=0 key=[6, 7] value=[idxA,idxB] ref=0 EVICTABLE LEAF key=[8, 9, 10, 11, 12] value=[idxC..idxG] ref=0 LEGEND Node with active request (lock_ref > 0) Evictable leaf (lock_ref=0, no active children) Internal node (has children, not leaf) value = tensor of KV cache allocator indices Each edge represents a compressed token sequence Node splitting occurs on partial match Paths: [1,2,3] | [1,2,4,5,6,7] | [8,9,10,11,12]

3. RadixCache Initialization

radix_cache.py : RadixCache.__init__ (line 289)

The RadixCache class is initialized with a CacheInitParams object. It sets up the eviction strategy, the key matching function (optimized for page_size=1 vs paged), and initializes the root node:

class RadixCache(BasePrefixCache):
    def __init__(self, params: CacheInitParams):
        self.disable = params.disable
        self.req_to_token_pool = params.req_to_token_pool
        self.token_to_kv_pool_allocator = params.token_to_kv_pool_allocator
        self.page_size = params.page_size

        # Select key matching: token-by-token for page_size=1, paged otherwise
        if self.page_size == 1:
            self.key_match_fn = _key_match_page_size1
        else:
            self.key_match_fn = partial(_key_match_paged, page_size=self.page_size)

        # Select eviction strategy (7 options!)
        if self.eviction_policy == "lru":
            self.eviction_strategy = LRUStrategy()
        elif self.eviction_policy == "lfu":
            self.eviction_strategy = LFUStrategy()
        # ... also: fifo, mru, filo, priority, slru

        self.evictable_leaves = set()
        self.reset()

    def reset(self):
        self.root_node = TreeNode(priority=-sys.maxsize)
        self.root_node.key = RadixKey(token_ids=[], extra_key=None)
        self.root_node.value = []
        self.root_node.lock_ref = 1  # root is always locked
        self.evictable_size_ = 0
        self.protected_size_ = 0
Key Design Insight

The root node always has lock_ref = 1, making it permanently immune to eviction. Every other node's lock_ref is incremented when an active request traverses through it (via inc_lock_ref) and decremented when the request finishes (via dec_lock_ref).

4. match_prefix() -- Finding the Longest Cached Prefix

radix_cache.py : match_prefix (line 378) / _match_prefix_helper (line 671)

When a new request arrives, SGLang calls match_prefix() to find the longest prefix of the request's token IDs that already has KV cache stored in the tree. This is the core operation that enables prefix reuse.

The Public API

def match_prefix(self, params: MatchPrefixParams) -> MatchResult:
    key = params.key
    key, _ = self.maybe_bigram_convert(key)  # EAGLE bigram conversion if needed

    if self.disable or len(key) == 0:
        return empty_match_result()

    # For paged mode, truncate to page-aligned length
    if self.page_size != 1:
        page_aligned_len = len(key) // self.page_size * self.page_size
        key = key[:page_aligned_len]

    value, last_node = self._match_prefix_helper(self.root_node, key)
    if value:
        value = torch.cat(value)  # concatenate KV indices from matched nodes
    return MatchResult(device_indices=value, last_device_node=last_node, ...)

The Internal Tree Walk

The helper walks down the tree following matching children. At each node, it compares the node's key against the remaining query key using key_match_fn. If a match ends inside a node's key (partial match), it splits the node to create a precise boundary:

def _match_prefix_helper(self, node, key):
    access_time = time.monotonic()
    node.last_access_time = access_time  # update for LRU

    child_key = self.get_child_key_fn(key)
    value = []

    while len(key) > 0 and child_key in node.children.keys():
        child = node.children[child_key]
        child.last_access_time = access_time

        prefix_len = self.key_match_fn(child.key, key)  # how many tokens match?

        if prefix_len < len(child.key):
            # PARTIAL MATCH: split the node
            new_node = self._split_node(child.key, child, prefix_len)
            value.append(new_node.value)
            node = new_node
            break
        else:
            # FULL MATCH: consume this node and continue
            value.append(child.value)
            node = child
            key = key[prefix_len:]
            if len(key):
                child_key = self.get_child_key_fn(key)

    return value, node  # list of KV index tensors, last matched node

Node Splitting (_split_node)

When a match ends at position split_len inside a node, the node is split into two: a new parent holding the matched prefix, and the existing node retaining the unmatched suffix:

def _split_node(self, key, child, split_len):
    # Create new_node -> child
    new_node = TreeNode(priority=child.priority)
    new_node.children = {self.get_child_key_fn(key[split_len:]): child}
    new_node.parent = child.parent
    new_node.lock_ref = child.lock_ref  # inherit lock count
    new_node.key = child.key[:split_len]
    new_node.value = child.value[:split_len].clone()

    child.parent = new_node
    child.key = child.key[split_len:]
    child.value = child.value[split_len:].clone()

    new_node.parent.children[self.get_child_key_fn(key)] = new_node
    return new_node
Diagram C: Cache Hit/Miss Flow for match_prefix()
Request tokens [1, 2, 3, 13, 14] match_prefix() Walk radix tree child_key in children? YES Match node [1,2] prefix_len=2, consume key remaining: [3,13,14] Match node [3] prefix_len=1, consume key remaining: [13,14] No child for 13 STOP - end of prefix MatchResult indices=[idx5,idx6,idx7] (3 tokens cached) last_node = node[3], need prefill [13,14] NO CACHE MISS Full prefill required Node Splitting Example Query [1,2] hits node with key=[1,2,3] prefix_len=2 < len(node.key)=3 Before: parent -> [1,2,3] After: parent -> new[1,2] -> old[3] new_node.value = old.value[:2].clone() old.value = old.value[2:].clone() Performance Impact 3 tokens cached = 3 fewer tokens to prefill Saves ~O(n^2) attention computation per token _key_match_page_size1: Token-by-token comparison, returns first mismatch position _key_match_paged: Compares in page_size chunks, returns page-aligned mismatch position

5. insert() -- Adding New KV Cache Entries

radix_cache.py : insert (line 450) / _insert_helper (line 727)

After a request completes its prefill computation, the new KV cache entries are inserted into the tree via insert(). The insert operation walks the tree much like match_prefix, but it also creates new nodes for the unmatched suffix and deduplicates against existing entries.

def _insert_helper(self, node, key, value, priority=0, chunked=False):
    access_time = time.monotonic()
    node.last_access_time = access_time
    node.priority = max(node.priority, priority)  # propagate higher priority

    if len(key) == 0:
        return 0

    child_key = self.get_child_key_fn(key)
    total_prefix_length = 0

    # Walk down existing matching nodes
    while len(key) > 0 and child_key in node.children.keys():
        node = node.children[child_key]
        node.last_access_time = access_time
        prefix_len = self.key_match_fn(node.key, key)
        total_prefix_length += prefix_len
        key = key[prefix_len:]     # consume matched portion
        value = value[prefix_len:] # skip already-cached KV indices

        if prefix_len < len(node.key):
            # Partial match: split and stop
            new_node = self._split_node(node.key, node, prefix_len)
            self._inc_hit_count(new_node, chunked)
            node = new_node
        else:
            self._inc_hit_count(node, chunked)

        if len(key):
            child_key = self.get_child_key_fn(key)

    # Create new node for unmatched suffix
    if len(key):
        new_node = TreeNode(priority=priority)
        new_node.parent = node
        new_node.key = key
        new_node.value = value.clone()
        node.children[child_key] = new_node
        self.evictable_size_ += len(key)
        self._update_leaf_status(node)
        self._update_leaf_status(new_node)

    return total_prefix_length  # how much was already in the tree

cache_finished_req -- The Full Lifecycle

When a request finishes, cache_finished_req() orchestrates the insertion and deduplication:

def cache_finished_req(self, req, is_insert=True):
    token_ids = (req.origin_input_ids + req.output_ids)[:kv_committed_len]
    kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :len(token_ids)]

    if is_insert:
        result = self.insert(InsertParams(key=radix_key, value=values, ...))
        new_prefix_len = result.prefix_len
        # Free DUPLICATES: indices that were already in the tree
        self.token_to_kv_pool_allocator.free(
            kv_indices[req.cache_protected_len : new_prefix_len]
        )

    # Free unaligned tail (page_size > 1 can leave a partial page)
    self.token_to_kv_pool_allocator.free(kv_indices[len(keys):])

    # Release lock on the tree path
    self.dec_lock_ref(req.last_node)
Deduplication

The return value total_prefix_length from _insert_helper tells the caller how many KV indices at the beginning of the inserted sequence were already in the tree. Those duplicates are freed back to the allocator, avoiding memory waste.

6. Eviction Policies

evict_policy.py (all strategies) | radix_cache.py : evict (line 586)

SGLang supports seven different eviction strategies, all implementing the EvictionStrategy abstract class. Each strategy defines a get_priority() method; lower priority values are evicted first (using a min-heap).

StrategyPriority KeyDescription
LRUnode.last_access_timeLeast Recently Used -- oldest access evicted first
LFU(node.hit_count, node.last_access_time)Least Frequently Used -- fewest hits first, then LRU as tiebreaker
FIFOnode.creation_timeFirst In First Out -- oldest creation evicted first
MRU-node.last_access_timeMost Recently Used -- newest access evicted first
FILO-node.creation_timeFirst In Last Out -- newest creation evicted first
Priority(node.priority, node.last_access_time)Priority-aware -- lower priority evicted first, then LRU
SLRU(segment, node.last_access_time)Segmented LRU -- probationary items (hit_count < threshold) evicted before protected

The Eviction Process

def evict(self, params: EvictParams) -> EvictResult:
    num_tokens = params.num_tokens
    leaves = list(self.evictable_leaves)  # pre-tracked set of evictable leaves

    # Build min-heap using strategy priorities
    eviction_heap = [
        (self.eviction_strategy.get_priority(node), node)
        for node in leaves
    ]
    heapq.heapify(eviction_heap)

    num_evicted = 0
    while num_evicted < num_tokens and len(eviction_heap):
        _priority, x = heapq.heappop(eviction_heap)

        # Free the KV cache indices back to allocator
        self.token_to_kv_pool_allocator.free(x.value)
        num_evicted += len(x.value)

        # Remove leaf from tree
        self._delete_leaf(x)

        # If parent becomes a childless unlocked node, add to heap
        if len(x.parent.children) == 0 and x.parent.lock_ref == 0:
            new_priority = self.eviction_strategy.get_priority(x.parent)
            heapq.heappush(eviction_heap, (new_priority, x.parent))

    return EvictResult(num_tokens_evicted=num_evicted)

Evictable Leaf Tracking

Rather than scanning the entire tree, SGLang maintains a evictable_leaves set that is updated incrementally. A node is an evictable leaf when:

  1. It is not evicted (value is not None)
  2. It has lock_ref == 0 (no active requests using it)
  3. All its children are evicted (it is effectively a leaf in the non-evicted subtree)
def _update_leaf_status(self, node):
    if node.evicted or node.lock_ref > 0:
        # Remove from evictable set if present
        if node in self.evictable_leaves:
            self.evictable_leaves.remove(node)
        return

    # Check if any child is NOT evicted
    for child in node.children.values():
        if not child.evicted:
            if node in self.evictable_leaves:
                self.evictable_leaves.remove(node)
            return

    # All children evicted (or no children) -> this is an evictable leaf
    if node not in self.evictable_leaves:
        self.evictable_leaves.add(node)

7. Reference Counting -- Protecting Active Requests

radix_cache.py : inc_lock_ref (line 615) / dec_lock_ref (line 630)

The reference counting system prevents eviction of cache entries that are being actively used by in-flight requests. When a request matches a prefix and starts processing, inc_lock_ref() is called on the terminal node, and it propagates upward to the root:

def inc_lock_ref(self, node) -> IncLockRefResult:
    delta = 0
    while node != self.root_node:
        if node.lock_ref == 0:
            # Transitioning from evictable to protected
            self.evictable_size_ -= len(node.key)
            self.protected_size_ += len(node.key)
            delta -= len(node.key)
        node.lock_ref += 1
        self._update_leaf_status(node)
        node = node.parent
    return IncLockRefResult(delta=delta)

def dec_lock_ref(self, node, ...) -> DecLockRefResult:
    delta = 0
    while node != self.root_node:
        if node.lock_ref == 1:
            # Transitioning from protected to evictable
            self.evictable_size_ += len(node.key)
            self.protected_size_ -= len(node.key)
            delta += len(node.key)
        node.lock_ref -= 1
        self._update_leaf_status(node)
        node = node.parent
    return DecLockRefResult(delta=delta)
Why Walk to Root?

Locking propagates from the terminal matched node all the way up to the root because eviction works bottom-up (leaf nodes first). If a child is locked, the parent must also be protected -- otherwise evicting the parent would leave the child's KV cache in an inconsistent state. The sizes evictable_size_ and protected_size_ are tracked incrementally so the scheduler can make memory-aware decisions without scanning the tree.

8. Memory Pools: ReqToTokenPool & MHATokenToKVPool

memory_pool.py : ReqToTokenPool (line 123) / MHATokenToKVPool (line 733)

ReqToTokenPool -- Request to Token Mapping

This pool maps each request slot to the token positions it occupies in the KV cache. It is a 2D tensor of shape [size, max_context_len] stored as torch.int32 on the GPU device:

class ReqToTokenPool:
    def __init__(self, size, max_context_len, device, enable_memory_saver):
        self.req_to_token = torch.zeros(
            (size, max_context_len), dtype=torch.int32, device=device
        )
        self.free_slots = list(range(size))  # Python list of available slot indices

    def alloc(self, reqs: list[Req]) -> Optional[List[int]]:
        # Reuse existing slots for chunked prefill (continuing across chunks)
        reusing = [i for i, r in enumerate(reqs) if r.req_pool_idx is not None]
        need_size = len(reqs) - len(reusing)
        if need_size > len(self.free_slots):
            return None  # not enough request slots
        select_index = self.free_slots[:need_size]
        self.free_slots = self.free_slots[need_size:]
        # Assign req_pool_idx to each request
        offset = 0
        for r in reqs:
            if r.req_pool_idx is None:
                r.req_pool_idx = select_index[offset]
                offset += 1
        return [r.req_pool_idx for r in reqs]

    def free(self, req):
        self.free_slots.append(req.req_pool_idx)
        req.req_pool_idx = None

MHATokenToKVPool -- The Actual KV Tensor Storage

This is where the GPU memory physically lives. For Multi-Head Attention (MHA), it stores separate K and V buffers per layer:

class MHATokenToKVPool(KVCache):
    def _create_buffers(self):
        # [size + page_size, head_num, head_dim] for each layer
        # The padded slot 0 is used for dummy outputs from padded tokens
        self.k_buffer = [
            torch.zeros(
                (self.size + self.page_size, self.head_num, self.head_dim),
                dtype=self.store_dtype, device=self.device
            )
            for _ in range(self.layer_num)
        ]
        self.v_buffer = [
            torch.zeros(
                (self.size + self.page_size, self.head_num, self.v_head_dim),
                dtype=self.store_dtype, device=self.device
            )
            for _ in range(self.layer_num)
        ]

    def set_kv_buffer(self, layer, loc, cache_k, cache_v, ...):
        # Called during forward pass to store computed K/V into cache
        # Uses optimized store_cache JIT kernel when possible
        _set_kv_buffer_impl(
            cache_k, cache_v,
            self.k_buffer[layer_id - self.start_layer],
            self.v_buffer[layer_id - self.start_layer],
            loc, ...
        )
Diagram D: Memory Pool Layout
ReqToTokenPool req_to_token: torch.int32[size, max_context_len] slot pos 0 pos 1 pos 2 pos 3 pos 4 ... 0 42 43 108 109 255 1 42 43 44 77 78 2 (free slot) Rows 0 and 1 share indices 42,43 -- same prefix from RadixCache! indices point into MHATokenToKVPool (GPU) k_buffer[0] [size+ps, H, D] idx 42,43=shared prefix v_buffer[0] [size+ps, H, V_D] same indexing k_buffer[1], v_buffer[1] ... k_buffer[N-1], v_buffer[N-1] Allocated (shared prefix) Allocated (unique) Free TokenToKVPoolAllocator free_pages: [5, 6, 10, 11, 15, 16, ...] alloc(3) -> [5,6,10], free_pages=[11,15,16,...] free([42,43]) -> free_pages=[42,43,11,15,16,...] release_pages: [old freed indices, unsorted] merge_and_sort_free() merges + sorts both need_sort flag controls when merge happens

9. TokenToKVPoolAllocator -- Allocation & Deallocation

allocator.py : TokenToKVPoolAllocator (line 117) / PagedTokenToKVPoolAllocator (line 356)

SGLang provides two allocator variants. The base TokenToKVPoolAllocator operates at token granularity (page_size=1), while PagedTokenToKVPoolAllocator operates at page granularity for higher efficiency with paged attention.

TokenToKVPoolAllocator (page_size=1)

class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
    def __init__(self, size, dtype, device, kvcache, need_sort):
        super().__init__(size, 1, dtype, device, kvcache, need_sort)
        self.clear()

    def clear(self):
        # Slot 0 is reserved as a padding dummy slot
        self.free_pages = torch.arange(1, self.size + 1,
                                        dtype=torch.int64, device=self.device)
        self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)

    def alloc(self, need_size):
        # If sorted mode and insufficient free pages, merge with release_pages
        if self.need_sort and need_size > len(self.free_pages):
            self.merge_and_sort_free()

        if need_size > len(self.free_pages):
            return None  # OOM

        select_index = self.free_pages[:need_size]
        self.free_pages = self.free_pages[need_size:]
        return select_index

    def free(self, free_index):
        if free_index.numel() == 0:
            return
        if self.is_not_in_free_group:
            if self.need_sort:
                # Deferred: add to release_pages (will be merged later)
                self.release_pages = torch.cat((self.release_pages, free_index))
            else:
                # Immediate: prepend to free_pages
                self.free_pages = torch.cat((free_index, self.free_pages))
        else:
            # Batched free: accumulate in group, apply at free_group_end()
            self.free_group.append(free_index)

PagedTokenToKVPoolAllocator -- Triton-Accelerated Page Allocation

For paged operation, the allocator uses Triton kernels for batch allocation during extend and decode phases. The alloc_extend_kernel handles three parts: filling the old partial page, allocating full new pages, and filling a new partial page at the end.

class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
    def alloc_extend(self, prefix_lens, ..., seq_lens, ..., last_loc, extend_num_tokens):
        out_indices = torch.empty((extend_num_tokens,), dtype=torch.int64, device=...)

        # Triton kernel computes per-request page allocation in parallel on GPU
        alloc_extend_kernel[(bs,)](
            prefix_lens, seq_lens, last_loc,
            self.free_pages, out_indices,
            next_power_of_2(bs), self.page_size
        )
        num_new_pages = get_num_new_pages(seq_lens=..., page_size=..., prefix_lens=...)
        self.free_pages = self.free_pages[num_new_pages:]
        return out_indices

    def free(self, free_index):
        # Convert token indices back to page indices
        free_page_indices = torch.unique(free_index // self.page_size)
        self.free_pages = torch.cat((free_page_indices, self.free_pages))
Free Group Batching

The allocator supports batched free via free_group_begin()/free_group_end(). Between these calls, all free() operations accumulate in a list and are applied as a single torch.cat at the end, reducing the number of tensor concatenation operations.

Sorted vs Unsorted Mode

The need_sort flag controls whether freed indices go to release_pages (deferred, sorted at merge time) or directly to free_pages. Sorted allocation is important for some attention backends that benefit from locality of KV cache indices in GPU memory.

10. HiRadixCache -- Three-Tier Hierarchical Caching

hiradix_cache.py : HiRadixCache (line 56)

The HiRadixCache extends RadixCache with a three-tier memory hierarchy: GPU (device) -> CPU (host) -> Storage (disk/remote). When GPU memory is exhausted, KV cache entries are demoted to CPU memory; when CPU memory is full, entries can be further spilled to persistent storage.

Initialization -- Setting Up the Hierarchy

class HiRadixCache(RadixCache):
    def __init__(self, params, server_args):
        self.kv_cache = params.token_to_kv_pool_allocator.get_kvcache()

        # Create CPU-side KV pool matching the GPU KV cache type
        if isinstance(self.kv_cache, MHATokenToKVPool):
            self.token_to_kv_pool_host = MHATokenToKVPoolHost(
                self.kv_cache, server_args.hicache_ratio,
                server_args.hicache_size, self.page_size, ...
            )
        elif isinstance(self.kv_cache, MLATokenToKVPool):
            self.token_to_kv_pool_host = MLATokenToKVPoolHost(...)

        # Cache controller manages async GPU<->CPU<->Storage transfers
        self.cache_controller = HiCacheController(
            params.token_to_kv_pool_allocator,
            self.token_to_kv_pool_host,
            self.page_size, self.tp_group,
            write_policy=server_args.hicache_write_policy,    # write_back | write_through
            storage_backend=server_args.hicache_storage_backend,
            ...
        )

        # Write-through: backup after hit_count reaches threshold
        self.write_through_threshold = (
            1 if server_args.hicache_write_policy == "write_through" else 2
        )
        self.load_back_threshold = 10  # minimum tokens to justify a load-back
        self.evictable_host_leaves = set()  # host-side evictable tracking

Eviction with Host Backup

The eviction process in HiRadixCache is significantly more complex than in the base RadixCache. Instead of simply freeing GPU memory, it can demote entries to CPU:

def evict(self, params) -> EvictResult:
    # Build eviction heap from evictable leaves (same as base)
    while num_evicted < num_tokens and len(eviction_heap):
        _priority, x = heapq.heappop(eviction_heap)

        if x.lock_ref > 0:
            continue

        if self._is_pinned(x):
            # Pinned node: demote to host, keep in tree
            if x.backuped:
                num_evicted += self._evict_backuped(x)  # just free GPU indices
            else:
                written = self.write_backup(x, write_back=True)
                ...

        elif not x.backuped:
            if self.cache_controller.write_policy == "write_back":
                # Write-back: backup to host before freeing GPU
                num_evicted += self.write_backup(x, write_back=True)
            else:
                # Write-through: already backed up, just free GPU
                num_evicted += self._evict_regular(x)
        else:
            num_evicted += self._evict_backuped(x)

def _evict_backuped(self, node):
    # GPU -> CPU demotion: free device indices but keep node in tree
    num_evicted = self.cache_controller.evict_device(node.value)
    node.value = None   # mark as evicted from GPU
    # node.host_value still has the CPU indices
    self._update_leaf_status(node)
    self._update_host_leaf_status(node)
    return num_evicted

Load-Back: CPU -> GPU Restoration

When a match_prefix() hits a node whose GPU data was evicted but still exists on CPU, load_back() restores it:

def load_back(self, node, mem_quota=None):
    # Walk up to find the nearest non-evicted ancestor
    last_hit_node = node
    nodes_to_load = []
    while node.evicted:
        assert node.backuped  # must have host backup
        nodes_to_load.insert(0, node)
        node = node.parent

    # Protect ancestor from eviction during load
    self.inc_lock_ref(ancester_node)

    # Concatenate all host indices for batch transfer
    host_indices = torch.cat([n.host_value for n in nodes_to_load])

    # Skip if too small (< 10 tokens) or exceeds memory quota
    if len(host_indices) < self.load_back_threshold:
        return None

    # Allocate GPU space and initiate async DMA transfer
    device_indices = self.cache_controller.load(host_indices=host_indices, ...)
    if device_indices is None:
        # Not enough GPU memory, try evicting first
        self.evict(EvictParams(num_tokens=len(host_indices)))
        device_indices = self.cache_controller.load(...)

    # Restore node values from device_indices
    offset = 0
    for node in nodes_to_load:
        node.value = device_indices[offset : offset + len(node.host_value)].clone()
        offset += len(node.host_value)
    self.evictable_size_ += len(device_indices)

Write-Through: Hit Count Triggered Backup

With the write-through policy, nodes are proactively backed up to CPU after reaching a hit count threshold:

def _inc_hit_count(self, node, chunked=False):
    if self.cache_controller.write_policy == "write_back" or chunked:
        return
    node.hit_count += 1
    if not node.backuped:
        if node.hit_count >= self.write_through_threshold:
            self.write_backup(node)  # async GPU->CPU copy
Diagram E: Three-Tier HiRadixCache Architecture
HiRadixCache: GPU -> CPU -> Storage Hierarchy TIER 1: GPU (Device) MHATokenToKVPool / MLATokenToKVPool -- fastest, most limited Active Nodes lock_ref > 0 PROTECTED from eviction Evictable Nodes lock_ref == 0, value != None CAN be evicted by LRU/LFU Backed-Up Nodes value != None, host_value != None SAFE to free GPU (host copy exists) Evicted Nodes value == None (GPU freed) host_value still available evict() / write_backup() Async DMA: GPU -> CPU pinned memory TIER 2: CPU (Host) MHATokenToKVPoolHost / MLATokenToKVPoolHost -- large, medium speed Host KV Buffers node.host_value indices host_ref_counter Protects during async ops evictable_host_leaves set Tracks CPU-evictable nodes load_back() write_backup_storage() Async: CPU -> Disk/Remote via HiCacheController TIER 3: Storage (Persistent) Optional: disk, NFS, nixl remote -- unlimited, slowest SHA256-hashed block keys Prefetch via hash_value matching Policies: best_effort | wait | timeout prefetch() Write Policies: write_through (backup on hit_count threshold) | write_back (backup only on eviction) | write_through_selective

Host Eviction

When even CPU memory is exhausted, evict_host() removes entries from CPU, fully deleting the tree node:

def evict_host(self, num_tokens):
    # Uses same eviction strategy but on evictable_host_leaves
    leaves = list(self.evictable_host_leaves)
    eviction_heap = [(self.eviction_strategy.get_priority(n), n) for n in leaves]
    heapq.heapify(eviction_heap)

    while num_evicted < num_tokens and eviction_heap:
        _priority, x = heapq.heappop(eviction_heap)
        if not x.evicted:   # only evict nodes already gone from GPU
            continue
        if x.host_ref_counter > 0:  # protected by ongoing prefetch/backup/pin
            continue

        num_evicted += self.cache_controller.evict_host(x.host_value)
        # Fully remove node from tree
        x.parent.children.pop(self.get_child_key_fn(x.key))

11. SessionAwareCache -- Streaming Session KV Management

session_aware_cache.py : SessionAwareCache (line 112)

The SessionAwareCache is a decorator (wrapper) around any BasePrefixCache implementation that adds support for streaming sessions. In streaming mode, KV cache is held across multiple request turns within the same session, avoiding re-computation between turns of a multi-turn conversation.

class SessionAwareCache(BasePrefixCache):
    def __init__(self, inner: BasePrefixCache):
        self.inner = inner                           # wrapped cache (RadixCache, HiRadixCache, etc.)
        self.slots: Dict[str, SessionSlot] = {}       # session_id -> slot

SessionSlot -- Per-Session KV State

@dataclass
class SessionSlot:
    virtual_node: _VirtualNode = field(default_factory=_VirtualNode)
    req_pool_idx: Optional[int] = None     # allocated request pool slot
    kv_committed_len: int = 0               # how many tokens have KV cached
    kv_allocated_len: int = 0               # how many tokens are allocated
    last_node: Any = None                   # radix tree node (for dec_lock_ref on close)
    cache_protected_len: int = 0            # overlap with radix tree

The key insight is that for streaming sessions, match_prefix() directly restores KV state from the SessionSlot instead of walking the radix tree. The _VirtualNode sentinel allows inc_lock_ref/dec_lock_ref to be no-ops for session-managed KV, since the session itself holds the KV alive:

def match_prefix(self, params) -> MatchResult:
    req = params.req
    if not _is_streaming(req):
        return self.inner.match_prefix(params)  # pass through for non-streaming

    slot = self.slots.get(req.session.session_id)
    if slot is None or slot.req_pool_idx is None:
        return self.inner.match_prefix(params)  # first turn: use radix tree

    # Subsequent turns: restore from session slot
    slot.restore_to_req(req)
    prefix_len = min(req.kv_committed_len, max(len(params.key.token_ids) - 1, 0))
    device_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :prefix_len]
    return MatchResult(device_indices=device_indices, last_device_node=slot.virtual_node, ...)

12. RadixAttention Layer -- Integrating Cache in Forward Pass

layers/radix_attention.py : RadixAttention (line 47)

The RadixAttention module is the attention layer used in SGLang model implementations. Despite its name, it does not directly reference the radix cache tree -- instead, it delegates to the attention backend, which reads from the KV cache buffers at the indices provided by the cache system.

class RadixAttention(nn.Module):
    def __init__(self,
        num_heads, head_dim, scaling, num_kv_heads, layer_id,
        logit_cap=0.0, v_head_dim=-1, sliding_window_size=-1,
        is_cross_attention=False, pos_encoding_mode="NONE", ...
    ):
        self.tp_q_head_num = num_heads
        self.tp_k_head_num = num_kv_heads
        self.head_dim = head_dim
        self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
        self.layer_id = layer_id
        self.sliding_window_size = sliding_window_size or -1

    def forward(self, q, k, v, forward_batch: ForwardBatch,
                save_kv_cache=True, **kwargs):
        if k is not None:
            k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
            v = v.view(-1, self.tp_v_head_num, self.v_head_dim)

        # For extend (prefill) mode with torch.compile context:
        if forward_batch.forward_mode.is_extend() and get_forward_context():
            output = torch.empty_like(q) if ... else ...
            unified_attention_with_output(q, k, v, output, save_kv_cache, self.layer_id)
            return output
        else:
            # Delegate to attention backend (FlashAttention, Triton, etc.)
            return forward_batch.attn_backend.forward(
                q, k, v, self, forward_batch, save_kv_cache, **kwargs
            )

How Cache Integration Works in Practice

The integration flow between the cache and the attention layer is orchestrated by the scheduler and attention backend:

  1. Scheduling: The scheduler calls match_prefix() on the radix cache to find cached KV indices for each request. These indices are written into req_to_token_pool.
  2. Forward Pass: During the forward pass, the attention backend reads K/V from MHATokenToKVPool.get_kv_buffer(layer_id) at the indices stored in req_to_token_pool.
  3. KV Storage: After computing attention for new tokens, set_kv_buffer() writes the new K/V into the pool. The save_kv_cache flag controls whether this happens (set to False for the last layer in speculative decoding verification).
  4. Cache Update: When the request finishes, cache_finished_req() inserts the full token sequence into the radix tree, deduplicating against existing entries and freeing unused KV indices.
# unified_attention_with_output -- the torch.compile-compatible entry point
@register_custom_op(mutates_args=["output"])
@register_split_op()
def unified_attention_with_output(query, key, value, output, save_kv_cache, layer_id, ...):
    context = get_forward_context()
    forward_batch = context.forward_batch
    attention_layer = context.attention_layers[layer_id]

    # The attn_backend.forward() handles:
    # 1. Writing K/V to cache (set_kv_buffer) for new tokens
    # 2. Reading K/V from cache for cached prefix tokens
    # 3. Computing attention output
    ret = forward_batch.attn_backend.forward(
        query, key, value, attention_layer, forward_batch, save_kv_cache, **kwargs
    )
    output.view(ret.shape).copy_(ret)
The Cache-Attention Contract

The RadixAttention layer itself is cache-unaware -- it simply passes data to the attention backend. The magic happens in the attention backend (e.g., FlashInfer, Triton), which knows about the KV cache layout because the ForwardBatch carries all the index information (from req_to_token_pool) and the MHATokenToKVPool provides the buffer handles. This clean separation means the radix cache can be swapped, extended, or disabled without touching the attention computation code.

Summary: Source File Map

ComponentFileKey Classes/FunctionsRole
Radix Tree radix_cache.py TreeNode, RadixKey, RadixCache Token->KV index mapping via compressed trie
Memory Pool memory_pool.py ReqToTokenPool, MHATokenToKVPool Request slot management, GPU KV tensor storage
Allocator allocator.py TokenToKVPoolAllocator, PagedTokenToKVPoolAllocator Free-list management of KV cache indices
Eviction evict_policy.py LRUStrategy, LFUStrategy, SLRUStrategy, ... 7 pluggable eviction policies
Hierarchical hiradix_cache.py HiRadixCache GPU->CPU->Storage three-tier caching
Sessions session_aware_cache.py SessionAwareCache, SessionSlot Multi-turn streaming session KV persistence
Chunk Cache chunk_cache.py ChunkCache Fallback when RadixCache is disabled
Attention radix_attention.py RadixAttention nn.Module layer delegating to attn backend