Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
a8be14c
Changes for SGLang support
avnermay Mar 18, 2026
1b2af07
Small test script
avnermay Mar 18, 2026
b9aceb5
Changes
avnermay Mar 18, 2026
fb9546a
Runner helpers
avnermay Mar 18, 2026
e8f7292
Updates to small test, assert in loader.py
avnermay Mar 18, 2026
af8c8ac
Changes
avnermay Mar 18, 2026
ff11967
Refactor of runner_helpers for all send/receive commands to use same …
avnermay Mar 19, 2026
6795127
Switch some torch.empty calls back to torch.zeros for correctness
avnermay Mar 19, 2026
04439b1
Add PrefillRequest and SpeculationRequest objects in runner_helpers.py
avnermay Mar 19, 2026
a3d6cf0
NIT bug fix
avnermay Mar 20, 2026
0b8a6e5
Further refactor of PrefillRequest, SpeculationRequest, SpeculationRe…
avnermay Mar 20, 2026
6a36a14
Improvements to logging
avnermay Mar 21, 2026
b8c1fd7
Support for Phoenix V1
avnermay Mar 23, 2026
4c127df
dist_utils needed for cross-node support
avnermay Mar 23, 2026
7a968e8
Merge branch 'avner/sglang' into avner/sglang-phnx
avnermay Mar 23, 2026
82ca79c
Fix bugs in how recovery_activations and eagle_activations are set an…
avnermay Mar 23, 2026
e632702
Merge branch 'avner/sglang' into avner/sglang-phnx
avnermay Mar 24, 2026
7053b80
FA4 initial implementation by CC
avnermay Mar 28, 2026
66b8b7b
FA4 support
avnermay Mar 28, 2026
65301a3
Add tests and tree_mask.py so that FA4 works
avnermay Mar 28, 2026
5256853
Merge branch 'avner/sglang-fa4' into avner/sglang-phnx-fa4
avnermay Mar 28, 2026
fc1130d
Remove debug loading of Eagle activations
avnermay Mar 28, 2026
aa50214
Merge branch 'avner/sglang' into avner/sglang-fa4
avnermay Mar 28, 2026
42cea6b
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 28, 2026
d1c9215
Update pyproject.toml to reflect flash-attn 4 dependency, and no more…
Mar 28, 2026
eb13cd3
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
Mar 28, 2026
2463748
Fix FA4 import
avnermay Mar 28, 2026
7184e54
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 28, 2026
d86d0fb
Add logging statement once draft process is waiting for target proces…
avnermay Mar 28, 2026
ab487ac
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 28, 2026
1425f32
Trust remote code fix
avnermay Mar 28, 2026
743fb40
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 28, 2026
cb51158
Add logging for draft model warmup
avnermay Mar 28, 2026
bfa56fd
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 28, 2026
e701bfe
More logging
avnermay Mar 29, 2026
bfcb931
Switch all attention calls to use FA4
avnermay Mar 29, 2026
80f2f76
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 29, 2026
cce45eb
Add tests for attention fa4
avnermay Mar 29, 2026
332b1f3
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 29, 2026
080c4a3
Upgrade transformers, pin FA4
avnermay Mar 29, 2026
37954a6
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 29, 2026
eb5e612
DUMP_TENSORS=false fix
avnermay Mar 30, 2026
08248b2
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 30, 2026
ff59fdf
Switch from ssh to https git dependency in pyproject.toml
avnermay Mar 31, 2026
dbdaa7b
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 31, 2026
107602a
Higher timeouts, clearer target <-> draft waiting messages, remove re…
avnermay Apr 1, 2026
0105932
Merge branch 'avner/sglang' into avner/sglang-fa4
avnermay Apr 1, 2026
ddaff75
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Apr 1, 2026
f8af8e7
Acceptance rate log and force-jit-speculate
avnermay Apr 10, 2026
4c6997f
Improvements to benchmarking
avnermay Apr 10, 2026
b417d75
NIT: print cache_hits as ints
avnermay Apr 10, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions bench/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def parse_arguments():
parser.add_argument("--fl", type=int, nargs='+', default=None, help="Fan out list (e.g., --fl 1 3 4 becomes [1, 3, 4])")
parser.add_argument("--flh", type=int, nargs='+', default=None, help="Fan out list (e.g., --flh 1 3 4 becomes [1, 3, 4])")
parser.add_argument("--flm", type=int, nargs='+', default=None, help="Fan out list miss (e.g., --flm 1 3 4 becomes [1, 3, 4])")
parser.add_argument("--backup", type=str, choices=["jit", "fast"], default="jit", help="Backup strategy (jit or fast)")
parser.add_argument("--backup", type=str, choices=["jit", "force-jit", "fast"], default="jit", help="Backup strategy (jit or fast)")

# Memory and batching configuration
parser.add_argument("--block_sz", type=int, default=256, help="KV cache block size (see config.py: kvcache_block_size)")
Expand Down Expand Up @@ -129,7 +129,7 @@ def initialize_wandb(args, run_name):
"gpus": args.gpus,
"speculative_decoding": args.spec,
"async_speculative": getattr(args, 'async', False),
"jit_speculative": args.backup == "jit",
"backup_strategy": args.backup,
"k": args.k if args.spec else None,
"f": args.f,
"fan_out_list": args.flh,
Expand Down Expand Up @@ -172,8 +172,11 @@ def create_llm_kwargs(args, draft_path):
max_num_seqs=args.b,
max_model_len=args.max_model_len,
sampler_x=args.x,
jit_speculate=(args.backup == "jit"),
jit_speculate=(args.backup == "jit" or args.backup == "force-jit"),
force_jit_speculate=(args.backup == "force-jit"),
max_steps=args.max_steps,
communicate_cache_hits=True,
communicate_logits=True,
)

if args.flh is not None:
Expand Down
9 changes: 8 additions & 1 deletion bench/bench_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def load_dataset_token_ids(
return None

dataset_file_path = DATASET_PATHS[dataset_name]
print(f"Loading dataset '{dataset_name}' from: {dataset_file_path}")
if not os.path.exists(dataset_file_path):
print(
f"Warning: Dataset file not found at {dataset_file_path}, falling back to random tokens")
Expand All @@ -172,10 +173,16 @@ def load_dataset_token_ids(
data = json.loads(line.strip())
text: str = data["text"]
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
tokens = tokenizer.apply_chat_template(
result = tokenizer.apply_chat_template(
[{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": text}],
add_generation_prompt=True,
)
text_result = tokenizer.apply_chat_template(
[{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": text}],
add_generation_prompt=True,
tokenize=False,
)
tokens = result.input_ids if hasattr(result, 'input_ids') else result
else:
tokens = tokenizer.encode(text, add_special_tokens=False)

Expand Down
10 changes: 9 additions & 1 deletion bench/bench_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def _required_env(var_name: str, note: str) -> str:
"BENCH_LLAMA_1B",
f"{HF_CACHE_DIR}/models--meta-llama--Llama-3.2-1B-Instruct",
),
"qwen_8b": os.environ.get(
"BENCH_QWEN_8B",
f"{HF_CACHE_DIR}/models--Qwen--Qwen3-8B",
),
"qwen_32b": os.environ.get(
"BENCH_QWEN_32B",
f"{HF_CACHE_DIR}/models--Qwen--Qwen3-32B",
Expand All @@ -62,12 +66,16 @@ def _required_env(var_name: str, note: str) -> str:
),
"eagle3_llama_70b": os.environ.get(
"BENCH_EAGLE3_LLAMA_70B",
"lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-SpecForge",
f"{HF_CACHE_DIR}/models--lmsys--SGLang-EAGLE3-Llama-3.3-70B-Instruct-SpecForge",
),
"eagle3_qwen_32b": os.environ.get(
"BENCH_EAGLE3_QWEN_32B",
"Zhihu-ai/Zhi-Create-Qwen3-32B-Eagle3",
),
"phoenix2_qwen_8b": os.environ.get(
"BENCH_PHOENIX2_QWEN_8B",
"togethercomputer/phnx2-llama-decagon-4layer-v1.0",
),
}


Expand Down
213 changes: 149 additions & 64 deletions bench/run_sglang_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Usage:
python run_sglang_bench.py --llama # SD, Llama 70B
python run_sglang_bench.py --qwen # SD, Qwen 32B
python run_sglang_bench.py --llama --mode ar # autoregressive baseline
python run_sglang_bench.py --llama --mode AR # autoregressive baseline
python run_sglang_bench.py --llama --wandb --name myrun # log to wandb

Set model paths via env vars (BENCH_LLAMA_70B, etc.) or edit bench_paths.py.
Expand All @@ -23,77 +23,37 @@
from bench_paths import MODELS, resolve_snapshot


def get_server_cmd(args):
if args.llama:
target = resolve_snapshot(MODELS["llama_70b"])
draft = resolve_snapshot(MODELS["llama_1b"])
else:
target = resolve_snapshot(MODELS["qwen_32b"])
draft = resolve_snapshot(MODELS["qwen_0.6b"])

cmd = [
sys.executable, "-m", "sglang.launch_server",
"--model-path", target,
"--tp", str(args.tp),
"--mem-fraction-static", str(args.mem_frac),
"--max-running-requests", "1",
"--disable-radix-cache",
"--log-level", "warning",
"--port", str(args.port),
]

if args.mode == "sd":
# Speculative decoding with standalone draft model.
# Default: k=5 (num_steps=4, num_draft_tokens=5).
cmd += [
"--speculative-algorithm", "STANDALONE",
"--speculative-draft-model-path", draft,
"--speculative-num-steps", str(args.num_steps),
"--speculative-eagle-topk", "1",
"--speculative-num-draft-tokens", str(args.num_draft_tokens),
]
# mode == "ar": no speculative flags, just serve the target model.

return cmd, target


def wait_for_server(port, timeout=900, interval=5):
url = f"http://localhost:{port}/health"
deadline = time.time() + timeout
while time.time() < deadline:
try:
if requests.get(url, timeout=2).status_code == 200:
return True
except requests.ConnectionError:
pass
time.sleep(interval)
return False


def kill_server(proc):
if proc.poll() is None:
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
proc.wait()


def main():
parser = argparse.ArgumentParser(description="Launch SGLang server and benchmark it")
parser.add_argument("--llama", action="store_true", default=True)
parser.add_argument("--qwen", action="store_true")
parser.add_argument("--mode", choices=["ar", "sd"], default="sd",
parser.add_argument("--mode", choices=["AR", "STANDALONE", "ASYNC_STANDALONE", "EAGLE3", "ASYNC_EAGLE3", "PHOENIX", "ASYNC_PHOENIX"], default="STANDALONE",
help="ar = autoregressive, sd = speculative decoding (default)")
parser.add_argument("--tp", type=int, default=4)
parser.add_argument("--port", type=int, default=40010)
parser.add_argument("--mem_frac", type=float, default=0.70)
parser.add_argument("--num_steps", type=int, default=4, help="draft chain depth (k = num_steps + 1)")
parser.add_argument("--num_draft_tokens", type=int, default=5)
parser.add_argument("--mem-frac", type=float, default=0.70)
parser.add_argument("--num-steps", type=int, default=4, help="draft chain depth (k = num_steps + 1)")
parser.add_argument("--context-length", type=int, default=4096)
# Pass-through to eval client
parser.add_argument("--numseqs", type=int, default=128)
parser.add_argument("--output_len", type=int, default=512)
parser.add_argument("--output-len", type=int, default=512)
parser.add_argument("--temp", type=float, default=0.0)
parser.add_argument("--dataset", type=str, choices=["all", "humaneval", "alpaca", "c4", "ultrafeedback", "random", "example"], default="all")
parser.add_argument("--wandb", action="store_true")
parser.add_argument("--group", type=str, default=None)
parser.add_argument("--group", type=str, default="ssd")
parser.add_argument("--name", type=str, default=None)

parser.add_argument("--f", type=int, default=4, help="Async fan out value")
parser.add_argument("--fl", type=int, nargs='+', default=None, help="Fan out list (e.g., --fl 1 3 4 becomes [1, 3, 4])")
parser.add_argument("--flh", type=int, nargs='+', default=None, help="Fan out list (e.g., --flh 1 3 4 becomes [1, 3, 4])")
parser.add_argument("--flm", type=int, nargs='+', default=None, help="Fan out list miss (e.g., --flm 1 3 4 becomes [1, 3, 4])")
parser.add_argument("--jit", action="store_true")
parser.add_argument("--force-jit", action="store_true")
parser.add_argument("--communicate-cache-hits", action="store_true")
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--acceptance-rate-log", type=str, default=None,
help="Path to log acceptance rates (sets ACCEPTANCE_RATE_LOG env var for the server)")

args = parser.parse_args()
if args.qwen:
args.llama = False
Expand All @@ -107,7 +67,12 @@ def main():
capture_output=True)
time.sleep(2)

proc = subprocess.Popen(server_cmd, preexec_fn=os.setsid)
env = os.environ.copy()
if args.acceptance_rate_log:
env["ACCEPTANCE_RATE_LOG"] = args.acceptance_rate_log
print(f"ACCEPTANCE_RATE_LOG={args.acceptance_rate_log}")

proc = subprocess.Popen(server_cmd, preexec_fn=os.setsid, env=env)
try:
print("Waiting for server...")
if not wait_for_server(args.port):
Expand All @@ -122,15 +87,16 @@ def main():
"--numseqs", str(args.numseqs),
"--output_len", str(args.output_len),
"--temp", str(args.temp),
"--all", "--b", "1",
f"--{args.dataset}",
"--b", "1",
"--port", str(args.port),
]
if args.llama:
eval_cmd.append("--llama")
else:
eval_cmd.append("--qwen")
if args.mode == "sd":
eval_cmd += ["--draft", "1" if args.llama else "0.6"]
if is_eagle3(args.mode):
eval_cmd.append("--eagle")
if args.wandb:
eval_cmd += ["--wandb"]
if args.group:
Expand All @@ -145,5 +111,124 @@ def main():
print("Server stopped")


def is_spec(mode):
return mode in ["STANDALONE", "ASYNC_STANDALONE", "EAGLE3", "ASYNC_EAGLE3", "PHOENIX2", "ASYNC_PHOENIX2"]


def is_async(mode):
return mode in ["ASYNC_STANDALONE", "ASYNC_EAGLE3", "ASYNC_PHOENIX"]


def is_standalone(mode):
return mode in ["STANDALONE", "ASYNC_STANDALONE"]

def is_eagle3(mode):
return mode in ["EAGLE3", "ASYNC_EAGLE3"]


def is_phoenix(mode):
return mode in ["PHOENIX2", "ASYNC_PHOENIX2"]


def get_server_cmd(args):
if args.llama:
target = resolve_snapshot(MODELS["llama_70b"])
if is_standalone(args.mode):
draft = resolve_snapshot(MODELS["llama_1b"])

elif is_eagle3(args.mode):
draft = resolve_snapshot(MODELS["eagle3_llama_70b"])
else:
raise ValueError(f"Unsupported mode for llama: {args.mode}")
else:
target = resolve_snapshot(MODELS["qwen_32b"])
if is_standalone(args.mode):
draft = resolve_snapshot(MODELS["qwen_0.6b"])
elif is_eagle3(args.mode):
draft = resolve_snapshot(MODELS["eagle3_qwen_32b"])
elif is_phoenix(args.mode):
target = resolve_snapshot(MODELS["qwen_8b"])
draft = resolve_snapshot(MODELS["phoenix2_qwen_8b"])
else:
raise ValueError(f"Unsupported mode for qwen: {args.mode}")

cmd = [
sys.executable, "-m", "sglang.launch_server",
"--model-path", target,
"--tp", str(args.tp),
"--mem-fraction-static", str(args.mem_frac),
"--max-running-requests", "1",
# "--disable-radix-cache",
"--log-level", "warning",
"--port", str(args.port),
"--context-length", str(args.context_length),
]

if is_spec(args.mode):
# Speculative decoding with standalone draft model.
# Default: k=5 (num_steps=4, num_draft_tokens=5).
cmd += [
"--speculative-algorithm", args.mode,
"--speculative-draft-model-path", draft,
"--speculative-num-steps", str(args.num_steps),
"--speculative-eagle-topk", "1",
"--speculative-num-draft-tokens", str(args.num_steps + 1),
]
if is_async(args.mode):
cmd += [
"--speculative-async-fan-out", str(args.f),
]
if args.fl:
cmd += [
"--speculative-async-fan-out-list", ",".join(map(str, args.fl)),
]
if args.flh:
cmd += [
"--speculative-async-fan-out-list-hit", ",".join(map(str, args.flh)),
]
if args.flm:
cmd += [
"--speculative-async-fan-out-list-miss", ",".join(map(str, args.flm)),
]
if args.jit or args.force_jit:
cmd += [
"--speculative-async-jit-speculate",
]
if args.force_jit:
cmd += [
"--speculative-async-force-jit-speculate",
]
if args.communicate_cache_hits:
cmd += [
"--speculative-async-communicate-cache-hits",
]
if args.verbose:
cmd += [
"--speculative-async-verbose",
]

# mode == "ar": no speculative flags, just serve the target model.
return cmd, target


def wait_for_server(port, timeout=900, interval=5):
url = f"http://localhost:{port}/health"
deadline = time.time() + timeout
while time.time() < deadline:
try:
if requests.get(url, timeout=2).status_code == 200:
return True
except requests.ConnectionError:
pass
time.sleep(interval)
return False


def kill_server(proc):
if proc.poll() is None:
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
proc.wait()


if __name__ == "__main__":
main()
Loading