SGLang Scheduler 介绍

KV Cache 是如何被管理的

Table of Contents

本文聚焦于 SGLang 的 Scheduler、RadixCache 以及整体的数据流动,不介绍模型加载和 AttentionBackend (e.g. FlashAttention)。Bruce 的文章 [1] 提供了代码走读骨架和诸多启发,非常感谢。

本文使用的 SGLang 的 commit ID 是 d88ac9bc9a6528942166ed9c61fba53cacd0b87c (Oct 17, 2025),release 版本大概是 v0.5.4 (+/- 0.0.1)

总览

推理引擎向用户暴露一个 OpenAI 兼容的 HTTP 服务,然后将用户的输入给到 decode-only 模型,执行前向传播(forward)后,模型会给出一段相应的输出,最终通过 HTTP 服务返回给用户。

推理引擎需要根据 TP/PP/DP/MP 等参数,将模型的切片按需分布式加载到 GPU 上。引擎还需要将多个请求打包成一个 batch 以提升效率。比如串行处理请求导致效率下降,又比如 PP (Pipeline Parallelism) 方式下,不合理的 batch 会导致气泡(bubble)产生,导致 GPU 空转,进而影响效率。下图是一种比较原始的 PP 策略,每个 GPU 都有比较长时间的空闲。

5QQtNl

管理 KV Cache 也是引擎的核心能力,KV Cache 可以显著降低 GPU 计算量,但是显存容量不是无限的,在显存容量紧张的情况下,引擎决定哪些保留,哪些释放(通常是 LRU 策略)。

9eq4Nt

SGLang 核心组件是:

  • OpenAI-compatible HTTP Service: 负责接收用户的请求。
  • SGLang Runtime (SRT) 的核心是 Engine:
  • TokenizerManager 位于主进程,负责将文本切分为 tokens,启动多个 Scheduler 子进程和一个 DetokenizerManager 子进程,进程间通讯使用 ZeroMQ(应该是 abstract UDS)。
  • Scheduler 是引擎的大脑,负责把请求队列中的请求组合为一个 batch,负责在显存容量不足时撤回部分请求等等。
  • TpModelWorker 是很薄的一层封装,透传了 Scheduler 提供的 batch 信息给 ModelRunner
  • ModelRunner 真正的加载模型权重,初始化 PyTorch 的 process group,group 之间使用 nccl 通讯,管理 KV Cache 显存和请求 tokens 等等
  • DetokenizerManager 负责将模型生成的 tokens 转换为文字。
b6Owsf

数据流动方式是

  • 用户通过 HTTP 服务提交 prompt。
  • TokenizerManager 将用户的 prompt 分割为 tokens 后,通过 ZeroMQ 发送给 Scheduler
  • Scheduler 将请求添加到等待队列,并规划下一轮的 prefill 或者 decode batch,将 batch 发送到模型执行前向传播,最后将生成的预测 tokens 发送给 DetokenizerManager
  • DetokenizerManager 将 tokens 转换为文字,发送给 TokenizerManager
  • TokenizerManager 将结果返回给用户。

核心类

下图展示了 SGLang 在调度层面与 KV Cache 相关的核心类之间的交互关系,本质上后面的流程介绍都是在维护下面这几个组件的关系。

MnlAS9

Req

Req 代表一个用户请求。

TokenizerManager 将用户的文本切分转换为 tokens,用户的一个请求最终会转换为一个 token ID 数组。上图 Req 是由多个 token ID 组成的,token ID 可以被简单的理解为一个指向 token tensor 的指针。

请求需要执行一次 prefill 计算后,才能继续执行 decode 阶段,简单理解 prefill 阶段是模型理解用户的 prompt,decode 阶段是生成结果。

origin_input_ids 是用户的原始输入,output_ids 是大模型生成的 tokens。在 decode 阶段之前,output_ids 为空,每执行一次 decode,就会追加一个 token ID 到 output_ids

TokenToKVPoolAllocator

TokenToKVPoolAllocator (简称 Allocator) 负责分配和管理 KV Cache,但不操作显存。

KV Cache 是在 GPU 中的一块连续的空间,系统初始化的时候就被创建好了,参见 MHATokenToKVPool

Allocator 负责为每个 KV Cache 创建一个对应索引(index)。当需要申请 KV Cache 的时候,Allocator 返回的是 KV Cache 索引(KV indices),参见 TokenToKVPoolAllocator.alloc()

ReqToTokenPool

ReqToTokenPool 负责存储请求和 KV Cache 的映射关系:一个 token 对应一个 KV Cache。不过需要注意的是,不同的请求即使 token 一样,大部分情况下不能共享 KV Cache,除非它们的前缀是一样的。

每个请求能持有 KV Cache 数量由 max_context_len 字段决定,该字段读取的是模型配置(即模型支持的最大上下文长度)。

能容纳的请求数量由 size 字段决定。ModelRunner 根据当前剩余显存(total_gpu_memory)和单个 KV Cache 尺寸(cell_size)计算在在悲观情况下(即完全没有可共享 prefix 的情况下)能容纳最多的 KV Cache 数量(ModelRunner.max_total_num_tokens = total_gpu_memory / cell_size),然后再在按照最大上下文长度平均分配给每个请求(size = ModelRunner.max_total_num_tokens / max_context_len),参见 ModelRunner.init_memory_pool()

ReqToTokenPool 还有一个 free_slots 字段,它的作用类似于一个 bitmap,用来标记哪些 slots 可供新请求使用,如果 free_slots 满了则表明当前系统已经无法再接收新的请求了,参见 ReqToTokenPool.alloc()

以 req 0 为例,prefill 阶段需要用户原始输入,而用户原始输入又根据是否命中 KV Cache 分为了 prefix 和 extend 两个部分。前面我们说了,不同的请求间,如果他们前缀是匹配的,那么相同的 token 的 KV Cache 是可以共享的,常见的是在使用 prompt 的时候,通常有一段相同的系统提示词。这部分可共享的 KV Cache 被算在 prefix (绿色) 中,而另一部分就是 extend (蓝色),prefill 计算只会计算 extend 的部分,紫色部分则是 decode 阶段生成的 token 的对应 KV Cache 索引。

RadixCache

RadixCache 利用了 radix tree(一种前缀树,如下图所示,来源 [2])判断一个请求是否有可以匹配到的 prefix,输入的是请求(token IDs),如果匹配则返回匹配到的前缀 tokens 对应的 kv indices,匹配到的最后一个节点等,参见 RadixCache.match_prefix()

KXT5rM

总结下,Req 存储每个请求的 token IDs,TokenToKVPoolAllocator 负责管理和分配 KV indices,RadixCache 维护前缀 tokens 和 KV indices 的关系,方便 ReqToTokenPoolReqToTokenPool 存储每个请求 token 对应的 KV index。

ScheduleBatch

ScheduleBatch 是调度的基本单位,也就是请求会被放到一个 batch 下同时执行前向传播。它是由调度器(Scheduler)从等待队列中,按照请求优先级筛选出来的。

ScheduleBatch 负责将上述几个结构体串起来:

  • reqs 字段保存了当前 batch 的全部请求。
  • req_to_token_pool 字段指向了 ReqToTokenPool 实例。
  • token_to_kv_pool_allocator 字段指向了 TokenToKVPoolAllocator 实例。
  • tree_cache 字段指向了 RadixCache(p.s. 有好几种 cache 类型,本文只介绍 RadixCache)。
  • forward_mode 字段表示当前 batch 将要要执行什么操作,类型是 ForwardMode,大致上分为了 EXTENDDECODE 两种类型,prefill 算是 EXTEND 的一种特殊情况。
  • ……

那么它是如何跟这些字段互动的呢?举一个为请求(requests)申请 KV Cache 空间的例子,它会调用 token_to_kv_pool_allocator 申请 KV indices,然后将结果保存到 req_to_token_pool

Scheduler

Scheduler 是 SGLang 核心组件,也是本文的核心,负责用户请求的调度工作,比如决定当前要执行 prefill 计算还是 decode 计算,比如决定当前执行哪些用户(通常是匹配前缀越长的请求,越早执行)等等。

waiting_queue 字段是一个 Req 数组,用来保存还没有被打包到 batch 的请求。

Scheduler 内部有三个 ScheduleBatch,它们分别是 running_batchcur_batchlast_batchrunning_batch 只包含进入 decode 阶段的请求,cur_batch 表示当前正在执行前向传播的 batch,last_batch 表示上一轮正在执行前向传播的 batch。

为什么要有三个 ScheduleBatch?因为 SGLang 在策略上优先执行 prefill,直到没办法执行时才会执行 decode。我能想到的“没办法”有两种情况:(1)TokenizerManager 没传输新的用户请求了;(2)当前显存容量已经不足以再支撑新的 prefill 了。

再抽象一下,SGLang 使用的策略是 prefill 优先,因为 prefill 只需执行一轮,每次 prefill 完请求一定会进入 decode 阶段,然后它们最终都会被加入到 running_batch 中了。这样本质上是把 decode batch 搞得足够大时,才执行 decode 计算。为什么要这样?配合 LLM 的结果,我猜测原因可能是优化 TTFT (Time-To-First-Token),优先 prefill 整体上就能更快进入 decode 阶段。

说回 cur_batch,当“没办法”发生的时候,它的 forward_mode 就是 DECODE,当 cur_batch 运行结束的时候,cur_batch 会被赋值给 last_batch,那么 last_batch 也是 DECODE 了。

Scheduler 运行流程

Scheduler 本质上是一个无限循环,首先从 TokenizerManager 获取新的用户请求,下一步将用户请求放到等待队列(waiting queue)中,下一步是从请求队列中生成一个 batch 执行 prefill/decode 操作,下一步是执行 batch,即把请求扔到模型中执行前向传播,最后处理 batch 的结果,比如将结果返回给 DetokenizerManager

NKIPcF

初始化

Scheduler 初始化会把它锁依赖的全部类初始化,比如 TpModelWorker,然后 TpModelWorker 会再初始化 ModelRunner……是时候再搬出来这张图了。

9eq4Nt

普通模式下,无限循环是由 Scheduler.event_loop_normal() 触发的。

获取用户请求

TokenizerManager 会将用户的 prompt 文本转换为 tokens,通过 ZeroMQ 发送给 Scheduler

Scheduler 调用 Scheduler.recv_requests() 接收请求,返回的是请求数组(list[Req])。接收模式是不阻塞的(non-blocking),所以不一定会有请求进来(返回空数组)。

入队 Waiting Queue

忽略中间没什么营养的步骤,核心是调用 Scheduler._add_request_to_queue() 将用户请求(Req)追加到 waiting_queue 中。

获取下一个 Batch

ITt9CI

尝试把请求添加到 Running Batch

last_batch 是已经执行过前向传播的 batch (参见上图,run batch 结束后的 batch 会被赋值给 last_batch),分两种情况:

  • 如果上一轮是 EXTEND 模式,那么 prefill 运行一轮后,进入 decode 阶段,所以需要将 last_batch 的用户请求合并到 running_batch
  • 如果上一轮是 DECODE 模式,请求已经存在于 running_batch,就不需要做什么了。

尝试获取新的 Prefill Batch

Dclkpo

请求的剩余空间也取决于两个方面:

  • PP 最大 micro batch 容量(pp_max_micro_batch_size),启动 SGLang 服务时,可通过参数指定,默认值是 max_running_requests / pp_size。为什么要有 micro batch?因为如果 batch size 比较大,stage 之间会产生较大的 bubble。
  • ReqToTokenPool 的剩余可分配容量(req_to_token_pool_available),即 free_slots 的剩余容量。

剩余空间是根据 min((pp_max_micro_batch_size - running_bs), req_to_token_pool_available),其中 running_bsrunning_batch 的尺寸。这里使用 running_batch 中的请求数量作为当前系统中的请求数量,因为新 prefill batch 还没生成(本步骤不就是正在生成嘛),last_batch 中能合并的也都已经合并到 running_batch 中了。

进入到这里就说明至少有一个请求的空间可用

根据策略对 waiting_queue 中的请求排序,典型策略包括 CacheAgnosticPolicy.FSFC (缓存不感知、先来先服务)、CacheAwarePolicy.LPM (缓存感知、longest prefix match),其中 LPM 能够提高 KV Cache 的命中率。

waiting_queue 逐步出队优先级最高的请求,初始化请求(核心是匹配前缀并更新请求),交给 PrefillAdder 判断当前是否能够再增加一个新的请求。如果可以的话,用户请求会暂时存放在 PrefillAdder.can_run_list 字段中。最终 can_run_list 的请求就会组成下一轮 prefill 计算的 batch。

PrefillAdder 应用的限制是来自多方面的:

  • 一个 prefill batch 最多 token 数量(rem_input_tokens 字段),来源于参数 --max-prefill-tokens,默认值是 16384。注意这并不包括命中前缀的 token 长度。
  • 一个 chunk prefill 最多 token 数量(rem_chunk_tokens 字段),来源于参数 --chunked-prefill-size。Chunked prefill 是一个降低 PP bubble 的优化手段,通常比 --max-prefill-tokens 小。
  • 当前 KV Cache 剩余容量(rem_total_tokens 字段),是 (token_to_kv_pool_allocater_available + tree_cache_evictable) - rem_total_token_offset,其中:
  • token_to_kv_pool_allocater_available 表示当前 allocator 的可用容量。
  • tree_cache_evictable 表示 radix cache 中可以驱除的(比如已经没有引用的前缀)。
  • rem_total_token_offset 表示目前已有的 token 数量,它包括 running_batch 中需要 decode 的请求数量,这个数字不是准确值,是根据最大上下文长度的预测值,以及当前 cur_run_list 中已经添加的请求所占用的 token 数量,extend 部分是精确的,decode 部分同上也是预测的。

上述任何一个不满足要求就会跳出添加请求的循环,此时 can_run_list 中的请求就会传给一个新的 ScheduleBatch,作为下一轮 prefill batch。

调用 ScheduleBatch.prepare_for_extend() 执行 extend 的准备工作:

  • 设置前向传播模式为 ForwardMode.EXTEND
  • 根据当前 batch 的请求去更新字段,其中的部分字段会被传给模型,其中比如将用户输入的 token IDs 转换为设备上的 tensors 等。
  • 调用 token_to_kv_pool_allocater.alloc() 为 extend tokens(参见上面“核心类”下面的图)申请 KV indices,然后写入到 req_to_token_pool

最后返回 prefill batch (p.s. 如果在上面的检查中出现了空间不足的情况,该函数会返回 None)。

返回一个 Batch

根据上面的结果,分为了 prefill batch 可执行、decode batch 可执行以及没有任何 batch 可执行三种情况。

Prefill batch 需要的显存等空间申请,已经在上一步处理完毕了,因此直接返回 prefill batch 即可。

同样的,没有任何 batch 可执行时,直接返回 None 即可。

Decode batch 在返回前需要调用 Scheduler.update_running_batch(),完成显存容量的检查、KV Cache 的申请以及对 decode batch 字段的更新。

每执行一次 decode 计算,每个请求都会生成且仅会生成 1 个 token,所以一次 decode 计算需要的空间就是当前 batch 中全部请求的数量。如果显存容量不足,就需要回撤(retract)请求,直到有充足显存。所谓回撤请求立即释放该请求所占用的显存(比如 KV Cache 等),然后将请求重新加到 waiting_queue 中。

调用 ScheduleBatch.prepare_for_decode 执行 decode 的准备工作:

  • 设置前向传播模式为 ForwardMode.DECODE
  • 将上一步生成的 token (Req.output_ids) 赋值为 Req.input_ids,然后将 Req.output_ids 赋值为 None,这是因为 decode 阶段只需要上一步的输入即可计算,相反的,prefill 阶段需要用户的完整输入。
  • token_to_kv_pool_allocater 申请 KV indices。

最后返回 decode batch。

运行 Batch

核心就是把 batch 送入到 TpModelWorker.forward_batch_generation(),然后再进一步送入到 ModelRunner.forward() 执行前向传播。

处理 Batch 运行结果

Scheduler 会分别处理 decode 结果(Scheduler.process_batch_result_decode())和 extend 结果(Scheduler.process_batch_result_prefill())。

处理 decode 结果本质就是遍历当前 batch 的全部用户请求:

  • 将新生成的 token 追加到 Req.output_ids,如果还没结束的话,供下一轮 decode 计算使用。
  • 检查计算是否已经结束,比如超过模型上下文长度限制,又比如产生了 EOS 字符,这些都作为结束的标志。
  • 如果当前计算已经完成时,会调用 RadixCache.cache_finished_req() 进行资源的清理和释放。

RadixCache.cache_finished_req() 这个函数需要单独拿出来再说一下。它的名字比较奇怪,明明是资源释放,为什么要叫 cache 呢?这也是 [1] 中提出的问题。该函数的 signature 如下所示,is_insert 就是问题所在了。当请求自然结束的时候(通常不会面临显存压力),会使用 is_insert=True 将结果全部缓存到 RadixCache 中,当然没有匹配到的结点会被标记为可驱逐的(evictable),这跟上面的情况就呼应上了。还有一种可能是,当我们面临显存压力的时候,我们需要立即释放 KV Cache,通常 is_insert=False

def cache_finished_req(self, req: Req, is_insert: bool = True)

不过 is_insert 的值不会影响 req_to_token_pool 的释放,仅影响 KV Cache 的行为。

References

  1. https://zhuanlan.zhihu.com/p/17186885141
  2. https://ivanzz1001.github.io/records/post/data-structure/2018/11/18/ds-radix-tree
All rights reserved
Except where otherwise noted, content on this page is copyrighted.