From b8c1fd75498da2c7be0d78078fdc2c1102ca6f96 Mon Sep 17 00:00:00 2001 From: Avner May Date: Sun, 22 Mar 2026 18:16:40 -0700 Subject: [PATCH 1/6] Support for Phoenix V1 --- bench/small_test.py | 10 ++ ssd/config.py | 19 +++- ssd/engine/draft_runner.py | 131 +++++++++++++----------- ssd/engine/helpers/cudagraph_helpers.py | 73 +++++++------ ssd/engine/llm_engine.py | 6 +- ssd/engine/model_runner.py | 54 +++++++--- ssd/engine/speculator_async.py | 7 +- ssd/layers/linear.py | 12 +++ ssd/models/eagle3_draft_llama3.py | 2 + ssd/models/llama3.py | 46 +++++++-- ssd/models/phoenix_draft_llama3.py | 74 +++++++++++++ 11 files changed, 310 insertions(+), 124 deletions(-) create mode 100644 ssd/models/phoenix_draft_llama3.py diff --git a/bench/small_test.py b/bench/small_test.py index 337665c6a..a59f23406 100644 --- a/bench/small_test.py +++ b/bench/small_test.py @@ -9,6 +9,7 @@ llama_1b_path = '/scratch/avner/huggingface/hub/models--meta-llama--Llama-3.2-1B-Instruct/snapshots/9213176726f574b556790deb65791e0c5aa438b6' llama_70b_path = '/scratch/avner/huggingface/hub/models--meta-llama--Llama-3.3-70B-Instruct/snapshots/6f6073b423013f6a7d4d9f39144961bfbfbc386b' eagle_path = '/scratch/avner/huggingface/hub/models--lmsys--SGLang-EAGLE3-Llama-3.3-70B-Instruct-SpecForge/snapshots/63ebaa6585f96b89685adad8fdfa0da53be6a8fd' + phoenix_path = '/scratch/avner/huggingface/hub/models--togethercomputer--phoenix-Llama-3p2-1B-Instruct-tgt-Llama-3p3-70b-instruct-UNTRAINED/snapshots/3af59d71514388e14d8685f2b684f74e3e311717' # eagle_path = '/scratch/avner/huggingface/hub/models--yuhuili--EAGLE3-LLaMA3.3-Instruct-70B' assert os.path.isdir(llama_1b_path) assert os.path.isdir(llama_70b_path) @@ -18,6 +19,7 @@ parser.add_argument("--model", type=str, default=llama_1b_path) parser.add_argument("--draft", type=str, default=llama_1b_path) parser.add_argument("--eagle", action="store_true") + parser.add_argument("--phoenix", action="store_true") parser.add_argument("--k", type=int, default=6) parser.add_argument("--jit-speculate", action="store_true") parser.add_argument("--num-gpus", type=int, default=2) @@ -34,10 +36,18 @@ args.jit_speculate = True args.chat_template = True + if args.phoenix: + args.draft = phoenix_path + args.model = llama_70b_path + args.num_gpus = 5 + args.jit_speculate = True + args.chat_template = True + llm = LLM( model=args.model, draft=args.draft, use_eagle=args.eagle, + use_phoenix=args.phoenix, speculate_k=args.k, speculate=True, draft_async=True, diff --git a/ssd/config.py b/ssd/config.py index c031746cc..5d1c7ea63 100644 --- a/ssd/config.py +++ b/ssd/config.py @@ -38,8 +38,9 @@ class Config: communicate_logits: bool = False communicate_cache_hits: bool = False - # eagle3 + # eagle3 / phoenix use_eagle: bool = False + use_phoenix: bool = False eagle_layers: list[int] | None = None d_model_target: int | None = None tokenizer_path: str | None = None @@ -53,6 +54,10 @@ class Config: def max_blocks(self): return (self.max_model_len + self.kvcache_block_size - 1) // self.kvcache_block_size + @property + def use_eagle_or_phoenix(self): + return self.use_eagle or self.use_phoenix + def __post_init__(self): model = self.model assert os.path.isdir(model) @@ -79,12 +84,16 @@ def __post_init__(self): if self.fan_out_list is None: self.fan_out_list = [self.async_fan_out] * (self.speculate_k + 1) self.MQ_LEN = sum(self.fan_out_list) - if self.fan_out_list_miss is None: - self.fan_out_list_miss = self.fan_out_list + if not self.jit_speculate: + print(f'[Config] Setting fan_out_list_miss to [sum(fan_out_list)] + [0] * speculate_k because jit_speculate is False', flush=True) + self.fan_out_list_miss = [sum(self.fan_out_list)] + [0] * self.speculate_k + elif self.fan_out_list_miss is None: + self.fan_out_list_miss = self.fan_out_list + assert sum(self.fan_out_list_miss) == sum(self.fan_out_list), "ERROR in Config: fan_out_list_miss must be the same as fan_out_list" - if self.use_eagle: - if self.eagle_layers is None: + if self.use_eagle_or_phoenix: + if self.use_eagle and self.eagle_layers is None: L = self.hf_config.num_hidden_layers # self.eagle_layers = [3, L//2, L-3] self.eagle_layers = [2, L//2, L-3] # [2, 16, 29] outputs, ie. [3, L//2+1, L-2] inputs diff --git a/ssd/engine/draft_runner.py b/ssd/engine/draft_runner.py index 32a82fb1d..8b37a5928 100644 --- a/ssd/engine/draft_runner.py +++ b/ssd/engine/draft_runner.py @@ -33,8 +33,8 @@ def create_draft_config(cls, cfg: Config) -> Config: cfg, model=cfg.draft, gpu_memory_utilization = (0.75 if not cfg.draft_async else 0.8), # REMAINING SPACE if not draft_async - tokenizer_path=cfg.model if cfg.use_eagle else None, - d_model_target=cfg.hf_config.hidden_size if cfg.use_eagle and cfg.hf_config else None, + tokenizer_path=cfg.model if cfg.use_eagle_or_phoenix else None, + d_model_target=cfg.hf_config.hidden_size if cfg.use_eagle_or_phoenix and cfg.hf_config else None, ) return draft_cfg @@ -49,10 +49,6 @@ def __init__(self, draft_cfg: Config, rank: int = 0, init_q = None): self.target_rank = 0 self.communicate_logits = self.config.communicate_logits self.communicate_cache_hits = self.config.communicate_cache_hits - - if self.config.use_eagle: - assert self.config.jit_speculate, \ - "EAGLE requires jit_speculate=True (cache misses need draft activations)" if self.is_draft and self.draft_async: self._reset_tree_cache_tensors() @@ -68,7 +64,7 @@ def draft_async_prefill(self): print(f'[{_ts()}] [draft_async_prefill] DRAFT ASYNC PREFILL STARTING', flush=True) prefill_request = PrefillRequest.receive(self.async_pg, self.target_rank, self.device, metadata_buffer=self._prefill_metadata) - total_new_tokens, batch_size, max_blocks, use_eagle, eagle_act_dim = prefill_request.metadata.tolist() + total_new_tokens, batch_size, max_blocks, use_eagle_or_phoenix, eagle_phoenix_act_dim = prefill_request.metadata.tolist() input_ids = prefill_request.input_ids num_tokens = prefill_request.num_tokens draft_block_table = prefill_request.draft_block_table @@ -87,12 +83,16 @@ def draft_async_prefill(self): prefill_ctxt = self.prepare_prefill_ctxt(num_tokens, draft_block_table) - if use_eagle: - assert eagle_act_dim == 3 * self.config.d_model_target, ( - f"EAGLE activation dimension {eagle_act_dim} does not match expected dimension 3 * {self.config.d_model_target}" + if self.config.use_eagle: + assert eagle_phoenix_act_dim == 3 * self.config.d_model_target, ( + f"EAGLE activation dimension {eagle_phoenix_act_dim} does not match expected dimension 3 * {self.config.d_model_target}" + ) + elif self.config.use_phoenix: + assert eagle_phoenix_act_dim == self.config.d_model_target, ( + f"PHOENIX activation dimension {eagle_phoenix_act_dim} does not match expected dimension {self.config.d_model_target}" ) if self.config.verbose: - print(f'[{_ts()}] [draft_async_prefill] METADATA: total_new_tokens={total_new_tokens}, batch_size={batch_size}, max_blocks={max_blocks}, use_eagle={use_eagle}, eagle_act_dim={eagle_act_dim}', flush=True) + print(f'[{_ts()}] [draft_async_prefill] METADATA: total_new_tokens={total_new_tokens}, batch_size={batch_size}, max_blocks={max_blocks}, use_eagle_or_phoenix={use_eagle_or_phoenix}, eagle_phoenix_act_dim={eagle_phoenix_act_dim}', flush=True) # 5) set up context exactly like prepare_prefill() does: @@ -108,10 +108,7 @@ def draft_async_prefill(self): # 6) run the draft model in prefill mode positions = prefill_ctxt["positions"] - if self.config.use_eagle: - self.run_model(input_ids, positions, is_prefill=True, last_only=True, hidden_states=eagle_acts) - else: - self.run_model(input_ids, positions, is_prefill=True, last_only=True, hidden_states=eagle_acts) + self.run_model(input_ids, positions, is_prefill=True, last_only=True, hidden_states=eagle_acts) if self.config.verbose: print(f'[{_ts()}] [draft_async_prefill] DRAFT ASYNC PREFILL DONE', flush=True) @@ -155,11 +152,9 @@ def jit_speculate( draft_block_tables: torch.Tensor, target_recovery_activations: torch.Tensor = None, ): - input_ids = request_keys[:, -1] - pos_offset = -1 if self.config.use_eagle else 0 - positions = num_tokens - 1 + pos_offset # want to write rec token at post N-1 since [0, ..., N-2] filled by prefill - context_lens = num_tokens + pos_offset # N+1 + positions = num_tokens - 1 + context_lens = num_tokens # Calculate slot mapping vectorized block_idx = positions // self.block_size pos_in_block = positions % self.block_size @@ -168,13 +163,16 @@ def jit_speculate( hidden_states = None spec_activations = None - - if self.config.use_eagle: + + if self.config.use_eagle_or_phoenix: assert target_recovery_activations is not None - hidden_states = self.model.fc(target_recovery_activations.to(self.model.fc.weight.dtype)) + if self.config.use_eagle: + hidden_states = self.model.fc(target_recovery_activations.to(self.model.fc.weight.dtype)) + else: + hidden_states = target_recovery_activations spec_activations = torch.empty( input_ids.shape[0], self.config.speculate_k, - self.hf_config.hidden_size, + self.hidden_states_dim, dtype=self.hf_config.torch_dtype, device=self.device) for i in range(self.config.speculate_k): # we're going to glue after this anyways, and by sending the spec request target has verified we have K more slots left in our last page @@ -186,10 +184,13 @@ def jit_speculate( is_jit=True, ) - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: logits, prenorm = self.run_model(input_ids, positions, is_prefill=False, last_only=True, hidden_states=hidden_states) - spec_activations[:, i] = prenorm - hidden_states = prenorm + if self.config.use_eagle: + spec_activations[:, i] = prenorm + hidden_states = prenorm + else: + spec_activations[:, i] = hidden_states else: logits = self.run_model(input_ids, positions, is_prefill=False, last_only=True) @@ -221,12 +222,11 @@ def hit_cache(self, request_keys, B, K, num_tokens, temperatures, draft_block_ta cache_hits = torch.zeros(B, dtype=torch.int64, device=self.device) assert request_keys.shape == (B, 3), f"ERROR in hit_cache: request_keys should be (B, 3), got {request_keys.shape}" - - hidden_size = self.hf_config.hidden_size + out_activations = torch.empty( - B, K, hidden_size, + B, K, self.hidden_states_dim, dtype=self.hf_config.torch_dtype, device=self.device - ) if self.config.use_eagle else None + ) if self.config.use_eagle_or_phoenix else None # Statistics ttl += int(B) @@ -274,7 +274,7 @@ def hit_cache(self, request_keys, B, K, num_tokens, temperatures, draft_block_ta out_tokens[sel] = self.tree_cache_tokens[idx[sel]] # logits [T,K+1,V] out_logits[sel] = self.tree_cache_logits[idx[sel]] - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: out_activations[sel] = self.tree_cache_activations[idx[sel]] elif self.config.jit_speculate: # print(f'[hit_cache] found a cache miss, running jit speculate', flush=True) @@ -289,7 +289,7 @@ def hit_cache(self, request_keys, B, K, num_tokens, temperatures, draft_block_ta draft_block_tables, target_recovery_activations ) # write into out_logits, out_tokens - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: out_activations = jit_acts elif self.config.jit_speculate: # Cache is empty (first iteration), must JIT all @@ -304,7 +304,7 @@ def hit_cache(self, request_keys, B, K, num_tokens, temperatures, draft_block_ta draft_block_tables, target_recovery_activations ) - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: out_activations = jit_acts rec_toks = request_keys[:, 2] @@ -415,8 +415,7 @@ def prepare_prefill_ctxt( def prepare_glue_decode_ctxt(self, num_tokens, input_ids, dbt, B): K = self.config.speculate_k - pos_offset = -1 if self.config.use_eagle else 0 - positions_start = (num_tokens - 1 + pos_offset).unsqueeze(-1) + positions_start = (num_tokens - 1).unsqueeze(-1) positions_grid = positions_start + self._arange_kp1 # Calculate block indices and offsets for ALL positions @@ -434,7 +433,7 @@ def prepare_glue_decode_ctxt(self, num_tokens, input_ids, dbt, B): positions_flat = positions_grid.reshape(-1).to(torch.int64) slot_map_flat = slot_map_grid.reshape(-1).to(torch.int32) - context_lens = (num_tokens + pos_offset + K).to(torch.int32) + context_lens = (num_tokens + K).to(torch.int32) seqlen_q = torch.full((B,), K + 1, dtype=torch.int32, device=self.device) cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device=self.device) cu_seqlens_q[1:] = torch.cumsum(seqlen_q, dim=0) @@ -507,9 +506,8 @@ def _construct_tree_decode_args(self, partial_tree_decode_args, rec_flat, dbt): seq_ids = partial_tree_decode_args["seq_ids"] seq_ids_expanded = seq_ids[b_flat] - pos_offset = -1 if self.config.use_eagle else 0 - positions = (partial_tree_decode_args["num_tokens"][b_flat] - 1 + pos_offset) + (K + 1) + fkp1_flat - rope_positions = (partial_tree_decode_args["num_tokens"][b_flat] - 1 + pos_offset) + j_idx_flat + 1 + positions = (partial_tree_decode_args["num_tokens"][b_flat] - 1) + (K + 1) + fkp1_flat + rope_positions = (partial_tree_decode_args["num_tokens"][b_flat] - 1) + j_idx_flat + 1 temperatures = partial_tree_decode_args["temperatures"][b_flat] tree_decode_args = { @@ -534,9 +532,8 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): dbt = partial_tree_decode_args["dbt"] cache_hits = partial_tree_decode_args["cache_hits"] cache_hits_list = cache_hits.tolist() - pos_offset = -1 if self.config.use_eagle else 0 - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: B = partial_tree_decode_args["num_tokens"].shape[0] extend_counts = partial_tree_decode_args.get("extend_counts") if extend_counts is None: @@ -545,8 +542,8 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): extend_token_ids_batch = partial_tree_decode_args.get("extend_token_ids") target_acts = partial_tree_decode_args["target_recovery_activations"] prev_acts = partial_tree_decode_args["previous_activations"] - hidden_size = self.hf_config.hidden_size - fc_dtype = self.model.fc.weight.dtype + hidden_size = self.hidden_states_dim + fc_dtype = self.model.fc.weight.dtype if self.config.use_eagle else self.hf_config.torch_dtype gd_view = glue_decode_input_ids.view(B, K + 1) rec_tok_ids = gd_view[:, 0] @@ -591,7 +588,10 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): fused_ids[is_rec] = rec_tok_ids[batch_idx[is_rec]] # Single batched fc call - fused_hs[is_target_conditioned] = self.model.fc(tc_acts) + if self.config.use_eagle: + fused_hs[is_target_conditioned] = self.model.fc(tc_acts) + elif self.config.use_phoenix: + fused_hs[is_target_conditioned] = tc_acts # Spec tokens: ids from spec_tok_ids, hs from prev_acts (self-conditioned, no fc) spec_j = local_off[is_spec] - n_ext_per_tok[is_spec] - 1 # 0..K-1 @@ -621,8 +621,8 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): N_pre = _pre_b_flat.shape[0] _pre_metadata_ints = (B, K, self.config.async_fan_out, N_pre) _pre_seq_ids_expanded = partial_tree_decode_args["seq_ids"][_pre_b_flat] - _pre_positions = (partial_tree_decode_args["num_tokens"][_pre_b_flat] - 1 + pos_offset) + (K + 1) + _pre_fkp1_flat - _pre_rope_positions = (partial_tree_decode_args["num_tokens"][_pre_b_flat] - 1 + pos_offset) + _pre_j_idx_flat + 1 + _pre_positions = (partial_tree_decode_args["num_tokens"][_pre_b_flat] - 1) + (K + 1) + _pre_fkp1_flat + _pre_rope_positions = (partial_tree_decode_args["num_tokens"][_pre_b_flat] - 1) + _pre_j_idx_flat + 1 _pre_temperatures = partial_tree_decode_args["temperatures"][_pre_b_flat] # --- Run glue decode forward --- @@ -636,7 +636,7 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): ) glue_prenorm = None - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: fused_hs_flat = glue_decode_ctxt["hidden_states"] glue_decode_logits_flat, glue_prenorm = self.run_model( glue_decode_ctxt["input_ids"], glue_decode_ctxt["positions"], @@ -655,7 +655,7 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): reset_context() # --- Extract K+1 logits/prenorms at rec+spec positions --- - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: # Packed layout: rec at cu_seqlens_q[b] + n_ext[b], spec follows cu_q = glue_decode_ctxt["cu_seqlens_q"] rec_offsets = cu_q[:-1].long() + extend_counts.long() # [B] @@ -672,6 +672,7 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): # --- Build tree hidden states from K+1 prenorms --- tree_hidden_states = None if glue_prenorm is not None: + assert self.config.use_eagle_or_phoenix, "ERROR in _build_tree_batch: use_eagle_or_phoenix must be True when glue_prenorm is not None." # Vectorized: for each (b, depth), repeat prenorm by fan_out[depth] # fan_out_t[depth] for hits, fan_out_t_miss[depth] for misses fan_hit = self.config.fan_out_t # [K+1] @@ -683,12 +684,20 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): fan_miss.unsqueeze(0).expand(B, K + 1), ) # [B, K+1] reps_flat = per_batch_fan.reshape(-1) # [B*(K+1)] - prenorms_flat = glue_prenorm_kp1.reshape(B * (K + 1), -1) # [B*(K+1), d] - tree_hidden_states = torch.repeat_interleave(prenorms_flat, reps_flat, dim=0) + + if self.config.use_eagle: + prenorms_flat = glue_prenorm_kp1.reshape(B * (K + 1), -1) # [B*(K+1), d] + tree_hidden_states = torch.repeat_interleave(prenorms_flat, reps_flat, dim=0) + else: + assert self.config.use_phoenix + # Phoenix conditions on target activations, not prenorms + target_acts_expanded = target_acts.unsqueeze(1).expand(B, K + 1, -1) # [B, K+1, target_dim] + acts_flat = target_acts_expanded.reshape(B * (K + 1), -1) # [B*(K+1), target_dim] + tree_hidden_states = torch.repeat_interleave(acts_flat, reps_flat, dim=0) # --- Fork tokens from K+1 logits --- # Need [B, K+1] input_ids for forking (rec + spec tokens) - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: gd_for_fork = gd_view # [B, K+1] already computed above else: gd_for_fork = glue_decode_input_ids.reshape(B, K + 1) @@ -712,6 +721,7 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): "seq_ids_expanded": _pre_seq_ids_expanded, "cache_hits": cache_hits, "cache_hits_list": cache_hits_list, + "target_recovery_activations": partial_tree_decode_args["target_recovery_activations"], } tree_decode_args["hidden_states"] = tree_hidden_states return tree_decode_args @@ -736,7 +746,7 @@ def _compute_step_positions_and_slot_maps(self, initial_positions, initial_rope_ return step_positions, step_rope_positions, step_context_lens, step_slot_maps - def _decode_tree_step(self, depth, current_input_ids, step_rope_positions, step_slot_maps, step_context_lens, dbt, payload, spec_tokens, spec_logits, spec_activations): + def _decode_tree_step(self, depth, current_input_ids, step_rope_positions, step_slot_maps, step_context_lens, dbt, payload, spec_tokens, spec_logits, spec_activations, target_recovery_activations): """Execute a single tree decode step.""" # Use precomputed values for this step set_context( @@ -747,11 +757,15 @@ def _decode_tree_step(self, depth, current_input_ids, step_rope_positions, step_ ) hidden_states = payload.get("hidden_states") - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: logits, prenorm = self.run_model(current_input_ids, step_rope_positions[depth], is_prefill=False, last_only=False, tree_decode_step=depth, cache_hits=payload["cache_hits"], hidden_states=hidden_states) assert spec_activations is not None - spec_activations[:, depth] = prenorm - payload["hidden_states"] = prenorm + if self.config.use_eagle: + spec_activations[:, depth] = prenorm + payload["hidden_states"] = prenorm + else: + spec_activations[:, depth] = target_recovery_activations + payload["hidden_states"] = target_recovery_activations else: logits = self.run_model(current_input_ids, step_rope_positions[depth], is_prefill=False, last_only=False, tree_decode_step=depth, cache_hits=payload["cache_hits"]) @@ -778,9 +792,9 @@ def _decode_tree(self, payload): spec_logits = torch.empty( N, K, V, dtype=self.hf_config.torch_dtype, device=self.device) spec_activations = torch.empty( - N, K, self.hf_config.hidden_size, + N, K, self.hidden_states_dim, dtype=self.hf_config.torch_dtype, device=self.device - ) if self.config.use_eagle else None + ) if self.config.use_eagle_or_phoenix else None # Precompute all positions, context_lens, and slot_maps for all K steps # PERFORMANCE: no .clone() needed — these are not modified in-place @@ -788,6 +802,7 @@ def _decode_tree(self, payload): initial_rope_positions = payload["rope_positions"] # [N] current_input_ids = payload["input_ids"] # [N], the forked tokens dbt = payload["block_tables"] # [B, M] - constant across steps + target_recovery_activations = payload["target_recovery_activations"] # Use compiled function for batch-size independent computations _, step_rope_positions, step_context_lens, step_slot_maps = self._compute_step_positions_and_slot_maps( @@ -803,7 +818,7 @@ def _decode_tree(self, payload): _st = time.perf_counter() current_input_ids = self._decode_tree_step( depth, current_input_ids, step_rope_positions, step_slot_maps, - step_context_lens, dbt, payload, spec_tokens, spec_logits, spec_activations + step_context_lens, dbt, payload, spec_tokens, spec_logits, spec_activations, target_recovery_activations, ) if _prof or PROFILE_DRAFT: torch.cuda.synchronize() diff --git a/ssd/engine/helpers/cudagraph_helpers.py b/ssd/engine/helpers/cudagraph_helpers.py index 6c38eeddf..cbcd0104c 100644 --- a/ssd/engine/helpers/cudagraph_helpers.py +++ b/ssd/engine/helpers/cudagraph_helpers.py @@ -482,14 +482,17 @@ def capture_cudagraph(model_runner): is_jit = (model_runner.config.speculate and model_runner.config.draft_async and model_runner.is_draft) # Eagle models need special handling during CUDA graph capture - is_eagle_draft = config.use_eagle and model_runner.is_draft - is_eagle_target = config.use_eagle and not model_runner.is_draft + is_eagle_or_phoenix_draft = config.use_eagle_or_phoenix and model_runner.is_draft + is_eagle_or_phoenix_target = config.use_eagle_or_phoenix and not model_runner.is_draft hidden_states = None - if is_eagle_draft: - # Use hidden_size (d_model_draft) so CG captures the pass-through branch in Eagle3DraftForCausalLM.forward() - # All callers project target acts via fc() BEFORE passing to CG - hidden_states = torch.zeros(max_bs, hf_config.hidden_size, - dtype=hf_config.torch_dtype, device=input_ids.device) + if is_eagle_or_phoenix_draft: + # Note: For Eagle3, all callers project target acts via fc() BEFORE passing to CG + hidden_states = torch.zeros( + max_bs, + model_runner.hidden_states_dim, + dtype=hf_config.torch_dtype, + device=input_ids.device, + ) total_graphs = len(graph_bs_list) print(f'[capture_cudagraph] Starting capture of {total_graphs} graphs, bs list: {graph_bs_list[:5]}...{graph_bs_list[-3:]} max_bs={max_bs}', flush=True) @@ -498,10 +501,10 @@ def capture_cudagraph(model_runner): graph = torch.cuda.CUDAGraph() set_context( False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs], is_jit=is_jit) - if is_eagle_draft: + if is_eagle_or_phoenix_draft: outputs[:bs] = model_runner.model( input_ids[:bs], positions[:bs], hidden_states[:bs]) # warmup - elif is_eagle_target: + elif is_eagle_or_phoenix_target: out, _ = model_runner.model( input_ids[:bs], positions[:bs]) # warmup outputs[:bs] = out @@ -509,10 +512,10 @@ def capture_cudagraph(model_runner): outputs[:bs] = model_runner.model( input_ids[:bs], positions[:bs]) # warmup with torch.cuda.graph(graph, graph_pool): - if is_eagle_draft: + if is_eagle_or_phoenix_draft: outputs[:bs] = model_runner.model( input_ids[:bs], positions[:bs], hidden_states[:bs]) # capture - elif is_eagle_target: + elif is_eagle_or_phoenix_target: out, _ = model_runner.model( input_ids[:bs], positions[:bs]) # capture outputs[:bs] = out @@ -547,7 +550,7 @@ def capture_verify_cudagraph(model_runner): max_bs = min(model_runner.config.max_num_seqs, 512) k_plus_1 = model_runner.config.speculate_k + 1 - is_eagle_target = config.use_eagle and not model_runner.is_draft + is_eagle_or_phoenix_target = config.use_eagle_or_phoenix and not model_runner.is_draft # For verify, we need to handle k+1 tokens per sequence, and use cu_seqlens_q and max_seqlen_q input_ids = torch.zeros(max_bs * k_plus_1, dtype=torch.int64) @@ -559,12 +562,14 @@ def capture_verify_cudagraph(model_runner): outputs = torch.zeros(max_bs * k_plus_1, hf_config.hidden_size) cu_seqlens_q = torch.zeros(max_bs + 1, dtype=torch.int32) - # Eagle target: also capture eagle_acts from model forward + # Eagle/Phoenix target: also capture activations from model forward eagle_acts = None - if is_eagle_target: - # eagle_acts has shape [num_tokens, 3 * hidden_size] for 3 layers - eagle_acts = torch.zeros(max_bs * k_plus_1, 3 * hf_config.hidden_size, - dtype=hf_config.torch_dtype) + if is_eagle_or_phoenix_target: + eagle_acts = torch.zeros( + max_bs * k_plus_1, + model_runner.eagle_acts_dim, + dtype=hf_config.torch_dtype, + ) base = [1, 2, 4, 8] dynamic = list(range(16, max_bs+1, 16)) @@ -685,6 +690,7 @@ def run_glue_decode_cudagraph(model_runner, input_ids, positions, last_only, gra outputs = graph_vars["outputs"][:orig_flat] logits = model_runner.model.compute_logits(outputs, last_only) + assert logits.dim() == 2, "ERROR in run_glue_decode_cudagraph: logits must be 2D" if "eagle_hidden_states" in graph_vars: return logits, outputs return logits @@ -709,9 +715,14 @@ def capture_glue_decode_cudagraph(model_runner): outputs = torch.empty(max_flat, hf_config.hidden_size, device=model_runner.device) cu_seqlens_q = torch.zeros(max_bs + 1, dtype=torch.int32, device=model_runner.device) - eagle_hs = None - if config.use_eagle and model_runner.is_draft: - eagle_hs = torch.zeros(max_flat, hf_config.hidden_size, dtype=hf_config.torch_dtype, device=model_runner.device) + eagle_hidden_states = None + if config.use_eagle_or_phoenix and model_runner.is_draft: + eagle_hidden_states = torch.zeros( + max_flat, + model_runner.hidden_states_dim, + dtype=hf_config.torch_dtype, + device=model_runner.device, + ) graph_bs_list = [1] for bs in [2, 4, 8] + list(range(16, max_bs + 1, 16)): @@ -745,14 +756,14 @@ def capture_glue_decode_cudagraph(model_runner): block_tables=block_tables[:bs], ) - if eagle_hs is not None: - outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat], eagle_hs[:flat]) + if eagle_hidden_states is not None: + outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat], eagle_hidden_states[:flat]) else: outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat]) with torch.cuda.graph(graph, graph_pool): - if eagle_hs is not None: - outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat], eagle_hs[:flat]) + if eagle_hidden_states is not None: + outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat], eagle_hidden_states[:flat]) else: outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat]) @@ -771,8 +782,8 @@ def capture_glue_decode_cudagraph(model_runner): cu_seqlens_q=cu_seqlens_q, outputs=outputs, ) - if eagle_hs is not None: - graph_vars["eagle_hidden_states"] = eagle_hs + if eagle_hidden_states is not None: + graph_vars["eagle_hidden_states"] = eagle_hidden_states return graph_vars, graph_pool, graphs, graph_bs_list @@ -813,9 +824,13 @@ def capture_fi_tree_decode_cudagraph(model_runner): # All callers project target acts via fc() BEFORE passing to CG # MUST be outside the for-loop so all graphs share the same tensor fi_hidden_states = None - if config.use_eagle and model_runner.is_draft: - fi_hidden_states = torch.zeros(max_flat_batch_size, hf_config.hidden_size, - dtype=hf_config.torch_dtype, device=model_runner.device) + if config.use_eagle_or_phoenix and model_runner.is_draft: + fi_hidden_states = torch.zeros( + max_flat_batch_size, + model_runner.hidden_states_dim, + dtype=hf_config.torch_dtype, + device=model_runner.device, + ) print(f'[cuda_graph_helpers.capture_fi_tree_decode_cudagraph] About to capture FI cudagraphs for bs={graph_bs_list}', flush=True) diff --git a/ssd/engine/llm_engine.py b/ssd/engine/llm_engine.py index e99c6484e..093298975 100644 --- a/ssd/engine/llm_engine.py +++ b/ssd/engine/llm_engine.py @@ -298,8 +298,8 @@ def create_inference_step(self, config: Config) -> InferenceStep: draft_dtype=config.draft_hf_config.torch_dtype, kvcache_block_size=config.kvcache_block_size, max_model_len=config.max_model_len, - eagle=config.use_eagle, - eagle_act_dim=3 * config.hf_config.hidden_size if config.use_eagle else 0, + eagle=config.use_eagle_or_phoenix, + eagle_act_dim=self.model_runner.eagle_acts_dim if config.use_eagle_or_phoenix else 0, communicate_logits=config.communicate_logits, communicate_cache_hits=config.communicate_cache_hits, async_pg=self.model_runner.async_pg, @@ -328,7 +328,7 @@ def create_inference_step(self, config: Config) -> InferenceStep: scheduler=self.scheduler, speculator=speculator, verifier=verifier, - eagle=config.use_eagle, + eagle=config.use_eagle_or_phoenix, tokenizer=self.tokenizer, async_spec=config.draft_async, ) diff --git a/ssd/engine/model_runner.py b/ssd/engine/model_runner.py index b94552219..8747eb576 100644 --- a/ssd/engine/model_runner.py +++ b/ssd/engine/model_runner.py @@ -14,6 +14,7 @@ from ssd.models.qwen3 import Qwen3ForCausalLM from ssd.models.llama3 import LlamaForCausalLM from ssd.models.eagle3_draft_llama3 import Eagle3DraftForCausalLM +from ssd.models.phoenix_draft_llama3 import PhoenixLlamaForCausalLM from ssd.layers.sampler import Sampler from ssd.utils.context import set_context, reset_context, get_context from ssd.utils.loader import load_model @@ -76,6 +77,7 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event], is_dra self.world_size = config.num_gpus if should_use_dist else 1 self.rank = rank self.use_eagle = config.use_eagle + self.use_phoenix = config.use_phoenix if config.draft_async: self.draft_rank = config.num_gpus - 1 @@ -125,7 +127,7 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event], is_dra assert num_tp_gpus == 1, "ERROR in ModelRunner: draft should have tp_size=1" self.tp_pg = None # every rank is given an object from self.tp_pg, even tho draft doesnt participate it gets GROUP_NON_MEMBER object != None back, so we can't assert None here, we - print(f'[model_runner] about to setup and warmup model and cudagraphs, is use_eagle={self.use_eagle}', flush=True) + print(f'[model_runner] about to setup and warmup model and cudagraphs, is use_eagle={self.use_eagle}, is use_phoenix={self.use_phoenix}', flush=True) model_type = self.setup_and_warmup_model_and_cudagraphs(config, self.hf_config, init_q, is_draft) if self.verbose: print(f'-----CAPTURED {model_type}CUDAGRAPH----', flush=True) @@ -228,6 +230,9 @@ def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoC if config.use_eagle and is_draft: print(f'[EAGLE3] Loading Eagle3DraftForCausalLM as model_class', flush=True) model_class = Eagle3DraftForCausalLM + elif config.use_phoenix and is_draft: + print(f'[PHOENIX] Loading PhoenixDraftForCausalLM as model_class', flush=True) + model_class = PhoenixLlamaForCausalLM elif hf_config.model_type == 'llama': model_class = LlamaForCausalLM elif hf_config.model_type == 'qwen3': @@ -247,11 +252,12 @@ def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoC tp_size=self.num_tp_gpus, ) - if config.use_eagle: - kwargs['use_eagle'] = True + if config.use_eagle_or_phoenix: + kwargs['use_eagle'] = config.use_eagle + kwargs['use_phoenix'] = config.use_phoenix kwargs['eagle_layers'] = self.config.eagle_layers - - if model_class == Eagle3DraftForCausalLM: + + if model_class in [Eagle3DraftForCausalLM, PhoenixLlamaForCausalLM]: kwargs['d_model_target'] = config.d_model_target kwargs['debug_mode'] = config.debug_mode @@ -307,7 +313,7 @@ def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoC self.graph_pools["decode"] = decode_graph_pool self.graphs["decode"] = decode_graphs self.graph_bs_list["decode"] = decode_graph_bs_list - if self.config.speculate and not (self.is_draft and self.config.use_eagle): # verify CG: target always, non-EAGLE draft for fan-out; EAGLE draft uses glue_decode CG instead + if self.config.speculate and not (self.is_draft and self.config.use_eagle_or_phoenix): # verify CG: target always, non-EAGLE draft for fan-out; EAGLE draft uses glue_decode CG instead verify_graph_vars, verify_graph_pool, verify_graphs, verify_graph_bs_list = capture_verify_cudagraph(self) self.graph_vars["verify"] = verify_graph_vars self.graph_pools["verify"] = verify_graph_pool @@ -319,7 +325,7 @@ def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoC self.graph_pools["fi_tree_decode"] = fi_tree_decode_graph_pool self.graphs["fi_tree_decode"] = fi_tree_decode_graphs self.graph_bs_list["fi_tree_decode"] = fi_tree_decode_graph_bs_list - if self.config.speculate and self.is_draft and self.config.draft_async and self.config.use_eagle: + if self.config.speculate and self.is_draft and self.config.draft_async and self.config.use_eagle_or_phoenix: glue_gv, glue_pool, glue_graphs, glue_bs_list = capture_glue_decode_cudagraph(self) self.graph_vars["glue_decode"] = glue_gv self.graph_pools["glue_decode"] = glue_pool @@ -484,10 +490,15 @@ def warmup_model(self): seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)] hidden_states = None - if self.config.use_eagle and self.is_draft: + if self.config.use_eagle_or_phoenix and self.is_draft: num_tokens = num_seqs * max_model_len d_model_target = self.config.d_model_target or 4096 - hidden_states = torch.zeros(num_tokens, 3 * d_model_target, dtype=self.hf_config.torch_dtype, device=self.device) + if self.config.use_eagle: + hidden_states = torch.zeros(num_tokens, 3 * d_model_target, dtype=self.hf_config.torch_dtype, device=self.device) + elif self.config.use_phoenix: + hidden_states = torch.zeros(num_tokens, d_model_target, dtype=self.hf_config.torch_dtype, device=self.device) + else: + raise ValueError(f"Unsupported model type: {self.config.use_eagle_or_phoenix}") self.run(seqs, True, hidden_states=hidden_states) torch.cuda.empty_cache() @@ -643,6 +654,21 @@ def eager_tree_decode_plan(self, input_ids, positions, step, cache_hits): kv_data_type=self.hf_config.torch_dtype, ) + @property + def hidden_states_dim(self): + # The dimension of the hidden states that are concatenated with the draft tokens embeddings + # as the input to the Eagle/Phoenix draft model. + assert self.config.use_eagle_or_phoenix and self.is_draft + return self.config.hf_config.hidden_size if self.config.use_eagle else self.config.d_model_target + + @property + def eagle_acts_dim(self): + assert self.config.use_eagle_or_phoenix and not self.is_draft + if self.config.eagle_layers: + return len(self.config.eagle_layers) * self.config.hf_config.hidden_size + else: + return self.config.hf_config.hidden_size + @torch.inference_mode() def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool, last_only: bool = True, tree_decode_step: int = -1, cache_hits: torch.Tensor | None = None, hidden_states: torch.Tensor | None = None): is_tree_decode = self.is_draft and self.config.draft_async and tree_decode_step >= 0 @@ -655,10 +681,10 @@ def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill if is_tree_decode: self.eager_tree_decode_plan(input_ids, positions, tree_decode_step, cache_hits) - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: if self.is_draft: assert hidden_states is not None, "hidden_states required for EAGLE draft" - assert isinstance(self.model, Eagle3DraftForCausalLM) + assert isinstance(self.model, Eagle3DraftForCausalLM) or isinstance(self.model, PhoenixLlamaForCausalLM) prenorm = self.model(input_ids, positions, hidden_states) logits = self.model.compute_logits(prenorm, last_only) return logits, prenorm # return prenorm as conditioning vector for next iteration @@ -708,7 +734,7 @@ def run( # Handle EAGLE returning (logits, conditioning_vector for next iter) conditioning = None - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: logits, conditioning = self.run_model( input_ids, positions, is_prefill, last_only, hidden_states=hidden_states) else: @@ -717,7 +743,7 @@ def run( if _pt: torch.cuda.synchronize() _r2 = time.perf_counter() - print(f"[PROFILE target_run] prepare_decode={(_r1-_r0)*1000:.2f}ms run_model={(_r2-_r1)*1000:.2f}ms eagle={self.config.use_eagle} n_ids={input_ids.shape[0]}", flush=True) + print(f"[PROFILE target_run] prepare_decode={(_r1-_r0)*1000:.2f}ms run_model={(_r2-_r1)*1000:.2f}ms eagle={self.config.use_eagle}, phoenix={self.config.use_phoenix}, n_ids={input_ids.shape[0]}", flush=True) if last_only: token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None @@ -730,5 +756,3 @@ def run( if conditioning is not None: return logits, conditioning return logits - - diff --git a/ssd/engine/speculator_async.py b/ssd/engine/speculator_async.py index a5e3abc87..f61d1212d 100644 --- a/ssd/engine/speculator_async.py +++ b/ssd/engine/speculator_async.py @@ -75,18 +75,17 @@ def _prepare_prefill_request(self, seqs: list[Sequence], verify_result: VerifyRe eagle_acts = verify_result.eagle_acts input_id_list = [seq.token_ids for seq in seqs] - # EAGLE token-conditioning shift: token at position j gets conditioning - # from target act at position j-1. Skip first token per seq and drop - # last eagle_act per seq so they align correctly. + # EAGLE/Phoenix token-conditioning shift: we duplicate the first target activation for each sequence. + # [t0, h0], [t1, h0], [t2, h1], [t3, h2], ... if eagle_acts is not None: sliced = [] offset = 0 for ids in input_id_list: seq_len = len(ids) + sliced.append(eagle_acts[offset:offset + 1]) sliced.append(eagle_acts[offset:offset + seq_len - 1]) offset += seq_len eagle_acts = torch.cat(sliced, dim=0) - input_id_list = [ids[1:] for ids in input_id_list] max_blocks = (self.max_model_len + self.kvcache_block_size - 1) // self.kvcache_block_size input_ids_flat = [] diff --git a/ssd/layers/linear.py b/ssd/layers/linear.py index b25824172..d605caaa5 100755 --- a/ssd/layers/linear.py +++ b/ssd/layers/linear.py @@ -89,6 +89,9 @@ def __init__( def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data + if param_data.dim() == 1: # bias — no sharding needed + param_data.copy_(loaded_weight) + return shard_size = param_data.size(self.tp_dim) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) @@ -115,6 +118,9 @@ def __init__( def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int): param_data = param.data + if param_data.dim() == 1: # bias — no sharding needed + param_data.copy_(loaded_weight) + return shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size shard_size = self.output_sizes[loaded_shard_id] // self.tp_size param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) @@ -147,6 +153,9 @@ def __init__( def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str): param_data = param.data + if param_data.dim() == 1: # bias — no sharding needed + param_data.copy_(loaded_weight) + return assert loaded_shard_id in ["q", "k", "v"] if loaded_shard_id == "q": shard_size = self.num_heads * self.head_size @@ -187,6 +196,9 @@ def __init__( def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data + if param_data.dim() == 1: # bias — no sharding needed + param_data.copy_(loaded_weight) + return shard_size = param_data.size(self.tp_dim) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) diff --git a/ssd/models/eagle3_draft_llama3.py b/ssd/models/eagle3_draft_llama3.py index a74dd413f..71c19a1b9 100644 --- a/ssd/models/eagle3_draft_llama3.py +++ b/ssd/models/eagle3_draft_llama3.py @@ -219,6 +219,7 @@ def __init__( draft: bool = False, speculate: bool = False, use_eagle: bool = False, + use_phoenix: bool = False, eagle_layers: list[int] | None = None, d_model_target: int = 4096, spec_k: int = 1, @@ -233,6 +234,7 @@ def __init__( assert draft, "ERROR in Eagle3DraftForLlama3: draft must be True" assert use_eagle, "ERROR in Eagle3DraftForLlama3: config.use_eagle must be True" assert eagle_layers is not None, "ERROR in Eagle3DraftForLlama3: eagle_layers must be set" + assert not use_phoenix, "ERROR in Eagle3DraftForLlama3: config.use_phoenix must be False" # this will be the draft that does tree decode, just needs a modified fwd pass that takes in hidden states and uses fc and dicts to sample, etc self.config = config diff --git a/ssd/models/llama3.py b/ssd/models/llama3.py index a9934ad5d..091df664e 100755 --- a/ssd/models/llama3.py +++ b/ssd/models/llama3.py @@ -210,6 +210,7 @@ def __init__( async_fan_out: int = 1, draft_async: bool = False, use_eagle: bool = False, + use_phoenix: bool = False, eagle_layers: list[int] | None = None, tp_group: dist.ProcessGroup | None = None, tp_size: int = 1, @@ -221,8 +222,9 @@ def __init__( self.async_fan_out = async_fan_out self.draft_async = draft_async self.use_eagle = use_eagle + self.use_phoenix = use_phoenix self.eagle_layers = eagle_layers - print(f'[LlamaModel] use_eagle={use_eagle}, eagle_layers={eagle_layers}', flush=True) + print(f'[LlamaModel] use_eagle={use_eagle}, use_phoenix={use_phoenix}, eagle_layers={eagle_layers}', flush=True) self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -249,24 +251,33 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, + hidden_states: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - hidden_states = self.embed_tokens(input_ids) # torch.Size([4096, 2560]) always through residual stream + if hidden_states is None: + hidden_states = self.embed_tokens(input_ids) residual = None # Collect activations if use_eagle - collected_acts = [] if self.use_eagle else None + collected_acts = [] if not self.draft and (self.use_eagle or self.use_phoenix) else None for layer_idx, layer in enumerate(self.layers): - if collected_acts is not None and layer_idx in self.eagle_layers: + if collected_acts is not None and self.eagle_layers is not None and layer_idx in self.eagle_layers: current_act = hidden_states if residual is None else hidden_states + residual collected_acts.append(current_act) hidden_states, residual = layer(positions, hidden_states, residual) - hidden_states, _ = self.norm(hidden_states, residual) - - if collected_acts: - eagle_acts = torch.cat(collected_acts, dim=-1) + + if not self.draft and self.use_phoenix: + assert self.eagle_layers is None, "ERROR in LlamaModel: use_phoenix and eagle_layers are not compatible" + collected_acts.append(hidden_states) + + if collected_acts is not None: + if len(collected_acts) > 1: + eagle_acts = torch.cat(collected_acts, dim=-1) + else: + assert len(collected_acts) == 1 + eagle_acts = collected_acts[0] print(f'[LlamaModel] eagle_acts shape={eagle_acts.shape}', flush=True) return hidden_states, eagle_acts else: @@ -284,9 +295,11 @@ class LlamaForCausalLM(nn.Module): def __init__( self, - config: LlamaConfig, draft: bool = False, + config: LlamaConfig, + draft: bool = False, speculate: bool = False, use_eagle: bool = False, + use_phoenix: bool = False, eagle_layers: list[int] | None = None, spec_k: int = 1, async_fan_out: int = 1, @@ -301,6 +314,7 @@ def __init__( self.async_fan_out = async_fan_out self.draft_async = draft_async self.use_eagle = use_eagle + self.use_phoenix = use_phoenix self.eagle_layers = eagle_layers self.tp_group = tp_group self.tp_size = tp_size @@ -310,7 +324,19 @@ def __init__( print(f'Starting LlamaForCausalLM init, draft={draft}, speculate={speculate}, spec_k={spec_k}') print(f'[LlamaForCausalLM] use_eagle={use_eagle}, eagle_layers={eagle_layers}', flush=True) - self.model = LlamaModel(config, draft, speculate, spec_k, async_fan_out, draft_async, use_eagle=use_eagle, eagle_layers=eagle_layers, tp_group=tp_group, tp_size=self.tp_size) + self.model = LlamaModel( + config, + draft, + speculate, + spec_k, + async_fan_out, + draft_async, + use_eagle=use_eagle, + use_phoenix=use_phoenix, + eagle_layers=eagle_layers, + tp_group=tp_group, + tp_size=self.tp_size, + ) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, diff --git a/ssd/models/phoenix_draft_llama3.py b/ssd/models/phoenix_draft_llama3.py new file mode 100644 index 000000000..2b25401cc --- /dev/null +++ b/ssd/models/phoenix_draft_llama3.py @@ -0,0 +1,74 @@ +import torch +import torch.distributed as dist +from transformers import LlamaConfig + +from ssd.layers.linear import RowParallelLinear +from ssd.models.llama3 import LlamaForCausalLM + + +class PhoenixLlamaForCausalLM(LlamaForCausalLM): + def __init__( + self, + config: LlamaConfig, + draft: bool = True, + speculate: bool = True, + use_eagle: bool = False, + use_phoenix: bool = True, + eagle_layers: list[int] | None = None, + d_model_target: int = 4096, + spec_k: int = 1, + async_fan_out: int = 1, + draft_async: bool = False, + tp_group: dist.ProcessGroup | None = None, + tp_size: int = 1, + debug_mode: bool = False, + ) -> None: + assert draft, "ERROR in PhoenixLlamaForCausalLM: draft must be True" + assert use_phoenix, "ERROR in PhoenixLlamaForCausalLM: config.use_phoenix must be True" + assert not use_eagle, "ERROR in PhoenixLlamaForCausalLM: config.use_eagle must be False" + super().__init__( + config, + draft=True, + speculate=True, + use_eagle=False, + use_phoenix=True, + eagle_layers=None, + spec_k=spec_k, + async_fan_out=async_fan_out, + draft_async=draft_async, + tp_group=tp_group, + tp_size=tp_size, + ) + self.d_model_target = d_model_target + self.debug_mode = debug_mode + self.eh_proj = RowParallelLinear( + self.d_model_target + config.hidden_size, + config.hidden_size, + bias=True, + tp_group=tp_group, + tp_size=tp_size, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + input_embeds = self.model.embed_tokens(input_ids) + hidden_states = torch.cat((input_embeds, hidden_states), dim=-1) + hidden_states = self.eh_proj(hidden_states.to(self.eh_proj.weight.dtype)) + out = self.model(input_ids, positions, hidden_states) + return out + + def compute_logits( + self, + hidden_states: torch.Tensor, + last_only: bool = True, + ) -> torch.Tensor: + logits = self.lm_head(hidden_states, last_only=last_only) + + if logits.dim() == 3: + logits = logits.view(-1, logits.shape[-1]) + + return logits From 7053b808b3f6fcdb2eb8b2e8a4f68b8ebffc0c4d Mon Sep 17 00:00:00 2001 From: Avner May Date: Sat, 28 Mar 2026 06:54:38 -0700 Subject: [PATCH 2/6] FA4 initial implementation by CC --- ssd/engine/helpers/cudagraph_helpers.py | 280 ++++-------------------- ssd/engine/helpers/runner_helpers.py | 2 + ssd/engine/model_runner.py | 119 ++-------- ssd/layers/attention.py | 36 +-- ssd/layers/tree_mask.py | 100 +++++++++ ssd/utils/context.py | 6 +- tests/test_fa4_tree_decode.py | 201 +++++++++++++++++ tests/test_score_mod_basic.py | 155 +++++++++++++ tests/test_tree_mask_correctness.py | 164 ++++++++++++++ 9 files changed, 711 insertions(+), 352 deletions(-) create mode 100644 ssd/layers/tree_mask.py create mode 100644 tests/test_fa4_tree_decode.py create mode 100644 tests/test_score_mod_basic.py create mode 100644 tests/test_tree_mask_correctness.py diff --git a/ssd/engine/helpers/cudagraph_helpers.py b/ssd/engine/helpers/cudagraph_helpers.py index cbcd0104c..0fc1529ec 100644 --- a/ssd/engine/helpers/cudagraph_helpers.py +++ b/ssd/engine/helpers/cudagraph_helpers.py @@ -1,7 +1,6 @@ import os import math import torch -import numpy as np from ssd.utils.context import set_context, get_context, reset_context from time import perf_counter @@ -122,9 +121,6 @@ def run_decode_cudagraph(model_runner, input_ids, positions, last_only, graph_va return logits -cache = {} - -_plan_event = None # Lazy-init CUDA event for plan() sync PROFILE = os.environ.get("SSD_PROFILE", "0") == "1" PROFILE_DRAFT = os.environ.get("SSD_PROFILE_DRAFT", "0") == "1" _draft_events = [] # [(step, label, start_event, end_event), ...] @@ -149,30 +145,23 @@ def flush_draft_profile(): @torch.inference_mode() def run_fi_tree_decode_cudagraph(model_runner, input_ids, positions, last_only, graph_vars, step, cache_hits, hidden_states=None): - # bs != len(input_ids, positions) now in multi-query seting, also need step-dependent mask context = get_context() - assert context.cu_seqlens_q is None, "ERROR in run_fi_tree_decode_cudagraph: cu_seqlens_q should be set to None so we don't take FA path" - K, F = model_runner.config.speculate_k, model_runner.config.async_fan_out - # MQ_LEN = F * (K+1) MQ_LEN = sum(model_runner.config.fan_out_list) orig_flat = input_ids.size(0) assert orig_flat % MQ_LEN == 0, f"ERROR in run_fi_tree_decode_cudagraph: flat_batch_size should be divisible by MQ_LEN, got {orig_flat} and {MQ_LEN}" orig_B = orig_flat // MQ_LEN - # Pick CUDA graph and wrapper bucket + # Pick CUDA graph bucket wrapper_bs = next( x for x in model_runner.graph_bs_list["fi_tree_decode"] if x >= orig_B) graph = model_runner.graphs["fi_tree_decode"][wrapper_bs] - wrapper = model_runner.prefill_wrappers[wrapper_bs] # Prepare padded inputs/context if needed if wrapper_bs > orig_B: - # print(f'PADDING--') pad_B = wrapper_bs - orig_B pad_flat = pad_B * MQ_LEN - # Pad queries (ids/rope positions) pad_ids = torch.zeros( pad_flat, dtype=input_ids.dtype, device=input_ids.device) pad_pos = torch.zeros( @@ -180,13 +169,11 @@ def run_fi_tree_decode_cudagraph(model_runner, input_ids, positions, last_only, input_ids = torch.cat([input_ids, pad_ids], dim=0) positions = torch.cat([positions, pad_pos], dim=0) - # Pad slot_mapping with -1 to skip KV writes for padded queries slot_map = torch.cat( [context.slot_mapping, torch.full((pad_flat,), -1, dtype=context.slot_mapping.dtype, device=context.slot_mapping.device)] ) - # Pad block_tables/context_lens by repeating the last real row bt = context.block_tables cl = context.context_lens pad_bt = bt[orig_B - 1:orig_B].expand(pad_B, -1).contiguous() @@ -194,19 +181,23 @@ def run_fi_tree_decode_cudagraph(model_runner, input_ids, positions, last_only, bt = torch.cat([bt, pad_bt], dim=0) cl = torch.cat([cl, pad_cl], dim=0) - # Set padded context for this replay set_context(is_prefill=False, slot_mapping=slot_map, - context_lens=cl, block_tables=bt) + context_lens=cl, block_tables=bt, + tree_cu_seqlens_q=graph_vars["tree_cu_seqlens_q"][wrapper_bs], + tree_mask_bias=graph_vars["tree_mask_bias"]) block_tables = bt context_lens = cl - flat_batch_size = input_ids.size(0) # == wrapper_bs * MQ_LEN + flat_batch_size = input_ids.size(0) B = wrapper_bs else: block_tables = context.block_tables context_lens = context.context_lens flat_batch_size = orig_flat B = orig_B + # Set tree decode metadata on context for FA4 + context.tree_cu_seqlens_q = graph_vars["tree_cu_seqlens_q"][wrapper_bs] + context.tree_mask_bias = graph_vars["tree_mask_bias"] if PROFILE: torch.cuda.synchronize() @@ -214,185 +205,26 @@ def run_fi_tree_decode_cudagraph(model_runner, input_ids, positions, last_only, end_time = torch.cuda.Event(enable_timing=True) start_time.record() - # in the case where we pad, we'll need cache_hits.shape[0] to match the padded batch size - if cache_hits.shape[0] < B: - cache_hits = torch.cat([cache_hits, torch.zeros(B - cache_hits.shape[0], device=cache_hits.device)]) - - # PERFORMANCE: Step 0 -- precompute KV page metadata on CPU for all K steps. - # CPU tensors let plan() skip its internal .to("cpu") GPU->CPU syncs. - # For B<=8, CPU slicing also avoids GPU boolean indexing. - if step == 0: - cache["cu_seqlens_q_cpu"] = torch.arange(B + 1, dtype=torch.int32) * MQ_LEN - context_lens_list = context_lens.tolist() - cache["block_tables"] = block_tables - block_size = model_runner.block_size - cache["precomputed_kv"] = [] - cache["plan_cpu_args"] = [] - - if B <= 8: - # PERFORMANCE: CPU-only kv_indices via slicing (no GPU boolean indexing) - for s in range(K): - step_cls = [int(cl) + s * MQ_LEN for cl in context_lens_list] - step_counts = [(cl + block_size - 1) // block_size for cl in step_cls] - if B == 1: - kv_indices_s = block_tables[0, :step_counts[0]] - else: - kv_indices_s = torch.cat([block_tables[b, :step_counts[b]] for b in range(B)]) - cache["precomputed_kv"].append(kv_indices_s) - kv_indptr_cpu = torch.zeros(B + 1, dtype=torch.int32) - kv_indptr_cpu[1:] = torch.tensor(step_counts, dtype=torch.int32).cumsum(0) - kv_lpl_cpu = torch.tensor( - [cl % block_size if cl % block_size != 0 else block_size for cl in step_cls], - dtype=torch.int32) - cache["plan_cpu_args"].append((kv_indptr_cpu, kv_lpl_cpu)) - else: - # Large batch: GPU boolean indexing for kv_indices, CPU tensors for plan args - bt_upcast = torch.arange(block_tables.size(1), device=block_tables.device)[None, :] - step_offsets = torch.arange(K + 2, device=context_lens.device) * MQ_LEN - all_step_cls = context_lens.unsqueeze(1) + step_offsets.unsqueeze(0) - all_counts = (all_step_cls + block_size - 1) // block_size - all_masks = bt_upcast.unsqueeze(1) < all_counts.unsqueeze(2) - for s in range(K): - cache["precomputed_kv"].append(block_tables[all_masks[:, s, :]]) - step_cls = [int(cl) + s * MQ_LEN for cl in context_lens_list] - step_counts = [(cl + block_size - 1) // block_size for cl in step_cls] - kv_indptr_cpu = torch.zeros(B + 1, dtype=torch.int32) - kv_indptr_cpu[1:] = torch.tensor(step_counts, dtype=torch.int32).cumsum(0) - kv_lpl_cpu = torch.tensor( - [cl % block_size if cl % block_size != 0 else block_size for cl in step_cls], - dtype=torch.int32) - cache["plan_cpu_args"].append((kv_indptr_cpu, kv_lpl_cpu)) - - # CPU mask precompute: build all K packed masks using numpy at step 0. - # Eliminates per-step get_custom_mask (GPU) + segment_packbits + GPU->CPU syncs. - cache_hits_list = cache_hits[:B].tolist() - - if "glue_hit_np" not in cache: - _fol = model_runner.config.fan_out_list - _fol_miss = model_runner.config.fan_out_list_miss - _tril = np.tril(np.ones((K + 1, K + 1), dtype=np.uint8)) - cache["glue_hit_np"] = np.repeat(_tril, _fol, axis=0) - cache["glue_miss_np"] = np.repeat(_tril, _fol_miss, axis=0) - - _glue_hit = cache["glue_hit_np"] - _glue_miss = cache["glue_miss_np"] - _rows_np = np.arange(MQ_LEN) - - cache["cpu_packed_masks"] = [] - cache["cpu_packed_indptrs"] = [] - - for s in range(K): - ttl_added_s = (s + 1) * MQ_LEN + (K + 1) - packed_segs = [] - seg_packed_sizes = [] - - for b in range(B): - cols_b = int(context_lens_list[b]) + s * MQ_LEN - prefix_len_b = cols_b - ttl_added_s - - mask_b = np.zeros((MQ_LEN, cols_b), dtype=np.uint8) - mask_b[:, :prefix_len_b] = 1 - glue = _glue_hit if int(cache_hits_list[b]) == 1 else _glue_miss - mask_b[:, prefix_len_b:prefix_len_b + K + 1] = glue - diag_start = prefix_len_b + K + 1 - for blk in range(s + 1): - mask_b[_rows_np, diag_start + blk * MQ_LEN + _rows_np] = 1 - - packed = np.packbits(mask_b.ravel(), bitorder='little') - packed_segs.append(packed) - seg_packed_sizes.append(len(packed)) - - full_packed = np.concatenate(packed_segs) if B > 1 else packed_segs[0] - indptr = np.zeros(B + 1, dtype=np.int32) - indptr[1:] = np.cumsum(seg_packed_sizes) - - cache["cpu_packed_masks"].append( - torch.from_numpy(full_packed.copy()).to(model_runner.device, non_blocking=True)) - cache["cpu_packed_indptrs"].append( - torch.from_numpy(indptr.copy()).to(model_runner.device, non_blocking=True)) - - # Pre-transfer KV metadata to GPU (eliminates per-step pageable H2D transfers) - cache["qo_indptr_gpu"] = cache["cu_seqlens_q_cpu"].to(model_runner.device, non_blocking=True) - cache["kv_indptr_gpu"] = [] - cache["kv_lpl_gpu"] = [] - cache["kv_lens_gpu"] = [] - for s in range(K): - ki, kl = cache["plan_cpu_args"][s] - cache["kv_indptr_gpu"].append(ki.to(model_runner.device, non_blocking=True)) - cache["kv_lpl_gpu"].append(kl.to(model_runner.device, non_blocking=True)) - kv_lens = ((ki[1:] - ki[:-1] - 1) * model_runner.block_size + kl).to(torch.int32) - cache["kv_lens_gpu"].append(kv_lens.to(model_runner.device, non_blocking=True)) - - if PROFILE: - end_time.record() - torch.cuda.synchronize() - precompute_time = start_time.elapsed_time(end_time) - start_time.record() - - # Use precomputed CPU-packed masks (built at step 0) - if PROFILE_DRAFT: - _ev_mask0 = torch.cuda.Event(enable_timing=True); _ev_mask0.record() - - kv_indices = cache["precomputed_kv"][step] - kv_indptr_cpu, kv_lpl_cpu = cache["plan_cpu_args"][step] - qo_indptr_cpu = cache["cu_seqlens_q_cpu"] - - packed_mask = cache["cpu_packed_masks"][step] - packed_indptr = cache["cpu_packed_indptrs"][step] - wrapper._custom_mask_buf[:len(packed_mask)].copy_(packed_mask, non_blocking=True) - wrapper._mask_indptr_buf.copy_(packed_indptr, non_blocking=True) - - # GPU-to-GPU copies from pre-transferred tensors (no pageable H2D) - wrapper._qo_indptr_buf.copy_(cache["qo_indptr_gpu"], non_blocking=True) - wrapper._paged_kv_indptr_buf.copy_(cache["kv_indptr_gpu"][step], non_blocking=True) - wrapper._paged_kv_last_page_len_buf.copy_(cache["kv_lpl_gpu"][step], non_blocking=True) - wrapper._paged_kv_indices_buf[:len(kv_indices)].copy_(kv_indices, non_blocking=True) - - total_num_rows = int(qo_indptr_cpu[-1].item()) - wrapper._kv_lens_buffer[:len(kv_indptr_cpu) - 1].copy_(cache["kv_lens_gpu"][step], non_blocking=True) - - # Event-based sync: only wait for this stream's copies, not all CUDA streams. - global _plan_event - if _plan_event is None: - _plan_event = torch.cuda.Event() - _plan_event.record() - _plan_event.synchronize() - - if PROFILE_DRAFT: - _ev_plan0 = torch.cuda.Event(enable_timing=True); _ev_plan0.record() - - plan_args = [ - wrapper._float_workspace_buffer, wrapper._int_workspace_buffer, - wrapper._pin_memory_int_workspace_buffer, - qo_indptr_cpu, kv_indptr_cpu, cache["kv_lens_gpu"][step], - wrapper._max_total_num_rows or total_num_rows, - B, model_runner.hf_config.num_attention_heads, - model_runner.hf_config.num_key_value_heads, - model_runner.block_size, wrapper.is_cuda_graph_enabled, - model_runner.hf_config.head_dim, model_runner.hf_config.head_dim, - False, -1, - ] - if wrapper._backend == "fa2": - plan_args.extend([-1, False, 0]) # fixed_split_size, disable_split_kv, num_colocated_ctas - wrapper._plan_info = wrapper._cached_module.plan(*plan_args) - - if PROFILE_DRAFT: - _ev_plan1 = torch.cuda.Event(enable_timing=True); _ev_plan1.record() - - if PROFILE: - end_time.record() - torch.cuda.synchronize() - plan_time = start_time.elapsed_time(end_time) - start_time.record() + # Build tree mask bias for this step and copy into pre-allocated buffer + from ssd.layers.tree_mask import build_tree_mask_bias + K = model_runner.config.speculate_k + mask_bias = build_tree_mask_bias( + context_lens, step=step, K=K, MQ_LEN=MQ_LEN, + fan_out_list=model_runner.config.fan_out_list, + fan_out_list_miss=model_runner.config.fan_out_list_miss, + cache_hits=cache_hits, + max_kv_stride=model_runner.config.max_model_len, + device=model_runner.device, + ) + graph_vars["tree_mask_bias"][:len(mask_bias)] = mask_bias - # Copy inputs/context into graph buffers for padded size + # Copy inputs/context into graph buffers graph_vars["input_ids"][:flat_batch_size] = input_ids graph_vars["positions"][:flat_batch_size] = positions graph_vars["slot_mapping"][:flat_batch_size] = get_context().slot_mapping graph_vars["context_lens"][:B] = context_lens if hidden_states is not None and "hidden_states" in graph_vars: if hidden_states.shape[0] < flat_batch_size: - # Pad hidden_states to match padded batch pad_n = flat_batch_size - hidden_states.shape[0] hidden_states = torch.cat([hidden_states, torch.zeros(pad_n, hidden_states.shape[1], dtype=hidden_states.dtype, device=hidden_states.device)]) graph_vars["hidden_states"][:flat_batch_size] = hidden_states @@ -412,8 +244,6 @@ def run_fi_tree_decode_cudagraph(model_runner, input_ids, positions, last_only, if PROFILE_DRAFT: _ev_replay1 = torch.cuda.Event(enable_timing=True); _ev_replay1.record() - _draft_events.append((step, "mask+buf", _ev_mask0, _ev_plan0)) - _draft_events.append((step, "plan", _ev_plan0, _ev_plan1)) _draft_events.append((step, "replay", _ev_replay0, _ev_replay1)) if PROFILE: @@ -421,14 +251,12 @@ def run_fi_tree_decode_cudagraph(model_runner, input_ids, positions, last_only, torch.cuda.synchronize() replay_time = start_time.elapsed_time(end_time) - # Extract logits from graph_vars instead of computing them separately logits_all = graph_vars["logits"][:flat_batch_size] if PROFILE: - print(f"[cuda_graph_helpers.run_fi_tree_decode_cudagraph] step {step}: precompute={precompute_time:.3f}ms, plan={plan_time:.3f}ms, buffer={buffer_prep_time:.3f}ms, replay={replay_time:.3f}ms", flush=True) + print(f"[cuda_graph_helpers.run_fi_tree_decode_cudagraph] step {step}: buffer={buffer_prep_time:.3f}ms, replay={replay_time:.3f}ms", flush=True) logits_out = logits_all[:orig_flat] - # EAGLE draft: also return prenorm (outputs) for self-conditioning if "hidden_states" in graph_vars: prenorm = graph_vars["outputs"][:orig_flat] return logits_out, prenorm @@ -793,8 +621,6 @@ def capture_fi_tree_decode_cudagraph(model_runner): config = model_runner.config hf_config = config.hf_config max_bs = min(model_runner.config.max_num_seqs, 512) - K, F = model_runner.config.speculate_k, model_runner.config.async_fan_out - # MQ_LEN = F * (K+1) MQ_LEN = sum(model_runner.config.fan_out_list) max_flat_batch_size = max_bs * MQ_LEN @@ -803,12 +629,11 @@ def capture_fi_tree_decode_cudagraph(model_runner): input_ids = torch.zeros(max_flat_batch_size, dtype=torch.int64, device=model_runner.device) positions = torch.zeros(max_flat_batch_size, dtype=torch.int64, device=model_runner.device) slot_mapping = torch.zeros(max_flat_batch_size, dtype=torch.int32, device=model_runner.device) - context_lens = torch.full((max_bs,), config.max_model_len, dtype=torch.int32, device=model_runner.device) # make sure these are consistent with our dummy example + context_lens = torch.full((max_bs,), config.max_model_len, dtype=torch.int32, device=model_runner.device) block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32, device=model_runner.device) outputs = torch.empty(max_flat_batch_size, hf_config.hidden_size, device=model_runner.device) logits = torch.empty(max_flat_batch_size, hf_config.vocab_size, device=model_runner.device) - # Create graph_bs_list to match what will be used in cudagraph_helpers.py graph_bs_list = [1] for bs in [2, 4, 8] + list(range(16, max_bs + 1, 16)): if bs <= max_bs: @@ -820,9 +645,6 @@ def capture_fi_tree_decode_cudagraph(model_runner): graphs = {} graph_pool = None - # Eagle draft needs hidden_states for forward (d_model_draft, NOT 3*d_model_target) - # All callers project target acts via fc() BEFORE passing to CG - # MUST be outside the for-loop so all graphs share the same tensor fi_hidden_states = None if config.use_eagle_or_phoenix and model_runner.is_draft: fi_hidden_states = torch.zeros( @@ -832,52 +654,30 @@ def capture_fi_tree_decode_cudagraph(model_runner): device=model_runner.device, ) - print(f'[cuda_graph_helpers.capture_fi_tree_decode_cudagraph] About to capture FI cudagraphs for bs={graph_bs_list}', flush=True) + # Pre-allocate tree_cu_seqlens_q per batch size bucket (constant values, used by FA4) + tree_cu_seqlens_q_dict = {} + for bs in graph_bs_list: + tree_cu_seqlens_q_dict[bs] = torch.arange( + bs + 1, dtype=torch.int32, device=model_runner.device) * MQ_LEN - for bs in reversed(graph_bs_list): - graph = torch.cuda.CUDAGraph() + # Pre-allocate tree mask bias at max size (shared across all batch sizes, updated before replay) + tree_mask_bias = torch.zeros( + max_flat_batch_size * config.max_model_len, + dtype=torch.float32, device=model_runner.device) - # Build a self-consistent fake plan for capture: - # - q_len = MQ_LEN for each request - # - k_len = max_model_len for each request (use maximum context length) + print(f'[cuda_graph_helpers.capture_fi_tree_decode_cudagraph] About to capture FA4 tree decode cudagraphs for bs={graph_bs_list}', flush=True) - cu_seqlens_q = torch.arange( - bs + 1, dtype=torch.int32, device=model_runner.device) * MQ_LEN - # Use max_num_blocks pages per request for maximum context length - kv_indptr = torch.arange( - bs + 1, dtype=torch.int32, device=model_runner.device) * max_num_blocks - kv_indices = torch.zeros(int( - kv_indptr[-1].item()), dtype=torch.int32, device=model_runner.device) # page ids (dummy) - # Last page length for max model len context - last_page_len = config.max_model_len % model_runner.block_size - if last_page_len == 0: - last_page_len = model_runner.block_size - kv_last_page_len = torch.full( - (bs,), last_page_len, dtype=torch.int32, device=model_runner.device) - custom_mask = torch.ones(bs * MQ_LEN * config.max_model_len, - dtype=torch.bool, device=model_runner.device) - - # Set the fi_tensors buffers with our fake data - model_runner.prefill_wrappers[bs].plan( - cu_seqlens_q, - kv_indptr, - kv_indices, - kv_last_page_len, - hf_config.num_attention_heads, - hf_config.num_key_value_heads, - hf_config.head_dim, - model_runner.block_size, - custom_mask=custom_mask, - q_data_type=hf_config.torch_dtype, - kv_data_type=hf_config.torch_dtype, - ) + for bs in reversed(graph_bs_list): + graph = torch.cuda.CUDAGraph() - # Set minimal context needed for run + # Set context with FA4 metadata set_context( is_prefill=False, slot_mapping=slot_mapping[:bs * MQ_LEN], context_lens=context_lens[:bs], - block_tables=block_tables[:bs] + block_tables=block_tables[:bs], + tree_cu_seqlens_q=tree_cu_seqlens_q_dict[bs], + tree_mask_bias=tree_mask_bias, ) # Warmup run @@ -913,6 +713,8 @@ def capture_fi_tree_decode_cudagraph(model_runner): context_lens=context_lens, outputs=outputs, logits=logits, + tree_cu_seqlens_q=tree_cu_seqlens_q_dict, + tree_mask_bias=tree_mask_bias, ) if fi_hidden_states is not None: graph_vars["hidden_states"] = fi_hidden_states diff --git a/ssd/engine/helpers/runner_helpers.py b/ssd/engine/helpers/runner_helpers.py index 46ed89489..ed567b36b 100644 --- a/ssd/engine/helpers/runner_helpers.py +++ b/ssd/engine/helpers/runner_helpers.py @@ -27,6 +27,8 @@ def _dump_ts(): print(f"[{_ts()}] BANANA: Dumping tensors to {DUMP_TENSORS_DIR}") os.makedirs(DUMP_TENSORS_DIR, exist_ok=True) DUMP_TENSORS = True +else: + DUMP_TENSORS = False def list_to_str(lst: list[float] | list[list[float]], num_decimals: int = 4) -> str: assert len(lst) > 0 diff --git a/ssd/engine/model_runner.py b/ssd/engine/model_runner.py index 8747eb576..b46b90325 100644 --- a/ssd/engine/model_runner.py +++ b/ssd/engine/model_runner.py @@ -8,7 +8,6 @@ from multiprocessing.shared_memory import SharedMemory from transformers import AutoTokenizer, AutoConfig import os -import flashinfer from ssd.config import Config from ssd.engine.sequence import Sequence from ssd.models.qwen3 import Qwen3ForCausalLM @@ -36,7 +35,6 @@ capture_fi_tree_decode_cudagraph, capture_glue_decode_cudagraph, ) -from ssd.engine.helpers.mask_helpers import get_custom_mask NCCL_LOG = os.environ.get("SSD_NCCL_LOG", "0") == "1" @@ -100,11 +98,7 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event], is_dra self.device = torch.device(f'cuda:{self.rank}') self._cmd = torch.empty(1, dtype=torch.int64, device=self.device) - - # cudagraph logic for FlashInfer kernels, need diff wrapper for each batch size we make a graph for - if is_draft and config.draft_async: - self._init_flashinfer_wrappers() - + if self.verbose: print(f'INSIDE MODEL RUNNER INIT, DRAFT={is_draft}', flush=True) self.tp_pg = None @@ -169,56 +163,6 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event], is_dra if self.verbose: print(f'-----{model_type}MODEL RUNNER INITIALIZED----', flush=True) - def _init_flashinfer_wrappers(self): - """Initialize FlashInfer wrappers for draft async mode.""" - self.workspace_buffer = torch.zeros( - 768 * 1024 * 1024, dtype=torch.uint8, device=f"cuda:{self.rank}") - - if self.config.enforce_eager: - self.only_prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") - else: - max_bs = min(self.config.max_num_seqs, 512) - max_num_blocks = (self.config.max_model_len + self.block_size - 1) // self.block_size - - # FlashInfer kernel tensors - # pages_for_max_len = (self.config.max_model_len + self.block_size - 1) // self.block_size - last_page_len_max_len = self.config.max_model_len % self.block_size - last_page_len_max_len = self.block_size if last_page_len_max_len == 0 else last_page_len_max_len - MQ_LEN = self.config.async_fan_out * (self.config.speculate_k + 1) - - cu_seqlens_q = torch.empty(max_bs + 1, dtype=torch.int32, device=self.device) - kv_indptr = torch.empty(max_bs + 1, dtype=torch.int32, device=self.device) - kv_indices = torch.empty(max_bs * max_num_blocks, dtype=torch.int32, device=self.device) - kv_last_page_len = torch.empty(max_bs, dtype=torch.int32, device=self.device) - custom_mask_buf = torch.empty(max_bs * MQ_LEN * self.config.max_model_len, dtype=torch.uint8, device=self.device) - mask_indptr_buf = torch.empty(max_bs + 1, dtype=torch.int32, device=self.device) - - # Create graph_bs_list to match what will be used in cudagraph_helpers.py - graph_bs_list = [1] - for bs in [2, 4, 8] + list(range(16, max_bs + 1, 16)): - if bs <= max_bs: - graph_bs_list.append(bs) - if max_bs not in graph_bs_list: - graph_bs_list.append(max_bs) - graph_bs_list.sort() - - # Create a dict of wrappers, one for each bs we will touch in cudagraph_helpers.py - self.prefill_wrappers = {} - print(f'[model_runner about to wrapper.init()] graph_bs_list={graph_bs_list}', flush=True) - for bs in graph_bs_list: - self.prefill_wrappers[bs] = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - self.workspace_buffer, "NHD", - use_cuda_graph=True, - qo_indptr_buf=cu_seqlens_q[:bs + 1], - paged_kv_indptr_buf=kv_indptr[:bs + 1], - paged_kv_indices_buf=kv_indices[:bs * max_num_blocks], - paged_kv_last_page_len_buf=kv_last_page_len[:bs], - custom_mask_buf=custom_mask_buf[:bs * MQ_LEN * self.config.max_model_len], - mask_indptr_buf=mask_indptr_buf[:bs + 1], - ) - print(f'wrapper backend is {self.prefill_wrappers[bs]._backend}', flush=True) - - def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoConfig, init_q=None, is_draft=False): # cudagraphs self.graph_vars = {} @@ -554,15 +498,20 @@ def allocate_kv_cache(self): ) print(f"allocate_kv_cache(): kv_cache shape = {self.kv_cache.shape}", flush=True) + # Create tree_score_mod once (shared across all attention layers) + tree_score_mod = None + if self.is_draft and self.draft_async: + from ssd.layers.tree_mask import create_tree_score_mod + tree_score_mod = create_tree_score_mod(config.max_model_len) + layer_id = 0 for module in self.model.modules(): if hasattr(module, "k_cache") and hasattr(module, "v_cache"): module.k_cache = self.kv_cache[0, layer_id] module.v_cache = self.kv_cache[1, layer_id] - if self.is_draft and self.draft_async and not self.enforce_eager: - module.prefill_wrappers = self.prefill_wrappers - elif self.is_draft and self.draft_async and self.enforce_eager: - module.only_prefill_wrapper = self.only_prefill_wrapper # this will make it not None so it can be used on fwd + if self.is_draft and self.draft_async: + module.max_seqlen_k = config.max_model_len + module.tree_score_mod = tree_score_mod layer_id += 1 @@ -613,45 +562,21 @@ def prepare_sample(self, seqs: list[Sequence]): return temperatures def eager_tree_decode_plan(self, input_ids, positions, step, cache_hits): - """Plan FlashInfer for tree decode in eager mode""" + """Set up context metadata for FA4 tree decode in eager mode.""" assert self.is_draft and self.config.draft_async, "ERROR in eager_tree_decode_plan: not a draft async model" + from ssd.layers.tree_mask import build_tree_mask_bias context = get_context() - - K, F = self.config.speculate_k, self.config.async_fan_out - # MQ_LEN = F * (K+1) + K = self.config.speculate_k MQ_LEN = self.config.MQ_LEN - flat_batch_size = input_ids.size(0) - B = flat_batch_size // MQ_LEN # [N] tokens = B * sum(fan_out_list) - - # Convert block_tables to FlashInfer format - block_tables = context.block_tables # [B, M] - context_lens = context.context_lens # [B] - - counts = (context_lens + self.block_size - 1) // self.block_size # [B] - kv_indptr = torch.cat([torch.tensor([0], device=block_tables.device), - counts.cumsum(dim=0)]).to(torch.int32) - mask = torch.arange(block_tables.size(1), device=block_tables.device)[None, :] < counts[:, None] - kv_indices = block_tables[mask] # flattened page ids - - # Last-page actual token count per request - kv_last_page_len = (context_lens % self.block_size) - kv_last_page_len[kv_last_page_len == 0] = self.block_size - kv_last_page_len = kv_last_page_len.to(torch.int32) - cu_seqlens_q = torch.arange(B + 1, device=self.device, dtype=torch.int32) * MQ_LEN # assumes same MQ_LEN across batch dimension - custom_mask = get_custom_mask(self.config, context_lens, step, K, F, B, device=self.device, cache_hits=cache_hits) - - self.only_prefill_wrapper.plan( - cu_seqlens_q, - kv_indptr, - kv_indices, - kv_last_page_len, - self.hf_config.num_attention_heads, - self.hf_config.num_key_value_heads, - self.hf_config.head_dim, - self.block_size, - custom_mask=custom_mask, - q_data_type=self.hf_config.torch_dtype, - kv_data_type=self.hf_config.torch_dtype, + B = input_ids.size(0) // MQ_LEN + context.tree_cu_seqlens_q = torch.arange(B + 1, device=self.device, dtype=torch.int32) * MQ_LEN + context.tree_mask_bias = build_tree_mask_bias( + context.context_lens, step=step, K=K, MQ_LEN=MQ_LEN, + fan_out_list=self.config.fan_out_list, + fan_out_list_miss=self.config.fan_out_list_miss, + cache_hits=cache_hits, + max_kv_stride=self.config.max_model_len, + device=self.device, ) @property diff --git a/ssd/layers/attention.py b/ssd/layers/attention.py index ed5ec7b3a..7d2b9cec1 100644 --- a/ssd/layers/attention.py +++ b/ssd/layers/attention.py @@ -4,6 +4,8 @@ import triton.language as tl from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from flash_attn.cute.interface import flash_attn_varlen_func as fa4_varlen_func +from ssd.layers.tree_mask import create_tree_score_mod from ssd.utils.context import get_context @@ -65,10 +67,10 @@ def __init__( self.speculate = speculate self.draft_async = draft_async self.use_eagle = use_eagle - self.prefill_wrappers = {} self.F = F # async_fan_out self.K = K # speculate_k - self.only_prefill_wrapper = None + self.max_seqlen_k = 0 # set during KV cache allocation to config.max_model_len + self.tree_score_mod = None # set during KV cache allocation def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): o: torch.Tensor @@ -111,18 +113,24 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): ) elif tree_decode: - if self.only_prefill_wrapper is not None: - prefill_wrapper = self.only_prefill_wrapper - else: - mq_len = self.F * (self.K+1) - bs = q.shape[0] // mq_len - wrapper_bs = None - for available_bs in sorted(self.prefill_wrappers.keys()): - if available_bs >= bs: - wrapper_bs = available_bs - break - prefill_wrapper = self.prefill_wrappers[wrapper_bs] - o = prefill_wrapper.run(q, (self.k_cache, self.v_cache)) + score_mod_kwargs = {} + if self.tree_score_mod is not None and context.tree_mask_bias is not None: + score_mod_kwargs["score_mod"] = self.tree_score_mod + score_mod_kwargs["aux_tensors"] = [context.tree_mask_bias] + o, _ = fa4_varlen_func( + q, + self.k_cache, + self.v_cache, + cu_seqlens_q=context.tree_cu_seqlens_q, + cu_seqlens_k=None, + max_seqlen_q=self.F * (self.K + 1), + max_seqlen_k=self.max_seqlen_k, + seqused_k=context.context_lens, + page_table=context.block_tables, + softmax_scale=self.scale, + causal=False, + **score_mod_kwargs, + ) else: # single query decode q = q.unsqueeze(1) o = flash_attn_with_kvcache(q, k_cache, v_cache, diff --git a/ssd/layers/tree_mask.py b/ssd/layers/tree_mask.py new file mode 100644 index 000000000..d44a7ec14 --- /dev/null +++ b/ssd/layers/tree_mask.py @@ -0,0 +1,100 @@ +"""Tree decode mask for FA4 via score_mod + aux_tensors. + +The tree mask is stored as a dense float32 bias tensor of shape +(max_total_q, max_kv_stride), flattened to 1D. Unmasked positions have +value 0.0; masked positions have a large negative value (-1e6). + +score_mod adds the bias to each attention score, effectively masking out +positions where the bias is -1e6. +""" + +import torch +import numpy as np +import cutlass +import cutlass.cute as cute + +# Large negative value used to mask attention scores. +_MASK_VAL = -1.0e6 + + +def create_tree_score_mod(max_kv_stride: int): + """Return a @cute.jit score_mod that reads a mask bias from aux_tensors[0]. + + The aux_tensor is a 1D float32 tensor indexed by: + (offset_q + q_idx) * max_kv_stride + kv_idx + + where offset_q comes from seqlen_info for varlen sequences. + """ + + @cute.jit + def tree_score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + mask_bias = aux_tensors[0] + dtype = mask_bias.element_type + global_q = seqlen_info.offset_q + q_idx + flat_idx = global_q * max_kv_stride + kv_idx + idx_frag = cute.make_rmem_tensor(1, cutlass.Int32) + idx_frag.store(flat_idx) + val_frag = cute.make_rmem_tensor(1, dtype) + val_frag[0] = mask_bias[idx_frag[0]] + bias = (val_frag.load()).to(cutlass.Float32) + return tSrS_ssa + bias + + return tree_score_mod + + +def build_tree_mask_bias( + context_lens: torch.Tensor, + step: int, + K: int, + MQ_LEN: int, + fan_out_list: list[int], + fan_out_list_miss: list[int], + cache_hits: torch.Tensor, + max_kv_stride: int, + device: torch.device, +) -> torch.Tensor: + """Build the dense mask bias tensor for one tree decode step. + + Returns a 1D float32 tensor of shape (B * MQ_LEN * max_kv_stride,) + with 0.0 for attend and _MASK_VAL for masked positions. + """ + B = context_lens.shape[0] + context_lens_list = context_lens.tolist() + cache_hits_list = cache_hits[:B].tolist() + + # Pre-compute glue patterns + tril = np.tril(np.ones((K + 1, K + 1), dtype=np.float32)) + fol = np.array(fan_out_list) + fol_miss = np.array(fan_out_list_miss) + glue_hit = np.repeat(tril, fol, axis=0) # (MQ_LEN, K+1) + glue_miss = np.repeat(tril, fol_miss, axis=0) + + ttl_added = (step + 1) * MQ_LEN + (K + 1) + rows = np.arange(MQ_LEN) + + # Build mask as numpy, then convert + bias = np.full((B * MQ_LEN, max_kv_stride), _MASK_VAL, dtype=np.float32) + + for b in range(B): + cols_b = int(context_lens_list[b]) + prefix_len_b = cols_b - ttl_added + row_offset = b * MQ_LEN + + # Prefix: attend to all + if prefix_len_b > 0: + bias[row_offset:row_offset + MQ_LEN, :prefix_len_b] = 0.0 + + # Glue pattern + glue = glue_hit if int(cache_hits_list[b]) == 1 else glue_miss + glue_start = prefix_len_b + glue_bias = np.where(glue > 0, 0.0, _MASK_VAL).astype(np.float32) + bias[row_offset:row_offset + MQ_LEN, glue_start:glue_start + K + 1] = glue_bias + + # Diagonal blocks + diag_start = prefix_len_b + K + 1 + for blk in range(step + 1): + col_indices = diag_start + blk * MQ_LEN + rows + valid = col_indices < max_kv_stride + bias[row_offset + rows[valid], col_indices[valid]] = 0.0 + + return torch.from_numpy(bias.reshape(-1)).to(device, non_blocking=True) diff --git a/ssd/utils/context.py b/ssd/utils/context.py index 91c744a27..cccb3459c 100644 --- a/ssd/utils/context.py +++ b/ssd/utils/context.py @@ -13,15 +13,17 @@ class Context: slot_mapping: torch.Tensor | None = None context_lens: torch.Tensor | None = None block_tables: torch.Tensor | None = None + tree_cu_seqlens_q: torch.Tensor | None = None + tree_mask_bias: torch.Tensor | None = None _CONTEXT = Context() def get_context(): return _CONTEXT -def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None, is_jit=False): +def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None, is_jit=False, tree_cu_seqlens_q=None, tree_mask_bias=None): global _CONTEXT - _CONTEXT = Context(is_prefill, is_jit, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables) + _CONTEXT = Context(is_prefill, is_jit, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables, tree_cu_seqlens_q, tree_mask_bias) def reset_context(): global _CONTEXT diff --git a/tests/test_fa4_tree_decode.py b/tests/test_fa4_tree_decode.py new file mode 100644 index 000000000..19102ad75 --- /dev/null +++ b/tests/test_fa4_tree_decode.py @@ -0,0 +1,201 @@ +"""Tests for FA4 flash_attn_varlen_func with paged KV cache (tree decode replacement).""" + +import pytest +import torch +from flash_attn.cute.interface import flash_attn_varlen_func as fa4_varlen_func +from ssd.layers.attention import Attention +from ssd.utils.context import set_context, reset_context + + +DEVICE = "cuda" +DTYPE = torch.bfloat16 + + +# --------------------------------------------------------------------------- +# FA4 varlen + page_table: basic correctness +# --------------------------------------------------------------------------- + +class TestFA4VarlenPageTable: + """Test flash_attn_varlen_func with page_table at various page sizes.""" + + @pytest.fixture(autouse=True) + def setup(self): + torch.manual_seed(42) + self.B = 2 + self.MQ_LEN = 6 + self.num_heads = 4 + self.num_kv_heads = 2 + self.head_dim = 128 + self.num_pages = 200 + self.max_pages_per_seq = 20 + + def _run(self, page_size, kv_lens): + total_q = self.B * self.MQ_LEN + q = torch.randn(total_q, self.num_heads, self.head_dim, dtype=DTYPE, device=DEVICE) + k_cache = torch.randn(self.num_pages, page_size, self.num_kv_heads, self.head_dim, dtype=DTYPE, device=DEVICE) + v_cache = torch.randn(self.num_pages, page_size, self.num_kv_heads, self.head_dim, dtype=DTYPE, device=DEVICE) + cu_seqlens_q = torch.arange(self.B + 1, dtype=torch.int32, device=DEVICE) * self.MQ_LEN + + page_table = torch.zeros(self.B, self.max_pages_per_seq, dtype=torch.int32, device=DEVICE) + for b in range(self.B): + n_pages = (kv_lens[b] + page_size - 1) // page_size + page_table[b, :n_pages] = torch.arange(n_pages, dtype=torch.int32, device=DEVICE) + b * 50 + + seqused_k = torch.tensor(kv_lens, dtype=torch.int32, device=DEVICE) + + out, lse = fa4_varlen_func( + q, k_cache, v_cache, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=None, + max_seqlen_q=self.MQ_LEN, + max_seqlen_k=max(kv_lens), + seqused_k=seqused_k, + page_table=page_table, + softmax_scale=self.head_dim ** -0.5, + causal=False, + ) + return out, lse + + @pytest.mark.parametrize("page_size", [1, 16, 128]) + def test_output_shape(self, page_size): + out, _ = self._run(page_size, kv_lens=[10, 5]) + assert out.shape == (self.B * self.MQ_LEN, self.num_heads, self.head_dim) + + @pytest.mark.parametrize("page_size", [1, 16, 128]) + def test_no_nan_inf(self, page_size): + out, _ = self._run(page_size, kv_lens=[10, 5]) + assert not torch.isnan(out).any(), "Output contains NaN" + assert not torch.isinf(out).any(), "Output contains Inf" + + @pytest.mark.parametrize("page_size", [1, 16, 128]) + def test_lse_returned_none_by_default(self, page_size): + _, lse = self._run(page_size, kv_lens=[10, 5]) + assert lse is None, "LSE should be None when return_lse=False (default)" + + def test_variable_kv_lengths(self): + """Sequences with very different KV lengths should both produce valid output.""" + self.max_pages_per_seq = 60 # accommodate kv_len=50 + out, _ = self._run(page_size=1, kv_lens=[50, 3]) + assert not torch.isnan(out).any() + # Check that the two sequences produce different outputs (they have different KV) + out_seq0 = out[:self.MQ_LEN] + out_seq1 = out[self.MQ_LEN:] + assert not torch.allclose(out_seq0, out_seq1), "Different KV should produce different outputs" + + def test_deterministic(self): + """Same inputs should produce same outputs.""" + out1, _ = self._run(page_size=1, kv_lens=[10, 5]) + torch.manual_seed(42) # reset seed to get same random inputs + out2, _ = self._run(page_size=1, kv_lens=[10, 5]) + assert torch.allclose(out1, out2), "Same inputs should produce identical outputs" + + def test_batch_size_1(self): + """Single-sequence batch should work.""" + self.B = 1 + out, _ = self._run(page_size=1, kv_lens=[10]) + assert out.shape == (self.MQ_LEN, self.num_heads, self.head_dim) + assert not torch.isnan(out).any() + + +# --------------------------------------------------------------------------- +# Attention layer integration: tree decode path +# --------------------------------------------------------------------------- + +class TestAttentionTreeDecode: + """Test the Attention module's tree_decode path end-to-end with FA4.""" + + @pytest.fixture(autouse=True) + def setup(self): + torch.manual_seed(42) + self.num_heads = 8 + self.num_kv_heads = 2 + self.head_dim = 128 + self.scale = self.head_dim ** -0.5 + self.F_fan = 2 + self.K_spec = 2 + self.MQ_LEN = self.F_fan * (self.K_spec + 1) + self.page_size = 1 + self.num_pages = 200 + self.max_pages_per_seq = 50 + self.max_model_len = 50 + yield + reset_context() + + def _make_attn(self): + attn = Attention( + num_heads=self.num_heads, head_dim=self.head_dim, scale=self.scale, + num_kv_heads=self.num_kv_heads, draft=True, speculate=True, + draft_async=True, use_eagle=False, F=self.F_fan, K=self.K_spec, + ) + attn.k_cache = torch.randn( + self.num_pages, self.page_size, self.num_kv_heads, self.head_dim, + dtype=DTYPE, device=DEVICE) + attn.v_cache = torch.randn( + self.num_pages, self.page_size, self.num_kv_heads, self.head_dim, + dtype=DTYPE, device=DEVICE) + attn.max_seqlen_k = self.max_model_len + return attn + + def _run(self, attn, B, context_lens_list): + total_tokens = B * self.MQ_LEN + q = torch.randn(total_tokens, self.num_heads * self.head_dim, dtype=DTYPE, device=DEVICE) + k = torch.randn(total_tokens, self.num_kv_heads * self.head_dim, dtype=DTYPE, device=DEVICE) + v = torch.randn(total_tokens, self.num_kv_heads * self.head_dim, dtype=DTYPE, device=DEVICE) + + context_lens = torch.tensor(context_lens_list, dtype=torch.int32, device=DEVICE) + slot_mapping = torch.arange(total_tokens, dtype=torch.int32, device=DEVICE) + + block_tables = torch.zeros(B, self.max_pages_per_seq, dtype=torch.int32, device=DEVICE) + for b in range(B): + n_pages = context_lens_list[b] # page_size=1, so pages == tokens + block_tables[b, :n_pages] = torch.arange(n_pages, dtype=torch.int32, device=DEVICE) + b * 50 + + cu_seqlens_q = torch.arange(B + 1, dtype=torch.int32, device=DEVICE) * self.MQ_LEN + + set_context( + is_prefill=False, + slot_mapping=slot_mapping, + context_lens=context_lens, + block_tables=block_tables, + tree_cu_seqlens_q=cu_seqlens_q, + ) + + with torch.inference_mode(): + out = attn(q, k, v) + return out + + def test_output_shape(self): + attn = self._make_attn() + out = self._run(attn, B=2, context_lens_list=[20, 15]) + expected = (2 * self.MQ_LEN, self.num_heads * self.head_dim) + assert out.shape == expected, f"Expected {expected}, got {out.shape}" + + def test_no_nan_inf(self): + attn = self._make_attn() + out = self._run(attn, B=2, context_lens_list=[20, 15]) + assert not torch.isnan(out).any(), "Output contains NaN" + assert not torch.isinf(out).any(), "Output contains Inf" + + def test_single_sequence(self): + attn = self._make_attn() + out = self._run(attn, B=1, context_lens_list=[30]) + expected = (self.MQ_LEN, self.num_heads * self.head_dim) + assert out.shape == expected + + def test_different_context_lens(self): + """Sequences with different context lengths should produce different outputs.""" + attn = self._make_attn() + out = self._run(attn, B=2, context_lens_list=[40, 10]) + out_seq0 = out[:self.MQ_LEN] + out_seq1 = out[self.MQ_LEN:] + assert not torch.allclose(out_seq0, out_seq1) + + def test_non_tree_decode_paths_unaffected(self): + """Verify that non-tree-decode paths still use the original kernels.""" + attn = Attention( + num_heads=self.num_heads, head_dim=self.head_dim, scale=self.scale, + num_kv_heads=self.num_kv_heads, draft=False, speculate=False, + draft_async=False, use_eagle=False, + ) + # This attention module should NOT take the tree_decode path + assert not (attn.speculate and attn.draft and attn.draft_async) diff --git a/tests/test_score_mod_basic.py b/tests/test_score_mod_basic.py new file mode 100644 index 000000000..e7ea7cdfe --- /dev/null +++ b/tests/test_score_mod_basic.py @@ -0,0 +1,155 @@ +"""Test that score_mod with aux_tensors works with FA4 varlen + page_table.""" + +import torch +import pytest +from flash_attn.cute.interface import flash_attn_varlen_func +from ssd.layers.tree_mask import create_tree_score_mod, build_tree_mask_bias + +DEVICE = "cuda" +DTYPE = torch.bfloat16 + + +class TestScoreModBasic: + """Verify score_mod compiles and runs with FA4 varlen + page_table.""" + + @pytest.fixture(autouse=True) + def setup(self): + torch.manual_seed(42) + self.B = 2 + self.MQ_LEN = 6 + self.num_heads = 4 + self.num_kv_heads = 2 + self.head_dim = 128 + self.num_pages = 200 + self.max_pages_per_seq = 50 + self.page_size = 1 + + def _make_inputs(self, kv_lens): + total_q = self.B * self.MQ_LEN + q = torch.randn(total_q, self.num_heads, self.head_dim, dtype=DTYPE, device=DEVICE) + k_cache = torch.randn(self.num_pages, self.page_size, self.num_kv_heads, self.head_dim, dtype=DTYPE, device=DEVICE) + v_cache = torch.randn(self.num_pages, self.page_size, self.num_kv_heads, self.head_dim, dtype=DTYPE, device=DEVICE) + cu_seqlens_q = torch.arange(self.B + 1, dtype=torch.int32, device=DEVICE) * self.MQ_LEN + page_table = torch.zeros(self.B, self.max_pages_per_seq, dtype=torch.int32, device=DEVICE) + for b in range(self.B): + n = kv_lens[b] + page_table[b, :n] = torch.arange(n, dtype=torch.int32, device=DEVICE) + b * 50 + seqused_k = torch.tensor(kv_lens, dtype=torch.int32, device=DEVICE) + return q, k_cache, v_cache, cu_seqlens_q, page_table, seqused_k + + def test_zero_bias_matches_no_scoremod(self): + """A score_mod that adds zero should produce identical output.""" + kv_lens = [10, 5] + max_kv_stride = 50 + q, k, v, cu, pt, sk = self._make_inputs(kv_lens) + + out_base, _ = flash_attn_varlen_func( + q, k, v, cu_seqlens_q=cu, cu_seqlens_k=None, + max_seqlen_q=self.MQ_LEN, max_seqlen_k=max(kv_lens), + seqused_k=sk, page_table=pt, + softmax_scale=self.head_dim ** -0.5, causal=False, + ) + + score_mod = create_tree_score_mod(max_kv_stride) + # All-zero bias = no masking + bias = torch.zeros(self.B * self.MQ_LEN * max_kv_stride, dtype=torch.float32, device=DEVICE) + + out_mod, _ = flash_attn_varlen_func( + q, k, v, cu_seqlens_q=cu, cu_seqlens_k=None, + max_seqlen_q=self.MQ_LEN, max_seqlen_k=max(kv_lens), + seqused_k=sk, page_table=pt, + softmax_scale=self.head_dim ** -0.5, causal=False, + score_mod=score_mod, aux_tensors=[bias], + ) + + assert torch.allclose(out_base, out_mod, atol=1e-2), \ + f"Zero bias should match base, max diff: {(out_base - out_mod).abs().max().item()}" + + def test_full_mask_produces_uniform_attention(self): + """Masking all but one KV position should concentrate attention there.""" + kv_lens = [10, 5] + max_kv_stride = 50 + q, k, v, cu, pt, sk = self._make_inputs(kv_lens) + + score_mod = create_tree_score_mod(max_kv_stride) + # Mask everything except KV position 0 for all queries + bias = torch.full((self.B * self.MQ_LEN * max_kv_stride,), -1e6, dtype=torch.float32, device=DEVICE) + for b in range(self.B): + for qi in range(self.MQ_LEN): + flat_idx = (b * self.MQ_LEN + qi) * max_kv_stride + 0 # only attend to kv_idx=0 + bias[flat_idx] = 0.0 + + out, _ = flash_attn_varlen_func( + q, k, v, cu_seqlens_q=cu, cu_seqlens_k=None, + max_seqlen_q=self.MQ_LEN, max_seqlen_k=max(kv_lens), + seqused_k=sk, page_table=pt, + softmax_scale=self.head_dim ** -0.5, causal=False, + score_mod=score_mod, aux_tensors=[bias], + ) + + assert not torch.isnan(out).any(), "Masked output has NaN" + assert not torch.isinf(out).any(), "Masked output has Inf" + + +class TestTreeMaskBuild: + """Test build_tree_mask_bias produces correct mask structure.""" + + def test_prefix_unmasked(self): + """All prefix positions should have bias=0 (attend).""" + B, K, MQ_LEN = 1, 2, 6 + fol = [2, 2, 2] + context_lens = torch.tensor([20], dtype=torch.int32) # prefix = 20 - (1*6 + 3) = 11 + cache_hits = torch.tensor([1]) + max_kv_stride = 50 + + bias = build_tree_mask_bias( + context_lens, step=0, K=K, MQ_LEN=MQ_LEN, + fan_out_list=fol, fan_out_list_miss=fol, + cache_hits=cache_hits, max_kv_stride=max_kv_stride, + device="cpu", + ) + bias_2d = bias.reshape(MQ_LEN, max_kv_stride) + prefix_len = 20 - (1 * MQ_LEN + K + 1) + # All prefix columns should be 0.0 (unmasked) + assert (bias_2d[:, :prefix_len] == 0.0).all(), "Prefix should be unmasked" + + def test_masked_positions_negative(self): + """Positions beyond the valid KV should be masked (large negative).""" + B, K, MQ_LEN = 1, 2, 6 + fol = [2, 2, 2] + context_lens = torch.tensor([20], dtype=torch.int32) + cache_hits = torch.tensor([1]) + max_kv_stride = 50 + + bias = build_tree_mask_bias( + context_lens, step=0, K=K, MQ_LEN=MQ_LEN, + fan_out_list=fol, fan_out_list_miss=fol, + cache_hits=cache_hits, max_kv_stride=max_kv_stride, + device="cpu", + ) + bias_2d = bias.reshape(MQ_LEN, max_kv_stride) + # Beyond context_lens should be masked + assert (bias_2d[:, 20:] < -1e5).all(), "Beyond context_lens should be masked" + + def test_diagonal_pattern(self): + """At step 0, each query should attend to its own diagonal position.""" + B, K, MQ_LEN = 1, 2, 6 + fol = [2, 2, 2] + # context_lens at step 0 needs to be at least ttl_added = 1*MQ_LEN + K+1 = 9 + context_lens = torch.tensor([15], dtype=torch.int32) + cache_hits = torch.tensor([1]) + max_kv_stride = 50 + + bias = build_tree_mask_bias( + context_lens, step=0, K=K, MQ_LEN=MQ_LEN, + fan_out_list=fol, fan_out_list_miss=fol, + cache_hits=cache_hits, max_kv_stride=max_kv_stride, + device="cpu", + ) + bias_2d = bias.reshape(MQ_LEN, max_kv_stride) + prefix_len = 15 - (1 * MQ_LEN + K + 1) # = 6 + diag_start = prefix_len + K + 1 # = 9 + # At step 0, block 0: bias_2d[q, diag_start + q] should be 0.0 + for q in range(MQ_LEN): + col = diag_start + q + assert bias_2d[q, col].item() == 0.0, f"Diagonal at q={q}, col={col} should be unmasked" diff --git a/tests/test_tree_mask_correctness.py b/tests/test_tree_mask_correctness.py new file mode 100644 index 000000000..0f8750c50 --- /dev/null +++ b/tests/test_tree_mask_correctness.py @@ -0,0 +1,164 @@ +"""Correctness tests: verify FA4 tree mask matches the original flashinfer mask logic.""" + +import torch +import numpy as np +import pytest +from flash_attn.cute.interface import flash_attn_varlen_func +from ssd.layers.tree_mask import create_tree_score_mod, build_tree_mask_bias +from ssd.engine.helpers.mask_helpers import get_custom_mask + +DEVICE = "cuda" +DTYPE = torch.bfloat16 + + +class FakeConfig: + """Minimal config for get_custom_mask.""" + def __init__(self, K, fan_out_list, fan_out_list_miss, max_model_len): + self.speculate_k = K + self.fan_out_list = fan_out_list + self.fan_out_list_miss = fan_out_list_miss + self.max_model_len = max_model_len + + +class TestTreeMaskMatchesOriginal: + """Verify that build_tree_mask_bias produces masks equivalent to get_custom_mask.""" + + @pytest.fixture(autouse=True) + def setup(self): + self.K = 2 + self.F = 2 + self.fan_out_list = [2, 2, 2] # F=2, K+1=3 groups + self.fan_out_list_miss = [2, 2, 2] + self.MQ_LEN = sum(self.fan_out_list) # = 6 + + def _compare_masks(self, B, context_lens_list, step, cache_hits_list): + """Compare old (get_custom_mask) vs new (build_tree_mask_bias) for one step.""" + context_lens = torch.tensor(context_lens_list, dtype=torch.int32, device=DEVICE) + cache_hits = torch.tensor(cache_hits_list, dtype=torch.float32, device=DEVICE) + max_model_len = 100 + + config = FakeConfig(self.K, self.fan_out_list, self.fan_out_list_miss, max_model_len) + + # Old mask: 1D bool tensor, concatenation of per-seq (MQ_LEN x kv_len) masks + old_mask = get_custom_mask( + config, context_lens, step, self.K, self.F, B, + device=DEVICE, cache_hits=cache_hits, + ) + + # New mask bias: (B * MQ_LEN * max_model_len,) float32 + new_bias = build_tree_mask_bias( + context_lens, step=step, K=self.K, MQ_LEN=self.MQ_LEN, + fan_out_list=self.fan_out_list, + fan_out_list_miss=self.fan_out_list_miss, + cache_hits=cache_hits, + max_kv_stride=max_model_len, + device=DEVICE, + ) + new_bias_2d = new_bias.reshape(B * self.MQ_LEN, max_model_len) + + # Extract per-batch masks from old format and compare + old_offset = 0 + for b in range(B): + kv_len = context_lens_list[b] + old_mask_b = old_mask[old_offset:old_offset + self.MQ_LEN * kv_len].reshape(self.MQ_LEN, kv_len) + new_mask_b = new_bias_2d[b * self.MQ_LEN:(b + 1) * self.MQ_LEN, :kv_len] + + # Old: True = attend, False = mask + # New: 0.0 = attend, -1e6 = mask + new_attend = (new_mask_b == 0.0) + old_attend = old_mask_b.bool() + + mismatches = (new_attend != old_attend).sum().item() + assert mismatches == 0, ( + f"Mask mismatch at batch={b}, step={step}: {mismatches} positions differ\n" + f" old attend count: {old_attend.sum().item()}, new attend count: {new_attend.sum().item()}\n" + f" context_len={kv_len}, cache_hit={cache_hits_list[b]}" + ) + old_offset += self.MQ_LEN * kv_len + + @pytest.mark.parametrize("step", [0, 1]) + def test_single_seq_cache_hit(self, step): + # context_lens must be >= ttl_added = (step+1)*MQ_LEN + K+1 + cl = 30 + step * self.MQ_LEN + self._compare_masks(B=1, context_lens_list=[cl], step=step, cache_hits_list=[1]) + + @pytest.mark.parametrize("step", [0, 1]) + def test_single_seq_cache_miss(self, step): + cl = 30 + step * self.MQ_LEN + self._compare_masks(B=1, context_lens_list=[cl], step=step, cache_hits_list=[0]) + + @pytest.mark.parametrize("step", [0, 1]) + def test_multi_seq_mixed_hits(self, step): + base = 25 + step * self.MQ_LEN + self._compare_masks( + B=3, + context_lens_list=[base, base + 10, base + 5], + step=step, + cache_hits_list=[1, 0, 1], + ) + + def test_step_2(self): + cl = 40 + 2 * self.MQ_LEN + self._compare_masks(B=2, context_lens_list=[cl, cl - 5], step=2, cache_hits_list=[1, 0]) + + +class TestFA4WithTreeMask: + """End-to-end: verify FA4 attention with tree mask produces valid, masked output.""" + + @pytest.fixture(autouse=True) + def setup(self): + torch.manual_seed(42) + self.B = 2 + self.K = 2 + self.MQ_LEN = 6 + self.num_heads = 4 + self.num_kv_heads = 2 + self.head_dim = 128 + self.num_pages = 200 + self.page_size = 1 + self.max_pages_per_seq = 50 + self.max_kv_stride = 50 + self.fan_out_list = [2, 2, 2] + self.fan_out_list_miss = [2, 2, 2] + + def test_masked_vs_unmasked_differ(self): + """Masked attention should produce different output than unmasked.""" + kv_lens = [20, 15] + total_q = self.B * self.MQ_LEN + q = torch.randn(total_q, self.num_heads, self.head_dim, dtype=DTYPE, device=DEVICE) + k = torch.randn(self.num_pages, self.page_size, self.num_kv_heads, self.head_dim, dtype=DTYPE, device=DEVICE) + v = torch.randn(self.num_pages, self.page_size, self.num_kv_heads, self.head_dim, dtype=DTYPE, device=DEVICE) + cu = torch.arange(self.B + 1, dtype=torch.int32, device=DEVICE) * self.MQ_LEN + pt = torch.zeros(self.B, self.max_pages_per_seq, dtype=torch.int32, device=DEVICE) + for b in range(self.B): + pt[b, :kv_lens[b]] = torch.arange(kv_lens[b], dtype=torch.int32, device=DEVICE) + b * 50 + sk = torch.tensor(kv_lens, dtype=torch.int32, device=DEVICE) + + # Unmasked (causal=False, no score_mod) + out_unmasked, _ = flash_attn_varlen_func( + q, k, v, cu_seqlens_q=cu, cu_seqlens_k=None, + max_seqlen_q=self.MQ_LEN, max_seqlen_k=max(kv_lens), + seqused_k=sk, page_table=pt, + softmax_scale=self.head_dim ** -0.5, causal=False, + ) + + # Masked + score_mod = create_tree_score_mod(self.max_kv_stride) + context_lens = torch.tensor(kv_lens, dtype=torch.int32) + cache_hits = torch.tensor([1, 1]) + mask_bias = build_tree_mask_bias( + context_lens, step=0, K=self.K, MQ_LEN=self.MQ_LEN, + fan_out_list=self.fan_out_list, fan_out_list_miss=self.fan_out_list_miss, + cache_hits=cache_hits, max_kv_stride=self.max_kv_stride, device=DEVICE, + ) + out_masked, _ = flash_attn_varlen_func( + q, k, v, cu_seqlens_q=cu, cu_seqlens_k=None, + max_seqlen_q=self.MQ_LEN, max_seqlen_k=max(kv_lens), + seqused_k=sk, page_table=pt, + softmax_scale=self.head_dim ** -0.5, causal=False, + score_mod=score_mod, aux_tensors=[mask_bias], + ) + + assert not torch.isnan(out_masked).any(), "Masked output has NaN" + assert not torch.allclose(out_masked, out_unmasked, atol=1e-2), \ + "Masked and unmasked should produce different outputs" From e701bfe5a9095522a54d7306adb6af60029f6dad Mon Sep 17 00:00:00 2001 From: Avner May Date: Sun, 29 Mar 2026 05:27:02 -0700 Subject: [PATCH 3/6] More logging --- ssd/engine/model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ssd/engine/model_runner.py b/ssd/engine/model_runner.py index 531b234ec..25ac7b9de 100644 --- a/ssd/engine/model_runner.py +++ b/ssd/engine/model_runner.py @@ -232,6 +232,7 @@ def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoC self.async_pg = init_custom_process_group( backend="nccl", store=store, world_size=2, rank=1, group_name="async_spec") + print('[model_runner] NCCL process group formed, now receiving kv_cache_size...', flush=True) # Cross-node: receive kv_cache_size from target so draft # allocates the same number of KV cache blocks. kv_buf = torch.empty(1, dtype=torch.int64, device=self.device) From f8af8e7619fa746676fc9741dcf8ab0cab435782 Mon Sep 17 00:00:00 2001 From: Avner May Date: Fri, 10 Apr 2026 10:43:30 -0700 Subject: [PATCH 4/6] Acceptance rate log and force-jit-speculate --- ssd/config.py | 2 ++ ssd/engine/draft_runner.py | 46 ++++++++++++++++++++++++++++++-------- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/ssd/config.py b/ssd/config.py index 5d1c7ea63..558802943 100644 --- a/ssd/config.py +++ b/ssd/config.py @@ -33,6 +33,7 @@ class Config: fan_out_list_miss: list[int] | None = None sampler_x: float | None = None jit_speculate: bool = False + force_jit_speculate: bool = False async_nccl_port: int | None = None async_nccl_host: str = "127.0.0.1" communicate_logits: bool = False @@ -88,6 +89,7 @@ def __post_init__(self): print(f'[Config] Setting fan_out_list_miss to [sum(fan_out_list)] + [0] * speculate_k because jit_speculate is False', flush=True) self.fan_out_list_miss = [sum(self.fan_out_list)] + [0] * self.speculate_k elif self.fan_out_list_miss is None: + # If you are jit speculating, always use the same fan_out_list for misses as for hits. self.fan_out_list_miss = self.fan_out_list assert sum(self.fan_out_list_miss) == sum(self.fan_out_list), "ERROR in Config: fan_out_list_miss must be the same as fan_out_list" diff --git a/ssd/engine/draft_runner.py b/ssd/engine/draft_runner.py index 0765ecee9..c8799be38 100644 --- a/ssd/engine/draft_runner.py +++ b/ssd/engine/draft_runner.py @@ -54,6 +54,11 @@ def __init__(self, draft_cfg: Config, rank: int = 0, init_q = None): self._reset_tree_cache_tensors() self._init_prealloc_buffers() self._draft_step_times = [] + self._acceptance_lengths = [] + self._cache_hits = [] + self._acceptance_rate_log_path = os.environ.get("ACCEPTANCE_RATE_LOG", None) + if self._acceptance_rate_log_path: + print(f'[{_ts()}] DraftRunner will log acceptance rate to: {self._acceptance_rate_log_path}', flush=True) print(f'[{_ts()}] DraftRunner set up, starting draft_loop', flush=True) self.draft_loop() @@ -219,7 +224,7 @@ def hit_cache(self, request_keys, B, K, num_tokens, temperatures, draft_block_ta # Init miss slots with valid random logits so token IDs are in-vocab (fixes B>1 crash) out_logits = torch.empty(B, K, V, dtype=self.hf_config.torch_dtype, device=self.device).uniform_() out_tokens = out_logits.argmax(dim=-1) - cache_hits = torch.zeros(B, dtype=torch.int64, device=self.device) + cache_hits = torch.zeros(B, dtype=torch.bool, device=self.device) assert request_keys.shape == (B, 3), f"ERROR in hit_cache: request_keys should be (B, 3), got {request_keys.shape}" @@ -227,24 +232,24 @@ def hit_cache(self, request_keys, B, K, num_tokens, temperatures, draft_block_ta B, K, self.hidden_states_dim, dtype=self.hf_config.torch_dtype, device=self.device ) if self.config.use_eagle_or_phoenix else None - + # Statistics ttl += int(B) - + if self.config.verbose: print(f"[{_ts()}] [hit_cache] Request keys: {request_keys}", flush=True) for i in range(B): rec_token = request_keys[i, 2].item() rec_text = self.tokenizer.decode([rec_token]) print(f"[{_ts()}] Req {i}: token={rec_token} ('{rec_text}')", flush=True) - + if self.tree_cache_keys.numel() > 0: # Vectorized membership against tensor cache eq = (request_keys.unsqueeze(1) == self.tree_cache_keys.unsqueeze(0)) # [B,T,3] match = torch.all(eq, dim=2) # [B,T] cache_hits = match.any(dim=1) # [B] ttl_hit += int(cache_hits.sum().item()) - + if self.config.verbose: print(f"[{_ts()}] [hit_cache] Cache hits: {cache_hits.sum().item()}/{B}", flush=True) print(f"[{_ts()}] [hit_cache] Cache: {self.tree_cache_keys.shape[0]} entries", flush=True) @@ -263,9 +268,9 @@ def hit_cache(self, request_keys, B, K, num_tokens, temperatures, draft_block_ta rec_text = self.tokenizer.decode([rec_token]) hit_marker = "[HIT]" if i in hit_indices else "" print(f"[{_ts()}] [{i}]: key=({seq_id}, {k_idx}, {rec_token}) -> value=('{rec_text}') {hit_marker}", flush=True) - + # Fill hits - if (cache_hits.any() and not self.config.jit_speculate) or (cache_hits.all() and self.config.jit_speculate): + if not self.config.force_jit_speculate and ((cache_hits.any() and not self.config.jit_speculate) or (cache_hits.all() and self.config.jit_speculate)): # print(f'[hit_cache] got all cache hits, using cached logits and tokens', flush=True) # [B], arbitrary if no match but masked out idx = match.float().argmax(dim=1).to(torch.int64) @@ -306,7 +311,7 @@ def hit_cache(self, request_keys, B, K, num_tokens, temperatures, draft_block_ta ) if self.config.use_eagle_or_phoenix: out_activations = jit_acts - + rec_toks = request_keys[:, 2] if self.config.verbose: @@ -345,9 +350,18 @@ def _service_spec_request(self): out_tokens, out_logits, glue_decode_input_ids, cache_hits, out_activations = self.hit_cache( cache_keys, B, K, num_tokens, temperatures, draft_block_tables, target_recovery_activations) + if self._acceptance_rate_log_path: + # Collect per-step metrics for logging. + # cache_keys[:, 1] is last_spec_step_accepted_len - 1 from the target; + # first request has -1 (forced miss). + for i in range(B): + accept_len = cache_keys[i, 1].item() + 1 + self._acceptance_lengths.append(accept_len) + self._cache_hits.append(cache_hits[i].item()) + speculation_response = SpeculationResponse( speculations=out_tokens.reshape(-1).to(torch.int64), - cache_hits=cache_hits.reshape(-1) if self.communicate_cache_hits else None, + cache_hits=cache_hits.reshape(-1).to(torch.int64) if self.communicate_cache_hits else None, logits_q=out_logits[:, :K, :].contiguous() if self.communicate_logits else None, ) if BRIEF_LOG: @@ -972,6 +986,20 @@ def _draft_loop_inner(self): if self._draft_step_times: avg_ms = sum(self._draft_step_times) * 1000 / len(self._draft_step_times) print(f"[{_ts()}] [metrics] Avg draft step time (ms): {avg_ms:.2f}", flush=True) + if self._acceptance_rate_log_path and self._acceptance_lengths: + import json + avg_acc = sum(self._acceptance_lengths) / len(self._acceptance_lengths) + hit_rate = sum(self._cache_hits) / len(self._cache_hits) if self._cache_hits else 0 + print(f"[{_ts()}] [metrics] Avg acceptance length: {avg_acc:.2f} ({len(self._acceptance_lengths)} steps)", flush=True) + print(f"[{_ts()}] [metrics] Cache hit rate: {hit_rate:.2%} ({sum(self._cache_hits)}/{len(self._cache_hits)})", flush=True) + print(f"[{_ts()}] [metrics] All acceptance lengths: {self._acceptance_lengths}", flush=True) + print(f"[{_ts()}] [metrics] All cache hits: {self._cache_hits}", flush=True) + print(f"[{_ts()}] [metrics] Logging acceptance lengths and cache hits to: {self._acceptance_rate_log_path}", flush=True) + with open(self._acceptance_rate_log_path, "w") as f: + json.dump({ + "acceptance_lengths": self._acceptance_lengths, + "cache_hits": self._cache_hits, + }, f) self.exit() break From 4c6997ff67c7aa949669d95876699449c547a343 Mon Sep 17 00:00:00 2001 From: Avner May Date: Fri, 10 Apr 2026 10:49:17 -0700 Subject: [PATCH 5/6] Improvements to benchmarking --- bench/bench.py | 9 +- bench/bench_helpers.py | 9 +- bench/bench_paths.py | 10 +- bench/run_sglang_bench.py | 213 ++++++++++++++++++++++++++------------ 4 files changed, 172 insertions(+), 69 deletions(-) diff --git a/bench/bench.py b/bench/bench.py index b80f21955..5e013f099 100644 --- a/bench/bench.py +++ b/bench/bench.py @@ -37,7 +37,7 @@ def parse_arguments(): parser.add_argument("--fl", type=int, nargs='+', default=None, help="Fan out list (e.g., --fl 1 3 4 becomes [1, 3, 4])") parser.add_argument("--flh", type=int, nargs='+', default=None, help="Fan out list (e.g., --flh 1 3 4 becomes [1, 3, 4])") parser.add_argument("--flm", type=int, nargs='+', default=None, help="Fan out list miss (e.g., --flm 1 3 4 becomes [1, 3, 4])") - parser.add_argument("--backup", type=str, choices=["jit", "fast"], default="jit", help="Backup strategy (jit or fast)") + parser.add_argument("--backup", type=str, choices=["jit", "force-jit", "fast"], default="jit", help="Backup strategy (jit or fast)") # Memory and batching configuration parser.add_argument("--block_sz", type=int, default=256, help="KV cache block size (see config.py: kvcache_block_size)") @@ -129,7 +129,7 @@ def initialize_wandb(args, run_name): "gpus": args.gpus, "speculative_decoding": args.spec, "async_speculative": getattr(args, 'async', False), - "jit_speculative": args.backup == "jit", + "backup_strategy": args.backup, "k": args.k if args.spec else None, "f": args.f, "fan_out_list": args.flh, @@ -172,8 +172,11 @@ def create_llm_kwargs(args, draft_path): max_num_seqs=args.b, max_model_len=args.max_model_len, sampler_x=args.x, - jit_speculate=(args.backup == "jit"), + jit_speculate=(args.backup == "jit" or args.backup == "force-jit"), + force_jit_speculate=(args.backup == "force-jit"), max_steps=args.max_steps, + communicate_cache_hits=True, + communicate_logits=True, ) if args.flh is not None: diff --git a/bench/bench_helpers.py b/bench/bench_helpers.py index 4079cf3a6..17153ab2a 100644 --- a/bench/bench_helpers.py +++ b/bench/bench_helpers.py @@ -157,6 +157,7 @@ def load_dataset_token_ids( return None dataset_file_path = DATASET_PATHS[dataset_name] + print(f"Loading dataset '{dataset_name}' from: {dataset_file_path}") if not os.path.exists(dataset_file_path): print( f"Warning: Dataset file not found at {dataset_file_path}, falling back to random tokens") @@ -172,10 +173,16 @@ def load_dataset_token_ids( data = json.loads(line.strip()) text: str = data["text"] if use_chat_template and hasattr(tokenizer, 'apply_chat_template'): - tokens = tokenizer.apply_chat_template( + result = tokenizer.apply_chat_template( [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": text}], add_generation_prompt=True, ) + text_result = tokenizer.apply_chat_template( + [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": text}], + add_generation_prompt=True, + tokenize=False, + ) + tokens = result.input_ids if hasattr(result, 'input_ids') else result else: tokens = tokenizer.encode(text, add_special_tokens=False) diff --git a/bench/bench_paths.py b/bench/bench_paths.py index 5e2e5ec6a..c4dd72a48 100644 --- a/bench/bench_paths.py +++ b/bench/bench_paths.py @@ -52,6 +52,10 @@ def _required_env(var_name: str, note: str) -> str: "BENCH_LLAMA_1B", f"{HF_CACHE_DIR}/models--meta-llama--Llama-3.2-1B-Instruct", ), + "qwen_8b": os.environ.get( + "BENCH_QWEN_8B", + f"{HF_CACHE_DIR}/models--Qwen--Qwen3-8B", + ), "qwen_32b": os.environ.get( "BENCH_QWEN_32B", f"{HF_CACHE_DIR}/models--Qwen--Qwen3-32B", @@ -62,12 +66,16 @@ def _required_env(var_name: str, note: str) -> str: ), "eagle3_llama_70b": os.environ.get( "BENCH_EAGLE3_LLAMA_70B", - "lmsys/SGLang-EAGLE3-Llama-3.3-70B-Instruct-SpecForge", + f"{HF_CACHE_DIR}/models--lmsys--SGLang-EAGLE3-Llama-3.3-70B-Instruct-SpecForge", ), "eagle3_qwen_32b": os.environ.get( "BENCH_EAGLE3_QWEN_32B", "Zhihu-ai/Zhi-Create-Qwen3-32B-Eagle3", ), + "phoenix2_qwen_8b": os.environ.get( + "BENCH_PHOENIX2_QWEN_8B", + "togethercomputer/phnx2-llama-decagon-4layer-v1.0", + ), } diff --git a/bench/run_sglang_bench.py b/bench/run_sglang_bench.py index 2949f8be7..c76a7b2c6 100644 --- a/bench/run_sglang_bench.py +++ b/bench/run_sglang_bench.py @@ -6,7 +6,7 @@ Usage: python run_sglang_bench.py --llama # SD, Llama 70B python run_sglang_bench.py --qwen # SD, Qwen 32B - python run_sglang_bench.py --llama --mode ar # autoregressive baseline + python run_sglang_bench.py --llama --mode AR # autoregressive baseline python run_sglang_bench.py --llama --wandb --name myrun # log to wandb Set model paths via env vars (BENCH_LLAMA_70B, etc.) or edit bench_paths.py. @@ -23,77 +23,37 @@ from bench_paths import MODELS, resolve_snapshot -def get_server_cmd(args): - if args.llama: - target = resolve_snapshot(MODELS["llama_70b"]) - draft = resolve_snapshot(MODELS["llama_1b"]) - else: - target = resolve_snapshot(MODELS["qwen_32b"]) - draft = resolve_snapshot(MODELS["qwen_0.6b"]) - - cmd = [ - sys.executable, "-m", "sglang.launch_server", - "--model-path", target, - "--tp", str(args.tp), - "--mem-fraction-static", str(args.mem_frac), - "--max-running-requests", "1", - "--disable-radix-cache", - "--log-level", "warning", - "--port", str(args.port), - ] - - if args.mode == "sd": - # Speculative decoding with standalone draft model. - # Default: k=5 (num_steps=4, num_draft_tokens=5). - cmd += [ - "--speculative-algorithm", "STANDALONE", - "--speculative-draft-model-path", draft, - "--speculative-num-steps", str(args.num_steps), - "--speculative-eagle-topk", "1", - "--speculative-num-draft-tokens", str(args.num_draft_tokens), - ] - # mode == "ar": no speculative flags, just serve the target model. - - return cmd, target - - -def wait_for_server(port, timeout=900, interval=5): - url = f"http://localhost:{port}/health" - deadline = time.time() + timeout - while time.time() < deadline: - try: - if requests.get(url, timeout=2).status_code == 200: - return True - except requests.ConnectionError: - pass - time.sleep(interval) - return False - - -def kill_server(proc): - if proc.poll() is None: - os.killpg(os.getpgid(proc.pid), signal.SIGKILL) - proc.wait() - - def main(): parser = argparse.ArgumentParser(description="Launch SGLang server and benchmark it") parser.add_argument("--llama", action="store_true", default=True) parser.add_argument("--qwen", action="store_true") - parser.add_argument("--mode", choices=["ar", "sd"], default="sd", + parser.add_argument("--mode", choices=["AR", "STANDALONE", "ASYNC_STANDALONE", "EAGLE3", "ASYNC_EAGLE3", "PHOENIX", "ASYNC_PHOENIX"], default="STANDALONE", help="ar = autoregressive, sd = speculative decoding (default)") parser.add_argument("--tp", type=int, default=4) parser.add_argument("--port", type=int, default=40010) - parser.add_argument("--mem_frac", type=float, default=0.70) - parser.add_argument("--num_steps", type=int, default=4, help="draft chain depth (k = num_steps + 1)") - parser.add_argument("--num_draft_tokens", type=int, default=5) + parser.add_argument("--mem-frac", type=float, default=0.70) + parser.add_argument("--num-steps", type=int, default=4, help="draft chain depth (k = num_steps + 1)") + parser.add_argument("--context-length", type=int, default=4096) # Pass-through to eval client parser.add_argument("--numseqs", type=int, default=128) - parser.add_argument("--output_len", type=int, default=512) + parser.add_argument("--output-len", type=int, default=512) parser.add_argument("--temp", type=float, default=0.0) + parser.add_argument("--dataset", type=str, choices=["all", "humaneval", "alpaca", "c4", "ultrafeedback", "random", "example"], default="all") parser.add_argument("--wandb", action="store_true") - parser.add_argument("--group", type=str, default=None) + parser.add_argument("--group", type=str, default="ssd") parser.add_argument("--name", type=str, default=None) + + parser.add_argument("--f", type=int, default=4, help="Async fan out value") + parser.add_argument("--fl", type=int, nargs='+', default=None, help="Fan out list (e.g., --fl 1 3 4 becomes [1, 3, 4])") + parser.add_argument("--flh", type=int, nargs='+', default=None, help="Fan out list (e.g., --flh 1 3 4 becomes [1, 3, 4])") + parser.add_argument("--flm", type=int, nargs='+', default=None, help="Fan out list miss (e.g., --flm 1 3 4 becomes [1, 3, 4])") + parser.add_argument("--jit", action="store_true") + parser.add_argument("--force-jit", action="store_true") + parser.add_argument("--communicate-cache-hits", action="store_true") + parser.add_argument("--verbose", action="store_true") + parser.add_argument("--acceptance-rate-log", type=str, default=None, + help="Path to log acceptance rates (sets ACCEPTANCE_RATE_LOG env var for the server)") + args = parser.parse_args() if args.qwen: args.llama = False @@ -107,7 +67,12 @@ def main(): capture_output=True) time.sleep(2) - proc = subprocess.Popen(server_cmd, preexec_fn=os.setsid) + env = os.environ.copy() + if args.acceptance_rate_log: + env["ACCEPTANCE_RATE_LOG"] = args.acceptance_rate_log + print(f"ACCEPTANCE_RATE_LOG={args.acceptance_rate_log}") + + proc = subprocess.Popen(server_cmd, preexec_fn=os.setsid, env=env) try: print("Waiting for server...") if not wait_for_server(args.port): @@ -122,15 +87,16 @@ def main(): "--numseqs", str(args.numseqs), "--output_len", str(args.output_len), "--temp", str(args.temp), - "--all", "--b", "1", + f"--{args.dataset}", + "--b", "1", "--port", str(args.port), ] if args.llama: eval_cmd.append("--llama") else: eval_cmd.append("--qwen") - if args.mode == "sd": - eval_cmd += ["--draft", "1" if args.llama else "0.6"] + if is_eagle3(args.mode): + eval_cmd.append("--eagle") if args.wandb: eval_cmd += ["--wandb"] if args.group: @@ -145,5 +111,124 @@ def main(): print("Server stopped") +def is_spec(mode): + return mode in ["STANDALONE", "ASYNC_STANDALONE", "EAGLE3", "ASYNC_EAGLE3", "PHOENIX2", "ASYNC_PHOENIX2"] + + +def is_async(mode): + return mode in ["ASYNC_STANDALONE", "ASYNC_EAGLE3", "ASYNC_PHOENIX"] + + +def is_standalone(mode): + return mode in ["STANDALONE", "ASYNC_STANDALONE"] + +def is_eagle3(mode): + return mode in ["EAGLE3", "ASYNC_EAGLE3"] + + +def is_phoenix(mode): + return mode in ["PHOENIX2", "ASYNC_PHOENIX2"] + + +def get_server_cmd(args): + if args.llama: + target = resolve_snapshot(MODELS["llama_70b"]) + if is_standalone(args.mode): + draft = resolve_snapshot(MODELS["llama_1b"]) + + elif is_eagle3(args.mode): + draft = resolve_snapshot(MODELS["eagle3_llama_70b"]) + else: + raise ValueError(f"Unsupported mode for llama: {args.mode}") + else: + target = resolve_snapshot(MODELS["qwen_32b"]) + if is_standalone(args.mode): + draft = resolve_snapshot(MODELS["qwen_0.6b"]) + elif is_eagle3(args.mode): + draft = resolve_snapshot(MODELS["eagle3_qwen_32b"]) + elif is_phoenix(args.mode): + target = resolve_snapshot(MODELS["qwen_8b"]) + draft = resolve_snapshot(MODELS["phoenix2_qwen_8b"]) + else: + raise ValueError(f"Unsupported mode for qwen: {args.mode}") + + cmd = [ + sys.executable, "-m", "sglang.launch_server", + "--model-path", target, + "--tp", str(args.tp), + "--mem-fraction-static", str(args.mem_frac), + "--max-running-requests", "1", + # "--disable-radix-cache", + "--log-level", "warning", + "--port", str(args.port), + "--context-length", str(args.context_length), + ] + + if is_spec(args.mode): + # Speculative decoding with standalone draft model. + # Default: k=5 (num_steps=4, num_draft_tokens=5). + cmd += [ + "--speculative-algorithm", args.mode, + "--speculative-draft-model-path", draft, + "--speculative-num-steps", str(args.num_steps), + "--speculative-eagle-topk", "1", + "--speculative-num-draft-tokens", str(args.num_steps + 1), + ] + if is_async(args.mode): + cmd += [ + "--speculative-async-fan-out", str(args.f), + ] + if args.fl: + cmd += [ + "--speculative-async-fan-out-list", ",".join(map(str, args.fl)), + ] + if args.flh: + cmd += [ + "--speculative-async-fan-out-list-hit", ",".join(map(str, args.flh)), + ] + if args.flm: + cmd += [ + "--speculative-async-fan-out-list-miss", ",".join(map(str, args.flm)), + ] + if args.jit or args.force_jit: + cmd += [ + "--speculative-async-jit-speculate", + ] + if args.force_jit: + cmd += [ + "--speculative-async-force-jit-speculate", + ] + if args.communicate_cache_hits: + cmd += [ + "--speculative-async-communicate-cache-hits", + ] + if args.verbose: + cmd += [ + "--speculative-async-verbose", + ] + + # mode == "ar": no speculative flags, just serve the target model. + return cmd, target + + +def wait_for_server(port, timeout=900, interval=5): + url = f"http://localhost:{port}/health" + deadline = time.time() + timeout + while time.time() < deadline: + try: + if requests.get(url, timeout=2).status_code == 200: + return True + except requests.ConnectionError: + pass + time.sleep(interval) + return False + + +def kill_server(proc): + if proc.poll() is None: + os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + proc.wait() + + if __name__ == "__main__": main() From b417d75fba99ae531c1d42f2c3345d949c3ae463 Mon Sep 17 00:00:00 2001 From: Avner May Date: Fri, 10 Apr 2026 14:01:32 -0700 Subject: [PATCH 6/6] NIT: print cache_hits as ints --- ssd/engine/draft_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ssd/engine/draft_runner.py b/ssd/engine/draft_runner.py index c8799be38..5882b5fc7 100644 --- a/ssd/engine/draft_runner.py +++ b/ssd/engine/draft_runner.py @@ -357,7 +357,7 @@ def _service_spec_request(self): for i in range(B): accept_len = cache_keys[i, 1].item() + 1 self._acceptance_lengths.append(accept_len) - self._cache_hits.append(cache_hits[i].item()) + self._cache_hits.append(int(cache_hits[i].item())) speculation_response = SpeculationResponse( speculations=out_tokens.reshape(-1).to(torch.int64),