diff --git a/bench/bench.py b/bench/bench.py index b80f21955..5e013f099 100644 --- a/bench/bench.py +++ b/bench/bench.py @@ -37,7 +37,7 @@ def parse_arguments(): parser.add_argument("--fl", type=int, nargs='+', default=None, help="Fan out list (e.g., --fl 1 3 4 becomes [1, 3, 4])") parser.add_argument("--flh", type=int, nargs='+', default=None, help="Fan out list (e.g., --flh 1 3 4 becomes [1, 3, 4])") parser.add_argument("--flm", type=int, nargs='+', default=None, help="Fan out list miss (e.g., --flm 1 3 4 becomes [1, 3, 4])") - parser.add_argument("--backup", type=str, choices=["jit", "fast"], default="jit", help="Backup strategy (jit or fast)") + parser.add_argument("--backup", type=str, choices=["jit", "force-jit", "fast"], default="jit", help="Backup strategy (jit or fast)") # Memory and batching configuration parser.add_argument("--block_sz", type=int, default=256, help="KV cache block size (see config.py: kvcache_block_size)") @@ -129,7 +129,7 @@ def initialize_wandb(args, run_name): "gpus": args.gpus, "speculative_decoding": args.spec, "async_speculative": getattr(args, 'async', False), - "jit_speculative": args.backup == "jit", + "backup_strategy": args.backup, "k": args.k if args.spec else None, "f": args.f, "fan_out_list": args.flh, @@ -172,8 +172,11 @@ def create_llm_kwargs(args, draft_path): max_num_seqs=args.b, max_model_len=args.max_model_len, sampler_x=args.x, - jit_speculate=(args.backup == "jit"), + jit_speculate=(args.backup == "jit" or args.backup == "force-jit"), + force_jit_speculate=(args.backup == "force-jit"), max_steps=args.max_steps, + communicate_cache_hits=True, + communicate_logits=True, ) if args.flh is not None: diff --git a/bench/bench_helpers.py b/bench/bench_helpers.py index 4079cf3a6..17153ab2a 100644 --- a/bench/bench_helpers.py +++ b/bench/bench_helpers.py @@ -157,6 +157,7 @@ def load_dataset_token_ids( return None dataset_file_path = DATASET_PATHS[dataset_name] + print(f"Loading dataset '{dataset_name}' from: {dataset_file_path}") if not os.path.exists(dataset_file_path): print( f"Warning: Dataset file not found at {dataset_file_path}, falling back to random tokens") @@ -172,10 +173,16 @@ def load_dataset_token_ids( data = json.loads(line.strip()) text: str = data["text"] if use_chat_template and hasattr(tokenizer, 'apply_chat_template'): - tokens = tokenizer.apply_chat_template( + result = tokenizer.apply_chat_template( [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": text}], add_generation_prompt=True, ) + text_result = tokenizer.apply_chat_template( + [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": text}], + add_generation_prompt=True, + tokenize=False, + ) + tokens = result.input_ids if hasattr(result, 'input_ids') else result else: tokens = tokenizer.encode(text, add_special_tokens=False) diff --git a/bench/bench_paths.py b/bench/bench_paths.py index 5e2e5ec6a..c4dd72a48 100644 --- a/bench/bench_paths.py +++ b/bench/bench_paths.py @@ -52,6 +52,10 @@ def _required_env(var_name: str, note: str) -> str: "BENCH_LLAMA_1B", f"{HF_CACHE_DIR}/models--meta-llama--Llama-3.2-1B-Instruct", ), + "qwen_8b": os.environ.get( + "BENCH_QWEN_8B", + f"{HF_CACHE_DIR}/models--Qwen--Qwen3-8B", + ), "qwen_32b": os.environ.get( "BENCH_QWEN_32B", f"{HF_CACHE_DIR}/models--Qwen--Qwen3-32B", @@ -62,12 +66,16 @@ def _required_env(var_name: str, note: str) -> str: ), "eagle3_llama_70b": os.environ.get( "BENCH_EAGLE3_LLAMA_70B", - "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-SpecForge", + f"{HF_CACHE_DIR}/models--lmsys--SGLang-EAGLE3-Llama-3.3-70B-Instruct-SpecForge", ), "eagle3_qwen_32b": os.environ.get( "BENCH_EAGLE3_QWEN_32B", "Zhihu-ai/Zhi-Create-Qwen3-32B-Eagle3", ), + "phoenix2_qwen_8b": os.environ.get( + "BENCH_PHOENIX2_QWEN_8B", + "togethercomputer/phnx2-llama-decagon-4layer-v1.0", + ), } diff --git a/bench/run_sglang_bench.py b/bench/run_sglang_bench.py index 2949f8be7..c76a7b2c6 100644 --- a/bench/run_sglang_bench.py +++ b/bench/run_sglang_bench.py @@ -6,7 +6,7 @@ Usage: python run_sglang_bench.py --llama # SD, Llama 70B python run_sglang_bench.py --qwen # SD, Qwen 32B - python run_sglang_bench.py --llama --mode ar # autoregressive baseline + python run_sglang_bench.py --llama --mode AR # autoregressive baseline python run_sglang_bench.py --llama --wandb --name myrun # log to wandb Set model paths via env vars (BENCH_LLAMA_70B, etc.) or edit bench_paths.py. @@ -23,77 +23,37 @@ from bench_paths import MODELS, resolve_snapshot -def get_server_cmd(args): - if args.llama: - target = resolve_snapshot(MODELS["llama_70b"]) - draft = resolve_snapshot(MODELS["llama_1b"]) - else: - target = resolve_snapshot(MODELS["qwen_32b"]) - draft = resolve_snapshot(MODELS["qwen_0.6b"]) - - cmd = [ - sys.executable, "-m", "sglang.launch_server", - "--model-path", target, - "--tp", str(args.tp), - "--mem-fraction-static", str(args.mem_frac), - "--max-running-requests", "1", - "--disable-radix-cache", - "--log-level", "warning", - "--port", str(args.port), - ] - - if args.mode == "sd": - # Speculative decoding with standalone draft model. - # Default: k=5 (num_steps=4, num_draft_tokens=5). - cmd += [ - "--speculative-algorithm", "STANDALONE", - "--speculative-draft-model-path", draft, - "--speculative-num-steps", str(args.num_steps), - "--speculative-eagle-topk", "1", - "--speculative-num-draft-tokens", str(args.num_draft_tokens), - ] - # mode == "ar": no speculative flags, just serve the target model. - - return cmd, target - - -def wait_for_server(port, timeout=900, interval=5): - url = f"http://localhost:{port}/health" - deadline = time.time() + timeout - while time.time() < deadline: - try: - if requests.get(url, timeout=2).status_code == 200: - return True - except requests.ConnectionError: - pass - time.sleep(interval) - return False - - -def kill_server(proc): - if proc.poll() is None: - os.killpg(os.getpgid(proc.pid), signal.SIGKILL) - proc.wait() - - def main(): parser = argparse.ArgumentParser(description="Launch SGLang server and benchmark it") parser.add_argument("--llama", action="store_true", default=True) parser.add_argument("--qwen", action="store_true") - parser.add_argument("--mode", choices=["ar", "sd"], default="sd", + parser.add_argument("--mode", choices=["AR", "STANDALONE", "ASYNC_STANDALONE", "EAGLE3", "ASYNC_EAGLE3", "PHOENIX", "ASYNC_PHOENIX"], default="STANDALONE", help="ar = autoregressive, sd = speculative decoding (default)") parser.add_argument("--tp", type=int, default=4) parser.add_argument("--port", type=int, default=40010) - parser.add_argument("--mem_frac", type=float, default=0.70) - parser.add_argument("--num_steps", type=int, default=4, help="draft chain depth (k = num_steps + 1)") - parser.add_argument("--num_draft_tokens", type=int, default=5) + parser.add_argument("--mem-frac", type=float, default=0.70) + parser.add_argument("--num-steps", type=int, default=4, help="draft chain depth (k = num_steps + 1)") + parser.add_argument("--context-length", type=int, default=4096) # Pass-through to eval client parser.add_argument("--numseqs", type=int, default=128) - parser.add_argument("--output_len", type=int, default=512) + parser.add_argument("--output-len", type=int, default=512) parser.add_argument("--temp", type=float, default=0.0) + parser.add_argument("--dataset", type=str, choices=["all", "humaneval", "alpaca", "c4", "ultrafeedback", "random", "example"], default="all") parser.add_argument("--wandb", action="store_true") - parser.add_argument("--group", type=str, default=None) + parser.add_argument("--group", type=str, default="ssd") parser.add_argument("--name", type=str, default=None) + + parser.add_argument("--f", type=int, default=4, help="Async fan out value") + parser.add_argument("--fl", type=int, nargs='+', default=None, help="Fan out list (e.g., --fl 1 3 4 becomes [1, 3, 4])") + parser.add_argument("--flh", type=int, nargs='+', default=None, help="Fan out list (e.g., --flh 1 3 4 becomes [1, 3, 4])") + parser.add_argument("--flm", type=int, nargs='+', default=None, help="Fan out list miss (e.g., --flm 1 3 4 becomes [1, 3, 4])") + parser.add_argument("--jit", action="store_true") + parser.add_argument("--force-jit", action="store_true") + parser.add_argument("--communicate-cache-hits", action="store_true") + parser.add_argument("--verbose", action="store_true") + parser.add_argument("--acceptance-rate-log", type=str, default=None, + help="Path to log acceptance rates (sets ACCEPTANCE_RATE_LOG env var for the server)") + args = parser.parse_args() if args.qwen: args.llama = False @@ -107,7 +67,12 @@ def main(): capture_output=True) time.sleep(2) - proc = subprocess.Popen(server_cmd, preexec_fn=os.setsid) + env = os.environ.copy() + if args.acceptance_rate_log: + env["ACCEPTANCE_RATE_LOG"] = args.acceptance_rate_log + print(f"ACCEPTANCE_RATE_LOG={args.acceptance_rate_log}") + + proc = subprocess.Popen(server_cmd, preexec_fn=os.setsid, env=env) try: print("Waiting for server...") if not wait_for_server(args.port): @@ -122,15 +87,16 @@ def main(): "--numseqs", str(args.numseqs), "--output_len", str(args.output_len), "--temp", str(args.temp), - "--all", "--b", "1", + f"--{args.dataset}", + "--b", "1", "--port", str(args.port), ] if args.llama: eval_cmd.append("--llama") else: eval_cmd.append("--qwen") - if args.mode == "sd": - eval_cmd += ["--draft", "1" if args.llama else "0.6"] + if is_eagle3(args.mode): + eval_cmd.append("--eagle") if args.wandb: eval_cmd += ["--wandb"] if args.group: @@ -145,5 +111,124 @@ def main(): print("Server stopped") +def is_spec(mode): + return mode in ["STANDALONE", "ASYNC_STANDALONE", "EAGLE3", "ASYNC_EAGLE3", "PHOENIX2", "ASYNC_PHOENIX2"] + + +def is_async(mode): + return mode in ["ASYNC_STANDALONE", "ASYNC_EAGLE3", "ASYNC_PHOENIX"] + + +def is_standalone(mode): + return mode in ["STANDALONE", "ASYNC_STANDALONE"] + +def is_eagle3(mode): + return mode in ["EAGLE3", "ASYNC_EAGLE3"] + + +def is_phoenix(mode): + return mode in ["PHOENIX2", "ASYNC_PHOENIX2"] + + +def get_server_cmd(args): + if args.llama: + target = resolve_snapshot(MODELS["llama_70b"]) + if is_standalone(args.mode): + draft = resolve_snapshot(MODELS["llama_1b"]) + + elif is_eagle3(args.mode): + draft = resolve_snapshot(MODELS["eagle3_llama_70b"]) + else: + raise ValueError(f"Unsupported mode for llama: {args.mode}") + else: + target = resolve_snapshot(MODELS["qwen_32b"]) + if is_standalone(args.mode): + draft = resolve_snapshot(MODELS["qwen_0.6b"]) + elif is_eagle3(args.mode): + draft = resolve_snapshot(MODELS["eagle3_qwen_32b"]) + elif is_phoenix(args.mode): + target = resolve_snapshot(MODELS["qwen_8b"]) + draft = resolve_snapshot(MODELS["phoenix2_qwen_8b"]) + else: + raise ValueError(f"Unsupported mode for qwen: {args.mode}") + + cmd = [ + sys.executable, "-m", "sglang.launch_server", + "--model-path", target, + "--tp", str(args.tp), + "--mem-fraction-static", str(args.mem_frac), + "--max-running-requests", "1", + # "--disable-radix-cache", + "--log-level", "warning", + "--port", str(args.port), + "--context-length", str(args.context_length), + ] + + if is_spec(args.mode): + # Speculative decoding with standalone draft model. + # Default: k=5 (num_steps=4, num_draft_tokens=5). + cmd += [ + "--speculative-algorithm", args.mode, + "--speculative-draft-model-path", draft, + "--speculative-num-steps", str(args.num_steps), + "--speculative-eagle-topk", "1", + "--speculative-num-draft-tokens", str(args.num_steps + 1), + ] + if is_async(args.mode): + cmd += [ + "--speculative-async-fan-out", str(args.f), + ] + if args.fl: + cmd += [ + "--speculative-async-fan-out-list", ",".join(map(str, args.fl)), + ] + if args.flh: + cmd += [ + "--speculative-async-fan-out-list-hit", ",".join(map(str, args.flh)), + ] + if args.flm: + cmd += [ + "--speculative-async-fan-out-list-miss", ",".join(map(str, args.flm)), + ] + if args.jit or args.force_jit: + cmd += [ + "--speculative-async-jit-speculate", + ] + if args.force_jit: + cmd += [ + "--speculative-async-force-jit-speculate", + ] + if args.communicate_cache_hits: + cmd += [ + "--speculative-async-communicate-cache-hits", + ] + if args.verbose: + cmd += [ + "--speculative-async-verbose", + ] + + # mode == "ar": no speculative flags, just serve the target model. + return cmd, target + + +def wait_for_server(port, timeout=900, interval=5): + url = f"http://localhost:{port}/health" + deadline = time.time() + timeout + while time.time() < deadline: + try: + if requests.get(url, timeout=2).status_code == 200: + return True + except requests.ConnectionError: + pass + time.sleep(interval) + return False + + +def kill_server(proc): + if proc.poll() is None: + os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + proc.wait() + + if __name__ == "__main__": main() diff --git a/bench/small_test.py b/bench/small_test.py index 8131faf8b..4efb136ee 100644 --- a/bench/small_test.py +++ b/bench/small_test.py @@ -9,6 +9,7 @@ llama_1b_path = '/scratch/avner/huggingface/hub/models--meta-llama--Llama-3.2-1B-Instruct/snapshots/9213176726f574b556790deb65791e0c5aa438b6' llama_70b_path = '/scratch/avner/huggingface/hub/models--meta-llama--Llama-3.3-70B-Instruct/snapshots/6f6073b423013f6a7d4d9f39144961bfbfbc386b' eagle_path = '/scratch/avner/huggingface/hub/models--lmsys--SGLang-EAGLE3-Llama-3.3-70B-Instruct-SpecForge/snapshots/63ebaa6585f96b89685adad8fdfa0da53be6a8fd' + phoenix_path = '/scratch/avner/huggingface/hub/models--togethercomputer--phoenix-Llama-3p2-1B-Instruct-tgt-Llama-3p3-70b-instruct-UNTRAINED/snapshots/3af59d71514388e14d8685f2b684f74e3e311717' # eagle_path = '/scratch/avner/huggingface/hub/models--yuhuili--EAGLE3-LLaMA3.3-Instruct-70B' assert os.path.isdir(llama_1b_path) assert os.path.isdir(llama_70b_path) @@ -18,6 +19,7 @@ parser.add_argument("--model", type=str, default=llama_1b_path) parser.add_argument("--draft", type=str, default=llama_1b_path) parser.add_argument("--eagle", action="store_true") + parser.add_argument("--phoenix", action="store_true") parser.add_argument("--k", type=int, default=7) parser.add_argument("--jit-speculate", action="store_true") parser.add_argument("--num-gpus", type=int, default=2) @@ -36,10 +38,18 @@ args.jit_speculate = True args.chat_template = True + if args.phoenix: + args.draft = phoenix_path + args.model = llama_70b_path + args.num_gpus = 5 + args.jit_speculate = True + args.chat_template = True + llm = LLM( model=args.model, draft=args.draft, use_eagle=args.eagle, + use_phoenix=args.phoenix, speculate_k=args.k, speculate=True, draft_async=True, diff --git a/ssd/config.py b/ssd/config.py index c031746cc..558802943 100644 --- a/ssd/config.py +++ b/ssd/config.py @@ -33,13 +33,15 @@ class Config: fan_out_list_miss: list[int] | None = None sampler_x: float | None = None jit_speculate: bool = False + force_jit_speculate: bool = False async_nccl_port: int | None = None async_nccl_host: str = "127.0.0.1" communicate_logits: bool = False communicate_cache_hits: bool = False - # eagle3 + # eagle3 / phoenix use_eagle: bool = False + use_phoenix: bool = False eagle_layers: list[int] | None = None d_model_target: int | None = None tokenizer_path: str | None = None @@ -53,6 +55,10 @@ class Config: def max_blocks(self): return (self.max_model_len + self.kvcache_block_size - 1) // self.kvcache_block_size + @property + def use_eagle_or_phoenix(self): + return self.use_eagle or self.use_phoenix + def __post_init__(self): model = self.model assert os.path.isdir(model) @@ -79,12 +85,17 @@ def __post_init__(self): if self.fan_out_list is None: self.fan_out_list = [self.async_fan_out] * (self.speculate_k + 1) self.MQ_LEN = sum(self.fan_out_list) - if self.fan_out_list_miss is None: - self.fan_out_list_miss = self.fan_out_list + if not self.jit_speculate: + print(f'[Config] Setting fan_out_list_miss to [sum(fan_out_list)] + [0] * speculate_k because jit_speculate is False', flush=True) + self.fan_out_list_miss = [sum(self.fan_out_list)] + [0] * self.speculate_k + elif self.fan_out_list_miss is None: + # If you are jit speculating, always use the same fan_out_list for misses as for hits. + self.fan_out_list_miss = self.fan_out_list + assert sum(self.fan_out_list_miss) == sum(self.fan_out_list), "ERROR in Config: fan_out_list_miss must be the same as fan_out_list" - if self.use_eagle: - if self.eagle_layers is None: + if self.use_eagle_or_phoenix: + if self.use_eagle and self.eagle_layers is None: L = self.hf_config.num_hidden_layers # self.eagle_layers = [3, L//2, L-3] self.eagle_layers = [2, L//2, L-3] # [2, 16, 29] outputs, ie. [3, L//2+1, L-2] inputs diff --git a/ssd/engine/draft_runner.py b/ssd/engine/draft_runner.py index afb1af0e8..5882b5fc7 100644 --- a/ssd/engine/draft_runner.py +++ b/ssd/engine/draft_runner.py @@ -33,8 +33,8 @@ def create_draft_config(cls, cfg: Config) -> Config: cfg, model=cfg.draft, gpu_memory_utilization = (0.75 if not cfg.draft_async else 0.8), # REMAINING SPACE if not draft_async - tokenizer_path=cfg.model if cfg.use_eagle else None, - d_model_target=cfg.hf_config.hidden_size if cfg.use_eagle and cfg.hf_config else None, + tokenizer_path=cfg.model if cfg.use_eagle_or_phoenix else None, + d_model_target=cfg.hf_config.hidden_size if cfg.use_eagle_or_phoenix and cfg.hf_config else None, ) return draft_cfg @@ -49,15 +49,16 @@ def __init__(self, draft_cfg: Config, rank: int = 0, init_q = None): self.target_rank = 0 self.communicate_logits = self.config.communicate_logits self.communicate_cache_hits = self.config.communicate_cache_hits - - if self.config.use_eagle: - assert self.config.jit_speculate, \ - "EAGLE requires jit_speculate=True (cache misses need draft activations)" if self.is_draft and self.draft_async: self._reset_tree_cache_tensors() self._init_prealloc_buffers() self._draft_step_times = [] + self._acceptance_lengths = [] + self._cache_hits = [] + self._acceptance_rate_log_path = os.environ.get("ACCEPTANCE_RATE_LOG", None) + if self._acceptance_rate_log_path: + print(f'[{_ts()}] DraftRunner will log acceptance rate to: {self._acceptance_rate_log_path}', flush=True) print(f'[{_ts()}] DraftRunner set up, starting draft_loop', flush=True) self.draft_loop() @@ -67,8 +68,8 @@ def draft_async_prefill(self): if self.config.verbose: print(f'[{_ts()}] [draft_async_prefill] DRAFT ASYNC PREFILL STARTING', flush=True) - prefill_request = PrefillRequest.receive(self.async_pg, self.target_rank, self.device, metadata_buffer=self._prefill_metadata, tokenizer=self.tokenizer) - total_new_tokens, batch_size, max_blocks, use_eagle, eagle_act_dim = prefill_request.metadata.tolist() + prefill_request = PrefillRequest.receive(self.async_pg, self.target_rank, self.device, metadata_buffer=self._prefill_metadata) + total_new_tokens, batch_size, max_blocks, use_eagle_or_phoenix, eagle_phoenix_act_dim = prefill_request.metadata.tolist() input_ids = prefill_request.input_ids num_tokens = prefill_request.num_tokens draft_block_table = prefill_request.draft_block_table @@ -87,12 +88,16 @@ def draft_async_prefill(self): prefill_ctxt = self.prepare_prefill_ctxt(num_tokens, draft_block_table) - if use_eagle: - assert eagle_act_dim == 3 * self.config.d_model_target, ( - f"EAGLE activation dimension {eagle_act_dim} does not match expected dimension 3 * {self.config.d_model_target}" + if self.config.use_eagle: + assert eagle_phoenix_act_dim == 3 * self.config.d_model_target, ( + f"EAGLE activation dimension {eagle_phoenix_act_dim} does not match expected dimension 3 * {self.config.d_model_target}" + ) + elif self.config.use_phoenix: + assert eagle_phoenix_act_dim == self.config.d_model_target, ( + f"PHOENIX activation dimension {eagle_phoenix_act_dim} does not match expected dimension {self.config.d_model_target}" ) if self.config.verbose: - print(f'[{_ts()}] [draft_async_prefill] METADATA: total_new_tokens={total_new_tokens}, batch_size={batch_size}, max_blocks={max_blocks}, use_eagle={use_eagle}, eagle_act_dim={eagle_act_dim}', flush=True) + print(f'[{_ts()}] [draft_async_prefill] METADATA: total_new_tokens={total_new_tokens}, batch_size={batch_size}, max_blocks={max_blocks}, use_eagle_or_phoenix={use_eagle_or_phoenix}, eagle_phoenix_act_dim={eagle_phoenix_act_dim}', flush=True) # 5) set up context exactly like prepare_prefill() does: @@ -108,10 +113,7 @@ def draft_async_prefill(self): # 6) run the draft model in prefill mode positions = prefill_ctxt["positions"] - if self.config.use_eagle: - self.run_model(input_ids, positions, is_prefill=True, last_only=True, hidden_states=eagle_acts) - else: - self.run_model(input_ids, positions, is_prefill=True, last_only=True, hidden_states=eagle_acts) + self.run_model(input_ids, positions, is_prefill=True, last_only=True, hidden_states=eagle_acts) if self.config.verbose: print(f'[{_ts()}] [draft_async_prefill] DRAFT ASYNC PREFILL DONE', flush=True) @@ -155,11 +157,9 @@ def jit_speculate( draft_block_tables: torch.Tensor, target_recovery_activations: torch.Tensor = None, ): - input_ids = request_keys[:, -1] - pos_offset = -1 if self.config.use_eagle else 0 - positions = num_tokens - 1 + pos_offset # want to write rec token at post N-1 since [0, ..., N-2] filled by prefill - context_lens = num_tokens + pos_offset # N+1 + positions = num_tokens - 1 + context_lens = num_tokens # Calculate slot mapping vectorized block_idx = positions // self.block_size pos_in_block = positions % self.block_size @@ -168,13 +168,16 @@ def jit_speculate( hidden_states = None spec_activations = None - - if self.config.use_eagle: + + if self.config.use_eagle_or_phoenix: assert target_recovery_activations is not None - hidden_states = self.model.fc(target_recovery_activations.to(self.model.fc.weight.dtype)) + if self.config.use_eagle: + hidden_states = self.model.fc(target_recovery_activations.to(self.model.fc.weight.dtype)) + else: + hidden_states = target_recovery_activations spec_activations = torch.empty( input_ids.shape[0], self.config.speculate_k, - self.hf_config.hidden_size, + self.hidden_states_dim, dtype=self.hf_config.torch_dtype, device=self.device) for i in range(self.config.speculate_k): # we're going to glue after this anyways, and by sending the spec request target has verified we have K more slots left in our last page @@ -186,10 +189,13 @@ def jit_speculate( is_jit=True, ) - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: logits, prenorm = self.run_model(input_ids, positions, is_prefill=False, last_only=True, hidden_states=hidden_states) - spec_activations[:, i] = prenorm - hidden_states = prenorm + if self.config.use_eagle: + spec_activations[:, i] = prenorm + hidden_states = prenorm + else: + spec_activations[:, i] = hidden_states else: logits = self.run_model(input_ids, positions, is_prefill=False, last_only=True) @@ -218,33 +224,32 @@ def hit_cache(self, request_keys, B, K, num_tokens, temperatures, draft_block_ta # Init miss slots with valid random logits so token IDs are in-vocab (fixes B>1 crash) out_logits = torch.empty(B, K, V, dtype=self.hf_config.torch_dtype, device=self.device).uniform_() out_tokens = out_logits.argmax(dim=-1) - cache_hits = torch.zeros(B, dtype=torch.int64, device=self.device) + cache_hits = torch.zeros(B, dtype=torch.bool, device=self.device) assert request_keys.shape == (B, 3), f"ERROR in hit_cache: request_keys should be (B, 3), got {request_keys.shape}" - - hidden_size = self.hf_config.hidden_size + out_activations = torch.empty( - B, K, hidden_size, + B, K, self.hidden_states_dim, dtype=self.hf_config.torch_dtype, device=self.device - ) if self.config.use_eagle else None - + ) if self.config.use_eagle_or_phoenix else None + # Statistics ttl += int(B) - + if self.config.verbose: print(f"[{_ts()}] [hit_cache] Request keys: {request_keys}", flush=True) for i in range(B): rec_token = request_keys[i, 2].item() rec_text = self.tokenizer.decode([rec_token]) print(f"[{_ts()}] Req {i}: token={rec_token} ('{rec_text}')", flush=True) - + if self.tree_cache_keys.numel() > 0: # Vectorized membership against tensor cache eq = (request_keys.unsqueeze(1) == self.tree_cache_keys.unsqueeze(0)) # [B,T,3] match = torch.all(eq, dim=2) # [B,T] cache_hits = match.any(dim=1) # [B] ttl_hit += int(cache_hits.sum().item()) - + if self.config.verbose: print(f"[{_ts()}] [hit_cache] Cache hits: {cache_hits.sum().item()}/{B}", flush=True) print(f"[{_ts()}] [hit_cache] Cache: {self.tree_cache_keys.shape[0]} entries", flush=True) @@ -263,9 +268,9 @@ def hit_cache(self, request_keys, B, K, num_tokens, temperatures, draft_block_ta rec_text = self.tokenizer.decode([rec_token]) hit_marker = "[HIT]" if i in hit_indices else "" print(f"[{_ts()}] [{i}]: key=({seq_id}, {k_idx}, {rec_token}) -> value=('{rec_text}') {hit_marker}", flush=True) - + # Fill hits - if (cache_hits.any() and not self.config.jit_speculate) or (cache_hits.all() and self.config.jit_speculate): + if not self.config.force_jit_speculate and ((cache_hits.any() and not self.config.jit_speculate) or (cache_hits.all() and self.config.jit_speculate)): # print(f'[hit_cache] got all cache hits, using cached logits and tokens', flush=True) # [B], arbitrary if no match but masked out idx = match.float().argmax(dim=1).to(torch.int64) @@ -274,7 +279,7 @@ def hit_cache(self, request_keys, B, K, num_tokens, temperatures, draft_block_ta out_tokens[sel] = self.tree_cache_tokens[idx[sel]] # logits [T,K+1,V] out_logits[sel] = self.tree_cache_logits[idx[sel]] - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: out_activations[sel] = self.tree_cache_activations[idx[sel]] elif self.config.jit_speculate: # print(f'[hit_cache] found a cache miss, running jit speculate', flush=True) @@ -289,7 +294,7 @@ def hit_cache(self, request_keys, B, K, num_tokens, temperatures, draft_block_ta draft_block_tables, target_recovery_activations ) # write into out_logits, out_tokens - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: out_activations = jit_acts elif self.config.jit_speculate: # Cache is empty (first iteration), must JIT all @@ -304,9 +309,9 @@ def hit_cache(self, request_keys, B, K, num_tokens, temperatures, draft_block_ta draft_block_tables, target_recovery_activations ) - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: out_activations = jit_acts - + rec_toks = request_keys[:, 2] if self.config.verbose: @@ -345,9 +350,18 @@ def _service_spec_request(self): out_tokens, out_logits, glue_decode_input_ids, cache_hits, out_activations = self.hit_cache( cache_keys, B, K, num_tokens, temperatures, draft_block_tables, target_recovery_activations) + if self._acceptance_rate_log_path: + # Collect per-step metrics for logging. + # cache_keys[:, 1] is last_spec_step_accepted_len - 1 from the target; + # first request has -1 (forced miss). + for i in range(B): + accept_len = cache_keys[i, 1].item() + 1 + self._acceptance_lengths.append(accept_len) + self._cache_hits.append(int(cache_hits[i].item())) + speculation_response = SpeculationResponse( speculations=out_tokens.reshape(-1).to(torch.int64), - cache_hits=cache_hits.reshape(-1) if self.communicate_cache_hits else None, + cache_hits=cache_hits.reshape(-1).to(torch.int64) if self.communicate_cache_hits else None, logits_q=out_logits[:, :K, :].contiguous() if self.communicate_logits else None, ) if BRIEF_LOG: @@ -422,8 +436,7 @@ def prepare_prefill_ctxt( def prepare_glue_decode_ctxt(self, num_tokens, input_ids, dbt, B): K = self.config.speculate_k - pos_offset = -1 if self.config.use_eagle else 0 - positions_start = (num_tokens - 1 + pos_offset).unsqueeze(-1) + positions_start = (num_tokens - 1).unsqueeze(-1) positions_grid = positions_start + self._arange_kp1 # Calculate block indices and offsets for ALL positions @@ -441,7 +454,7 @@ def prepare_glue_decode_ctxt(self, num_tokens, input_ids, dbt, B): positions_flat = positions_grid.reshape(-1).to(torch.int64) slot_map_flat = slot_map_grid.reshape(-1).to(torch.int32) - context_lens = (num_tokens + pos_offset + K).to(torch.int32) + context_lens = (num_tokens + K).to(torch.int32) seqlen_q = torch.full((B,), K + 1, dtype=torch.int32, device=self.device) cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device=self.device) cu_seqlens_q[1:] = torch.cumsum(seqlen_q, dim=0) @@ -514,9 +527,8 @@ def _construct_tree_decode_args(self, partial_tree_decode_args, rec_flat, dbt): seq_ids = partial_tree_decode_args["seq_ids"] seq_ids_expanded = seq_ids[b_flat] - pos_offset = -1 if self.config.use_eagle else 0 - positions = (partial_tree_decode_args["num_tokens"][b_flat] - 1 + pos_offset) + (K + 1) + fkp1_flat - rope_positions = (partial_tree_decode_args["num_tokens"][b_flat] - 1 + pos_offset) + j_idx_flat + 1 + positions = (partial_tree_decode_args["num_tokens"][b_flat] - 1) + (K + 1) + fkp1_flat + rope_positions = (partial_tree_decode_args["num_tokens"][b_flat] - 1) + j_idx_flat + 1 temperatures = partial_tree_decode_args["temperatures"][b_flat] tree_decode_args = { @@ -541,9 +553,8 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): dbt = partial_tree_decode_args["dbt"] cache_hits = partial_tree_decode_args["cache_hits"] cache_hits_list = cache_hits.tolist() - pos_offset = -1 if self.config.use_eagle else 0 - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: B = partial_tree_decode_args["num_tokens"].shape[0] extend_counts = partial_tree_decode_args.get("extend_counts") if extend_counts is None: @@ -552,8 +563,8 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): extend_token_ids_batch = partial_tree_decode_args.get("extend_token_ids") target_acts = partial_tree_decode_args["target_recovery_activations"] prev_acts = partial_tree_decode_args["previous_activations"] - hidden_size = self.hf_config.hidden_size - fc_dtype = self.model.fc.weight.dtype + hidden_size = self.hidden_states_dim + fc_dtype = self.model.fc.weight.dtype if self.config.use_eagle else self.hf_config.torch_dtype gd_view = glue_decode_input_ids.view(B, K + 1) rec_tok_ids = gd_view[:, 0] @@ -598,7 +609,10 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): fused_ids[is_rec] = rec_tok_ids[batch_idx[is_rec]] # Single batched fc call - fused_hs[is_target_conditioned] = self.model.fc(tc_acts) + if self.config.use_eagle: + fused_hs[is_target_conditioned] = self.model.fc(tc_acts) + elif self.config.use_phoenix: + fused_hs[is_target_conditioned] = tc_acts # Spec tokens: ids from spec_tok_ids, hs from prev_acts (self-conditioned, no fc) spec_j = local_off[is_spec] - n_ext_per_tok[is_spec] - 1 # 0..K-1 @@ -628,8 +642,8 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): N_pre = _pre_b_flat.shape[0] _pre_metadata_ints = (B, K, self.config.async_fan_out, N_pre) _pre_seq_ids_expanded = partial_tree_decode_args["seq_ids"][_pre_b_flat] - _pre_positions = (partial_tree_decode_args["num_tokens"][_pre_b_flat] - 1 + pos_offset) + (K + 1) + _pre_fkp1_flat - _pre_rope_positions = (partial_tree_decode_args["num_tokens"][_pre_b_flat] - 1 + pos_offset) + _pre_j_idx_flat + 1 + _pre_positions = (partial_tree_decode_args["num_tokens"][_pre_b_flat] - 1) + (K + 1) + _pre_fkp1_flat + _pre_rope_positions = (partial_tree_decode_args["num_tokens"][_pre_b_flat] - 1) + _pre_j_idx_flat + 1 _pre_temperatures = partial_tree_decode_args["temperatures"][_pre_b_flat] # --- Run glue decode forward --- @@ -643,7 +657,7 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): ) glue_prenorm = None - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: fused_hs_flat = glue_decode_ctxt["hidden_states"] glue_decode_logits_flat, glue_prenorm = self.run_model( glue_decode_ctxt["input_ids"], glue_decode_ctxt["positions"], @@ -662,7 +676,7 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): reset_context() # --- Extract K+1 logits/prenorms at rec+spec positions --- - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: # Packed layout: rec at cu_seqlens_q[b] + n_ext[b], spec follows cu_q = glue_decode_ctxt["cu_seqlens_q"] rec_offsets = cu_q[:-1].long() + extend_counts.long() # [B] @@ -679,6 +693,7 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): # --- Build tree hidden states from K+1 prenorms --- tree_hidden_states = None if glue_prenorm is not None: + assert self.config.use_eagle_or_phoenix, "ERROR in _build_tree_batch: use_eagle_or_phoenix must be True when glue_prenorm is not None." # Vectorized: for each (b, depth), repeat prenorm by fan_out[depth] # fan_out_t[depth] for hits, fan_out_t_miss[depth] for misses fan_hit = self.config.fan_out_t # [K+1] @@ -690,12 +705,20 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): fan_miss.unsqueeze(0).expand(B, K + 1), ) # [B, K+1] reps_flat = per_batch_fan.reshape(-1) # [B*(K+1)] - prenorms_flat = glue_prenorm_kp1.reshape(B * (K + 1), -1) # [B*(K+1), d] - tree_hidden_states = torch.repeat_interleave(prenorms_flat, reps_flat, dim=0) + + if self.config.use_eagle: + prenorms_flat = glue_prenorm_kp1.reshape(B * (K + 1), -1) # [B*(K+1), d] + tree_hidden_states = torch.repeat_interleave(prenorms_flat, reps_flat, dim=0) + else: + assert self.config.use_phoenix + # Phoenix conditions on target activations, not prenorms + target_acts_expanded = target_acts.unsqueeze(1).expand(B, K + 1, -1) # [B, K+1, target_dim] + acts_flat = target_acts_expanded.reshape(B * (K + 1), -1) # [B*(K+1), target_dim] + tree_hidden_states = torch.repeat_interleave(acts_flat, reps_flat, dim=0) # --- Fork tokens from K+1 logits --- # Need [B, K+1] input_ids for forking (rec + spec tokens) - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: gd_for_fork = gd_view # [B, K+1] already computed above else: gd_for_fork = glue_decode_input_ids.reshape(B, K + 1) @@ -719,6 +742,7 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): "seq_ids_expanded": _pre_seq_ids_expanded, "cache_hits": cache_hits, "cache_hits_list": cache_hits_list, + "target_recovery_activations": partial_tree_decode_args["target_recovery_activations"], } tree_decode_args["hidden_states"] = tree_hidden_states return tree_decode_args @@ -743,7 +767,7 @@ def _compute_step_positions_and_slot_maps(self, initial_positions, initial_rope_ return step_positions, step_rope_positions, step_context_lens, step_slot_maps - def _decode_tree_step(self, depth, current_input_ids, step_rope_positions, step_slot_maps, step_context_lens, dbt, payload, spec_tokens, spec_logits, spec_activations): + def _decode_tree_step(self, depth, current_input_ids, step_rope_positions, step_slot_maps, step_context_lens, dbt, payload, spec_tokens, spec_logits, spec_activations, target_recovery_activations): """Execute a single tree decode step.""" # Use precomputed values for this step set_context( @@ -754,11 +778,15 @@ def _decode_tree_step(self, depth, current_input_ids, step_rope_positions, step_ ) hidden_states = payload.get("hidden_states") - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: logits, prenorm = self.run_model(current_input_ids, step_rope_positions[depth], is_prefill=False, last_only=False, tree_decode_step=depth, cache_hits=payload["cache_hits"], hidden_states=hidden_states) assert spec_activations is not None - spec_activations[:, depth] = prenorm - payload["hidden_states"] = prenorm + if self.config.use_eagle: + spec_activations[:, depth] = prenorm + payload["hidden_states"] = prenorm + else: + spec_activations[:, depth] = target_recovery_activations + payload["hidden_states"] = target_recovery_activations else: logits = self.run_model(current_input_ids, step_rope_positions[depth], is_prefill=False, last_only=False, tree_decode_step=depth, cache_hits=payload["cache_hits"]) @@ -785,9 +813,9 @@ def _decode_tree(self, payload): spec_logits = torch.empty( N, K, V, dtype=self.hf_config.torch_dtype, device=self.device) spec_activations = torch.empty( - N, K, self.hf_config.hidden_size, + N, K, self.hidden_states_dim, dtype=self.hf_config.torch_dtype, device=self.device - ) if self.config.use_eagle else None + ) if self.config.use_eagle_or_phoenix else None # Precompute all positions, context_lens, and slot_maps for all K steps # PERFORMANCE: no .clone() needed — these are not modified in-place @@ -795,6 +823,7 @@ def _decode_tree(self, payload): initial_rope_positions = payload["rope_positions"] # [N] current_input_ids = payload["input_ids"] # [N], the forked tokens dbt = payload["block_tables"] # [B, M] - constant across steps + target_recovery_activations = payload["target_recovery_activations"] # Use compiled function for batch-size independent computations _, step_rope_positions, step_context_lens, step_slot_maps = self._compute_step_positions_and_slot_maps( @@ -810,7 +839,7 @@ def _decode_tree(self, payload): _st = time.perf_counter() current_input_ids = self._decode_tree_step( depth, current_input_ids, step_rope_positions, step_slot_maps, - step_context_lens, dbt, payload, spec_tokens, spec_logits, spec_activations + step_context_lens, dbt, payload, spec_tokens, spec_logits, spec_activations, target_recovery_activations, ) if _prof or PROFILE_DRAFT: torch.cuda.synchronize() @@ -957,6 +986,20 @@ def _draft_loop_inner(self): if self._draft_step_times: avg_ms = sum(self._draft_step_times) * 1000 / len(self._draft_step_times) print(f"[{_ts()}] [metrics] Avg draft step time (ms): {avg_ms:.2f}", flush=True) + if self._acceptance_rate_log_path and self._acceptance_lengths: + import json + avg_acc = sum(self._acceptance_lengths) / len(self._acceptance_lengths) + hit_rate = sum(self._cache_hits) / len(self._cache_hits) if self._cache_hits else 0 + print(f"[{_ts()}] [metrics] Avg acceptance length: {avg_acc:.2f} ({len(self._acceptance_lengths)} steps)", flush=True) + print(f"[{_ts()}] [metrics] Cache hit rate: {hit_rate:.2%} ({sum(self._cache_hits)}/{len(self._cache_hits)})", flush=True) + print(f"[{_ts()}] [metrics] All acceptance lengths: {self._acceptance_lengths}", flush=True) + print(f"[{_ts()}] [metrics] All cache hits: {self._cache_hits}", flush=True) + print(f"[{_ts()}] [metrics] Logging acceptance lengths and cache hits to: {self._acceptance_rate_log_path}", flush=True) + with open(self._acceptance_rate_log_path, "w") as f: + json.dump({ + "acceptance_lengths": self._acceptance_lengths, + "cache_hits": self._cache_hits, + }, f) self.exit() break diff --git a/ssd/engine/helpers/cudagraph_helpers.py b/ssd/engine/helpers/cudagraph_helpers.py index b2d41887d..60d322491 100644 --- a/ssd/engine/helpers/cudagraph_helpers.py +++ b/ssd/engine/helpers/cudagraph_helpers.py @@ -314,14 +314,17 @@ def capture_cudagraph(model_runner): is_jit = (model_runner.config.speculate and model_runner.config.draft_async and model_runner.is_draft) # Eagle models need special handling during CUDA graph capture - is_eagle_draft = config.use_eagle and model_runner.is_draft - is_eagle_target = config.use_eagle and not model_runner.is_draft + is_eagle_or_phoenix_draft = config.use_eagle_or_phoenix and model_runner.is_draft + is_eagle_or_phoenix_target = config.use_eagle_or_phoenix and not model_runner.is_draft hidden_states = None - if is_eagle_draft: - # Use hidden_size (d_model_draft) so CG captures the pass-through branch in Eagle3DraftForCausalLM.forward() - # All callers project target acts via fc() BEFORE passing to CG - hidden_states = torch.zeros(max_bs, hf_config.hidden_size, - dtype=hf_config.torch_dtype, device=input_ids.device) + if is_eagle_or_phoenix_draft: + # Note: For Eagle3, all callers project target acts via fc() BEFORE passing to CG + hidden_states = torch.zeros( + max_bs, + model_runner.hidden_states_dim, + dtype=hf_config.torch_dtype, + device=input_ids.device, + ) total_graphs = len(graph_bs_list) print(f'[capture_cudagraph] Starting capture of {total_graphs} graphs, bs list: {graph_bs_list[:5]}...{graph_bs_list[-3:]} max_bs={max_bs}', flush=True) @@ -330,10 +333,10 @@ def capture_cudagraph(model_runner): graph = torch.cuda.CUDAGraph() set_context( False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs], is_jit=is_jit) - if is_eagle_draft: + if is_eagle_or_phoenix_draft: outputs[:bs] = model_runner.model( input_ids[:bs], positions[:bs], hidden_states[:bs]) # warmup - elif is_eagle_target: + elif is_eagle_or_phoenix_target: out, _ = model_runner.model( input_ids[:bs], positions[:bs]) # warmup outputs[:bs] = out @@ -341,10 +344,10 @@ def capture_cudagraph(model_runner): outputs[:bs] = model_runner.model( input_ids[:bs], positions[:bs]) # warmup with torch.cuda.graph(graph, graph_pool): - if is_eagle_draft: + if is_eagle_or_phoenix_draft: outputs[:bs] = model_runner.model( input_ids[:bs], positions[:bs], hidden_states[:bs]) # capture - elif is_eagle_target: + elif is_eagle_or_phoenix_target: out, _ = model_runner.model( input_ids[:bs], positions[:bs]) # capture outputs[:bs] = out @@ -379,7 +382,7 @@ def capture_verify_cudagraph(model_runner): max_bs = min(model_runner.config.max_num_seqs, 512) k_plus_1 = model_runner.config.speculate_k + 1 - is_eagle_target = config.use_eagle and not model_runner.is_draft + is_eagle_or_phoenix_target = config.use_eagle_or_phoenix and not model_runner.is_draft # For verify, we need to handle k+1 tokens per sequence, and use cu_seqlens_q and max_seqlen_q input_ids = torch.zeros(max_bs * k_plus_1, dtype=torch.int64) @@ -391,12 +394,14 @@ def capture_verify_cudagraph(model_runner): outputs = torch.zeros(max_bs * k_plus_1, hf_config.hidden_size) cu_seqlens_q = torch.zeros(max_bs + 1, dtype=torch.int32) - # Eagle target: also capture eagle_acts from model forward + # Eagle/Phoenix target: also capture activations from model forward eagle_acts = None - if is_eagle_target: - # eagle_acts has shape [num_tokens, 3 * hidden_size] for 3 layers - eagle_acts = torch.zeros(max_bs * k_plus_1, 3 * hf_config.hidden_size, - dtype=hf_config.torch_dtype) + if is_eagle_or_phoenix_target: + eagle_acts = torch.zeros( + max_bs * k_plus_1, + model_runner.eagle_acts_dim, + dtype=hf_config.torch_dtype, + ) base = [1, 2, 4, 8] dynamic = list(range(16, max_bs+1, 16)) @@ -517,6 +522,7 @@ def run_glue_decode_cudagraph(model_runner, input_ids, positions, last_only, gra outputs = graph_vars["outputs"][:orig_flat] logits = model_runner.model.compute_logits(outputs, last_only) + assert logits.dim() == 2, "ERROR in run_glue_decode_cudagraph: logits must be 2D" if "eagle_hidden_states" in graph_vars: return logits, outputs return logits @@ -541,9 +547,14 @@ def capture_glue_decode_cudagraph(model_runner): outputs = torch.empty(max_flat, hf_config.hidden_size, device=model_runner.device) cu_seqlens_q = torch.zeros(max_bs + 1, dtype=torch.int32, device=model_runner.device) - eagle_hs = None - if config.use_eagle and model_runner.is_draft: - eagle_hs = torch.zeros(max_flat, hf_config.hidden_size, dtype=hf_config.torch_dtype, device=model_runner.device) + eagle_hidden_states = None + if config.use_eagle_or_phoenix and model_runner.is_draft: + eagle_hidden_states = torch.zeros( + max_flat, + model_runner.hidden_states_dim, + dtype=hf_config.torch_dtype, + device=model_runner.device, + ) graph_bs_list = [1] for bs in [2, 4, 8] + list(range(16, max_bs + 1, 16)): @@ -577,14 +588,14 @@ def capture_glue_decode_cudagraph(model_runner): block_tables=block_tables[:bs], ) - if eagle_hs is not None: - outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat], eagle_hs[:flat]) + if eagle_hidden_states is not None: + outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat], eagle_hidden_states[:flat]) else: outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat]) with torch.cuda.graph(graph, graph_pool): - if eagle_hs is not None: - outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat], eagle_hs[:flat]) + if eagle_hidden_states is not None: + outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat], eagle_hidden_states[:flat]) else: outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat]) @@ -603,8 +614,8 @@ def capture_glue_decode_cudagraph(model_runner): cu_seqlens_q=cu_seqlens_q, outputs=outputs, ) - if eagle_hs is not None: - graph_vars["eagle_hidden_states"] = eagle_hs + if eagle_hidden_states is not None: + graph_vars["eagle_hidden_states"] = eagle_hidden_states return graph_vars, graph_pool, graphs, graph_bs_list @@ -639,9 +650,13 @@ def capture_fi_tree_decode_cudagraph(model_runner): graph_pool = None fi_hidden_states = None - if config.use_eagle and model_runner.is_draft: - fi_hidden_states = torch.zeros(max_flat_batch_size, hf_config.hidden_size, - dtype=hf_config.torch_dtype, device=model_runner.device) + if config.use_eagle_or_phoenix and model_runner.is_draft: + fi_hidden_states = torch.zeros( + max_flat_batch_size, + model_runner.hidden_states_dim, + dtype=hf_config.torch_dtype, + device=model_runner.device, + ) # Pre-allocate tree_cu_seqlens_q per batch size bucket (constant values, used by FA4) tree_cu_seqlens_q_dict = {} diff --git a/ssd/engine/llm_engine.py b/ssd/engine/llm_engine.py index b14564eec..ca42417c3 100644 --- a/ssd/engine/llm_engine.py +++ b/ssd/engine/llm_engine.py @@ -312,8 +312,8 @@ def create_inference_step(self, config: Config) -> InferenceStep: draft_dtype=config.draft_hf_config.torch_dtype, kvcache_block_size=config.kvcache_block_size, max_model_len=config.max_model_len, - eagle=config.use_eagle, - eagle_act_dim=3 * config.hf_config.hidden_size if config.use_eagle else 0, + eagle=config.use_eagle_or_phoenix, + eagle_act_dim=self.model_runner.eagle_acts_dim if config.use_eagle_or_phoenix else 0, communicate_logits=config.communicate_logits, communicate_cache_hits=config.communicate_cache_hits, async_pg=self.model_runner.async_pg, @@ -342,7 +342,7 @@ def create_inference_step(self, config: Config) -> InferenceStep: scheduler=self.scheduler, speculator=speculator, verifier=verifier, - eagle=config.use_eagle, + eagle=config.use_eagle_or_phoenix, tokenizer=self.tokenizer, async_spec=config.draft_async, ) diff --git a/ssd/engine/model_runner.py b/ssd/engine/model_runner.py index 6c1223d5a..a175863a6 100644 --- a/ssd/engine/model_runner.py +++ b/ssd/engine/model_runner.py @@ -13,6 +13,7 @@ from ssd.models.qwen3 import Qwen3ForCausalLM from ssd.models.llama3 import LlamaForCausalLM from ssd.models.eagle3_draft_llama3 import Eagle3DraftForCausalLM +from ssd.models.phoenix_draft_llama3 import PhoenixLlamaForCausalLM from ssd.layers.sampler import Sampler from ssd.utils.context import set_context, reset_context, get_context from ssd.utils.loader import load_model @@ -74,6 +75,7 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event], is_dra self.world_size = config.num_gpus if should_use_dist else 1 self.rank = rank self.use_eagle = config.use_eagle + self.use_phoenix = config.use_phoenix if config.draft_async: self.draft_rank = config.num_gpus - 1 @@ -119,7 +121,7 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event], is_dra assert num_tp_gpus == 1, "ERROR in ModelRunner: draft should have tp_size=1" self.tp_pg = None # every rank is given an object from self.tp_pg, even tho draft doesnt participate it gets GROUP_NON_MEMBER object != None back, so we can't assert None here, we - print(f'[model_runner] about to setup and warmup model and cudagraphs, is use_eagle={self.use_eagle}', flush=True) + print(f'[model_runner] about to setup and warmup model and cudagraphs, is use_eagle={self.use_eagle}, is use_phoenix={self.use_phoenix}', flush=True) model_type = self.setup_and_warmup_model_and_cudagraphs(config, self.hf_config, init_q, is_draft) if self.verbose: print(f'-----CAPTURED {model_type}CUDAGRAPH----', flush=True) @@ -172,6 +174,9 @@ def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoC if config.use_eagle and is_draft: print(f'[EAGLE3] Loading Eagle3DraftForCausalLM as model_class', flush=True) model_class = Eagle3DraftForCausalLM + elif config.use_phoenix and is_draft: + print(f'[PHOENIX] Loading PhoenixDraftForCausalLM as model_class', flush=True) + model_class = PhoenixLlamaForCausalLM elif hf_config.model_type == 'llama': model_class = LlamaForCausalLM elif hf_config.model_type == 'qwen3': @@ -191,11 +196,12 @@ def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoC tp_size=self.num_tp_gpus, ) - if config.use_eagle: - kwargs['use_eagle'] = True + if config.use_eagle_or_phoenix: + kwargs['use_eagle'] = config.use_eagle + kwargs['use_phoenix'] = config.use_phoenix kwargs['eagle_layers'] = self.config.eagle_layers - - if model_class == Eagle3DraftForCausalLM: + + if model_class in [Eagle3DraftForCausalLM, PhoenixLlamaForCausalLM]: kwargs['d_model_target'] = config.d_model_target kwargs['debug_mode'] = config.debug_mode @@ -262,7 +268,7 @@ def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoC self.graph_pools["decode"] = decode_graph_pool self.graphs["decode"] = decode_graphs self.graph_bs_list["decode"] = decode_graph_bs_list - if self.config.speculate and not (self.is_draft and self.config.use_eagle): # verify CG: target always, non-EAGLE draft for fan-out; EAGLE draft uses glue_decode CG instead + if self.config.speculate and not (self.is_draft and self.config.use_eagle_or_phoenix): # verify CG: target always, non-EAGLE draft for fan-out; EAGLE draft uses glue_decode CG instead verify_graph_vars, verify_graph_pool, verify_graphs, verify_graph_bs_list = capture_verify_cudagraph(self) self.graph_vars["verify"] = verify_graph_vars self.graph_pools["verify"] = verify_graph_pool @@ -274,7 +280,7 @@ def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoC self.graph_pools["fi_tree_decode"] = fi_tree_decode_graph_pool self.graphs["fi_tree_decode"] = fi_tree_decode_graphs self.graph_bs_list["fi_tree_decode"] = fi_tree_decode_graph_bs_list - if self.config.speculate and self.is_draft and self.config.draft_async and self.config.use_eagle: + if self.config.speculate and self.is_draft and self.config.draft_async and self.config.use_eagle_or_phoenix: glue_gv, glue_pool, glue_graphs, glue_bs_list = capture_glue_decode_cudagraph(self) self.graph_vars["glue_decode"] = glue_gv self.graph_pools["glue_decode"] = glue_pool @@ -440,10 +446,15 @@ def warmup_model(self): seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)] hidden_states = None - if self.config.use_eagle and self.is_draft: + if self.config.use_eagle_or_phoenix and self.is_draft: num_tokens = num_seqs * max_model_len d_model_target = self.config.d_model_target or 4096 - hidden_states = torch.zeros(num_tokens, 3 * d_model_target, dtype=self.hf_config.torch_dtype, device=self.device) + if self.config.use_eagle: + hidden_states = torch.zeros(num_tokens, 3 * d_model_target, dtype=self.hf_config.torch_dtype, device=self.device) + elif self.config.use_phoenix: + hidden_states = torch.zeros(num_tokens, d_model_target, dtype=self.hf_config.torch_dtype, device=self.device) + else: + raise ValueError(f"Unsupported model type: {self.config.use_eagle_or_phoenix}") self.run(seqs, True, hidden_states=hidden_states) torch.cuda.empty_cache() @@ -581,6 +592,21 @@ def eager_tree_decode_plan(self, input_ids, positions, step, cache_hits): device=self.device, ) + @property + def hidden_states_dim(self): + # The dimension of the hidden states that are concatenated with the draft tokens embeddings + # as the input to the Eagle/Phoenix draft model. + assert self.config.use_eagle_or_phoenix and self.is_draft + return self.config.hf_config.hidden_size if self.config.use_eagle else self.config.d_model_target + + @property + def eagle_acts_dim(self): + assert self.config.use_eagle_or_phoenix and not self.is_draft + if self.config.eagle_layers: + return len(self.config.eagle_layers) * self.config.hf_config.hidden_size + else: + return self.config.hf_config.hidden_size + @torch.inference_mode() def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool, last_only: bool = True, tree_decode_step: int = -1, cache_hits: torch.Tensor | None = None, hidden_states: torch.Tensor | None = None): is_tree_decode = self.is_draft and self.config.draft_async and tree_decode_step >= 0 @@ -593,10 +619,10 @@ def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill if is_tree_decode: self.eager_tree_decode_plan(input_ids, positions, tree_decode_step, cache_hits) - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: if self.is_draft: assert hidden_states is not None, "hidden_states required for EAGLE draft" - assert isinstance(self.model, Eagle3DraftForCausalLM) + assert isinstance(self.model, Eagle3DraftForCausalLM) or isinstance(self.model, PhoenixLlamaForCausalLM) prenorm = self.model(input_ids, positions, hidden_states) logits = self.model.compute_logits(prenorm, last_only) return logits, prenorm # return prenorm as conditioning vector for next iteration @@ -646,7 +672,7 @@ def run( # Handle EAGLE returning (logits, conditioning_vector for next iter) conditioning = None - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: logits, conditioning = self.run_model( input_ids, positions, is_prefill, last_only, hidden_states=hidden_states) else: @@ -655,7 +681,7 @@ def run( if _pt: torch.cuda.synchronize() _r2 = time.perf_counter() - print(f"[PROFILE target_run] prepare_decode={(_r1-_r0)*1000:.2f}ms run_model={(_r2-_r1)*1000:.2f}ms eagle={self.config.use_eagle} n_ids={input_ids.shape[0]}", flush=True) + print(f"[PROFILE target_run] prepare_decode={(_r1-_r0)*1000:.2f}ms run_model={(_r2-_r1)*1000:.2f}ms eagle={self.config.use_eagle}, phoenix={self.config.use_phoenix}, n_ids={input_ids.shape[0]}", flush=True) if last_only: token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None @@ -668,5 +694,3 @@ def run( if conditioning is not None: return logits, conditioning return logits - - diff --git a/ssd/engine/speculator_async.py b/ssd/engine/speculator_async.py index a5e3abc87..f61d1212d 100644 --- a/ssd/engine/speculator_async.py +++ b/ssd/engine/speculator_async.py @@ -75,18 +75,17 @@ def _prepare_prefill_request(self, seqs: list[Sequence], verify_result: VerifyRe eagle_acts = verify_result.eagle_acts input_id_list = [seq.token_ids for seq in seqs] - # EAGLE token-conditioning shift: token at position j gets conditioning - # from target act at position j-1. Skip first token per seq and drop - # last eagle_act per seq so they align correctly. + # EAGLE/Phoenix token-conditioning shift: we duplicate the first target activation for each sequence. + # [t0, h0], [t1, h0], [t2, h1], [t3, h2], ... if eagle_acts is not None: sliced = [] offset = 0 for ids in input_id_list: seq_len = len(ids) + sliced.append(eagle_acts[offset:offset + 1]) sliced.append(eagle_acts[offset:offset + seq_len - 1]) offset += seq_len eagle_acts = torch.cat(sliced, dim=0) - input_id_list = [ids[1:] for ids in input_id_list] max_blocks = (self.max_model_len + self.kvcache_block_size - 1) // self.kvcache_block_size input_ids_flat = [] diff --git a/ssd/layers/linear.py b/ssd/layers/linear.py index b25824172..d605caaa5 100755 --- a/ssd/layers/linear.py +++ b/ssd/layers/linear.py @@ -89,6 +89,9 @@ def __init__( def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data + if param_data.dim() == 1: # bias — no sharding needed + param_data.copy_(loaded_weight) + return shard_size = param_data.size(self.tp_dim) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) @@ -115,6 +118,9 @@ def __init__( def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int): param_data = param.data + if param_data.dim() == 1: # bias — no sharding needed + param_data.copy_(loaded_weight) + return shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size shard_size = self.output_sizes[loaded_shard_id] // self.tp_size param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) @@ -147,6 +153,9 @@ def __init__( def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str): param_data = param.data + if param_data.dim() == 1: # bias — no sharding needed + param_data.copy_(loaded_weight) + return assert loaded_shard_id in ["q", "k", "v"] if loaded_shard_id == "q": shard_size = self.num_heads * self.head_size @@ -187,6 +196,9 @@ def __init__( def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data + if param_data.dim() == 1: # bias — no sharding needed + param_data.copy_(loaded_weight) + return shard_size = param_data.size(self.tp_dim) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) diff --git a/ssd/models/eagle3_draft_llama3.py b/ssd/models/eagle3_draft_llama3.py index a74dd413f..71c19a1b9 100644 --- a/ssd/models/eagle3_draft_llama3.py +++ b/ssd/models/eagle3_draft_llama3.py @@ -219,6 +219,7 @@ def __init__( draft: bool = False, speculate: bool = False, use_eagle: bool = False, + use_phoenix: bool = False, eagle_layers: list[int] | None = None, d_model_target: int = 4096, spec_k: int = 1, @@ -233,6 +234,7 @@ def __init__( assert draft, "ERROR in Eagle3DraftForLlama3: draft must be True" assert use_eagle, "ERROR in Eagle3DraftForLlama3: config.use_eagle must be True" assert eagle_layers is not None, "ERROR in Eagle3DraftForLlama3: eagle_layers must be set" + assert not use_phoenix, "ERROR in Eagle3DraftForLlama3: config.use_phoenix must be False" # this will be the draft that does tree decode, just needs a modified fwd pass that takes in hidden states and uses fc and dicts to sample, etc self.config = config diff --git a/ssd/models/llama3.py b/ssd/models/llama3.py index a9934ad5d..091df664e 100755 --- a/ssd/models/llama3.py +++ b/ssd/models/llama3.py @@ -210,6 +210,7 @@ def __init__( async_fan_out: int = 1, draft_async: bool = False, use_eagle: bool = False, + use_phoenix: bool = False, eagle_layers: list[int] | None = None, tp_group: dist.ProcessGroup | None = None, tp_size: int = 1, @@ -221,8 +222,9 @@ def __init__( self.async_fan_out = async_fan_out self.draft_async = draft_async self.use_eagle = use_eagle + self.use_phoenix = use_phoenix self.eagle_layers = eagle_layers - print(f'[LlamaModel] use_eagle={use_eagle}, eagle_layers={eagle_layers}', flush=True) + print(f'[LlamaModel] use_eagle={use_eagle}, use_phoenix={use_phoenix}, eagle_layers={eagle_layers}', flush=True) self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -249,24 +251,33 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, + hidden_states: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - hidden_states = self.embed_tokens(input_ids) # torch.Size([4096, 2560]) always through residual stream + if hidden_states is None: + hidden_states = self.embed_tokens(input_ids) residual = None # Collect activations if use_eagle - collected_acts = [] if self.use_eagle else None + collected_acts = [] if not self.draft and (self.use_eagle or self.use_phoenix) else None for layer_idx, layer in enumerate(self.layers): - if collected_acts is not None and layer_idx in self.eagle_layers: + if collected_acts is not None and self.eagle_layers is not None and layer_idx in self.eagle_layers: current_act = hidden_states if residual is None else hidden_states + residual collected_acts.append(current_act) hidden_states, residual = layer(positions, hidden_states, residual) - hidden_states, _ = self.norm(hidden_states, residual) - - if collected_acts: - eagle_acts = torch.cat(collected_acts, dim=-1) + + if not self.draft and self.use_phoenix: + assert self.eagle_layers is None, "ERROR in LlamaModel: use_phoenix and eagle_layers are not compatible" + collected_acts.append(hidden_states) + + if collected_acts is not None: + if len(collected_acts) > 1: + eagle_acts = torch.cat(collected_acts, dim=-1) + else: + assert len(collected_acts) == 1 + eagle_acts = collected_acts[0] print(f'[LlamaModel] eagle_acts shape={eagle_acts.shape}', flush=True) return hidden_states, eagle_acts else: @@ -284,9 +295,11 @@ class LlamaForCausalLM(nn.Module): def __init__( self, - config: LlamaConfig, draft: bool = False, + config: LlamaConfig, + draft: bool = False, speculate: bool = False, use_eagle: bool = False, + use_phoenix: bool = False, eagle_layers: list[int] | None = None, spec_k: int = 1, async_fan_out: int = 1, @@ -301,6 +314,7 @@ def __init__( self.async_fan_out = async_fan_out self.draft_async = draft_async self.use_eagle = use_eagle + self.use_phoenix = use_phoenix self.eagle_layers = eagle_layers self.tp_group = tp_group self.tp_size = tp_size @@ -310,7 +324,19 @@ def __init__( print(f'Starting LlamaForCausalLM init, draft={draft}, speculate={speculate}, spec_k={spec_k}') print(f'[LlamaForCausalLM] use_eagle={use_eagle}, eagle_layers={eagle_layers}', flush=True) - self.model = LlamaModel(config, draft, speculate, spec_k, async_fan_out, draft_async, use_eagle=use_eagle, eagle_layers=eagle_layers, tp_group=tp_group, tp_size=self.tp_size) + self.model = LlamaModel( + config, + draft, + speculate, + spec_k, + async_fan_out, + draft_async, + use_eagle=use_eagle, + use_phoenix=use_phoenix, + eagle_layers=eagle_layers, + tp_group=tp_group, + tp_size=self.tp_size, + ) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, diff --git a/ssd/models/phoenix_draft_llama3.py b/ssd/models/phoenix_draft_llama3.py new file mode 100644 index 000000000..2b25401cc --- /dev/null +++ b/ssd/models/phoenix_draft_llama3.py @@ -0,0 +1,74 @@ +import torch +import torch.distributed as dist +from transformers import LlamaConfig + +from ssd.layers.linear import RowParallelLinear +from ssd.models.llama3 import LlamaForCausalLM + + +class PhoenixLlamaForCausalLM(LlamaForCausalLM): + def __init__( + self, + config: LlamaConfig, + draft: bool = True, + speculate: bool = True, + use_eagle: bool = False, + use_phoenix: bool = True, + eagle_layers: list[int] | None = None, + d_model_target: int = 4096, + spec_k: int = 1, + async_fan_out: int = 1, + draft_async: bool = False, + tp_group: dist.ProcessGroup | None = None, + tp_size: int = 1, + debug_mode: bool = False, + ) -> None: + assert draft, "ERROR in PhoenixLlamaForCausalLM: draft must be True" + assert use_phoenix, "ERROR in PhoenixLlamaForCausalLM: config.use_phoenix must be True" + assert not use_eagle, "ERROR in PhoenixLlamaForCausalLM: config.use_eagle must be False" + super().__init__( + config, + draft=True, + speculate=True, + use_eagle=False, + use_phoenix=True, + eagle_layers=None, + spec_k=spec_k, + async_fan_out=async_fan_out, + draft_async=draft_async, + tp_group=tp_group, + tp_size=tp_size, + ) + self.d_model_target = d_model_target + self.debug_mode = debug_mode + self.eh_proj = RowParallelLinear( + self.d_model_target + config.hidden_size, + config.hidden_size, + bias=True, + tp_group=tp_group, + tp_size=tp_size, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + input_embeds = self.model.embed_tokens(input_ids) + hidden_states = torch.cat((input_embeds, hidden_states), dim=-1) + hidden_states = self.eh_proj(hidden_states.to(self.eh_proj.weight.dtype)) + out = self.model(input_ids, positions, hidden_states) + return out + + def compute_logits( + self, + hidden_states: torch.Tensor, + last_only: bool = True, + ) -> torch.Tensor: + logits = self.lm_head(hidden_states, last_only=last_only) + + if logits.dim() == 3: + logits = logits.view(-1, logits.shape[-1]) + + return logits