Table of Contents
1. Language Primitives: gen(), select(), fork(), join(), image()
SGLang's frontend language defines a small set of powerful primitives in python/sglang/lang/api.py. Each primitive creates an IR node that the StreamExecutor later interprets.
1.1 sgl.gen() -- The Core Generation Primitive
The gen() function accepts all sampling parameters and returns either an SglGen (for free-form generation) or an SglSelect (when choices is provided):
def gen(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
# ... 15+ more parameters
dtype: Optional[Union[type, str]] = None,
choices: Optional[List[str]] = None,
regex: Optional[str] = None,
json_schema: Optional[str] = None,
):
"""Call the model to generate."""
if choices:
# When choices are provided, gen() becomes a select()
return SglSelect(
name, choices,
0.0 if temperature is None else temperature,
token_length_normalized if choices_method is None else choices_method,
)
# Validate regex if provided
if regex is not None:
re.compile(regex) # Raises re.error if invalid
return SglGen(name, max_tokens, min_tokens, n, stop, ...)
name parameter serves as the variable binding -- the generated text is stored in state[name], enabling later retrieval via s["answer"]. Structured generation constraints (regex, json_schema) are baked into the IR node and forwarded to the grammar backend at execution time.
1.2 Type-Safe Generation: gen_int(), gen_string()
Convenience wrappers that set the dtype field on the IR node. In the runtime endpoint, dtype is converted to a regex pattern:
REGEX_INT = r"[-+]?[0-9]+[ \n]*"
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+[ \n]*"
REGEX_BOOL = r"(True|False)"
REGEX_STR = r"\"[\w\d\s]*\""
1.3 sgl.select() -- Constrained Choice
def select(
name: Optional[str] = None,
choices: Optional[List[str]] = None,
temperature: float = 0.0,
choices_method: ChoicesSamplingMethod = token_length_normalized,
):
assert choices is not None
return SglSelect(name, choices, temperature, choices_method)
select() evaluates log-probabilities for each choice string and picks the best one using token_length_normalized scoring by default. The backend computes conditional logprobs by sending each prefix + choice to the server with max_new_tokens=0.
1.4 sgl.image() and sgl.video()
def image(expr: SglExpr):
return SglImage(expr)
def video(path: str, num_frames: int):
return SglVideo(path, num_frames)
These encode image/video data as base64 and insert the chat template's image_token into the text stream. The encoded data is stored in StreamExecutor.images_.
1.5 Role Primitives: system(), user(), assistant()
def _role_common(name: str, expr: Optional[SglExpr] = None):
if expr is None:
return SglExprList([SglRoleBegin(name), SglRoleEnd(name)])
else:
return SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
These wrap content between role markers. When used as context managers in ProgramState, they support the with s.user(): syntax pattern.
2. How @sgl.function Decorated Functions Work
The @sgl.function decorator wraps a user function into an SglFunction object, which is the entry point for all SGLang programs.
class SglFunction:
def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
self.func = func
self.num_api_spec_tokens = num_api_spec_tokens
self.bind_arguments = bind_arguments or {}
# Parse arguments -- first arg MUST be 's' (the ProgramState)
argspec = inspect.getfullargspec(func)
assert argspec.args[0] == 's', 'The first argument must be "s"'
self.arg_names = argspec.args[1:]
self.arg_defaults = argspec.defaults if argspec.defaults is not None else []
def run(self, *args, max_new_tokens=128, temperature=1.0, ...):
from sglang.lang.interpreter import run_program
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens, temperature=temperature, ...
)
backend = backend or global_config.default_backend
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
def __call__(self, *args, **kwargs):
# Detect if we are in a tracing scope
tracing_scope = TracingScope.get_current_scope()
if tracing_scope is None:
return self.run(*args, **kwargs) # Normal execution
else:
return self.trace(*args, **kwargs) # Tracing for prefix caching
@sgl.functioncreates anSglFunctionwrapping the user's Python function- Calling
.run()(or.__call__) creates defaultSglSamplingParamsand invokesrun_program() run_programcreates aStreamExecutorandProgramState, then callsprogram.func(state, *args)- Inside the function,
s += exprcallsProgramState.__iadd__which submits the IR node to the executor - The executor's background thread processes each node sequentially
2.1 Batch Execution: run_batch()
def run_batch(self, batch_kwargs, *, num_threads="auto", ...):
# Auto-detect threads: max(96, cpu_count * 16)
if num_threads == "auto":
num_threads = max(96, multiprocessing.cpu_count() * 16)
# Pre-cache common prefix via tracing (if batch > 1)
if global_config.enable_precache_with_tracing and len(batch_arguments) > 1:
cache_program(program, backend)
# Submit all programs to ThreadPoolExecutor
with ThreadPoolExecutor(num_threads) as executor:
futures = [executor.submit(run_program, ...) for args in batch_arguments]
rets = [f.result() for f in futures]
3. IR Representation: How Language Constructs Become IR Nodes
The intermediate representation in python/sglang/lang/ir.py defines a tree of expression nodes. All nodes inherit from SglExpr and support concatenation via __add__.
3.1 The SglExpr Base Class
class SglExpr:
node_ct = 0 # Global counter for unique node IDs
def __init__(self):
self.node_id = SglExpr.node_ct
self.prev_node = None
self.pid = None
SglExpr.node_ct += 1
def __add__(self, other):
# "Hello " + sgl.gen("x") => SglExprList
if isinstance(other, str):
other = SglConstantText(other)
return self.concatenate_ir(self, other)
def concatenate_ir(self, a, b):
# Flatten nested SglExprLists
if isinstance(a, SglExprList):
if isinstance(b, SglExprList):
return SglExprList(a.expr_list + b.expr_list)
else:
return SglExprList(a.expr_list + [b])
return SglExprList([a, b])
3.2 Complete IR Node Hierarchy
| IR Node | Created By | Purpose |
|---|---|---|
SglConstantText | String literals, __add__ | Append constant text to prompt |
SglGen | sgl.gen() | Call model for generation; stores SglSamplingParams |
SglSelect | sgl.select() | Choose from fixed list using log-probabilities |
SglImage | sgl.image() | Embed image data (base64 encoded) |
SglVideo | sgl.video() | Embed video data (frame extraction) |
SglRoleBegin/End | sgl.user()/assistant() | Chat template role markers |
SglExprList | + operator | Sequence of expressions (flattened) |
SglFork | s.fork(n) | Create n parallel execution branches |
SglGetForkItem | Indexing fork result | Access a specific fork branch |
SglVariable | Internal | Reference to a named variable from another executor |
SglConcateAndAppend | join("concate_and_append") | Merge KV caches from forked branches |
SglCommitLazy | Internal | Force the backend to process pending text (KV cache warmup) |
SglSeparateReasoning | sgl.separate_reasoning() | Parse reasoning tokens from output (CoT models) |
3.3 SglSamplingParams: Frontend Sampling Configuration
@dataclasses.dataclass
class SglSamplingParams:
max_new_tokens: int = 128
min_new_tokens: int = 0
n: int = 1
stop: Union[str, List[str]] = ()
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1 # -1 means disable
min_p: float = 0.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
ignore_eos: bool = False
json_schema: Optional[str] = None
dtype: Optional[str] = None # for constrained gen
regex: Optional[str] = None # for constrained gen
def to_srt_kwargs(self):
# Convert to server-side format
return {
"max_new_tokens": self.max_new_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"regex": self.regex,
"json_schema": self.json_schema,
...
}
4. StreamExecutor: How the Interpreter Executes SGLang Programs
The StreamExecutor in python/sglang/lang/interpreter.py is the heart of SGLang's execution engine. It runs a background thread that processes IR nodes from a queue.
4.1 Initialization and Thread Worker
class StreamExecutor:
def __init__(self, backend, arguments, default_sampling_para, ...):
self.sid = uuid.uuid4().hex # Unique session ID
self.backend = backend
self.arguments = arguments
self.default_sampling_para = default_sampling_para
self.variables = {} # name -> value store
self.variable_event = {} # name -> threading.Event
self.meta_info = {} # name -> generation metadata
self.text_ = "" # Full prompt text accumulated
self.messages_ = [] # Chat messages (OpenAI format)
self.images_ = [] # Vision data
self.fork_start_text_pos = None # For fork/join tracking
# Background worker thread
self.queue = queue.Queue()
self.worker = threading.Thread(
target=contextvars.copy_context().run,
args=(self._thread_worker_func,)
)
self.worker.start()
4.2 The Dispatch Loop: _execute()
Every IR node type has a dedicated handler. The _execute method is the central dispatch:
def _execute(self, other):
if isinstance(other, str):
other = SglConstantText(other)
if isinstance(other, SglConstantText):
self._execute_fill(other.value) # Append text
elif isinstance(other, SglGen):
self._execute_gen(other) # Generate via backend
elif isinstance(other, SglSelect):
self._execute_select(other) # Select from choices
elif isinstance(other, SglExprList):
for x in other.expr_list:
self._execute(x) # Recurse into list
elif isinstance(other, SglRoleBegin):
self._execute_role_begin(other) # Apply chat template prefix
elif isinstance(other, SglRoleEnd):
self._execute_role_end(other) # Apply suffix, create message
elif isinstance(other, SglImage):
self._execute_image(other) # Base64 encode, add token
elif isinstance(other, SglConcateAndAppend):
if self.backend.support_concate_and_append:
self._execute_concatenate_and_append_kv_cache(other) # KV reuse
else:
self._execute_concatenate_and_append_text(other) # Text fallback
...
4.3 _execute_gen: The Generation Path
def _execute_gen(self, expr: SglGen):
sampling_params = self._resolve_sampling_params(expr.sampling_params)
name = expr.name
if not self.stream:
comp, meta_info = self.backend.generate(self, sampling_params=sampling_params)
self.text_ += comp # Append generated text
self.variables[name] = comp # Store in variable dict
self.meta_info[name] = meta_info
self.variable_event[name].set() # Signal that variable is ready
else:
generator = self.backend.generate_stream(self, sampling_params=sampling_params)
self.variables[name] = ""
for comp, meta_info in generator:
self.text_ += comp
self.variables[name] += comp
self.stream_var_event[name].set() # Notify streaming consumers
4.4 Fork/Join: Parallel Execution
def fork(self, size: int = 1, position_ids_offset=None):
# Commit pending text to KV cache first
if size > 1 and str(self.text_):
self.submit(SglCommitLazy())
self.sync()
# Create child executors sharing the same state
exes = [StreamExecutor(self.backend, self.arguments, ...) for _ in range(size)]
for i in range(size):
exes[i].variables = dict(self.variables) # Copy variables
exes[i].text_ = str(self.text_) # Copy prompt text
exes[i].messages_ = list(self.messages_) # Copy chat history
exes[i].fork_start_text_pos = len(self.text_) # Mark where fork began
return exes
Each forked executor gets a shallow copy of the parent's state (variables, text, messages) and its own background worker thread. The fork_start_text_pos records where the shared prefix ends, enabling the concate_and_append join mode to physically merge KV caches for the new text only.
4.5 ProgramState: The User-Facing State Object
class ProgramState:
def __iadd__(self, other): # s += sgl.gen("x")
self.stream_executor.submit(other)
return self
def __getitem__(self, name): # s["answer"]
return self.get_var(name) # Blocks until variable_event is set
def fork(self, size):
stream_executors = self.stream_executor.fork(size)
states = [ProgramState(x) for x in stream_executors]
return ProgramStateGroup(states, self)
# Context manager for chat roles
def user(self, expr=None):
if expr is not None:
role_expr = SglExprList([SglRoleBegin("user"), expr, SglRoleEnd("user")])
self.stream_executor.submit(role_expr)
else:
@contextmanager
def role_scope():
self.stream_executor.submit(SglRoleBegin("user"))
yield
self.stream_executor.submit(SglRoleEnd("user"))
return role_scope()
4.6 Parameter Resolution
The _resolve_sampling_params method merges per-call overrides with defaults:
def _resolve_sampling_params(self, sampling_params):
"""
Base = default_sampling_para from sgl.function.run()
Override = per-gen sampling_params from sgl.gen()
"""
clone = copy.deepcopy(self.default_sampling_para)
for item in ["max_new_tokens", "temperature", "top_p", "regex", ...]:
value = getattr(sampling_params, item, None)
if value is not None:
setattr(clone, item, value)
# Add chat template stop strings
if self.chat_template.stop_str:
clone.stop += self.chat_template.stop_str
return clone
5. Runtime Endpoint: concatenate_and_append for KV Cache Reuse
The RuntimeEndpoint in python/sglang/lang/backend/runtime_endpoint.py bridges the frontend language to the SGLang server over HTTP.
5.1 generate(): Sending Requests to the Server
def generate(self, s: StreamExecutor, sampling_params: SglSamplingParams):
self._handle_dtype_to_regex(sampling_params) # Convert dtype -> regex
data = {
"text": s.text_, # Full accumulated text
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
**sampling_params.to_srt_kwargs(), # All sampling params
},
}
self._add_images(s, data) # Attach base64 image data if present
res = http_request(self.base_url + "/generate", json=data)
obj = res.json()
return obj["text"], obj["meta_info"]
s.text_), the server's RadixAttention cache can automatically match and reuse the KV cache for any shared prefix. This is the key to SGLang's efficiency for multi-turn conversations and fork/join patterns.
5.2 concatenate_and_append: Physical KV Cache Merging
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
"""Merge KV caches from multiple forked branches back into parent."""
res = http_request(
self.base_url + "/concate_and_append_request",
json={"src_rids": src_rids, "dst_rid": dst_rid},
)
self._assert_success(res)
This is called during join("concate_and_append") mode, where the server physically concatenates the KV cache entries from multiple forked branches (src_rids) and appends them to the parent session (dst_rid). The executor-side implementation:
def _execute_concatenate_and_append_kv_cache(self, expr):
self_len = len(self.text_)
# Commit all children's pending operations
for s in expr.states:
s.stream_executor.submit(SglCommitLazy())
for s in expr.states:
s.stream_executor.sync()
self.text_ += s.stream_executor.text_[s.stream_executor.fork_start_text_pos:]
# Ask the server to merge KV caches
src_rids = [state.stream_executor.sid for state in expr.states]
self.backend.concatenate_and_append(src_rids, self.sid)
5.3 select(): Log-Probability Based Selection
def select(self, s, choices, temperature, choices_method):
# Step 1: Cache the prefix (max_new_tokens=0)
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
obj = self._generate_http_request(s, data)
prompt_len = obj["meta_info"]["prompt_tokens"]
# Step 2: Compute logprobs for each choice
data = {
"text": [s.text_ + c for c in choices], # Batch all choices
"sampling_params": {"max_new_tokens": 0},
"return_logprob": True,
}
obj = self._generate_http_request(s, data)
# Step 3: Apply the selection method (e.g., token_length_normalized)
return choices_method(choices=choices, ...)
5.4 The Runtime Wrapper Class
The Runtime class wraps server lifecycle management. It spawns the server in a subprocess and polls /health_generate until ready:
class Runtime:
def __init__(self, log_level="error", launch_timeout=300.0, *args, **kwargs):
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
# Find available port
for port in range(self.server_args.port, 40000):
if is_port_available(port): break
# Launch server in subprocess
proc = ctx.Process(target=launch_server, args=(self.server_args,))
proc.start()
# Poll /health_generate until ready
while time.time() - start_time < launch_timeout:
response = session.get(f"{self.url}/health_generate")
if response.status_code == 200: break
self.endpoint = RuntimeEndpoint(self.url)
6. Structured Generation: Grammar Backends
SGLang supports three grammar backends for structured output generation: xgrammar (default), llguidance, and outlines. Each compiles constraints (JSON schema, regex, EBNF) into a token-level mask applied during sampling.
6.1 BaseGrammarBackend: The Backend Interface
class BaseGrammarBackend:
def __init__(self):
self.executor = ThreadPoolExecutor() # Async grammar compilation
self.cache: Dict[Tuple[str,str], BaseGrammarObject] = {}
def get_cached_or_future_value(self, key, require_reasoning):
value = self.cache.get(key)
if value:
copied = value.copy() # Return cached copy (with fresh matcher state)
return copied, True
# Submit async compilation
value = self.executor.submit(self._init_value_dispatch, key, require_reasoning)
return value, False # Returns Future
def _init_value_dispatch(self, key, require_reasoning):
key_type, key_string = key
if key_type == "json": grammar = self.dispatch_json(key_string)
elif key_type == "regex": grammar = self.dispatch_regex(key_string)
elif key_type == "ebnf": grammar = self.dispatch_ebnf(key_string)
elif key_type == "structural_tag": grammar = self.dispatch_structural_tag(key_string)
return grammar
6.2 BaseGrammarObject: The Per-Request Grammar State
class BaseGrammarObject:
def accept_token(self, token: int): # Advance grammar state
def rollback(self, k: int): # Undo k tokens
def is_terminated(self): # Grammar accepted
def allocate_vocab_mask(self, V, B, dev): # Create mask tensor
def fill_vocab_mask(self, mask, idx): # Fill mask for position idx
def move_vocab_mask(mask, device): # CPU -> GPU transfer
def apply_vocab_mask(logits, mask): # Mask invalid tokens in logits
def try_jump_forward(self, tokenizer): # Jump-ahead optimization
def copy(self): # Clone for caching
6.3 XGrammar Backend (Default)
XGrammar uses a compiled grammar approach with efficient bitmask operations:
class XGrammarGrammarBackend(BaseGrammarBackend):
def __init__(self, tokenizer, vocab_size, model_eos_token_ids, any_whitespace):
# Create tokenizer info for xgrammar
tokenizer_info = TokenizerInfo.from_huggingface(
tokenizer, vocab_size=vocab_size, stop_token_ids=model_eos_token_ids
)
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
def dispatch_json(self, key_string):
if key_string == "$$ANY$$":
ctx = self.grammar_compiler.compile_builtin_json_grammar()
else:
ctx = self.grammar_compiler.compile_json_schema(
schema=key_string, any_whitespace=self.any_whitespace
)
matcher = GrammarMatcher(ctx, max_rollback_tokens=200)
return XGrammarGrammar(matcher, self.vocab_size, ctx, ...)
The mask operations use Triton (NVIDIA) or sgl_kernel (AMD) for GPU-side bitmask application:
class XGrammarGrammar(BaseGrammarObject):
def allocate_vocab_mask(self, vocab_size, batch_size, device):
return allocate_token_bitmask(batch_size, vocab_size) # Compact bitmask
def fill_vocab_mask(self, vocab_mask, idx):
self.matcher.fill_next_token_bitmask(vocab_mask, idx) # CPU-side fill
def apply_vocab_mask(self, logits, vocab_mask):
# GPU-accelerated bitmask application
if _is_hip:
apply_token_bitmask_inplace_cuda(logits, vocab_mask)
else:
apply_token_bitmask_inplace_triton(logits, vocab_mask)
6.4 LLGuidance Backend
class GuidanceBackend(BaseGrammarBackend):
def __init__(self, tokenizer, any_whitespace, whitespace_pattern):
self.llguidance_tokenizer = from_tokenizer(self.tokenizer, n_vocab)
def dispatch_json(self, key_string):
serialized = LLMatcher.grammar_from_json_schema(
key_string,
defaults={"whitespace_flexible": self.any_whitespace}
)
return GuidanceGrammar(self.llguidance_tokenizer, serialized)
def dispatch_regex(self, key_string):
serialized = grammar_from("regex", key_string)
return self._from_serialized(serialized)
class GuidanceGrammar(BaseGrammarObject):
def try_jump_forward(self, tokenizer):
ff_tokens = self.ll_matcher.compute_ff_tokens() # Fast-forward tokens
if ff_tokens: return ff_tokens, ""
return None
6.5 Outlines Backend
class OutlinesGrammarBackend(BaseGrammarBackend):
def _compile_regex(self, regex):
guide = RegexGuide.from_regex(regex, self.outlines_tokenizer)
return OutlinesGrammar(guide, jump_forward_map=None)
def dispatch_json(self, key_string):
regex = build_regex_from_schema(key_string, self.whitespace_pattern)
return self._compile_regex(regex)
class OutlinesGrammar(BaseGrammarObject):
def fill_vocab_mask(self, vocab_mask, idx):
tokens = torch.tensor(self.guide.get_next_instruction(self.state).tokens)
vocab_mask[idx].fill_(1)
vocab_mask[idx].scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
def apply_vocab_mask(logits, vocab_mask):
logits.masked_fill_(vocab_mask, float("-inf")) # Bool mask approach
6.6 Backend Comparison
| Feature | XGrammar (default) | LLGuidance | Outlines |
|---|---|---|---|
| Mask Format | Compact bitmask (packed bits) | Compact bitmask | Bool tensor (B x V) |
| GPU Application | Triton/sgl_kernel (bitwise) | llguidance.torch apply | torch.masked_fill_ |
| JSON Schema | compile_json_schema() | grammar_from_json_schema() | build_regex_from_schema() |
| Regex | compile_regex() | grammar_from("regex", ...) | RegexGuide.from_regex() |
| EBNF | compile_grammar() | grammar_from("ebnf", ...) | Not supported |
| Structural Tags | compile_structural_tag() | StructTag.to_grammar() | Not supported |
| Jump-Forward | find_jump_forward_string() | compute_ff_tokens() | jump_forward_byte() |
| Rollback | matcher.rollback(k) | ll_matcher.rollback(k) | state = next_state (FSM) |
6.7 GrammarManager: Async Compilation Pipeline
class GrammarManager:
def process_req_with_grammar(self, req):
# Determine grammar key from request params
if req.sampling_params.json_schema: key = ("json", req.sampling_params.json_schema)
elif req.sampling_params.regex: key = ("regex", req.sampling_params.regex)
elif req.sampling_params.ebnf: key = ("ebnf", req.sampling_params.ebnf)
value, cache_hit = self.grammar_backend.get_cached_or_future_value(key, ...)
req.grammar = value # Either BaseGrammarObject or Future
if not cache_hit:
self.grammar_queue.append(req) # Queue for polling
def get_ready_grammar_requests(self):
"""Poll futures in grammar_queue, return ready requests."""
ready_req_idxs = set()
while time.perf_counter() - start < POLL_INTERVAL:
for i, req in enumerate(self.grammar_queue):
if req.grammar.done():
ready_req_idxs.add(i)
# For multi-rank: sync ready sets across all DP/TP ranks
if self.grammar_sync_size > 1:
torch.distributed.all_gather_object(...)
synced_ready = set.intersection(...)
# Resolve futures and cache results
for i in synced_ready:
req.grammar = req.grammar.result() # Future -> GrammarObject
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
return return_reqs
7. How Grammar Mask Generation is Overlapped with GPU Inference
SGLang overlaps the CPU-bound grammar mask computation with the GPU model forward pass. This is achieved through the batch processing pipeline in the scheduler and SamplingBatchInfo.
7.1 The Overlap Strategy
- Before GPU forward:
update_regex_vocab_mask()is called on CPU. For each request in the batch,fill_vocab_mask()populates the bitmask based on the current grammar state. - During GPU forward: The mask tensor is transferred to GPU via
move_vocab_mask(device)withnon_blocking=True, overlapping with the model computation. - After GPU forward:
apply_logits_bias(logits)applies the mask to the logits before sampling. Invalid tokens get-inf(outlines) or are bitmasked (xgrammar/llguidance). - After sampling:
accept_token(token_id)advances each grammar's state for the next iteration.
7.2 SamplingBatchInfo: The Mask Pipeline
def update_regex_vocab_mask(self):
if not self.grammars:
self.vocab_mask = None
return
first_grammar = next(g for g in self.grammars if g)
# Step 1: Allocate mask on CPU
self.vocab_mask = first_grammar.allocate_vocab_mask(
vocab_size=self.vocab_size,
batch_size=len(self.temperatures),
device=self.device, # CPU device for xgrammar/llguidance
)
self.apply_mask_func = first_grammar.apply_vocab_mask
# Step 2: Fill mask for each request (CPU work)
for i, grammar in enumerate(self.grammars):
if grammar and not grammar.finished and not grammar.is_terminated():
grammar.fill_vocab_mask(self.vocab_mask, i)
# Step 3: Transfer to GPU (non-blocking for overlap)
self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device)
def apply_logits_bias(self, logits):
# Apply penalties first
if self.acc_linear_penalties is not None:
logits.add_(self.acc_linear_penalties)
# Apply grammar mask
if self.vocab_mask is not None:
self.apply_mask_func(logits=logits, vocab_mask=self.vocab_mask)
7.3 Scheduler Integration
From the scheduler code, we see that grammar processing is coordinated with the batch lifecycle:
# In scheduler.py:
# Grammar compilation happens BEFORE the request enters the waiting queue
self.grammar_manager = GrammarManager(self)
# The scheduler notes overlap constraints:
# "We do not support overlap + spec + grammar yet"
need_grammar_sync = (
batch and batch.is_spec_v2
and batch.has_grammar
and batch.forward_mode.is_decode()
)
8. SamplingParams: All Sampling Options
The server-side SamplingParams class in python/sglang/srt/sampling/sampling_params.py defines the full set of sampling options that the engine understands.
class SamplingParams:
def __init__(self,
# Generation control
max_new_tokens: int = 128,
min_new_tokens: int = 0,
n: int = 1,
# Stopping criteria
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
ignore_eos: bool = False,
no_stop_trim: bool = False,
# Temperature and nucleus sampling
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
# Penalty-based sampling
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repetition_penalty: float = 1.0,
# Structured generation constraints
json_schema: Optional[str] = None,
regex: Optional[str] = None,
ebnf: Optional[str] = None,
structural_tag: Optional[str] = None,
# Output options
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
# Advanced
custom_params: Optional[Dict[str, Any]] = None,
stream_interval: Optional[int] = None,
logit_bias: Optional[Dict[str, float]] = None,
sampling_seed: Optional[int] = None,
)
8.1 Temperature Handling
# Special case: temperature=0 becomes greedy (top_k=1)
if 0 <= self.temperature < _SAMPLING_EPS: # 1e-6
self.temperature = 1.0
self.top_k = 1 # Greedy: argmax
if self.top_k == -1:
self.top_k = TOP_K_ALL # 1 << 30 = entire vocabulary
8.2 Constraint Validation
def verify(self, vocab_size):
# Only ONE grammar constraint allowed
grammars = [self.json_schema, self.regex, self.ebnf]
if sum(x is not None for x in grammars) > 1:
raise ValueError("Only one of regex, json_schema, or ebnf can be set.")
# Range checks
assert 0.0 < self.top_p <= 1.0
assert -2.0 <= self.frequency_penalty <= 2.0
assert 0.0 <= self.repetition_penalty <= 2.0
8.3 Complete Sampling Parameter Table
| Parameter | Type | Default | Description |
|---|---|---|---|
max_new_tokens | int | 128 | Maximum tokens to generate |
min_new_tokens | int | 0 | Minimum tokens before stopping |
temperature | float | 1.0 | Sampling temperature (0=greedy) |
top_p | float | 1.0 | Nucleus sampling threshold |
top_k | int | -1 | Top-k tokens (-1=disabled) |
min_p | float | 0.0 | Minimum probability threshold |
frequency_penalty | float | 0.0 | Penalize frequent tokens [-2, 2] |
presence_penalty | float | 0.0 | Penalize present tokens [-2, 2] |
repetition_penalty | float | 1.0 | Repetition penalty [0, 2] |
stop | str/List[str] | None | Stop strings |
stop_token_ids | List[int] | None | Stop token IDs |
stop_regex | str/List[str] | None | Stop when regex matches |
json_schema | str | None | JSON schema for structured output |
regex | str | None | Regex for constrained generation |
ebnf | str | None | EBNF grammar for constrained generation |
structural_tag | str | None | Structural tag format (function calling) |
logit_bias | Dict[str,float] | None | Per-token logit adjustments |
sampling_seed | int | None | Deterministic sampling seed |
n | int | 1 | Number of completions per prompt |
9. ServerArgs: Key Configuration Parameters
The ServerArgs dataclass in python/sglang/srt/server_args.py defines 200+ configuration parameters. Here are the most important groups:
Model & Tokenizer
Coremodel_path: strtokenizer_path: Optional[str]tokenizer_mode: "auto"skip_tokenizer_init: Falseload_format: "auto"trust_remote_code: Falsecontext_length: Optional[int]revision: Optional[str]
Quantization
Performancedtype: "auto"quantization: None(awq, fp8, gptq, ...)kv_cache_dtype: "auto"enable_fp32_lm_head: False- Supports: awq, fp8, mxfp8, gptq, marlin, bitsandbytes, gguf, w8a8_int8, w8a8_fp8, mxfp4, ...
Memory & Scheduling
Resourcemem_fraction_static: Optional[float]max_running_requests: Optional[int]max_total_tokens: Optional[int]chunked_prefill_size: Optional[int]max_prefill_tokens: 16384schedule_policy: "fcfs"schedule_conservativeness: 1.0radix_eviction_policy: "lru"
Parallelism
Scaletp_size: 1(tensor parallelism)pp_size: 1(pipeline parallelism)dp_size: 1(data parallelism)attn_cp_size: 1moe_dp_size: 1nnodes: 1node_rank: 0
Kernel Backends
Backendattention_backend: None(triton, flashinfer, fa3, fa4, cutlass_mla, ...)sampling_backend: None(flashinfer, pytorch, ascend)grammar_backend: None(xgrammar, outlines, llguidance, none)mamba_backend: "triton"fp8_gemm_runner_backend: "auto"
Speculative Decoding
Speedspeculative_algorithm: Nonespeculative_draft_model_path: Nonespeculative_num_steps: Nonespeculative_eagle_topk: Nonespeculative_num_draft_tokens: Nonespeculative_accept_threshold_single: 1.0
API & Chat
Interfaceapi_key: Noneserved_model_name: Nonechat_template: Nonereasoning_parser: Nonetool_call_parser: Nonesampling_defaults: "model"enable_cache_report: False
LoRA Support
Adaptersenable_lora: Nonemax_lora_rank: Nonelora_target_modules: Nonelora_paths: Nonemax_loaded_loras: Nonemax_loras_per_batch: 8lora_backend: "csgmv"
9.1 Grammar Backend Choices
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
9.2 Attention Backend Choices
ATTENTION_BACKEND_CHOICES = [
# Common
"triton", "torch_native", "flex_attention", "nsa",
# NVIDIA specific
"cutlass_mla", "fa3", "fa4", "flashinfer", "flashmla",
"trtllm_mla", "trtllm_mha",
# AMD specific
"aiter", "wave",
# Other platforms
"intel_amx", "ascend", "intel_xpu",
]
10. TokenizerManager: Chat Template Application & Multimodal Preprocessing
The TokenizerManager in python/sglang/srt/managers/tokenizer_manager.py is a dedicated process that handles all tokenization, chat template application, and multimodal data preprocessing before requests reach the scheduler.
10.1 Initialization Pipeline
class TokenizerManager:
def __init__(self, server_args, port_args):
# 1. Parse model config
self.init_model_config()
# 2. Initialize tokenizer and multimodal processor
self.init_tokenizer_and_processor()
# 3. Set up ZMQ IPC channels
self.init_ipc_channels(port_args)
# 4. Initialize request state tracking
self.init_running_status()
# 5. Set up LoRA registry
self.init_lora()
# 6. Set up disaggregation
self.init_disaggregation()
# 7. Start metrics and watchdog
self.init_metric_collector_watchdog()
10.2 Multimodal Processor Initialization
def init_tokenizer_and_processor(self):
if self.model_config.is_multimodal:
import_processors("sglang.srt.multimodal.processors")
_processor = _get_processor_wrapper(server_args)
# Create async multimodal data processor with concurrency control
self.mm_processor = get_mm_processor(
self.model_config.hf_config, server_args, _processor, transport_mode
)
self.mm_data_processor = AsyncMMDataProcessor(
self.mm_processor,
max_concurrent_calls=self.server_args.mm_max_concurrent_calls,
timeout_s=self.server_args.mm_per_request_timeout,
)
self.tokenizer = get_tokenizer_from_processor(self.processor)
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
10.3 Request Tokenization: _tokenize_one_request()
async def _tokenize_one_request(self, obj):
# Handle different input types
if obj.input_embeds is not None:
input_embeds = obj.input_embeds
input_ids = obj.input_ids
elif obj.input_ids is not None:
input_ids = obj.input_ids
else:
# Tokenize text (uses async batch tokenizer if available)
input_ids, token_type_ids = await self._tokenize_texts(
input_text, is_cross_encoder_request
)
# Multimodal preprocessing
if self.mm_processor and obj.contains_mm_input():
# Normalize image/video/audio data to lists
if obj.image_data is not None and not isinstance(obj.image_data, list):
obj.image_data = [obj.image_data]
# Process multimodal data asynchronously
mm_inputs = await self.mm_data_processor.process(
image_data=obj.image_data,
audio_data=obj.audio_data,
input_text_or_ids=(input_text or input_ids),
request_obj=obj,
max_req_input_len=self.max_req_input_len,
)
10.4 Dynamic Batch Tokenizer
SGLang supports batching tokenization requests for efficiency:
# In init_tokenizer_and_processor:
if server_args.enable_dynamic_batch_tokenizer:
self.async_dynamic_batch_tokenizer = AsyncDynamicbatchTokenizer(
self.tokenizer,
max_batch_size=server_args.dynamic_batch_tokenizer_batch_size,
batch_wait_timeout_s=server_args.dynamic_batch_tokenizer_batch_timeout,
)
# In _tokenize_texts:
if use_async_tokenizer:
result = await self.async_dynamic_batch_tokenizer.encode(text, **kwargs)
else:
encoded = self.tokenizer(tokenizer_input, **kwargs)
10.5 Input Format Detection
class InputFormat(Enum):
SINGLE_STRING = 1 # "Hello world"
BATCH_STRINGS = 2 # ["Hello", "World"]
CROSS_ENCODER_PAIRS = 3 # [["query", "document"]]
def _detect_input_format(self, texts, is_cross_encoder):
if isinstance(texts, str):
return InputFormat.SINGLE_STRING
if is_cross_encoder and isinstance(texts[0], list):
return InputFormat.CROSS_ENCODER_PAIRS
return InputFormat.BATCH_STRINGS
10.6 Request State Management
@dataclasses.dataclass
class ReqState:
out_list: List[Dict] # Accumulated output chunks
finished: bool # Completion flag
event: asyncio.Event # Notification event
obj: GenerateReqInput # Original request
time_stats: APIServerReqTimeStats
# Incremental state for streaming
text: str = ""
output_ids: List[int] = field(default_factory=list)
last_output_offset: int = 0
last_text_offset: int = 0
# Logprob tracking
input_token_logprobs_val: List[float] = field(default_factory=list)
output_token_logprobs_val: List[float] = field(default_factory=list)
input_top_logprobs_val: List[List[float]] = field(default_factory=list)
...
10.7 IO Structures: GenerateReqInput
The GenerateReqInput in io_struct.py defines the complete request format:
@dataclass
class GenerateReqInput(BaseReq):
text: Optional[Union[List[str], str]] = None
input_ids: Optional[Union[List[List[int]], List[int]]] = None
input_embeds: Optional[...] = None
# Multimodal inputs
image_data: Optional[MultimodalDataInputFormat] = None
video_data: Optional[MultimodalDataInputFormat] = None
audio_data: Optional[MultimodalDataInputFormat] = None
# Sampling configuration
sampling_params: Optional[Union[List[Dict], Dict]] = None
# Logprob options
return_logprob: Optional[Union[List[bool], bool]] = None
logprob_start_len: Optional[Union[List[int], int]] = None
top_logprobs_num: Optional[Union[List[int], int]] = None
return_text_in_logprobs: bool = False
# Streaming
stream: bool = False
# Session continuity
session_params: Optional[Union[List[Dict], Dict]] = None
# LoRA
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Reasoning
require_reasoning: bool = False
# DP routing
routed_dp_rank: Optional[int] = None
Summary
SGLang's frontend is a layered system:
- API Layer (
api.py): User-facing primitives --gen(),select(),image(), role helpers - IR Layer (
ir.py): Typed expression nodes forming an AST --SglGen,SglSelect,SglExprList, etc. - Interpreter Layer (
interpreter.py):StreamExecutorprocesses IR nodes in a background thread, manages variables/events, supports fork/join - Backend Layer (
runtime_endpoint.py): HTTP bridge to the SGLang server, with KV cache reuse viaconcatenate_and_append - Grammar Layer (
constrained/): Three interchangeable backends (xgrammar, llguidance, outlines) for structured generation with CPU-GPU overlapped mask computation - Server Layer (
tokenizer_manager.py,server_args.py): 200+ configuration parameters, async tokenization, multimodal preprocessing, chat template application
The key innovation is that SGLang programs are ordinary Python functions where s += expr submits IR nodes to a background executor. This enables natural control flow (if/else, loops) while the executor handles asynchronous generation, KV cache management, and fork/join parallelism transparently.