diff --git a/examples/dynamo/attention_plugin_example.py b/examples/dynamo/attention_plugin_example.py new file mode 100644 index 0000000000..d0498eeeaa --- /dev/null +++ b/examples/dynamo/attention_plugin_example.py @@ -0,0 +1,756 @@ +""" +.. _attention_plugin_example: + +Custom Attention Plugin with KV Cache Management +================================================= + +This example demonstrates how to use a custom TensorRT AttentionPlugin that implements +efficient multi-head attention with Rotary Position Embedding (RoPE) and KV cache management +for autoregressive generation in Large Language Models (LLMs). + +**Plugin Library:** + +This example uses a custom TensorRT plugin shared library (``libNvInfer_edgellm_plugin.so``) +that replaces standard transformer attention operations and RoPE computations with optimized +CUDA kernels. The plugin source code is available at: + +https://github.com/chohk88/TensorRT-Edge-LLM/tree/feature/torch-tensorrt-python-runtime + +Build instructions and implementation details can be found in the repository above. +This implementation has been verified with TensorRT-Edge-LLM release 0.4.0. + +**Key Features:** + +- **Dual Kernel Support:** + + - **FMHA (Fused Multi-Head Attention)** for context phase when ``seq_len > 1`` (processing multiple tokens) + - **XQA (Extended Query Attention)** for decode phase when ``seq_len = 1`` (single token generation) + +- **KV Cache Management:** Efficiently manages key-value cache for autoregressive generation +- **Perfect Accuracy:** Achieves cosine similarity = 1.0 with PyTorch's ``scaled_dot_product_attention`` +- **Grouped Query Attention (GQA):** Supports efficient attention with fewer KV heads + +**What This Example Tests:** + +1. **XQA Kernel (seq_len=1):** Single token generation, with and without past context +2. **FMHA Kernel (seq_len>1):** Context processing with multiple tokens +3. **Multi-Step Generation:** Realistic LLM scenario - process prompt (FMHA), then generate tokens (XQA) +4. **Perfect Accuracy:** All tests achieve ``cosine_similarity ≥ 0.99`` with PyTorch SDPA + +**Installation Requirements:** + +.. code-block:: bash + + pip install torch torch_tensorrt tensorrt + +Build the AttentionPlugin shared library following instructions at the GitHub repository above. +The compiled library should be located at: ``/path/to/tensorrt-edgellm/build/libNvInfer_edgellm_plugin.so`` +""" + +# %% +# Imports and Setup +# ----------------- + +import ctypes +import os +from typing import Tuple + +import numpy as np +import tensorrt as trt +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_tensorrt +from torch_tensorrt.dynamo.conversion import ( + ConversionContext, + dynamo_tensorrt_converter, +) +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor + +# %% +# Enable plugin debug logging +# ---------------------------- +os.environ["TRT_EDGELLM_DEBUG_PLUGIN"] = "1" + +# %% +# Initialize CUDA and Load Plugin +# -------------------------------- +# CUDA must be initialized before loading the TensorRT plugin library + +print("Initializing CUDA context...") +DEVICE = torch.device("cuda:0") +_ = torch.zeros(1, device=DEVICE) # Initialize CUDA +print(f"CUDA initialized on {DEVICE}\n") + +PLUGIN_PATH = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "TensorRT-Edge-LLM", + "build", + "libNvInfer_edgellm_plugin.so", +) +ctypes.CDLL(PLUGIN_PATH) +print(f"Loaded plugin: {PLUGIN_PATH}\n") + +# %% +# Model Configuration +# ------------------- +# These hyperparameters match typical LLM architectures with Grouped Query Attention (GQA) + +BATCH_SIZE = 1 +NUM_Q_HEADS = 4 # Number of query heads +NUM_KV_HEADS = 2 # Number of key/value heads (GQA: fewer than query heads) +HEAD_DIM = 64 # Dimension per head +KV_CACHE_CAPACITY = 128 # Maximum sequence length +HIDDEN_DIM = NUM_Q_HEADS * HEAD_DIM # 256 +NUM_KV_GROUPS = NUM_Q_HEADS // NUM_KV_HEADS # 2 + +DTYPE = torch.float16 + +# %% +# RoPE (Rotary Position Embedding) Utilities +# ------------------------------------------- +# RoPE encodes positional information through rotation in complex space + + +def precompute_rope(head_dim: int, max_seq_len: int = 128, base: float = 10000.0): + """ + Precompute RoPE cos/sin for all positions. + + Returns: + Tensor of shape [1, max_seq_len, head_dim] in FP32 + """ + inv_freq = 1.0 / ( + base ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim) + ) + t = torch.arange(max_seq_len, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + rope = torch.cat([cos, sin], dim=-1) + return rope.unsqueeze(0).to(DEVICE) + + +def apply_rope(x, rope_cache, position_ids): + """ + Apply RoPE to input tensor. + + Args: + x: [batch, num_heads, seq_len, head_dim] + rope_cache: [1, max_seq_len, head_dim] + position_ids: [seq_len] position indices + """ + seq_len = x.shape[2] + rope = rope_cache[:, position_ids, :] # [1, seq_len, head_dim] + rope = rope.unsqueeze(1) # [1, 1, seq_len, head_dim] + + half_dim = x.shape[-1] // 2 + cos = rope[..., :half_dim] + sin = rope[..., half_dim:] + + x_fp32 = x.float() + x1 = x_fp32[..., :half_dim] + x2 = x_fp32[..., half_dim:] + + rotated = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) + + return rotated.half() + + +def repeat_kv(x, n_rep): + """Repeat KV heads for Grouped Query Attention""" + if n_rep == 1: + return x + bs, n_kv_heads, slen, head_dim = x.shape + return ( + x[:, :, None, :, :] + .expand(bs, n_kv_heads, n_rep, slen, head_dim) + .reshape(bs, n_kv_heads * n_rep, slen, head_dim) + ) + + +# %% +# PyTorch SDPA Reference Implementation +# ------------------------------------- +# This serves as the ground truth for correctness validation + + +class SDPAModel(nn.Module): + """Reference attention using PyTorch's scaled_dot_product_attention""" + + def __init__(self): + super().__init__() + self.num_q_heads = NUM_Q_HEADS + self.num_kv_heads = NUM_KV_HEADS + self.head_dim = HEAD_DIM + self.num_key_value_groups = NUM_KV_GROUPS + + self.qkv = nn.Linear( + HIDDEN_DIM, HIDDEN_DIM + 2 * NUM_KV_HEADS * HEAD_DIM, bias=True + ) + self.out = nn.Linear(HIDDEN_DIM, HIDDEN_DIM, bias=False) + + def forward(self, x, kv_cache, ctx_len_tensor, rope): + """ + Args: + x: [batch, seq_len, hidden_dim] + kv_cache: [batch, 2, num_kv_heads, capacity, head_dim] + ctx_len_tensor: [batch] - total context length including current tokens + rope: [1, max_seq_len, head_dim] + """ + batch_size, seq_len, _ = x.shape + ctx_len = ctx_len_tensor[0].item() + past_len = ctx_len - seq_len + + # QKV projection + qkv = self.qkv(x) + q_size = self.num_q_heads * self.head_dim + kv_size = self.num_kv_heads * self.head_dim + query, key, value = torch.split(qkv, [q_size, kv_size, kv_size], dim=-1) + + # Reshape to multi-head format + query = query.view( + batch_size, seq_len, self.num_q_heads, self.head_dim + ).transpose(1, 2) + key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose( + 1, 2 + ) + value = value.view( + batch_size, seq_len, self.num_kv_heads, self.head_dim + ).transpose(1, 2) + + # Apply RoPE + position_ids = torch.arange(past_len, past_len + seq_len, device=x.device) + query = apply_rope(query, rope, position_ids) + key = apply_rope(key, rope, position_ids) + + # Update KV cache + kv_cache[:, 0, :, past_len : past_len + seq_len, :] = key + kv_cache[:, 1, :, past_len : past_len + seq_len, :] = value + + # Get full K/V from cache + full_key = kv_cache[:, 0, :, :ctx_len, :] + full_value = kv_cache[:, 1, :, :ctx_len, :] + + # Expand for GQA + full_key = repeat_kv(full_key, self.num_key_value_groups) + full_value = repeat_kv(full_value, self.num_key_value_groups) + + # Scaled dot-product attention + is_causal = seq_len > 1 + attn_out = F.scaled_dot_product_attention( + query.contiguous(), + full_key.contiguous(), + full_value.contiguous(), + attn_mask=None, + dropout_p=0.0, + is_causal=is_causal, + ) + + # Output projection + attn_out = ( + attn_out.transpose(1, 2).contiguous().view(batch_size, seq_len, HIDDEN_DIM) + ) + output = self.out(attn_out) + + return output, kv_cache + + +# %% +# TensorRT Plugin Integration +# ---------------------------- +# Register custom operation and converter for TensorRT plugin + + +def register_plugin_op(): + """ + Register custom attention operation. + + Note: The release version of TensorRT-Edge-LLM requires 5 inputs: + - qkv: [B, S, (Hq+Hk+Hv)*D] fused QKV tensor + - kv: [B, 2, Hkv, Capacity, D] KV cache tensor + - ctx_len: [B] context length per batch + - rope: [S, D] rotary position encoding + - kv_cache_start_idx: [B] starting index in KV cache (required for release version) + """ + + @torch.library.custom_op("tensorrt_edge_llm::xqa_attn", mutates_args=()) + def attn( + qkv: torch.Tensor, + kv: torch.Tensor, + ctx_len: torch.Tensor, + rope: torch.Tensor, + kv_cache_start_idx: torch.Tensor, # Required 5th input for release plugin + nq: int, + nkv: int, + d: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = qkv.shape[0] + seq_len = qkv.shape[1] + attn_out = torch.zeros( + batch_size, seq_len, nq, d, dtype=qkv.dtype, device=qkv.device + ) + updated_kv = kv.clone() + return attn_out, updated_kv + + @torch.library.register_fake("tensorrt_edge_llm::xqa_attn") + def _(qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d): + batch_size = qkv.shape[0] + seq_len = qkv.shape[1] + attn_out = torch.empty( + batch_size, seq_len, nq, d, dtype=qkv.dtype, device=qkv.device + ) + updated_kv = kv.clone() + return attn_out, updated_kv + + +register_plugin_op() + + +@dynamo_tensorrt_converter( + torch.ops.tensorrt_edge_llm.xqa_attn.default, supports_dynamic_shapes=True +) +def convert_attn(ctx: ConversionContext, target, args, kwargs, name): + """ + Convert PyTorch custom op to TensorRT plugin. + + Release version of TensorRT-Edge-LLM requires 5 inputs: + - qkv, kv, ctx_len, rope, kv_cache_start_idx + + Plugin fields for release version: + - num_q_heads, num_kv_heads, head_size, enable_tree_attention, enable_delta_kv_output + """ + # args: qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d + qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d = args[:8] + + # Get plugin creator + creator = trt.get_plugin_registry().get_plugin_creator("AttentionPlugin", "1", "") + if creator is None: + raise RuntimeError("AttentionPlugin not found! Make sure plugin is loaded.") + + # Plugin fields for TensorRT-Edge-LLM AttentionPlugin + # Required: num_q_heads, num_kv_heads, head_size, enable_tree_attention + # enable_delta_kv_output=1 enables delta KV output for Python/torch_tensorrt compatibility + field_list = [ + trt.PluginField( + field_name, np.array([field_val], dtype=np.int32), trt.PluginFieldType.INT32 + ) + for field_name, field_val in [ + ("num_q_heads", nq), + ("num_kv_heads", nkv), + ("head_size", d), + ("enable_tree_attention", 0), + ("enable_delta_kv_output", 1), + ] + ] + + fields = trt.PluginFieldCollection(field_list) + plugin = creator.create_plugin(name, fields) + + if plugin is None: + raise RuntimeError("Failed to create plugin") + + # 5 inputs for release version: qkv, kv, ctx_len, rope, kv_cache_start_idx + inputs = [ + ( + get_trt_tensor(ctx, i, f"{name}_i{idx}") + if not isinstance(i, trt.ITensor) + else i + ) + for idx, i in enumerate([qkv, kv, ctx_len, rope, kv_cache_start_idx]) + ] + + # Handle kv_cache_start_idx shape if needed (squeeze if [B, 1] -> [B]) + if len(inputs[4].shape) == 2 and inputs[4].shape[1] == 1: + shuffle_layer = ctx.net.add_shuffle(inputs[4]) + shuffle_layer.reshape_dims = (inputs[4].shape[0],) + inputs[4] = shuffle_layer.get_output(0) + + layer = ctx.net.add_plugin_v2(inputs, plugin) + + return layer.get_output(0), layer.get_output(1) + + +class PluginModel(nn.Module): + """Attention model using TensorRT plugin""" + + def __init__(self): + super().__init__() + self.qkv = nn.Linear( + HIDDEN_DIM, HIDDEN_DIM + 2 * NUM_KV_HEADS * HEAD_DIM, bias=True + ) + self.out = nn.Linear(HIDDEN_DIM, HIDDEN_DIM, bias=False) + + def forward(self, x, kv_cache, ctx_len_tensor, rope): + bsz, seq_len, _ = x.shape + qkv = self.qkv(x) + + # kv_cache_start_idx: starting position in KV cache for each batch + # For normal inference, this is 0 (start from beginning) + kv_cache_start_idx = torch.zeros(bsz, dtype=torch.int32, device=x.device) + + # Custom plugin call (5 inputs for release version) + attn_out, updated_kv = torch.ops.tensorrt_edge_llm.xqa_attn.default( + qkv, + kv_cache, + ctx_len_tensor, + rope, + kv_cache_start_idx, + NUM_Q_HEADS, + NUM_KV_HEADS, + HEAD_DIM, + ) + + # Reshape from [B, S, num_heads, head_dim] to [B, S, hidden_dim] + attn_out = attn_out.reshape(bsz, seq_len, HIDDEN_DIM) + + return self.out(attn_out), updated_kv + + +# %% +# Test Functions +# -------------- + + +def test_case( + name: str, seq_len: int, has_past_context: bool, sdpa_model, trt_model, rope +): + """ + Run a single test case and validate correctness. + + Args: + name: Test case name + seq_len: Sequence length (1 for XQA, >1 for FMHA) + has_past_context: Whether to initialize KV cache with past tokens + sdpa_model: PyTorch SDPA reference model + trt_model: Compiled TensorRT model + rope: Precomputed RoPE cache + + Note: + With enable_delta_kv_output=1, TRT plugin outputs only the delta KV: + - Context Phase: [B, 2, H, seq_len, D] (newly processed tokens) + - Generation Phase: [B, 2, H, 1, D] (single new token) + Python runtime must merge this delta into the main KV cache. + """ + print(f"\n{name}") + + # Determine context length + past_len = 10 if has_past_context else 0 + ctx_len = torch.tensor([past_len + seq_len], dtype=torch.int32, device=DEVICE) + + # Initialize KV caches + sdpa_kv = torch.zeros( + BATCH_SIZE, + 2, + NUM_KV_HEADS, + KV_CACHE_CAPACITY, + HEAD_DIM, + dtype=DTYPE, + device=DEVICE, + ) + trt_kv = torch.zeros( + BATCH_SIZE, + 2, + NUM_KV_HEADS, + KV_CACHE_CAPACITY, + HEAD_DIM, + dtype=DTYPE, + device=DEVICE, + ) + + # Add past context if needed + if has_past_context: + past_values = torch.randn( + BATCH_SIZE, 2, NUM_KV_HEADS, past_len, HEAD_DIM, dtype=DTYPE, device=DEVICE + ) + sdpa_kv[:, :, :, :past_len, :] = past_values + trt_kv[:, :, :, :past_len, :] = past_values + print(f" Input: {seq_len} new tokens + {past_len} past tokens in cache") + else: + print(f" Input: {seq_len} tokens (empty KV cache)") + + # Generate input tokens + x = torch.randn(BATCH_SIZE, seq_len, HIDDEN_DIM, dtype=DTYPE, device=DEVICE) + + # Run both models + with torch.no_grad(): + sdpa_out, sdpa_kv_new = sdpa_model(x, sdpa_kv, ctx_len, rope) + trt_out, trt_kv_delta = trt_model(x, trt_kv, ctx_len, rope) + + # TRT plugin with enable_delta_kv_output=1 returns only delta KV + # Merge delta into main KV cache at the correct position + delta_seq_len = trt_kv_delta.shape[3] # Should be seq_len + trt_kv[:, :, :, past_len : past_len + delta_seq_len, :] = trt_kv_delta + + # Compute similarities + attn_sim = F.cosine_similarity( + sdpa_out.flatten().float(), trt_out.flatten().float(), dim=0 + ).item() + + # Compare the newly updated portion of KV cache (after merge) + new_kv_sim = F.cosine_similarity( + sdpa_kv_new[:, :, :, past_len : past_len + seq_len, :].flatten().float(), + trt_kv[:, :, :, past_len : past_len + seq_len, :].flatten().float(), + dim=0, + ).item() + + # Determine which kernel was used + kernel_type = "XQA (decode)" if seq_len == 1 else "FMHA (context)" + + # Print results + print(f" Kernel Used: {kernel_type}") + print(f" Attention Output: cosine_similarity = {attn_sim:.6f}") + print(f" Updated KV Cache: cosine_similarity = {new_kv_sim:.6f}") + + # If there's past context, verify it's preserved in our main buffer + if has_past_context: + past_sim = F.cosine_similarity( + sdpa_kv_new[:, :, :, :past_len, :].flatten().float(), + trt_kv[:, :, :, :past_len, :].flatten().float(), + dim=0, + ).item() + print(f" Past KV Preserved: cosine_similarity = {past_sim:.6f}") + passed = attn_sim >= 0.99 and new_kv_sim >= 0.99 and past_sim >= 0.99 + else: + passed = attn_sim >= 0.99 and new_kv_sim >= 0.99 + + status = "PASS" if passed else "FAIL" + print(f" Result: {status}") + + return passed, attn_sim, new_kv_sim + + +# %% +# Main Execution +# -------------- + +if __name__ == "__main__": + print("\nCustom Attention Plugin - Correctness Validation") + + # Precompute RoPE + rope = precompute_rope(HEAD_DIM, KV_CACHE_CAPACITY) + + # Create models + print("\nCreating models...") + sdpa_model = SDPAModel().to(DEVICE).to(DTYPE).eval() + plugin_model = PluginModel().to(DEVICE).to(DTYPE).eval() + + # Share weights + plugin_model.qkv.weight.data.copy_(sdpa_model.qkv.weight.data) + plugin_model.qkv.bias.data.copy_(sdpa_model.qkv.bias.data) + plugin_model.out.weight.data.copy_(sdpa_model.out.weight.data) + print("Weights shared between models") + + # Compile with Torch-TensorRT (with dynamic shapes for seq_len) + print("\nCompiling with Torch-TensorRT...") + x_example = torch.randn(BATCH_SIZE, 1, HIDDEN_DIM, dtype=DTYPE, device=DEVICE) + kv_example = torch.zeros( + BATCH_SIZE, + 2, + NUM_KV_HEADS, + KV_CACHE_CAPACITY, + HEAD_DIM, + dtype=DTYPE, + device=DEVICE, + ) + ctx_example = torch.tensor([1], dtype=torch.int32, device=DEVICE) + + # Enable dynamic shapes for seq_len dimension + inputs_spec = [ + torch_tensorrt.Input( + min_shape=(BATCH_SIZE, 1, HIDDEN_DIM), + opt_shape=(BATCH_SIZE, 8, HIDDEN_DIM), + max_shape=(BATCH_SIZE, 32, HIDDEN_DIM), + dtype=DTYPE, + ), + kv_example, + ctx_example, + rope, + ] + + with torch_tensorrt.logging.errors(): + trt_model = torch_tensorrt.compile( + plugin_model, + inputs=inputs_spec, + enabled_precisions={torch.float32}, + use_explicit_typing=True, + use_fp32_acc=True, + min_block_size=1, + device=DEVICE, + ) + print("Compilation complete") + + # %% + # Run Test Cases + # -------------- + # Test all 4 combinations: {seq_len=1, seq_len>1} × {empty cache, with past} + + print("\nRunning Test Cases") + + results = [] + + # Test 1: Single token, empty cache (XQA kernel, cold start) + results.append( + test_case( + "Test 1: Single Token Generation (XQA) - Empty Cache", + seq_len=1, + has_past_context=False, + sdpa_model=sdpa_model, + trt_model=trt_model, + rope=rope, + ) + ) + + # Test 2: Single token, with past context (XQA kernel, typical decode) + results.append( + test_case( + "Test 2: Single Token Generation (XQA) - With Past Context", + seq_len=1, + has_past_context=True, + sdpa_model=sdpa_model, + trt_model=trt_model, + rope=rope, + ) + ) + + # Test 3: Multiple tokens, empty cache (FMHA kernel, prefill phase) + results.append( + test_case( + "Test 3: Context Processing (FMHA) - Empty Cache", + seq_len=16, + has_past_context=False, + sdpa_model=sdpa_model, + trt_model=trt_model, + rope=rope, + ) + ) + + # %% + # Multi-Step Generation Test + # --------------------------- + # Realistic test: Process initial context (FMHA), then generate tokens one by one (XQA) + # Note: With enable_delta_kv_output=1, we must merge delta KV into main buffer + + print("\nTest 4: Multi-Step Generation (FMHA -> XQA x 3)") + print("Simulating real LLM generation:") + print(" 1. Process initial prompt with FMHA (seq_len=16)") + print(" 2. Generate tokens one by one with XQA (seq_len=1)") + + # Step 1: Process initial prompt (FMHA) + initial_seq_len = 16 + x_init = torch.randn( + BATCH_SIZE, initial_seq_len, HIDDEN_DIM, dtype=DTYPE, device=DEVICE + ) + ctx_len_init = torch.tensor([initial_seq_len], dtype=torch.int32, device=DEVICE) + + sdpa_kv_multi = torch.zeros( + BATCH_SIZE, + 2, + NUM_KV_HEADS, + KV_CACHE_CAPACITY, + HEAD_DIM, + dtype=DTYPE, + device=DEVICE, + ) + trt_kv_multi = torch.zeros( + BATCH_SIZE, + 2, + NUM_KV_HEADS, + KV_CACHE_CAPACITY, + HEAD_DIM, + dtype=DTYPE, + device=DEVICE, + ) + + with torch.no_grad(): + sdpa_out_init, sdpa_kv_multi = sdpa_model( + x_init, sdpa_kv_multi, ctx_len_init, rope + ) + trt_out_init, trt_kv_delta = trt_model(x_init, trt_kv_multi, ctx_len_init, rope) + + # Merge delta KV into main buffer (context phase: delta has shape [B, 2, H, seq_len, D]) + delta_len = trt_kv_delta.shape[3] + trt_kv_multi[:, :, :, :delta_len, :] = trt_kv_delta + + init_sim = F.cosine_similarity( + sdpa_out_init.flatten().float(), trt_out_init.flatten().float(), dim=0 + ).item() + + print(f"\nStep 1: Initial prompt (FMHA, seq_len={initial_seq_len})") + print(f" Similarity: {init_sim:.6f}") + + # Step 2: Generate tokens one by one (XQA) + num_gen_tokens = 3 + all_passed_multi = init_sim > 0.99 + current_pos = initial_seq_len # Track current position in KV cache + + for gen_step in range(num_gen_tokens): + current_ctx_len = initial_seq_len + gen_step + 1 + x_gen = torch.randn(BATCH_SIZE, 1, HIDDEN_DIM, dtype=DTYPE, device=DEVICE) + ctx_len_gen = torch.tensor([current_ctx_len], dtype=torch.int32, device=DEVICE) + + with torch.no_grad(): + sdpa_out_gen, sdpa_kv_multi = sdpa_model( + x_gen, sdpa_kv_multi, ctx_len_gen, rope + ) + trt_out_gen, trt_kv_delta = trt_model( + x_gen, trt_kv_multi, ctx_len_gen, rope + ) + + # Merge delta KV into main buffer (generation phase: delta has shape [B, 2, H, 1, D]) + trt_kv_multi[:, :, :, current_pos : current_pos + 1, :] = trt_kv_delta + current_pos += 1 + + gen_sim = F.cosine_similarity( + sdpa_out_gen.flatten().float(), trt_out_gen.flatten().float(), dim=0 + ).item() + + kv_sim_gen = F.cosine_similarity( + sdpa_kv_multi[:, :, :, :current_ctx_len, :].flatten().float(), + trt_kv_multi[:, :, :, :current_ctx_len, :].flatten().float(), + dim=0, + ).item() + + passed = gen_sim > 0.99 and kv_sim_gen > 0.99 + all_passed_multi = all_passed_multi and passed + + print(f"\nStep {gen_step + 2}: Generate token {gen_step + 1} (XQA, seq_len=1)") + print(f" Attn similarity: {gen_sim:.6f}") + print(f" KV similarity: {kv_sim_gen:.6f}") + + results.append( + ( + all_passed_multi, + 1.0 if all_passed_multi else 0.0, + 1.0 if all_passed_multi else 0.0, + ) + ) + + print(f"\nResult: {'PASS - All steps matched!' if all_passed_multi else 'FAIL'}") + + # %% + # Summary + # ------- + + print("\nSUMMARY") + + test_names = [ + "Test 1: XQA - Empty Cache", + "Test 2: XQA - With Past", + "Test 3: FMHA - Empty Cache", + "Test 4: Multi-Step (FMHA->XQA)", + ] + + for name, (passed, attn_sim, kv_sim) in zip(test_names, results): + status = "PASS" if passed else "FAIL" + print(f"{name}: {status}") + print(f" Attention: {attn_sim:.4f}, KV Cache: {kv_sim:.4f}") + + all_passed = all(r[0] for r in results) + + if all_passed: + print("SUCCESS: All tests passed!") + print("Both FMHA and XQA kernels work correctly") + print("KV cache management is accurate") + print("Perfect agreement with PyTorch SDPA (cosine similarity >= 0.99)") + else: + print("FAILURE: Some tests failed") diff --git a/examples/dynamo/end_to_end_llm_generation_example.py b/examples/dynamo/end_to_end_llm_generation_example.py new file mode 100644 index 0000000000..dd2a23f721 --- /dev/null +++ b/examples/dynamo/end_to_end_llm_generation_example.py @@ -0,0 +1,404 @@ +""" +End-to-End LLM Generation Example with TensorRT Attention Plugin + +This example demonstrates how to use the TensorRT attention plugin for +efficient LLM inference with KV caching. + +The plugin utilities are shared with tools/llm/run_llm.py for consistency. + +This implementation has been verified with TensorRT-Edge-LLM release 0.4.0. +The plugin source code is available at: +https://github.com/chohk88/TensorRT-Edge-LLM/tree/feature/torch-tensorrt-python-runtime +""" + +import os +import sys +import time + +import torch +import torch.nn as nn +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +# Add tools/llm to path for shared utilities +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../tools/llm")) + +from plugin_utils import ( + LLMPluginWrapper, + PluginAttention, + benchmark_plugin_generation, + compile_plugin_model, + create_kv_caches, + generate_with_plugin, + get_plugin_config, + get_plugin_rope_cache, + load_plugin, + register_plugin_op, + replace_attention_with_plugin, + set_plugin_config_from_model, +) + +# Configuration +MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" +MAX_SEQ_LEN = 2048 +DTYPE = torch.float16 +DEVICE = torch.device("cuda:0") + +# Load the plugin +load_plugin() +register_plugin_op() + + +# ----------------------------------------------------------------------------- +# Backward Compatibility Exports +# ----------------------------------------------------------------------------- + +# These are exported for backward compatibility with any code that imports +# from this module directly. + +# Re-export Qwen2Wrapper as an alias for LLMPluginWrapper +Qwen2Wrapper = LLMPluginWrapper + + +# Re-export replace_attention for backward compatibility +def replace_attention(model, config): + """ + Replace attention modules with plugin attention. + + This is a backward-compatible wrapper around replace_attention_with_plugin. + """ + return replace_attention_with_plugin(model, config, MAX_SEQ_LEN, DEVICE, DTYPE) + + +def compile_model(model, input_ids, position_ids, kv_caches, ctx_len): + """ + Compile a model for TensorRT inference. + + This is a backward-compatible wrapper that extracts config from the model. + """ + # Get config from the wrapped model + if hasattr(model, "model"): + inner_model = model.model + if hasattr(inner_model, "config"): + config = inner_model.config + else: + config = inner_model.model.config + else: + config = model.config + + return compile_plugin_model(model, config, MAX_SEQ_LEN, DEVICE, DTYPE) + + +# Global config for backward compatibility with converter +TARGET_CONFIG = None + + +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- + + +def apply_repetition_penalty(logits, generated_ids, penalty): + """Apply repetition penalty to logits.""" + if penalty == 1.0: + return logits + + score = torch.gather(logits, 1, generated_ids) + score = torch.where(score < 0, score * penalty, score / penalty) + logits.scatter_(1, generated_ids, score) + return logits + + +# ----------------------------------------------------------------------------- +# Benchmarking +# ----------------------------------------------------------------------------- + + +def benchmark_generation(model_func, isl, osl, config, run_name="Model"): + """ + Benchmark generation with the plugin model. + + This wraps benchmark_plugin_generation for backward compatibility. + """ + return benchmark_plugin_generation( + model_func, config, isl, osl, MAX_SEQ_LEN, DEVICE, DTYPE, run_name + ) + + +def run_pytorch_benchmark_manual(model, config, isl, osl): + """Run PyTorch benchmark with manual loop (no KV cache).""" + input_ids = torch.randint(0, config.vocab_size, (1, isl), device=DEVICE) + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + + with torch.no_grad(): + generated_ids = input_ids + + for _ in range(osl): + outputs = model(generated_ids, use_cache=False) + next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(0) + generated_ids = torch.cat([generated_ids, next_token], dim=1) + + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + print( + f"PyTorch (Manual - No Cache) | ISL: {isl}, OSL: {osl} | Total Time: {elapsed_ms:.2f} ms | Tokens/sec: {osl / (elapsed_ms / 1000.0):.2f}" + ) + return elapsed_ms + + +def run_pytorch_benchmark_generate(model, config, isl, osl): + """Run PyTorch benchmark with model.generate() API.""" + input_ids = torch.randint(0, config.vocab_size, (1, isl), device=DEVICE) + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + + with torch.no_grad(): + _ = model.generate( + input_ids, + max_new_tokens=osl, + min_new_tokens=osl, + do_sample=False, + use_cache=True, + pad_token_id=config.eos_token_id, + ) + + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + print( + f"PyTorch (Generate) | ISL: {isl}, OSL: {osl} | Total Time: {elapsed_ms:.2f} ms | Tokens/sec: {osl / (elapsed_ms / 1000.0):.2f}" + ) + return elapsed_ms + + +def generate_reference(model, tokenizer, prompt, max_new_tokens=20): + """ + Generate reference output with PyTorch (greedy, no cache). + """ + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE) + generated_ids = input_ids + + repetition_penalty = getattr(model.generation_config, "repetition_penalty", 1.0) + print( + f"DEBUG: Using repetition_penalty={repetition_penalty} for Reference Generation" + ) + + for _ in range(max_new_tokens): + current_seq_len = generated_ids.shape[1] + position_ids = torch.arange( + current_seq_len, dtype=torch.long, device=DEVICE + ).unsqueeze(0) + + outputs = model(generated_ids, position_ids=position_ids, use_cache=False) + next_token_logits = outputs.logits[:, -1, :] + + next_token_logits = apply_repetition_penalty( + next_token_logits, generated_ids, repetition_penalty + ) + next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) + + if next_token.item() == tokenizer.eos_token_id: + break + + generated_ids = torch.cat([generated_ids, next_token], dim=1) + + return tokenizer.decode(generated_ids[0], skip_special_tokens=True) + + +def verify_output(trt_model_func, model_pytorch, tokenizer, prompt, max_new_tokens=20): + """Verify TensorRT output matches PyTorch reference.""" + print(f"\nPrompt: '{prompt}'") + + # 1. PyTorch Reference Generation + print("\n=== PyTorch Reference Generation ===") + inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) + input_ids = inputs.input_ids + + with torch.no_grad(): + pyt_outputs = generate_reference( + model_pytorch, tokenizer, prompt, max_new_tokens=30 + ) + print(f"PyTorch Reference Text Output: {pyt_outputs}") + + with torch.no_grad(): + pyt_outputs_generate_ids = model_pytorch.generate( + input_ids, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + pad_token_id=tokenizer.eos_token_id, + ) + pyt_outputs_generate_text = tokenizer.decode( + pyt_outputs_generate_ids[0], skip_special_tokens=True + ) + print(f"PyTorch Generate Text Output: {pyt_outputs_generate_text}") + + pyt_text = pyt_outputs + print(f"PyTorch Output: {pyt_text}") + + # 2. TensorRT Plugin Generation + print("\n=== TensorRT Plugin Generation ===") + + repetition_penalty = getattr( + model_pytorch.generation_config, "repetition_penalty", 1.0 + ) + print( + f"DEBUG: Using repetition_penalty={repetition_penalty} for TensorRT Generation" + ) + + seq_len = input_ids.shape[1] + position_ids = torch.arange(seq_len, dtype=torch.long, device=DEVICE).unsqueeze(0) + + config = model_pytorch.config + kv_caches = create_kv_caches(config, MAX_SEQ_LEN, 1, DEVICE, DTYPE) + + generated_ids = input_ids + + # Prefill + ctx_len = torch.tensor([seq_len], dtype=torch.int32, device=DEVICE) + logits, kv_caches_delta = trt_model_func( + input_ids, position_ids, kv_caches, ctx_len + ) + + for i, delta in enumerate(kv_caches_delta): + seq_len_out = delta.shape[3] + kv_caches[i][:, :, :, :seq_len_out, :] = delta + + next_token_logits = logits[:, -1, :] + next_token_logits = apply_repetition_penalty( + next_token_logits, generated_ids, repetition_penalty + ) + next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) + + generated_ids = torch.cat([generated_ids, next_token], dim=1) + + # Decode + cur_pos = seq_len + + if next_token.item() != tokenizer.eos_token_id: + for _ in range(max_new_tokens - 1): + input_ids_step = next_token + position_ids_step = torch.tensor( + [[cur_pos]], dtype=torch.long, device=DEVICE + ) + ctx_len_step = torch.tensor([cur_pos + 1], dtype=torch.int32, device=DEVICE) + + logits, kv_caches_delta = trt_model_func( + input_ids_step, position_ids_step, kv_caches, ctx_len_step + ) + + for i, delta in enumerate(kv_caches_delta): + kv_caches[i][:, :, :, cur_pos : cur_pos + 1, :] = delta + + next_token_logits = logits[:, -1, :] + next_token_logits = apply_repetition_penalty( + next_token_logits, generated_ids, repetition_penalty + ) + next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) + + if next_token.item() == tokenizer.eos_token_id: + break + + generated_ids = torch.cat([generated_ids, next_token], dim=1) + cur_pos += 1 + + trt_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + print(f"TensorRT Output: {trt_text}") + + # 3. Comparison + print("\n=== Comparison ===") + if pyt_text == trt_text: + print("SUCCESS: Outputs match exactly!") + else: + print("FAILURE: Outputs differ.") + print(f"PyTorch: {pyt_text}") + print(f"TensorRT: {trt_text}") + + +# ----------------------------------------------------------------------------- +# Main +# ----------------------------------------------------------------------------- + +if __name__ == "__main__": + torch.manual_seed(42) + + print(f"Loading {MODEL_NAME}...") + config = AutoConfig.from_pretrained(MODEL_NAME) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + + # Set global config for backward compatibility + # Note: TARGET_CONFIG is defined at module level for backward compatibility + globals()["TARGET_CONFIG"] = config + + # Set plugin config + set_plugin_config_from_model(config, MAX_SEQ_LEN) + + # 1. PyTorch Model + model_pytorch = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=DTYPE, + ).to(DEVICE) + model_pytorch.eval() + + # 2. TensorRT Plugin Model + model_trt = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=DTYPE).to( + DEVICE + ) + model_trt.eval() + + model_trt = replace_attention(model_trt, config) + wrapper = LLMPluginWrapper(model_trt) + + # Compilation + print("Compiling TensorRT model...") + + dummy_input_ids = torch.tensor([[1, 2, 3]], device=DEVICE) + dummy_pos_ids = torch.tensor([[0, 1, 2]], device=DEVICE) + dummy_ctx_len = torch.tensor([3], dtype=torch.int32, device=DEVICE) + dummy_kvs = create_kv_caches(config, MAX_SEQ_LEN, 1, DEVICE, DTYPE) + + trt_model_func = compile_model( + wrapper, dummy_input_ids, dummy_pos_ids, dummy_kvs, dummy_ctx_len + ) + + # 3. Verification + print("\n=== Verifying Output Accuracy ===") + verify_output( + trt_model_func, + model_pytorch, + tokenizer, + "What is parallel programming?", + max_new_tokens=30, + ) + + # 4. Benchmarks + benchmarks = [ + (128, 128), + (256, 128), + (512, 256), + ] + + print("\n=== Starting Benchmarks ===") + print(f"Device: {torch.cuda.get_device_name(0)}") + + for isl, osl in benchmarks: + print("-" * 60) + # PyTorch Manual Loop + run_pytorch_benchmark_manual(model_pytorch, config, isl, osl) + + # PyTorch Generate API + run_pytorch_benchmark_generate(model_pytorch, config, isl, osl) + + # TensorRT + benchmark_generation(trt_model_func, isl, osl, config, run_name="TensorRT") diff --git a/tools/llm/README.md b/tools/llm/README.md index cb921bcadb..55f0d960b9 100644 --- a/tools/llm/README.md +++ b/tools/llm/README.md @@ -7,10 +7,13 @@ This directory provides utilities and scripts for compiling, optimizing, and ben - **Model Support:** Works with popular LLMs such as Llama-3, Qwen2.5, etc. - **VLM Support:** Supports Visual Language Models like Qwen2.5-VL and Eagle2. - **Precision Modes:** Supports FP16, BF16, and FP32. +- **Multiple Backends:** + - **SDPA Backend** (`--backend sdpa`, default): Registers custom lowering pass for SDPA operations, converting attention to matmul+softmax+matmul with optional static KV cache support (`--cache static_v1`/`static_v2`) + - **IAttention Backend** (`--backend iattention`): Uses TensorRT's native `IAttention` layer for attention conversion. Single-pass inference only (KV cache not yet supported with this backend) + - **Plugin Backend** (`--backend plugin`): Uses TensorRT Edge-LLM attention plugin for optimized inference with built-in KV cache management +- **KV Cache:** Supports static KV cache for efficient autoregressive decoding (SDPA and Plugin backends). - **Quantization:** Supports FP8 and NVFP4 quantization formats for reduced memory usage and improved inference speed. -- **KV Cache:** Supports static and dynamic KV cache for efficient autoregressive decoding. - **Benchmarking:** Measures and compares throughput and latency for PyTorch and TensorRT backends. -- **Custom Attention:** Registers and converts custom scaled dot-product attention (SDPA) for compatibility with TensorRT. ### Supported Models @@ -38,29 +41,110 @@ We have officially verified support for the following models: #### Text-only LLMs: `run_llm.py` +**1. Generation with Output Verification** + +Compare PyTorch and TensorRT outputs to verify correctness: + +*SDPA Backend:* +```bash +python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --backend sdpa \ + --prompt "What is parallel programming?" --model_precision FP16 --num_tokens 30 --enable_pytorch_run +``` +
+Expected Output + +``` +========= PyTorch ========= +PyTorch model generated text: What is parallel programming? Parallel programming is a technique used to improve the performance of a program by dividing the work into smaller tasks and executing them simultaneously on multiple processors or cores. +=================================== +========= TensorRT ========= +TensorRT model generated text: What is parallel programming? Parallel programming is a technique used to improve the performance of a program by dividing the work into smaller tasks and executing them simultaneously on multiple processors or cores. +=================================== +PyTorch and TensorRT outputs match: True +``` +
+ +*IAttention Backend (native TRT IAttention layer, no KV cache):* +```bash +python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --backend iattention \ + --prompt "What is parallel programming?" --model_precision FP16 --num_tokens 30 --enable_pytorch_run +``` + +*Plugin Backend:* +```bash +python run_llm.py --model Qwen/Qwen2.5-0.5B-Instruct --backend plugin \ + --prompt "What is parallel programming?" --model_precision FP16 --num_tokens 30 --enable_pytorch_run +``` +
+Expected Output + +``` +========= PyTorch ========= +PyTorch model generated text: What is parallel programming? What are the benefits of parallel programming? What are the challenges of parallel programming? What are the different types of parallel programming? What are the advantages of +=================================== +========= TensorRT ========= +TensorRT model generated text: What is parallel programming? What are the benefits of parallel programming? What are the challenges of parallel programming? What are the different types of parallel programming? What are the advantages of +=================================== +PyTorch and TensorRT outputs match: True +``` +
+ +**2. Benchmarking for Performance Comparison** + +*Plugin Backend (compares TensorRT-Plugin vs PyTorch):* ```bash -python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --model_precision FP16 --num_tokens 128 --cache static_v2 --benchmark +python run_llm.py --model Qwen/Qwen2.5-0.5B-Instruct --backend plugin --model_precision FP16 \ + --benchmark --iterations 5 --isl 128 --num_tokens 20 --batch_size 1 --enable_pytorch_run ``` +*SDPA with Static Cache (compares TensorRT-SDPA-StaticCache vs PyTorch):* +```bash +python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --backend sdpa --cache static_v1 \ + --model_precision FP16 --benchmark --iterations 3 --isl 128 --num_tokens 128 --batch_size 1 --enable_pytorch_run +``` + +> **Note**: In benchmark mode, `--prompt` is not used. Random input tokens are generated based on `--isl` (input sequence length). + #### Vision Language Models: `run_vlm.py` +*Generation with Output Verification:* +```bash +python run_vlm.py --model nvidia/Eagle2-2B --precision FP16 --num_tokens 64 --cache static_v1 --enable_pytorch_run +``` + +*Benchmarking:* ```bash -python run_vlm.py --model nvidia/Eagle2-2B --precision FP16 --num_tokens 128 --cache static_v1 --enable_pytorch_run --benchmark +python run_vlm.py --model nvidia/Eagle2-2B --precision FP16 --cache static_v1 --benchmark --iterations 5 --num_tokens 128 ``` #### Key Arguments +**Model Configuration:** - `--model`: Name or path of the HuggingFace LLM/VLM. - `--tokenizer`: (Optional) Tokenizer name; defaults to model. -- `--prompt`: Input prompt for generation. +- `--backend`: Backend to use (`sdpa`, `iattention`, or `plugin`). Default is `sdpa`. Only applicable for LLM models. + - `sdpa`: Custom SDPA lowering pass + converter. Supports `--cache static_v1`/`static_v2` for KV caching. + - `iattention`: TensorRT native IAttention layer. No KV cache support yet (single-pass inference only). + - `plugin`: TensorRT Edge-LLM attention plugin. KV cache managed internally by the plugin. + +**Generation Settings:** +- `--prompt`: Input prompt for generation (generation mode only, ignored in benchmark mode). - `--image_path`: (Optional) Path to input image file for VLM models. If not provided, will use a sample image. - `--model_precision`: Precision of model weight/buffer (`FP16`, `BF16`, `FP32`). - `--quant_format`: (Optional) Quantization format (`int8`, `fp8`, `nvfp4`) to apply. - `--quant_algo`: (Optional) Quantization algorithm (`max`, `smoothquant`), by default it is `max`. - `--weight_only`: (Optional) weight only quantization flag, by default it False. - `--num_tokens`: Number of output tokens to generate. -- `--cache`: KV cache type (`static_v1`, `static_v2`, or empty for no KV caching). -- `--benchmark`: Enable benchmarking mode. + +**Cache and Optimization:** +- `--cache`: KV cache type for SDPA backend (`static_v1`, `static_v2`, or empty for no KV caching). + - Note: Not applicable for plugin backend (manages cache internally). + +**Benchmarking:** +- `--benchmark`: Enable benchmarking mode (uses random inputs instead of prompt). +- `--iterations`: Number of benchmark iterations. Default is 5. +- `--isl`: Input sequence length for benchmarking. Default is 2048. +- `--batch_size`: Batch size for benchmarking. Default is 1. - `--enable_pytorch_run`: Also run and compare PyTorch baseline. ### Quantization @@ -104,11 +188,74 @@ python run_llm.py --model meta-llama/Llama-3.1-8B --quant_format fp8 --prompt "W ### Caching Strategies +#### SDPA Backend (`--backend sdpa`) - **Static Cache v1/v2:** Adds static KV cache tensors as model inputs/outputs for efficient reuse. - **No Cache:** Standard autoregressive decoding. Please read our tutorial on how static cache is implemented. +#### IAttention Backend (`--backend iattention`) +The IAttention backend uses TensorRT's native `IAttention` layer for attention conversion. KV cache is **not yet supported** with this backend because the static cache FX passes (`static_cache_v1`/`v2`) look for `torch.nn.functional.scaled_dot_product_attention` nodes, which are only present after the SDPA lowering pass (used by the `sdpa` backend). The `--cache` option will be ignored if specified with `--backend iattention`. + +#### Plugin Backend (`--backend plugin`) +The plugin backend uses the TensorRT Edge-LLM AttentionPlugin which manages KV cache internally. The `--cache` option is not applicable and will be ignored if specified with `--backend plugin`. + +## Plugin Backend Setup + +To use the plugin backend (`--backend plugin`), you need to build the TensorRT Edge-LLM AttentionPlugin library. + +> **Note**: This implementation has been verified with TensorRT-Edge-LLM release 0.4.0. + +### Building the AttentionPlugin + +Currently, the plugin support requires a custom build from a feature branch: + +```bash +# Clone the repository with the torch-tensorrt-python-runtime feature +git clone -b feature/torch-tensorrt-python-runtime https://github.com/chohk88/TensorRT-Edge-LLM.git +cd TensorRT-Edge-LLM + +# Initialize submodules (required for nlohmann/json and googletest) +git submodule update --init --recursive + +# Build the plugin library +mkdir build && cd build + +# Configure with CMake (adjust paths based on your environment) +cmake .. -DTRT_PACKAGE_DIR=/usr -DCUDA_VERSION=12.9 + +# Build +make -j$(nproc) + +# The plugin library will be at: build/libNvInfer_edgellm_plugin.so +``` + +> **Note**: CMake configuration may vary depending on your system setup. Common options include: +> - `-DTRT_PACKAGE_DIR`: TensorRT installation directory (e.g., `/usr`, `/usr/local`) +> - `-DCUDA_VERSION`: CUDA version (e.g., `12.9`, `12.6`) +> +> Refer to the [TensorRT-Edge-LLM build documentation](https://github.com/chohk88/TensorRT-Edge-LLM/tree/feature/torch-tensorrt-python-runtime) for complete build instructions and dependencies. + +After building, the plugin path defaults to `/TensorRT-Edge-LLM/build/libNvInfer_edgellm_plugin.so`. You can override this by updating `DEFAULT_PLUGIN_PATH` in `plugin_utils.py`. + +### Performance + +In our internal testing on NVIDIA A100 (FP16), the backends show roughly the following speedup over PyTorch eager inference: + +- **SDPA backend (no cache):** ~1.3–1.7x faster than PyTorch +- **SDPA backend (static_v1 cache):** ~4–5x faster than PyTorch +- **Plugin backend:** ~11–15x faster than PyTorch, ~3x faster than SDPA with static cache +- **IAttention backend (no cache):** Currently slower than PyTorch for autoregressive generation since KV cache is not yet supported + +> Exact speedup depends on model size, sequence length, and hardware. The plugin backend achieves the highest throughput thanks to its fused attention+RoPE+KV-cache kernel. All backends produce outputs that match PyTorch for the tested models. + +### Additional Examples + +Two comprehensive examples are provided in `examples/dynamo/` to demonstrate plugin usage: + +- **`attention_plugin_example.py`**: Standalone example showing how to use the AttentionPlugin with custom models +- **`end_to_end_llm_generation_example.py`**: End-to-end LLM generation example with plugin integration + ## Extension This codebase can be extended to @@ -129,4 +276,4 @@ This codebase can be extended to - `pip install qwen-vl-utils` (for Qwen2.5-VL-3B-Instruct model) - **Flash Attention**: For models using flash attention operations (e.g., Eagle2-2B), install one of the following: - **Fast installation (recommended)**: `pip install flash-attn==2.8.1` (pre-built wheel, should work) - - **Source build (slow)**: `pip install flash-attn --no-build-isolation -v` (fallback if pre-built wheels fail) \ No newline at end of file + - **Source build (slow)**: `MAX_JOBS=8 pip install flash-attn --no-build-isolation -v` (fallback if pre-built wheels fail) diff --git a/tools/llm/plugin_converter.py b/tools/llm/plugin_converter.py new file mode 100644 index 0000000000..49b58f5891 --- /dev/null +++ b/tools/llm/plugin_converter.py @@ -0,0 +1,97 @@ +""" +TensorRT converter for Edge-LLM attention plugin ops. + +This module contains the TensorRT converter for the tensorrt_edge_llm::xqa_attn +custom op. It is kept in a separate file from plugin_utils.py for maintainability. +""" + +import numpy as np +import tensorrt as trt +from plugin_utils import get_plugin_config, register_plugin_op +from torch_tensorrt.dynamo.conversion import ( + ConversionContext, + dynamo_tensorrt_converter, +) +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor + +# Ensure the custom op is registered before the converter decorator runs +register_plugin_op() + +import torch # noqa: E402 (must be after register_plugin_op so the op exists) + + +@dynamo_tensorrt_converter( + torch.ops.tensorrt_edge_llm.xqa_attn.default, supports_dynamic_shapes=True +) +def convert_attn(ctx: ConversionContext, target, args, kwargs, name): + """ + Convert tensorrt_edge_llm::xqa_attn op to TensorRT AttentionPlugin. + + TensorRT-Edge-LLM (0.4.0) plugin requires 5 inputs: + - qkv, kv, ctx_len, rope, kv_cache_start_idx + + Plugin fields: + - num_q_heads, num_kv_heads, head_size, enable_tree_attention, enable_delta_kv_output + """ + # args: qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d + qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d = args[:8] + + creator = trt.get_plugin_registry().get_plugin_creator("AttentionPlugin", "1", "") + if creator is None: + raise RuntimeError("AttentionPlugin not found in TensorRT plugin registry!") + + # Get config from global settings + config = get_plugin_config() + if config: + nq_val = config["num_attention_heads"] + nkv_val = config["num_key_value_heads"] + d_val = config["head_dim"] + else: + # Fallback to values from args (may not work correctly) + nq_val = nq if isinstance(nq, int) else 14 + nkv_val = nkv if isinstance(nkv, int) else 2 + d_val = d if isinstance(d, int) else 64 + + # Plugin fields for TensorRT-Edge-LLM AttentionPlugin + # Required: num_q_heads, num_kv_heads, head_size, enable_tree_attention + # enable_delta_kv_output=1 enables delta KV output for Python/torch_tensorrt compatibility + field_list = [ + trt.PluginField( + field_name, np.array([field_val], dtype=np.int32), trt.PluginFieldType.INT32 + ) + for field_name, field_val in [ + ("num_q_heads", nq_val), + ("num_kv_heads", nkv_val), + ("head_size", d_val), + ("enable_tree_attention", 0), + ("enable_delta_kv_output", 1), + ] + ] + + fields = trt.PluginFieldCollection(field_list) + plugin = creator.create_plugin(name, fields) + + # 5 inputs for release version: qkv, kv, ctx_len, rope, kv_cache_start_idx + inputs = [ + ( + get_trt_tensor(ctx, i, f"{name}_i{idx}") + if not isinstance(i, trt.ITensor) + else i + ) + for idx, i in enumerate([qkv, kv, ctx_len, rope, kv_cache_start_idx]) + ] + + # Handle ctx_len shape if needed (squeeze if [B, 1] -> [B]) + if len(inputs[2].shape) == 2 and inputs[2].shape[1] == 1: + shuffle_layer = ctx.net.add_shuffle(inputs[2]) + shuffle_layer.reshape_dims = (inputs[2].shape[0],) + inputs[2] = shuffle_layer.get_output(0) + + # Handle kv_cache_start_idx shape if needed (squeeze if [B, 1] -> [B]) + if len(inputs[4].shape) == 2 and inputs[4].shape[1] == 1: + shuffle_layer = ctx.net.add_shuffle(inputs[4]) + shuffle_layer.reshape_dims = (inputs[4].shape[0],) + inputs[4] = shuffle_layer.get_output(0) + + layer = ctx.net.add_plugin_v2(inputs, plugin) + return layer.get_output(0), layer.get_output(1) diff --git a/tools/llm/plugin_utils.py b/tools/llm/plugin_utils.py new file mode 100644 index 0000000000..bcc547dd39 --- /dev/null +++ b/tools/llm/plugin_utils.py @@ -0,0 +1,868 @@ +""" +Plugin utilities for TensorRT LLM inference with custom attention plugins. + +This module provides model-agnostic utilities for using TensorRT attention plugins +with various LLM architectures (Qwen, Llama, etc.). +""" + +import ctypes +import inspect +import os +from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional, Tuple, Type + +import numpy as np +import tensorrt as trt +import torch +import torch.nn as nn +import torch_tensorrt + +# Default plugin path - can be overridden +# Built from: https://github.com/chohk88/TensorRT-Edge-LLM/tree/feature/torch-tensorrt-python-runtime +DEFAULT_PLUGIN_PATH = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "TensorRT-Edge-LLM", + "build", + "libNvInfer_edgellm_plugin.so", +) + +# Global configuration for plugin converter +_PLUGIN_CONFIG: Dict[str, Any] = {} + + +def load_plugin(plugin_path: Optional[str] = None) -> bool: + """ + Load the TensorRT attention plugin library. + + Args: + plugin_path: Path to the plugin .so file. If None, uses DEFAULT_PLUGIN_PATH. + + Returns: + True if plugin was loaded successfully, False otherwise. + + Raises: + RuntimeError: If plugin file does not exist. + """ + path = plugin_path or DEFAULT_PLUGIN_PATH + if not os.path.exists(path): + raise RuntimeError(f"Plugin not found at {path}") + ctypes.CDLL(path) + return True + + +def set_plugin_config( + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + max_seq_len: int = 2048, + max_batch_size: int = 4, +) -> None: + """ + Set global configuration for the plugin converter. + + Args: + num_attention_heads: Number of query attention heads. + num_key_value_heads: Number of key/value attention heads (for GQA). + head_dim: Dimension of each attention head. + max_seq_len: Maximum sequence length for KV cache. + max_batch_size: Maximum batch size. + """ + global _PLUGIN_CONFIG + _PLUGIN_CONFIG = { + "num_attention_heads": num_attention_heads, + "num_key_value_heads": num_key_value_heads, + "head_dim": head_dim, + "max_seq_len": max_seq_len, + "max_batch_size": max_batch_size, + } + + +def get_plugin_config() -> Dict[str, Any]: + """Get the current plugin configuration.""" + return _PLUGIN_CONFIG.copy() + + +def set_plugin_config_from_model(model_config: Any, max_seq_len: int = 2048) -> None: + """ + Set plugin configuration from a HuggingFace model config. + + Args: + model_config: HuggingFace model configuration object. + max_seq_len: Maximum sequence length for KV cache. + """ + # Qwen3 has explicit head_dim in config that differs from hidden_size // num_attention_heads + if hasattr(model_config, "head_dim") and model_config.head_dim is not None: + head_dim = model_config.head_dim + else: + head_dim = model_config.hidden_size // model_config.num_attention_heads + + set_plugin_config( + num_attention_heads=model_config.num_attention_heads, + num_key_value_heads=model_config.num_key_value_heads, + head_dim=head_dim, + max_seq_len=max_seq_len, + ) + + +# ----------------------------------------------------------------------------- +# Plugin Op Registration +# ----------------------------------------------------------------------------- + + +def _register_plugin_op_impl() -> None: + """ + Internal implementation to register the tensorrt_edge_llm::xqa_attn custom op for PyTorch. + + The TensorRT-Edge-LLM plugin (0.4.0-based) requires 5 inputs: + - qkv: [B, S, (Hq+Hk+Hv)*D] fused QKV tensor + - kv: [B, 2, Hkv, Capacity, D] KV cache tensor + - ctx_len: [B] context length per batch + - rope: [1, MaxSeqLen, RotaryDim] rotary position encoding + - kv_cache_start_idx: [B] starting index in KV cache + + With enable_delta_kv_output=1, output KV shape is [B, 2, H, SeqLen, D] (delta only). + """ + + @torch.library.custom_op("tensorrt_edge_llm::xqa_attn", mutates_args=()) + def attn( + qkv: torch.Tensor, + kv: torch.Tensor, + ctx_len: torch.Tensor, + rope: torch.Tensor, + kv_cache_start_idx: torch.Tensor, + nq: int, + nkv: int, + d: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = qkv.shape[0] + seq_len = qkv.shape[1] + attn_out = torch.zeros( + batch_size, seq_len, nq, d, dtype=qkv.dtype, device=qkv.device + ) + # Delta KV output: shape is [B, 2, H, SeqLen, D] + updated_kv = torch.zeros( + batch_size, 2, nkv, seq_len, d, dtype=qkv.dtype, device=qkv.device + ) + return attn_out, updated_kv + + @torch.library.register_fake("tensorrt_edge_llm::xqa_attn") + def _(qkv, kv, ctx_len, rope, kv_cache_start_idx, nq, nkv, d): + batch_size = qkv.shape[0] + seq_len = qkv.shape[1] + attn_out = torch.empty( + batch_size, seq_len, nq, d, dtype=qkv.dtype, device=qkv.device + ) + # Delta KV output + updated_kv = torch.empty( + batch_size, 2, nkv, seq_len, d, dtype=qkv.dtype, device=qkv.device + ) + return attn_out, updated_kv + + +def register_plugin_op() -> None: + """ + Register the tensorrt_edge_llm::xqa_attn custom op for PyTorch. + + This function is idempotent - safe to call multiple times. + """ + if hasattr(torch.ops, "tensorrt_edge_llm") and hasattr( + torch.ops.tensorrt_edge_llm, "xqa_attn" + ): + return + _register_plugin_op_impl() + + +# Register the op at module import time so the converter decorator works +# This is safe because the op registration is idempotent +if not ( + hasattr(torch.ops, "tensorrt_edge_llm") + and hasattr(torch.ops.tensorrt_edge_llm, "xqa_attn") +): + _register_plugin_op_impl() + +# The converter for tensorrt_edge_llm::xqa_attn is defined in plugin_converter.py. +# Import it here so that importing plugin_utils still registers the converter. +from plugin_converter import convert_attn # noqa: F401 + +# ----------------------------------------------------------------------------- +# RoPE Cache Generation +# ----------------------------------------------------------------------------- + + +def get_plugin_rope_cache( + rotary_emb: nn.Module, + max_seq_len: int, + head_dim: int, + device: torch.device, +) -> torch.Tensor: + """ + Generate RoPE cache tensor for the plugin from a rotary embedding module. + + Args: + rotary_emb: The rotary embedding module from the model. + max_seq_len: Maximum sequence length. + head_dim: Dimension of each attention head. + device: Device to create the cache on. + + Returns: + RoPE cache tensor of shape [1, max_seq_len, head_dim]. + """ + inv_freq = rotary_emb.inv_freq.to(device).float() + attention_scaling = getattr(rotary_emb, "attention_scaling", 1.0) + t = torch.arange(max_seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + cos_half = freqs.cos() * attention_scaling + sin_half = freqs.sin() * attention_scaling + rope = torch.cat([cos_half, sin_half], dim=-1) + return rope.unsqueeze(0) + + +# ----------------------------------------------------------------------------- +# Plugin Attention Module +# ----------------------------------------------------------------------------- + + +class PluginAttention(nn.Module): + """ + Model-agnostic Plugin Attention module that replaces standard attention. + + This module wraps the projection layers from the original attention module + and uses the tensorrt_edge_llm::xqa_attn plugin op for the attention computation. + + Supports: + - Qwen2.5, Llama: Standard attention + - Qwen3: Attention with QK Normalization (q_norm, k_norm) + """ + + def __init__( + self, + original_attn: nn.Module, + config: Any, + layer_idx: int, + rope_cache: torch.Tensor, + ): + """ + Initialize PluginAttention. + + Args: + original_attn: The original attention module to wrap. + config: Model configuration. + layer_idx: Index of this layer in the model. + rope_cache: Pre-computed RoPE cache tensor. + """ + super().__init__() + self.q_proj = original_attn.q_proj + self.k_proj = original_attn.k_proj + self.v_proj = original_attn.v_proj + self.o_proj = original_attn.o_proj + + # Qwen3 has QK Normalization + self.q_norm = getattr(original_attn, "q_norm", None) + self.k_norm = getattr(original_attn, "k_norm", None) + + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + + # Qwen3 has explicit head_dim that may differ from hidden_size // num_attention_heads + if hasattr(config, "head_dim") and config.head_dim is not None: + self.head_dim = config.head_dim + else: + self.head_dim = config.hidden_size // config.num_attention_heads + + # For Qwen3, attention output size is num_heads * head_dim, not hidden_size + self.attn_hidden_size = self.num_heads * self.head_dim + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.register_buffer("rope_cache", rope_cache) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[torch.Tensor] = None, + ctx_len: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass using the plugin attention. + + Args: + hidden_states: Input tensor of shape [batch, seq_len, hidden_size]. + attention_mask: Unused (plugin handles masking internally). + position_ids: Position IDs (unused, plugin uses RoPE cache). + past_key_value: KV cache tensor of shape [batch, 2, num_kv_heads, capacity, head_dim]. + ctx_len: Context length tensor for each batch item. + + Returns: + Tuple of (output tensor, updated KV cache). + """ + batch_size, seq_len, _ = hidden_states.shape + + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Qwen3: Apply QK Normalization if available + if self.q_norm is not None: + # Reshape for per-head normalization: [B, S, num_heads, head_dim] + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim) + q = self.q_norm(q) + q = q.view(batch_size, seq_len, -1) + + if self.k_norm is not None: + # Reshape for per-head normalization: [B, S, num_kv_heads, head_dim] + k = k.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim) + k = self.k_norm(k) + k = k.view(batch_size, seq_len, -1) + + qkv = torch.cat([q, k, v], dim=-1) + + if ctx_len is None: + ctx_len = torch.tensor( + [seq_len], dtype=torch.int32, device=hidden_states.device + ).expand(batch_size) + + rope_fp32 = self.rope_cache.float() + + if past_key_value is None: + raise ValueError("past_key_value (KV cache tensor) must be provided") + + # kv_cache_start_idx: starting position in KV cache for each batch + # For normal inference, this is 0 (start from beginning) + kv_cache_start_idx = torch.zeros( + batch_size, dtype=torch.int32, device=hidden_states.device + ) + + attn_out, updated_kv = torch.ops.tensorrt_edge_llm.xqa_attn.default( + qkv, + past_key_value, + ctx_len, + rope_fp32, + kv_cache_start_idx, + self.num_heads, + self.num_key_value_heads, + self.head_dim, + ) + + # Use attn_hidden_size for reshape (may differ from hidden_size in Qwen3) + attn_out = attn_out.reshape(batch_size, seq_len, self.attn_hidden_size) + output = self.o_proj(attn_out) + return output, updated_kv + + +# ----------------------------------------------------------------------------- +# Model Wrappers +# ----------------------------------------------------------------------------- + + +class LLMPluginWrapper(nn.Module): + """ + Generic wrapper for LLM models with plugin attention. + + This wrapper handles the forward pass for models with replaced attention modules, + managing KV caches and context lengths appropriately. + """ + + def __init__(self, model: nn.Module, model_type: str = "auto"): + """ + Initialize the wrapper. + + Args: + model: The model with replaced attention modules. + model_type: Type of model ("qwen", "llama", or "auto" for auto-detection). + """ + super().__init__() + self.model = model + self.model_type = ( + self._detect_model_type(model) if model_type == "auto" else model_type + ) + + def _detect_model_type(self, model: nn.Module) -> str: + """Auto-detect model type from model structure.""" + model_class = model.__class__.__name__.lower() + if "qwen" in model_class: + return "qwen" + elif "llama" in model_class or "mistral" in model_class: + return "llama" + else: + # Default to generic transformer structure + return "generic" + + def _get_transformer(self) -> nn.Module: + """Get the transformer backbone based on model type.""" + if self.model_type == "qwen": + return self.model.model + elif self.model_type == "llama": + return self.model.model + else: + # Try common attribute names + for attr in ["model", "transformer", "backbone"]: + if hasattr(self.model, attr): + return getattr(self.model, attr) + raise ValueError( + f"Cannot find transformer backbone for model type: {self.model_type}" + ) + + def _get_layers(self, transformer: nn.Module) -> nn.ModuleList: + """Get the list of transformer layers.""" + for attr in ["layers", "h", "blocks"]: + if hasattr(transformer, attr): + return getattr(transformer, attr) + raise ValueError("Cannot find transformer layers") + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + ctx_len: torch.Tensor, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Forward pass with plugin attention. + + Args: + input_ids: Input token IDs [batch, seq_len]. + position_ids: Position IDs [batch, seq_len]. + kv_caches: List of KV cache tensors, one per layer. + ctx_len: Context length tensor [batch]. + + Returns: + Tuple of (logits, list of updated KV caches). + """ + transformer = self._get_transformer() + hidden_states = transformer.embed_tokens(input_ids) + + layers = self._get_layers(transformer) + new_kv_caches = [] + + for i, layer in enumerate(layers): + past_key_value = kv_caches[i] + residual = hidden_states + + # Input layer norm + if hasattr(layer, "input_layernorm"): + hidden_states = layer.input_layernorm(hidden_states) + elif hasattr(layer, "ln_1"): + hidden_states = layer.ln_1(hidden_states) + + # Self attention + hidden_states, updated_kv = layer.self_attn( + hidden_states=hidden_states, + attention_mask=None, + position_ids=position_ids, + past_key_value=past_key_value, + ctx_len=ctx_len, + ) + hidden_states = residual + hidden_states + + # Post attention layer norm + MLP + residual = hidden_states + if hasattr(layer, "post_attention_layernorm"): + hidden_states = layer.post_attention_layernorm(hidden_states) + elif hasattr(layer, "ln_2"): + hidden_states = layer.ln_2(hidden_states) + hidden_states = layer.mlp(hidden_states) + hidden_states = residual + hidden_states + + new_kv_caches.append(updated_kv) + + # Final layer norm + if hasattr(transformer, "norm"): + hidden_states = transformer.norm(hidden_states) + elif hasattr(transformer, "ln_f"): + hidden_states = transformer.ln_f(hidden_states) + + # LM head + logits = self.model.lm_head(hidden_states) + + return logits, new_kv_caches + + +# ----------------------------------------------------------------------------- +# Model Modification Functions +# ----------------------------------------------------------------------------- + + +def replace_attention_with_plugin( + model: nn.Module, + config: Any, + max_seq_len: int, + device: torch.device, + dtype: torch.dtype = torch.float16, +) -> nn.Module: + """ + Replace all attention modules in a model with PluginAttention. + + Args: + model: The HuggingFace model to modify. + config: Model configuration. + max_seq_len: Maximum sequence length for RoPE cache. + device: Device for the model. + dtype: Data type for the model. + + Returns: + The modified model with plugin attention. + """ + # Get rotary embedding from model + transformer = model.model if hasattr(model, "model") else model + + # Try to find rotary embedding + rotary_emb = None + if hasattr(transformer, "rotary_emb"): + rotary_emb = transformer.rotary_emb + elif hasattr(transformer, "layers") and len(transformer.layers) > 0: + first_layer = transformer.layers[0] + if hasattr(first_layer, "self_attn") and hasattr( + first_layer.self_attn, "rotary_emb" + ): + rotary_emb = first_layer.self_attn.rotary_emb + + if rotary_emb is None: + raise ValueError("Cannot find rotary embedding in model") + + # Qwen3 has explicit head_dim that may differ from hidden_size // num_attention_heads + if hasattr(config, "head_dim") and config.head_dim is not None: + head_dim = config.head_dim + else: + head_dim = config.hidden_size // config.num_attention_heads + rope_cache = get_plugin_rope_cache(rotary_emb, max_seq_len, head_dim, device) + + # Get layers + if hasattr(transformer, "layers"): + layers = transformer.layers + elif hasattr(transformer, "h"): + layers = transformer.h + else: + raise ValueError("Cannot find transformer layers") + + # Replace attention modules + for i, layer in enumerate(layers): + layer.self_attn = PluginAttention(layer.self_attn, config, i, rope_cache) + + return model + + +# ----------------------------------------------------------------------------- +# Compilation +# ----------------------------------------------------------------------------- + + +def compile_plugin_model( + model: nn.Module, + config: Any, + max_seq_len: int, + device: torch.device, + dtype: torch.dtype = torch.float16, + debug: bool = False, +) -> Callable: + """ + Compile a model with plugin attention for TensorRT inference. + + Args: + model: The wrapped model (should be LLMPluginWrapper or similar). + config: Model configuration. + max_seq_len: Maximum sequence length. + device: Device for compilation. + dtype: Data type. + debug: Whether to enable debug logging. + + Returns: + Compiled TensorRT model function. + """ + # Prepare dummy inputs + num_layers = config.num_hidden_layers + num_kv_heads = config.num_key_value_heads + # Qwen3 has explicit head_dim that may differ from hidden_size // num_attention_heads + if hasattr(config, "head_dim") and config.head_dim is not None: + head_dim = config.head_dim + else: + head_dim = config.hidden_size // config.num_attention_heads + + dummy_input_ids = torch.tensor([[1, 2, 3]], device=device) + dummy_pos_ids = torch.tensor([[0, 1, 2]], device=device) + dummy_ctx_len = torch.tensor([3], dtype=torch.int32, device=device) + dummy_kvs = [ + torch.zeros( + 1, 2, num_kv_heads, max_seq_len, head_dim, dtype=dtype, device=device + ) + for _ in range(num_layers) + ] + + # Dynamic shapes + seq_len_dim = torch.export.Dim("seq_len", min=1, max=max_seq_len) + kv_cache_dynamics = [{}] * num_layers + dynamic_shapes = { + "input_ids": {1: seq_len_dim}, + "position_ids": {1: seq_len_dim}, + "kv_caches": kv_cache_dynamics, + "ctx_len": {}, + } + + # Export + ep = torch.export.export( + model, + args=(dummy_input_ids, dummy_pos_ids, dummy_kvs, dummy_ctx_len), + dynamic_shapes=dynamic_shapes, + strict=False, + ) + + # Compile + with torch_tensorrt.dynamo.Debugger() if debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[dummy_input_ids, dummy_pos_ids, dummy_kvs, dummy_ctx_len], + use_explicit_typing=True, + use_fp32_acc=True, + device=device, + disable_tf32=True, + min_block_size=1, + ) + + return trt_model + + +# ----------------------------------------------------------------------------- +# KV Cache Utilities +# ----------------------------------------------------------------------------- + + +def create_kv_caches( + config: Any, + max_seq_len: int, + batch_size: int, + device: torch.device, + dtype: torch.dtype = torch.float16, +) -> List[torch.Tensor]: + """ + Create empty KV cache tensors for all layers. + + Args: + config: Model configuration. + max_seq_len: Maximum sequence length (capacity). + batch_size: Batch size. + device: Device to create tensors on. + dtype: Data type for the tensors. + + Returns: + List of KV cache tensors, one per layer. + """ + num_layers = config.num_hidden_layers + num_kv_heads = config.num_key_value_heads + # Qwen3 has explicit head_dim that may differ from hidden_size // num_attention_heads + if hasattr(config, "head_dim") and config.head_dim is not None: + head_dim = config.head_dim + else: + head_dim = config.hidden_size // config.num_attention_heads + + return [ + torch.zeros( + batch_size, + 2, + num_kv_heads, + max_seq_len, + head_dim, + dtype=dtype, + device=device, + ) + for _ in range(num_layers) + ] + + +# ----------------------------------------------------------------------------- +# Generation Utilities +# ----------------------------------------------------------------------------- + + +def generate_with_plugin( + model_func: Callable, + input_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + max_new_tokens: int, + eos_token_id: Optional[int] = None, + device: torch.device = torch.device("cuda:0"), +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Generate tokens using the plugin model. + + Args: + model_func: The compiled model function. + input_ids: Input token IDs [batch, seq_len]. + kv_caches: List of KV cache tensors. + max_new_tokens: Maximum number of new tokens to generate. + eos_token_id: EOS token ID for early stopping (optional). + device: Device for computation. + + Returns: + Tuple of (generated token IDs, updated KV caches). + """ + generated_ids = input_ids.clone() + seq_len = input_ids.shape[1] + + # Prefill + position_ids = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(0) + ctx_len = torch.tensor([seq_len], dtype=torch.int32, device=device) + + output = model_func(input_ids, position_ids, kv_caches, ctx_len) + + if isinstance(output, (tuple, list)): + if len(output) == 2: + logits, delta_kvs = output + else: + logits = output[0] + delta_kvs = output[1:] + else: + logits = output + delta_kvs = [] + + # Update KV caches + if len(delta_kvs) > 0: + for i, delta in enumerate(delta_kvs): + seq_len_out = delta.shape[3] + kv_caches[i][:, :, :, :seq_len_out, :] = delta + + next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0) + generated_ids = torch.cat([generated_ids, next_token], dim=1) + + # Check for EOS + if eos_token_id is not None and next_token.item() == eos_token_id: + return generated_ids, kv_caches + + # Decode + cur_pos = seq_len + + for _ in range(max_new_tokens - 1): + input_ids_step = next_token + position_ids_step = torch.tensor([[cur_pos]], dtype=torch.long, device=device) + ctx_len_step = torch.tensor([cur_pos + 1], dtype=torch.int32, device=device) + + output = model_func(input_ids_step, position_ids_step, kv_caches, ctx_len_step) + + if isinstance(output, (tuple, list)): + if len(output) == 2: + logits, delta_kvs = output + else: + logits = output[0] + delta_kvs = output[1:] + + # Update KV caches + if len(delta_kvs) > 0: + for i, delta in enumerate(delta_kvs): + kv_caches[i][:, :, :, cur_pos : cur_pos + 1, :] = delta + + next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0) + generated_ids = torch.cat([generated_ids, next_token], dim=1) + cur_pos += 1 + + # Check for EOS + if eos_token_id is not None and next_token.item() == eos_token_id: + break + + return generated_ids, kv_caches + + +def benchmark_plugin_generation( + model_func: Callable, + config: Any, + isl: int, + osl: int, + max_seq_len: int, + device: torch.device, + dtype: torch.dtype = torch.float16, + run_name: str = "Plugin", +) -> float: + """ + Benchmark plugin model generation. + + Args: + model_func: The compiled model function. + config: Model configuration. + isl: Input sequence length. + osl: Output sequence length (number of tokens to generate). + max_seq_len: Maximum sequence length for KV cache. + device: Device for computation. + dtype: Data type. + run_name: Name for logging. + + Returns: + Elapsed time in milliseconds. + """ + # Check for extra kwargs the model might need + extra_kwargs = {} + if hasattr(model_func, "forward"): + sig = inspect.signature(model_func.forward) + if "arg_start_idx" in sig.parameters: + extra_kwargs["arg_start_idx"] = 0 + if "arg_end_idx" in sig.parameters: + extra_kwargs["arg_end_idx"] = 0 + + # Prepare inputs + input_ids = torch.randint(0, config.vocab_size, (1, isl), device=device) + kv_caches = create_kv_caches(config, max_seq_len, 1, device, dtype) + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + + # Prefill + seq_len = isl + position_ids = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(0) + ctx_len = torch.tensor([seq_len], dtype=torch.int32, device=device) + + output = model_func(input_ids, position_ids, kv_caches, ctx_len, **extra_kwargs) + + if isinstance(output, (tuple, list)): + if len(output) == 2: + logits, delta_kvs = output + else: + logits = output[0] + delta_kvs = output[1:] + else: + logits = output + delta_kvs = [] + + # Update KV caches + if len(delta_kvs) > 0: + for i, delta in enumerate(delta_kvs): + seq_len_out = delta.shape[3] + kv_caches[i][:, :, :, :seq_len_out, :] = delta + + next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0) + + # Decode + cur_pos = seq_len + + for _ in range(osl - 1): + input_ids_step = next_token + position_ids_step = torch.tensor([[cur_pos]], dtype=torch.long, device=device) + ctx_len_step = torch.tensor([cur_pos + 1], dtype=torch.int32, device=device) + + output = model_func( + input_ids_step, position_ids_step, kv_caches, ctx_len_step, **extra_kwargs + ) + + if isinstance(output, (tuple, list)): + if len(output) == 2: + logits, delta_kvs = output + else: + logits = output[0] + delta_kvs = output[1:] + + # Update KV caches + if len(delta_kvs) > 0: + for i, delta in enumerate(delta_kvs): + kv_caches[i][:, :, :, cur_pos : cur_pos + 1, :] = delta + + next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0) + cur_pos += 1 + + end_event.record() + torch.cuda.synchronize() + + elapsed_ms = start_event.elapsed_time(end_event) + print( + f"{run_name} | ISL: {isl}, OSL: {osl} | Total Time: {elapsed_ms:.2f} ms | Tokens/sec: {osl / (elapsed_ms / 1000.0):.2f}" + ) + return elapsed_ms diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index 1053577f37..710149761e 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -19,12 +19,18 @@ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ import torch import torch_tensorrt -from modelopt.torch.quantization.utils import export_torch_mode -from quantize_utils import ( - convert_linear_to_tensorrt_quantized, - load_quantization_config, - quantize_model, -) + +try: + from modelopt.torch.quantization.utils import export_torch_mode + from quantize_utils import ( + convert_linear_to_tensorrt_quantized, + load_quantization_config, + quantize_model, + ) + + QUANTIZATION_AVAILABLE = True +except ImportError: + QUANTIZATION_AVAILABLE = False from torchtrt_ext import register_sdpa from transformers import AutoModelForCausalLM, AutoTokenizer from utils import ( @@ -35,6 +41,24 @@ time_generate, ) +# Import plugin utilities (optional) +try: + from plugin_utils import ( + LLMPluginWrapper, + benchmark_plugin_generation, + compile_plugin_model, + create_kv_caches, + generate_with_plugin, + load_plugin, + register_plugin_op, + replace_attention_with_plugin, + set_plugin_config_from_model, + ) + + PLUGIN_AVAILABLE = True +except ImportError as e: + PLUGIN_AVAILABLE = False + DEVICE = torch.device("cuda:0") @@ -56,31 +80,40 @@ def get_model(args): moved to CUDA device with the specified precision """ with torch.no_grad(): + # For plugin backend, we don't set attn_implementation + attn_impl_kwargs = {} + if args.backend in ("sdpa", "iattention"): + attn_impl_kwargs["attn_implementation"] = "sdpa" + model = ( AutoModelForCausalLM.from_pretrained( args.model, use_cache=False, - attn_implementation="sdpa", ignore_mismatched_sizes=True, + **attn_impl_kwargs, ) .eval() .cuda() ) - # register SDPA variant for the model - register_sdpa.enable_sdpa_converter(args.model, model.config) - - hf_quant_config = load_quantization_config(args.model) - if hf_quant_config: - model = convert_linear_to_tensorrt_quantized( - model, args.model_precision, hf_quant_config - ).cuda() - print( - f"Model is {hf_quant_config['quant_algo']} pre-quantized hf model. Quantized linear layers are applied" - ) - if args.quant_format: - raise RuntimeError( - f"Quantization cannot be applied for pre-quantized hf model" + # Register SDPA lowering pass only for sdpa backend. + # For iattention backend, the core TRT IAttention converters handle SDPA ops + # directly without needing the custom lowering pass. + if args.backend == "sdpa": + register_sdpa.enable_sdpa_converter(args.model, model.config) + + if QUANTIZATION_AVAILABLE: + hf_quant_config = load_quantization_config(args.model) + if hf_quant_config: + model = convert_linear_to_tensorrt_quantized( + model, args.model_precision, hf_quant_config + ).cuda() + print( + f"Model is {hf_quant_config['quant_algo']} pre-quantized hf model. Quantized linear layers are applied" ) + if args.quant_format: + raise RuntimeError( + f"Quantization cannot be applied for pre-quantized hf model" + ) if args.model_precision == "FP16": model = model.to(torch.float16) @@ -114,7 +147,10 @@ def compile_torchtrt(model, input_ids, args): for optimized inference """ max_seq_len = input_ids.shape[1] + args.num_tokens - with export_torch_mode(): + if QUANTIZATION_AVAILABLE: + with export_torch_mode(): + ep = export_llm(model, input_ids, max_seq_len=max_seq_len) + else: ep = export_llm(model, input_ids, max_seq_len=max_seq_len) position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) # Set precision specific flags @@ -217,6 +253,15 @@ def measure_perf(trt_model, input_signature, backend_name): default="FP16", help="Precision to use in the model. Options: FP16, BF16, FP32", ) + arg_parser.add_argument( + "--backend", + type=str, + default="sdpa", + help="Backend to use. Options: sdpa, iattention, plugin. " + "'sdpa' uses custom SDPA lowering pass + converter (matmul+softmax+matmul). " + "'iattention' uses TensorRT native IAttention layer (no KV cache support yet). " + "'plugin' uses TensorRT Edge-LLM attention plugin with built-in KV cache.", + ) arg_parser.add_argument( "--iterations", type=int, default=5, help="no. of iterations to run" ) @@ -275,6 +320,26 @@ def measure_perf(trt_model, input_signature, backend_name): ) args = arg_parser.parse_args() + # Validate arguments + if args.backend not in ("sdpa", "iattention", "plugin"): + raise ValueError( + f"Unknown backend '{args.backend}'. Options: sdpa, iattention, plugin" + ) + if args.backend == "plugin" and not PLUGIN_AVAILABLE: + raise RuntimeError( + "Plugin backend requested but plugin utilities are not available." + ) + if args.cache and args.backend == "plugin": + print("Warning: --cache is only applicable with 'sdpa' backend. Ignoring.") + args.cache = "" + if args.cache and args.backend == "iattention": + print( + "Warning: --cache is not supported with 'iattention' backend " + "(static_cache passes are incompatible with native IAttention converters). " + "Ignoring --cache." + ) + args.cache = "" + with torch.inference_mode(): model = get_model(args) @@ -299,7 +364,11 @@ def measure_perf(trt_model, input_signature, backend_name): pyt_timings = None pyt_stats = None - if args.quant_format != None: + if args.quant_format is not None: + if not QUANTIZATION_AVAILABLE: + raise RuntimeError( + "Quantization requested but modelopt is not installed." + ) model = quantize_model(model, args, tokenizer) if args.enable_pytorch_run: pyt_gen_tokens = generate( @@ -322,54 +391,117 @@ def measure_perf(trt_model, input_signature, backend_name): compile_time_s=None, ) - if args.cache == "static_v1": - # This import is required to register static v1 KV cache transformations as lowering passes - import static_cache_v1 - if args.cache == "static_v2": - # This import is required to register static v2 KV cache transformations as lowering passes - import static_cache_v2 + # Backend selection: sdpa, iattention, or plugin + if args.backend == "plugin": + # Plugin backend + if not PLUGIN_AVAILABLE: + raise RuntimeError("Plugin backend requested but not available") + + dtype = ( + torch.float16 + if args.model_precision == "FP16" + else ( + torch.bfloat16 if args.model_precision == "BF16" else torch.float32 + ) + ) + config = model.config + max_seq_len = max(2048, MAX_OUTPUT_SEQ_LENGTH) + + # Load plugin and register op + load_plugin() + register_plugin_op() + set_plugin_config_from_model(config, max_seq_len) + + # Replace attention with plugin + model = replace_attention_with_plugin( + model, config, max_seq_len, DEVICE, dtype + ) + wrapper = LLMPluginWrapper(model) - # Compile the model with Torch-TensorRT - trt_model = compile_torchtrt(model, input_ids, args) + # Compile plugin model + trt_model = compile_plugin_model( + wrapper, config, max_seq_len, DEVICE, dtype, args.debug + ) - if args.cache == "static_v1" or args.cache == "static_v2": - if args.cudagraph: - # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. - # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) - torch_tensorrt.runtime.set_cudagraphs_mode(True) + # Create KV caches + kv_caches = create_kv_caches( + config, max_seq_len, args.batch_size, DEVICE, dtype + ) - trt_gen_tokens = generate_with_static_cache( + # Generate + trt_gen_tokens, _ = generate_with_plugin( trt_model, input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, + kv_caches, + args.num_tokens, tokenizer.eos_token_id, + DEVICE, ) if args.benchmark: - trt_timings = time_generate( - generate_with_static_cache, + trt_timings = [] + for i in range(args.iterations): + elapsed_ms = benchmark_plugin_generation( + trt_model, + config, + input_ids.shape[1], + args.num_tokens, + max_seq_len, + DEVICE, + dtype, + ) + trt_timings.append(elapsed_ms / 1000.0) + else: + # SDPA or IAttention backend + # For iattention, args.cache is already cleared by validation above. + if args.cache == "static_v1": + # This import is required to register static v1 KV cache transformations as lowering passes + import static_cache_v1 + if args.cache == "static_v2": + # This import is required to register static v2 KV cache transformations as lowering passes + import static_cache_v2 + + # Compile the model with Torch-TensorRT + trt_model = compile_torchtrt(model, input_ids, args) + + if args.cache == "static_v1" or args.cache == "static_v2": + if args.cudagraph: + # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases. + # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model) + torch_tensorrt.runtime.set_cudagraphs_mode(True) + + trt_gen_tokens = generate_with_static_cache( trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, - iterations=args.iterations, ) - else: - trt_gen_tokens = generate( - trt_model, - input_ids.clone(), - MAX_OUTPUT_SEQ_LENGTH, - tokenizer.eos_token_id, - ) - if args.benchmark: - trt_timings = time_generate( - generate, + + if args.benchmark: + trt_timings = time_generate( + generate_with_static_cache, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) + else: + trt_gen_tokens = generate( trt_model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id, - iterations=args.iterations, ) + if args.benchmark: + trt_timings = time_generate( + generate, + trt_model, + input_ids.clone(), + MAX_OUTPUT_SEQ_LENGTH, + tokenizer.eos_token_id, + iterations=args.iterations, + ) if args.benchmark: trt_stats = record_stats(