Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b8c1fd7
Support for Phoenix V1
avnermay Mar 23, 2026
7a968e8
Merge branch 'avner/sglang' into avner/sglang-phnx
avnermay Mar 23, 2026
e632702
Merge branch 'avner/sglang' into avner/sglang-phnx
avnermay Mar 24, 2026
7053b80
FA4 initial implementation by CC
avnermay Mar 28, 2026
5256853
Merge branch 'avner/sglang-fa4' into avner/sglang-phnx-fa4
avnermay Mar 28, 2026
42cea6b
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 28, 2026
eb13cd3
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
Mar 28, 2026
7184e54
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 28, 2026
ab487ac
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 28, 2026
743fb40
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 28, 2026
bfa56fd
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 28, 2026
e701bfe
More logging
avnermay Mar 29, 2026
80f2f76
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 29, 2026
332b1f3
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 29, 2026
37954a6
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 29, 2026
08248b2
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 30, 2026
dbdaa7b
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Mar 31, 2026
ddaff75
Merge branch 'avner/sglang-fa4' into avner/sglang-fa4-phnx
avnermay Apr 1, 2026
f8af8e7
Acceptance rate log and force-jit-speculate
avnermay Apr 10, 2026
4c6997f
Improvements to benchmarking
avnermay Apr 10, 2026
b417d75
NIT: print cache_hits as ints
avnermay Apr 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions bench/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def parse_arguments():
parser.add_argument("--fl", type=int, nargs='+', default=None, help="Fan out list (e.g., --fl 1 3 4 becomes [1, 3, 4])")
parser.add_argument("--flh", type=int, nargs='+', default=None, help="Fan out list (e.g., --flh 1 3 4 becomes [1, 3, 4])")
parser.add_argument("--flm", type=int, nargs='+', default=None, help="Fan out list miss (e.g., --flm 1 3 4 becomes [1, 3, 4])")
parser.add_argument("--backup", type=str, choices=["jit", "fast"], default="jit", help="Backup strategy (jit or fast)")
parser.add_argument("--backup", type=str, choices=["jit", "force-jit", "fast"], default="jit", help="Backup strategy (jit or fast)")

# Memory and batching configuration
parser.add_argument("--block_sz", type=int, default=256, help="KV cache block size (see config.py: kvcache_block_size)")
Expand Down Expand Up @@ -129,7 +129,7 @@ def initialize_wandb(args, run_name):
"gpus": args.gpus,
"speculative_decoding": args.spec,
"async_speculative": getattr(args, 'async', False),
"jit_speculative": args.backup == "jit",
"backup_strategy": args.backup,
"k": args.k if args.spec else None,
"f": args.f,
"fan_out_list": args.flh,
Expand Down Expand Up @@ -172,8 +172,11 @@ def create_llm_kwargs(args, draft_path):
max_num_seqs=args.b,
max_model_len=args.max_model_len,
sampler_x=args.x,
jit_speculate=(args.backup == "jit"),
jit_speculate=(args.backup == "jit" or args.backup == "force-jit"),
force_jit_speculate=(args.backup == "force-jit"),
max_steps=args.max_steps,
communicate_cache_hits=True,
communicate_logits=True,
)

if args.flh is not None:
Expand Down
9 changes: 8 additions & 1 deletion bench/bench_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def load_dataset_token_ids(
return None

dataset_file_path = DATASET_PATHS[dataset_name]
print(f"Loading dataset '{dataset_name}' from: {dataset_file_path}")
if not os.path.exists(dataset_file_path):
print(
f"Warning: Dataset file not found at {dataset_file_path}, falling back to random tokens")
Expand All @@ -172,10 +173,16 @@ def load_dataset_token_ids(
data = json.loads(line.strip())
text: str = data["text"]
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
tokens = tokenizer.apply_chat_template(
result = tokenizer.apply_chat_template(
[{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": text}],
add_generation_prompt=True,
)
text_result = tokenizer.apply_chat_template(
[{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": text}],
add_generation_prompt=True,
tokenize=False,
)
tokens = result.input_ids if hasattr(result, 'input_ids') else result
else:
tokens = tokenizer.encode(text, add_special_tokens=False)

Expand Down
10 changes: 9 additions & 1 deletion bench/bench_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def _required_env(var_name: str, note: str) -> str:
"BENCH_LLAMA_1B",
f"{HF_CACHE_DIR}/models--meta-llama--Llama-3.2-1B-Instruct",
),
"qwen_8b": os.environ.get(
"BENCH_QWEN_8B",
f"{HF_CACHE_DIR}/models--Qwen--Qwen3-8B",
),
"qwen_32b": os.environ.get(
"BENCH_QWEN_32B",
f"{HF_CACHE_DIR}/models--Qwen--Qwen3-32B",
Expand All @@ -62,12 +66,16 @@ def _required_env(var_name: str, note: str) -> str:
),
"eagle3_llama_70b": os.environ.get(
"BENCH_EAGLE3_LLAMA_70B",
"lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-SpecForge",
f"{HF_CACHE_DIR}/models--lmsys--SGLang-EAGLE3-Llama-3.3-70B-Instruct-SpecForge",
),
"eagle3_qwen_32b": os.environ.get(
"BENCH_EAGLE3_QWEN_32B",
"Zhihu-ai/Zhi-Create-Qwen3-32B-Eagle3",
),
"phoenix2_qwen_8b": os.environ.get(
"BENCH_PHOENIX2_QWEN_8B",
"togethercomputer/phnx2-llama-decagon-4layer-v1.0",
),
}


Expand Down
213 changes: 149 additions & 64 deletions bench/run_sglang_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Usage:
python run_sglang_bench.py --llama # SD, Llama 70B
python run_sglang_bench.py --qwen # SD, Qwen 32B
python run_sglang_bench.py --llama --mode ar # autoregressive baseline
python run_sglang_bench.py --llama --mode AR # autoregressive baseline
python run_sglang_bench.py --llama --wandb --name myrun # log to wandb

Set model paths via env vars (BENCH_LLAMA_70B, etc.) or edit bench_paths.py.
Expand All @@ -23,77 +23,37 @@
from bench_paths import MODELS, resolve_snapshot


def get_server_cmd(args):
if args.llama:
target = resolve_snapshot(MODELS["llama_70b"])
draft = resolve_snapshot(MODELS["llama_1b"])
else:
target = resolve_snapshot(MODELS["qwen_32b"])
draft = resolve_snapshot(MODELS["qwen_0.6b"])

cmd = [
sys.executable, "-m", "sglang.launch_server",
"--model-path", target,
"--tp", str(args.tp),
"--mem-fraction-static", str(args.mem_frac),
"--max-running-requests", "1",
"--disable-radix-cache",
"--log-level", "warning",
"--port", str(args.port),
]

if args.mode == "sd":
# Speculative decoding with standalone draft model.
# Default: k=5 (num_steps=4, num_draft_tokens=5).
cmd += [
"--speculative-algorithm", "STANDALONE",
"--speculative-draft-model-path", draft,
"--speculative-num-steps", str(args.num_steps),
"--speculative-eagle-topk", "1",
"--speculative-num-draft-tokens", str(args.num_draft_tokens),
]
# mode == "ar": no speculative flags, just serve the target model.

return cmd, target


def wait_for_server(port, timeout=900, interval=5):
url = f"http://localhost:{port}/health"
deadline = time.time() + timeout
while time.time() < deadline:
try:
if requests.get(url, timeout=2).status_code == 200:
return True
except requests.ConnectionError:
pass
time.sleep(interval)
return False


def kill_server(proc):
if proc.poll() is None:
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
proc.wait()


def main():
parser = argparse.ArgumentParser(description="Launch SGLang server and benchmark it")
parser.add_argument("--llama", action="store_true", default=True)
parser.add_argument("--qwen", action="store_true")
parser.add_argument("--mode", choices=["ar", "sd"], default="sd",
parser.add_argument("--mode", choices=["AR", "STANDALONE", "ASYNC_STANDALONE", "EAGLE3", "ASYNC_EAGLE3", "PHOENIX", "ASYNC_PHOENIX"], default="STANDALONE",
help="ar = autoregressive, sd = speculative decoding (default)")
parser.add_argument("--tp", type=int, default=4)
parser.add_argument("--port", type=int, default=40010)
parser.add_argument("--mem_frac", type=float, default=0.70)
parser.add_argument("--num_steps", type=int, default=4, help="draft chain depth (k = num_steps + 1)")
parser.add_argument("--num_draft_tokens", type=int, default=5)
parser.add_argument("--mem-frac", type=float, default=0.70)
parser.add_argument("--num-steps", type=int, default=4, help="draft chain depth (k = num_steps + 1)")
parser.add_argument("--context-length", type=int, default=4096)
# Pass-through to eval client
parser.add_argument("--numseqs", type=int, default=128)
parser.add_argument("--output_len", type=int, default=512)
parser.add_argument("--output-len", type=int, default=512)
parser.add_argument("--temp", type=float, default=0.0)
parser.add_argument("--dataset", type=str, choices=["all", "humaneval", "alpaca", "c4", "ultrafeedback", "random", "example"], default="all")
parser.add_argument("--wandb", action="store_true")
parser.add_argument("--group", type=str, default=None)
parser.add_argument("--group", type=str, default="ssd")
parser.add_argument("--name", type=str, default=None)

parser.add_argument("--f", type=int, default=4, help="Async fan out value")
parser.add_argument("--fl", type=int, nargs='+', default=None, help="Fan out list (e.g., --fl 1 3 4 becomes [1, 3, 4])")
parser.add_argument("--flh", type=int, nargs='+', default=None, help="Fan out list (e.g., --flh 1 3 4 becomes [1, 3, 4])")
parser.add_argument("--flm", type=int, nargs='+', default=None, help="Fan out list miss (e.g., --flm 1 3 4 becomes [1, 3, 4])")
parser.add_argument("--jit", action="store_true")
parser.add_argument("--force-jit", action="store_true")
parser.add_argument("--communicate-cache-hits", action="store_true")
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--acceptance-rate-log", type=str, default=None,
help="Path to log acceptance rates (sets ACCEPTANCE_RATE_LOG env var for the server)")

args = parser.parse_args()
if args.qwen:
args.llama = False
Expand All @@ -107,7 +67,12 @@ def main():
capture_output=True)
time.sleep(2)

proc = subprocess.Popen(server_cmd, preexec_fn=os.setsid)
env = os.environ.copy()
if args.acceptance_rate_log:
env["ACCEPTANCE_RATE_LOG"] = args.acceptance_rate_log
print(f"ACCEPTANCE_RATE_LOG={args.acceptance_rate_log}")

proc = subprocess.Popen(server_cmd, preexec_fn=os.setsid, env=env)
try:
print("Waiting for server...")
if not wait_for_server(args.port):
Expand All @@ -122,15 +87,16 @@ def main():
"--numseqs", str(args.numseqs),
"--output_len", str(args.output_len),
"--temp", str(args.temp),
"--all", "--b", "1",
f"--{args.dataset}",
"--b", "1",
"--port", str(args.port),
]
if args.llama:
eval_cmd.append("--llama")
else:
eval_cmd.append("--qwen")
if args.mode == "sd":
eval_cmd += ["--draft", "1" if args.llama else "0.6"]
if is_eagle3(args.mode):
eval_cmd.append("--eagle")
if args.wandb:
eval_cmd += ["--wandb"]
if args.group:
Expand All @@ -145,5 +111,124 @@ def main():
print("Server stopped")


def is_spec(mode):
return mode in ["STANDALONE", "ASYNC_STANDALONE", "EAGLE3", "ASYNC_EAGLE3", "PHOENIX2", "ASYNC_PHOENIX2"]


def is_async(mode):
return mode in ["ASYNC_STANDALONE", "ASYNC_EAGLE3", "ASYNC_PHOENIX"]


def is_standalone(mode):
return mode in ["STANDALONE", "ASYNC_STANDALONE"]

def is_eagle3(mode):
return mode in ["EAGLE3", "ASYNC_EAGLE3"]


def is_phoenix(mode):
return mode in ["PHOENIX2", "ASYNC_PHOENIX2"]


def get_server_cmd(args):
if args.llama:
target = resolve_snapshot(MODELS["llama_70b"])
if is_standalone(args.mode):
draft = resolve_snapshot(MODELS["llama_1b"])

elif is_eagle3(args.mode):
draft = resolve_snapshot(MODELS["eagle3_llama_70b"])
else:
raise ValueError(f"Unsupported mode for llama: {args.mode}")
else:
target = resolve_snapshot(MODELS["qwen_32b"])
if is_standalone(args.mode):
draft = resolve_snapshot(MODELS["qwen_0.6b"])
elif is_eagle3(args.mode):
draft = resolve_snapshot(MODELS["eagle3_qwen_32b"])
elif is_phoenix(args.mode):
target = resolve_snapshot(MODELS["qwen_8b"])
draft = resolve_snapshot(MODELS["phoenix2_qwen_8b"])
else:
raise ValueError(f"Unsupported mode for qwen: {args.mode}")

cmd = [
sys.executable, "-m", "sglang.launch_server",
"--model-path", target,
"--tp", str(args.tp),
"--mem-fraction-static", str(args.mem_frac),
"--max-running-requests", "1",
# "--disable-radix-cache",
"--log-level", "warning",
"--port", str(args.port),
"--context-length", str(args.context_length),
]

if is_spec(args.mode):
# Speculative decoding with standalone draft model.
# Default: k=5 (num_steps=4, num_draft_tokens=5).
cmd += [
"--speculative-algorithm", args.mode,
"--speculative-draft-model-path", draft,
"--speculative-num-steps", str(args.num_steps),
"--speculative-eagle-topk", "1",
"--speculative-num-draft-tokens", str(args.num_steps + 1),
]
if is_async(args.mode):
cmd += [
"--speculative-async-fan-out", str(args.f),
]
if args.fl:
cmd += [
"--speculative-async-fan-out-list", ",".join(map(str, args.fl)),
]
if args.flh:
cmd += [
"--speculative-async-fan-out-list-hit", ",".join(map(str, args.flh)),
]
if args.flm:
cmd += [
"--speculative-async-fan-out-list-miss", ",".join(map(str, args.flm)),
]
if args.jit or args.force_jit:
cmd += [
"--speculative-async-jit-speculate",
]
if args.force_jit:
cmd += [
"--speculative-async-force-jit-speculate",
]
if args.communicate_cache_hits:
cmd += [
"--speculative-async-communicate-cache-hits",
]
if args.verbose:
cmd += [
"--speculative-async-verbose",
]

# mode == "ar": no speculative flags, just serve the target model.
return cmd, target


def wait_for_server(port, timeout=900, interval=5):
url = f"http://localhost:{port}/health"
deadline = time.time() + timeout
while time.time() < deadline:
try:
if requests.get(url, timeout=2).status_code == 200:
return True
except requests.ConnectionError:
pass
time.sleep(interval)
return False


def kill_server(proc):
if proc.poll() is None:
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
proc.wait()


if __name__ == "__main__":
main()
10 changes: 10 additions & 0 deletions bench/small_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
llama_1b_path = '/scratch/avner/huggingface/hub/models--meta-llama--Llama-3.2-1B-Instruct/snapshots/9213176726f574b556790deb65791e0c5aa438b6'
llama_70b_path = '/scratch/avner/huggingface/hub/models--meta-llama--Llama-3.3-70B-Instruct/snapshots/6f6073b423013f6a7d4d9f39144961bfbfbc386b'
eagle_path = '/scratch/avner/huggingface/hub/models--lmsys--SGLang-EAGLE3-Llama-3.3-70B-Instruct-SpecForge/snapshots/63ebaa6585f96b89685adad8fdfa0da53be6a8fd'
phoenix_path = '/scratch/avner/huggingface/hub/models--togethercomputer--phoenix-Llama-3p2-1B-Instruct-tgt-Llama-3p3-70b-instruct-UNTRAINED/snapshots/3af59d71514388e14d8685f2b684f74e3e311717'
# eagle_path = '/scratch/avner/huggingface/hub/models--yuhuili--EAGLE3-LLaMA3.3-Instruct-70B'
assert os.path.isdir(llama_1b_path)
assert os.path.isdir(llama_70b_path)
Expand All @@ -18,6 +19,7 @@
parser.add_argument("--model", type=str, default=llama_1b_path)
parser.add_argument("--draft", type=str, default=llama_1b_path)
parser.add_argument("--eagle", action="store_true")
parser.add_argument("--phoenix", action="store_true")
parser.add_argument("--k", type=int, default=7)
parser.add_argument("--jit-speculate", action="store_true")
parser.add_argument("--num-gpus", type=int, default=2)
Expand All @@ -36,10 +38,18 @@
args.jit_speculate = True
args.chat_template = True

if args.phoenix:
args.draft = phoenix_path
args.model = llama_70b_path
args.num_gpus = 5
args.jit_speculate = True
args.chat_template = True

llm = LLM(
model=args.model,
draft=args.draft,
use_eagle=args.eagle,
use_phoenix=args.phoenix,
speculate_k=args.k,
speculate=True,
draft_async=True,
Expand Down
Loading