diff --git a/pyproject.toml b/pyproject.toml index 7c43d4e11..690a519db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,17 +14,18 @@ requires-python = ">=3.11,<3.13" dependencies = [ "torch==2.9.1", "triton", - "transformers==4.57.1", + "transformers>=5.3.0", "xxhash", "numpy", "safetensors", "tqdm", - "flashinfer-python==0.6.6", "sgl-kernel==0.3.21", "nvidia-cutlass-dsl>=4.3.4", "wandb==0.22.0", "hf_transfer", "tiktoken", + # Install from source for now, for latest support on Hopper + "flash-attn-4 @ git+https://github.com/Dao-AILab/flash-attention.git@5301a359f59ef8fa10f211618d9f7a69716a8898#subdirectory=flash_attn/cute", ] [project.urls] diff --git a/ssd/engine/helpers/cudagraph_helpers.py b/ssd/engine/helpers/cudagraph_helpers.py index 6c38eeddf..b2d41887d 100644 --- a/ssd/engine/helpers/cudagraph_helpers.py +++ b/ssd/engine/helpers/cudagraph_helpers.py @@ -1,7 +1,6 @@ import os import math import torch -import numpy as np from ssd.utils.context import set_context, get_context, reset_context from time import perf_counter @@ -122,9 +121,6 @@ def run_decode_cudagraph(model_runner, input_ids, positions, last_only, graph_va return logits -cache = {} - -_plan_event = None # Lazy-init CUDA event for plan() sync PROFILE = os.environ.get("SSD_PROFILE", "0") == "1" PROFILE_DRAFT = os.environ.get("SSD_PROFILE_DRAFT", "0") == "1" _draft_events = [] # [(step, label, start_event, end_event), ...] @@ -149,30 +145,23 @@ def flush_draft_profile(): @torch.inference_mode() def run_fi_tree_decode_cudagraph(model_runner, input_ids, positions, last_only, graph_vars, step, cache_hits, hidden_states=None): - # bs != len(input_ids, positions) now in multi-query seting, also need step-dependent mask context = get_context() - assert context.cu_seqlens_q is None, "ERROR in run_fi_tree_decode_cudagraph: cu_seqlens_q should be set to None so we don't take FA path" - K, F = model_runner.config.speculate_k, model_runner.config.async_fan_out - # MQ_LEN = F * (K+1) MQ_LEN = sum(model_runner.config.fan_out_list) orig_flat = input_ids.size(0) assert orig_flat % MQ_LEN == 0, f"ERROR in run_fi_tree_decode_cudagraph: flat_batch_size should be divisible by MQ_LEN, got {orig_flat} and {MQ_LEN}" orig_B = orig_flat // MQ_LEN - # Pick CUDA graph and wrapper bucket + # Pick CUDA graph bucket wrapper_bs = next( x for x in model_runner.graph_bs_list["fi_tree_decode"] if x >= orig_B) graph = model_runner.graphs["fi_tree_decode"][wrapper_bs] - wrapper = model_runner.prefill_wrappers[wrapper_bs] # Prepare padded inputs/context if needed if wrapper_bs > orig_B: - # print(f'PADDING--') pad_B = wrapper_bs - orig_B pad_flat = pad_B * MQ_LEN - # Pad queries (ids/rope positions) pad_ids = torch.zeros( pad_flat, dtype=input_ids.dtype, device=input_ids.device) pad_pos = torch.zeros( @@ -180,13 +169,11 @@ def run_fi_tree_decode_cudagraph(model_runner, input_ids, positions, last_only, input_ids = torch.cat([input_ids, pad_ids], dim=0) positions = torch.cat([positions, pad_pos], dim=0) - # Pad slot_mapping with -1 to skip KV writes for padded queries slot_map = torch.cat( [context.slot_mapping, torch.full((pad_flat,), -1, dtype=context.slot_mapping.dtype, device=context.slot_mapping.device)] ) - # Pad block_tables/context_lens by repeating the last real row bt = context.block_tables cl = context.context_lens pad_bt = bt[orig_B - 1:orig_B].expand(pad_B, -1).contiguous() @@ -194,205 +181,54 @@ def run_fi_tree_decode_cudagraph(model_runner, input_ids, positions, last_only, bt = torch.cat([bt, pad_bt], dim=0) cl = torch.cat([cl, pad_cl], dim=0) - # Set padded context for this replay set_context(is_prefill=False, slot_mapping=slot_map, - context_lens=cl, block_tables=bt) + context_lens=cl, block_tables=bt, + tree_cu_seqlens_q=graph_vars["tree_cu_seqlens_q"][wrapper_bs], + tree_mask_bias=graph_vars["tree_mask_bias"]) block_tables = bt context_lens = cl - flat_batch_size = input_ids.size(0) # == wrapper_bs * MQ_LEN + flat_batch_size = input_ids.size(0) B = wrapper_bs else: block_tables = context.block_tables context_lens = context.context_lens flat_batch_size = orig_flat B = orig_B - - if PROFILE: - torch.cuda.synchronize() - start_time = torch.cuda.Event(enable_timing=True) - end_time = torch.cuda.Event(enable_timing=True) - start_time.record() + # Set tree decode metadata on context for FA4 + context.tree_cu_seqlens_q = graph_vars["tree_cu_seqlens_q"][wrapper_bs] + context.tree_mask_bias = graph_vars["tree_mask_bias"] # in the case where we pad, we'll need cache_hits.shape[0] to match the padded batch size if cache_hits.shape[0] < B: cache_hits = torch.cat([cache_hits, torch.zeros(B - cache_hits.shape[0], device=cache_hits.device)]) - # PERFORMANCE: Step 0 -- precompute KV page metadata on CPU for all K steps. - # CPU tensors let plan() skip its internal .to("cpu") GPU->CPU syncs. - # For B<=8, CPU slicing also avoids GPU boolean indexing. - if step == 0: - cache["cu_seqlens_q_cpu"] = torch.arange(B + 1, dtype=torch.int32) * MQ_LEN - context_lens_list = context_lens.tolist() - cache["block_tables"] = block_tables - block_size = model_runner.block_size - cache["precomputed_kv"] = [] - cache["plan_cpu_args"] = [] - - if B <= 8: - # PERFORMANCE: CPU-only kv_indices via slicing (no GPU boolean indexing) - for s in range(K): - step_cls = [int(cl) + s * MQ_LEN for cl in context_lens_list] - step_counts = [(cl + block_size - 1) // block_size for cl in step_cls] - if B == 1: - kv_indices_s = block_tables[0, :step_counts[0]] - else: - kv_indices_s = torch.cat([block_tables[b, :step_counts[b]] for b in range(B)]) - cache["precomputed_kv"].append(kv_indices_s) - kv_indptr_cpu = torch.zeros(B + 1, dtype=torch.int32) - kv_indptr_cpu[1:] = torch.tensor(step_counts, dtype=torch.int32).cumsum(0) - kv_lpl_cpu = torch.tensor( - [cl % block_size if cl % block_size != 0 else block_size for cl in step_cls], - dtype=torch.int32) - cache["plan_cpu_args"].append((kv_indptr_cpu, kv_lpl_cpu)) - else: - # Large batch: GPU boolean indexing for kv_indices, CPU tensors for plan args - bt_upcast = torch.arange(block_tables.size(1), device=block_tables.device)[None, :] - step_offsets = torch.arange(K + 2, device=context_lens.device) * MQ_LEN - all_step_cls = context_lens.unsqueeze(1) + step_offsets.unsqueeze(0) - all_counts = (all_step_cls + block_size - 1) // block_size - all_masks = bt_upcast.unsqueeze(1) < all_counts.unsqueeze(2) - for s in range(K): - cache["precomputed_kv"].append(block_tables[all_masks[:, s, :]]) - step_cls = [int(cl) + s * MQ_LEN for cl in context_lens_list] - step_counts = [(cl + block_size - 1) // block_size for cl in step_cls] - kv_indptr_cpu = torch.zeros(B + 1, dtype=torch.int32) - kv_indptr_cpu[1:] = torch.tensor(step_counts, dtype=torch.int32).cumsum(0) - kv_lpl_cpu = torch.tensor( - [cl % block_size if cl % block_size != 0 else block_size for cl in step_cls], - dtype=torch.int32) - cache["plan_cpu_args"].append((kv_indptr_cpu, kv_lpl_cpu)) - - # CPU mask precompute: build all K packed masks using numpy at step 0. - # Eliminates per-step get_custom_mask (GPU) + segment_packbits + GPU->CPU syncs. - cache_hits_list = cache_hits[:B].tolist() - - if "glue_hit_np" not in cache: - _fol = model_runner.config.fan_out_list - _fol_miss = model_runner.config.fan_out_list_miss - _tril = np.tril(np.ones((K + 1, K + 1), dtype=np.uint8)) - cache["glue_hit_np"] = np.repeat(_tril, _fol, axis=0) - cache["glue_miss_np"] = np.repeat(_tril, _fol_miss, axis=0) - - _glue_hit = cache["glue_hit_np"] - _glue_miss = cache["glue_miss_np"] - _rows_np = np.arange(MQ_LEN) - - cache["cpu_packed_masks"] = [] - cache["cpu_packed_indptrs"] = [] - - for s in range(K): - ttl_added_s = (s + 1) * MQ_LEN + (K + 1) - packed_segs = [] - seg_packed_sizes = [] - - for b in range(B): - cols_b = int(context_lens_list[b]) + s * MQ_LEN - prefix_len_b = cols_b - ttl_added_s - - mask_b = np.zeros((MQ_LEN, cols_b), dtype=np.uint8) - mask_b[:, :prefix_len_b] = 1 - glue = _glue_hit if int(cache_hits_list[b]) == 1 else _glue_miss - mask_b[:, prefix_len_b:prefix_len_b + K + 1] = glue - diag_start = prefix_len_b + K + 1 - for blk in range(s + 1): - mask_b[_rows_np, diag_start + blk * MQ_LEN + _rows_np] = 1 - - packed = np.packbits(mask_b.ravel(), bitorder='little') - packed_segs.append(packed) - seg_packed_sizes.append(len(packed)) - - full_packed = np.concatenate(packed_segs) if B > 1 else packed_segs[0] - indptr = np.zeros(B + 1, dtype=np.int32) - indptr[1:] = np.cumsum(seg_packed_sizes) - - cache["cpu_packed_masks"].append( - torch.from_numpy(full_packed.copy()).to(model_runner.device, non_blocking=True)) - cache["cpu_packed_indptrs"].append( - torch.from_numpy(indptr.copy()).to(model_runner.device, non_blocking=True)) - - # Pre-transfer KV metadata to GPU (eliminates per-step pageable H2D transfers) - cache["qo_indptr_gpu"] = cache["cu_seqlens_q_cpu"].to(model_runner.device, non_blocking=True) - cache["kv_indptr_gpu"] = [] - cache["kv_lpl_gpu"] = [] - cache["kv_lens_gpu"] = [] - for s in range(K): - ki, kl = cache["plan_cpu_args"][s] - cache["kv_indptr_gpu"].append(ki.to(model_runner.device, non_blocking=True)) - cache["kv_lpl_gpu"].append(kl.to(model_runner.device, non_blocking=True)) - kv_lens = ((ki[1:] - ki[:-1] - 1) * model_runner.block_size + kl).to(torch.int32) - cache["kv_lens_gpu"].append(kv_lens.to(model_runner.device, non_blocking=True)) - if PROFILE: - end_time.record() torch.cuda.synchronize() - precompute_time = start_time.elapsed_time(end_time) + start_time = torch.cuda.Event(enable_timing=True) + end_time = torch.cuda.Event(enable_timing=True) start_time.record() - # Use precomputed CPU-packed masks (built at step 0) - if PROFILE_DRAFT: - _ev_mask0 = torch.cuda.Event(enable_timing=True); _ev_mask0.record() - - kv_indices = cache["precomputed_kv"][step] - kv_indptr_cpu, kv_lpl_cpu = cache["plan_cpu_args"][step] - qo_indptr_cpu = cache["cu_seqlens_q_cpu"] - - packed_mask = cache["cpu_packed_masks"][step] - packed_indptr = cache["cpu_packed_indptrs"][step] - wrapper._custom_mask_buf[:len(packed_mask)].copy_(packed_mask, non_blocking=True) - wrapper._mask_indptr_buf.copy_(packed_indptr, non_blocking=True) - - # GPU-to-GPU copies from pre-transferred tensors (no pageable H2D) - wrapper._qo_indptr_buf.copy_(cache["qo_indptr_gpu"], non_blocking=True) - wrapper._paged_kv_indptr_buf.copy_(cache["kv_indptr_gpu"][step], non_blocking=True) - wrapper._paged_kv_last_page_len_buf.copy_(cache["kv_lpl_gpu"][step], non_blocking=True) - wrapper._paged_kv_indices_buf[:len(kv_indices)].copy_(kv_indices, non_blocking=True) - - total_num_rows = int(qo_indptr_cpu[-1].item()) - wrapper._kv_lens_buffer[:len(kv_indptr_cpu) - 1].copy_(cache["kv_lens_gpu"][step], non_blocking=True) - - # Event-based sync: only wait for this stream's copies, not all CUDA streams. - global _plan_event - if _plan_event is None: - _plan_event = torch.cuda.Event() - _plan_event.record() - _plan_event.synchronize() - - if PROFILE_DRAFT: - _ev_plan0 = torch.cuda.Event(enable_timing=True); _ev_plan0.record() - - plan_args = [ - wrapper._float_workspace_buffer, wrapper._int_workspace_buffer, - wrapper._pin_memory_int_workspace_buffer, - qo_indptr_cpu, kv_indptr_cpu, cache["kv_lens_gpu"][step], - wrapper._max_total_num_rows or total_num_rows, - B, model_runner.hf_config.num_attention_heads, - model_runner.hf_config.num_key_value_heads, - model_runner.block_size, wrapper.is_cuda_graph_enabled, - model_runner.hf_config.head_dim, model_runner.hf_config.head_dim, - False, -1, - ] - if wrapper._backend == "fa2": - plan_args.extend([-1, False, 0]) # fixed_split_size, disable_split_kv, num_colocated_ctas - wrapper._plan_info = wrapper._cached_module.plan(*plan_args) - - if PROFILE_DRAFT: - _ev_plan1 = torch.cuda.Event(enable_timing=True); _ev_plan1.record() - - if PROFILE: - end_time.record() - torch.cuda.synchronize() - plan_time = start_time.elapsed_time(end_time) - start_time.record() + # Build tree mask bias for this step and copy into pre-allocated buffer + from ssd.layers.tree_mask import build_tree_mask_bias + K = model_runner.config.speculate_k + mask_bias = build_tree_mask_bias( + context_lens, step=step, K=K, MQ_LEN=MQ_LEN, + fan_out_list=model_runner.config.fan_out_list, + fan_out_list_miss=model_runner.config.fan_out_list_miss, + cache_hits=cache_hits, + max_kv_stride=model_runner.config.max_model_len, + device=model_runner.device, + ) + graph_vars["tree_mask_bias"][:len(mask_bias)] = mask_bias - # Copy inputs/context into graph buffers for padded size + # Copy inputs/context into graph buffers graph_vars["input_ids"][:flat_batch_size] = input_ids graph_vars["positions"][:flat_batch_size] = positions graph_vars["slot_mapping"][:flat_batch_size] = get_context().slot_mapping graph_vars["context_lens"][:B] = context_lens if hidden_states is not None and "hidden_states" in graph_vars: if hidden_states.shape[0] < flat_batch_size: - # Pad hidden_states to match padded batch pad_n = flat_batch_size - hidden_states.shape[0] hidden_states = torch.cat([hidden_states, torch.zeros(pad_n, hidden_states.shape[1], dtype=hidden_states.dtype, device=hidden_states.device)]) graph_vars["hidden_states"][:flat_batch_size] = hidden_states @@ -412,8 +248,6 @@ def run_fi_tree_decode_cudagraph(model_runner, input_ids, positions, last_only, if PROFILE_DRAFT: _ev_replay1 = torch.cuda.Event(enable_timing=True); _ev_replay1.record() - _draft_events.append((step, "mask+buf", _ev_mask0, _ev_plan0)) - _draft_events.append((step, "plan", _ev_plan0, _ev_plan1)) _draft_events.append((step, "replay", _ev_replay0, _ev_replay1)) if PROFILE: @@ -421,14 +255,12 @@ def run_fi_tree_decode_cudagraph(model_runner, input_ids, positions, last_only, torch.cuda.synchronize() replay_time = start_time.elapsed_time(end_time) - # Extract logits from graph_vars instead of computing them separately logits_all = graph_vars["logits"][:flat_batch_size] if PROFILE: - print(f"[cuda_graph_helpers.run_fi_tree_decode_cudagraph] step {step}: precompute={precompute_time:.3f}ms, plan={plan_time:.3f}ms, buffer={buffer_prep_time:.3f}ms, replay={replay_time:.3f}ms", flush=True) + print(f"[cuda_graph_helpers.run_fi_tree_decode_cudagraph] step {step}: buffer={buffer_prep_time:.3f}ms, replay={replay_time:.3f}ms", flush=True) logits_out = logits_all[:orig_flat] - # EAGLE draft: also return prenorm (outputs) for self-conditioning if "hidden_states" in graph_vars: prenorm = graph_vars["outputs"][:orig_flat] return logits_out, prenorm @@ -782,8 +614,6 @@ def capture_fi_tree_decode_cudagraph(model_runner): config = model_runner.config hf_config = config.hf_config max_bs = min(model_runner.config.max_num_seqs, 512) - K, F = model_runner.config.speculate_k, model_runner.config.async_fan_out - # MQ_LEN = F * (K+1) MQ_LEN = sum(model_runner.config.fan_out_list) max_flat_batch_size = max_bs * MQ_LEN @@ -792,12 +622,11 @@ def capture_fi_tree_decode_cudagraph(model_runner): input_ids = torch.zeros(max_flat_batch_size, dtype=torch.int64, device=model_runner.device) positions = torch.zeros(max_flat_batch_size, dtype=torch.int64, device=model_runner.device) slot_mapping = torch.zeros(max_flat_batch_size, dtype=torch.int32, device=model_runner.device) - context_lens = torch.full((max_bs,), config.max_model_len, dtype=torch.int32, device=model_runner.device) # make sure these are consistent with our dummy example + context_lens = torch.full((max_bs,), config.max_model_len, dtype=torch.int32, device=model_runner.device) block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32, device=model_runner.device) outputs = torch.empty(max_flat_batch_size, hf_config.hidden_size, device=model_runner.device) logits = torch.empty(max_flat_batch_size, hf_config.vocab_size, device=model_runner.device) - # Create graph_bs_list to match what will be used in cudagraph_helpers.py graph_bs_list = [1] for bs in [2, 4, 8] + list(range(16, max_bs + 1, 16)): if bs <= max_bs: @@ -809,60 +638,35 @@ def capture_fi_tree_decode_cudagraph(model_runner): graphs = {} graph_pool = None - # Eagle draft needs hidden_states for forward (d_model_draft, NOT 3*d_model_target) - # All callers project target acts via fc() BEFORE passing to CG - # MUST be outside the for-loop so all graphs share the same tensor fi_hidden_states = None if config.use_eagle and model_runner.is_draft: fi_hidden_states = torch.zeros(max_flat_batch_size, hf_config.hidden_size, dtype=hf_config.torch_dtype, device=model_runner.device) - print(f'[cuda_graph_helpers.capture_fi_tree_decode_cudagraph] About to capture FI cudagraphs for bs={graph_bs_list}', flush=True) + # Pre-allocate tree_cu_seqlens_q per batch size bucket (constant values, used by FA4) + tree_cu_seqlens_q_dict = {} + for bs in graph_bs_list: + tree_cu_seqlens_q_dict[bs] = torch.arange( + bs + 1, dtype=torch.int32, device=model_runner.device) * MQ_LEN - for bs in reversed(graph_bs_list): - graph = torch.cuda.CUDAGraph() + # Pre-allocate tree mask bias at max size (shared across all batch sizes, updated before replay) + tree_mask_bias = torch.zeros( + max_flat_batch_size * config.max_model_len, + dtype=torch.float32, device=model_runner.device) - # Build a self-consistent fake plan for capture: - # - q_len = MQ_LEN for each request - # - k_len = max_model_len for each request (use maximum context length) + print(f'[cuda_graph_helpers.capture_fi_tree_decode_cudagraph] About to capture FA4 tree decode cudagraphs for bs={graph_bs_list}', flush=True) - cu_seqlens_q = torch.arange( - bs + 1, dtype=torch.int32, device=model_runner.device) * MQ_LEN - # Use max_num_blocks pages per request for maximum context length - kv_indptr = torch.arange( - bs + 1, dtype=torch.int32, device=model_runner.device) * max_num_blocks - kv_indices = torch.zeros(int( - kv_indptr[-1].item()), dtype=torch.int32, device=model_runner.device) # page ids (dummy) - # Last page length for max model len context - last_page_len = config.max_model_len % model_runner.block_size - if last_page_len == 0: - last_page_len = model_runner.block_size - kv_last_page_len = torch.full( - (bs,), last_page_len, dtype=torch.int32, device=model_runner.device) - custom_mask = torch.ones(bs * MQ_LEN * config.max_model_len, - dtype=torch.bool, device=model_runner.device) - - # Set the fi_tensors buffers with our fake data - model_runner.prefill_wrappers[bs].plan( - cu_seqlens_q, - kv_indptr, - kv_indices, - kv_last_page_len, - hf_config.num_attention_heads, - hf_config.num_key_value_heads, - hf_config.head_dim, - model_runner.block_size, - custom_mask=custom_mask, - q_data_type=hf_config.torch_dtype, - kv_data_type=hf_config.torch_dtype, - ) + for bs in reversed(graph_bs_list): + graph = torch.cuda.CUDAGraph() - # Set minimal context needed for run + # Set context with FA4 metadata set_context( is_prefill=False, slot_mapping=slot_mapping[:bs * MQ_LEN], context_lens=context_lens[:bs], - block_tables=block_tables[:bs] + block_tables=block_tables[:bs], + tree_cu_seqlens_q=tree_cu_seqlens_q_dict[bs], + tree_mask_bias=tree_mask_bias, ) # Warmup run @@ -898,6 +702,8 @@ def capture_fi_tree_decode_cudagraph(model_runner): context_lens=context_lens, outputs=outputs, logits=logits, + tree_cu_seqlens_q=tree_cu_seqlens_q_dict, + tree_mask_bias=tree_mask_bias, ) if fi_hidden_states is not None: graph_vars["hidden_states"] = fi_hidden_states diff --git a/ssd/engine/helpers/runner_helpers.py b/ssd/engine/helpers/runner_helpers.py index c818311ce..aaad1d89d 100644 --- a/ssd/engine/helpers/runner_helpers.py +++ b/ssd/engine/helpers/runner_helpers.py @@ -27,6 +27,8 @@ def _dump_ts(): print(f"[{_ts()}] BANANA: Dumping tensors to {DUMP_TENSORS_DIR}") os.makedirs(DUMP_TENSORS_DIR, exist_ok=True) DUMP_TENSORS = True +else: + DUMP_TENSORS = False def list_to_str(lst: list[float] | list[list[float]], num_decimals: int = 4) -> str: assert len(lst) > 0 diff --git a/ssd/engine/model_runner.py b/ssd/engine/model_runner.py index e601ab45d..6c1223d5a 100644 --- a/ssd/engine/model_runner.py +++ b/ssd/engine/model_runner.py @@ -8,7 +8,6 @@ from multiprocessing.shared_memory import SharedMemory from transformers import AutoTokenizer, AutoConfig import os -import flashinfer from ssd.config import Config from ssd.engine.sequence import Sequence from ssd.models.qwen3 import Qwen3ForCausalLM @@ -35,7 +34,6 @@ capture_fi_tree_decode_cudagraph, capture_glue_decode_cudagraph, ) -from ssd.engine.helpers.mask_helpers import get_custom_mask NCCL_LOG = os.environ.get("SSD_NCCL_LOG", "0") == "1" @@ -61,7 +59,7 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event], is_dra self.hf_config = config.hf_config if not is_draft else config.draft_hf_config self.block_size = config.kvcache_block_size self.enforce_eager = config.enforce_eager - self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path if config.tokenizer_path else config.model, use_fast=True) + self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path if config.tokenizer_path else config.model, use_fast=True, trust_remote_code=True) self.max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size assert self.hf_config is not None, "ERROR in ModelRunner: hf_config is None" # this implies boundedness to the end @@ -98,11 +96,7 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event], is_dra self.device = torch.device(f'cuda:{self.rank}') self._cmd = torch.empty(1, dtype=torch.int64, device=self.device) - - # cudagraph logic for FlashInfer kernels, need diff wrapper for each batch size we make a graph for - if is_draft and config.draft_async: - self._init_flashinfer_wrappers() - + if self.verbose: print(f'INSIDE MODEL RUNNER INIT, DRAFT={is_draft}', flush=True) self.tp_pg = None @@ -167,56 +161,6 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event], is_dra if self.verbose: print(f'-----{model_type}MODEL RUNNER INITIALIZED----', flush=True) - def _init_flashinfer_wrappers(self): - """Initialize FlashInfer wrappers for draft async mode.""" - self.workspace_buffer = torch.zeros( - 768 * 1024 * 1024, dtype=torch.uint8, device=f"cuda:{self.rank}") - - if self.config.enforce_eager: - self.only_prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") - else: - max_bs = min(self.config.max_num_seqs, 512) - max_num_blocks = (self.config.max_model_len + self.block_size - 1) // self.block_size - - # FlashInfer kernel tensors - # pages_for_max_len = (self.config.max_model_len + self.block_size - 1) // self.block_size - last_page_len_max_len = self.config.max_model_len % self.block_size - last_page_len_max_len = self.block_size if last_page_len_max_len == 0 else last_page_len_max_len - MQ_LEN = self.config.async_fan_out * (self.config.speculate_k + 1) - - cu_seqlens_q = torch.empty(max_bs + 1, dtype=torch.int32, device=self.device) - kv_indptr = torch.empty(max_bs + 1, dtype=torch.int32, device=self.device) - kv_indices = torch.empty(max_bs * max_num_blocks, dtype=torch.int32, device=self.device) - kv_last_page_len = torch.empty(max_bs, dtype=torch.int32, device=self.device) - custom_mask_buf = torch.empty(max_bs * MQ_LEN * self.config.max_model_len, dtype=torch.uint8, device=self.device) - mask_indptr_buf = torch.empty(max_bs + 1, dtype=torch.int32, device=self.device) - - # Create graph_bs_list to match what will be used in cudagraph_helpers.py - graph_bs_list = [1] - for bs in [2, 4, 8] + list(range(16, max_bs + 1, 16)): - if bs <= max_bs: - graph_bs_list.append(bs) - if max_bs not in graph_bs_list: - graph_bs_list.append(max_bs) - graph_bs_list.sort() - - # Create a dict of wrappers, one for each bs we will touch in cudagraph_helpers.py - self.prefill_wrappers = {} - print(f'[model_runner about to wrapper.init()] graph_bs_list={graph_bs_list}', flush=True) - for bs in graph_bs_list: - self.prefill_wrappers[bs] = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - self.workspace_buffer, "NHD", - use_cuda_graph=True, - qo_indptr_buf=cu_seqlens_q[:bs + 1], - paged_kv_indptr_buf=kv_indptr[:bs + 1], - paged_kv_indices_buf=kv_indices[:bs * max_num_blocks], - paged_kv_last_page_len_buf=kv_last_page_len[:bs], - custom_mask_buf=custom_mask_buf[:bs * MQ_LEN * self.config.max_model_len], - mask_indptr_buf=mask_indptr_buf[:bs + 1], - ) - print(f'wrapper backend is {self.prefill_wrappers[bs]._backend}', flush=True) - - def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoConfig, init_q=None, is_draft=False): # cudagraphs self.graph_vars = {} @@ -305,15 +249,14 @@ def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoC assert sum(config.fan_out_list) == sum(config.fan_out_list_miss) == config.async_fan_out * (config.speculate_k + 1), "ERROR in ModelRunner: fancy sampling only supported for constant fan out for now." self.sampler = Sampler(sampler_x=config.sampler_x, async_fan_out=config.async_fan_out) - if self.verbose: - print(f'-----WARMING UP {model_type}MODEL----', flush=True) + print(f'[model_runner] Warming up {model_type}model...', flush=True) self.warmup_model() - if self.verbose: - print(f'-----ALLOCATING {model_type}KV CACHE----', flush=True) + print(f'[model_runner] Allocating {model_type}KV cache...', flush=True) self.allocate_kv_cache() if not self.enforce_eager: - # if not self.is_draft or (self.is_draft and self.config.draft_async and self.config.speculate): + print(f'[model_runner] Capturing CUDA graphs for {model_type}model...', flush=True) + # if not self.is_draft or (self.is_draft and self.config.draft_async and self.config.speculate): decode_graph_vars, decode_graph_pool, decode_graphs, decode_graph_bs_list = capture_cudagraph(self) # decode cudagraph, draft needs in spec and target in normal self.graph_vars["decode"] = decode_graph_vars self.graph_pools["decode"] = decode_graph_pool @@ -338,6 +281,7 @@ def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoC self.graphs["glue_decode"] = glue_graphs self.graph_bs_list["glue_decode"] = glue_bs_list + print(f'[model_runner] {model_type}model initialization complete.', flush=True) if init_q is not None: # Signal the scheduler that we're fully initialized (model loaded, # KV cache allocated, CUDA graphs captured). Must happen after @@ -555,15 +499,21 @@ def allocate_kv_cache(self): ) print(f"allocate_kv_cache(): kv_cache shape = {self.kv_cache.shape}", flush=True) + + # Create tree_score_mod once (shared across all attention layers) + tree_score_mod = None + if self.is_draft and self.draft_async: + from ssd.layers.tree_mask import create_tree_score_mod + tree_score_mod = create_tree_score_mod(config.max_model_len) + layer_id = 0 for module in self.model.modules(): if hasattr(module, "k_cache") and hasattr(module, "v_cache"): module.k_cache = self.kv_cache[0, layer_id] module.v_cache = self.kv_cache[1, layer_id] - if self.is_draft and self.draft_async and not self.enforce_eager: - module.prefill_wrappers = self.prefill_wrappers - elif self.is_draft and self.draft_async and self.enforce_eager: - module.only_prefill_wrapper = self.only_prefill_wrapper # this will make it not None so it can be used on fwd + if self.is_draft and self.draft_async: + module.max_seqlen_k = config.max_model_len + module.tree_score_mod = tree_score_mod layer_id += 1 @@ -614,45 +564,21 @@ def prepare_sample(self, seqs: list[Sequence]): return temperatures def eager_tree_decode_plan(self, input_ids, positions, step, cache_hits): - """Plan FlashInfer for tree decode in eager mode""" + """Set up context metadata for FA4 tree decode in eager mode.""" assert self.is_draft and self.config.draft_async, "ERROR in eager_tree_decode_plan: not a draft async model" + from ssd.layers.tree_mask import build_tree_mask_bias context = get_context() - - K, F = self.config.speculate_k, self.config.async_fan_out - # MQ_LEN = F * (K+1) + K = self.config.speculate_k MQ_LEN = self.config.MQ_LEN - flat_batch_size = input_ids.size(0) - B = flat_batch_size // MQ_LEN # [N] tokens = B * sum(fan_out_list) - - # Convert block_tables to FlashInfer format - block_tables = context.block_tables # [B, M] - context_lens = context.context_lens # [B] - - counts = (context_lens + self.block_size - 1) // self.block_size # [B] - kv_indptr = torch.cat([torch.tensor([0], device=block_tables.device), - counts.cumsum(dim=0)]).to(torch.int32) - mask = torch.arange(block_tables.size(1), device=block_tables.device)[None, :] < counts[:, None] - kv_indices = block_tables[mask] # flattened page ids - - # Last-page actual token count per request - kv_last_page_len = (context_lens % self.block_size) - kv_last_page_len[kv_last_page_len == 0] = self.block_size - kv_last_page_len = kv_last_page_len.to(torch.int32) - cu_seqlens_q = torch.arange(B + 1, device=self.device, dtype=torch.int32) * MQ_LEN # assumes same MQ_LEN across batch dimension - custom_mask = get_custom_mask(self.config, context_lens, step, K, F, B, device=self.device, cache_hits=cache_hits) - - self.only_prefill_wrapper.plan( - cu_seqlens_q, - kv_indptr, - kv_indices, - kv_last_page_len, - self.hf_config.num_attention_heads, - self.hf_config.num_key_value_heads, - self.hf_config.head_dim, - self.block_size, - custom_mask=custom_mask, - q_data_type=self.hf_config.torch_dtype, - kv_data_type=self.hf_config.torch_dtype, + B = input_ids.size(0) // MQ_LEN + context.tree_cu_seqlens_q = torch.arange(B + 1, device=self.device, dtype=torch.int32) * MQ_LEN + context.tree_mask_bias = build_tree_mask_bias( + context.context_lens, step=step, K=K, MQ_LEN=MQ_LEN, + fan_out_list=self.config.fan_out_list, + fan_out_list_miss=self.config.fan_out_list_miss, + cache_hits=cache_hits, + max_kv_stride=self.config.max_model_len, + device=self.device, ) @torch.inference_mode() diff --git a/ssd/layers/attention.py b/ssd/layers/attention.py index ed5ec7b3a..6b1f61c7c 100644 --- a/ssd/layers/attention.py +++ b/ssd/layers/attention.py @@ -3,7 +3,8 @@ import triton import triton.language as tl -from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from flash_attn.cute.interface import flash_attn_varlen_func as fa4_varlen_func +from ssd.layers.tree_mask import create_tree_score_mod from ssd.utils.context import get_context @@ -65,10 +66,10 @@ def __init__( self.speculate = speculate self.draft_async = draft_async self.use_eagle = use_eagle - self.prefill_wrappers = {} self.F = F # async_fan_out self.K = K # speculate_k - self.only_prefill_wrapper = None + self.max_seqlen_k = 0 # set during KV cache allocation to config.max_model_len + self.tree_score_mod = None # set during KV cache allocation def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): o: torch.Tensor @@ -87,7 +88,7 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): k, v = k_cache, v_cache k, v = k.view(-1, self.num_kv_heads, self.head_dim), v.view(-1, self.num_kv_heads, self.head_dim) - o = flash_attn_varlen_func(q, k, v, + o, _ = fa4_varlen_func(q, k, v, max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, softmax_scale=self.scale, causal=True) @@ -104,29 +105,45 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): if verify_or_glue: assert context.context_lens is not None - o = flash_attn_with_kvcache(q, k_cache, v_cache, - cache_seqlens=context.context_lens, page_table=context.block_tables, + o, _ = fa4_varlen_func(q, k_cache, v_cache, + cu_seqlens_q=context.cu_seqlens_q, + cu_seqlens_k=None, + max_seqlen_q=context.max_seqlen_q, + max_seqlen_k=self.max_seqlen_k, + seqused_k=context.context_lens, + page_table=context.block_tables, softmax_scale=self.scale, causal=True, - cu_seqlens_q=context.cu_seqlens_q, max_seqlen_q=context.max_seqlen_q, ) elif tree_decode: - if self.only_prefill_wrapper is not None: - prefill_wrapper = self.only_prefill_wrapper - else: - mq_len = self.F * (self.K+1) - bs = q.shape[0] // mq_len - wrapper_bs = None - for available_bs in sorted(self.prefill_wrappers.keys()): - if available_bs >= bs: - wrapper_bs = available_bs - break - prefill_wrapper = self.prefill_wrappers[wrapper_bs] - o = prefill_wrapper.run(q, (self.k_cache, self.v_cache)) + score_mod_kwargs = {} + if self.tree_score_mod is not None and context.tree_mask_bias is not None: + score_mod_kwargs["score_mod"] = self.tree_score_mod + score_mod_kwargs["aux_tensors"] = [context.tree_mask_bias] + o, _ = fa4_varlen_func( + q, + self.k_cache, + self.v_cache, + cu_seqlens_q=context.tree_cu_seqlens_q, + cu_seqlens_k=None, + max_seqlen_q=self.F * (self.K + 1), + max_seqlen_k=self.max_seqlen_k, + seqused_k=context.context_lens, + page_table=context.block_tables, + softmax_scale=self.scale, + causal=False, + **score_mod_kwargs, + ) else: # single query decode - q = q.unsqueeze(1) - o = flash_attn_with_kvcache(q, k_cache, v_cache, - cache_seqlens=context.context_lens, page_table=context.block_tables, + batch_size = context.context_lens.shape[0] + cu_seqlens_q = torch.arange(0, batch_size + 1, dtype=torch.int32, device=q.device) + o, _ = fa4_varlen_func(q, k_cache, v_cache, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=None, + max_seqlen_q=1, + max_seqlen_k=self.max_seqlen_k, + seqused_k=context.context_lens, + page_table=context.block_tables, softmax_scale=self.scale, causal=True, ) diff --git a/ssd/layers/tree_mask.py b/ssd/layers/tree_mask.py new file mode 100644 index 000000000..d44a7ec14 --- /dev/null +++ b/ssd/layers/tree_mask.py @@ -0,0 +1,100 @@ +"""Tree decode mask for FA4 via score_mod + aux_tensors. + +The tree mask is stored as a dense float32 bias tensor of shape +(max_total_q, max_kv_stride), flattened to 1D. Unmasked positions have +value 0.0; masked positions have a large negative value (-1e6). + +score_mod adds the bias to each attention score, effectively masking out +positions where the bias is -1e6. +""" + +import torch +import numpy as np +import cutlass +import cutlass.cute as cute + +# Large negative value used to mask attention scores. +_MASK_VAL = -1.0e6 + + +def create_tree_score_mod(max_kv_stride: int): + """Return a @cute.jit score_mod that reads a mask bias from aux_tensors[0]. + + The aux_tensor is a 1D float32 tensor indexed by: + (offset_q + q_idx) * max_kv_stride + kv_idx + + where offset_q comes from seqlen_info for varlen sequences. + """ + + @cute.jit + def tree_score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + mask_bias = aux_tensors[0] + dtype = mask_bias.element_type + global_q = seqlen_info.offset_q + q_idx + flat_idx = global_q * max_kv_stride + kv_idx + idx_frag = cute.make_rmem_tensor(1, cutlass.Int32) + idx_frag.store(flat_idx) + val_frag = cute.make_rmem_tensor(1, dtype) + val_frag[0] = mask_bias[idx_frag[0]] + bias = (val_frag.load()).to(cutlass.Float32) + return tSrS_ssa + bias + + return tree_score_mod + + +def build_tree_mask_bias( + context_lens: torch.Tensor, + step: int, + K: int, + MQ_LEN: int, + fan_out_list: list[int], + fan_out_list_miss: list[int], + cache_hits: torch.Tensor, + max_kv_stride: int, + device: torch.device, +) -> torch.Tensor: + """Build the dense mask bias tensor for one tree decode step. + + Returns a 1D float32 tensor of shape (B * MQ_LEN * max_kv_stride,) + with 0.0 for attend and _MASK_VAL for masked positions. + """ + B = context_lens.shape[0] + context_lens_list = context_lens.tolist() + cache_hits_list = cache_hits[:B].tolist() + + # Pre-compute glue patterns + tril = np.tril(np.ones((K + 1, K + 1), dtype=np.float32)) + fol = np.array(fan_out_list) + fol_miss = np.array(fan_out_list_miss) + glue_hit = np.repeat(tril, fol, axis=0) # (MQ_LEN, K+1) + glue_miss = np.repeat(tril, fol_miss, axis=0) + + ttl_added = (step + 1) * MQ_LEN + (K + 1) + rows = np.arange(MQ_LEN) + + # Build mask as numpy, then convert + bias = np.full((B * MQ_LEN, max_kv_stride), _MASK_VAL, dtype=np.float32) + + for b in range(B): + cols_b = int(context_lens_list[b]) + prefix_len_b = cols_b - ttl_added + row_offset = b * MQ_LEN + + # Prefix: attend to all + if prefix_len_b > 0: + bias[row_offset:row_offset + MQ_LEN, :prefix_len_b] = 0.0 + + # Glue pattern + glue = glue_hit if int(cache_hits_list[b]) == 1 else glue_miss + glue_start = prefix_len_b + glue_bias = np.where(glue > 0, 0.0, _MASK_VAL).astype(np.float32) + bias[row_offset:row_offset + MQ_LEN, glue_start:glue_start + K + 1] = glue_bias + + # Diagonal blocks + diag_start = prefix_len_b + K + 1 + for blk in range(step + 1): + col_indices = diag_start + blk * MQ_LEN + rows + valid = col_indices < max_kv_stride + bias[row_offset + rows[valid], col_indices[valid]] = 0.0 + + return torch.from_numpy(bias.reshape(-1)).to(device, non_blocking=True) diff --git a/ssd/utils/context.py b/ssd/utils/context.py index 91c744a27..cccb3459c 100644 --- a/ssd/utils/context.py +++ b/ssd/utils/context.py @@ -13,15 +13,17 @@ class Context: slot_mapping: torch.Tensor | None = None context_lens: torch.Tensor | None = None block_tables: torch.Tensor | None = None + tree_cu_seqlens_q: torch.Tensor | None = None + tree_mask_bias: torch.Tensor | None = None _CONTEXT = Context() def get_context(): return _CONTEXT -def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None, is_jit=False): +def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None, is_jit=False, tree_cu_seqlens_q=None, tree_mask_bias=None): global _CONTEXT - _CONTEXT = Context(is_prefill, is_jit, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables) + _CONTEXT = Context(is_prefill, is_jit, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables, tree_cu_seqlens_q, tree_mask_bias) def reset_context(): global _CONTEXT diff --git a/tests/test_attention_paths.py b/tests/test_attention_paths.py new file mode 100644 index 000000000..8bedf948e --- /dev/null +++ b/tests/test_attention_paths.py @@ -0,0 +1,388 @@ +"""Tests for all Attention code paths after migration from sgl_kernel to FA4. + +Covers: + 1. Prefill (contiguous Q/K/V with cu_seqlens) + 2. Verify/glue decode (paged KV cache with cu_seqlens_q) + 3. Single query decode (paged KV cache, 1 query per sequence) + 4. Tree decode is already covered in test_fa4_tree_decode.py +""" + +import pytest +import torch +from ssd.layers.attention import Attention +from ssd.utils.context import set_context, reset_context + + +DEVICE = "cuda" +DTYPE = torch.bfloat16 + + +@pytest.fixture(autouse=True) +def cleanup_context(): + yield + reset_context() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def make_attention( + num_heads=8, num_kv_heads=2, head_dim=128, + draft=False, speculate=False, draft_async=False, + F=1, K=1, +): + scale = head_dim ** -0.5 + return Attention( + num_heads=num_heads, head_dim=head_dim, scale=scale, + num_kv_heads=num_kv_heads, draft=draft, speculate=speculate, + draft_async=draft_async, use_eagle=False, F=F, K=K, + ) + + +def make_paged_kv_cache(num_pages, page_size, num_kv_heads, head_dim): + k_cache = torch.randn(num_pages, page_size, num_kv_heads, head_dim, dtype=DTYPE, device=DEVICE) + v_cache = torch.randn(num_pages, page_size, num_kv_heads, head_dim, dtype=DTYPE, device=DEVICE) + return k_cache, v_cache + + +def make_block_tables(batch_size, context_lens_list, page_size, max_pages_per_seq, page_offset=0): + block_tables = torch.zeros(batch_size, max_pages_per_seq, dtype=torch.int32, device=DEVICE) + for b in range(batch_size): + n_pages = (context_lens_list[b] + page_size - 1) // page_size + block_tables[b, :n_pages] = torch.arange(n_pages, dtype=torch.int32, device=DEVICE) + b * page_offset + return block_tables + + +# =========================================================================== +# 1. Prefill path +# =========================================================================== + +class TestPrefill: + """context.is_prefill=True, no paged KV cache (contiguous Q/K/V).""" + + @pytest.fixture(autouse=True) + def setup(self): + torch.manual_seed(0) + self.num_heads = 8 + self.num_kv_heads = 2 + self.head_dim = 128 + self.hidden = self.num_heads * self.head_dim + self.kv_hidden = self.num_kv_heads * self.head_dim + + def _run(self, seq_lens): + attn = make_attention( + num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + ) + # No KV cache for prefill without paging + total_tokens = sum(seq_lens) + q = torch.randn(total_tokens, self.hidden, dtype=DTYPE, device=DEVICE) + k = torch.randn(total_tokens, self.kv_hidden, dtype=DTYPE, device=DEVICE) + v = torch.randn(total_tokens, self.kv_hidden, dtype=DTYPE, device=DEVICE) + + cu_seqlens = torch.zeros(len(seq_lens) + 1, dtype=torch.int32, device=DEVICE) + for i, sl in enumerate(seq_lens): + cu_seqlens[i + 1] = cu_seqlens[i] + sl + max_seqlen = max(seq_lens) + slot_mapping = torch.arange(total_tokens, dtype=torch.int32, device=DEVICE) + + set_context( + is_prefill=True, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + slot_mapping=slot_mapping, + ) + + with torch.inference_mode(): + out = attn(q, k, v) + return out + + def test_output_shape(self): + out = self._run([10, 15]) + assert out.shape == (25, self.hidden) + + def test_no_nan_inf(self): + out = self._run([10, 15]) + assert not torch.isnan(out).any(), "Output contains NaN" + assert not torch.isinf(out).any(), "Output contains Inf" + + def test_single_sequence(self): + out = self._run([20]) + assert out.shape == (20, self.hidden) + assert not torch.isnan(out).any() + + def test_different_seq_lens(self): + out = self._run([5, 30]) + out_seq0 = out[:5] + out_seq1 = out[5:] + assert not torch.allclose(out_seq0.mean(), out_seq1.mean()) + + def test_deterministic(self): + torch.manual_seed(0) + out1 = self._run([10, 15]) + torch.manual_seed(0) + out2 = self._run([10, 15]) + assert torch.allclose(out1, out2) + + +# =========================================================================== +# 2. Prefill with paged KV cache +# =========================================================================== + +class TestPrefillPaged: + """context.is_prefill=True with block_tables set (paged KV).""" + + @pytest.fixture(autouse=True) + def setup(self): + torch.manual_seed(1) + self.num_heads = 8 + self.num_kv_heads = 2 + self.head_dim = 128 + self.hidden = self.num_heads * self.head_dim + self.kv_hidden = self.num_kv_heads * self.head_dim + self.page_size = 1 + self.num_pages = 200 + self.max_pages_per_seq = 50 + + def _run(self, seq_lens): + attn = make_attention( + num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + ) + k_cache, v_cache = make_paged_kv_cache( + self.num_pages, self.page_size, self.num_kv_heads, self.head_dim, + ) + attn.k_cache = k_cache + attn.v_cache = v_cache + + total_tokens = sum(seq_lens) + q = torch.randn(total_tokens, self.hidden, dtype=DTYPE, device=DEVICE) + k = torch.randn(total_tokens, self.kv_hidden, dtype=DTYPE, device=DEVICE) + v = torch.randn(total_tokens, self.kv_hidden, dtype=DTYPE, device=DEVICE) + + cu_seqlens = torch.zeros(len(seq_lens) + 1, dtype=torch.int32, device=DEVICE) + for i, sl in enumerate(seq_lens): + cu_seqlens[i + 1] = cu_seqlens[i] + sl + max_seqlen = max(seq_lens) + + slot_mapping = torch.arange(total_tokens, dtype=torch.int32, device=DEVICE) + block_tables = make_block_tables( + len(seq_lens), seq_lens, self.page_size, self.max_pages_per_seq, page_offset=50, + ) + + set_context( + is_prefill=True, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + slot_mapping=slot_mapping, + block_tables=block_tables, + ) + + with torch.inference_mode(): + out = attn(q, k, v) + return out + + def test_output_shape(self): + out = self._run([10, 15]) + assert out.shape == (25, self.hidden) + + def test_no_nan_inf(self): + out = self._run([10, 15]) + assert not torch.isnan(out).any() + assert not torch.isinf(out).any() + + +# =========================================================================== +# 3. Verify/glue decode path +# =========================================================================== + +class TestVerifyGlueDecode: + """speculate=True, cu_seqlens_q is not None → verify_or_glue path.""" + + @pytest.fixture(autouse=True) + def setup(self): + torch.manual_seed(2) + self.num_heads = 8 + self.num_kv_heads = 2 + self.head_dim = 128 + self.hidden = self.num_heads * self.head_dim + self.kv_hidden = self.num_kv_heads * self.head_dim + self.page_size = 1 + self.num_pages = 200 + self.max_pages_per_seq = 50 + self.max_model_len = 100 + + def _make_attn(self): + attn = make_attention( + num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, speculate=True, + ) + k_cache, v_cache = make_paged_kv_cache( + self.num_pages, self.page_size, self.num_kv_heads, self.head_dim, + ) + attn.k_cache = k_cache + attn.v_cache = v_cache + attn.max_seqlen_k = self.max_model_len + return attn + + def _run(self, query_lens, context_lens_list): + """ + query_lens: list of query tokens per sequence (e.g. [K+1, K+1] for verify) + context_lens_list: list of KV context lengths per sequence + """ + attn = self._make_attn() + B = len(query_lens) + total_q = sum(query_lens) + q = torch.randn(total_q, self.hidden, dtype=DTYPE, device=DEVICE) + k = torch.randn(total_q, self.kv_hidden, dtype=DTYPE, device=DEVICE) + v = torch.randn(total_q, self.kv_hidden, dtype=DTYPE, device=DEVICE) + + cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device=DEVICE) + for i, ql in enumerate(query_lens): + cu_seqlens_q[i + 1] = cu_seqlens_q[i] + ql + max_seqlen_q = max(query_lens) + + context_lens = torch.tensor(context_lens_list, dtype=torch.int32, device=DEVICE) + slot_mapping = torch.arange(total_q, dtype=torch.int32, device=DEVICE) + block_tables = make_block_tables( + B, context_lens_list, self.page_size, self.max_pages_per_seq, page_offset=50, + ) + + set_context( + is_prefill=False, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + slot_mapping=slot_mapping, + context_lens=context_lens, + block_tables=block_tables, + ) + + with torch.inference_mode(): + out = attn(q, k, v) + return out + + def test_output_shape(self): + # 2 sequences, each with K+1=4 query tokens, context 20 and 15 + out = self._run([4, 4], [20, 15]) + assert out.shape == (8, self.hidden) + + def test_no_nan_inf(self): + out = self._run([4, 4], [20, 15]) + assert not torch.isnan(out).any(), "Output contains NaN" + assert not torch.isinf(out).any(), "Output contains Inf" + + def test_single_sequence(self): + out = self._run([8], [30]) + assert out.shape == (8, self.hidden) + assert not torch.isnan(out).any() + + def test_variable_query_lens(self): + out = self._run([3, 6], [25, 10]) + assert out.shape == (9, self.hidden) + assert not torch.isnan(out).any() + + def test_deterministic(self): + torch.manual_seed(2) + out1 = self._run([4, 4], [20, 15]) + torch.manual_seed(2) + out2 = self._run([4, 4], [20, 15]) + assert torch.allclose(out1, out2) + + +# =========================================================================== +# 4. Single query decode path +# =========================================================================== + +class TestSingleQueryDecode: + """decode=True, not verify_or_glue, not tree_decode → single query decode.""" + + @pytest.fixture(autouse=True) + def setup(self): + torch.manual_seed(3) + self.num_heads = 8 + self.num_kv_heads = 2 + self.head_dim = 128 + self.hidden = self.num_heads * self.head_dim + self.kv_hidden = self.num_kv_heads * self.head_dim + self.page_size = 1 + self.num_pages = 200 + self.max_pages_per_seq = 50 + self.max_model_len = 100 + + def _make_attn(self): + # speculate=False (or draft=False, draft_async=False) so we don't enter + # verify_or_glue or tree_decode + attn = make_attention( + num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, speculate=False, + ) + k_cache, v_cache = make_paged_kv_cache( + self.num_pages, self.page_size, self.num_kv_heads, self.head_dim, + ) + attn.k_cache = k_cache + attn.v_cache = v_cache + attn.max_seqlen_k = self.max_model_len + return attn + + def _run(self, batch_size, context_lens_list): + attn = self._make_attn() + # Single query decode: 1 query token per sequence + total_q = batch_size + q = torch.randn(total_q, self.hidden, dtype=DTYPE, device=DEVICE) + k = torch.randn(total_q, self.kv_hidden, dtype=DTYPE, device=DEVICE) + v = torch.randn(total_q, self.kv_hidden, dtype=DTYPE, device=DEVICE) + + context_lens = torch.tensor(context_lens_list, dtype=torch.int32, device=DEVICE) + slot_mapping = torch.arange(total_q, dtype=torch.int32, device=DEVICE) + block_tables = make_block_tables( + batch_size, context_lens_list, self.page_size, self.max_pages_per_seq, page_offset=50, + ) + + set_context( + is_prefill=False, + cu_seqlens_q=None, # None → not verify_or_glue + slot_mapping=slot_mapping, + context_lens=context_lens, + block_tables=block_tables, + ) + + with torch.inference_mode(): + out = attn(q, k, v) + return out + + def test_output_shape(self): + out = self._run(2, [20, 15]) + assert out.shape == (2, self.hidden) + + def test_no_nan_inf(self): + out = self._run(2, [20, 15]) + assert not torch.isnan(out).any(), "Output contains NaN" + assert not torch.isinf(out).any(), "Output contains Inf" + + def test_single_sequence(self): + out = self._run(1, [30]) + assert out.shape == (1, self.hidden) + assert not torch.isnan(out).any() + + def test_large_batch(self): + B = 16 + ctx_lens = [5 + i * 2 for i in range(B)] # max = 5 + 15*2 = 35 < max_pages_per_seq + out = self._run(B, ctx_lens) + assert out.shape == (B, self.hidden) + assert not torch.isnan(out).any() + + def test_different_context_lens_produce_different_outputs(self): + out = self._run(2, [50, 5]) + assert not torch.allclose(out[0], out[1]) + + def test_deterministic(self): + torch.manual_seed(3) + out1 = self._run(2, [20, 15]) + torch.manual_seed(3) + out2 = self._run(2, [20, 15]) + assert torch.allclose(out1, out2) diff --git a/tests/test_fa4_tree_decode.py b/tests/test_fa4_tree_decode.py new file mode 100644 index 000000000..19102ad75 --- /dev/null +++ b/tests/test_fa4_tree_decode.py @@ -0,0 +1,201 @@ +"""Tests for FA4 flash_attn_varlen_func with paged KV cache (tree decode replacement).""" + +import pytest +import torch +from flash_attn.cute.interface import flash_attn_varlen_func as fa4_varlen_func +from ssd.layers.attention import Attention +from ssd.utils.context import set_context, reset_context + + +DEVICE = "cuda" +DTYPE = torch.bfloat16 + + +# --------------------------------------------------------------------------- +# FA4 varlen + page_table: basic correctness +# --------------------------------------------------------------------------- + +class TestFA4VarlenPageTable: + """Test flash_attn_varlen_func with page_table at various page sizes.""" + + @pytest.fixture(autouse=True) + def setup(self): + torch.manual_seed(42) + self.B = 2 + self.MQ_LEN = 6 + self.num_heads = 4 + self.num_kv_heads = 2 + self.head_dim = 128 + self.num_pages = 200 + self.max_pages_per_seq = 20 + + def _run(self, page_size, kv_lens): + total_q = self.B * self.MQ_LEN + q = torch.randn(total_q, self.num_heads, self.head_dim, dtype=DTYPE, device=DEVICE) + k_cache = torch.randn(self.num_pages, page_size, self.num_kv_heads, self.head_dim, dtype=DTYPE, device=DEVICE) + v_cache = torch.randn(self.num_pages, page_size, self.num_kv_heads, self.head_dim, dtype=DTYPE, device=DEVICE) + cu_seqlens_q = torch.arange(self.B + 1, dtype=torch.int32, device=DEVICE) * self.MQ_LEN + + page_table = torch.zeros(self.B, self.max_pages_per_seq, dtype=torch.int32, device=DEVICE) + for b in range(self.B): + n_pages = (kv_lens[b] + page_size - 1) // page_size + page_table[b, :n_pages] = torch.arange(n_pages, dtype=torch.int32, device=DEVICE) + b * 50 + + seqused_k = torch.tensor(kv_lens, dtype=torch.int32, device=DEVICE) + + out, lse = fa4_varlen_func( + q, k_cache, v_cache, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=None, + max_seqlen_q=self.MQ_LEN, + max_seqlen_k=max(kv_lens), + seqused_k=seqused_k, + page_table=page_table, + softmax_scale=self.head_dim ** -0.5, + causal=False, + ) + return out, lse + + @pytest.mark.parametrize("page_size", [1, 16, 128]) + def test_output_shape(self, page_size): + out, _ = self._run(page_size, kv_lens=[10, 5]) + assert out.shape == (self.B * self.MQ_LEN, self.num_heads, self.head_dim) + + @pytest.mark.parametrize("page_size", [1, 16, 128]) + def test_no_nan_inf(self, page_size): + out, _ = self._run(page_size, kv_lens=[10, 5]) + assert not torch.isnan(out).any(), "Output contains NaN" + assert not torch.isinf(out).any(), "Output contains Inf" + + @pytest.mark.parametrize("page_size", [1, 16, 128]) + def test_lse_returned_none_by_default(self, page_size): + _, lse = self._run(page_size, kv_lens=[10, 5]) + assert lse is None, "LSE should be None when return_lse=False (default)" + + def test_variable_kv_lengths(self): + """Sequences with very different KV lengths should both produce valid output.""" + self.max_pages_per_seq = 60 # accommodate kv_len=50 + out, _ = self._run(page_size=1, kv_lens=[50, 3]) + assert not torch.isnan(out).any() + # Check that the two sequences produce different outputs (they have different KV) + out_seq0 = out[:self.MQ_LEN] + out_seq1 = out[self.MQ_LEN:] + assert not torch.allclose(out_seq0, out_seq1), "Different KV should produce different outputs" + + def test_deterministic(self): + """Same inputs should produce same outputs.""" + out1, _ = self._run(page_size=1, kv_lens=[10, 5]) + torch.manual_seed(42) # reset seed to get same random inputs + out2, _ = self._run(page_size=1, kv_lens=[10, 5]) + assert torch.allclose(out1, out2), "Same inputs should produce identical outputs" + + def test_batch_size_1(self): + """Single-sequence batch should work.""" + self.B = 1 + out, _ = self._run(page_size=1, kv_lens=[10]) + assert out.shape == (self.MQ_LEN, self.num_heads, self.head_dim) + assert not torch.isnan(out).any() + + +# --------------------------------------------------------------------------- +# Attention layer integration: tree decode path +# --------------------------------------------------------------------------- + +class TestAttentionTreeDecode: + """Test the Attention module's tree_decode path end-to-end with FA4.""" + + @pytest.fixture(autouse=True) + def setup(self): + torch.manual_seed(42) + self.num_heads = 8 + self.num_kv_heads = 2 + self.head_dim = 128 + self.scale = self.head_dim ** -0.5 + self.F_fan = 2 + self.K_spec = 2 + self.MQ_LEN = self.F_fan * (self.K_spec + 1) + self.page_size = 1 + self.num_pages = 200 + self.max_pages_per_seq = 50 + self.max_model_len = 50 + yield + reset_context() + + def _make_attn(self): + attn = Attention( + num_heads=self.num_heads, head_dim=self.head_dim, scale=self.scale, + num_kv_heads=self.num_kv_heads, draft=True, speculate=True, + draft_async=True, use_eagle=False, F=self.F_fan, K=self.K_spec, + ) + attn.k_cache = torch.randn( + self.num_pages, self.page_size, self.num_kv_heads, self.head_dim, + dtype=DTYPE, device=DEVICE) + attn.v_cache = torch.randn( + self.num_pages, self.page_size, self.num_kv_heads, self.head_dim, + dtype=DTYPE, device=DEVICE) + attn.max_seqlen_k = self.max_model_len + return attn + + def _run(self, attn, B, context_lens_list): + total_tokens = B * self.MQ_LEN + q = torch.randn(total_tokens, self.num_heads * self.head_dim, dtype=DTYPE, device=DEVICE) + k = torch.randn(total_tokens, self.num_kv_heads * self.head_dim, dtype=DTYPE, device=DEVICE) + v = torch.randn(total_tokens, self.num_kv_heads * self.head_dim, dtype=DTYPE, device=DEVICE) + + context_lens = torch.tensor(context_lens_list, dtype=torch.int32, device=DEVICE) + slot_mapping = torch.arange(total_tokens, dtype=torch.int32, device=DEVICE) + + block_tables = torch.zeros(B, self.max_pages_per_seq, dtype=torch.int32, device=DEVICE) + for b in range(B): + n_pages = context_lens_list[b] # page_size=1, so pages == tokens + block_tables[b, :n_pages] = torch.arange(n_pages, dtype=torch.int32, device=DEVICE) + b * 50 + + cu_seqlens_q = torch.arange(B + 1, dtype=torch.int32, device=DEVICE) * self.MQ_LEN + + set_context( + is_prefill=False, + slot_mapping=slot_mapping, + context_lens=context_lens, + block_tables=block_tables, + tree_cu_seqlens_q=cu_seqlens_q, + ) + + with torch.inference_mode(): + out = attn(q, k, v) + return out + + def test_output_shape(self): + attn = self._make_attn() + out = self._run(attn, B=2, context_lens_list=[20, 15]) + expected = (2 * self.MQ_LEN, self.num_heads * self.head_dim) + assert out.shape == expected, f"Expected {expected}, got {out.shape}" + + def test_no_nan_inf(self): + attn = self._make_attn() + out = self._run(attn, B=2, context_lens_list=[20, 15]) + assert not torch.isnan(out).any(), "Output contains NaN" + assert not torch.isinf(out).any(), "Output contains Inf" + + def test_single_sequence(self): + attn = self._make_attn() + out = self._run(attn, B=1, context_lens_list=[30]) + expected = (self.MQ_LEN, self.num_heads * self.head_dim) + assert out.shape == expected + + def test_different_context_lens(self): + """Sequences with different context lengths should produce different outputs.""" + attn = self._make_attn() + out = self._run(attn, B=2, context_lens_list=[40, 10]) + out_seq0 = out[:self.MQ_LEN] + out_seq1 = out[self.MQ_LEN:] + assert not torch.allclose(out_seq0, out_seq1) + + def test_non_tree_decode_paths_unaffected(self): + """Verify that non-tree-decode paths still use the original kernels.""" + attn = Attention( + num_heads=self.num_heads, head_dim=self.head_dim, scale=self.scale, + num_kv_heads=self.num_kv_heads, draft=False, speculate=False, + draft_async=False, use_eagle=False, + ) + # This attention module should NOT take the tree_decode path + assert not (attn.speculate and attn.draft and attn.draft_async) diff --git a/tests/test_score_mod_basic.py b/tests/test_score_mod_basic.py new file mode 100644 index 000000000..e7ea7cdfe --- /dev/null +++ b/tests/test_score_mod_basic.py @@ -0,0 +1,155 @@ +"""Test that score_mod with aux_tensors works with FA4 varlen + page_table.""" + +import torch +import pytest +from flash_attn.cute.interface import flash_attn_varlen_func +from ssd.layers.tree_mask import create_tree_score_mod, build_tree_mask_bias + +DEVICE = "cuda" +DTYPE = torch.bfloat16 + + +class TestScoreModBasic: + """Verify score_mod compiles and runs with FA4 varlen + page_table.""" + + @pytest.fixture(autouse=True) + def setup(self): + torch.manual_seed(42) + self.B = 2 + self.MQ_LEN = 6 + self.num_heads = 4 + self.num_kv_heads = 2 + self.head_dim = 128 + self.num_pages = 200 + self.max_pages_per_seq = 50 + self.page_size = 1 + + def _make_inputs(self, kv_lens): + total_q = self.B * self.MQ_LEN + q = torch.randn(total_q, self.num_heads, self.head_dim, dtype=DTYPE, device=DEVICE) + k_cache = torch.randn(self.num_pages, self.page_size, self.num_kv_heads, self.head_dim, dtype=DTYPE, device=DEVICE) + v_cache = torch.randn(self.num_pages, self.page_size, self.num_kv_heads, self.head_dim, dtype=DTYPE, device=DEVICE) + cu_seqlens_q = torch.arange(self.B + 1, dtype=torch.int32, device=DEVICE) * self.MQ_LEN + page_table = torch.zeros(self.B, self.max_pages_per_seq, dtype=torch.int32, device=DEVICE) + for b in range(self.B): + n = kv_lens[b] + page_table[b, :n] = torch.arange(n, dtype=torch.int32, device=DEVICE) + b * 50 + seqused_k = torch.tensor(kv_lens, dtype=torch.int32, device=DEVICE) + return q, k_cache, v_cache, cu_seqlens_q, page_table, seqused_k + + def test_zero_bias_matches_no_scoremod(self): + """A score_mod that adds zero should produce identical output.""" + kv_lens = [10, 5] + max_kv_stride = 50 + q, k, v, cu, pt, sk = self._make_inputs(kv_lens) + + out_base, _ = flash_attn_varlen_func( + q, k, v, cu_seqlens_q=cu, cu_seqlens_k=None, + max_seqlen_q=self.MQ_LEN, max_seqlen_k=max(kv_lens), + seqused_k=sk, page_table=pt, + softmax_scale=self.head_dim ** -0.5, causal=False, + ) + + score_mod = create_tree_score_mod(max_kv_stride) + # All-zero bias = no masking + bias = torch.zeros(self.B * self.MQ_LEN * max_kv_stride, dtype=torch.float32, device=DEVICE) + + out_mod, _ = flash_attn_varlen_func( + q, k, v, cu_seqlens_q=cu, cu_seqlens_k=None, + max_seqlen_q=self.MQ_LEN, max_seqlen_k=max(kv_lens), + seqused_k=sk, page_table=pt, + softmax_scale=self.head_dim ** -0.5, causal=False, + score_mod=score_mod, aux_tensors=[bias], + ) + + assert torch.allclose(out_base, out_mod, atol=1e-2), \ + f"Zero bias should match base, max diff: {(out_base - out_mod).abs().max().item()}" + + def test_full_mask_produces_uniform_attention(self): + """Masking all but one KV position should concentrate attention there.""" + kv_lens = [10, 5] + max_kv_stride = 50 + q, k, v, cu, pt, sk = self._make_inputs(kv_lens) + + score_mod = create_tree_score_mod(max_kv_stride) + # Mask everything except KV position 0 for all queries + bias = torch.full((self.B * self.MQ_LEN * max_kv_stride,), -1e6, dtype=torch.float32, device=DEVICE) + for b in range(self.B): + for qi in range(self.MQ_LEN): + flat_idx = (b * self.MQ_LEN + qi) * max_kv_stride + 0 # only attend to kv_idx=0 + bias[flat_idx] = 0.0 + + out, _ = flash_attn_varlen_func( + q, k, v, cu_seqlens_q=cu, cu_seqlens_k=None, + max_seqlen_q=self.MQ_LEN, max_seqlen_k=max(kv_lens), + seqused_k=sk, page_table=pt, + softmax_scale=self.head_dim ** -0.5, causal=False, + score_mod=score_mod, aux_tensors=[bias], + ) + + assert not torch.isnan(out).any(), "Masked output has NaN" + assert not torch.isinf(out).any(), "Masked output has Inf" + + +class TestTreeMaskBuild: + """Test build_tree_mask_bias produces correct mask structure.""" + + def test_prefix_unmasked(self): + """All prefix positions should have bias=0 (attend).""" + B, K, MQ_LEN = 1, 2, 6 + fol = [2, 2, 2] + context_lens = torch.tensor([20], dtype=torch.int32) # prefix = 20 - (1*6 + 3) = 11 + cache_hits = torch.tensor([1]) + max_kv_stride = 50 + + bias = build_tree_mask_bias( + context_lens, step=0, K=K, MQ_LEN=MQ_LEN, + fan_out_list=fol, fan_out_list_miss=fol, + cache_hits=cache_hits, max_kv_stride=max_kv_stride, + device="cpu", + ) + bias_2d = bias.reshape(MQ_LEN, max_kv_stride) + prefix_len = 20 - (1 * MQ_LEN + K + 1) + # All prefix columns should be 0.0 (unmasked) + assert (bias_2d[:, :prefix_len] == 0.0).all(), "Prefix should be unmasked" + + def test_masked_positions_negative(self): + """Positions beyond the valid KV should be masked (large negative).""" + B, K, MQ_LEN = 1, 2, 6 + fol = [2, 2, 2] + context_lens = torch.tensor([20], dtype=torch.int32) + cache_hits = torch.tensor([1]) + max_kv_stride = 50 + + bias = build_tree_mask_bias( + context_lens, step=0, K=K, MQ_LEN=MQ_LEN, + fan_out_list=fol, fan_out_list_miss=fol, + cache_hits=cache_hits, max_kv_stride=max_kv_stride, + device="cpu", + ) + bias_2d = bias.reshape(MQ_LEN, max_kv_stride) + # Beyond context_lens should be masked + assert (bias_2d[:, 20:] < -1e5).all(), "Beyond context_lens should be masked" + + def test_diagonal_pattern(self): + """At step 0, each query should attend to its own diagonal position.""" + B, K, MQ_LEN = 1, 2, 6 + fol = [2, 2, 2] + # context_lens at step 0 needs to be at least ttl_added = 1*MQ_LEN + K+1 = 9 + context_lens = torch.tensor([15], dtype=torch.int32) + cache_hits = torch.tensor([1]) + max_kv_stride = 50 + + bias = build_tree_mask_bias( + context_lens, step=0, K=K, MQ_LEN=MQ_LEN, + fan_out_list=fol, fan_out_list_miss=fol, + cache_hits=cache_hits, max_kv_stride=max_kv_stride, + device="cpu", + ) + bias_2d = bias.reshape(MQ_LEN, max_kv_stride) + prefix_len = 15 - (1 * MQ_LEN + K + 1) # = 6 + diag_start = prefix_len + K + 1 # = 9 + # At step 0, block 0: bias_2d[q, diag_start + q] should be 0.0 + for q in range(MQ_LEN): + col = diag_start + q + assert bias_2d[q, col].item() == 0.0, f"Diagonal at q={q}, col={col} should be unmasked" diff --git a/tests/test_tree_mask_correctness.py b/tests/test_tree_mask_correctness.py new file mode 100644 index 000000000..0f8750c50 --- /dev/null +++ b/tests/test_tree_mask_correctness.py @@ -0,0 +1,164 @@ +"""Correctness tests: verify FA4 tree mask matches the original flashinfer mask logic.""" + +import torch +import numpy as np +import pytest +from flash_attn.cute.interface import flash_attn_varlen_func +from ssd.layers.tree_mask import create_tree_score_mod, build_tree_mask_bias +from ssd.engine.helpers.mask_helpers import get_custom_mask + +DEVICE = "cuda" +DTYPE = torch.bfloat16 + + +class FakeConfig: + """Minimal config for get_custom_mask.""" + def __init__(self, K, fan_out_list, fan_out_list_miss, max_model_len): + self.speculate_k = K + self.fan_out_list = fan_out_list + self.fan_out_list_miss = fan_out_list_miss + self.max_model_len = max_model_len + + +class TestTreeMaskMatchesOriginal: + """Verify that build_tree_mask_bias produces masks equivalent to get_custom_mask.""" + + @pytest.fixture(autouse=True) + def setup(self): + self.K = 2 + self.F = 2 + self.fan_out_list = [2, 2, 2] # F=2, K+1=3 groups + self.fan_out_list_miss = [2, 2, 2] + self.MQ_LEN = sum(self.fan_out_list) # = 6 + + def _compare_masks(self, B, context_lens_list, step, cache_hits_list): + """Compare old (get_custom_mask) vs new (build_tree_mask_bias) for one step.""" + context_lens = torch.tensor(context_lens_list, dtype=torch.int32, device=DEVICE) + cache_hits = torch.tensor(cache_hits_list, dtype=torch.float32, device=DEVICE) + max_model_len = 100 + + config = FakeConfig(self.K, self.fan_out_list, self.fan_out_list_miss, max_model_len) + + # Old mask: 1D bool tensor, concatenation of per-seq (MQ_LEN x kv_len) masks + old_mask = get_custom_mask( + config, context_lens, step, self.K, self.F, B, + device=DEVICE, cache_hits=cache_hits, + ) + + # New mask bias: (B * MQ_LEN * max_model_len,) float32 + new_bias = build_tree_mask_bias( + context_lens, step=step, K=self.K, MQ_LEN=self.MQ_LEN, + fan_out_list=self.fan_out_list, + fan_out_list_miss=self.fan_out_list_miss, + cache_hits=cache_hits, + max_kv_stride=max_model_len, + device=DEVICE, + ) + new_bias_2d = new_bias.reshape(B * self.MQ_LEN, max_model_len) + + # Extract per-batch masks from old format and compare + old_offset = 0 + for b in range(B): + kv_len = context_lens_list[b] + old_mask_b = old_mask[old_offset:old_offset + self.MQ_LEN * kv_len].reshape(self.MQ_LEN, kv_len) + new_mask_b = new_bias_2d[b * self.MQ_LEN:(b + 1) * self.MQ_LEN, :kv_len] + + # Old: True = attend, False = mask + # New: 0.0 = attend, -1e6 = mask + new_attend = (new_mask_b == 0.0) + old_attend = old_mask_b.bool() + + mismatches = (new_attend != old_attend).sum().item() + assert mismatches == 0, ( + f"Mask mismatch at batch={b}, step={step}: {mismatches} positions differ\n" + f" old attend count: {old_attend.sum().item()}, new attend count: {new_attend.sum().item()}\n" + f" context_len={kv_len}, cache_hit={cache_hits_list[b]}" + ) + old_offset += self.MQ_LEN * kv_len + + @pytest.mark.parametrize("step", [0, 1]) + def test_single_seq_cache_hit(self, step): + # context_lens must be >= ttl_added = (step+1)*MQ_LEN + K+1 + cl = 30 + step * self.MQ_LEN + self._compare_masks(B=1, context_lens_list=[cl], step=step, cache_hits_list=[1]) + + @pytest.mark.parametrize("step", [0, 1]) + def test_single_seq_cache_miss(self, step): + cl = 30 + step * self.MQ_LEN + self._compare_masks(B=1, context_lens_list=[cl], step=step, cache_hits_list=[0]) + + @pytest.mark.parametrize("step", [0, 1]) + def test_multi_seq_mixed_hits(self, step): + base = 25 + step * self.MQ_LEN + self._compare_masks( + B=3, + context_lens_list=[base, base + 10, base + 5], + step=step, + cache_hits_list=[1, 0, 1], + ) + + def test_step_2(self): + cl = 40 + 2 * self.MQ_LEN + self._compare_masks(B=2, context_lens_list=[cl, cl - 5], step=2, cache_hits_list=[1, 0]) + + +class TestFA4WithTreeMask: + """End-to-end: verify FA4 attention with tree mask produces valid, masked output.""" + + @pytest.fixture(autouse=True) + def setup(self): + torch.manual_seed(42) + self.B = 2 + self.K = 2 + self.MQ_LEN = 6 + self.num_heads = 4 + self.num_kv_heads = 2 + self.head_dim = 128 + self.num_pages = 200 + self.page_size = 1 + self.max_pages_per_seq = 50 + self.max_kv_stride = 50 + self.fan_out_list = [2, 2, 2] + self.fan_out_list_miss = [2, 2, 2] + + def test_masked_vs_unmasked_differ(self): + """Masked attention should produce different output than unmasked.""" + kv_lens = [20, 15] + total_q = self.B * self.MQ_LEN + q = torch.randn(total_q, self.num_heads, self.head_dim, dtype=DTYPE, device=DEVICE) + k = torch.randn(self.num_pages, self.page_size, self.num_kv_heads, self.head_dim, dtype=DTYPE, device=DEVICE) + v = torch.randn(self.num_pages, self.page_size, self.num_kv_heads, self.head_dim, dtype=DTYPE, device=DEVICE) + cu = torch.arange(self.B + 1, dtype=torch.int32, device=DEVICE) * self.MQ_LEN + pt = torch.zeros(self.B, self.max_pages_per_seq, dtype=torch.int32, device=DEVICE) + for b in range(self.B): + pt[b, :kv_lens[b]] = torch.arange(kv_lens[b], dtype=torch.int32, device=DEVICE) + b * 50 + sk = torch.tensor(kv_lens, dtype=torch.int32, device=DEVICE) + + # Unmasked (causal=False, no score_mod) + out_unmasked, _ = flash_attn_varlen_func( + q, k, v, cu_seqlens_q=cu, cu_seqlens_k=None, + max_seqlen_q=self.MQ_LEN, max_seqlen_k=max(kv_lens), + seqused_k=sk, page_table=pt, + softmax_scale=self.head_dim ** -0.5, causal=False, + ) + + # Masked + score_mod = create_tree_score_mod(self.max_kv_stride) + context_lens = torch.tensor(kv_lens, dtype=torch.int32) + cache_hits = torch.tensor([1, 1]) + mask_bias = build_tree_mask_bias( + context_lens, step=0, K=self.K, MQ_LEN=self.MQ_LEN, + fan_out_list=self.fan_out_list, fan_out_list_miss=self.fan_out_list_miss, + cache_hits=cache_hits, max_kv_stride=self.max_kv_stride, device=DEVICE, + ) + out_masked, _ = flash_attn_varlen_func( + q, k, v, cu_seqlens_q=cu, cu_seqlens_k=None, + max_seqlen_q=self.MQ_LEN, max_seqlen_k=max(kv_lens), + seqused_k=sk, page_table=pt, + softmax_scale=self.head_dim ** -0.5, causal=False, + score_mod=score_mod, aux_tensors=[mask_bias], + ) + + assert not torch.isnan(out_masked).any(), "Masked output has NaN" + assert not torch.allclose(out_masked, out_unmasked, atol=1e-2), \ + "Masked and unmasked should produce different outputs"