From 8d81c4edc07091a8153bc44d1de336d49a1b04d7 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Thu, 19 Feb 2026 17:30:25 +0000 Subject: [PATCH 1/8] feat(//py/torch_tensorrt/dynamo): Adding comprehensive support for complex numerics, including complex tensor I/O Introduce a new infrastructure in the replace complex pass to handle a number of cases where simply just unpacking complex tensors is not sufficent for supporting the numerics correctly. This pass also now captures meta data about the original call signature so that during graph construction, the original calling convention is preserved and the runtimes do not need any specialization on supporting complex types. --- examples/dynamo/torch_export_3d_rope.py | 369 +++ py/torch_tensorrt/dynamo/_compiler.py | 103 + .../dynamo/lowering/_SubgraphBuilder.py | 89 + py/torch_tensorrt/dynamo/lowering/__init__.py | 1 + .../lowering/passes/complex_graph_rewrite.py | 1303 ++++++++++- .../dynamo/partitioning/common.py | 17 +- py/torch_tensorrt/dynamo/utils.py | 37 +- tests/py/dynamo/hlo/__init__.py | 0 tests/py/dynamo/hlo/test_complex_ops.py | 2070 +++++++++++++++++ tests/py/dynamo/hlo/test_rope_embedding.py | 526 +++++ 10 files changed, 4416 insertions(+), 99 deletions(-) create mode 100644 examples/dynamo/torch_export_3d_rope.py create mode 100644 py/torch_tensorrt/dynamo/lowering/_SubgraphBuilder.py create mode 100644 tests/py/dynamo/hlo/__init__.py create mode 100644 tests/py/dynamo/hlo/test_complex_ops.py create mode 100644 tests/py/dynamo/hlo/test_rope_embedding.py diff --git a/examples/dynamo/torch_export_3d_rope.py b/examples/dynamo/torch_export_3d_rope.py new file mode 100644 index 0000000000..8851beb59e --- /dev/null +++ b/examples/dynamo/torch_export_3d_rope.py @@ -0,0 +1,369 @@ +""" +3D Rotary Position Embedding (RoPE) + Attention compiled with Torch-TensorRT +============================================================================= + +3D RoPE is the positional encoding used in video generation transformers such +as CogVideoX, Wan, and HunyuanVideo. Unlike 1D RoPE (used in language models) +which encodes a single sequence index, 3D RoPE independently encodes three +axes — temporal (T), height (H), and width (W) — and assigns each axis a +dedicated slice of the per-head frequency vector: + + head-dim slots 0 .. d//3-1 → temporal frequencies + head-dim slots d//3.. 2d//3-1 → height frequencies + head-dim slots 2d//3.. d//2-1 → width frequencies + +The rotation is expressed with complex arithmetic: + + xq_rotated = view_as_real(view_as_complex(xq) * freqs_cis) + +PyTorch complex ops (view_as_complex, complex mul) are not natively supported +by TensorRT. Torch-TensorRT's ``complex_graph_detection`` lowering pass +intercepts them before partitioning and rewrites the subgraph to equivalent +real arithmetic — splitting the last dimension into (..., 2) real/imag pairs +and computing (ac-bd, ad+bc) manually — so the TRT engine only sees standard +float32 ops and the caller never needs to change anything. + +This example: + 1. Defines a 3D-RoPE frequency precomputation helper (complex64 output). + 2. Defines a VideoAttentionBlock: linear QKV projection → 3D RoPE → SDPA. + 3. Runs a PyTorch baseline forward pass. + 4. Exports with torch.export.export() and dynamic T/H/W dimensions. + 5. Compiles to TensorRT via torch_tensorrt.dynamo.compile(). + 6. Verifies numerical accuracy (cosine similarity on the output tensor). + 7. (Optional) benchmarks latency of both backends. + +Usage +----- +# Quick correctness check (static shapes) +python examples/dynamo/torch_export_3d_rope.py + +# Dynamic T/H/W shapes +python examples/dynamo/torch_export_3d_rope.py --dynamic + +# Larger config + benchmark +python examples/dynamo/torch_export_3d_rope.py --heads 16 --head-dim 96 --t 8 --h 16 --w 16 --benchmark +""" + +import argparse +import timeit + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_tensorrt +from torch.export import Dim + +DEVICE = torch.device("cuda:0") + + +# --------------------------------------------------------------------------- +# Frequency precomputation +# --------------------------------------------------------------------------- + + +def precompute_freqs_3d( + head_dim: int, + t: int, + h: int, + w: int, + theta: float = 10000.0, +) -> torch.Tensor: + """Pre-compute 3D RoPE unit-complex frequency tensor. + + Returns a complex64 tensor of shape (t, h, w, head_dim // 2) where the + last dimension is split evenly across the three spatial axes. + + Args: + head_dim: Channels per attention head (must be even, head_dim//2 + must be divisible by 3). + t: Number of temporal frames. + h: Spatial height in patches. + w: Spatial width in patches. + theta: Base for the geometric frequency progression. + """ + half = head_dim // 2 + d_t = half // 3 + d_h = half // 3 + d_w = half - d_t - d_h # absorbs any remainder from integer division + + def _axis_freqs(d: int, n: int) -> torch.Tensor: + """1-D complex exponentials, shape (n, d).""" + inv_freq = 1.0 / (theta ** (torch.arange(0, d * 2, 2).float() / (d * 2))) + positions = torch.arange(n, dtype=torch.float32) + angles = torch.outer(positions, inv_freq) + return torch.polar(torch.ones_like(angles), angles) # complex64 + + freqs_t = _axis_freqs(d_t, t)[:, None, None, :].expand(t, h, w, d_t) + freqs_h = _axis_freqs(d_h, h)[None, :, None, :].expand(t, h, w, d_h) + freqs_w = _axis_freqs(d_w, w)[None, None, :, :].expand(t, h, w, d_w) + + # Concatenate along last dim → (t, h, w, half), complex64 + return torch.cat([freqs_t, freqs_h, freqs_w], dim=-1).contiguous() + + +# --------------------------------------------------------------------------- +# Model +# --------------------------------------------------------------------------- + + +class VideoAttentionBlock(nn.Module): + """Single attention block for video latents with 3D RoPE. + + Inputs + ------ + x : (B, T, H, W, C) float32 video patch features + freqs_cis_real: (T, H, W, C // n_heads) float32 + The RoPE frequency tensor pre-flattened from complex64 via + ``view_as_real(...).flatten(-2)``. The module reconstructs the + complex form internally with ``view_as_complex``. + + Passing frequencies as a plain real-valued input avoids exposing a + complex tensor at the model boundary (TRT inputs must be real). + + Output + ------ + (B, T, H, W, C) float32 + """ + + def __init__(self, channels: int = 512, n_heads: int = 8) -> None: + super().__init__() + assert channels % n_heads == 0 + self.n_heads = n_heads + self.head_dim = channels // n_heads + self.norm = nn.LayerNorm(channels) + self.qkv = nn.Linear(channels, 3 * channels, bias=False) + self.proj = nn.Linear(channels, channels, bias=False) + + def _apply_rope(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """Apply 3D RoPE to a single Q or K tensor. + + The complex multiply ``xc * freqs_cis`` is what Torch-TensorRT rewrites + to real arithmetic via the complex_graph_detection lowering pass. + + Args: + x : (B, T, H, W, n_heads, head_dim) float32 + freqs_cis: (T, H, W, head_dim // 2) complex64 + Returns: + Rotated tensor, same shape as ``x``, float32. + """ + B, T, H, W, Nh, D = x.shape + # Interpret consecutive pairs of head-dim channels as complex numbers. + xc = torch.view_as_complex(x.reshape(B, T, H, W, Nh, D // 2, 2)) + # freqs_cis broadcast over batch (dim 0) and head (dim 4). + freqs = freqs_cis[None, :, :, :, None, :] # (1, T, H, W, 1, D//2) + return torch.view_as_real(xc * freqs).flatten(-2) # (B,T,H,W,Nh,D) + + def forward( + self, + x: torch.Tensor, + freqs_cis_real: torch.Tensor, + ) -> torch.Tensor: + B, T, H, W, C = x.shape + Nh, D = self.n_heads, self.head_dim + + h = self.norm(x) + qkv = self.qkv(h).reshape(B, T, H, W, 3, Nh, D) + q, k, v = qkv.unbind(dim=4) # each (B, T, H, W, Nh, D) + + # Recover complex frequencies from the real-valued input. + # freqs_cis_real: (T, H, W, D) → reshape to (T, H, W, D//2, 2) → complex + freqs_cis = torch.view_as_complex(freqs_cis_real.reshape(T, H, W, D // 2, 2)) + + q = self._apply_rope(q, freqs_cis) + k = self._apply_rope(k, freqs_cis) + + # Flatten spatial dims for attention: (B, Nh, T*H*W, D) + N = T * H * W + q = q.reshape(B, N, Nh, D).permute(0, 2, 1, 3) + k = k.reshape(B, N, Nh, D).permute(0, 2, 1, 3) + v = v.reshape(B, N, Nh, D).permute(0, 2, 1, 3) + + out = F.scaled_dot_product_attention(q, k, v) # (B, Nh, N, D) + out = out.permute(0, 2, 1, 3).reshape(B, T, H, W, C) + return x + self.proj(out) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_inputs( + B: int, T: int, H: int, W: int, C: int, n_heads: int +) -> tuple[torch.Tensor, torch.Tensor]: + """Return (x, freqs_cis_real) on DEVICE.""" + x = torch.randn(B, T, H, W, C, dtype=torch.float32, device=DEVICE) + freqs_cis = precompute_freqs_3d(C // n_heads, t=T, h=H, w=W).to(DEVICE) + freqs_cis_real = torch.view_as_real(freqs_cis).flatten(-2) # (T,H,W,D) + return x, freqs_cis_real + + +def benchmark(fn, *args, iterations: int = 20, label: str = "") -> float: + fn(*args) # warmup + torch.cuda.synchronize() + total = 0.0 + for _ in range(iterations): + t0 = timeit.default_timer() + fn(*args) + torch.cuda.synchronize() + total += timeit.default_timer() - t0 + avg_ms = total / iterations * 1000 + print(f"[{label}] avg latency over {iterations} iters: {avg_ms:.2f} ms") + return avg_ms + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def parse_args(): + p = argparse.ArgumentParser( + description="3D RoPE attention block compiled with Torch-TensorRT" + ) + p.add_argument("--heads", type=int, default=8, help="Number of attention heads") + p.add_argument( + "--head-dim", + dest="head_dim", + type=int, + default=48, + help="Channels per head. head_dim//2 must be divisible by 3 (default: 48)", + ) + p.add_argument("--t", type=int, default=4, help="Temporal frames (default: 4)") + p.add_argument( + "--h", type=int, default=8, help="Spatial height patches (default: 8)" + ) + p.add_argument( + "--w", type=int, default=8, help="Spatial width patches (default: 8)" + ) + p.add_argument( + "--dynamic", + action="store_true", + help="Export with dynamic T/H/W dims and compile with min/opt/max shapes", + ) + p.add_argument( + "--benchmark", action="store_true", help="Benchmark PyTorch vs TRT latency" + ) + p.add_argument("--iterations", type=int, default=20) + return p.parse_args() + + +def main(): + args = parse_args() + + if (args.head_dim // 2) % 3 != 0: + raise ValueError( + f"head_dim // 2 = {args.head_dim // 2} must be divisible by 3 " + "for the T/H/W frequency split. Try --head-dim 48, 60, 96, or 192." + ) + + B, T, H, W = 1, args.t, args.h, args.w + C = args.heads * args.head_dim + + print(f"VideoAttentionBlock with 3D RoPE") + print(f" heads={args.heads} head_dim={args.head_dim} channels={C}") + print(f" input shape: ({B}, {T}, {H}, {W}, {C})") + + model = VideoAttentionBlock(channels=C, n_heads=args.heads).eval().to(DEVICE) + + # ------------------------------------------------------------------ + # 1. Build inputs + # ------------------------------------------------------------------ + x, freqs_cis_real = make_inputs(B, T, H, W, C, args.heads) + inputs = (x, freqs_cis_real) + print(f"\n x shape : {x.shape}") + print(f" freqs_cis_real shape: {freqs_cis_real.shape}") + + # ------------------------------------------------------------------ + # 2. PyTorch baseline + # ------------------------------------------------------------------ + with torch.inference_mode(): + pyt_out = model(*inputs) + print(f"\n--- PyTorch baseline ---") + print(f" output shape: {pyt_out.shape} dtype: {pyt_out.dtype}") + + # ------------------------------------------------------------------ + # 3. Export + # ------------------------------------------------------------------ + print("\nExporting model ...") + if args.dynamic: + t_dim = Dim("T", min=1, max=32) + h_dim = Dim("H", min=4, max=64) + w_dim = Dim("W", min=4, max=64) + dynamic_shapes = ( + # x: (B, T, H, W, C) + {1: t_dim, 2: h_dim, 3: w_dim}, + # freqs_cis_real: (T, H, W, D) + {0: t_dim, 1: h_dim, 2: w_dim}, + ) + ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + print(" Exported with dynamic T / H / W dimensions.") + else: + ep = torch.export.export(model, inputs) + print(" Exported with static shapes.") + + # ------------------------------------------------------------------ + # 4. Compile with Torch-TensorRT + # + # No special flags are required for the complex arithmetic rewrite. + # The complex_graph_detection lowering pass automatically detects + # view_as_complex / complex-mul / view_as_real subgraphs and rewrites + # them to real-arithmetic ops before the TRT engine is built. + # ------------------------------------------------------------------ + print("\nCompiling with Torch-TensorRT ...") + D = C // args.heads # freqs_cis_real last dim + if args.dynamic: + trt_inputs = [ + torch_tensorrt.Input( + min_shape=(B, 1, 4, 4, C), + opt_shape=(B, T, H, W, C), + max_shape=(B, 32, 64, 64, C), + dtype=torch.float32, + ), + torch_tensorrt.Input( + min_shape=(1, 4, 4, D), + opt_shape=(T, H, W, D), + max_shape=(32, 64, 64, D), + dtype=torch.float32, + ), + ] + else: + trt_inputs = list(inputs) + + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=trt_inputs, + enabled_precisions={torch.float32}, + min_block_size=1, + ) + + # ------------------------------------------------------------------ + # 5. TRT inference & accuracy check + # ------------------------------------------------------------------ + with torch.inference_mode(): + trt_out = trt_model(*inputs) + + pyt_flat = pyt_out.float().flatten() + trt_flat = trt_out.float().flatten() + cos_sim = (pyt_flat @ trt_flat / (pyt_flat.norm() * trt_flat.norm())).item() + max_diff = (pyt_out.float() - trt_out.float()).abs().max().item() + + print(f"\n--- TensorRT vs PyTorch ---") + print(f" output shape : {trt_out.shape}") + print(f" cosine sim : {cos_sim:.6f}") + print(f" max |Δ| : {max_diff:.2e}") + assert cos_sim > 0.99, f"Cosine similarity {cos_sim:.4f} below threshold 0.99!" + print(" PASSED") + + # ------------------------------------------------------------------ + # 6. (Optional) benchmark + # ------------------------------------------------------------------ + if args.benchmark: + print("\n--- Benchmarking ---") + with torch.inference_mode(): + benchmark(model, *inputs, iterations=args.iterations, label="PyTorch") + benchmark(trt_model, *inputs, iterations=args.iterations, label="TensorRT") + + +if __name__ == "__main__": + main() diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index bc3cdc5721..21d7e802c0 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -44,6 +44,7 @@ resource_partition, ) from torch_tensorrt.dynamo.utils import ( + COMPLEX_TO_REAL_DTYPE, deallocate_module, get_cpu_memory_usage, get_flat_args_with_check, @@ -801,6 +802,104 @@ def compile( return trt_gm +def _insert_complex_io_adapters( + partitioned_module: torch.fx.GraphModule, + gm: torch.fx.GraphModule, + settings: CompilationSettings, +) -> None: + """Insert view_as_real / view_as_complex boundary nodes for complex I/O. + + complex_graph_detection rewrites complex subgraphs to real arithmetic before + partitioning, but when a model has complex inputs or outputs the outer wrapper + graph still needs adapters at the TRT block boundary: + + Inputs: insert view_as_real (+ optional cast for complex128+truncate_double) + after each placeholder that was unpacked by the rewriter. + Outputs: insert view_as_complex before the output node for each originally-complex + output that comes from a TRT block. + + Leverages metadata that was captued when the complex rewriter pass was run + """ + complex_input_names = gm.meta.get("complex_input_names", []) + complex_input_dtypes = gm.meta.get("complex_input_dtypes", {}) + complex_output_indices = gm.meta.get("complex_output_indices", []) + + if not complex_input_names and not complex_output_indices: + return + + graph_modified = False + + # --- Input boundary: view_as_real for complex inputs --- + # complex_graph_detection renames complex placeholder 'foo' to 'foo_unpacked_complex' + # with float dtype. The outer graph still has 'foo_unpacked_complex' as a placeholder, + # but the caller passes the original complex tensor. Insert view_as_real after + # each such placeholder so the graph unpacks it transparently. + reshaped_names = {f"{n}_unpacked_complex" for n in complex_input_names} + for node in list(partitioned_module.graph.nodes): + if node.op != "placeholder" or node.name not in reshaped_names: + continue + with partitioned_module.graph.inserting_after(node): + real_node = partitioned_module.graph.call_function( + torch.ops.aten.view_as_real.default, args=(node,) + ) + # For complex128 with truncate_double, the rewriter produced float32 + # TRT engine inputs but view_as_real gives float64 — add an explicit cast. + orig_name = node.name[: -len("_unpacked_complex")] + orig_dtype = complex_input_dtypes.get(orig_name, None) + + if orig_dtype == torch.complex128 and settings.truncate_double: + logger.info( + f"Input '{orig_name}' is complex128 with truncate_double=True: unpacked " + f"float64 components will be cast to float32." + ) + with partitioned_module.graph.inserting_after(real_node): + cast_node = partitioned_module.graph.call_function( + torch.ops.aten.to.dtype, + args=(real_node, torch.float32), + ) + node.replace_all_uses_with(cast_node) + cast_node.args = (real_node, torch.float32) + real_node.args = (node,) + logger.info( + f"Inserted view_as_real + cast-to-float32 for complex128 input placeholder '{node.name}' (truncate_double=True)" + ) + else: + node.replace_all_uses_with(real_node) + # fix the self-reference created by replace_all_uses_with + real_node.args = (node,) + logger.info( + f"Inserted view_as_real for complex input placeholder '{node.name}'" + ) + graph_modified = True + + # --- Output boundary: view_as_complex for complex outputs from TRT blocks --- + if complex_output_indices: + output_node = list(partitioned_module.graph.nodes)[-1] + outputs = list(output_node.args[0]) + for idx in complex_output_indices: + if idx >= len(outputs): + continue + src = outputs[idx] + if not isinstance(src, torch.fx.Node): + continue + if src.op == "call_module" and "_run_on_acc" in str(src.target): + with partitioned_module.graph.inserting_before(output_node): + complex_node = partitioned_module.graph.call_function( + torch.ops.aten.view_as_complex.default, args=(src,) + ) + logger.info( + f"Inserted view_as_complex for complex output index {idx} " + f"from TRT block '{src.target}'" + ) + outputs[idx] = complex_node + graph_modified = True + output_node.args = (tuple(outputs),) + + if graph_modified: + partitioned_module.graph.lint() + partitioned_module.recompile() + + @fn_supports_debugger # type: ignore[misc] def compile_module( gm: torch.fx.GraphModule, @@ -1097,6 +1196,10 @@ def preserve_module_specs( trt_module = getattr(partitioned_module, name) trt_module.setup_engine() + # Post-partition complex I/O boundary pass — runs in both normal and dryrun mode + # so the wrapper graph reflects the exact graph that will be executed/built. + _insert_complex_io_adapters(partitioned_module, gm, settings) + # Only set output tensors as unowned if not in dryrun mode (TRT modules exist) if not settings.dryrun: output_node = list(partitioned_module.graph.nodes)[-1] diff --git a/py/torch_tensorrt/dynamo/lowering/_SubgraphBuilder.py b/py/torch_tensorrt/dynamo/lowering/_SubgraphBuilder.py new file mode 100644 index 0000000000..e5ef07dd7e --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/_SubgraphBuilder.py @@ -0,0 +1,89 @@ +"""Cursor-based FX graph node builder.""" + +from __future__ import annotations + +import logging +from types import TracebackType +from typing import List, Optional, Type + +import torch +import torch.fx +from torch.fx.node import Node + +logger = logging.getLogger(__name__) + + +def _fmt_node(n: object) -> str: + """Return a compact string summary of an FX node or a plain value.""" + if not isinstance(n, Node): + return repr(n) + val = n.meta.get("val", None) + if val is not None and hasattr(val, "shape") and hasattr(val, "dtype"): + return f"%{n.name}[{tuple(val.shape)},{val.dtype}]" + return f"%{n.name}" + + +def _fmt_args(args: tuple) -> str: + return "(" + ", ".join(_fmt_node(a) for a in args) + ")" + + +# NB: Its pretty tedious to go through and hand write all the graph insert afters +# Could not find a Pytorch utility that simplifies this so we have this class. I want +# remove it if we find a PyTorch alternative +class SubgraphBuilder: + """Cursor-based helper for inserting a sequence of FX ``call_function`` nodes. + + Construct it with the graph and an anchor node, then call it like a + function to append each new node immediately after the current cursor:: + + with SubgraphBuilder(graph, node) as b: + re = b(aten.select.int, inp, -1, 0) + im = b(aten.select.int, inp, -1, 1) + out = b(aten.add.Tensor, re, im) + + Each call inserts one ``call_function`` node right after the cursor and + advances the cursor to that node. Scalar / list arguments are forwarded + as-is. + + On ``__exit__`` the graph is linted to catch any malformed nodes inserted + during the block. Exceptions from user code propagate normally; lint + errors are only raised when the block itself succeeds. + """ + + __slots__ = ("_g", "_anchor_desc", "_cursor", "_inserted") + + def __init__(self, graph: torch.fx.Graph, cursor: Node) -> None: + self._g = graph + # Snapshot the description now — the anchor node is erased inside the block. + self._anchor_desc: str = _fmt_node(cursor) + self._cursor = cursor + self._inserted: List[Node] = [] + + @property + def cursor(self) -> Node: + return self._cursor + + def __call__(self, op: object, *args: object) -> Node: + with self._g.inserting_after(self._cursor): + node = self._g.call_function(op, args=args) + self._cursor = node + self._inserted.append(node) + return node + + def __enter__(self) -> "SubgraphBuilder": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if exc_type is None: + if logger.isEnabledFor(logging.DEBUG) and self._inserted: + lines = [f" rewrite {self._anchor_desc} ->"] + for n in self._inserted: + op_name = getattr(n.target, "__name__", str(n.target)) + lines.append(f" {_fmt_node(n)} = {op_name}{_fmt_args(n.args)}") + logger.debug("\n".join(lines)) + self._g.lint() diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py index bec5e407b5..3c73e42f86 100644 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -5,3 +5,4 @@ ) from ._decompositions import get_decompositions # noqa: F401 from .passes import * +from torch_tensorrt.dynamo.lowering._SubgraphBuilder import SubgraphBuilder diff --git a/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py b/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py index c3ead218aa..ea36a9deb5 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py @@ -1,17 +1,88 @@ import logging -from typing import Callable, List, Set, Tuple +import math +from typing import Callable, List, Optional, Set, Tuple import torch from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx import GraphModule, Node from torch.fx.experimental.proxy_tensor import unset_fake_temporarily + from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering._SubgraphBuilder import SubgraphBuilder from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, ) logger = logging.getLogger(__name__) +# Ops that are elementwise-safe on the [..., 2] real layout used to represent +# complex tensors. These ops apply independently to every scalar in the tensor +# (including both the real and imaginary components stored in the last dim) so +# no explicit rewrite is needed — the pass-through behaviour is correct. +# +# NOTE: add.Scalar / sub.Scalar are NOT in this set. (a+bi)+s = (a+s)+bi +# adds the scalar only to the real part, but on the [...,2] layout +# add.Scalar would add to both parts. Those need explicit rewrites. +_ELEMENTWISE_SAFE: frozenset = frozenset( + { + # Arithmetic — component-wise operations are correct by construction + torch.ops.aten.add.Tensor, + torch.ops.aten.sub.Tensor, + torch.ops.aten.neg.default, + torch.ops.aten.mul.Scalar, # scalar*(re,im) — both parts scaled equally + torch.ops.aten.div.Scalar, # (re,im)/scalar — both parts divided equally + # Structural / copy — operate on the whole tensor without touching content. + # Note: permute.default is NOT here; it needs an explicit rewrite to append + # the trailing real/imag dimension index to the dims list. + torch.ops.aten.clone.default, + torch.ops.aten.detach.default, + torch.ops.aten.alias.default, + torch.ops.aten.expand.default, + torch.ops.aten.t.default, + # Construction — producing zero/one tensors of the same shape is layout-neutral + torch.ops.aten.zeros_like.default, + torch.ops.aten.ones_like.default, + # Conditional selection — correct on the real layout when mask broadcasts + torch.ops.aten.where.self, + # Rounding — applies to each float independently; complex rounding is + # undefined in PyTorch so these only appear after the rewrite anyway + torch.ops.aten.ceil.default, + torch.ops.aten.floor.default, + torch.ops.aten.round.default, + torch.ops.aten.trunc.default, + } +) + + +def _complex_unpacker(*ops: object) -> Callable: + """Decorator that registers a rewrite method for a complex aten op into a real value subgraph. + + Usage:: + + @_complex_unpacker(aten.sin.default, aten.cos.default) + def _rewrite_sin_cos(self, node): ... + + The ops are stored on the function as ``._complex_unpacker_ops`` and picked up by + ``@_register_unpackers`` when the class is fully defined. + """ + + def decorator(fn: Callable) -> Callable: + fn._complex_unpacker_ops = ops + return fn + + return decorator + + +def _register_unpackers(cls: type) -> type: + """Class decorator that builds ``cls._DISPATCH`` from all methods tagged + with ``@_complex_unpacker``. Applied once at class-definition time.""" + dispatch: dict = {} + for attr in vars(cls).values(): + for op in getattr(attr, "_complex_unpacker_ops", ()): + dispatch[op] = attr + cls._DISPATCH = dispatch + return cls + class ComplexSubGraphInfo: def __init__( @@ -44,16 +115,19 @@ def is_complex_dtype(self, node: Node) -> bool: if hasattr(val, "dtype"): dtype = val.dtype - logger.debug(f"dtype of node: {dtype}") return dtype in {torch.complex64, torch.complex128} + def has_complex_input(self, node: Node) -> bool: + """Return True if any input to node has complex dtype.""" + return any(self.is_complex_dtype(inp) for inp in node.all_input_nodes) + def node_include_in_subgraph(self, node: Node) -> bool: - # Include only call_function ops on complex tensors - if node.op == "call_function" and self.is_complex_dtype(node): - logger.debug( - f"node.op is added to subgraph: {node.op}, node name: {node.name} is complex" - ) - return node.op == "call_function" and self.is_complex_dtype(node) + # Include call_function ops that either output complex OR consume complex inputs. + # The second condition catches real-output ops like abs, angle, real, imag whose + # inputs are complex and must be rewritten alongside the rest of the subgraph. + if node.op != "call_function": + return False + return self.is_complex_dtype(node) or self.has_complex_input(node) def subgraph_from_anchor(self, anchor_node: Node) -> ComplexSubGraphInfo: subgraph_nodes: Set[Node] = set() @@ -64,15 +138,19 @@ def subgraph_from_anchor(self, anchor_node: Node) -> ComplexSubGraphInfo: if n in subgraph_nodes: continue subgraph_nodes.add(n) - logger.debug(f"node {n.name} is added to subgraph") for inp in n.all_input_nodes: if self.node_include_in_subgraph(inp): stack.append(inp) else: input_nodes.add(inp) - return ComplexSubGraphInfo( - [anchor_node], list(subgraph_nodes), list(input_nodes) - ) + # Sort subgraph_nodes in topological (graph) order so the rewriter + # processes producers before consumers. The set has no stable order, + # which caused bugs when e.g. mul(sin, sin) was processed before sin + # was rewritten (sin still had complex dtype, so the mul pattern ran + # against the original complex node and produced wrong results). + node_order = {n: i for i, n in enumerate(anchor_node.graph.nodes)} + ordered_subgraph = sorted(subgraph_nodes, key=lambda n: node_order.get(n, 0)) + return ComplexSubGraphInfo([anchor_node], ordered_subgraph, list(input_nodes)) def find_complex_op_subgraphs( self, gm: GraphModule, anchor_target: str @@ -86,10 +164,13 @@ def find_complex_op_subgraphs( for existing_sub in complex_op_subgraphs: if set(existing_sub.subgraph_nodes) & set(new_sub.subgraph_nodes): logger.debug(f"merging subgraphs {existing_sub} {new_sub}") - # merge the two subgraphs - existing_sub.subgraph_nodes = list( - set(existing_sub.subgraph_nodes) - | set(new_sub.subgraph_nodes) + # merge the two subgraphs, preserving topological order + merged_nodes = set(existing_sub.subgraph_nodes) | set( + new_sub.subgraph_nodes + ) + node_order = {n: i for i, n in enumerate(gm.graph.nodes)} + existing_sub.subgraph_nodes = sorted( + merged_nodes, key=lambda n: node_order.get(n, 0) ) existing_sub.input_nodes = list( set(existing_sub.input_nodes) | set(new_sub.input_nodes) @@ -103,7 +184,41 @@ def find_complex_op_subgraphs( complex_op_subgraphs.append(new_sub) return complex_op_subgraphs + def find_all_complex_subgraphs(self, gm: GraphModule) -> List[ComplexSubGraphInfo]: + """Forward scan: collect all complex-dtype call_function nodes as one subgraph. + + Unlike find_complex_op_subgraphs (which walks backwards from a single anchor), + this scans forward over every node and collects all call_function nodes whose + output is complex — regardless of whether they are bounded by view_as_real. + This ensures complex ops that feed directly into graph outputs (no view_as_real) + are still rewritten to real arithmetic. + """ + subgraph_nodes: Set[Node] = set() + input_nodes: Set[Node] = set() + for node in gm.graph.nodes: + if not self.node_include_in_subgraph(node): + continue + subgraph_nodes.add(node) + for inp in node.all_input_nodes: + if not self.node_include_in_subgraph(inp): + input_nodes.add(inp) + if not subgraph_nodes: + return [] + # Sort in topological (graph) order so the rewriter processes producers + # before consumers, avoiding the case where e.g. a mul node is rewritten + # before its sin/cos inputs are rewritten (which causes wrong results). + node_order = {n: i for i, n in enumerate(gm.graph.nodes)} + ordered = sorted(subgraph_nodes, key=lambda n: node_order.get(n, 0)) + return [ + ComplexSubGraphInfo( + anchor_nodes=ordered, + subgraph_nodes=ordered, + input_nodes=list(input_nodes), + ) + ] + +@_register_unpackers class ComplexGraphRewriter: def __init__(self, gm: GraphModule, truncate_double: bool = False) -> None: self.gm = gm @@ -146,21 +261,60 @@ def get_attr_tensor(self, target): # type: ignore f"Attribute {target} not found in gm parameters or buffers." ) - def replace_input_node(self, input_node: Node) -> None: + def replace_input_node( + self, input_node: Node, fake_mode: Optional[FakeTensorMode] = None + ) -> None: modified = False - logger.debug(f"Replacing input node: {input_node.name}") new_shape, new_dtype, device = self.extract_shape_dtype_device(input_node) - real_tensor = torch.empty(new_shape, dtype=new_dtype, device=device) if input_node.op == "placeholder": - with FakeTensorMode() as fake_mode: + if fake_mode is None: + fake_mode = FakeTensorMode() + # Preserve symbolic dimensions from the original placeholder's fake + # tensor so that dynamic-shape information (SymInt ranges from + # torch.export) survives the rewrite. We build the new fake tensor + # by appending a concrete 2 to the original symbolic shape. + # + # We use the *original* placeholder's FakeTensorMode + # (which owns the ShapeEnv with the export's range constraints) so + # that the new SymInt dimensions belong to the same ShapeEnv as all + # other nodes in the graph. Using shared_fake_mode would create a + # separate ShapeEnv and cause "symbol from different env" errors + # during FakeTensorProp. + orig_fake = input_node.meta.get("val", None) + if orig_fake is not None and hasattr(orig_fake, "shape"): + # orig_fake.shape contains the symbolic sizes; append 2 for real/imag. + sym_shape = list(orig_fake.shape) + [2] + orig_mode = getattr(orig_fake, "fake_mode", None) + create_mode = orig_mode if orig_mode is not None else fake_mode + with create_mode: + fake_tensor = torch.empty(sym_shape, dtype=new_dtype, device=device) + else: + concrete_shape = tuple( + int(s) if not isinstance(s, int) else s for s in new_shape + ) + real_tensor = torch.empty( + concrete_shape, dtype=new_dtype, device=device + ) fake_tensor = fake_mode.from_tensor(real_tensor) with self.gm.graph.inserting_before(input_node): - new_node = self.gm.graph.placeholder(input_node.target + "_reshaped") + new_node = self.gm.graph.placeholder( + input_node.target + "_unpacked_complex" + ) new_node.meta["val"] = fake_tensor + logger.debug( + " unpack placeholder %s%s -> %s%s", + input_node.name, + tuple(fake_tensor.shape[:-1]), + new_node.name, + tuple(fake_tensor.shape), + ) elif input_node.op == "get_attr": - new_attr_name = input_node.target + "_reshaped" + # Sanitize dots from nested-module targets (e.g. "block1.freq") + # so register_buffer does not raise KeyError on dotted names. + sanitized = input_node.target.replace(".", "__") # type: ignore + new_attr_name = sanitized + "_unpacked_complex" with unset_fake_temporarily(): original_tensor = self.get_attr_tensor(input_node.target) # type: ignore stacked_tensor = torch.stack( @@ -169,93 +323,944 @@ def replace_input_node(self, input_node: Node) -> None: self.gm.register_buffer(new_attr_name, stacked_tensor) with self.gm.graph.inserting_after(input_node): new_node = self.gm.graph.get_attr(new_attr_name) - else: logger.debug( - f"Unsupported node type in replacement of input node: {input_node.op}" - ) - logger.debug( - "This complex subgraph inputnode type does not need to replaced" + " unpack get_attr %s%s -> %s%s", + input_node.target, + tuple(original_tensor.shape), + new_attr_name, + tuple(stacked_tensor.shape), ) + else: + pass # call_function inputs are rewritten in-place by the op handlers input_node.replace_all_uses_with(new_node) self.gm.graph.erase_node(input_node) clean_up_graph_after_modifications(self.gm) + # ------------------------------------------------------------------ + # Private graph-building helpers + # + # Each helper takes a SubgraphBuilder and emits a sub-sequence of nodes, + # advancing the builder's cursor. They return the last node(s) they + # inserted. + # ------------------------------------------------------------------ + + @staticmethod + def _inline_select_re_im(b: SubgraphBuilder, inp: Node) -> Tuple[Node, Node]: + """Select re (index 0) and im (index 1) from a [..., 2] tensor.""" + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + return re, im + + @staticmethod + def _inline_cat_re_im(b: SubgraphBuilder, out_re: Node, out_im: Node) -> Node: + """Rebuild a [..., 2] complex-layout tensor from re and im nodes.""" + re_u = b(torch.ops.aten.unsqueeze.default, out_re, -1) + im_u = b(torch.ops.aten.unsqueeze.default, out_im, -1) + return b(torch.ops.aten.cat.default, [re_u, im_u], -1) + + @staticmethod + def _inline_complex_log( + b: SubgraphBuilder, re: Node, im: Node + ) -> Tuple[Node, Node]: + """log(a+bi) = 0.5*log(a²+b²) + i*atan2(b, a)""" + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + r2 = b(torch.ops.aten.add.Tensor, re2, im2) + log_r2 = b(torch.ops.aten.log.default, r2) + log_re = b(torch.ops.aten.mul.Tensor, log_r2, 0.5) + log_im = b(torch.ops.aten.atan2.default, im, re) + return log_re, log_im + + @staticmethod + def _inline_complex_exp( + b: SubgraphBuilder, re: Node, im: Node + ) -> Tuple[Node, Node]: + """exp(a+bi) = e^a*cos(b) + i*e^a*sin(b)""" + ea = b(torch.ops.aten.exp.default, re) + cos_b = b(torch.ops.aten.cos.default, im) + sin_b = b(torch.ops.aten.sin.default, im) + exp_re = b(torch.ops.aten.mul.Tensor, ea, cos_b) + exp_im = b(torch.ops.aten.mul.Tensor, ea, sin_b) + return exp_re, exp_im + + @staticmethod + def _inline_complex_mul( + b: SubgraphBuilder, re1: Node, im1: Node, re2: Node, im2: Node + ) -> Tuple[Node, Node]: + """(a+bi)(c+di) = (ac-bd) + (ad+bc)i""" + ac = b(torch.ops.aten.mul.Tensor, re1, re2) + bd = b(torch.ops.aten.mul.Tensor, im1, im2) + ad = b(torch.ops.aten.mul.Tensor, re1, im2) + bc = b(torch.ops.aten.mul.Tensor, im1, re2) + out_re = b(torch.ops.aten.sub.Tensor, ac, bd) + out_im = b(torch.ops.aten.add.Tensor, ad, bc) + return out_re, out_im + + @staticmethod + def _inline_complex_div( + b: SubgraphBuilder, re1: Node, im1: Node, re2: Node, im2: Node + ) -> Tuple[Node, Node]: + """(a+bi)/(c+di) = ((ac+bd) + (bc-ad)i) / (c²+d²)""" + c2 = b(torch.ops.aten.mul.Tensor, re2, re2) + d2 = b(torch.ops.aten.mul.Tensor, im2, im2) + denom = b(torch.ops.aten.add.Tensor, c2, d2) + ac = b(torch.ops.aten.mul.Tensor, re1, re2) + bd = b(torch.ops.aten.mul.Tensor, im1, im2) + bc = b(torch.ops.aten.mul.Tensor, im1, re2) + ad = b(torch.ops.aten.mul.Tensor, re1, im2) + numer_re = b(torch.ops.aten.add.Tensor, ac, bd) + numer_im = b(torch.ops.aten.sub.Tensor, bc, ad) + out_re = b(torch.ops.aten.div.Tensor, numer_re, denom) + out_im = b(torch.ops.aten.div.Tensor, numer_im, denom) + return out_re, out_im + + @staticmethod + def _inline_complex_sqrt( + b: SubgraphBuilder, re: Node, im: Node + ) -> Tuple[Node, Node]: + """sqrt(z) = r^0.5 * (cos(θ/2) + i*sin(θ/2))""" + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + r2 = b(torch.ops.aten.add.Tensor, re2, im2) + r = b(torch.ops.aten.sqrt.default, r2) + r_sq = b(torch.ops.aten.pow.Tensor_Scalar, r, 0.5) + theta = b(torch.ops.aten.atan2.default, im, re) + half_theta = b(torch.ops.aten.mul.Tensor, theta, 0.5) + cos_ht = b(torch.ops.aten.cos.default, half_theta) + sin_ht = b(torch.ops.aten.sin.default, half_theta) + sq_re = b(torch.ops.aten.mul.Tensor, r_sq, cos_ht) + sq_im = b(torch.ops.aten.mul.Tensor, r_sq, sin_ht) + return sq_re, sq_im + + # ------------------------------------------------------------------ + # Per-op rewrite handlers + # + # Each method receives the node to rewrite and returns True if it + # modified the graph. They are registered in _build_dispatch_table() + # which is called at the end of __init__. + # ------------------------------------------------------------------ + + @_complex_unpacker(torch.ops.aten.view_as_complex.default) + def _rewrite_view_as_complex(self, node: Node) -> bool: + node.replace_all_uses_with(node.args[0]) + self.gm.graph.erase_node(node) + return False # bypass only, no structural change that needs propagation + + @_complex_unpacker(torch.ops.aten.view_as_real.default) + def _rewrite_view_as_real(self, node: Node) -> bool: + node.replace_all_uses_with(node.args[0]) + self.gm.graph.erase_node(node) + return False + + @_complex_unpacker(torch.ops.aten.permute.default) + def _rewrite_permute(self, node: Node) -> bool: + # permute on a complex tensor: after rewrite the tensor has an extra + # trailing dim of size 2 (real/imag). Append the index for that + # trailing dim so the permutation stays valid. + inp = node.args[0] + orig_dims = list(node.args[1]) + n_orig = len(orig_dims) + new_dims = [d % n_orig for d in orig_dims] + [n_orig] + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.permute.default, inp, new_dims) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.mul.Tensor, torch.ops.aten.div.Tensor) + def _rewrite_mul_div_tensor(self, node: Node) -> bool: + arg0_is_node = isinstance(node.args[0], torch.fx.Node) + arg1_is_node = isinstance(node.args[1], torch.fx.Node) + + if not arg0_is_node and not arg1_is_node: + return False # both scalars + + if node.target == torch.ops.aten.mul.Tensor and ( + not arg0_is_node or not arg1_is_node + ): + return False # scalar * complex — elementwise-safe + + if node.target == torch.ops.aten.div.Tensor and not arg1_is_node: + return False # complex / scalar — elementwise-safe + + if node.target == torch.ops.aten.div.Tensor and not arg0_is_node: + # scalar / complex: s/(a+bi) = (s*a/(a²+b²)) + i*(-s*b/(a²+b²)) + scalar_val = node.args[0] + z_node = node.args[1] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, z_node, -1, 0) + im = b(torch.ops.aten.select.int, z_node, -1, 1) + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + denom = b(torch.ops.aten.add.Tensor, re2, im2) + re_s = b(torch.ops.aten.mul.Tensor, re, scalar_val) + out_re = b(torch.ops.aten.div.Tensor, re_s, denom) + im_s = b(torch.ops.aten.mul.Tensor, im, scalar_val) + neg_im_s = b(torch.ops.aten.neg.default, im_s) + out_im = b(torch.ops.aten.div.Tensor, neg_im_s, denom) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + # Both args are Nodes from here on. + if node.target == torch.ops.aten.div.Tensor: + detector = ComplexOpDetector() + + def _is_complex_layout(n: Node) -> bool: + if detector.is_complex_dtype(n): + return True + val = n.meta.get("val", None) + if val is not None and hasattr(val, "shape"): + return len(val.shape) >= 1 and val.shape[-1] == 2 + return False + + arg0_layout = _is_complex_layout(node.args[0]) + arg1_layout = _is_complex_layout(node.args[1]) + + if arg0_layout and not arg1_layout: + # complex_layout / real — unsqueeze denom for correct broadcast + with SubgraphBuilder(self.gm.graph, node) as b: + denom_unsq = b(torch.ops.aten.unsqueeze.default, node.args[1], -1) + out = b(torch.ops.aten.div.Tensor, node.args[0], denom_unsq) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + elif not arg0_layout and not arg1_layout: + return False # both real — elementwise-safe + else: + # complex / complex — full div rewrite + x_pf = node.args[0].op != "get_attr" + y_pf = node.args[1].op != "get_attr" + original_div, replacement = complex_div_replacement(x_pf, y_pf) + + def match_complex_div( + match: torch.fx.subgraph_rewriter.Match, + original_graph: object, + pattern_graph: object, + ) -> bool: + for original_node in match.nodes_map.values(): + if not isinstance(original_node, torch.fx.Node): + continue + if original_node.name == node.name: + return True + return False + + torch.fx.subgraph_rewriter.replace_pattern_with_filters( + self.gm, + original_div, + replacement, + match_filters=[match_complex_div], + ignore_literals=True, + ) + return True + + # mul.Tensor, both nodes — complex × complex + # Use SubgraphBuilder directly rather than replace_pattern_with_filters so + # that self-multiplication (mul(x, x)) is handled correctly. + # replace_pattern_with_filters requires distinct placeholder nodes for x and y, + # so it silently produces no matches when both args are the same node. + if node in self._originally_complex: + x, y = node.args[0], node.args[1] + x_is_get_attr = x.op == "get_attr" + y_is_get_attr = y.op == "get_attr" + + if not x_is_get_attr and not y_is_get_attr: + # Both are ITensors — use select.int (TRT-compatible) + with SubgraphBuilder(self.gm.graph, node) as b: + x_re = b(torch.ops.aten.select.int, x, -1, 0) + x_im = b(torch.ops.aten.select.int, x, -1, 1) + y_re = b(torch.ops.aten.select.int, y, -1, 0) + y_im = b(torch.ops.aten.select.int, y, -1, 1) + ac = b(torch.ops.aten.mul.Tensor, x_re, y_re) + bd = b(torch.ops.aten.mul.Tensor, x_im, y_im) + ad = b(torch.ops.aten.mul.Tensor, x_re, y_im) + bc = b(torch.ops.aten.mul.Tensor, x_im, y_re) + out_re = b(torch.ops.aten.sub.Tensor, ac, bd) + out_im = b(torch.ops.aten.add.Tensor, ad, bc) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + else: + # At least one arg is a get_attr buffer — fall back to the + # pattern rewriter which uses tensor indexing for get_attr nodes. + x_pf = not x_is_get_attr + y_pf = not y_is_get_attr + original_mul, replacement = complex_mul_replacement(x_pf, y_pf) + + def match_complex_mul( + match: torch.fx.subgraph_rewriter.Match, + original_graph: object, + pattern_graph: object, + ) -> bool: + for original_node in match.nodes_map.values(): + if not isinstance(original_node, torch.fx.Node): + continue + if original_node.name == node.name: + return True + return False + + torch.fx.subgraph_rewriter.replace_pattern_with_filters( + self.gm, + original_mul, + replacement, + match_filters=[match_complex_mul], + ignore_literals=True, + ) + return True + return False + + @_complex_unpacker(torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor) + def _rewrite_add_sub_tensor_scalar(self, node: Node) -> bool: + # add.Tensor(z, scalar) / sub.Tensor(z, scalar): scalar applies to real part only. + if len(node.args) < 2 or isinstance(node.args[1], torch.fx.Node): + return False + inp, scalar = node.args[0], node.args[1] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + new_re = b(node.target, re, scalar) + out = self._inline_cat_re_im(b, new_re, im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten._conj.default) + def _rewrite_conj(self, node: Node) -> bool: + # conj(a+bi) = a - bi + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + neg_im = b(torch.ops.aten.neg.default, im) + out = self._inline_cat_re_im(b, re, neg_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.abs.default) + def _rewrite_abs(self, node: Node) -> bool: + # |a+bi| = sqrt(a²+b²) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + sum_ = b(torch.ops.aten.add.Tensor, re2, im2) + out = b(torch.ops.aten.sqrt.default, sum_) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.exp.default) + def _rewrite_exp(self, node: Node) -> bool: + # exp(a+bi) = e^a*cos(b) + i*e^a*sin(b) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + exp_re, exp_im = self._inline_complex_exp(b, re, im) + out = self._inline_cat_re_im(b, exp_re, exp_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.log.default) + def _rewrite_log(self, node: Node) -> bool: + # log(a+bi) = 0.5*log(a²+b²) + i*atan2(b, a) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + log_re, log_im = self._inline_complex_log(b, re, im) + out = self._inline_cat_re_im(b, log_re, log_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.pow.Tensor_Scalar, torch.ops.aten.sqrt.default) + def _rewrite_pow_sqrt(self, node: Node) -> bool: + # pow(a+bi, n) / sqrt via polar form: r^n*(cos(n*θ) + i*sin(n*θ)) + inp = node.args[0] + exponent = ( + node.args[1] if node.target == torch.ops.aten.pow.Tensor_Scalar else 0.5 + ) + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + r2 = b(torch.ops.aten.add.Tensor, re2, im2) + r = b(torch.ops.aten.sqrt.default, r2) + rn = b(torch.ops.aten.pow.Tensor_Scalar, r, exponent) + theta = b(torch.ops.aten.atan2.default, im, re) + n_theta = b(torch.ops.aten.mul.Tensor, theta, exponent) + cos_n = b(torch.ops.aten.cos.default, n_theta) + sin_n = b(torch.ops.aten.sin.default, n_theta) + out_re = b(torch.ops.aten.mul.Tensor, rn, cos_n) + out_im = b(torch.ops.aten.mul.Tensor, rn, sin_n) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.real.default) + def _rewrite_real(self, node: Node) -> bool: + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.select.int, node.args[0], -1, 0) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.imag.default) + def _rewrite_imag(self, node: Node) -> bool: + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.select.int, node.args[0], -1, 1) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.angle.default) + def _rewrite_angle(self, node: Node) -> bool: + # angle(a+bi) = atan2(b, a) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + out = b(torch.ops.aten.atan2.default, im, re) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.polar.default) + def _rewrite_polar(self, node: Node) -> bool: + # polar(r, theta) = r*cos(theta) + i*r*sin(theta) + r_arg, theta_arg = node.args[0], node.args[1] + with SubgraphBuilder(self.gm.graph, node) as b: + cos_t = b(torch.ops.aten.cos.default, theta_arg) + sin_t = b(torch.ops.aten.sin.default, theta_arg) + out_re = b(torch.ops.aten.mul.Tensor, r_arg, cos_t) + out_im = b(torch.ops.aten.mul.Tensor, r_arg, sin_t) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.add.Scalar, torch.ops.aten.sub.Scalar) + def _rewrite_add_sub_scalar(self, node: Node) -> bool: + # (a+bi) ± s = (a±s) + bi — scalar applies to real part only + inp, scalar = node.args[0], node.args[1] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + new_re = b(node.target, re, scalar) + out = self._inline_cat_re_im(b, new_re, im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.log2.default, torch.ops.aten.log10.default) + def _rewrite_log2_log10(self, node: Node) -> bool: + # log_b(z) = log(z) / log(b) + base_val = ( + math.log(2.0) + if node.target == torch.ops.aten.log2.default + else math.log(10.0) + ) + inp = node.args[0] + inv_base = 1.0 / base_val + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + log_re, log_im = self._inline_complex_log(b, re, im) + out_re = b(torch.ops.aten.mul.Tensor, log_re, inv_base) + out_im = b(torch.ops.aten.mul.Tensor, log_im, inv_base) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.isnan.default, torch.ops.aten.isinf.default) + def _rewrite_isnan_isinf(self, node: Node) -> bool: + # isnan/isinf(z) = isnan/isinf(re) | isnan/isinf(im) + inp = node.args[0] + op = node.target + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + re_flag = b(op, re) + im_flag = b(op, im) + out = b(torch.ops.aten.logical_or.default, re_flag, im_flag) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.log1p.default) + def _rewrite_log1p(self, node: Node) -> bool: + # log1p(a+bi) = log((a+1) + bi) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + re1 = b(torch.ops.aten.add.Tensor, re, 1.0) + log_re, log_im = self._inline_complex_log(b, re1, im) + out = self._inline_cat_re_im(b, log_re, log_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.expm1.default) + def _rewrite_expm1(self, node: Node) -> bool: + # expm1(a+bi) = (exp(a)*cos(b) - 1) + i*(exp(a)*sin(b)) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + exp_re, exp_im = self._inline_complex_exp(b, re, im) + out_re = b(torch.ops.aten.sub.Tensor, exp_re, 1.0) + out = self._inline_cat_re_im(b, out_re, exp_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.sin.default) + def _rewrite_sin(self, node: Node) -> bool: + # sin(a+bi) = sin(a)*cosh(b) + i*cos(a)*sinh(b) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + sin_a = b(torch.ops.aten.sin.default, re) + cosh_b = b(torch.ops.aten.cosh.default, im) + cos_a = b(torch.ops.aten.cos.default, re) + sinh_b = b(torch.ops.aten.sinh.default, im) + out_re = b(torch.ops.aten.mul.Tensor, sin_a, cosh_b) + out_im = b(torch.ops.aten.mul.Tensor, cos_a, sinh_b) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.cos.default) + def _rewrite_cos(self, node: Node) -> bool: + # cos(a+bi) = cos(a)*cosh(b) - i*sin(a)*sinh(b) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + cos_a = b(torch.ops.aten.cos.default, re) + cosh_b = b(torch.ops.aten.cosh.default, im) + sin_a = b(torch.ops.aten.sin.default, re) + sinh_b = b(torch.ops.aten.sinh.default, im) + out_re = b(torch.ops.aten.mul.Tensor, cos_a, cosh_b) + raw_im = b(torch.ops.aten.mul.Tensor, sin_a, sinh_b) + out_im = b(torch.ops.aten.neg.default, raw_im) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.sinh.default) + def _rewrite_sinh(self, node: Node) -> bool: + # sinh(a+bi) = sinh(a)*cos(b) + i*cosh(a)*sin(b) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + sinh_a = b(torch.ops.aten.sinh.default, re) + cos_b = b(torch.ops.aten.cos.default, im) + cosh_a = b(torch.ops.aten.cosh.default, re) + sin_b = b(torch.ops.aten.sin.default, im) + out_re = b(torch.ops.aten.mul.Tensor, sinh_a, cos_b) + out_im = b(torch.ops.aten.mul.Tensor, cosh_a, sin_b) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.cosh.default) + def _rewrite_cosh(self, node: Node) -> bool: + # cosh(a+bi) = cosh(a)*cos(b) + i*sinh(a)*sin(b) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + cosh_a = b(torch.ops.aten.cosh.default, re) + cos_b = b(torch.ops.aten.cos.default, im) + sinh_a = b(torch.ops.aten.sinh.default, re) + sin_b = b(torch.ops.aten.sin.default, im) + out_re = b(torch.ops.aten.mul.Tensor, cosh_a, cos_b) + out_im = b(torch.ops.aten.mul.Tensor, sinh_a, sin_b) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.tan.default) + def _rewrite_tan(self, node: Node) -> bool: + # tan(a+bi) = sin(2a)/(cos(2a)+cosh(2b)) + i*sinh(2b)/(cos(2a)+cosh(2b)) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + two_re = b(torch.ops.aten.mul.Tensor, re, 2.0) + two_im = b(torch.ops.aten.mul.Tensor, im, 2.0) + sin_2a = b(torch.ops.aten.sin.default, two_re) + cos_2a = b(torch.ops.aten.cos.default, two_re) + sinh_2b = b(torch.ops.aten.sinh.default, two_im) + cosh_2b = b(torch.ops.aten.cosh.default, two_im) + denom = b(torch.ops.aten.add.Tensor, cos_2a, cosh_2b) + out_re = b(torch.ops.aten.div.Tensor, sin_2a, denom) + out_im = b(torch.ops.aten.div.Tensor, sinh_2b, denom) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.tanh.default) + def _rewrite_tanh(self, node: Node) -> bool: + # tanh(a+bi) = sinh(2a)/(cosh(2a)+cos(2b)) + i*sin(2b)/(cosh(2a)+cos(2b)) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + two_re = b(torch.ops.aten.mul.Tensor, re, 2.0) + two_im = b(torch.ops.aten.mul.Tensor, im, 2.0) + sinh_2a = b(torch.ops.aten.sinh.default, two_re) + cosh_2a = b(torch.ops.aten.cosh.default, two_re) + sin_2b = b(torch.ops.aten.sin.default, two_im) + cos_2b = b(torch.ops.aten.cos.default, two_im) + denom = b(torch.ops.aten.add.Tensor, cosh_2a, cos_2b) + out_re = b(torch.ops.aten.div.Tensor, sinh_2a, denom) + out_im = b(torch.ops.aten.div.Tensor, sin_2b, denom) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.asinh.default) + def _rewrite_asinh(self, node: Node) -> bool: + # asinh(z) = log(z + sqrt(z² + 1)) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re, im = self._inline_select_re_im(b, inp) + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + z2_re = b(torch.ops.aten.sub.Tensor, re2, im2) + re_im = b(torch.ops.aten.mul.Tensor, re, im) + z2_im = b(torch.ops.aten.mul.Tensor, re_im, 2.0) + w_re = b(torch.ops.aten.add.Scalar, z2_re, 1.0) # w = z²+1 + sq_re, sq_im = self._inline_complex_sqrt(b, w_re, z2_im) + sum_re = b(torch.ops.aten.add.Tensor, re, sq_re) + sum_im = b(torch.ops.aten.add.Tensor, im, sq_im) + log_re, log_im = self._inline_complex_log(b, sum_re, sum_im) + out = self._inline_cat_re_im(b, log_re, log_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.acosh.default) + def _rewrite_acosh(self, node: Node) -> bool: + # acosh(z) = log(z + sqrt(z² - 1)) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re, im = self._inline_select_re_im(b, inp) + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + z2_re = b(torch.ops.aten.sub.Tensor, re2, im2) + re_im = b(torch.ops.aten.mul.Tensor, re, im) + z2_im = b(torch.ops.aten.mul.Tensor, re_im, 2.0) + w_re = b(torch.ops.aten.sub.Scalar, z2_re, 1.0) # w = z²-1 + sq_re, sq_im = self._inline_complex_sqrt(b, w_re, z2_im) + sum_re = b(torch.ops.aten.add.Tensor, re, sq_re) + sum_im = b(torch.ops.aten.add.Tensor, im, sq_im) + log_re, log_im = self._inline_complex_log(b, sum_re, sum_im) + out = self._inline_cat_re_im(b, log_re, log_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.atanh.default) + def _rewrite_atanh(self, node: Node) -> bool: + # atanh(z) = (1/2) * log((1+z) / (1-z)) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re, im = self._inline_select_re_im(b, inp) + p_re = b(torch.ops.aten.add.Scalar, re, 1.0) # 1+re + q_re = b(torch.ops.aten.sub.Scalar, re, 1.0) # re-1 + neg_q_re = b(torch.ops.aten.neg.default, q_re) # 1-re + neg_im = b(torch.ops.aten.neg.default, im) + div_re, div_im = self._inline_complex_div(b, p_re, im, neg_q_re, neg_im) + log_re, log_im = self._inline_complex_log(b, div_re, div_im) + out_re = b(torch.ops.aten.mul.Tensor, log_re, 0.5) + out_im = b(torch.ops.aten.mul.Tensor, log_im, 0.5) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.asin.default) + def _rewrite_asin(self, node: Node) -> bool: + # asin(z) = -i * log(iz + sqrt(1 - z²)) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re, im = self._inline_select_re_im(b, inp) + iz_re = b(torch.ops.aten.neg.default, im) # iz = (-im, re) + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + z2_re = b(torch.ops.aten.sub.Tensor, re2, im2) + re_im = b(torch.ops.aten.mul.Tensor, re, im) + z2_im = b(torch.ops.aten.mul.Tensor, re_im, 2.0) + ones = b(torch.ops.aten.ones_like.default, z2_re) + w_re = b(torch.ops.aten.sub.Tensor, ones, z2_re) # 1-z² + w_im = b(torch.ops.aten.neg.default, z2_im) + sq_re, sq_im = self._inline_complex_sqrt(b, w_re, w_im) + sum_re = b(torch.ops.aten.add.Tensor, iz_re, sq_re) + sum_im = b(torch.ops.aten.add.Tensor, re, sq_im) # iz_im = re + log_re, log_im = self._inline_complex_log(b, sum_re, sum_im) + # -i*(log_re + i*log_im) = log_im + i*(-log_re) + out_im = b(torch.ops.aten.neg.default, log_re) + out = self._inline_cat_re_im(b, log_im, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.acos.default) + def _rewrite_acos(self, node: Node) -> bool: + # acos(z) = -i * log(z + i*sqrt(1 - z²)) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re, im = self._inline_select_re_im(b, inp) + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + z2_re = b(torch.ops.aten.sub.Tensor, re2, im2) + re_im = b(torch.ops.aten.mul.Tensor, re, im) + z2_im = b(torch.ops.aten.mul.Tensor, re_im, 2.0) + ones = b(torch.ops.aten.ones_like.default, z2_re) + w_re = b(torch.ops.aten.sub.Tensor, ones, z2_re) # 1-z² + w_im = b(torch.ops.aten.neg.default, z2_im) + sq_re, sq_im = self._inline_complex_sqrt(b, w_re, w_im) + isq_re = b(torch.ops.aten.neg.default, sq_im) # i*sqrt = (-sq_im, sq_re) + sum_re = b(torch.ops.aten.add.Tensor, re, isq_re) + sum_im = b(torch.ops.aten.add.Tensor, im, sq_re) + log_re, log_im = self._inline_complex_log(b, sum_re, sum_im) + # -i*(log_re + i*log_im) = log_im + i*(-log_re) + out_im = b(torch.ops.aten.neg.default, log_re) + out = self._inline_cat_re_im(b, log_im, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.atan.default) + def _rewrite_atan(self, node: Node) -> bool: + # atan(z) = (i/2) * log((1-iz) / (1+iz)) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re, im = self._inline_select_re_im(b, inp) + iz_re = b(torch.ops.aten.neg.default, im) # iz = (-im, re) + ones = b(torch.ops.aten.ones_like.default, re) + p_re = b(torch.ops.aten.sub.Tensor, ones, iz_re) # 1-iz + p_im = b(torch.ops.aten.neg.default, re) + q_re = b(torch.ops.aten.add.Tensor, ones, iz_re) # 1+iz + q_im = re # iz_im = re + div_re, div_im = self._inline_complex_div(b, p_re, p_im, q_re, q_im) + log_re, log_im = self._inline_complex_log(b, div_re, div_im) + # (i/2)*(log_re+i*log_im) = (-log_im/2) + i*(log_re/2) + out_re = b(torch.ops.aten.mul.Tensor, log_im, -0.5) + out_im = b(torch.ops.aten.mul.Tensor, log_re, 0.5) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.pow.Tensor_Tensor) + def _rewrite_pow_tensor_tensor(self, node: Node) -> bool: + # z1**z2 = exp(z2 * log(z1)) + z1_inp, z2_inp = node.args[0], node.args[1] + with SubgraphBuilder(self.gm.graph, node) as b: + re1, im1 = self._inline_select_re_im(b, z1_inp) + re2 = b(torch.ops.aten.select.int, z2_inp, -1, 0) + im2 = b(torch.ops.aten.select.int, z2_inp, -1, 1) + log_re, log_im = self._inline_complex_log(b, re1, im1) + mul_re, mul_im = self._inline_complex_mul(b, re2, im2, log_re, log_im) + exp_re, exp_im = self._inline_complex_exp(b, mul_re, mul_im) + out = self._inline_cat_re_im(b, exp_re, exp_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.scalar_tensor.default) + def _rewrite_scalar_tensor(self, node: Node) -> bool: + # scalar_tensor(val, dtype=complex64) → scalar_tensor(0.0, float32) + if dict(node.kwargs).get("dtype") not in (torch.complex64, torch.complex128): + return False + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.scalar_tensor.default, 0.0) + out.kwargs = {"dtype": torch.float32} # type: ignore[assignment] + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.where.self) + def _rewrite_where(self, node: Node) -> bool: + # where.self: unsqueeze mask and optionally expand true-branch for complex layout. + if len(node.args) != 3: + return False + node_val = node.meta.get("val", None) + if node_val is None or not hasattr(node_val, "dtype"): + return False + if node_val.dtype not in (torch.complex64, torch.complex128): + return False + mask_node, true_node, other_node = node.args + target_shape = list(node_val.shape) + [2] + with SubgraphBuilder(self.gm.graph, node) as b: + mask_unsq = b(torch.ops.aten.unsqueeze.default, mask_node, -1) + true_arg = true_node + if isinstance(true_node, torch.fx.Node): + true_val = true_node.meta.get("val", None) + if ( + true_val is not None + and hasattr(true_val, "shape") + and list(true_val.shape) == [2] + ): + true_arg = b(torch.ops.aten.expand.default, true_node, target_shape) + out = b(torch.ops.aten.where.self, mask_unsq, true_arg, other_node) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + def rewrite_subgraph_nodes(self, subgraphs: List[ComplexSubGraphInfo]) -> None: modified = False + # Detect the existing FakeTensorMode from the graph's placeholders + # *before* any rewrites. We pass this to replace_input_node so that + # new placeholder fake tensors are created under the same mode as the + # rest of the graph. Using a fresh FakeTensorMode would cause "mode + # mismatch" assertions under torch.compile (where a mode is already + # active) and would lose SymInt information for torch.export graphs. + detected_fake_mode = torch._export.utils._detect_fake_mode_from_gm(self.gm) + + # Record the set of all nodes that have complex dtype BEFORE any rewriting. + # This is needed because after replace_input_node (which changes dtype from + # complex to float32), is_complex_dtype() would return False for those nodes — + # but we still need to know they were originally complex when we later decide + # whether a mul.Tensor operand should be treated as complex-layout. + detector = ComplexOpDetector() + self._originally_complex: Set[Node] = set() + for subgraph in subgraphs: + for node in subgraph.input_nodes: + if detector.is_complex_dtype(node): + self._originally_complex.add(node) + for node in subgraph.subgraph_nodes: + if detector.is_complex_dtype(node): + self._originally_complex.add(node) + + # _DISPATCH maps op -> unbound method; bind self here once per call. + dispatch = {op: method.__get__(self) for op, method in self._DISPATCH.items()} + + logger.debug( + "complex_graph_rewrite begin subgraphs=%d nodes=%s", + len(subgraphs), + [n.name for s in subgraphs for n in s.subgraph_nodes], + ) + for subgraph in subgraphs: for input_node in subgraph.input_nodes: - logger.debug(f"Input node rewrite: {input_node.name}") if input_node.op not in ("call_function"): - self.replace_input_node(input_node) + # Only rewrite inputs that are themselves complex — real inputs + # to complex-output ops (e.g. r, theta for polar) must NOT be + # renamed to *_unpacked_complex. + if not detector.is_complex_dtype(input_node): + continue + self.replace_input_node(input_node, fake_mode=detected_fake_mode) for node in subgraph.subgraph_nodes: - logger.debug(f"Subgraph Node rewrite: {node.name}") - if node.target == torch.ops.aten.view_as_complex.default: - node.replace_all_uses_with(node.args[0]) - self.gm.graph.erase_node(node) - elif node.target == torch.ops.aten.mul.Tensor: - # this is complex mul where inputs = a+ib and output = c+id. - # complex mul returns (ac - bd) + (ad + bc)i - # which is then view_as_real as (ac-bd), (ad+bc) stacked along the last dimension with last dimension size 2 - x_placeholder_or_func = ( - True if node.args[0].op != "get_attr" else False - ) - y_placeholder_or_func = ( - True if node.args[1].op != "get_attr" else False - ) - - replaced_nodes = [] - original_mul, replacement = complex_mul_replacement( - x_placeholder_or_func, y_placeholder_or_func + # Skip nodes that were already erased by a previous pattern replacement + if node.graph is not self.gm.graph: + continue + handler = dispatch.get(node.target) + if handler is not None: + if handler(node): + modified = True + elif node.target in _ELEMENTWISE_SAFE: + logger.debug(" pass-through %s (elementwise-safe)", node.name) + else: + logger.warning( + "Complex op '%s' has no explicit rewrite rule. " + "It will be passed through as-is on the real [..., 2] layout, " + "which may produce incorrect results or fail TRT compilation. " + "Consider adding a rewrite in complex_graph_rewrite.py.", + node.target, ) - - def match_complex_mul( # type: ignore[no-untyped-def] - match: torch.fx.subgraph_rewriter.Match, - original_graph, - pattern_graph, - ) -> bool: - for original_node in match.nodes_map.values(): - if original_node.name == node.name: - return True - return False - - nodes = torch.fx.subgraph_rewriter.replace_pattern_with_filters( - self.gm, - original_mul, - replacement, - match_filters=[match_complex_mul], - ignore_literals=True, + if modified: + # After rewriting complex ops, any view_as_real node that now receives a + # real tensor must be erased. The subgraph_rewriter replaces the original + # complex mul with a cat of real/imag parts; view_as_real on that result + # is invalid. We detect this by checking whether the input to view_as_real + # is no longer complex-typed (its meta val dtype is real, or has no val yet + # but its target is the real-arithmetic cat output). + for node in list(self.gm.graph.nodes): + if node.target != torch.ops.aten.view_as_real.default: + continue + inp = node.args[0] + if not isinstance(inp, torch.fx.Node): + continue + inp_val = inp.meta.get("val", None) + # If meta is available and dtype is real, erase view_as_real + is_real_input = ( + inp_val is not None + and hasattr(inp_val, "dtype") + and inp_val.dtype not in {torch.complex64, torch.complex128} + ) + # If meta not yet propagated, use the target as a heuristic: + # the real-arithmetic replacement ends with aten.cat.default + if inp_val is None: + is_real_input = inp.target == torch.ops.aten.cat.default + if is_real_input: + inp_desc = ( + f"{inp.name}[{tuple(inp_val.shape)},{inp_val.dtype}]" + if inp_val is not None and hasattr(inp_val, "shape") + else inp.name ) - replaced_nodes += nodes - modified = True - elif node.target == torch.ops.aten.view_as_real.default: - node.replace_all_uses_with(node.args[0]) - self.gm.graph.erase_node(node) - else: - logger.debug(f"Unsupported node target: {node.target}") logger.debug( - "This complex subgraphnode type does not need to replaced" + " erase view_as_real %s (input %s is already real)", + node.name, + inp_desc, ) - - if modified: - self.propagate_metadata() + node.replace_all_uses_with(inp) + self.gm.graph.erase_node(node) + logger.debug("complex_graph_rewrite propagating metadata") + self.propagate_metadata(detected_fake_mode) self.gm.graph.lint() self.gm.recompile() + logger.debug("complex_graph_rewrite done") - def propagate_metadata(self) -> None: - fake_inputs = [] - from torch._subclasses.fake_tensor import FakeTensorMode + def propagate_metadata( + self, existing_fake_mode: Optional[FakeTensorMode] = None + ) -> None: + """Re-propagate FakeTensor metadata after graph rewrites via FakeTensorProp. + + Uses *existing_fake_mode* (detected from the graph's placeholder fake + tensors) when available. This ensures the propagation mode matches the + mode under which the graph was originally traced — critical for both + torch.compile (where a FakeTensorMode is already active on the thread) + and torch.export (where we must preserve the ShapeEnv / SymInt ranges). + + Falls back to a fresh FakeTensorMode only for plain FX graphs that have + no fake tensor metadata at all. + """ from torch.fx.passes.fake_tensor_prop import FakeTensorProp + fake_inputs = [] for node in self.gm.graph.nodes: if node.op == "placeholder": if "val" in node.meta: - with FakeTensorMode(allow_non_fake_inputs=True): - fake_val = node.meta["val"] - fake_inputs.append( - fake_val.to("cuda") - if fake_val.device.type == "cuda" - else fake_val - ) + fake_val = node.meta["val"] + fake_inputs.append( + fake_val.to("cuda") + if fake_val.device.type == "cuda" + else fake_val + ) else: fake_tensor = torch.empty( [s if s != 0 else 1 for s in node.meta["tensor_meta"].shape], @@ -263,9 +1268,13 @@ def propagate_metadata(self) -> None: device=node.meta["tensor_meta"].device, ) fake_inputs.append(fake_tensor) - FakeTensorProp( - self.gm, mode=FakeTensorMode(allow_non_fake_inputs=True) - ).propagate(*fake_inputs) + + prop_mode = ( + existing_fake_mode + if existing_fake_mode is not None + else FakeTensorMode(allow_non_fake_inputs=True) + ) + FakeTensorProp(self.gm, mode=prop_mode).propagate(*fake_inputs) def extract_real_imag(input, placeholder_or_func: bool = True): # type: ignore @@ -337,6 +1346,108 @@ def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return (original_mul, replacement) +def complex_div_replacement( + x_placeholder_or_func: bool = True, y_placeholder_or_func: bool = True +) -> Tuple[ + Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + Callable[[torch.Tensor, torch.Tensor], torch.Tensor], +]: + """Constructs the original and replacement functions for complex division. + + (a+bi)/(c+di) = ((ac+bd) + (bc-ad)i) / (c²+d²) + """ + + def original_div(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.ops.aten.div.Tensor(x, y) + + def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + x_real, x_imag = extract_real_imag(x, x_placeholder_or_func) + y_real, y_imag = extract_real_imag(y, y_placeholder_or_func) + + denom = torch.ops.aten.add.Tensor( + torch.ops.aten.mul.Tensor(y_real, y_real), + torch.ops.aten.mul.Tensor(y_imag, y_imag), + ) + real = torch.ops.aten.div.Tensor( + torch.ops.aten.add.Tensor( + torch.ops.aten.mul.Tensor(x_real, y_real), + torch.ops.aten.mul.Tensor(x_imag, y_imag), + ), + denom, + ) + imag = torch.ops.aten.div.Tensor( + torch.ops.aten.sub.Tensor( + torch.ops.aten.mul.Tensor(x_imag, y_real), + torch.ops.aten.mul.Tensor(x_real, y_imag), + ), + denom, + ) + + return torch.ops.aten.cat.default( + [ + torch.ops.aten.unsqueeze.default(real, -1), + torch.ops.aten.unsqueeze.default(imag, -1), + ], + -1, + ) + + return (original_div, replacement) + + +def _get_complex_output_indices(gm: GraphModule) -> List[int]: + """Return indices of output nodes that have complex dtype, before rewriting.""" + complex_dtypes = {torch.complex64, torch.complex128} + output_node = next((n for n in reversed(gm.graph.nodes) if n.op == "output"), None) + if output_node is None: + return [] + # output args is a tuple of the return values + outputs = output_node.args[0] + if not isinstance(outputs, (list, tuple)): + outputs = (outputs,) + indices = [] + for i, out in enumerate(outputs): + if isinstance(out, torch.fx.Node) and "val" in out.meta: + val = out.meta["val"] + if hasattr(val, "dtype") and val.dtype in complex_dtypes: + indices.append(i) + return indices + + +def _get_complex_input_names(gm: GraphModule) -> List[str]: + """Return the original names of placeholder nodes that have complex dtype, before rewriting. + + complex_graph_detection renames complex placeholders from 'name' to 'name_unpacked_complex' + and changes their dtype to float. This captures the original names so the post-partition + pass can insert view_as_real at the graph input boundary. + """ + complex_dtypes = {torch.complex64, torch.complex128} + names = [] + for node in gm.graph.nodes: + if node.op != "placeholder": + continue + val = node.meta.get("val", None) + if val is not None and hasattr(val, "dtype") and val.dtype in complex_dtypes: + names.append(node.name) + return names + + +def _get_complex_input_dtypes(gm: GraphModule) -> dict: + """Return a mapping of placeholder name -> complex dtype for complex-dtype inputs. + + Used by the post-partition boundary pass to know which inputs were complex128 + so it can insert float32 casts when truncate_double=True. + """ + complex_dtypes = {torch.complex64, torch.complex128} + dtypes = {} + for node in gm.graph.nodes: + if node.op != "placeholder": + continue + val = node.meta.get("val", None) + if val is not None and hasattr(val, "dtype") and val.dtype in complex_dtypes: + dtypes[node.name] = val.dtype + return dtypes + + # This lowering pass is used to detect and rewrite complex subgraphs in the graph def complex_graph_detection( gm: GraphModule, settings: CompilationSettings @@ -350,10 +1461,20 @@ def complex_graph_detection( Returns: The modified GraphModule with complex subgraphs rewritten """ + # Capture I/O signature before rewriting — used post-partition to restore + # the complex tensor interface at the graph boundaries. + gm.meta["complex_output_indices"] = _get_complex_output_indices(gm) + gm.meta["complex_input_names"] = _get_complex_input_names(gm) + gm.meta["complex_input_dtypes"] = _get_complex_input_dtypes(gm) + if gm.meta["complex_output_indices"]: + logger.debug( + f"Complex output indices captured: {gm.meta['complex_output_indices']}" + ) + if gm.meta["complex_input_names"]: + logger.debug(f"Complex input names captured: {gm.meta['complex_input_names']}") + complex_op_detector = ComplexOpDetector() - complex_subgraphs = complex_op_detector.find_complex_op_subgraphs( - gm, anchor_target=torch.ops.aten.view_as_real.default - ) + complex_subgraphs = complex_op_detector.find_all_complex_subgraphs(gm) for subgraph in complex_subgraphs: logger.debug(f"Complex subgraph info: {subgraph}") complex_graph_rewriter = ComplexGraphRewriter(gm, settings.truncate_double) diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py index 3a250085f1..98b478db77 100644 --- a/py/torch_tensorrt/dynamo/partitioning/common.py +++ b/py/torch_tensorrt/dynamo/partitioning/common.py @@ -6,7 +6,11 @@ from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo.utils import contains_sym_int, extract_var_range_info +from torch_tensorrt.dynamo.utils import ( + COMPLEX_TO_REAL_DTYPE, + contains_sym_int, + extract_var_range_info, +) logger = logging.getLogger(__name__) @@ -85,6 +89,17 @@ def get_input( """ Based on type of dimensions in the input_shape, construct regular or dynamic shaped inputs """ + if dtype in COMPLEX_TO_REAL_DTYPE: + real_dtype = COMPLEX_TO_REAL_DTYPE[dtype] + real_shape = torch.Size(list(input_shape) + [2]) + logger.info( + f"Input '{name}' has complex dtype {dtype}. TensorRT does not support complex " + f"tensors natively; it will be implicitly unpacked to a real tensor of shape " + f"{real_shape} and dtype {real_dtype} (last dim = [real, imag])." + ) + dtype = real_dtype + input_shape = real_shape + if contains_sym_int(input_shape): return construct_dynamic_input( input_shape, diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 0de257f7c6..0c7166655a 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -27,6 +27,8 @@ from torch._subclasses.fake_tensor import FakeTensor from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.utils._sympy.numbers import int_oo + +from packaging import version from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES @@ -37,8 +39,6 @@ from torch_tensorrt.dynamo._engine_cache import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings -from packaging import version - from .types import TRTDataType logger = logging.getLogger(__name__) @@ -99,6 +99,12 @@ class Frameworks(Enum): } +COMPLEX_TO_REAL_DTYPE: Dict[torch.dtype, torch.dtype] = { + torch.complex64: torch.float32, + torch.complex128: torch.float64, +} + + def unified_dtype_converter( dtype: Union[TRTDataType, torch.dtype, np.dtype], to: Frameworks ) -> Union[np.dtype, torch.dtype, TRTDataType]: @@ -313,6 +319,20 @@ def prepare_inputs( return inputs elif isinstance(inputs, (torch.Tensor, int, float, bool)): + if isinstance(inputs, torch.Tensor) and inputs.is_complex(): + # Complex tensors are lowered to real tensors with an extra last + # dimension of size 2 (real, imag) by complex_graph_detection. + # Build an Input whose shape/dtype reflects the lowered representation + # while keeping the original complex tensor for tracing (torch.export + # needs the complex tensor to trace the model correctly). + real_view = torch.view_as_real(inputs.contiguous()) + inp = Input.from_tensor( + real_view, disable_memory_format_check=disable_memory_format_check + ) + # Restore the original complex tensor so dynamo_trace can export + # the model with the correct input dtype. + inp.torch_tensor = inputs + return inp return Input.from_tensor( torch.tensor(inputs), disable_memory_format_check=disable_memory_format_check, @@ -818,9 +838,9 @@ def copy_metadata(match_and_replacements: List[Any]) -> None: """ for match_and_replacement in match_and_replacements: anchor_node = match_and_replacement.nodes_map[match_and_replacement.anchor] - assert ( - len(match_and_replacement.replacements) == 1 - ), "Found more than 1 replacements for the anchor node." + assert len(match_and_replacement.replacements) == 1, ( + "Found more than 1 replacements for the anchor node." + ) replacement_node = match_and_replacement.replacements[0] replacement_node.meta = anchor_node.meta @@ -859,10 +879,13 @@ def get_output_dtypes(output: Any, truncate_double: bool = False) -> List[dtype] # Placeholder output (e.g. unused slot in flash attention return tuple) pass elif isinstance(output_meta, (FakeTensor, torch.Tensor)): - if truncate_double and output_meta.dtype == torch.float64: + out_dtype = output_meta.dtype + if out_dtype in COMPLEX_TO_REAL_DTYPE: + out_dtype = COMPLEX_TO_REAL_DTYPE[out_dtype] + if truncate_double and out_dtype == torch.float64: output_dtypes.append(dtype.float32) else: - output_dtypes.append(dtype._from(output_meta.dtype)) + output_dtypes.append(dtype._from(out_dtype)) elif isinstance(output_meta, torch.SymInt): output_dtypes.append(dtype.int64) elif "tensor_meta" in output.meta: diff --git a/tests/py/dynamo/hlo/__init__.py b/tests/py/dynamo/hlo/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/py/dynamo/hlo/test_complex_ops.py b/tests/py/dynamo/hlo/test_complex_ops.py new file mode 100644 index 0000000000..27d135c4f1 --- /dev/null +++ b/tests/py/dynamo/hlo/test_complex_ops.py @@ -0,0 +1,2070 @@ +""" +Numerical accuracy stress tests for complex tensor decomposition in torch-tensorrt. + +The complex_graph_detection lowering pass rewrites complex-dtype ops to equivalent +real-arithmetic ops before TRT compilation. These tests verify correctness across: + + - I/O boundaries: complex inputs, complex outputs, mixed real/complex I/O + - Internal subgraphs: complex ops entirely within a TRT block + - Operator coverage: mul, add, sub, abs, angle, conj, real/imag extraction, + gather/scatter (select, slice, index), reshape/view, cat/stack, where, + unsqueeze/squeeze, expand/broadcast, type casting + - Chains: multiple sequential complex ops + - Multiple complex tensors interacting in one graph + - Dynamic shapes: batch and seq_len as symbolic dims + +All tests compare PyTorch (CPU/CUDA reference) vs TRT compiled output via +cosine similarity > COSINE_THRESHOLD on both real and imaginary parts. +""" + +import pytest +import torch +import torch.nn as nn +import torch_tensorrt as torchtrt +from torch.export import Dim +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_freqs(seq: int, dim: int, theta: float = 10000.0) -> torch.Tensor: + """Complex unit-magnitude frequency tensor, shape (seq, dim//2).""" + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(seq, dtype=torch.float32) + freqs = torch.outer(t, freqs) + return torch.polar(torch.ones_like(freqs), freqs).cuda() + + +def _cossim_complex(py_out: torch.Tensor, trt_out: torch.Tensor, tag: str) -> None: + """Assert cosine similarity on real and imaginary parts separately.""" + assert trt_out.is_complex(), f"{tag}: expected complex output, got {trt_out.dtype}" + assert ( + trt_out.shape == py_out.shape + ), f"{tag}: shape mismatch {trt_out.shape} vs {py_out.shape}" + r = cosine_similarity(py_out.real.contiguous(), trt_out.real.contiguous()) + i = cosine_similarity(py_out.imag.contiguous(), trt_out.imag.contiguous()) + assert ( + r > COSINE_THRESHOLD + ), f"{tag}: real part cosine sim {r:.4f} < {COSINE_THRESHOLD}" + assert ( + i > COSINE_THRESHOLD + ), f"{tag}: imag part cosine sim {i:.4f} < {COSINE_THRESHOLD}" + + +def _cossim_real(py_out: torch.Tensor, trt_out: torch.Tensor, tag: str) -> None: + """Assert cosine similarity on a real-valued output.""" + assert not trt_out.is_complex(), f"{tag}: expected real output, got {trt_out.dtype}" + s = cosine_similarity(py_out.contiguous(), trt_out.contiguous()) + assert s > COSINE_THRESHOLD, f"{tag}: cosine sim {s:.4f} < {COSINE_THRESHOLD}" + + +_COMPILE = dict(ir="dynamo", min_block_size=1, pass_through_build_failures=True) + + +# =========================================================================== +# 1. I/O boundary tests +# =========================================================================== + + +class ComplexInputRealOutput(nn.Module): + """Complex input → real output (magnitude).""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + # z: complex, output: real magnitude + r = torch.view_as_real(z) + real = r[..., 0] + imag = r[..., 1] + return torch.sqrt(real * real + imag * imag) + + +@pytest.mark.unit +def test_complex_input_real_output(): + model = ComplexInputRealOutput().eval().cuda() + z = _make_freqs(8, 64) # (8, 32) complex64 + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex_input_real_output") + torch._dynamo.reset() + + +class RealInputComplexOutput(nn.Module): + """Real input → complex output (no view_as_real at graph output).""" + + def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + # x: real (B, S, H, D), freqs: complex (S, D//2) + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + return xc * freqs[None, :, None, :] # complex output + + +@pytest.mark.unit +def test_real_input_complex_output(): + model = RealInputComplexOutput().eval().cuda() + x = torch.randn(2, 8, 4, 64).cuda() + freqs = _make_freqs(8, 64) + inputs = (x, freqs) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "real_input_complex_output") + torch._dynamo.reset() + + +class ComplexInputComplexOutput(nn.Module): + """Complex input × complex input → complex output.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return a * b + + +@pytest.mark.unit +def test_complex_input_complex_output(): + model = ComplexInputComplexOutput().eval().cuda() + a = _make_freqs(8, 64) + b = _make_freqs(8, 64) + inputs = (a, b) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_input_complex_output") + torch._dynamo.reset() + + +class MixedRealComplexInputRealOutput(nn.Module): + """One real input, one complex input, real output.""" + + def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + prod = xc * freqs[None, :, None, :] + return torch.view_as_real(prod).flatten(3) + + +@pytest.mark.unit +def test_mixed_real_complex_input_real_output(): + model = MixedRealComplexInputRealOutput().eval().cuda() + x = torch.randn(2, 8, 4, 64).cuda() + freqs = _make_freqs(8, 64) + inputs = (x, freqs) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "mixed_io_real_output") + torch._dynamo.reset() + + +# =========================================================================== +# 2. Operator coverage +# =========================================================================== + + +class ComplexAdd(nn.Module): + """Complex addition: (a+bi) + (c+di) = (a+c) + (b+d)i.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return a + b + + +@pytest.mark.unit +def test_complex_add_output(): + model = ComplexAdd().eval().cuda() + a = _make_freqs(8, 64) + b = _make_freqs(8, 64) + inputs = (a, b) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_add") + torch._dynamo.reset() + + +class ComplexSub(nn.Module): + """Complex subtraction.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return a - b + + +@pytest.mark.unit +def test_complex_sub_output(): + model = ComplexSub().eval().cuda() + a = _make_freqs(8, 64) + b = _make_freqs(8, 64) + inputs = (a, b) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_sub") + torch._dynamo.reset() + + +class ComplexMulChain(nn.Module): + """Chain of two complex multiplications: (a * b) * c.""" + + def forward( + self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor + ) -> torch.Tensor: + return (a * b) * c + + +@pytest.mark.unit +def test_complex_mul_chain(): + model = ComplexMulChain().eval().cuda() + a = _make_freqs(8, 64) + b = _make_freqs(8, 64) + c = _make_freqs(8, 64) + inputs = (a, b, c) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_mul_chain") + torch._dynamo.reset() + + +class ComplexDiv(nn.Module): + """Complex division: (a+bi)/(c+di) = ((ac+bd) + (bc-ad)i) / (c²+d²).""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return a / b + + +@pytest.mark.unit +def test_complex_div(): + model = ComplexDiv().eval().cuda() + a = _make_freqs(8, 64) + b = _make_freqs( + 8, 64, theta=500.0 + ) # different theta → different angles → non-trivial imaginary + inputs = (a, b) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_div") + torch._dynamo.reset() + + +class ComplexScalarMul(nn.Module): + """Scale a complex tensor by a real scalar.""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + # Scale by 2.0 — real * complex + r = torch.view_as_real(z) + scaled = r * 2.0 + return torch.view_as_complex(scaled) + + +@pytest.mark.unit +def test_complex_scalar_mul_output(): + model = ComplexScalarMul().eval().cuda() + z = _make_freqs(8, 64) + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_scalar_mul") + torch._dynamo.reset() + + +class ComplexAbs(nn.Module): + """Complex magnitude: |z| = sqrt(re^2 + im^2) — real output.""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + r = torch.view_as_real(z) + return (r * r).sum(-1).sqrt() + + +@pytest.mark.unit +def test_complex_abs(): + model = ComplexAbs().eval().cuda() + z = _make_freqs(8, 64) + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex_abs") + torch._dynamo.reset() + + +class ComplexAbsNative(nn.Module): + """torch.abs on a complex tensor — exercises the aten.abs.default rewrite.""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.abs(z) + + +@pytest.mark.unit +def test_complex_abs_native(): + model = ComplexAbsNative().eval().cuda() + z = torch.polar(2 * torch.ones(8, 32), torch.randn(8, 32)).cuda() + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex_abs_native") + torch._dynamo.reset() + + +class ComplexExp(nn.Module): + """torch.exp on a complex tensor: exp(a+bi) = e^a*(cos b + i sin b).""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.exp(z) + + +@pytest.mark.unit +def test_complex_exp(): + model = ComplexExp().eval().cuda() + # small magnitudes to keep exp from overflowing + z = torch.polar(0.1 * torch.ones(8, 32), torch.randn(8, 32)).cuda() + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_exp") + torch._dynamo.reset() + + +class ComplexLog(nn.Module): + """torch.log on a complex tensor: log(a+bi) = log|z| + i*angle(z).""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.log(z) + + +@pytest.mark.unit +def test_complex_log(): + model = ComplexLog().eval().cuda() + z = torch.polar(2 * torch.ones(8, 32), torch.randn(8, 32)).cuda() + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_log") + torch._dynamo.reset() + + +class ComplexPow(nn.Module): + """z**n via polar form: r^n * (cos(nθ) + i*sin(nθ)).""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + return z**3 + + +@pytest.mark.unit +def test_complex_pow(): + model = ComplexPow().eval().cuda() + z = torch.polar(torch.ones(8, 32), torch.randn(8, 32)).cuda() + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_pow") + torch._dynamo.reset() + + +class ComplexSqrt(nn.Module): + """torch.sqrt on a complex tensor: z**0.5.""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.sqrt(z) + + +@pytest.mark.unit +def test_complex_sqrt(): + model = ComplexSqrt().eval().cuda() + z = torch.polar(4 * torch.ones(8, 32), torch.randn(8, 32)).cuda() + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_sqrt") + torch._dynamo.reset() + + +class ComplexConj(nn.Module): + """torch.conj on a complex tensor — exercises the _conj rewrite.""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.conj(z) + + +@pytest.mark.unit +def test_complex_conj(): + model = ComplexConj().eval().cuda() + z = _make_freqs(8, 64) + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs).resolve_conj() + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_conj") + torch._dynamo.reset() + + +class ComplexConjMul(nn.Module): + """z * conj(z) = |z|^2 — real-valued result returned as complex.""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + r = torch.view_as_real(z) + re, im = r[..., 0], r[..., 1] + # conj(z) has same real, negated imag + real_part = re * re + im * im # ac - b(-d) = ac + bd when c=a, d=-b + imag_part = torch.zeros_like(real_part) + return torch.view_as_complex(torch.stack([real_part, imag_part], dim=-1)) + + +@pytest.mark.unit +def test_complex_conj_mul(): + model = ComplexConjMul().eval().cuda() + z = _make_freqs(8, 64) + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_conj_mul") + torch._dynamo.reset() + + +# =========================================================================== +# 3. Gather/scatter: select, slice, index +# =========================================================================== + + +class ComplexSelect(nn.Module): + """Select a slice along a dimension from a complex tensor.""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + # z: (S, D) complex — select first half of S, then mul + half = z[:4, :] # slice along seq dim + return half * z[4:, :] # element-wise complex mul, real output via view_as_real + # returns complex — covered by complex output test + + +@pytest.mark.unit +def test_complex_select_and_mul(): + model = ComplexSelect().eval().cuda() + z = _make_freqs(8, 64) # (8, 32) complex64 + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_select_mul") + torch._dynamo.reset() + + +class ComplexSlice(nn.Module): + """Slice two halves of a complex tensor and multiply them.""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + # z: (S, D) complex — split into first and second half along D + half_d = z.shape[-1] // 2 + a = z[:, :half_d] # (S, D//2) complex + b = z[:, half_d:] # (S, D//2) complex + return a * b # complex output + + +@pytest.mark.unit +def test_complex_slice_and_mul(): + model = ComplexSlice().eval().cuda() + z = _make_freqs(8, 64) # (8, 32) complex64 + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_slice_mul") + torch._dynamo.reset() + + +# =========================================================================== +# 4. Shape manipulation: reshape, unsqueeze, squeeze, expand, flatten +# =========================================================================== + + +class ComplexReshapeAndMul(nn.Module): + """Reshape a complex tensor, then multiply.""" + + def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + # x: real (B, S, H*D), freqs: complex (S, D//2) + B, S, HD = x.shape + H = 4 + D = HD // H + xr = x.view(B, S, H, D) + xc = torch.view_as_complex(xr.reshape(B, S, H, -1, 2)) # (B,S,H,D//2) complex + return torch.view_as_real(xc * freqs[None, :, None, :]).flatten(3) + + +@pytest.mark.unit +def test_complex_reshape_and_mul(): + model = ComplexReshapeAndMul().eval().cuda() + x = torch.randn(2, 8, 64).cuda() + freqs = _make_freqs(8, 16) # (8, 8) complex, head_dim=16 -> D//2=8 + inputs = (x, freqs) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex_reshape_mul") + torch._dynamo.reset() + + +class ComplexUnsqueezeExpand(nn.Module): + """Unsqueeze and expand a complex tensor before multiplication.""" + + def forward(self, z: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + # z: (S, D) complex, freqs: (D,) complex + # unsqueeze freqs to broadcast over S + return z * freqs.unsqueeze(0) # (S,D) complex output + + +@pytest.mark.unit +def test_complex_unsqueeze_expand(): + model = ComplexUnsqueezeExpand().eval().cuda() + z = _make_freqs(8, 64) # (8, 32) + freqs = _make_freqs(1, 64).squeeze(0) # (32,) + inputs = (z, freqs) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_unsqueeze_expand") + torch._dynamo.reset() + + +# =========================================================================== +# 5. Concatenation and stacking +# =========================================================================== + + +class ComplexCat(nn.Module): + """Concatenate two complex tensors along the sequence dimension.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return torch.cat([a, b], dim=0) # (2S, D) complex output + + +@pytest.mark.unit +def test_complex_cat(): + model = ComplexCat().eval().cuda() + a = _make_freqs(4, 64) + b = _make_freqs(4, 64) + inputs = (a, b) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_cat") + torch._dynamo.reset() + + +class ComplexCatThenMul(nn.Module): + """Concatenate two complex tensors then multiply by a third.""" + + def forward( + self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor + ) -> torch.Tensor: + ab = torch.cat([a, b], dim=0) # (2S, D) + return ab * c # complex output + + +@pytest.mark.unit +def test_complex_cat_then_mul(): + model = ComplexCatThenMul().eval().cuda() + a = _make_freqs(4, 64) + b = _make_freqs(4, 64) + c = _make_freqs(8, 64) + inputs = (a, b, c) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_cat_then_mul") + torch._dynamo.reset() + + +class ComplexStackRealView(nn.Module): + """Stack real-view representations of two complex tensors, then multiply. + + Tests that the rewriter correctly handles complex ops on stacked real tensors: + view_as_real(a) and view_as_real(b) are stacked, then used to form two + independent complex multiplications. + """ + + def forward( + self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor + ) -> torch.Tensor: + # a, b: (S, D) complex, c: (S, D) complex + # Multiply each independently and add — tests multiple complex paths + return torch.view_as_real(a * c).flatten(-2) + torch.view_as_real( + b * c + ).flatten(-2) + + +@pytest.mark.unit +def test_complex_stack_real_view(): + model = ComplexStackRealView().eval().cuda() + a = _make_freqs(8, 64) + b = _make_freqs(8, 64) + c = _make_freqs(8, 64) + inputs = (a, b, c) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex_stack_real_view") + torch._dynamo.reset() + + +# =========================================================================== +# 6. Where / masked selection +# =========================================================================== + + +class ComplexWhere(nn.Module): + """Conditional selection between two complex tensors.""" + + def forward( + self, mask: torch.Tensor, a: torch.Tensor, b: torch.Tensor + ) -> torch.Tensor: + # Operate on real/imag separately — where doesn't support complex natively + ar = torch.view_as_real(a) + br = torch.view_as_real(b) + m = mask.unsqueeze(-1) # broadcast over last (2,) dim + out = torch.where(m, ar, br) + return torch.view_as_complex(out.contiguous()) + + +@pytest.mark.unit +def test_complex_where(): + model = ComplexWhere().eval().cuda() + a = _make_freqs(8, 64) + b = _make_freqs(8, 64) + mask = (torch.randn(8, 32) > 0).cuda() + inputs = (mask, a, b) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_where") + torch._dynamo.reset() + + +# =========================================================================== +# 7. Multiple complex subgraphs in one model +# =========================================================================== + + +class DualComplexPath(nn.Module): + """Two independent complex multiplications merged at the output. + + freqs is passed already broadcast-ready (same shape as the complex view of x/y) + so no indexing/unsqueeze is needed on the complex tensor. + """ + + def forward( + self, + x: torch.Tensor, + y: torch.Tensor, + freqs: torch.Tensor, + ) -> torch.Tensor: + # Path A: x rotated by freqs + xa = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + out_a = torch.view_as_real(xa * freqs).flatten(3) + # Path B: y rotated by same freqs + xb = torch.view_as_complex(y.reshape(*y.shape[:-1], -1, 2)) + out_b = torch.view_as_real(xb * freqs).flatten(3) + return out_a + out_b # real output + + +@pytest.mark.unit +def test_dual_complex_path(): + model = DualComplexPath().eval().cuda() + x = torch.randn(2, 8, 4, 64).cuda() + y = torch.randn(2, 8, 4, 64).cuda() + # freqs must match the complex view shape (2,8,4,32) — broadcast via register_buffer + freqs = ( + _make_freqs(8, 64).unsqueeze(0).unsqueeze(2).expand(2, 8, 4, 32).contiguous() + ) + inputs = (x, y, freqs) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "dual_complex_path") + torch._dynamo.reset() + + +# =========================================================================== +# 8. Complex ops interleaved with real ops +# =========================================================================== + + +class ComplexSandwich(nn.Module): + """Real → complex → real → linear → complex → real sandwich. + + Uses a buffer for freqs so the complex tensor is a get_attr (not placeholder), + which the rewriter handles via stacked real tensor. + """ + + def __init__(self, freqs: torch.Tensor) -> None: + super().__init__() + self.register_buffer("freqs", freqs) + self.linear = nn.Linear(64, 64, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # real → complex rotation using buffer freqs (get_attr path) + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + rotated = torch.view_as_real(xc * self.freqs).flatten(3) # (B,S,H,D) real + # real linear + out = self.linear(rotated) + # another complex rotation + outc = torch.view_as_complex(out.reshape(*out.shape[:-1], -1, 2)) + return torch.view_as_real(outc * self.freqs).flatten(3) + + +@pytest.mark.unit +def test_complex_sandwich(): + freqs = ( + _make_freqs(8, 64).unsqueeze(0).unsqueeze(2).expand(2, 8, 4, 32).contiguous() + ) + model = ComplexSandwich(freqs).eval().cuda() + x = torch.randn(2, 8, 4, 64).cuda() + inputs = (x,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex_sandwich") + torch._dynamo.reset() + + +# =========================================================================== +# 9. Complex nn.Parameter (get_attr path) +# =========================================================================== + + +class ComplexParamMul(nn.Module): + """Complex weight stored as nn.Parameter — exercises the get_attr rewrite path.""" + + def __init__(self, freqs: torch.Tensor) -> None: + super().__init__() + # nn.Parameter, not register_buffer — still a get_attr node in the exported graph + self.freqs = nn.Parameter(freqs, requires_grad=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + return torch.view_as_real(xc * self.freqs).flatten(3) + + +@pytest.mark.unit +def test_complex_param_get_attr(): + freqs = ( + _make_freqs(8, 64).unsqueeze(0).unsqueeze(2).expand(2, 8, 4, 32).contiguous() + ) + model = ComplexParamMul(freqs).eval().cuda() + x = torch.randn(2, 8, 4, 64).cuda() + inputs = (x,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex_param_get_attr") + torch._dynamo.reset() + + +# =========================================================================== +# 10. Dynamic shapes +# =========================================================================== + + +class ComplexMulDynamic(nn.Module): + """Complex RoPE with dynamic seq_len.""" + + def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + return torch.view_as_real(xc * freqs[None, :, None, :]).flatten(3) + + +@pytest.mark.unit +def test_complex_mul_dynamic_seqlen(): + """Dynamic seq_len: x has shape (B, seq, H, D), freqs has shape (seq, D//2).""" + model = ComplexMulDynamic().eval().cuda() + x = torch.randn(2, 8, 4, 64).cuda() + freqs = _make_freqs(8, 64) # (8, 32) + inputs = (x, freqs) + + # x dim-1 and freqs dim-0 are both the seq dimension — share the same Dim + seq = Dim("seq", min=2, max=64) + dynamic_shapes = ({1: seq}, {0: seq}) + ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + trt_model = torchtrt.dynamo.compile( + ep, inputs=inputs, min_block_size=1, pass_through_build_failures=True + ) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex_mul_dynamic_seqlen") + torch._dynamo.reset() + + +class ComplexOutputDynamic(nn.Module): + """Complex output with dynamic batch.""" + + def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + return xc * freqs[None, :, None, :] # complex output + + +@pytest.mark.unit +def test_complex_output_dynamic_batch(): + model = ComplexOutputDynamic().eval().cuda() + x = torch.randn(2, 8, 4, 64).cuda() + freqs = _make_freqs(8, 64) + inputs = (x, freqs) + + batch = Dim("batch", min=1, max=8) + dynamic_shapes = ({0: batch}, {}) + ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + trt_model = torchtrt.dynamo.compile( + ep, inputs=inputs, min_block_size=1, pass_through_build_failures=True + ) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_output_dynamic_batch") + torch._dynamo.reset() + + +# =========================================================================== +# 11. Numerical precision: complex64 vs truncated complex128 +# =========================================================================== + + +class Complex128Model(nn.Module): + """Uses complex128 (double precision).""" + + def forward(self, z: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + return torch.view_as_real(z * w).flatten(-2) + + +@pytest.mark.unit +def test_complex128_truncated_to_float32(): + """complex128 with truncate_double=True should compile to float32 arithmetic.""" + model = Complex128Model().eval().cuda() + z = torch.polar( + torch.ones(8, 32, dtype=torch.float64), + torch.randn(8, 32, dtype=torch.float64), + ).cuda() + w = torch.polar( + torch.ones(8, 32, dtype=torch.float64), + torch.randn(8, 32, dtype=torch.float64), + ).cuda() + inputs = (z, w) + trt_model = torchtrt.compile( + model, + inputs=inputs, + ir="dynamo", + min_block_size=1, + pass_through_build_failures=True, + truncate_double=True, + ) + py_out = model(*inputs).float() # cast reference to float32 for comparison + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex128_truncated") + torch._dynamo.reset() + + +# =========================================================================== +# 12. End-to-end: full attention-style block with complex RoPE +# =========================================================================== + + +class AttentionWithComplexRoPE(nn.Module): + """Multi-head self-attention with complex-number RoPE and real output.""" + + def __init__(self, d_model: int = 64, n_heads: int = 4) -> None: + super().__init__() + self.n_heads = n_heads + self.head_dim = d_model // n_heads + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + self.out_proj = nn.Linear(d_model, d_model, bias=False) + + def _apply_rope(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + B, S, H, D = x.shape + xc = torch.view_as_complex(x.reshape(B, S, H, -1, 2)) + return torch.view_as_real(xc * freqs[None, :, None, :]).flatten(3) + + def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + B, S, C = x.shape + H, D = self.n_heads, self.head_dim + q = self.q_proj(x).view(B, S, H, D) + k = self.k_proj(x).view(B, S, H, D) + v = self.v_proj(x).view(B, S, H, D) + q = self._apply_rope(q, freqs) + k = self._apply_rope(k, freqs) + # Scaled dot-product attention + scale = D**-0.5 + attn = torch.einsum("bshd,bthd->bhst", q, k) * scale + attn = torch.softmax(attn, dim=-1) + out = torch.einsum("bhst,bthd->bshd", attn, v).reshape(B, S, C) + return self.out_proj(out) + + +@pytest.mark.unit +def test_attention_with_complex_rope_static(): + model = AttentionWithComplexRoPE(d_model=64, n_heads=4).eval().cuda() + x = torch.randn(2, 8, 64).cuda() + freqs = _make_freqs(8, 16) # head_dim=16, D//2=8 + inputs = (x, freqs) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "attention_with_complex_rope") + torch._dynamo.reset() + + +# =========================================================================== +# 13. Elementwise-safe structural ops (clone, permute) +# =========================================================================== + + +class ComplexClone(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return z.clone() * z + + +@pytest.mark.unit +def test_complex_clone(): + model = ComplexClone().eval().cuda() + z = _make_freqs(8, 32) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_clone") + torch._dynamo.reset() + + +class ComplexPermute(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + # permute spatial dims only, then apply mul so complex subgraph is detected + return z.permute(1, 0) * z.permute(1, 0) + + +@pytest.mark.unit +def test_complex_permute(): + model = ComplexPermute().eval().cuda() + z = _make_freqs(8, 32) # (8, 16) complex64 + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_permute") + torch._dynamo.reset() + + +# =========================================================================== +# 14. Extraction / construction ops +# =========================================================================== + + +class ComplexRealExtract(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return z.real + + +@pytest.mark.unit +def test_complex_real(): + model = ComplexRealExtract().eval().cuda() + z = _make_freqs(8, 32) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + py_out = model(z) + trt_out = trt_model(z) + _cossim_real(py_out, trt_out, "complex_real") + torch._dynamo.reset() + + +class ComplexImagExtract(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return z.imag + + +@pytest.mark.unit +def test_complex_imag(): + model = ComplexImagExtract().eval().cuda() + z = _make_freqs(8, 32) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + py_out = model(z) + trt_out = trt_model(z) + _cossim_real(py_out, trt_out, "complex_imag") + torch._dynamo.reset() + + +class ComplexAngle(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.angle(z) + + +@pytest.mark.unit +def test_complex_angle(): + model = ComplexAngle().eval().cuda() + z = _make_freqs(8, 32) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + py_out = model(z) + trt_out = trt_model(z) + _cossim_real(py_out, trt_out, "complex_angle") + torch._dynamo.reset() + + +class ComplexPolar(nn.Module): + def forward(self, r: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: + return torch.polar(r, theta) + + +@pytest.mark.unit +def test_complex_polar(): + r = torch.rand(8, 16, device="cuda") + 0.1 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + model = ComplexPolar().eval().cuda() + trt_model = torchtrt.compile(model, inputs=(r, theta), **_COMPILE) + py_out = model(r, theta) + trt_out = trt_model(r, theta) + _cossim_complex(py_out, trt_out, "complex_polar") + torch._dynamo.reset() + + +class ComplexReciprocal(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.reciprocal(z) + + +@pytest.mark.unit +def test_complex_reciprocal(): + model = ComplexReciprocal().eval().cuda() + # Use non-unit magnitude to avoid trivial 1/z=conj(z) for |z|=1 + z = _make_freqs(8, 32) * 2.0 + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_reciprocal") + torch._dynamo.reset() + + +class ComplexRsqrt(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.rsqrt(z) + + +@pytest.mark.unit +def test_complex_rsqrt(): + model = ComplexRsqrt().eval().cuda() + # Use polar form with r > 0 so rsqrt is well-defined + r = torch.rand(8, 16, device="cuda") + 0.5 + theta = torch.rand(8, 16, device="cuda") * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_rsqrt") + torch._dynamo.reset() + + +class ComplexAddScalar(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + # Exercise add.Scalar: (a+bi)+2 = (a+2)+bi + return torch.view_as_complex(torch.view_as_real(z).add(0.0)) + 2.0 + + +@pytest.mark.unit +def test_complex_add_scalar(): + """add.Scalar: scalar adds to real part only — (a+2) + bi.""" + model = ComplexAddScalar().eval().cuda() + z = _make_freqs(8, 32) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_add_scalar") + torch._dynamo.reset() + + +class ComplexSgn(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.sgn(z) + + +@pytest.mark.unit +def test_complex_sgn(): + """sgn(z) = z/|z|, sgn(0) = 0.""" + model = ComplexSgn().eval().cuda() + r = torch.rand(8, 16, device="cuda") + 0.1 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + # Include one zero entry + r[0, 0] = 0.0 + theta[0, 0] = 0.0 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_sgn") + torch._dynamo.reset() + + +class ComplexLog2(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.log2(z) + + +@pytest.mark.unit +def test_complex_log2(): + model = ComplexLog2().eval().cuda() + r = torch.rand(8, 16, device="cuda") + 0.5 + theta = torch.rand(8, 16, device="cuda") * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_log2") + torch._dynamo.reset() + + +class ComplexLog10(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.log10(z) + + +@pytest.mark.unit +def test_complex_log10(): + model = ComplexLog10().eval().cuda() + r = torch.rand(8, 16, device="cuda") + 0.5 + theta = torch.rand(8, 16, device="cuda") * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_log10") + torch._dynamo.reset() + + +class ComplexLog1p(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.log1p(z) + + +@pytest.mark.unit +def test_complex_log1p(): + model = ComplexLog1p().eval().cuda() + # |z| < 1 for numerical stability + r = torch.rand(8, 16, device="cuda") * 0.5 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_log1p") + torch._dynamo.reset() + + +class ComplexExpm1(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.expm1(z) + + +@pytest.mark.unit +def test_complex_expm1(): + model = ComplexExpm1().eval().cuda() + # Small magnitude to avoid exp overflow + r = torch.rand(8, 16, device="cuda") * 0.3 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_expm1") + torch._dynamo.reset() + + +class ComplexIsnan(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + # Output bool → cast to float for cosine sim + return torch.isnan(z).float() + + +@pytest.mark.unit +def test_complex_isnan(): + model = ComplexIsnan().eval().cuda() + z = _make_freqs(8, 32) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + py_out = model(z) + trt_out = trt_model(z) + # All-zero output: check element-wise equality + assert torch.allclose(py_out, trt_out), "complex_isnan: output mismatch" + torch._dynamo.reset() + + +class ComplexIsinf(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.isinf(z).float() + + +@pytest.mark.unit +def test_complex_isinf(): + model = ComplexIsinf().eval().cuda() + z = _make_freqs(8, 32) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + py_out = model(z) + trt_out = trt_model(z) + assert torch.allclose(py_out, trt_out), "complex_isinf: output mismatch" + torch._dynamo.reset() + + +# =========================================================================== +# 15. Trigonometric ops +# =========================================================================== + + +class ComplexSin(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.sin(z) + + +@pytest.mark.unit +def test_complex_sin(): + model = ComplexSin().eval().cuda() + r = torch.ones(8, 16, device="cuda") * 0.5 + theta = torch.linspace(0.1, 1.5, 16, device="cuda").unsqueeze(0).expand(8, -1) + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_sin") + torch._dynamo.reset() + + +class ComplexCos(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.cos(z) + + +@pytest.mark.unit +def test_complex_cos(): + model = ComplexCos().eval().cuda() + r = torch.ones(8, 16, device="cuda") * 0.5 + theta = torch.linspace(0.1, 1.5, 16, device="cuda").unsqueeze(0).expand(8, -1) + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_cos") + torch._dynamo.reset() + + +class ComplexSinh(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.sinh(z) + + +@pytest.mark.unit +def test_complex_sinh(): + model = ComplexSinh().eval().cuda() + # Small imaginary part to avoid cosh overflow + r = torch.rand(8, 16, device="cuda") * 0.5 + theta = torch.rand(8, 16, device="cuda") * 0.5 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_sinh") + torch._dynamo.reset() + + +class ComplexCosh(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.cosh(z) + + +@pytest.mark.unit +def test_complex_cosh(): + model = ComplexCosh().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.5 + theta = torch.rand(8, 16, device="cuda") * 0.5 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_cosh") + torch._dynamo.reset() + + +class ComplexTan(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.tan(z) + + +@pytest.mark.unit +def test_complex_tan(): + model = ComplexTan().eval().cuda() + # Avoid a = ±pi/4 where denom → 0 + r = torch.rand(8, 16, device="cuda") * 0.4 + theta = torch.rand(8, 16, device="cuda") * 0.3 + 0.5 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_tan") + torch._dynamo.reset() + + +class ComplexTanh(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.tanh(z) + + +@pytest.mark.unit +def test_complex_tanh(): + model = ComplexTanh().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.5 + theta = torch.rand(8, 16, device="cuda") * 0.5 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_tanh") + torch._dynamo.reset() + + +# =========================================================================== +# 16. Inverse trigonometric ops +# =========================================================================== + + +class ComplexAsinh(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.asinh(z) + + +@pytest.mark.unit +def test_complex_asinh(): + model = ComplexAsinh().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.8 + 0.1 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_asinh") + torch._dynamo.reset() + + +class ComplexAcosh(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.acosh(z) + + +@pytest.mark.unit +def test_complex_acosh(): + model = ComplexAcosh().eval().cuda() + # |Re(z)| > 1 for non-trivial (non-purely-imaginary) result + r = torch.rand(8, 16, device="cuda") * 0.5 + 1.5 + theta = torch.rand(8, 16, device="cuda") * 0.4 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_acosh") + torch._dynamo.reset() + + +class ComplexAtanh(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.atanh(z) + + +@pytest.mark.unit +def test_complex_atanh(): + model = ComplexAtanh().eval().cuda() + # |z| < 1 to stay within principal domain + r = torch.rand(8, 16, device="cuda") * 0.6 + 0.1 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_atanh") + torch._dynamo.reset() + + +class ComplexAsin(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.asin(z) + + +@pytest.mark.unit +def test_complex_asin(): + """asin(z) = -i*log(iz + sqrt(1-z²)). + Tested with |z| < 1 to avoid branch-cut ambiguity on the real axis.""" + model = ComplexAsin().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.6 + 0.1 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_asin") + torch._dynamo.reset() + + +class ComplexAcos(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.acos(z) + + +@pytest.mark.unit +def test_complex_acos(): + """acos(z) = -i*log(z + i*sqrt(1-z²)). + Tested with |z| < 1 to avoid branch-cut ambiguity on the real axis.""" + model = ComplexAcos().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.6 + 0.1 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_acos") + torch._dynamo.reset() + + +class ComplexAtan(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.atan(z) + + +@pytest.mark.unit +def test_complex_atan(): + """atan(z) = (i/2)*log((1-iz)/(1+iz)). + Tested with |z| < 1.""" + model = ComplexAtan().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.6 + 0.1 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_atan") + torch._dynamo.reset() + + +# =========================================================================== +# 17. Complex-complex power (pow.Tensor_Tensor) +# =========================================================================== + + +class ComplexPowTensorTensor(nn.Module): + def forward(self, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor: + return torch.pow(z1, z2) + + +@pytest.mark.unit +def test_complex_pow_tensor_tensor(): + """z1**z2 = exp(z2 * log(z1)), both complex.""" + model = ComplexPowTensorTensor().eval().cuda() + # Use unit-magnitude base to keep values bounded + r1 = torch.ones(8, 16, device="cuda") + theta1 = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z1 = torch.polar(r1, theta1) + # Small exponent magnitude to avoid overflow + r2 = torch.rand(8, 16, device="cuda") * 0.3 + theta2 = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z2 = torch.polar(r2, theta2) + trt_model = torchtrt.compile(model, inputs=(z1, z2), **_COMPILE) + _cossim_complex(model(z1, z2), trt_model(z1, z2), "complex_pow_tensor_tensor") + torch._dynamo.reset() + + +# =========================================================================== +# 18. Composite complex-only multi-op chains +# =========================================================================== + + +class ComplexLogExp(nn.Module): + """exp(log(z)) ≈ z — round-trip through log and exp.""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.exp(torch.log(z)) + + +@pytest.mark.unit +def test_complex_log_exp(): + """exp(log(z)) ≈ z: round-trip verifies log and exp rewrites compose correctly.""" + model = ComplexLogExp().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.8 + 0.2 + theta = torch.rand(8, 16, device="cuda") * 1.5 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_log_exp") + torch._dynamo.reset() + + +class ComplexMulAddSub(nn.Module): + """(a*b)+c-d — four complex operands, two muls and add/sub.""" + + def forward(self, a, b, c, d): + return (a * b) + c - d + + +@pytest.mark.unit +def test_complex_mul_add_sub(): + """(a*b)+c-d with four complex inputs.""" + model = ComplexMulAddSub().eval().cuda() + + def _rc(): + return torch.polar( + torch.rand(8, 16, device="cuda") + 0.2, + torch.rand(8, 16, device="cuda") * 2 * 3.14159, + ) + + a, b, c, d = _rc(), _rc(), _rc(), _rc() + trt_model = torchtrt.compile(model, inputs=(a, b, c, d), **_COMPILE) + _cossim_complex(model(a, b, c, d), trt_model(a, b, c, d), "complex_mul_add_sub") + torch._dynamo.reset() + + +class ComplexConjThenMul(nn.Module): + """conj(a) * b.""" + + def forward(self, a, b): + return torch.conj(a) * b + + +@pytest.mark.unit +def test_complex_conj_then_mul(): + """conj(a)*b: conjugate followed by complex multiply.""" + model = ComplexConjThenMul().eval().cuda() + + def _rc(): + return torch.polar( + torch.rand(8, 16, device="cuda") + 0.2, + torch.rand(8, 16, device="cuda") * 2 * 3.14159, + ) + + a, b = _rc(), _rc() + trt_model = torchtrt.compile(model, inputs=(a, b), **_COMPILE) + _cossim_complex(model(a, b), trt_model(a, b), "complex_conj_then_mul") + torch._dynamo.reset() + + +class ComplexAbsThenLog(nn.Module): + """log(abs(z)) — chain ending in real output.""" + + def forward(self, z): + return torch.log(torch.abs(z)) + + +@pytest.mark.unit +def test_complex_abs_then_log(): + """log(|z|): abs(complex) → log(real), result is real.""" + model = ComplexAbsThenLog().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.8 + 0.2 + z = torch.polar(r, torch.rand(8, 16, device="cuda") * 2 * 3.14159) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + py_out = model(z) + trt_out = trt_model(z) + _cossim_real(py_out, trt_out, "complex_abs_then_log") + torch._dynamo.reset() + + +class ComplexSqrtThenMul(nn.Module): + """sqrt(a) * sqrt(b) — two sqrt rewrites in one graph.""" + + def forward(self, a, b): + return torch.sqrt(a) * torch.sqrt(b) + + +@pytest.mark.unit +def test_complex_sqrt_then_mul(): + """sqrt(a)*sqrt(b) ≈ sqrt(a*b) — exercises two sqrt rewrites in one graph.""" + model = ComplexSqrtThenMul().eval().cuda() + r = torch.rand(8, 16, device="cuda") + 0.5 + a = torch.polar(r, torch.rand(8, 16, device="cuda") * 3.14159) + b = torch.polar(r, torch.rand(8, 16, device="cuda") * 3.14159) + trt_model = torchtrt.compile(model, inputs=(a, b), **_COMPILE) + _cossim_complex(model(a, b), trt_model(a, b), "complex_sqrt_then_mul") + torch._dynamo.reset() + + +class ComplexPowThenAdd(nn.Module): + """z**2 + z — polynomial evaluation via pow + add.""" + + def forward(self, z): + return z**2 + z + + +@pytest.mark.unit +def test_complex_pow_then_add(): + """z² + z — quadratic in z, exercises pow.Tensor_Scalar → add chain.""" + model = ComplexPowThenAdd().eval().cuda() + z = torch.polar( + torch.rand(8, 16, device="cuda") * 0.8 + 0.2, + torch.rand(8, 16, device="cuda") * 2 * 3.14159, + ) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_pow_then_add") + torch._dynamo.reset() + + +class ComplexSinCosPythagorean(nn.Module): + """sin(z)² + cos(z)² — Pythagorean identity over ℂ.""" + + def forward(self, z): + s = torch.sin(z) + c = torch.cos(z) + return s * s + c * c + + +@pytest.mark.unit +def test_complex_sin_cos_pythagorean(): + """sin²(z) + cos²(z): TRT vs PyTorch agree numerically.""" + model = ComplexSinCosPythagorean().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.4 + theta = ( + torch.linspace(0.1, 1.2, 16, device="cuda") + .unsqueeze(0) + .expand(8, -1) + .contiguous() + ) + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + py_out = model(z) + trt_out = trt_model(z) + _cossim_real( + py_out.real.contiguous(), + trt_out.real.contiguous(), + "complex_sin_cos_pythagorean", + ) + torch._dynamo.reset() + + +class ComplexExpThenAbs(nn.Module): + """|exp(z)| = exp(Re(z)) — chain: exp → abs, result is real.""" + + def forward(self, z): + return torch.abs(torch.exp(z)) + + +@pytest.mark.unit +def test_complex_exp_then_abs(): + """|exp(z)| = exp(Re(z)): exercises exp rewrite feeding into abs rewrite.""" + model = ComplexExpThenAbs().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.3 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + py_out = model(z) + trt_out = trt_model(z) + _cossim_real(py_out, trt_out, "complex_exp_then_abs") + torch._dynamo.reset() + + +class ComplexNormalize(nn.Module): + """z / |z| — normalize to unit circle via abs + divide.""" + + def forward(self, z): + mag = torch.abs(z) + # Avoid aten.complex.default — build complex divisor via view_as_complex. + mag_c = torch.view_as_complex(torch.stack([mag, torch.zeros_like(mag)], dim=-1)) + return z / mag_c + + +@pytest.mark.unit +def test_complex_normalize(): + """z/|z|: unit-normalize a complex tensor.""" + model = ComplexNormalize().eval().cuda() + z = torch.polar( + torch.rand(8, 16, device="cuda") * 0.8 + 0.2, + torch.rand(8, 16, device="cuda") * 2 * 3.14159, + ) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_normalize") + torch._dynamo.reset() + + +# =========================================================================== +# 19. Complex + real interleaved computations +# =========================================================================== + + +class ComplexMulThenRealLinear(nn.Module): + """Complex rotation followed by a real-valued linear projection (core RoPE pattern).""" + + def __init__(self) -> None: + super().__init__() + self.proj = nn.Linear(64, 32, bias=False) + + def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + rotated = torch.view_as_real(xc * freqs).flatten(-2) + return self.proj(rotated) + + +@pytest.mark.unit +def test_complex_mul_then_real_linear(): + """Complex RoPE rotation followed by a real linear layer.""" + model = ComplexMulThenRealLinear().eval().cuda() + x = torch.randn(2, 8, 64, device="cuda") + freqs = _make_freqs(8, 64) + trt_model = torchtrt.compile(model, inputs=(x, freqs), **_COMPILE) + py_out = model(x, freqs) + trt_out = trt_model(x, freqs) + _cossim_real(py_out, trt_out, "complex_mul_then_real_linear") + torch._dynamo.reset() + + +class RealNormThenComplexMul(nn.Module): + """LayerNorm on the real input, then rotate with complex freqs.""" + + def __init__(self, d: int = 64) -> None: + super().__init__() + self.norm = nn.LayerNorm(d) + + def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + return torch.view_as_real(xc * freqs).flatten(-2) + + +@pytest.mark.unit +def test_real_norm_then_complex_mul(): + """LayerNorm (real) → view_as_complex → complex mul → view_as_real.""" + model = RealNormThenComplexMul(d=64).eval().cuda() + x = torch.randn(2, 8, 64, device="cuda") + freqs = _make_freqs(8, 64) + trt_model = torchtrt.compile(model, inputs=(x, freqs), **_COMPILE) + py_out = model(x, freqs) + trt_out = trt_model(x, freqs) + _cossim_real(py_out, trt_out, "real_norm_then_complex_mul") + torch._dynamo.reset() + + +class ComplexMulThenRealActivation(nn.Module): + """Complex rotation → real view → GELU activation.""" + + def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + real_out = torch.view_as_real(xc * freqs).flatten(-2) + return torch.nn.functional.gelu(real_out) + + +@pytest.mark.unit +def test_complex_mul_then_gelu(): + """Complex rotation followed by GELU on the real-valued output.""" + model = ComplexMulThenRealActivation().eval().cuda() + x = torch.randn(2, 8, 64, device="cuda") + freqs = _make_freqs(8, 64) + trt_model = torchtrt.compile(model, inputs=(x, freqs), **_COMPILE) + py_out = model(x, freqs) + trt_out = trt_model(x, freqs) + _cossim_real(py_out, trt_out, "complex_mul_then_gelu") + torch._dynamo.reset() + + +class RealScaleThenComplexAddSub(nn.Module): + """Scale two real tensors, pack as complex, do add and sub.""" + + def __init__(self) -> None: + super().__init__() + self.scale = nn.Parameter(torch.ones(1)) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # x, y: real (B, D, 2) — pack as complex + xa = x * self.scale + ya = y * self.scale + zx = torch.view_as_complex(xa) + zy = torch.view_as_complex(ya) + return torch.view_as_real(zx + zy - zx) + + +@pytest.mark.unit +def test_real_scale_then_complex_add_sub(): + """Real scale → pack as complex → add/sub → unpack.""" + model = RealScaleThenComplexAddSub().eval().cuda() + x = torch.randn(4, 16, 2, device="cuda") + y = torch.randn(4, 16, 2, device="cuda") + trt_model = torchtrt.compile(model, inputs=(x, y), **_COMPILE) + py_out = model(x, y) + trt_out = trt_model(x, y) + _cossim_real(py_out, trt_out, "real_scale_then_complex_add_sub") + torch._dynamo.reset() + + +class ComplexMagPhaseRecompose(nn.Module): + """Decompose into magnitude + phase, apply real ops to each, recompose.""" + + def __init__(self) -> None: + super().__init__() + self.mag_scale = nn.Parameter(torch.ones(1)) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + mag = torch.abs(z) + phase = torch.angle(z) + mag2 = mag * self.mag_scale.abs() + phase2 = torch.clamp(phase, -1.5, 1.5) + return torch.polar(mag2, phase2) + + +@pytest.mark.unit +def test_complex_mag_phase_recompose(): + """Decompose z → (|z|, angle) → scale+clip → polar recompose.""" + model = ComplexMagPhaseRecompose().eval().cuda() + r = torch.rand(8, 16, device="cuda") + 0.3 + theta = torch.rand(8, 16, device="cuda") * 2.0 - 1.0 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_mag_phase_recompose") + torch._dynamo.reset() + + +class ComplexResidual(nn.Module): + """Complex residual: z + exp(log(z)).""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + return z + torch.exp(torch.log(z)) + + +@pytest.mark.unit +def test_complex_residual(): + """z + exp(log(z)) ≈ 2z — residual connection through complex ops.""" + model = ComplexResidual().eval().cuda() + r = torch.rand(8, 16, device="cuda") + 0.5 + theta = torch.rand(8, 16, device="cuda") * 1.5 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_residual") + torch._dynamo.reset() + + +class ComplexGatedMul(nn.Module): + """Real sigmoid gate applied to a complex tensor.""" + + def __init__(self) -> None: + super().__init__() + self.gate_proj = nn.Linear(32, 16, bias=False) + + def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor: + gate = torch.sigmoid(self.gate_proj(x)) + gate_c = torch.view_as_complex( + torch.stack([gate, torch.zeros_like(gate)], dim=-1) + ) + return z * gate_c + + +@pytest.mark.unit +def test_complex_gated_mul(): + """Real sigmoid gate × complex tensor — real and complex subgraphs in one model.""" + model = ComplexGatedMul().eval().cuda() + x = torch.randn(4, 32, device="cuda") + z = torch.polar( + torch.rand(4, 16, device="cuda") + 0.3, + torch.rand(4, 16, device="cuda") * 2 * 3.14159, + ) + trt_model = torchtrt.compile(model, inputs=(x, z), **_COMPILE) + _cossim_complex(model(x, z), trt_model(x, z), "complex_gated_mul") + torch._dynamo.reset() + + +# =========================================================================== +# 20. Multi-layer and branching subgraph integration tests +# =========================================================================== + + +class MultiHeadRoPE(nn.Module): + """Apply independent RoPE rotations to Q, K, V and compute attention logits.""" + + def __init__(self, seq: int = 8, dim: int = 32) -> None: + super().__init__() + self.freq_q = nn.Parameter(_make_freqs(seq, dim).detach()) + self.freq_k = nn.Parameter(_make_freqs(seq, dim).detach()) + self.freq_v = nn.Parameter(_make_freqs(seq, dim).detach()) + + def forward(self, q, k, v): + def rope(x, freq): + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + return torch.view_as_real(xc * freq).flatten(-2) + + q_r = rope(q, self.freq_q) + k_r = rope(k, self.freq_k) + v_r = rope(v, self.freq_v) + scores = torch.bmm(q_r, k_r.transpose(1, 2)) / (q_r.shape[-1] ** 0.5) + return torch.bmm(scores, v_r) + + +@pytest.mark.unit +def test_multi_head_rope(): + """Q/K/V independently rotated by RoPE, then bmm attention — 3 complex subgraphs.""" + model = MultiHeadRoPE(seq=8, dim=32).eval().cuda() + B, S, D = 2, 8, 32 + q = torch.randn(B, S, D, device="cuda") + k = torch.randn(B, S, D, device="cuda") + v = torch.randn(B, S, D, device="cuda") + trt_model = torchtrt.compile(model, inputs=(q, k, v), **_COMPILE) + _cossim_real(model(q, k, v), trt_model(q, k, v), "multi_head_rope") + torch._dynamo.reset() + + +class ParallelComplexBranches(nn.Module): + """One complex input forks into two independent rotation paths, then concat + project.""" + + def __init__(self, dim: int = 16) -> None: + super().__init__() + self.freq_a = nn.Parameter(_make_freqs(8, dim * 2).detach()) + self.freq_b = nn.Parameter(_make_freqs(8, dim * 2).detach()) + self.proj = nn.Linear(dim * 4, dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + z = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + real_a = torch.view_as_real(z * self.freq_a).flatten(-2) + real_b = torch.view_as_real(z * self.freq_b).flatten(-2) + return self.proj(torch.cat([real_a, real_b], dim=-1)) + + +@pytest.mark.unit +def test_parallel_complex_branches(): + """One complex input forks into two rotation paths, concat, then project.""" + model = ParallelComplexBranches(dim=16).eval().cuda() + x = torch.randn(2, 8, 32, device="cuda") + trt_model = torchtrt.compile(model, inputs=(x,), **_COMPILE) + _cossim_real(model(x), trt_model(x), "parallel_complex_branches") + torch._dynamo.reset() + + +class TransformerLikeBlock(nn.Module): + """One layer: RoPE rotation + real FFN with residual.""" + + def __init__(self, d: int = 32) -> None: + super().__init__() + self.freq = nn.Parameter(_make_freqs(8, d).detach()) + self.norm = nn.LayerNorm(d) + self.ff1 = nn.Linear(d, d * 2, bias=False) + self.ff2 = nn.Linear(d * 2, d, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + rotated = torch.view_as_real(xc * self.freq).flatten(-2) + h = self.norm(rotated) + h = torch.nn.functional.gelu(self.ff1(h)) + h = self.ff2(h) + return rotated + h + + +class StackedTransformerBlocks(nn.Module): + """Two sequential transformer-like blocks, each with complex RoPE.""" + + def __init__(self, d: int = 32) -> None: + super().__init__() + self.block1 = TransformerLikeBlock(d) + self.block2 = TransformerLikeBlock(d) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.block2(self.block1(x)) + + +@pytest.mark.unit +def test_stacked_transformer_blocks(): + """Two stacked transformer-like blocks, each containing a complex RoPE sub-graph.""" + model = StackedTransformerBlocks(d=32).eval().cuda() + x = torch.randn(2, 8, 32, device="cuda") + trt_model = torchtrt.compile(model, inputs=(x,), **_COMPILE) + _cossim_real(model(x), trt_model(x), "stacked_transformer_blocks") + torch._dynamo.reset() + + +class FourComplexInputsMulAdd(nn.Module): + """z1*z2 + z3*z4 — four distinct complex runtime inputs.""" + + def forward(self, z1, z2, z3, z4): + return z1 * z2 + z3 * z4 + + +@pytest.mark.unit +def test_four_complex_inputs_mul_add(): + """z1*z2 + z3*z4 — four complex runtime inputs, two muls and one add.""" + model = FourComplexInputsMulAdd().eval().cuda() + + def _rc(shape): + return torch.polar( + torch.rand(*shape, device="cuda") * 0.8 + 0.2, + torch.rand(*shape, device="cuda") * 2 * 3.14159, + ) + + z1, z2, z3, z4 = [_rc((4, 16)) for _ in range(4)] + trt_model = torchtrt.compile(model, inputs=(z1, z2, z3, z4), **_COMPILE) + _cossim_complex( + model(z1, z2, z3, z4), trt_model(z1, z2, z3, z4), "four_complex_inputs_mul_add" + ) + torch._dynamo.reset() + + +class CrossAttentionComplexQ(nn.Module): + """Cross-attention: complex-rotated queries, real key/value projections.""" + + def __init__(self, d_q: int = 32, d_kv: int = 64) -> None: + super().__init__() + self.freq = nn.Parameter(_make_freqs(8, d_q).detach()) + self.norm_q = nn.LayerNorm(d_q) + self.Wk = nn.Linear(d_kv, d_q, bias=False) + self.Wv = nn.Linear(d_kv, d_q, bias=False) + + def forward(self, q_real, kv): + qc = torch.view_as_complex(q_real.reshape(*q_real.shape[:-1], -1, 2)) + q = self.norm_q(torch.view_as_real(qc * self.freq).flatten(-2)) + k = self.Wk(kv) + v = self.Wv(kv) + scores = torch.bmm(q, k.transpose(1, 2)) / (q.shape[-1] ** 0.5) + return torch.bmm(scores, v) + + +@pytest.mark.unit +def test_cross_attention_complex_q(): + """Cross-attention: complex-rotated query, real key/value projections.""" + model = CrossAttentionComplexQ(d_q=32, d_kv=64).eval().cuda() + q_real = torch.randn(2, 8, 32, device="cuda") + kv = torch.randn(2, 12, 64, device="cuda") + trt_model = torchtrt.compile(model, inputs=(q_real, kv), **_COMPILE) + _cossim_real(model(q_real, kv), trt_model(q_real, kv), "cross_attention_complex_q") + torch._dynamo.reset() + + +class ComplexRotator(nn.Module): + """Single complex rotation layer wrapping a learnable frequency buffer.""" + + def __init__(self, seq: int = 8, dim: int = 32) -> None: + super().__init__() + self.freq = nn.Parameter(_make_freqs(seq, dim).detach()) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + return torch.view_as_real(xc * self.freq).flatten(-2) + + +class NestedComplexRotators(nn.Module): + """Two ComplexRotator sub-modules with a real LayerNorm between them.""" + + def __init__(self, d: int = 32) -> None: + super().__init__() + self.rot1 = ComplexRotator(seq=8, dim=d) + self.norm = nn.LayerNorm(d) + self.rot2 = ComplexRotator(seq=8, dim=d) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.rot1(x) + x = self.norm(x) + return self.rot2(x) + + +@pytest.mark.unit +def test_nested_complex_rotators(): + """Two nested ComplexRotator sub-modules with a real LayerNorm between them.""" + model = NestedComplexRotators(d=32).eval().cuda() + x = torch.randn(2, 8, 32, device="cuda") + trt_model = torchtrt.compile(model, inputs=(x,), **_COMPILE) + _cossim_real(model(x), trt_model(x), "nested_complex_rotators") + torch._dynamo.reset() + + +class ComplexNormThenProject(nn.Module): + """abs(z) → LayerNorm → rescale z: real and complex subgraphs share an edge.""" + + def __init__(self, dim: int = 16) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + mag = torch.abs(z) + scale = self.norm(mag) + scale_c = torch.view_as_complex( + torch.stack([scale, torch.zeros_like(scale)], dim=-1) + ) + return z * scale_c + + +@pytest.mark.unit +def test_complex_norm_then_project(): + """abs(z) → LayerNorm → rescale z: real and complex subgraphs share an edge.""" + model = ComplexNormThenProject(dim=16).eval().cuda() + z = torch.polar( + torch.rand(4, 16, device="cuda") * 0.8 + 0.2, + torch.rand(4, 16, device="cuda") * 2 * 3.14159, + ) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_norm_then_project") + torch._dynamo.reset() + + +class ComplexRotateProject(nn.Module): + """Two complex rotations separated by a real linear layer.""" + + def __init__(self, d: int = 32) -> None: + super().__init__() + self.freq1 = nn.Parameter(_make_freqs(8, d).detach()) + self.freq2 = nn.Parameter(_make_freqs(8, d).detach()) + self.proj = nn.Linear(d, d, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + r1 = torch.view_as_real(xc * self.freq1).flatten(-2) + r2 = self.proj(r1) + xc2 = torch.view_as_complex(r2.reshape(*r2.shape[:-1], -1, 2)) + return torch.view_as_real(xc2 * self.freq2).flatten(-2) + + +@pytest.mark.unit +def test_complex_rotate_project(): + """Two complex rotations separated by a real linear layer.""" + model = ComplexRotateProject(d=32).eval().cuda() + x = torch.randn(2, 8, 32, device="cuda") + trt_model = torchtrt.compile(model, inputs=(x,), **_COMPILE) + _cossim_real(model(x), trt_model(x), "complex_rotate_project") + torch._dynamo.reset() diff --git a/tests/py/dynamo/hlo/test_rope_embedding.py b/tests/py/dynamo/hlo/test_rope_embedding.py new file mode 100644 index 0000000000..bbd76c7a92 --- /dev/null +++ b/tests/py/dynamo/hlo/test_rope_embedding.py @@ -0,0 +1,526 @@ +""" +Tests for Rotary Position Embedding (RoPE) compilation with torch-tensorrt. + +RoPE is a critical subgraph used in modern LLMs (LLaMA, Qwen, Mistral, etc.). +Two common forms are tested: + 1. HuggingFace-style: rotate_half + apply_rotary_pos_emb using cos/sin tensors + 2. Complex-number form: view_as_complex + complex multiply + view_as_real + +Both static and dynamic shapes (varying seq_len, batch) are covered, as well as +RoPE embedded inside a larger attention block (a common failure mode). +""" + +import os + +import pytest +import torch +import torch.nn as nn +import torch_tensorrt as torchtrt +from torch.export import Dim +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +# --------------------------------------------------------------------------- +# Shared helper modules +# --------------------------------------------------------------------------- + + +class HFRotaryEmbedding(nn.Module): + """HuggingFace-style RoPE as used in LLaMA / Qwen / Mistral. + + Identical to ``apply_rotary_pos_emb`` in transformers: + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + """ + + def rotate_half(self, x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): + # cos/sin shape: (batch, seq_len, head_dim) – unsqueeze head dim + cos = cos.unsqueeze(1) # (batch, 1, seq_len, head_dim) + sin = sin.unsqueeze(1) + q_embed = (q * cos) + (self.rotate_half(q) * sin) + k_embed = (k * cos) + (self.rotate_half(k) * sin) + return q_embed, k_embed + + +class ComplexRotaryEmbedding(nn.Module): + """Complex-number RoPE as used in original LLaMA / Meta models. + + Applies pre-computed complex frequency tensor via: + x_complex = view_as_complex(x.reshape(..., -1, 2)) + out = view_as_real(x_complex * freqs_cis).flatten(-2) + """ + + def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + # x: (batch, seq_len, n_heads, head_dim) + # freqs_cis: (seq_len, head_dim // 2) complex + x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis[None, :, None, :] # broadcast over batch and heads + x_out = torch.view_as_real(x_ * freqs_cis).flatten(3) + return x_out.type_as(x) + + +def _make_freqs_cis( + seq_len: int, head_dim: int, theta: float = 10000.0 +) -> torch.Tensor: + """Pre-compute complex frequency tensor on CUDA.""" + freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) + t = torch.arange(seq_len, dtype=torch.float32) + freqs = torch.outer(t, freqs) + return torch.polar(torch.ones_like(freqs), freqs).cuda() + + +# --------------------------------------------------------------------------- +# Test 1: HuggingFace-style RoPE – static shapes +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +def test_rope_hf_style_static(): + """HF rotate_half RoPE compiles and produces correct outputs (static shapes).""" + model = HFRotaryEmbedding().eval().cuda() + + q = torch.randn(1, 12, 5, 128, dtype=torch.float32).cuda() + k = torch.randn(1, 12, 5, 128, dtype=torch.float32).cuda() + # cos/sin: (batch, seq_len, head_dim) + cos = torch.randn(1, 5, 128, dtype=torch.float32).cuda() + sin = torch.randn(1, 5, 128, dtype=torch.float32).cuda() + inputs = (q, k, cos, sin) + + compile_spec = { + "inputs": inputs, + "ir": "dynamo", + "min_block_size": 1, + "pass_through_build_failures": True, + } + trt_model = torchtrt.compile(model, **compile_spec) + + py_q, py_k = model(*inputs) + trt_q, trt_k = trt_model(*inputs) + + cos_sim_q = cosine_similarity(py_q, trt_q) + cos_sim_k = cosine_similarity(py_k, trt_k) + assert cos_sim_q > COSINE_THRESHOLD, ( + f"test_rope_hf_style_static: q outputs differ. " + f"Cosine sim: {cos_sim_q:.4f} < threshold {COSINE_THRESHOLD}" + ) + assert cos_sim_k > COSINE_THRESHOLD, ( + f"test_rope_hf_style_static: k outputs differ. " + f"Cosine sim: {cos_sim_k:.4f} < threshold {COSINE_THRESHOLD}" + ) + torch._dynamo.reset() + + +# --------------------------------------------------------------------------- +# Test 2: HuggingFace-style RoPE – dynamic seq_len +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +def test_rope_hf_style_dynamic(): + """HF rotate_half RoPE compiles and produces correct outputs (dynamic seq_len).""" + model = HFRotaryEmbedding().eval().cuda() + + q = torch.randn(1, 12, 5, 128, dtype=torch.float32).cuda() + k = torch.randn(1, 12, 5, 128, dtype=torch.float32).cuda() + cos = torch.randn(1, 5, 128, dtype=torch.float32).cuda() + sin = torch.randn(1, 5, 128, dtype=torch.float32).cuda() + inputs = (q, k, cos, sin) + + seq_len = Dim("seq_len", min=2, max=2048) + # q/k: (batch, n_heads, seq_len, head_dim) – seq_len is dim 2 + # cos/sin: (batch, seq_len, head_dim) – seq_len is dim 1 + dynamic_shapes = ( + {2: seq_len}, + {2: seq_len}, + {1: seq_len}, + {1: seq_len}, + ) + exp_program = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + + compile_spec = { + "inputs": inputs, + "min_block_size": 1, + "pass_through_build_failures": True, + } + trt_model = torchtrt.dynamo.compile(exp_program, **compile_spec) + + py_q, py_k = model(*inputs) + trt_q, trt_k = trt_model(*inputs) + + cos_sim_q = cosine_similarity(py_q, trt_q) + cos_sim_k = cosine_similarity(py_k, trt_k) + assert cos_sim_q > COSINE_THRESHOLD, ( + f"test_rope_hf_style_dynamic: q outputs differ. " + f"Cosine sim: {cos_sim_q:.4f} < threshold {COSINE_THRESHOLD}" + ) + assert cos_sim_k > COSINE_THRESHOLD, ( + f"test_rope_hf_style_dynamic: k outputs differ. " + f"Cosine sim: {cos_sim_k:.4f} < threshold {COSINE_THRESHOLD}" + ) + torch._dynamo.reset() + + +# --------------------------------------------------------------------------- +# Test 3: Complex-number RoPE – static shapes +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +def test_rope_complex_form_static(): + """Complex (view_as_complex/view_as_real) RoPE compiles correctly (static shapes).""" + BATCH, SEQ_LEN, N_HEADS, HEAD_DIM = 2, 8, 4, 64 + model = ComplexRotaryEmbedding().eval().cuda() + + x = torch.randn(BATCH, SEQ_LEN, N_HEADS, HEAD_DIM, dtype=torch.float32).cuda() + freqs_cis = _make_freqs_cis(SEQ_LEN, HEAD_DIM) + inputs = (x, freqs_cis) + + compile_spec = { + "inputs": inputs, + "ir": "dynamo", + "min_block_size": 1, + "pass_through_build_failures": True, + } + trt_model = torchtrt.compile(model, **compile_spec) + + py_out = model(*inputs) + trt_out = trt_model(*inputs) + + cos_sim = cosine_similarity(py_out, trt_out) + assert cos_sim > COSINE_THRESHOLD, ( + f"test_rope_complex_form_static: outputs differ. " + f"Cosine sim: {cos_sim:.4f} < threshold {COSINE_THRESHOLD}" + ) + torch._dynamo.reset() + + +# --------------------------------------------------------------------------- +# Test 4: Complex-number RoPE – dynamic batch and seq_len +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +def test_rope_complex_form_dynamic(): + """Complex RoPE compiles correctly with dynamic batch and seq_len.""" + BATCH, SEQ_LEN, N_HEADS, HEAD_DIM = 2, 8, 4, 64 + model = ComplexRotaryEmbedding().eval().cuda() + + x = torch.randn(BATCH, SEQ_LEN, N_HEADS, HEAD_DIM, dtype=torch.float32).cuda() + freqs_cis = _make_freqs_cis(SEQ_LEN, HEAD_DIM) + inputs = (x, freqs_cis) + + batch = Dim("batch", min=1, max=4) + seq_len = Dim("seq_len", min=2, max=512) + # x: (batch, seq_len, n_heads, head_dim) + # freqs_cis: (seq_len, head_dim//2) complex – dim 0 is seq_len + dynamic_shapes = ( + {0: batch, 1: seq_len}, + {0: seq_len}, + ) + exp_program = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + + compile_spec = { + "inputs": inputs, + "min_block_size": 1, + "pass_through_build_failures": True, + } + trt_model = torchtrt.dynamo.compile(exp_program, **compile_spec) + + py_out = model(*inputs) + trt_out = trt_model(*inputs) + + cos_sim = cosine_similarity(py_out, trt_out) + assert cos_sim > COSINE_THRESHOLD, ( + f"test_rope_complex_form_dynamic: outputs differ. " + f"Cosine sim: {cos_sim:.4f} < threshold {COSINE_THRESHOLD}" + ) + torch._dynamo.reset() + + +# --------------------------------------------------------------------------- +# Test 5: RoPE embedded inside an attention block – static shapes +# --------------------------------------------------------------------------- + + +class AttentionWithRoPE(nn.Module): + """Minimal self-attention block with HF-style RoPE, as found in LLaMA/Qwen. + + This exercises RoPE inside a larger graph—a common failure mode where + the shape inference for cos/sin unsqueeze interacts with the projection + output shapes. + """ + + def __init__(self, embed_dim: int = 64, n_heads: int = 4): + super().__init__() + self.n_heads = n_heads + self.head_dim = embed_dim // n_heads + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.o_proj = nn.Linear(embed_dim, embed_dim, bias=False) + + def rotate_half(self, x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb( + self, + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): + cos = cos.unsqueeze(1) # add head dim + sin = sin.unsqueeze(1) + return (q * cos) + (self.rotate_half(q) * sin), (k * cos) + ( + self.rotate_half(k) * sin + ) + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + batch, seq_len, _ = hidden_states.shape + q = ( + self.q_proj(hidden_states) + .view(batch, seq_len, self.n_heads, self.head_dim) + .transpose(1, 2) + ) + k = ( + self.k_proj(hidden_states) + .view(batch, seq_len, self.n_heads, self.head_dim) + .transpose(1, 2) + ) + v = ( + self.v_proj(hidden_states) + .view(batch, seq_len, self.n_heads, self.head_dim) + .transpose(1, 2) + ) + + q, k = self.apply_rotary_pos_emb(q, k, cos, sin) + + attn_out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, is_causal=True + ) + attn_out = attn_out.transpose(1, 2).reshape(batch, seq_len, -1) + return self.o_proj(attn_out) + + +@pytest.mark.unit +def test_rope_in_attention_block_static(): + """RoPE inside a full attention block compiles correctly (static shapes).""" + EMBED_DIM, N_HEADS, BATCH, SEQ_LEN = 64, 4, 2, 16 + HEAD_DIM = EMBED_DIM // N_HEADS + + model = AttentionWithRoPE(EMBED_DIM, N_HEADS).eval().cuda() + + hidden = torch.randn(BATCH, SEQ_LEN, EMBED_DIM, dtype=torch.float32).cuda() + cos = torch.randn(BATCH, SEQ_LEN, HEAD_DIM, dtype=torch.float32).cuda() + sin = torch.randn(BATCH, SEQ_LEN, HEAD_DIM, dtype=torch.float32).cuda() + inputs = (hidden, cos, sin) + + compile_spec = { + "inputs": inputs, + "ir": "dynamo", + "min_block_size": 1, + "pass_through_build_failures": True, + } + trt_model = torchtrt.compile(model, **compile_spec) + + py_out = model(*inputs) + trt_out = trt_model(*inputs) + + cos_sim = cosine_similarity(py_out, trt_out) + assert cos_sim > COSINE_THRESHOLD, ( + f"test_rope_in_attention_block_static: outputs differ. " + f"Cosine sim: {cos_sim:.4f} < threshold {COSINE_THRESHOLD}" + ) + torch._dynamo.reset() + + +# --------------------------------------------------------------------------- +# Test 6: RoPE embedded inside an attention block – dynamic seq_len +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +def test_rope_in_attention_block_dynamic(): + """RoPE inside a full attention block compiles correctly (dynamic seq_len).""" + EMBED_DIM, N_HEADS, BATCH, SEQ_LEN = 64, 4, 2, 16 + HEAD_DIM = EMBED_DIM // N_HEADS + + model = AttentionWithRoPE(EMBED_DIM, N_HEADS).eval().cuda() + + hidden = torch.randn(BATCH, SEQ_LEN, EMBED_DIM, dtype=torch.float32).cuda() + cos = torch.randn(BATCH, SEQ_LEN, HEAD_DIM, dtype=torch.float32).cuda() + sin = torch.randn(BATCH, SEQ_LEN, HEAD_DIM, dtype=torch.float32).cuda() + inputs = (hidden, cos, sin) + + seq_len = Dim("seq_len", min=2, max=2048) + # hidden: (batch, seq_len, embed_dim) – seq_len is dim 1 + # cos/sin: (batch, seq_len, head_dim) – seq_len is dim 1 + dynamic_shapes = ( + {1: seq_len}, + {1: seq_len}, + {1: seq_len}, + ) + exp_program = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + + compile_spec = { + "inputs": inputs, + "min_block_size": 1, + "pass_through_build_failures": True, + } + trt_model = torchtrt.dynamo.compile(exp_program, **compile_spec) + + py_out = model(*inputs) + trt_out = trt_model(*inputs) + + cos_sim = cosine_similarity(py_out, trt_out) + assert cos_sim > COSINE_THRESHOLD, ( + f"test_rope_in_attention_block_dynamic: outputs differ. " + f"Cosine sim: {cos_sim:.4f} < threshold {COSINE_THRESHOLD}" + ) + torch._dynamo.reset() + + +# --------------------------------------------------------------------------- +# Test 7: Complex RoPE – serialization with retrace=True then inference +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +def test_rope_complex_form_serialization_retrace(tmp_path): + """Complex RoPE survives save(retrace=True) + load + inference round-trip. + + When retrace=True, torch_tensorrt.save re-exports the compiled GraphModule + via torch.export.export (strict=False), inlining the view_as_real unpacking + ops that live in the Python runtime forward(). The reloaded ExportedProgram + must accept the original complex inputs and produce correct results. + """ + BATCH, SEQ_LEN, N_HEADS, HEAD_DIM = 2, 8, 4, 64 + model = ComplexRotaryEmbedding().eval().cuda() + + x = torch.randn(BATCH, SEQ_LEN, N_HEADS, HEAD_DIM, dtype=torch.float32).cuda() + freqs_cis = _make_freqs_cis(SEQ_LEN, HEAD_DIM) + inputs = (x, freqs_cis) + + # Step 1: compile + compile_spec = { + "inputs": inputs, + "ir": "dynamo", + "min_block_size": 1, + "pass_through_build_failures": True, + } + trt_model = torchtrt.compile(model, **compile_spec) + + py_out = model(*inputs) + trt_out_before = trt_model(*inputs) + + cos_sim_before = cosine_similarity(py_out, trt_out_before) + assert cos_sim_before > COSINE_THRESHOLD, ( + f"test_rope_complex_form_serialization_retrace: pre-save TRT output wrong. " + f"Cosine sim: {cos_sim_before:.4f} < threshold {COSINE_THRESHOLD}" + ) + + # Step 2: save with retrace=True — re-exports the compiled GraphModule so + # the view_as_real input-unpacking is inlined into the exported graph. + ep_path = str(tmp_path / "rope_complex_trt.ep") + torchtrt.save( + trt_model, + ep_path, + output_format="exported_program", + arg_inputs=list(inputs), + retrace=True, + ) + assert os.path.exists(ep_path), "Serialized .ep file was not created" + + # Step 3: reload + loaded_ep = torchtrt.load(ep_path) + # torch_tensorrt.load returns ExportedProgram; call .module() to get the + # callable GraphModule. + loaded_module = loaded_ep.module() + + # Step 4: inference on reloaded model + trt_out_after = loaded_module(*inputs) + + cos_sim_after = cosine_similarity(py_out, trt_out_after) + assert cos_sim_after > COSINE_THRESHOLD, ( + f"test_rope_complex_form_serialization_retrace: post-load TRT output wrong. " + f"Cosine sim: {cos_sim_after:.4f} < threshold {COSINE_THRESHOLD}" + ) + torch._dynamo.reset() + + +# --------------------------------------------------------------------------- +# Test 8: Complex output – model whose output is a complex tensor +# --------------------------------------------------------------------------- + + +class ComplexOutputModel(nn.Module): + """A model that outputs a complex tensor. + + This exercises the post-partition complex output restoration pass: + complex_graph_detection rewrites the internal complex ops to real + arithmetic before partitioning, and the compiler must re-insert + view_as_complex at the output boundary when the tail block is TRT. + """ + + def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + # x: (batch, seq_len, n_heads, head_dim) – real + # freqs_cis: (seq_len, head_dim // 2) – complex + # Returns: complex tensor of shape (batch, seq_len, n_heads, head_dim // 2) + x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis[None, :, None, :] + return x_ * freqs_cis # complex output – no view_as_real + + +@pytest.mark.unit +def test_complex_output_static(): + """Model with a complex tensor output compiles and produces correct results.""" + model = ComplexOutputModel().eval().cuda() + + x = torch.randn(1, 4, 8, 64, dtype=torch.float32).cuda() + freqs_cis = _make_freqs_cis(4, 64) # shape (4, 32), complex64 + inputs = (x, freqs_cis) + + compile_spec = { + "inputs": inputs, + "ir": "dynamo", + "min_block_size": 1, + "pass_through_build_failures": True, + } + trt_model = torchtrt.compile(model, **compile_spec) + + py_out = model(*inputs) + trt_out = trt_model(*inputs) + + assert ( + trt_out.is_complex() + ), f"test_complex_output_static: TRT output should be complex, got dtype {trt_out.dtype}" + assert ( + trt_out.shape == py_out.shape + ), f"test_complex_output_static: shape mismatch {trt_out.shape} vs {py_out.shape}" + # Compare real and imaginary parts via cosine similarity + cos_sim_real = cosine_similarity(py_out.real, trt_out.real) + cos_sim_imag = cosine_similarity(py_out.imag, trt_out.imag) + assert ( + cos_sim_real > COSINE_THRESHOLD + ), f"test_complex_output_static: real part cosine sim {cos_sim_real:.4f} < {COSINE_THRESHOLD}" + assert ( + cos_sim_imag > COSINE_THRESHOLD + ), f"test_complex_output_static: imag part cosine sim {cos_sim_imag:.4f} < {COSINE_THRESHOLD}" + torch._dynamo.reset() From 0b449a7c92d0206b02931ee23dd16b25b4f06a50 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Thu, 5 Mar 2026 18:09:25 +0000 Subject: [PATCH 2/8] feat(//py/torch_tensorrt/dynamo): Allow the refit system to cache complex numerics --- py/torch_tensorrt/dynamo/_refit.py | 46 +++- .../dynamo/conversion/_TRTInterpreter.py | 35 +++ tests/py/dynamo/models/test_model_refit.py | 237 ++++++++++++++++++ 3 files changed, 310 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 0b6af849fa..cf1fe5a191 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -43,6 +43,7 @@ from torch_tensorrt.dynamo.utils import ( CPU_DEVICE, check_module_output, + check_output_equal, get_model_device, get_torch_inputs, to_torch_device, @@ -110,6 +111,17 @@ def construct_refit_mapping_from_weight_name_map( engine_weight_name.split(" ")[-1].lower() ) + elif isinstance(sd_weight_name, tuple): + # Buffer-slice mapping created by Stage 3 of _save_weight_mapping. + # Encodes (state_dict_key, dim, index) for weights that are slices + # of a source buffer (e.g. real/imag parts of an unpacked complex buffer). + sd_key, dim, idx = sd_weight_name + if sd_key not in state_dict: + continue + engine_weight_map[engine_weight_name] = ( + state_dict[sd_key].select(dim, idx).to(to_torch_device(settings.device)) + ) + elif sd_weight_name not in state_dict: # If weights is not in sd, we can leave it unchanged continue @@ -587,14 +599,33 @@ def refit_module_weights( if verify_output and arg_inputs is not None: new_gm.to(to_torch_device(settings.device)) - if check_module_output( - new_module=new_gm, - refitted_module=compiled_module, - arg_inputs=torch_inputs, - kwarg_inputs=torch_kwarg_inputs, - ): + # complex_graph_detection rewrites complex placeholders to real (view_as_real). + # The compiled TRT module handles complex→real internally, but the lowered + # PyTorch reference module (new_gm) expects real-unpacked inputs directly. + has_complex_inputs = any( + isinstance(x, torch.Tensor) and x.is_complex() for x in torch_inputs + ) + if has_complex_inputs: + lowered_inputs = [ + ( + torch.view_as_real(x).contiguous() + if isinstance(x, torch.Tensor) and x.is_complex() + else x + ) + for x in torch_inputs + ] + trt_outputs = compiled_module(*torch_inputs) + ref_outputs = new_gm(*lowered_inputs, **torch_kwarg_inputs) + outputs_match = check_output_equal(trt_outputs, ref_outputs) + else: + outputs_match = check_module_output( + new_module=new_gm, + refitted_module=compiled_module, + arg_inputs=torch_inputs, + kwarg_inputs=torch_kwarg_inputs, + ) + if outputs_match: logger.info("Refitting Succeed!") - new_gm.to(CPU_DEVICE) else: if weight_name_map: logger.warning( @@ -610,7 +641,6 @@ def refit_module_weights( in_place=in_place, ) logger.error("Refitting Failed! The outputs do not match.") - new_gm.to(CPU_DEVICE) else: logger.info("Refitting Completed! Output verification skipped.") diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index d4735baa12..fa982bc6da 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -587,6 +587,41 @@ def _save_weight_mapping(self) -> None: weight_refit_map[engine_weight_name].dtype, ] + # Stage 3: Slice matching for unmatched non-scalar CONSTANT weights. + # complex_graph_detection unpacks complex buffers to real: + # freqs (S,D complex64) → freqs_unpacked_complex (S,D,2 float32) + # The real and imag slices (freqs_unpacked_complex[...,0] and [...,1]) are + # embedded as separate TRT constants, but their shapes differ from the source + # buffer, so Stage 2 value matching fails. Here we try selecting each slice + # along the last dimension of every sd entry to find the match. + for engine_weight_name, val in list(weight_name_map.items()): + if not isinstance(val, list) or len(val) != 2: + continue + sd_weight_name, dtype_val = val + if sd_weight_name != "" or engine_weight_name not in weight_refit_map: + continue + ew_tensor = weight_refit_map[engine_weight_name].to(torch_device) + if ew_tensor.numel() <= 1: + continue # scalars are handled via constant_mapping + matched = False + for sd_key, sd_tensor in sd.items(): + if sd_tensor.dim() < 1 or sd_tensor.shape[-1] < 2: + continue + last_dim = sd_tensor.dim() - 1 + for idx in range(sd_tensor.shape[last_dim]): + sd_slice = sd_tensor.select(last_dim, idx) + if TRTInterpreter.check_weight_equal( + sd_slice, ew_tensor, torch_device + ): + weight_name_map[engine_weight_name] = [ + (sd_key, last_dim, idx), + dtype_val, + ] + matched = True + break + if matched: + break + weight_name_map["constant_mapping"] = constant_mapping self.weight_name_map = weight_name_map diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 1bdbd2dc60..3f8fafe7d2 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -522,6 +522,10 @@ def test_refit_one_engine_bert_with_weightmap(): torch._dynamo.reset() +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) @unittest.skipIf( not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", @@ -582,6 +586,10 @@ def test_refit_one_engine_inline_runtime_with_weightmap(tmpdir): torch._dynamo.reset() +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) @unittest.skipIf( not torch_trt.ENABLED_FEATURES.refit, "Refit feature is not supported in Python 3.13 or higher", @@ -773,6 +781,10 @@ def forward(self, x): torch._dynamo.reset() +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) @unittest.skipIf( not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", @@ -892,6 +904,10 @@ def test_refit_one_engine_bert_without_weightmap(): torch._dynamo.reset() +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) @unittest.skipIf( not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", @@ -949,6 +965,10 @@ def test_refit_one_engine_inline_runtime_without_weightmap(tmpdir): torch._dynamo.reset() +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) @unittest.skipIf( not torch_trt.ENABLED_FEATURES.refit, "Refit feature is not supported in Python 3.13 or higher", @@ -1128,3 +1148,220 @@ def forward(self, x): # Clean up model env torch._dynamo.reset() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) +@pytest.mark.unit +def test_complex_buffer_refit(): + """Refit a model whose weights include a complex-valued buffer (e.g. RoPE freqs). + + Exercises the combined complex_graph_detection + refit_module_weights path: + - complex get_attr buffer is unpacked to real by the lowering pass + - complex placeholder input goes through view_as_real at the TRT boundary + - after refitting with new frequencies the output matches the new model + """ + + class ComplexFreqModel(nn.Module): + def __init__(self, freqs: torch.Tensor): + super().__init__() + self.register_buffer("freqs", freqs.cuda()) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + # complex mul then expose as real so TRT can produce a real output + return torch.view_as_real(z * self.freqs) + + SEQ, DIM = 8, 32 + + def make_freqs() -> torch.Tensor: + angles = torch.rand(SEQ, DIM // 2) + return torch.polar(torch.ones_like(angles), angles).cuda() + + freqs1 = make_freqs() + freqs2 = make_freqs() + + model1 = ComplexFreqModel(freqs1).eval() + model2 = ComplexFreqModel(freqs2).eval() + + z = torch.randn(SEQ, DIM // 2, dtype=torch.complex64).cuda() + inputs = [z] + + exp_program1 = torch.export.export(model1, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program1, + tuple(inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + min_block_size=1, + immutable_weights=False, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + arg_inputs=inputs, + use_weight_map_cache=True, + verify_output=True, + ) + + expected_output = exp_program2.module()(*inputs) + refitted_output = new_trt_gm(*inputs) + + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, atol=1e-2, rtol=1e-2), + "Refit with complex buffer failed: output mismatch after refit", + ) + + torch._dynamo.reset() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) +@pytest.mark.unit +def test_complex_buffer_with_real_param_refit(): + """Refit a model that mixes a complex buffer with a real nn.Linear weight. + + Verifies that Stage 3 slice-matching for complex buffer constants coexists + correctly with ordinary weight-name-map entries for real parameters. + After refitting both the frequencies and the projection matrix, the output + should match the new model exactly. + """ + + SEQ, DIM = 8, 32 + + class ComplexRotateAndProject(nn.Module): + def __init__(self, freqs: torch.Tensor): + super().__init__() + self.register_buffer("freqs", freqs.cuda()) + self.proj = nn.Linear(DIM, DIM, bias=False) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + rotated = z * self.freqs # complex mul, (SEQ, DIM//2) + r = torch.view_as_real(rotated) # (SEQ, DIM//2, 2) + flat = r.reshape(z.shape[0], -1) # (SEQ, DIM) + return self.proj(flat) # (SEQ, DIM) real output + + def make_freqs() -> torch.Tensor: + angles = torch.rand(SEQ, DIM // 2) + return torch.polar(torch.ones_like(angles), angles).cuda() + + model1 = ComplexRotateAndProject(make_freqs()).eval().cuda() + model2 = ComplexRotateAndProject(make_freqs()).eval().cuda() + + z = torch.randn(SEQ, DIM // 2, dtype=torch.complex64).cuda() + inputs = [z] + + exp_program1 = torch.export.export(model1, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program1, + tuple(inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + min_block_size=1, + immutable_weights=False, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + arg_inputs=inputs, + use_weight_map_cache=True, + verify_output=True, + ) + + expected_output = exp_program2.module()(*inputs) + refitted_output = new_trt_gm(*inputs) + + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, atol=1e-2, rtol=1e-2), + "Refit with complex buffer + real param failed: output mismatch", + ) + + torch._dynamo.reset() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) +@pytest.mark.unit +def test_dual_complex_buffer_refit(): + """Refit a model with two independent complex buffers. + + Ensures Stage 3 value-based matching correctly distinguishes the real and + imaginary slices of freqs_a from those of freqs_b when both are unpacked to + separate _unpacked_complex state-dict entries with the same shape. + """ + + SEQ, DIM = 8, 32 + + class DualComplexFreqModel(nn.Module): + def __init__(self, freqs_a: torch.Tensor, freqs_b: torch.Tensor): + super().__init__() + self.register_buffer("freqs_a", freqs_a.cuda()) + self.register_buffer("freqs_b", freqs_b.cuda()) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + ra = torch.view_as_real(z * self.freqs_a) # (SEQ, DIM//2, 2) + rb = torch.view_as_real(z * self.freqs_b) # (SEQ, DIM//2, 2) + return ra + rb # real output + + def make_freqs() -> torch.Tensor: + angles = torch.rand(SEQ, DIM // 2) + return torch.polar(torch.ones_like(angles), angles).cuda() + + model1 = DualComplexFreqModel(make_freqs(), make_freqs()).eval() + model2 = DualComplexFreqModel(make_freqs(), make_freqs()).eval() + + z = torch.randn(SEQ, DIM // 2, dtype=torch.complex64).cuda() + inputs = [z] + + exp_program1 = torch.export.export(model1, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program1, + tuple(inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + min_block_size=1, + immutable_weights=False, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + arg_inputs=inputs, + use_weight_map_cache=True, + verify_output=True, + ) + + expected_output = exp_program2.module()(*inputs) + refitted_output = new_trt_gm(*inputs) + + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, atol=1e-2, rtol=1e-2), + "Refit with dual complex buffers failed: output mismatch", + ) + + torch._dynamo.reset() From ed435c15e27a2e1e1fa25aa230f89e9cffa47e2f Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Thu, 5 Mar 2026 19:07:40 +0000 Subject: [PATCH 3/8] docs: Add documentation on how complex numerics works --- .../contributors/complex_number_support.rst | 210 ++++++++++++++++++ docsrc/tutorials/advanced_usage.rst | 3 +- .../complex_tensors.rst | 6 + docsrc/tutorials/complex_numerics/index.rst | 10 + docsrc/tutorials/deployment/index.rst | 1 - .../extensibility/lowering/index.rst | 1 + .../lowering/subgraph_builder.rst | 105 +++++++++ examples/dynamo/README.rst | 3 +- 8 files changed, 336 insertions(+), 3 deletions(-) rename docsrc/tutorials/{deployment => complex_numerics}/complex_tensors.rst (96%) create mode 100644 docsrc/tutorials/complex_numerics/index.rst create mode 100644 docsrc/tutorials/extensibility/lowering/subgraph_builder.rst diff --git a/docsrc/contributors/complex_number_support.rst b/docsrc/contributors/complex_number_support.rst index c224c76d6d..fa5b4dd9ae 100644 --- a/docsrc/contributors/complex_number_support.rst +++ b/docsrc/contributors/complex_number_support.rst @@ -146,8 +146,218 @@ Key Implementation Invariants Nested submodule parameter names (e.g. ``layers.0.weight``) must have ``.`` replaced with ``__`` before registration. +The Decomposition System — How It Is Built +------------------------------------------- + +The rewriter is split across two classes and wired together by a lightweight +dispatch mechanism. This section walks through each piece and explains the +design decisions. + +ComplexOpDetector — Subgraph Discovery +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``ComplexOpDetector`` walks the graph to find the set of nodes that participate +in complex arithmetic. + +``node_include_in_subgraph`` +"""""""""""""""""""""""""""" + +A node is included in a complex subgraph if: + +1. Its output dtype is ``complex64`` or ``complex128`` (``is_complex_dtype``), **or** +2. Any of its inputs are complex (``has_complex_input``). + +The second condition is necessary to catch real-output ops — ``abs``, ``angle``, +``real``, ``imag`` — whose inputs are complex. These must be rewritten alongside +the rest of the subgraph even though their outputs are real. + +``subgraph_from_anchor`` +"""""""""""""""""""""""" + +For ``view_as_real``-bounded subgraphs, detection starts at a ``view_as_real`` +*anchor* node and performs a backward BFS: + +.. code-block:: text + + view_as_real ← mul (complex) ← reshape ← placeholder (complex) + ↑ anchor ↑ subgraph ↑ subgraph ↑ input + +At each step, if an upstream node satisfies ``node_include_in_subgraph`` it is +added to the subgraph; otherwise it becomes an *input node* (the boundary). The +result is a ``ComplexSubGraphInfo`` containing anchor nodes, subgraph nodes, and +input nodes. + +After collection the subgraph is **sorted in topological order** (by position in +the graph's node list). This is critical: without it a ``mul`` node could be +processed before its ``sin`` or ``cos`` operands, causing the rewriter to see the +original complex node instead of the already-rewritten real node. + +``find_complex_op_subgraphs`` and subgraph merging +""""""""""""""""""""""""""""""""""""""""""""""""""" + +When a model has multiple ``view_as_real`` anchors that share upstream nodes +(e.g. ``xq_out`` and ``xk_out`` in a RoPE layer both descend from the same +``freqs_cis`` placeholder), their subgraphs would otherwise be detected +separately. ``find_complex_op_subgraphs`` merges overlapping subgraphs by +set intersection so each node is rewritten exactly once. + +``find_all_complex_subgraphs`` — unbounded complex ops +""""""""""""""""""""""""""""""""""""""""""""""""""""""" + +Some models produce a complex tensor as a graph *output* without passing it +through ``view_as_real``. ``find_all_complex_subgraphs`` is a forward scan that +collects every ``call_function`` node with a complex output, regardless of +anchoring. The resulting subgraph is processed the same way as an +anchor-bounded one. + +ComplexGraphRewriter — Dispatch-Based Rewriting +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``ComplexGraphRewriter`` is decorated with ``@_register_unpackers``, which at +class-definition time scans every method for the ``@_complex_unpacker(op, ...)`` +decorator and builds a ``cls._DISPATCH`` dictionary mapping aten ops to rewrite +methods. + +.. code-block:: python + + @_complex_unpacker(torch.ops.aten.mul.Tensor) + def _rewrite_mul(self, node: Node, b: SubgraphBuilder, ...): + ... + +The entry point ``rewrite_subgraph_nodes`` iterates over the (topologically +ordered) subgraph nodes and for each node: + +1. Looks up ``node.target`` in ``_DISPATCH``. +2. If found, calls the corresponding rewrite method. +3. If not found but the op is in ``_ELEMENTWISE_SAFE``, skips it (the op applies + independently to every scalar, so the ``(..., 2)`` real layout is already + correct). +4. Otherwise logs a warning and leaves the node unchanged. + +``_ELEMENTWISE_SAFE`` +""""""""""""""""""""" + +The ``_ELEMENTWISE_SAFE`` set contains ops that apply to every element of the +tensor independently — ``add.Tensor``, ``sub.Tensor``, ``neg``, ``mul.Scalar``, +``clone``, ``where``, etc. On the ``(..., 2)`` real layout these are already +correct: adding two complex tensors element-wise is the same as adding their +real and imaginary parts independently. + +Notably **excluded** from this set: + +* ``permute.default`` — must append the trailing real/imag dim index. +* ``add.Scalar`` / ``sub.Scalar`` — a scalar added to a complex number only + shifts the real part; on the ``(..., 2)`` layout both parts would be shifted. +* ``reshape`` / ``view`` — shape arguments need updating for the extra ``2`` dim. + +Complex Multiply Decomposition +""""""""""""""""""""""""""""""" + +The most important rewrite is ``mul.Tensor`` between two complex operands. +The rewriter calls ``complex_mul_replacement``: + +.. code-block:: python + + # inputs a, b have shape (..., 2) — last dim is [real, imag] + re_a = select(a, -1, 0); im_a = select(a, -1, 1) + re_b = select(b, -1, 0); im_b = select(b, -1, 1) + real_out = re_a * re_b - im_a * im_b # ac - bd + imag_out = re_a * im_b + im_a * re_b # ad + bc + result = stack([real_out, imag_out], dim=-1) + +Each step is inserted via a ``SubgraphBuilder`` anchored at the ``mul`` node, +so all six new nodes appear immediately after it in topological order. The +original ``mul`` node is then replaced and erased. + +See :ref:`subgraph_builder` for more on how ``SubgraphBuilder`` manages +cursor-based insertion. + +The ``originally_complex`` Invariant +""""""""""""""""""""""""""""""""""""" + +Input replacement (Stage 2) converts complex ``placeholder`` nodes to +``float32``. After that, ``is_complex_dtype(node)`` returns ``False`` for those +nodes even though they logically represent complex quantities. + +To avoid missed rewrites, the rewriter records the set of nodes that were complex +*before any rewrites* in ``originally_complex``. The ``mul.Tensor`` dispatch +handler only triggers the full complex-multiply decomposition when the ``mul`` +node appears in ``originally_complex``; real multiplies that happen to follow a +complex input (e.g. an ``abs`` followed by a real-valued scale) are left alone. + +FakeTensorMode Reuse for Dynamic Shapes +""""""""""""""""""""""""""""""""""""""""" + +When inserting a new ``placeholder`` for a complex input, the pass must populate +``meta["val"]`` with a ``FakeTensor`` of the new real shape. Using a fresh +``FakeTensorMode()`` would create a *new* ``ShapeEnv``, which is incompatible +with the one that ``torch.export`` used to encode dynamic shape constraints +(SymInt ranges). + +The fix is to extract the ``FakeTensorMode`` from the *original* placeholder's +``meta["val"].fake_mode`` and reuse it. The new fake tensor is then constructed +by appending a concrete ``2`` to the symbolic shape list: + +.. code-block:: python + + orig_fake = input_node.meta["val"] + sym_shape = list(orig_fake.shape) + [2] + with orig_fake.fake_mode: + fake_tensor = torch.empty(sym_shape, dtype=new_dtype, device=device) + +This preserves all SymInt identity across the graph and keeps +dynamic-shape exports working correctly. + +Entry Point: ``complex_graph_detection`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The public entry point called by the lowering pipeline is +``complex_graph_detection(gm, settings)``. It: + +1. Instantiates ``ComplexOpDetector`` and ``ComplexGraphRewriter``. +2. Calls ``find_complex_op_subgraphs`` anchored on ``view_as_real`` to find + bounded complex subgraphs. +3. Calls ``find_all_complex_subgraphs`` for any remaining complex nodes that + are not ``view_as_real``-bounded. +4. For each subgraph: + + a. Calls ``replace_input_node`` on every boundary input node (Stage 2). + b. Calls ``rewrite_subgraph_nodes`` on the ordered subgraph (Stage 3). + c. Calls ``clean_up_graph_after_modifications`` to remove dead nodes. + +5. Returns the modified ``GraphModule``. + +Adding New Op Rewrites +^^^^^^^^^^^^^^^^^^^^^^^ + +To teach the rewriter about a new complex op, add a method to +``ComplexGraphRewriter`` tagged with ``@_complex_unpacker``: + +.. code-block:: python + + @_complex_unpacker(torch.ops.aten.my_new_op.default) + def _rewrite_my_new_op( + self, + node: Node, + originally_complex: set, + ) -> None: + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + result = b(my_real_impl, re, im) + node.replace_all_uses_with(result) + self.gm.graph.erase_node(node) + +``@_register_unpackers`` (applied to the class) picks up the new entry +automatically at import time — no other registration is required. + +If the new op is elementwise-safe on the ``(..., 2)`` layout (i.e. it acts +independently on every scalar), add it to ``_ELEMENTWISE_SAFE`` instead. + Related ------- * :ref:`lowering` — the complex rewrite is a lowering pass. +* :ref:`subgraph_builder` — the ``SubgraphBuilder`` helper used in every rewrite method. * :ref:`lowering_passes_catalog` — pass ordering and management. diff --git a/docsrc/tutorials/advanced_usage.rst b/docsrc/tutorials/advanced_usage.rst index 11090dc5a0..28480a4a44 100644 --- a/docsrc/tutorials/advanced_usage.rst +++ b/docsrc/tutorials/advanced_usage.rst @@ -2,7 +2,7 @@ Advanced Usage ============== Step-by-step tutorials covering engine caching, quantization, custom kernels, -dynamic shapes, weight streaming, debugging, and more. +dynamic shapes, weight streaming, debugging, complex numerics, and more. .. toctree:: :maxdepth: 2 @@ -14,5 +14,6 @@ dynamic shapes, weight streaming, debugging, and more. weight_refit/index runtime_opt/index deployment/index + complex_numerics/index Example: Distributed Inference <_rendered_examples/distributed_inference/index> ../indices/supported_ops diff --git a/docsrc/tutorials/deployment/complex_tensors.rst b/docsrc/tutorials/complex_numerics/complex_tensors.rst similarity index 96% rename from docsrc/tutorials/deployment/complex_tensors.rst rename to docsrc/tutorials/complex_numerics/complex_tensors.rst index 57716f181d..c7507685a9 100644 --- a/docsrc/tutorials/deployment/complex_tensors.rst +++ b/docsrc/tutorials/complex_numerics/complex_tensors.rst @@ -11,6 +11,12 @@ compilation. This page explains what the rewriter does, which patterns are supported, and what limitations to be aware of when compiling models with complex inputs. +.. seealso:: + + :doc:`../_rendered_examples/dynamo/torch_export_3d_rope` — a runnable + end-to-end example compiling a video-transformer 3D RoPE attention block + (CogVideoX / Wan / HunyuanVideo style) with dynamic T×H×W shapes. + ---- How the Rewriter Works diff --git a/docsrc/tutorials/complex_numerics/index.rst b/docsrc/tutorials/complex_numerics/index.rst new file mode 100644 index 0000000000..0494d84dad --- /dev/null +++ b/docsrc/tutorials/complex_numerics/index.rst @@ -0,0 +1,10 @@ +Complex Numerics +=================== + +Compatiblity support for numerical datatypes like complex numerics which are not natively supported by TensorRT + +.. toctree:: + :maxdepth: 1 + + complex_tensors + Example: 3D RoPE with Complex Numerics <../_rendered_examples/dynamo/torch_export_3d_rope> diff --git a/docsrc/tutorials/deployment/index.rst b/docsrc/tutorials/deployment/index.rst index 7df88922e5..40383bfd65 100644 --- a/docsrc/tutorials/deployment/index.rst +++ b/docsrc/tutorials/deployment/index.rst @@ -12,4 +12,3 @@ complex-valued model support. cross_compile_windows Example: Cross-runtime Compilation for Windows <../_rendered_examples/dynamo/cross_runtime_compilation_for_windows> distributed_inference - complex_tensors diff --git a/docsrc/tutorials/extensibility/lowering/index.rst b/docsrc/tutorials/extensibility/lowering/index.rst index 487fe5b4ec..e44dcb66c4 100644 --- a/docsrc/tutorials/extensibility/lowering/index.rst +++ b/docsrc/tutorials/extensibility/lowering/index.rst @@ -8,3 +8,4 @@ rewrite ATen ops before TensorRT compilation. :maxdepth: 1 writing_dynamo_aten_lowering_passes + subgraph_builder diff --git a/docsrc/tutorials/extensibility/lowering/subgraph_builder.rst b/docsrc/tutorials/extensibility/lowering/subgraph_builder.rst new file mode 100644 index 0000000000..b7f6131e0d --- /dev/null +++ b/docsrc/tutorials/extensibility/lowering/subgraph_builder.rst @@ -0,0 +1,105 @@ +.. _subgraph_builder: + +SubgraphBuilder — Cursor-Based FX Node Insertion +================================================= + +Writing lowering passes that replace one node with several new nodes requires +careful management of insertion order: each new node must be inserted +*after the previous one* so that the topological ordering of the graph is +preserved. Doing this by hand with repeated ``graph.inserting_after(cursor)`` +context managers is verbose and error-prone. + +``SubgraphBuilder`` is a lightweight context-manager helper in +``torch_tensorrt.dynamo.lowering._SubgraphBuilder`` that automates this +cursor-tracking pattern. + +Basic Usage +----------- + +Construct a ``SubgraphBuilder`` with the target graph and the *anchor* node — +the node immediately before where you want to start inserting. Then use it +as a callable inside a ``with`` block to add nodes one at a time: + +.. code-block:: python + + from torch_tensorrt.dynamo.lowering._SubgraphBuilder import SubgraphBuilder + import torch.ops.aten as aten + + # Inside a lowering pass, given a node `mul_node` to replace: + with SubgraphBuilder(gm.graph, mul_node) as b: + # Each call inserts a node after the current cursor and advances it. + re_a = b(aten.select.int, a, -1, 0) # a[..., 0] (real part of a) + im_a = b(aten.select.int, a, -1, 1) # a[..., 1] (imag part of a) + re_b = b(aten.select.int, b_node, -1, 0) + im_b = b(aten.select.int, b_node, -1, 1) + real = b(aten.sub.Tensor, b(aten.mul.Tensor, re_a, re_b), + b(aten.mul.Tensor, im_a, im_b)) # ac - bd + imag = b(aten.add.Tensor, b(aten.mul.Tensor, re_a, im_b), + b(aten.mul.Tensor, im_a, re_b)) # ad + bc + result = b(aten.stack, [real, imag], -1) + + mul_node.replace_all_uses_with(result) + gm.graph.erase_node(mul_node) + +On ``__exit__``, the builder automatically calls ``graph.lint()`` to validate +the modified graph. If your code raises an exception inside the block, the +lint is skipped so you see the original error rather than a secondary graph +validation failure. + +How It Works +------------ + +The builder maintains a *cursor* — initially the anchor node passed to +``__init__``. Every time you call it: + +1. A new ``call_function`` node is inserted via ``graph.inserting_after(cursor)``. +2. The cursor advances to the newly inserted node. +3. The new node is appended to an internal ``_inserted`` list for debug logging. + +This ensures that successive calls produce a correctly ordered chain: + +.. code-block:: text + + anchor → node_0 → node_1 → node_2 → ... + +without any manual bookkeeping. + +Debug Logging +------------- + +When the ``torch_tensorrt`` logger is set to ``DEBUG``, the builder emits a +compact summary of all inserted nodes after a successful block, for example:: + + rewrite %mul_17[(4, 32, 2),torch.float32] -> + %select_72[(4, 32),torch.float32] = select_int(%inp_0, -1, 0) + %select_73[(4, 32),torch.float32] = select_int(%inp_0, -1, 1) + %mul_18[(4, 32),torch.float32] = mul_Tensor(%select_72, %select_73) + ... + +This makes it easy to trace exactly which nodes were produced by a particular +rewrite rule. + +API Reference +------------- + +.. autoclass:: torch_tensorrt.dynamo.lowering._SubgraphBuilder.SubgraphBuilder + :members: + :undoc-members: + +When to Use SubgraphBuilder +--------------------------- + +Use ``SubgraphBuilder`` whenever a lowering pass needs to **expand one node into +a sequence of several nodes** in a single linear chain. Typical use cases: + +* Replacing a complex-arithmetic op with real-arithmetic equivalents + (e.g. the ``complex_mul_replacement`` in :ref:`complex_number_support_design`). +* Decomposing a high-level op (e.g. ``layer_norm``) into its ATen primitives + when a custom replacement strategy is needed beyond the standard decomposition + table. +* Inserting diagnostic nodes (shape probes, debug prints) around a target op. + +If you only need to insert a *single* node, a plain +``graph.inserting_after(node)`` is simpler. If you need to insert into multiple +disconnected locations in the same pass, create a separate ``SubgraphBuilder`` +for each anchor. diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index fade7a3ee5..219d825af3 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -25,4 +25,5 @@ Model Zoo * :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`) * :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`) * :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`) -* :ref:`debugger_example`: Debugging Torch-TensorRT Compilation \ No newline at end of file +* :ref:`debugger_example`: Debugging Torch-TensorRT Compilation +* :ref:`torch_export_3d_rope`: Compiling a 3D RoPE video-transformer block with complex numerics support \ No newline at end of file From f1b04cdbc3cfd028f59e3b0eebe1aa97dcbc187d Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Fri, 6 Mar 2026 19:37:26 +0000 Subject: [PATCH 4/8] Instead of keying on shapes we add metadata prior to subgraph replacement that marks nodes that are complex --- .../lowering/passes/complex_graph_rewrite.py | 600 +++++++- .../dynamo/lowering/test_complex_rewrite.py | 1229 +++++++++++++++++ 2 files changed, 1799 insertions(+), 30 deletions(-) create mode 100644 tests/py/dynamo/lowering/test_complex_rewrite.py diff --git a/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py b/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py index ea36a9deb5..a72d326be5 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py @@ -1,6 +1,6 @@ import logging import math -from typing import Callable, List, Optional, Set, Tuple +from typing import Callable, List, Optional, Tuple import torch from torch._subclasses.fake_tensor import FakeTensorMode @@ -37,11 +37,18 @@ torch.ops.aten.clone.default, torch.ops.aten.detach.default, torch.ops.aten.alias.default, - torch.ops.aten.expand.default, - torch.ops.aten.t.default, - # Construction — producing zero/one tensors of the same shape is layout-neutral + # NOTE: expand.default is NOT here — it takes a shape arg that must + # include the trailing real/imag dim. It has an explicit handler below. + # NOTE: t.default is NOT here — it requires an explicit handler since t() + # raises on tensors with more than 2 dimensions (which the [..., 2] real + # layout always is). + # squeeze.default (no dim arg) squeezes all size-1 dims; the trailing + # real/imag dim is always size 2 so it is never accidentally squeezed. + torch.ops.aten.squeeze.default, + # Construction — zeros_like is layout-neutral (zeros everywhere = 0+0i). + # ones_like is NOT here: ones([a, b]) in real layout = [1, 1] per element + # = 1+1i, but we want 1+0i. It has an explicit handler below. torch.ops.aten.zeros_like.default, - torch.ops.aten.ones_like.default, # Conditional selection — correct on the real layout when mask broadcasts torch.ops.aten.where.self, # Rounding — applies to each float independently; complex rounding is @@ -302,6 +309,7 @@ def replace_input_node( input_node.target + "_unpacked_complex" ) new_node.meta["val"] = fake_tensor + new_node.meta["is_complex_layout"] = True logger.debug( " unpack placeholder %s%s -> %s%s", input_node.name, @@ -323,6 +331,21 @@ def replace_input_node( self.gm.register_buffer(new_attr_name, stacked_tensor) with self.gm.graph.inserting_after(input_node): new_node = self.gm.graph.get_attr(new_attr_name) + # Set fake-tensor metadata on the new node so that _is_complex_layout_node + # can identify it as a complex-layout [..., 2] tensor later when + # processing ops that use this buffer. + if fake_mode is not None: + try: + with unset_fake_temporarily(): + real_tensor = torch.empty( + stacked_tensor.shape, + dtype=stacked_tensor.dtype, + device=stacked_tensor.device, + ) + new_node.meta["val"] = fake_mode.from_tensor(real_tensor) + except Exception: + pass # best-effort + new_node.meta["is_complex_layout"] = True logger.debug( " unpack get_attr %s%s -> %s%s", input_node.target, @@ -356,7 +379,9 @@ def _inline_cat_re_im(b: SubgraphBuilder, out_re: Node, out_im: Node) -> Node: """Rebuild a [..., 2] complex-layout tensor from re and im nodes.""" re_u = b(torch.ops.aten.unsqueeze.default, out_re, -1) im_u = b(torch.ops.aten.unsqueeze.default, out_im, -1) - return b(torch.ops.aten.cat.default, [re_u, im_u], -1) + out = b(torch.ops.aten.cat.default, [re_u, im_u], -1) + out.meta["is_complex_layout"] = True + return out @staticmethod def _inline_complex_log( @@ -444,13 +469,15 @@ def _inline_complex_sqrt( def _rewrite_view_as_complex(self, node: Node) -> bool: node.replace_all_uses_with(node.args[0]) self.gm.graph.erase_node(node) - return False # bypass only, no structural change that needs propagation + # Return True so the caller triggers propagate_metadata + gm.recompile(). + # Without recompile the compiled forward still calls the erased node. + return True @_complex_unpacker(torch.ops.aten.view_as_real.default) def _rewrite_view_as_real(self, node: Node) -> bool: node.replace_all_uses_with(node.args[0]) self.gm.graph.erase_node(node) - return False + return True # triggers recompile, same reason as above @_complex_unpacker(torch.ops.aten.permute.default) def _rewrite_permute(self, node: Node) -> bool: @@ -464,6 +491,7 @@ def _rewrite_permute(self, node: Node) -> bool: with SubgraphBuilder(self.gm.graph, node) as b: out = b(torch.ops.aten.permute.default, inp, new_dims) node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True self.gm.graph.erase_node(node) return True @@ -505,24 +533,15 @@ def _rewrite_mul_div_tensor(self, node: Node) -> bool: # Both args are Nodes from here on. if node.target == torch.ops.aten.div.Tensor: - detector = ComplexOpDetector() - - def _is_complex_layout(n: Node) -> bool: - if detector.is_complex_dtype(n): - return True - val = n.meta.get("val", None) - if val is not None and hasattr(val, "shape"): - return len(val.shape) >= 1 and val.shape[-1] == 2 - return False - - arg0_layout = _is_complex_layout(node.args[0]) - arg1_layout = _is_complex_layout(node.args[1]) + arg0_layout = self._is_complex_layout_node(node.args[0]) + arg1_layout = self._is_complex_layout_node(node.args[1]) if arg0_layout and not arg1_layout: # complex_layout / real — unsqueeze denom for correct broadcast with SubgraphBuilder(self.gm.graph, node) as b: denom_unsq = b(torch.ops.aten.unsqueeze.default, node.args[1], -1) out = b(torch.ops.aten.div.Tensor, node.args[0], denom_unsq) + out.meta["is_complex_layout"] = True node.replace_all_uses_with(out) self.gm.graph.erase_node(node) return True @@ -555,15 +574,41 @@ def match_complex_div( ) return True - # mul.Tensor, both nodes — complex × complex + # mul.Tensor, both nodes # Use SubgraphBuilder directly rather than replace_pattern_with_filters so # that self-multiplication (mul(x, x)) is handled correctly. # replace_pattern_with_filters requires distinct placeholder nodes for x and y, # so it silently produces no matches when both args are the same node. - if node in self._originally_complex: + if node.meta.get("is_complex_layout", False): x, y = node.args[0], node.args[1] x_is_get_attr = x.op == "get_attr" y_is_get_attr = y.op == "get_attr" + x_is_complex = self._is_complex_layout_node(x) + y_is_complex = self._is_complex_layout_node(y) + + # complex × real (or real × complex): just scale both components + if x_is_complex and not y_is_complex and not x_is_get_attr: + with SubgraphBuilder(self.gm.graph, node) as b: + x_re = b(torch.ops.aten.select.int, x, -1, 0) + x_im = b(torch.ops.aten.select.int, x, -1, 1) + out_re = b(torch.ops.aten.mul.Tensor, x_re, y) + out_im = b(torch.ops.aten.mul.Tensor, x_im, y) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + if not x_is_complex and y_is_complex and not y_is_get_attr: + with SubgraphBuilder(self.gm.graph, node) as b: + y_re = b(torch.ops.aten.select.int, y, -1, 0) + y_im = b(torch.ops.aten.select.int, y, -1, 1) + out_re = b(torch.ops.aten.mul.Tensor, x, y_re) + out_im = b(torch.ops.aten.mul.Tensor, x, y_im) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True if not x_is_get_attr and not y_is_get_attr: # Both are ITensors — use select.int (TRT-compatible) @@ -580,6 +625,7 @@ def match_complex_div( out_im = b(torch.ops.aten.add.Tensor, ad, bc) out = self._inline_cat_re_im(b, out_re, out_im) node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True self.gm.graph.erase_node(node) return True else: @@ -623,6 +669,148 @@ def _rewrite_add_sub_tensor_scalar(self, node: Node) -> bool: new_re = b(node.target, re, scalar) out = self._inline_cat_re_im(b, new_re, im) node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.ones_like.default) + def _rewrite_ones_like(self, node: Node) -> bool: + # ones_like in [..., 2] layout produces [1, 1] = 1+1i. We want 1+0i. + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re_slice = b(torch.ops.aten.select.int, inp, -1, 0) + out_re = b(torch.ops.aten.ones_like.default, re_slice) + out_im = b(torch.ops.aten.zeros_like.default, re_slice) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.full_like.default) + def _rewrite_full_like(self, node: Node) -> bool: + # full_like(z, fill_value) in [..., 2] layout fills both re and im with + # fill_value → fill_value + fill_value*i. We want fill_value + 0i. + inp = node.args[0] + fill_value = node.args[1] + with SubgraphBuilder(self.gm.graph, node) as b: + re_slice = b(torch.ops.aten.select.int, inp, -1, 0) + out_re = b(torch.ops.aten.full_like.default, re_slice, fill_value) + out_im = b(torch.ops.aten.zeros_like.default, re_slice) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.sum.dim_IntList) + def _rewrite_sum_dim(self, node: Node) -> bool: + # sum.dim_IntList(inp, dim_list, keepdim=False, dtype=None) + # Negative dims must be shifted by -1 to skip the trailing real/imag dim. + inp = node.args[0] + dims = list(node.args[1]) + new_dims = [d - 1 if d < 0 else d for d in dims] + if new_dims == dims: + return False # all positive — pass-through is correct + keepdim = node.args[2] if len(node.args) > 2 else False + extra = list(node.args[3:]) + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.sum.dim_IntList, inp, new_dims, keepdim, *extra) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.mean.dim) + def _rewrite_mean_dim(self, node: Node) -> bool: + # mean.dim(inp, dim_list, keepdim=False, dtype=None) + inp = node.args[0] + dims = list(node.args[1]) + new_dims = [d - 1 if d < 0 else d for d in dims] + if new_dims == dims: + return False # all positive — pass-through is correct + keepdim = node.args[2] if len(node.args) > 2 else False + extra = list(node.args[3:]) + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.mean.dim, inp, new_dims, keepdim, *extra) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.prod.dim_int) + def _rewrite_prod_dim(self, node: Node) -> bool: + # prod.dim_int(inp, dim, keepdim=False, dtype=None) + inp = node.args[0] + dim = node.args[1] + if dim >= 0: + return False # positive dim — pass-through is correct + new_dim = dim - 1 + keepdim = node.args[2] if len(node.args) > 2 else False + extra = list(node.args[3:]) + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.prod.dim_int, inp, new_dim, keepdim, *extra) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.narrow.default) + def _rewrite_narrow(self, node: Node) -> bool: + # narrow(inp, dim, start, length) — shift negative dim by -1 + inp, dim, start, length = node.args + if dim >= 0: + return False + new_dim = dim - 1 + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.narrow.default, inp, new_dim, start, length) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.roll.default) + def _rewrite_roll(self, node: Node) -> bool: + # roll(inp, shifts, dims) — shift negative dims by -1 + inp = node.args[0] + shifts = node.args[1] + dims = list(node.args[2]) if len(node.args) > 2 else [] + new_dims = [d - 1 if d < 0 else d for d in dims] + if new_dims == dims: + return False + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.roll.default, inp, shifts, new_dims) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.flip.default) + def _rewrite_flip(self, node: Node) -> bool: + # flip(inp, dims) — shift negative dims by -1 + inp = node.args[0] + dims = list(node.args[1]) + new_dims = [d - 1 if d < 0 else d for d in dims] + if new_dims == dims: + return False + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.flip.default, inp, new_dims) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.repeat.default) + def _rewrite_repeat(self, node: Node) -> bool: + # repeat(inp, repeats) — repeats must include a trailing 1 for the + # real/imag dim so the layout is not disrupted. + inp = node.args[0] + repeats = list(node.args[1]) + new_repeats = repeats + [1] + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.repeat.default, inp, new_repeats) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True self.gm.graph.erase_node(node) return True @@ -636,6 +824,26 @@ def _rewrite_conj(self, node: Node) -> bool: neg_im = b(torch.ops.aten.neg.default, im) out = self._inline_cat_re_im(b, re, neg_im) node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.reciprocal.default) + def _rewrite_reciprocal(self, node: Node) -> bool: + # 1/(a+bi) = a/(a²+b²) - ib/(a²+b²) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + denom = b(torch.ops.aten.add.Tensor, re2, im2) + out_re = b(torch.ops.aten.div.Tensor, re, denom) + neg_im = b(torch.ops.aten.neg.default, im) + out_im = b(torch.ops.aten.div.Tensor, neg_im, denom) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True self.gm.graph.erase_node(node) return True @@ -1103,6 +1311,337 @@ def _rewrite_scalar_tensor(self, node: Node) -> bool: self.gm.graph.erase_node(node) return True + # ------------------------------------------------------------------ + # Shape-manipulation handlers + # + # All of these work on the same principle: in the [..., 2] real layout + # the trailing dimension stores real/imag. Dimension indices that refer + # to the *last* complex dimension (dim=-1) must be shifted by -1 to + # avoid touching or conflating with the trailing 2 dim. + # + # Rule: new_dim = dim - 1 if dim < 0 else dim + # ------------------------------------------------------------------ + + @_complex_unpacker( + torch.ops.aten.reshape.default, + torch.ops.aten.view.default, + torch.ops.aten._unsafe_view.default, + ) + def _rewrite_reshape_view(self, node: Node) -> bool: + # Append 2 to the target shape so the trailing real/imag dim is + # preserved after the reshape. E.g. complex [a,b] reshaped to [c] + # becomes float [a,b,2] reshaped to [c,2]. + inp = node.args[0] + new_shape = list(node.args[1]) + [2] + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(node.target, inp, new_shape) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.flatten.using_ints) + def _rewrite_flatten(self, node: Node) -> bool: + inp = node.args[0] + start_dim = node.args[1] if len(node.args) > 1 else 0 + end_dim = node.args[2] if len(node.args) > 2 else -1 + # Shift negative dims by -1 so end_dim=-1 (last complex dim) maps to + # the second-to-last dim in the real layout, keeping the trailing 2 intact. + new_start = start_dim - 1 if start_dim < 0 else start_dim + new_end = end_dim - 1 if end_dim < 0 else end_dim + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.flatten.using_ints, inp, new_start, new_end) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.unsqueeze.default) + def _rewrite_unsqueeze(self, node: Node) -> bool: + inp = node.args[0] + dim = node.args[1] + # Negative dims: shift by -1 so dim=-1 inserts *before* the trailing + # real/imag dim rather than *after* it. + new_dim = dim - 1 if dim < 0 else dim + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.unsqueeze.default, inp, new_dim) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.squeeze.dim, torch.ops.aten.squeeze.dims) + def _rewrite_squeeze_dim(self, node: Node) -> bool: + inp = node.args[0] + # squeeze.dim(inp, int) vs squeeze.dims(inp, List[int]) + is_multi = node.target == torch.ops.aten.squeeze.dims + raw_dim = node.args[1] + dims_list = list(raw_dim) if is_multi else [raw_dim] + # Shift negative dims so that complex dim=-1 (last complex dim) maps to + # real-layout dim=-2 (second-to-last), keeping the trailing real/imag dim. + # A squeeze on a valid complex dim can never accidentally hit the trailing + # 2 dim: for rank-n complex, valid dims are [-n, n-1]; after the shift + # they land in [-n-1, n-1], all safely before the trailing dim at index n. + new_dims = [d - 1 if d < 0 else d for d in dims_list] + with SubgraphBuilder(self.gm.graph, node) as b: + if is_multi: + out = b(torch.ops.aten.squeeze.dims, inp, new_dims) + else: + out = b(torch.ops.aten.squeeze.dim, inp, new_dims[0]) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.cat.default) + def _rewrite_cat(self, node: Node) -> bool: + tensors = node.args[0] + dim = node.args[1] if len(node.args) > 1 else 0 + # Negative dims: shift by -1 to avoid concatenating into the trailing + # real/imag dim. E.g. cat(tensors, dim=-1) on complex tensors should + # concat along the last *complex* dimension, not the trailing 2. + new_dim = dim - 1 if dim < 0 else dim + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.cat.default, list(tensors), new_dim) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.stack.default) + def _rewrite_stack(self, node: Node) -> bool: + tensors = node.args[0] + dim = node.args[1] if len(node.args) > 1 else 0 + # Negative dims: shift by -1 so a new dim inserted at position -1 lands + # before the trailing real/imag dim, not after it. + new_dim = dim - 1 if dim < 0 else dim + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.stack.default, list(tensors), new_dim) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.t.default) + def _rewrite_t(self, node: Node) -> bool: + # t() is the 2-D transpose shorthand. After unpacking, the tensor is + # 3-D ([..., 2]) so t() would raise. Replace with transpose(0, 1). + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.transpose.int, inp, 0, 1) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.transpose.int) + def _rewrite_transpose(self, node: Node) -> bool: + inp = node.args[0] + dim0, dim1 = node.args[1], node.args[2] + # Get the original complex rank from node metadata (not yet re-propagated). + node_val = node.meta.get("val", None) + if node_val is None or not hasattr(node_val, "shape"): + logger.warning( + "transpose on complex tensor '%s': no metadata, skipping rewrite. " + "This may produce incorrect results or fail TRT compilation.", + node.name, + ) + return False + n = len(node_val.shape) # original complex rank + # Normalize dims to absolute indices in [0, n-1]: same indices are valid + # in the real layout too (both are before the trailing 2). + abs0 = dim0 % n if dim0 < 0 else dim0 + abs1 = dim1 % n if dim1 < 0 else dim1 + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.transpose.int, inp, abs0, abs1) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.select.int) + def _rewrite_select(self, node: Node) -> bool: + # select.int on a complex tensor selects along a batch/sequence dim. + # In the real layout the trailing dim encodes real/imag, so negative + # dim indices must be shifted by -1 to avoid selecting from that dim. + inp = node.args[0] + dim = node.args[1] + idx = node.args[2] + if dim >= 0: + return False # non-negative dims are unchanged in real layout + new_dim = dim - 1 + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.select.int, inp, new_dim, idx) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.slice.Tensor) + def _rewrite_slice(self, node: Node) -> bool: + inp = node.args[0] + dim = node.args[1] if len(node.args) > 1 else 0 + start = node.args[2] if len(node.args) > 2 else None + end = node.args[3] if len(node.args) > 3 else None + step = node.args[4] if len(node.args) > 4 else 1 + if dim >= 0: + return False # non-negative dims are safe in real layout + new_dim = dim - 1 + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.slice.Tensor, inp, new_dim, start, end, step) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker( + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes.default, + ) + def _rewrite_split(self, node: Node) -> bool: + inp = node.args[0] + size_or_sizes = node.args[1] + dim = node.args[2] if len(node.args) > 2 else 0 + new_dim = dim - 1 if dim < 0 else dim + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(node.target, inp, size_or_sizes, new_dim) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.chunk.default) + def _rewrite_chunk(self, node: Node) -> bool: + inp = node.args[0] + chunks = node.args[1] + dim = node.args[2] if len(node.args) > 2 else 0 + new_dim = dim - 1 if dim < 0 else dim + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.chunk.default, inp, chunks, new_dim) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.expand.default) + def _rewrite_expand(self, node: Node) -> bool: + # expand(input, size) — size must include the trailing real/imag dim. + # Append 2 to the size list. Negative sizes (-1 = keep dim) are left as-is; + # only the trailing 2 is appended for the real/imag encoding dim. + inp = node.args[0] + size = list(node.args[1]) + new_size = size + [2] + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.expand.default, inp, new_size) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + # ------------------------------------------------------------------ + # Matrix-multiplication handlers + # + # Complex mm: (A+iB)(C+iD) = (AC-BD) + i(AD+BC) — 4 real matmuls. + # ------------------------------------------------------------------ + + def _inline_complex_mm_op( + self, + b: "SubgraphBuilder", + matmul_op: object, + x: Node, + y: Node, + x_was_complex: bool, + y_was_complex: bool, + ) -> "Tuple[Node, Node]": + """Emit real/imag components of a complex matmul using *matmul_op*.""" + if x_was_complex and y_was_complex: + x_re = b(torch.ops.aten.select.int, x, -1, 0) + x_im = b(torch.ops.aten.select.int, x, -1, 1) + y_re = b(torch.ops.aten.select.int, y, -1, 0) + y_im = b(torch.ops.aten.select.int, y, -1, 1) + ac = b(matmul_op, x_re, y_re) + bd = b(matmul_op, x_im, y_im) + ad = b(matmul_op, x_re, y_im) + bc = b(matmul_op, x_im, y_re) + out_re = b(torch.ops.aten.sub.Tensor, ac, bd) + out_im = b(torch.ops.aten.add.Tensor, ad, bc) + elif x_was_complex: + # x is complex, y is real: (A+iB)*C = AC + iBC + x_re = b(torch.ops.aten.select.int, x, -1, 0) + x_im = b(torch.ops.aten.select.int, x, -1, 1) + out_re = b(matmul_op, x_re, y) + out_im = b(matmul_op, x_im, y) + else: + # x is real, y is complex: A*(C+iD) = AC + iAD + y_re = b(torch.ops.aten.select.int, y, -1, 0) + y_im = b(torch.ops.aten.select.int, y, -1, 1) + out_re = b(matmul_op, x, y_re) + out_im = b(matmul_op, x, y_im) + return out_re, out_im + + def _is_complex_layout_node(self, n: Node) -> bool: + """True if *n* is in real [..., 2] layout representing a complex tensor. + + All complex nodes are annotated with node.meta["is_complex_layout"] = True + during the detection phase (or by each rewrite handler as it emits new + nodes), so this is a direct metadata lookup — no shape heuristics needed. + """ + return n.meta.get("is_complex_layout", False) + + @_complex_unpacker(torch.ops.aten.mm.default) + def _rewrite_mm(self, node: Node) -> bool: + if not node.meta.get("is_complex_layout", False): + return False + x, y = node.args[0], node.args[1] + x_c = self._is_complex_layout_node(x) + y_c = self._is_complex_layout_node(y) + if not x_c and not y_c: + return False + with SubgraphBuilder(self.gm.graph, node) as b: + out_re, out_im = self._inline_complex_mm_op( + b, torch.ops.aten.mm.default, x, y, x_c, y_c + ) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.bmm.default) + def _rewrite_bmm(self, node: Node) -> bool: + if not node.meta.get("is_complex_layout", False): + return False + x, y = node.args[0], node.args[1] + x_c = self._is_complex_layout_node(x) + y_c = self._is_complex_layout_node(y) + if not x_c and not y_c: + return False + with SubgraphBuilder(self.gm.graph, node) as b: + out_re, out_im = self._inline_complex_mm_op( + b, torch.ops.aten.bmm.default, x, y, x_c, y_c + ) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.matmul.default) + def _rewrite_matmul(self, node: Node) -> bool: + if not node.meta.get("is_complex_layout", False): + return False + x, y = node.args[0], node.args[1] + x_c = self._is_complex_layout_node(x) + y_c = self._is_complex_layout_node(y) + if not x_c and not y_c: + return False + with SubgraphBuilder(self.gm.graph, node) as b: + out_re, out_im = self._inline_complex_mm_op( + b, torch.ops.aten.matmul.default, x, y, x_c, y_c + ) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + @_complex_unpacker(torch.ops.aten.where.self) def _rewrite_where(self, node: Node) -> bool: # where.self: unsqueeze mask and optionally expand true-branch for complex layout. @@ -1127,6 +1666,7 @@ def _rewrite_where(self, node: Node) -> bool: ): true_arg = b(torch.ops.aten.expand.default, true_node, target_shape) out = b(torch.ops.aten.where.self, mask_unsq, true_arg, other_node) + out.meta["is_complex_layout"] = True node.replace_all_uses_with(out) self.gm.graph.erase_node(node) return True @@ -1141,20 +1681,20 @@ def rewrite_subgraph_nodes(self, subgraphs: List[ComplexSubGraphInfo]) -> None: # active) and would lose SymInt information for torch.export graphs. detected_fake_mode = torch._export.utils._detect_fake_mode_from_gm(self.gm) - # Record the set of all nodes that have complex dtype BEFORE any rewriting. - # This is needed because after replace_input_node (which changes dtype from - # complex to float32), is_complex_dtype() would return False for those nodes — - # but we still need to know they were originally complex when we later decide - # whether a mul.Tensor operand should be treated as complex-layout. + # Annotate all nodes that have complex dtype BEFORE any rewriting. + # We stamp node.meta["is_complex_layout"] = True on every complex-dtype node + # so that later passes can reliably distinguish real [..., 2] layout tensors + # (created by this rewriter) from coincidentally-shaped real tensors. + # This is stable across rewrites: after replace_input_node changes dtype to + # float32, is_complex_dtype() would return False, but the metadata flag persists. detector = ComplexOpDetector() - self._originally_complex: Set[Node] = set() for subgraph in subgraphs: for node in subgraph.input_nodes: if detector.is_complex_dtype(node): - self._originally_complex.add(node) + node.meta["is_complex_layout"] = True for node in subgraph.subgraph_nodes: if detector.is_complex_dtype(node): - self._originally_complex.add(node) + node.meta["is_complex_layout"] = True # _DISPATCH maps op -> unbound method; bind self here once per call. dispatch = {op: method.__get__(self) for op, method in self._DISPATCH.items()} diff --git a/tests/py/dynamo/lowering/test_complex_rewrite.py b/tests/py/dynamo/lowering/test_complex_rewrite.py new file mode 100644 index 0000000000..dfd752df9b --- /dev/null +++ b/tests/py/dynamo/lowering/test_complex_rewrite.py @@ -0,0 +1,1229 @@ +"""Comprehensive numerical-equivalence tests for complex_graph_detection lowering pass. + +Each test verifies: + lowered_gm(view_as_real(z)) ≡ original_model(z) (numerically) + +The lowering pass rewrites complex-dtype ops to real arithmetic on a [..., 2] +layout (trailing dim encodes real/imag). After lowering, all inputs and outputs +are in that real layout; the test harness converts back to complex before +comparison. + +Organisation +------------ + 1. Infrastructure helpers + 2. Elementwise arithmetic (mul / div / add / sub variants) + 3. Complex-specific ops (real, imag, conj, abs, angle, polar) + 4. Transcendental functions (exp, log, pow, sin/cos/tan …) + 5. Shape manipulation (permute, reshape/view, flatten, squeeze/unsqueeze, + cat, stack, select, slice, split, chunk, expand, + transpose, t, clone, narrow, roll, flip) + 6. Matrix multiplication (mm, bmm, matmul) + 7. Elementwise-safe pass-through verification + 8. Reduction ops (sum / mean — positive dims pass, negative = xfail) + 9. Creation-op bugs (ones_like → xfail, zeros_like → pass, full_like → xfail) +10. Chain / composition tests + +xfail tests document known bugs or missing handlers. They are expected to fail. +If they start passing a handler was fixed — remove the xfail marker. +""" + +from __future__ import annotations + +from typing import Any, Tuple + +import pytest +import torch +import torch.nn as nn + +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes.complex_graph_rewrite import ( + complex_graph_detection, +) + +# --------------------------------------------------------------------------- +# 1. Infrastructure +# --------------------------------------------------------------------------- + +_RTOL = 1e-4 +_ATOL = 1e-4 + + +def _export_and_lower( + model: nn.Module, example_inputs: Tuple[Any, ...] +) -> torch.fx.GraphModule: + """Export *model* and apply the complex_graph_detection lowering pass.""" + with torch.no_grad(): + exp = torch.export.export(model.eval(), example_inputs) + gm = exp.module() + complex_graph_detection(gm, CompilationSettings()) + return gm + + +def _real_inputs(inputs: Tuple[Any, ...]) -> Tuple[Any, ...]: + """Convert complex tensors to [..., 2] real layout.""" + return tuple( + torch.view_as_real(x).contiguous() if isinstance(x, torch.Tensor) and x.is_complex() + else x + for x in inputs + ) + + +def _to_complex_if_needed(out: torch.Tensor, ref: torch.Tensor) -> torch.Tensor: + """Reinterpret *out* as complex if *ref* is complex and *out* has trailing dim 2.""" + if ref.is_complex() and not out.is_complex() and out.shape[-1] == 2: + return torch.view_as_complex(out.contiguous()) + return out + + +def _assert_close(ref: torch.Tensor, got: torch.Tensor, tag: str) -> None: + if ref.dtype == torch.bool: + assert torch.equal(got, ref), f"{tag}: bool tensor mismatch" + return + if ref.is_complex(): + assert got.is_complex(), f"{tag}: expected complex output, got {got.dtype}" + assert ref.shape == got.shape, f"{tag}: shape {got.shape} != {ref.shape}" + torch.testing.assert_close( + got.real.float(), ref.real.float(), rtol=_RTOL, atol=_ATOL + ) + torch.testing.assert_close( + got.imag.float(), ref.imag.float(), rtol=_RTOL, atol=_ATOL + ) + else: + assert not got.is_complex(), f"{tag}: expected real output, got {got.dtype}" + assert ref.shape == got.shape, f"{tag}: shape {got.shape} != {ref.shape}" + torch.testing.assert_close(got.float(), ref.float(), rtol=_RTOL, atol=_ATOL) + + +def _check_op(model: nn.Module, inputs: Tuple[Any, ...], tag: str) -> None: + """Full pipeline: run model → export+lower → compare.""" + with torch.no_grad(): + ref = model(*inputs) + + gm = _export_and_lower(model, inputs) + raw = gm(*_real_inputs(inputs)) + + if isinstance(raw, (list, tuple)): + ref_list = list(ref) if isinstance(ref, (list, tuple)) else [ref] + for i, (r, o) in enumerate(zip(ref_list, raw)): + got = _to_complex_if_needed(o, r) + _assert_close(r, got, f"{tag}[{i}]") + else: + got = _to_complex_if_needed(raw, ref) + _assert_close(ref, got, tag) + + +# Convenience: 2-D complex inputs used by most tests +def _z(rows: int = 3, cols: int = 4) -> torch.Tensor: + return torch.randn(rows, cols, dtype=torch.complex64) + + +def _z3d(b: int = 2, m: int = 3, n: int = 4) -> torch.Tensor: + return torch.randn(b, m, n, dtype=torch.complex64) + + +# =========================================================================== +# 2. Elementwise arithmetic +# =========================================================================== + + +@pytest.mark.unit +def test_mul_complex_complex(): + class M(nn.Module): + def forward(self, x, y): + return x * y + + z1, z2 = _z(), _z() + _check_op(M(), (z1, z2), "mul_cc") + + +@pytest.mark.unit +def test_mul_complex_real(): + """Complex × real tensor — only the complex part gets the mul handler.""" + + class M(nn.Module): + def forward(self, z, r): + return z * r # z complex, r real — result is complex + + z = _z() + r = torch.randn(3, 4) + _check_op(M(), (z, r), "mul_cr") + + +@pytest.mark.unit +def test_mul_scalar(): + """z * scalar — both re/im scaled equally (elementwise-safe).""" + + class M(nn.Module): + def forward(self, z): + return z * 3.0 + + _check_op(M(), (_z(),), "mul_scalar") + + +@pytest.mark.unit +def test_div_complex_complex(): + class M(nn.Module): + def forward(self, x, y): + return x / y + + _check_op(M(), (_z(), _z() + 0.1), "div_cc") + + +@pytest.mark.unit +def test_div_complex_scalar(): + class M(nn.Module): + def forward(self, z): + return z / 2.0 + + _check_op(M(), (_z(),), "div_cscalar") + + +@pytest.mark.unit +def test_div_scalar_complex(): + """scalar / complex — s/(a+bi) = (sa - sbi)/(a²+b²).""" + + class M(nn.Module): + def forward(self, z): + return 4.0 / (z + 0.1) + + _check_op(M(), (_z(),), "div_scalar_c") + + +@pytest.mark.unit +def test_add_tensor(): + """z1 + z2 — both complex; elementwise-safe (component-wise).""" + + class M(nn.Module): + def forward(self, x, y): + return x + y + + _check_op(M(), (_z(), _z()), "add_tensor") + + +@pytest.mark.unit +def test_sub_tensor(): + class M(nn.Module): + def forward(self, x, y): + return x - y + + _check_op(M(), (_z(), _z()), "sub_tensor") + + +@pytest.mark.unit +def test_add_scalar(): + """(a+bi) + s = (a+s) + bi — scalar added to real part only.""" + + class M(nn.Module): + def forward(self, z): + return z + 2.5 + + _check_op(M(), (_z(),), "add_scalar") + + +@pytest.mark.unit +def test_sub_scalar(): + class M(nn.Module): + def forward(self, z): + return z - 1.0 + + _check_op(M(), (_z(),), "sub_scalar") + + +@pytest.mark.unit +def test_neg(): + """Negation is elementwise-safe (flips sign of both re/im).""" + + class M(nn.Module): + def forward(self, z): + return -z + + _check_op(M(), (_z(),), "neg") + + +# =========================================================================== +# 3. Complex-specific ops +# =========================================================================== + + +@pytest.mark.unit +def test_real(): + """z.real → real tensor (select re component).""" + + class M(nn.Module): + def forward(self, z): + return z.real + + _check_op(M(), (_z(3, 5),), "real") # shape (3,5) so last dim≠2 + + +@pytest.mark.unit +def test_imag(): + class M(nn.Module): + def forward(self, z): + return z.imag + + _check_op(M(), (_z(3, 5),), "imag") + + +@pytest.mark.unit +def test_conj(): + """conj(a+bi) = a - bi.""" + + class M(nn.Module): + def forward(self, z): + return torch.conj(z) + + _check_op(M(), (_z(),), "conj") + + +@pytest.mark.unit +def test_abs(): + """|a+bi| = sqrt(a²+b²) — real output.""" + + class M(nn.Module): + def forward(self, z): + return torch.abs(z) + + _check_op(M(), (_z(3, 5),), "abs") + + +@pytest.mark.unit +def test_angle(): + """angle(a+bi) = atan2(b, a) — real output.""" + + class M(nn.Module): + def forward(self, z): + return torch.angle(z) + + _check_op(M(), (_z(3, 5),), "angle") + + +@pytest.mark.unit +def test_polar(): + """polar(r, theta) = r*cos(theta) + i*r*sin(theta).""" + + class M(nn.Module): + def forward(self, r, theta): + return torch.polar(r, theta) + + r = torch.rand(3, 4) + 0.1 + theta = torch.randn(3, 4) + _check_op(M(), (r, theta), "polar") + + +# =========================================================================== +# 4. Transcendental functions +# =========================================================================== + + +@pytest.mark.unit +def test_exp(): + """exp(a+bi) = e^a*(cos(b) + i*sin(b)).""" + + class M(nn.Module): + def forward(self, z): + return torch.exp(z) + + _check_op(M(), (_z(),), "exp") + + +@pytest.mark.unit +def test_log(): + class M(nn.Module): + def forward(self, z): + return torch.log(z) + + _check_op(M(), (_z() + 0.1,), "log") + + +@pytest.mark.unit +def test_log2(): + class M(nn.Module): + def forward(self, z): + return torch.log2(z) + + _check_op(M(), (_z() + 0.1,), "log2") + + +@pytest.mark.unit +def test_log10(): + class M(nn.Module): + def forward(self, z): + return torch.log10(z) + + _check_op(M(), (_z() + 0.1,), "log10") + + +@pytest.mark.unit +def test_log1p(): + class M(nn.Module): + def forward(self, z): + return torch.log1p(z) + + _check_op(M(), (_z() + 0.1,), "log1p") + + +@pytest.mark.unit +def test_expm1(): + class M(nn.Module): + def forward(self, z): + return torch.expm1(z) + + # Use small values to keep numbers from overflowing + z = torch.randn(3, 4, dtype=torch.complex64) * 0.3 + _check_op(M(), (z,), "expm1") + + +@pytest.mark.unit +def test_sqrt(): + class M(nn.Module): + def forward(self, z): + return torch.sqrt(z) + + _check_op(M(), (_z(),), "sqrt") + + +@pytest.mark.unit +def test_pow_scalar(): + """z**n via polar form.""" + + class M(nn.Module): + def forward(self, z): + return z**2.0 + + _check_op(M(), (_z() + 0.1,), "pow_scalar") + + +@pytest.mark.unit +def test_pow_tensor(): + """z1**z2 = exp(z2 * log(z1)).""" + + class M(nn.Module): + def forward(self, x, y): + return x**y + + z1 = _z() + 0.5 + z2 = torch.randn(3, 4, dtype=torch.complex64) * 0.3 + _check_op(M(), (z1, z2), "pow_tensor") + + +@pytest.mark.unit +def test_sin(): + class M(nn.Module): + def forward(self, z): + return torch.sin(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.5 + _check_op(M(), (z,), "sin") + + +@pytest.mark.unit +def test_cos(): + class M(nn.Module): + def forward(self, z): + return torch.cos(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.5 + _check_op(M(), (z,), "cos") + + +@pytest.mark.unit +def test_tan(): + class M(nn.Module): + def forward(self, z): + return torch.tan(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.3 + _check_op(M(), (z,), "tan") + + +@pytest.mark.unit +def test_sinh(): + class M(nn.Module): + def forward(self, z): + return torch.sinh(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.5 + _check_op(M(), (z,), "sinh") + + +@pytest.mark.unit +def test_cosh(): + class M(nn.Module): + def forward(self, z): + return torch.cosh(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.5 + _check_op(M(), (z,), "cosh") + + +@pytest.mark.unit +def test_tanh(): + class M(nn.Module): + def forward(self, z): + return torch.tanh(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.5 + _check_op(M(), (z,), "tanh") + + +@pytest.mark.unit +def test_asin(): + class M(nn.Module): + def forward(self, z): + return torch.asin(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.5 + _check_op(M(), (z,), "asin") + + +@pytest.mark.unit +def test_acos(): + class M(nn.Module): + def forward(self, z): + return torch.acos(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.5 + _check_op(M(), (z,), "acos") + + +@pytest.mark.unit +def test_atan(): + class M(nn.Module): + def forward(self, z): + return torch.atan(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.5 + _check_op(M(), (z,), "atan") + + +@pytest.mark.unit +def test_asinh(): + class M(nn.Module): + def forward(self, z): + return torch.asinh(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.5 + _check_op(M(), (z,), "asinh") + + +@pytest.mark.unit +def test_acosh(): + class M(nn.Module): + def forward(self, z): + return torch.acosh(z) + + # acosh needs |z| > 1 to avoid NaN + z = torch.randn(3, 4, dtype=torch.complex64) + 2.0 + _check_op(M(), (z,), "acosh") + + +@pytest.mark.unit +def test_atanh(): + class M(nn.Module): + def forward(self, z): + return torch.atanh(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.3 + _check_op(M(), (z,), "atanh") + + +@pytest.mark.unit +def test_isnan(): + """isnan/isinf: boolean output, checks re|im.""" + + class M(nn.Module): + def forward(self, z): + return torch.isnan(z) + + _check_op(M(), (_z(3, 5),), "isnan") + + +@pytest.mark.unit +def test_isinf(): + class M(nn.Module): + def forward(self, z): + return torch.isinf(z) + + _check_op(M(), (_z(3, 5),), "isinf") + + +# =========================================================================== +# 5. Shape manipulation +# =========================================================================== + + +@pytest.mark.unit +def test_view_as_real_complex_bypass(): + """view_as_real → view_as_complex is a round-trip no-op after lowering.""" + + class M(nn.Module): + def forward(self, z): + r = torch.view_as_real(z) + return torch.view_as_complex(r) + + _check_op(M(), (_z(),), "var_vac_bypass") + + +@pytest.mark.unit +def test_permute(): + class M(nn.Module): + def forward(self, z): + return z.permute(1, 0) + + _check_op(M(), (_z(),), "permute_2d") + + +@pytest.mark.unit +def test_permute_3d(): + class M(nn.Module): + def forward(self, z): + return z.permute(0, 2, 1) + + _check_op(M(), (_z3d(),), "permute_3d") + + +@pytest.mark.unit +def test_reshape(): + class M(nn.Module): + def forward(self, z): + return z.reshape(12) + + _check_op(M(), (_z(),), "reshape") + + +@pytest.mark.unit +def test_reshape_batch(): + class M(nn.Module): + def forward(self, z): + return z.reshape(2, 6) + + _check_op(M(), (torch.randn(3, 4, dtype=torch.complex64),), "reshape_batch") + + +@pytest.mark.unit +def test_view(): + class M(nn.Module): + def forward(self, z): + return z.view(12) + + _check_op(M(), (torch.randn(3, 4, dtype=torch.complex64).contiguous(),), "view") + + +@pytest.mark.unit +def test_flatten_all(): + class M(nn.Module): + def forward(self, z): + return z.flatten() + + _check_op(M(), (_z3d(),), "flatten_all") + + +@pytest.mark.unit +def test_flatten_partial(): + class M(nn.Module): + def forward(self, z): + return z.flatten(1, -1) + + _check_op(M(), (_z3d(),), "flatten_partial") + + +@pytest.mark.unit +def test_flatten_start_neg(): + class M(nn.Module): + def forward(self, z): + return z.flatten(-2, -1) + + _check_op(M(), (_z3d(),), "flatten_neg_dims") + + +@pytest.mark.unit +def test_unsqueeze_pos(): + class M(nn.Module): + def forward(self, z): + return z.unsqueeze(0) + + _check_op(M(), (_z(),), "unsqueeze_pos") + + +@pytest.mark.unit +def test_unsqueeze_neg(): + class M(nn.Module): + def forward(self, z): + return z.unsqueeze(-1) + + _check_op(M(), (_z(),), "unsqueeze_neg") + + +@pytest.mark.unit +def test_unsqueeze_mid_neg(): + class M(nn.Module): + def forward(self, z): + return z.unsqueeze(-2) + + _check_op(M(), (_z3d(),), "unsqueeze_mid_neg") + + +@pytest.mark.unit +def test_squeeze_pos(): + class M(nn.Module): + def forward(self, z): + return z.squeeze(0) + + _check_op(M(), (torch.randn(1, 4, dtype=torch.complex64),), "squeeze_pos") + + +@pytest.mark.unit +def test_squeeze_neg(): + class M(nn.Module): + def forward(self, z): + return z.squeeze(-2) + + _check_op(M(), (torch.randn(3, 1, 4, dtype=torch.complex64),), "squeeze_neg") + + +@pytest.mark.unit +def test_squeeze_last_dim(): + """squeeze(dim=-1) removes the last *complex* dim (not real/imag encoding).""" + + class M(nn.Module): + def forward(self, z): + return z.squeeze(-1) + + _check_op(M(), (torch.randn(3, 1, dtype=torch.complex64),), "squeeze_last") + + +@pytest.mark.unit +def test_cat_dim0(): + class M(nn.Module): + def forward(self, x, y): + return torch.cat([x, y], dim=0) + + _check_op(M(), (_z(2, 4), _z(3, 4)), "cat_dim0") + + +@pytest.mark.unit +def test_cat_dim1(): + class M(nn.Module): + def forward(self, x, y): + return torch.cat([x, y], dim=1) + + _check_op(M(), (_z(3, 2), _z(3, 3)), "cat_dim1") + + +@pytest.mark.unit +def test_cat_neg_dim(): + """cat(tensors, dim=-1) on complex — must concat the last *complex* dim.""" + + class M(nn.Module): + def forward(self, x, y): + return torch.cat([x, y], dim=-1) + + _check_op(M(), (_z(3, 2), _z(3, 3)), "cat_neg_dim") + + +@pytest.mark.unit +def test_stack_dim0(): + class M(nn.Module): + def forward(self, x, y): + return torch.stack([x, y], dim=0) + + _check_op(M(), (_z(), _z()), "stack_dim0") + + +@pytest.mark.unit +def test_stack_neg_dim(): + """stack(tensors, dim=-1) — new dim must land before real/imag encoding.""" + + class M(nn.Module): + def forward(self, x, y): + return torch.stack([x, y], dim=-1) + + _check_op(M(), (_z(), _z()), "stack_neg_dim") + + +@pytest.mark.unit +def test_select_pos(): + class M(nn.Module): + def forward(self, z): + return z[1] + + _check_op(M(), (_z(),), "select_pos") + + +@pytest.mark.unit +def test_select_neg_dim(): + """select along the last complex dim (dim=-1).""" + + class M(nn.Module): + def forward(self, z): + return z.select(-1, 2) + + _check_op(M(), (_z(),), "select_neg_dim") + + +@pytest.mark.unit +def test_slice_pos(): + class M(nn.Module): + def forward(self, z): + return z[1:] + + _check_op(M(), (_z(),), "slice_pos") + + +@pytest.mark.unit +def test_slice_neg_dim(): + class M(nn.Module): + def forward(self, z): + return z[..., 1:3] + + _check_op(M(), (torch.randn(3, 6, dtype=torch.complex64),), "slice_neg_dim") + + +@pytest.mark.unit +def test_split(): + class M(nn.Module): + def forward(self, z): + a, b = z.split(2, dim=0) + return a + b + + _check_op(M(), (torch.randn(4, 3, dtype=torch.complex64),), "split_pos") + + +@pytest.mark.unit +def test_split_neg_dim(): + class M(nn.Module): + def forward(self, z): + a, b = z.split(2, dim=-1) + return a + b + + _check_op(M(), (torch.randn(3, 4, dtype=torch.complex64),), "split_neg") + + +@pytest.mark.unit +def test_chunk(): + class M(nn.Module): + def forward(self, z): + a, b = z.chunk(2, dim=0) + return a * b + + _check_op(M(), (torch.randn(4, 3, dtype=torch.complex64),), "chunk_pos") + + +@pytest.mark.unit +def test_transpose_2d(): + class M(nn.Module): + def forward(self, z): + return z.transpose(0, 1) + + _check_op(M(), (_z(),), "transpose_2d") + + +@pytest.mark.unit +def test_transpose_neg(): + class M(nn.Module): + def forward(self, z): + return z.transpose(-2, -1) + + _check_op(M(), (_z3d(),), "transpose_neg") + + +@pytest.mark.unit +def test_t_default(): + """t.default (2D transpose) is elementwise-safe.""" + + class M(nn.Module): + def forward(self, z): + return z.t() + + _check_op(M(), (_z(),), "t_default") + + +@pytest.mark.unit +def test_expand(): + class M(nn.Module): + def forward(self, z): + return z.expand(3, 4) + + _check_op(M(), (torch.randn(1, 4, dtype=torch.complex64),), "expand") + + +@pytest.mark.unit +def test_narrow_pos(): + """narrow along a non-negative dim — pass-through is correct.""" + + class M(nn.Module): + def forward(self, z): + return z.narrow(0, 1, 2) + + _check_op(M(), (_z(),), "narrow_pos") + + +@pytest.mark.unit +def test_narrow_neg_dim(): + class M(nn.Module): + def forward(self, z): + return z.narrow(-1, 1, 2) + + _check_op(M(), (torch.randn(3, 5, dtype=torch.complex64),), "narrow_neg_dim") + + +@pytest.mark.unit +def test_roll_pos(): + """roll along a positive dim — pass-through is correct.""" + + class M(nn.Module): + def forward(self, z): + return z.roll(1, 0) + + _check_op(M(), (_z(),), "roll_pos") + + +@pytest.mark.unit +def test_roll_neg_dim(): + class M(nn.Module): + def forward(self, z): + return z.roll(1, -1) + + _check_op(M(), (_z(),), "roll_neg_dim") + + +@pytest.mark.unit +def test_flip_pos(): + """flip along a positive dim — pass-through is correct.""" + + class M(nn.Module): + def forward(self, z): + return z.flip([0]) + + _check_op(M(), (_z(),), "flip_pos") + + +@pytest.mark.unit +def test_flip_neg_dim(): + class M(nn.Module): + def forward(self, z): + return z.flip([-1]) + + _check_op(M(), (_z(),), "flip_neg_dim") + + +@pytest.mark.unit +def test_repeat(): + class M(nn.Module): + def forward(self, z): + return z.repeat(2, 1) + + _check_op(M(), (_z(),), "repeat") + + +# =========================================================================== +# 6. Matrix multiplication +# =========================================================================== + + +@pytest.mark.unit +def test_mm(): + class M(nn.Module): + def forward(self, x, y): + return torch.mm(x, y) + + x = torch.randn(3, 4, dtype=torch.complex64) + y = torch.randn(4, 5, dtype=torch.complex64) + _check_op(M(), (x, y), "mm") + + +@pytest.mark.unit +def test_bmm(): + class M(nn.Module): + def forward(self, x, y): + return torch.bmm(x, y) + + x = torch.randn(2, 3, 4, dtype=torch.complex64) + y = torch.randn(2, 4, 5, dtype=torch.complex64) + _check_op(M(), (x, y), "bmm") + + +@pytest.mark.unit +def test_matmul_2d(): + class M(nn.Module): + def forward(self, x, y): + return x @ y + + x = torch.randn(3, 4, dtype=torch.complex64) + y = torch.randn(4, 5, dtype=torch.complex64) + _check_op(M(), (x, y), "matmul_2d") + + +@pytest.mark.unit +def test_matmul_3d(): + class M(nn.Module): + def forward(self, x, y): + return x @ y + + x = torch.randn(2, 3, 4, dtype=torch.complex64) + y = torch.randn(2, 4, 5, dtype=torch.complex64) + _check_op(M(), (x, y), "matmul_3d") + + +@pytest.mark.unit +def test_mm_self_multiply(): + """mm(z, z) — self-multiplication should use the same node twice correctly.""" + + class M(nn.Module): + def forward(self, z): + return torch.mm(z, z.t()) + + z = torch.randn(4, 4, dtype=torch.complex64) + _check_op(M(), (z,), "mm_self") + + +# =========================================================================== +# 7. Elementwise-safe pass-through verification +# =========================================================================== + + +@pytest.mark.unit +def test_clone(): + class M(nn.Module): + def forward(self, z): + return z.clone() + + _check_op(M(), (_z(),), "clone") + + +@pytest.mark.unit +def test_zeros_like(): + """zeros_like(z) → 0+0i (correct — all zeros in [..., 2] layout).""" + + class M(nn.Module): + def forward(self, z): + return torch.zeros_like(z) + + _check_op(M(), (_z(),), "zeros_like") + + +@pytest.mark.unit +def test_mul_scalar_elementwise(): + """mul.Scalar is elementwise-safe: scales both re and im.""" + + class M(nn.Module): + def forward(self, z): + return torch.ops.aten.mul.Scalar(z, 2.5) + + _check_op(M(), (_z(),), "mul_scalar_aten") + + +@pytest.mark.unit +def test_div_scalar_elementwise(): + class M(nn.Module): + def forward(self, z): + return z / 4.0 + + _check_op(M(), (_z(),), "div_scalar_elementwise") + + +# =========================================================================== +# 8. Reduction ops +# — positive dims: pass-through gives correct results +# — negative dims: no handler yet → xfail +# =========================================================================== + + +@pytest.mark.unit +def test_sum_pos_dim(): + """sum(z, dim=0) — positive dim, pass-through is correct.""" + + class M(nn.Module): + def forward(self, z): + return z.sum(dim=0) + + _check_op(M(), (_z(),), "sum_pos") + + +@pytest.mark.unit +def test_sum_pos_dim_keepdim(): + class M(nn.Module): + def forward(self, z): + return z.sum(dim=1, keepdim=True) + + _check_op(M(), (_z3d(),), "sum_pos_keepdim") + + +@pytest.mark.unit +def test_sum_neg_dim(): + class M(nn.Module): + def forward(self, z): + return z.sum(dim=-1) + + _check_op(M(), (_z(),), "sum_neg") + + +@pytest.mark.unit +def test_mean_pos_dim(): + class M(nn.Module): + def forward(self, z): + return z.mean(dim=0) + + _check_op(M(), (_z(),), "mean_pos") + + +@pytest.mark.unit +def test_mean_neg_dim(): + class M(nn.Module): + def forward(self, z): + return z.mean(dim=-1) + + _check_op(M(), (_z(),), "mean_neg") + + +# =========================================================================== +# 9. Creation-op bugs (xfail = documented known failures) +# =========================================================================== + + +@pytest.mark.unit +def test_ones_like_bug(): + """ones_like(z) should give 1+0i everywhere, not 1+1i.""" + + class M(nn.Module): + def forward(self, z): + return torch.ones_like(z) + + _check_op(M(), (_z(),), "ones_like") + + +@pytest.mark.unit +def test_full_like_bug(): + """full_like(z, 3.0) should give 3+0i everywhere.""" + + class M(nn.Module): + def forward(self, z): + return torch.full_like(z, 3.0) + + _check_op(M(), (_z(),), "full_like") + + +# =========================================================================== +# 10. Chain / composition tests +# =========================================================================== + + +@pytest.mark.unit +def test_mul_then_exp(): + class M(nn.Module): + def forward(self, x, y): + return torch.exp(x * y) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.3 + _check_op(M(), (z, z.clone()), "mul_then_exp") + + +@pytest.mark.unit +def test_reshape_then_mul(): + class M(nn.Module): + def forward(self, x, y): + return x.reshape(12) * y + + x = _z() + y = torch.randn(12, dtype=torch.complex64) + _check_op(M(), (x, y), "reshape_then_mul") + + +@pytest.mark.unit +def test_mm_then_reshape(): + class M(nn.Module): + def forward(self, x, y): + return (x @ y).reshape(15) + + x = torch.randn(3, 4, dtype=torch.complex64) + y = torch.randn(4, 5, dtype=torch.complex64) + _check_op(M(), (x, y), "mm_then_reshape") + + +@pytest.mark.unit +def test_cat_then_exp(): + class M(nn.Module): + def forward(self, x, y): + return torch.exp(torch.cat([x, y], dim=0)) + + z = torch.randn(2, 4, dtype=torch.complex64) * 0.3 + _check_op(M(), (z, z.clone()), "cat_then_exp") + + +@pytest.mark.unit +def test_unsqueeze_squeeze_round_trip(): + class M(nn.Module): + def forward(self, z): + return z.unsqueeze(1).squeeze(1) + + _check_op(M(), (_z(),), "unsqueeze_squeeze_rt") + + +@pytest.mark.unit +def test_permute_mul(): + class M(nn.Module): + def forward(self, x, y): + return x.permute(1, 0) * y.permute(1, 0) + + _check_op(M(), (_z(), _z()), "permute_mul") + + +@pytest.mark.unit +def test_transpose_then_mm(): + class M(nn.Module): + def forward(self, x, y): + return x @ y.transpose(-2, -1) + + x = torch.randn(3, 4, dtype=torch.complex64) + y = torch.randn(5, 4, dtype=torch.complex64) + _check_op(M(), (x, y), "transpose_mm") + + +@pytest.mark.unit +def test_rope_style_pattern(): + """RoPE-like pattern: split → mul with freqs → cat.""" + + class M(nn.Module): + def forward(self, q, freqs): + # q: [B, T, D] complex, freqs: [T, D] complex + return q * freqs.unsqueeze(0) + + q = _z3d(2, 8, 4) + freqs = _z(8, 4) + _check_op(M(), (q, freqs), "rope_style") + + +@pytest.mark.unit +def test_multiop_chain(): + """sin(exp(z) + conj(z)) — exercises several handlers in sequence.""" + + class M(nn.Module): + def forward(self, z): + return torch.sin(torch.exp(z * 0.1) + torch.conj(z)) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.2 + _check_op(M(), (z,), "multiop_chain") + + +@pytest.mark.unit +def test_abs_then_mul(): + """abs(z) is real; multiplying by a real scalar stays real.""" + + class M(nn.Module): + def forward(self, z): + return torch.abs(z) * 2.0 + + _check_op(M(), (_z(3, 5),), "abs_then_mul") + + +@pytest.mark.unit +def test_split_then_mul_then_cat(): + """split → element-wise mul → cat.""" + + class M(nn.Module): + def forward(self, z): + a, b = z.split(2, dim=1) # [3,2] each + return torch.cat([a * b, b * a], dim=1) + + _check_op(M(), (torch.randn(3, 4, dtype=torch.complex64),), "split_mul_cat") From 3fcd86b2220123460880ed553e01455d357f8c00 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Fri, 6 Mar 2026 19:55:43 +0000 Subject: [PATCH 5/8] docs: update for new metadata approach --- .../contributors/complex_number_support.rst | 53 ++++++++++++------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/docsrc/contributors/complex_number_support.rst b/docsrc/contributors/complex_number_support.rst index fa5b4dd9ae..f4fbe96f70 100644 --- a/docsrc/contributors/complex_number_support.rst +++ b/docsrc/contributors/complex_number_support.rst @@ -135,10 +135,15 @@ runtime modules handle the conversion: Key Implementation Invariants ------------------------------- -* **``originally_complex`` set** — the set of nodes that were complex-dtype - *before* any rewrites. After ``replace_input_node``, complex placeholders become - ``float32`` so ``is_complex_dtype()`` returns ``False``. The ``originally_complex`` - set is used to decide which ``mul.Tensor`` nodes need the complex mul rewrite. +* **``node.meta["is_complex_layout"]``** — every node that represents a complex + quantity (either originally complex-dtype, or a real ``(..., 2)`` tensor produced + by the rewriter) is annotated with ``node.meta["is_complex_layout"] = True``. + This annotation is set during the detection phase (before any rewrites begin) and + propagated by every rewrite handler as it emits new nodes. It survives dtype + changes: after ``replace_input_node`` converts a ``placeholder`` from complex to + ``float32``, the dtype-based check ``is_complex_dtype()`` would return ``False``, + but the metadata flag remains. ``_is_complex_layout_node(n)`` is simply + ``n.meta.get("is_complex_layout", False)`` — no shape heuristics or recursion. * **FakeTensorMode reuse** — ``propagate_metadata`` must use the ``FakeTensorMode`` from existing placeholder fake tensors (not a fresh mode) to avoid mode-mismatch errors under ``torch.compile`` and to preserve SymInt for dynamic shapes. @@ -272,18 +277,28 @@ original ``mul`` node is then replaced and erased. See :ref:`subgraph_builder` for more on how ``SubgraphBuilder`` manages cursor-based insertion. -The ``originally_complex`` Invariant -""""""""""""""""""""""""""""""""""""" +The ``is_complex_layout`` Metadata Invariant +""""""""""""""""""""""""""""""""""""""""""""" Input replacement (Stage 2) converts complex ``placeholder`` nodes to ``float32``. After that, ``is_complex_dtype(node)`` returns ``False`` for those nodes even though they logically represent complex quantities. -To avoid missed rewrites, the rewriter records the set of nodes that were complex -*before any rewrites* in ``originally_complex``. The ``mul.Tensor`` dispatch -handler only triggers the full complex-multiply decomposition when the ``mul`` -node appears in ``originally_complex``; real multiplies that happen to follow a -complex input (e.g. an ``abs`` followed by a real-valued scale) are left alone. +To avoid missed rewrites, every node that represents a complex quantity is +annotated with ``node.meta["is_complex_layout"] = True`` during the detection +phase (lines in ``rewrite_subgraph_nodes`` before any rewrites begin). The +annotation is then propagated forward by every rewrite handler: + +* ``replace_input_node`` stamps it on the new placeholder and ``get_attr`` nodes. +* ``_inline_cat_re_im`` stamps it on every ``[re_u, im_u]`` concatenation node, + covering all math handlers (``exp``, ``log``, ``sin``, ``mul``, etc.) at once. +* Each shape-manipulation handler (``reshape``, ``permute``, ``unsqueeze``, + ``cat``, ``stack``, etc.) stamps it on its output node explicitly. + +``_is_complex_layout_node(n)`` is therefore a direct metadata lookup — no shape +heuristics (``val.shape[-1] == 2``), no recursive ``_SHAPE_TRANSPARENT_OPS`` +propagation. This also eliminates false-positives on real parameters that +coincidentally have a trailing dimension of size 2. FakeTensorMode Reuse for Dynamic Shapes """"""""""""""""""""""""""""""""""""""""" @@ -336,18 +351,18 @@ To teach the rewriter about a new complex op, add a method to .. code-block:: python @_complex_unpacker(torch.ops.aten.my_new_op.default) - def _rewrite_my_new_op( - self, - node: Node, - originally_complex: set, - ) -> None: + def _rewrite_my_new_op(self, node: Node) -> bool: inp = node.args[0] with SubgraphBuilder(self.gm.graph, node) as b: re = b(torch.ops.aten.select.int, inp, -1, 0) im = b(torch.ops.aten.select.int, inp, -1, 1) - result = b(my_real_impl, re, im) - node.replace_all_uses_with(result) - self.gm.graph.erase_node(node) + out = b(my_real_impl, re, im) + # If the output is still a complex-layout [..., 2] tensor, annotate it. + # (Not needed if using _inline_cat_re_im, which sets the flag automatically.) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True ``@_register_unpackers`` (applied to the class) picks up the new entry automatically at import time — no other registration is required. From e16b61227ac5bef5ea517b11f631f5c1102b7b04 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Wed, 11 Mar 2026 19:30:40 +0000 Subject: [PATCH 6/8] feat: Complex operations which are not supported will now fallback to pytorch rather than fail to build --- .github/workflows/build-test-linux-x86_64.yml | 3 +- .../workflows/build-test-linux-x86_64_rtx.yml | 3 +- .github/workflows/build-test-windows.yml | 1 + .github/workflows/build-test-windows_rtx.yml | 1 + .../lowering/passes/complex_graph_rewrite.py | 73 ++++- .../partitioning/_adjacency_partitioner.py | 11 + .../partitioning/_global_partitioner.py | 32 ++- pyproject.toml | 3 + .../py/dynamo/hlo/test_complex_graph_break.py | 250 ++++++++++++++++++ 9 files changed, 370 insertions(+), 7 deletions(-) create mode 100644 tests/py/dynamo/hlo/test_complex_graph_break.py diff --git a/.github/workflows/build-test-linux-x86_64.yml b/.github/workflows/build-test-linux-x86_64.yml index a437c284c0..883c756edd 100644 --- a/.github/workflows/build-test-linux-x86_64.yml +++ b/.github/workflows/build-test-linux-x86_64.yml @@ -107,7 +107,7 @@ jobs: set -euo pipefail pushd . cd tests/py/dynamo - python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_converter_tests_results.xml --dist=loadscope --maxfail=20 conversion/ + python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_converter_tests_results.xml --maxfail=20 conversion/ popd L0-dynamo-core-tests: @@ -141,6 +141,7 @@ jobs: python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_* python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/test_000_* python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/ + python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_hlo_tests_results.xml hlo/ popd L0-py-core-tests: diff --git a/.github/workflows/build-test-linux-x86_64_rtx.yml b/.github/workflows/build-test-linux-x86_64_rtx.yml index 5315cdd762..3b2f913f25 100644 --- a/.github/workflows/build-test-linux-x86_64_rtx.yml +++ b/.github/workflows/build-test-linux-x86_64_rtx.yml @@ -107,7 +107,7 @@ jobs: set -euo pipefail pushd . cd tests/py/dynamo - python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_converter_tests_results.xml --dist=loadscope --maxfail=20 conversion/ + python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_converter_tests_results.xml --maxfail=20 conversion/ popd L0-dynamo-core-tests: @@ -142,6 +142,7 @@ jobs: python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_* python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/ python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/ + python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_hlo_tests_results.xml hlo/ popd L0-py-core-tests: diff --git a/.github/workflows/build-test-windows.yml b/.github/workflows/build-test-windows.yml index 4106b65046..a03e2209c6 100644 --- a/.github/workflows/build-test-windows.yml +++ b/.github/workflows/build-test-windows.yml @@ -140,6 +140,7 @@ jobs: ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_* ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/test_000_* ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/ + ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_hlo_tests_results.xml hlo/ popd L0-py-core-tests: diff --git a/.github/workflows/build-test-windows_rtx.yml b/.github/workflows/build-test-windows_rtx.yml index 6fdbc1eab3..104551cd14 100644 --- a/.github/workflows/build-test-windows_rtx.yml +++ b/.github/workflows/build-test-windows_rtx.yml @@ -144,6 +144,7 @@ jobs: ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_* ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/ ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/ + ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_hlo_tests_results.xml hlo/ popd L0-py-core-tests: diff --git a/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py b/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py index a72d326be5..f6bff89fa5 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py @@ -1,5 +1,6 @@ import logging import math +import operator from typing import Callable, List, Optional, Tuple import torch @@ -57,6 +58,16 @@ torch.ops.aten.floor.default, torch.ops.aten.round.default, torch.ops.aten.trunc.default, + # Structural list indexing — extracts one element from a split/chunk output. + # The element is still in real [..., 2] complex layout; the flag is already + # set by the pre-rewrite annotation loop. No view_as_complex wrapping needed. + operator.getitem, + # Shape queries — sym_size.int reads a tensor's dimension value, which is not + # affected by the complex [..., 2] layout. Without this entry the fallback + # wrapper inserts view_as_complex before the sym_size node, causing the shape + # to be computed from a complex tensor in the PyTorch fallback and returning + # a raw SymInt backing value (garbage) to TRT for reshape dims. + torch.ops.aten.sym_size.int, } ) @@ -467,7 +478,14 @@ def _inline_complex_sqrt( @_complex_unpacker(torch.ops.aten.view_as_complex.default) def _rewrite_view_as_complex(self, node: Node) -> bool: - node.replace_all_uses_with(node.args[0]) + inp = node.args[0] + # The input to view_as_complex is a (..., 2) real-layout tensor that + # represents a complex tensor. After erasing view_as_complex, downstream + # consumers (e.g. mul.Tensor) need to know that this node is in complex + # layout so the correct rewrite branch is chosen. + if isinstance(inp, torch.fx.Node): + inp.meta["is_complex_layout"] = True + node.replace_all_uses_with(inp) self.gm.graph.erase_node(node) # Return True so the caller triggers propagate_metadata + gm.recompile(). # Without recompile the compiled forward still calls the erased node. @@ -1727,11 +1745,58 @@ def rewrite_subgraph_nodes(self, subgraphs: List[ComplexSubGraphInfo]) -> None: else: logger.warning( "Complex op '%s' has no explicit rewrite rule. " - "It will be passed through as-is on the real [..., 2] layout, " - "which may produce incorrect results or fail TRT compilation. " - "Consider adding a rewrite in complex_graph_rewrite.py.", + "Wrapping with view_as_complex/view_as_real so the op " + "receives genuine complex tensors and TRT graph-breaks " + "around it into a PyTorch fallback block.", node.target, ) + # Generic fallback: for each arg that is a real-layout + # complex node, insert view_as_complex before the node so + # the op sees genuine complex-dtype tensors (correct + # semantics); then, if the node itself originally produced + # a complex-layout output, wrap it with view_as_real and + # redirect all users back onto the real [..., 2] path. + # TRT has no complex-dtype support so it will refuse to + # compile the view_as_complex/op/view_as_real cluster, + # causing the partitioner to create a PyTorch fallback + # block around it — exactly the graph break we want. + new_args = list(node.args) + any_complexified = False + for i, arg in enumerate(node.args): + if not isinstance(arg, torch.fx.Node): + continue + if not arg.meta.get("is_complex_layout", False): + continue + # Skip when val is a list/tuple (e.g. a residual split + # output that wasn't caught by the getitem pass-through). + # Allow None (newly created node without metadata yet). + arg_val = arg.meta.get("val") + if isinstance(arg_val, (list, tuple)): + continue + with self.gm.graph.inserting_before(node): + vc = self.gm.graph.call_function( + torch.ops.aten.view_as_complex.default, + (arg,), + ) + # view_as_complex produces a genuine complex node — + # do NOT set is_complex_layout; it is not a + # real-layout stand-in. + new_args[i] = vc + any_complexified = True + if any_complexified: + node.args = tuple(new_args) + if any_complexified and node.meta.get("is_complex_layout", False): + with self.gm.graph.inserting_after(node): + vr = self.gm.graph.call_function( + torch.ops.aten.view_as_real.default, + (node,), + ) + vr.meta["is_complex_layout"] = True + node.replace_all_uses_with( + vr, + delete_user_cb=lambda user: user is not vr, + ) + modified = True if modified: # After rewriting complex ops, any view_as_real node that now receives a # real tensor must be erased. The subgraph_rewriter replaces the original diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index 72d0be42c7..04c8c50dbe 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -22,6 +22,9 @@ from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( ConverterRegistry, ) +from torch_tensorrt.dynamo.partitioning._global_partitioner import ( + TorchTensorRTOperatorSupport, +) logger = logging.getLogger(__name__) @@ -42,6 +45,14 @@ def is_node_supported( ) -> bool: node_name = ConverterRegistry.qualified_name_or_str(node.target) + if TorchTensorRTOperatorSupport._has_complex_dtype(node): + # Complex-dtype tensors are not supported by TensorRT; force PyTorch fallback + if not node.is_impure(): + self.unsupported_operators[node_name] = ( + self.unsupported_operators.get(node_name, 0) + 1 + ) + return False + if ( (node in CONVERTERS or node.op == "get_attr") and node_name not in self.torch_executed_ops diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index 707497b227..8d02076607 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -1,5 +1,5 @@ import logging -from typing import Collection, Dict, List, Mapping, Optional, Sequence, Tuple +from typing import Collection, Dict, List, Mapping, Optional, Sequence, Tuple, Set import torch from torch.fx.graph_module import GraphModule @@ -144,11 +144,41 @@ def __init__( self.unsupported_operators: Dict[str, int] = {} self.torch_executed_ops: Collection[Target] = torch_executed_ops + @staticmethod + def _has_complex_dtype(node: torch.fx.Node) -> bool: + """Return True if the node output or any of its tensor inputs is complex-dtype. + + TensorRT has no native complex-type support. Any node that produces or + consumes a complex tensor must run in the PyTorch fallback so the graph + breaks naturally around it. + """ + _COMPLEX = {torch.complex64, torch.complex128} + + def _dtype(n: torch.fx.Node) -> Optional[torch.dtype]: + val = n.meta.get("val") + return getattr(val, "dtype", None) if val is not None else None + + if _dtype(node) in _COMPLEX: + return True + for arg in node.all_input_nodes: + if _dtype(arg) in _COMPLEX: + return True + return False + def is_node_supported( self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node ) -> bool: node_name = ConverterRegistry.qualified_name_or_str(node.target) + if self._has_complex_dtype(node): + # Complex-dtype tensors are not supported by TensorRT; force PyTorch fallback + # so the graph breaks around the complex cluster inserted by complex_graph_detection. + if not node.is_impure(): + self.unsupported_operators[node_name] = ( + self.unsupported_operators.get(node_name, 0) + 1 + ) + return False + if ( (node in CONVERTERS or node.op == "get_attr") and node_name not in self.torch_executed_ops diff --git a/pyproject.toml b/pyproject.toml index 47d18ed8fe..911e22d3e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ lint = [ dev = [ {include-group = "lint"}, + {include-group = "test"}, "pre-commit>=2.20.0", "typos", "mypy", @@ -77,6 +78,7 @@ debug = [ test = [ "pytest", "pytest-xdist", + "pytest-forked>=1.6.0", "parameterized>=0.2.0", "expecttest==0.1.6", ] @@ -114,6 +116,7 @@ include-package-data = false [tool.pytest.ini_options] testpaths = ["tests/py"] +addopts = "-n auto --dist=loadfile" norecursedirs = [ "bazel-*", ".venv", diff --git a/tests/py/dynamo/hlo/test_complex_graph_break.py b/tests/py/dynamo/hlo/test_complex_graph_break.py new file mode 100644 index 0000000000..4781c80255 --- /dev/null +++ b/tests/py/dynamo/hlo/test_complex_graph_break.py @@ -0,0 +1,250 @@ +"""Tests for complex tensor graph-break behavior in torch-tensorrt. + +These tests verify that when a model contains complex tensor operations mixed with +ops that have no handler in the complex-lowering rewriter, the compiler: + + 1. Wraps the unsupported op with ``view_as_complex`` / ``view_as_real`` so it + receives genuine complex-dtype inputs and returns a real-layout output. + 2. TRT, which has no complex-dtype support, naturally graph-breaks around the + wrapped cluster and runs it as a PyTorch fallback block. + 3. The lowerable complex ops on both sides compile to TRT via + ``complex_graph_detection``. + 4. The overall model produces numerically correct results end-to-end. + +Background +---------- +``complex_graph_detection`` rewrites complex-dtype ATen ops to equivalent +real-arithmetic ops before TRT compilation. When an op is *not* registered +with ``@_complex_unpacker`` and is not in ``_ELEMENTWISE_SAFE`` the rewriter +inserts ``view_as_complex`` before each complex-layout input and +``view_as_real`` after the output, preserving correct semantics and letting +TRT's lack of complex support create the graph break automatically. + +``cumsum`` is used as the representative unsupported op: it has well-defined +PyTorch semantics on complex tensors but has no handler in the rewriter. +""" + +import pytest +import torch +import torch.nn as nn +import torch_tensorrt as torchtrt +from torch_tensorrt.dynamo.lowering.passes.complex_graph_rewrite import ( + complex_graph_detection, +) +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +try: + from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule + + _PYTHON_RUNTIME_AVAILABLE = True +except ImportError: # pragma: no cover + _PYTHON_RUNTIME_AVAILABLE = False + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +def _make_freqs(seq: int, dim: int, theta: float = 10000.0) -> torch.Tensor: + """Complex unit-magnitude frequency tensor on CUDA, shape ``(seq, dim//2)``.""" + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(seq, dtype=torch.float32) + freqs = torch.outer(t, freqs) + return torch.polar(torch.ones_like(freqs), freqs).cuda() + + +def _cossim_real(py_out: torch.Tensor, trt_out: torch.Tensor, tag: str) -> None: + """Assert cosine similarity > COSINE_THRESHOLD on a real-valued output.""" + assert not trt_out.is_complex(), f"{tag}: expected real output, got {trt_out.dtype}" + s = cosine_similarity(py_out.contiguous(), trt_out.contiguous()) + assert s > COSINE_THRESHOLD, f"{tag}: cosine sim {s:.4f} < {COSINE_THRESHOLD}" + + +def _count_trt_modules(mod: torch.nn.Module) -> int: + """Return the number of ``PythonTorchTensorRTModule`` submodules (-1 if unavailable).""" + if not _PYTHON_RUNTIME_AVAILABLE: + return -1 + return sum( + 1 for _, m in mod.named_modules() if isinstance(m, PythonTorchTensorRTModule) + ) + + +def _export_and_lower(model: nn.Module, inputs: tuple) -> torch.fx.GraphModule: + """Export model and apply complex_graph_detection lowering pass.""" + with torch.no_grad(): + ep = torch.export.export(model.eval(), inputs) + gm = ep.module() + complex_graph_detection(gm, CompilationSettings()) + return gm + + +# =========================================================================== +# Test 1 — unsupported op gets view_as_complex/view_as_real wrapper +# =========================================================================== + + +class ComplexMulThenCumsum(nn.Module): + """Complex mul (lowerable) followed by cumsum (no rewriter handler). + + After ``complex_graph_detection`` the rewriter cannot handle ``cumsum``. + It inserts ``view_as_complex`` before cumsum's input and ``view_as_real`` + after its output so the op runs in PyTorch with correct complex semantics + while TRT compiles the surrounding real-arithmetic blocks. + """ + + def forward(self, z: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + rotated = z * freqs # complex mul — lowered to real arithmetic by rewriter + accumulated = torch.cumsum(rotated, dim=0) # no handler → PyTorch fallback + return torch.view_as_real(accumulated).flatten(-2) + + +@pytest.mark.unit +def test_unsupported_op_gets_complexify_wrap() -> None: + """The rewriter wraps cumsum with view_as_complex/view_as_real. + + Structural check (no TRT required): + - After lowering, the graph contains ``view_as_complex`` immediately + before ``cumsum`` and ``view_as_real`` immediately after. + - The ``view_as_complex`` input is the real-layout output of the + rewritten complex mul — confirming it is a float32 ``(..., 2)`` node. + - The ``view_as_real`` output feeds the downstream flatten. + - The PyTorch cumsum receives a complex-dtype tensor (correct semantics). + """ + model = ComplexMulThenCumsum().eval().cuda() + z = _make_freqs(8, 64) + freqs = _make_freqs(8, 64) + + gm = _export_and_lower(model, (z, freqs)) + + nodes_by_target: dict = {} + for n in gm.graph.nodes: + nodes_by_target.setdefault(n.target, []).append(n) + + # view_as_complex must be present (inserted by the fallback wrapper) + assert torch.ops.aten.view_as_complex.default in nodes_by_target, ( + "Expected view_as_complex to be inserted before cumsum, but it was not found" + ) + + # cumsum must still be present (it was NOT removed) + assert torch.ops.aten.cumsum.default in nodes_by_target, ( + "cumsum should remain in the graph (runs as PyTorch fallback)" + ) + + # The view_as_complex output feeds directly into cumsum + vc_node = nodes_by_target[torch.ops.aten.view_as_complex.default][0] + cumsum_node = nodes_by_target[torch.ops.aten.cumsum.default][0] + assert cumsum_node.args[0] is vc_node, ( + f"cumsum's first arg should be the view_as_complex node, got {cumsum_node.args[0]}" + ) + + # The view_as_complex input is a real-layout (is_complex_layout) node + vc_input = vc_node.args[0] + assert isinstance(vc_input, torch.fx.Node), "view_as_complex input must be a Node" + assert vc_input.meta.get("is_complex_layout", False), ( + "view_as_complex input should be a real-layout complex node (is_complex_layout=True)" + ) + + # view_as_real must follow cumsum + assert torch.ops.aten.view_as_real.default in nodes_by_target, ( + "Expected view_as_real to be inserted after cumsum, but it was not found" + ) + vr_node = nodes_by_target[torch.ops.aten.view_as_real.default][0] + assert vr_node.args[0] is cumsum_node, ( + f"view_as_real's arg should be the cumsum node, got {vr_node.args[0]}" + ) + + # After metadata propagation, cumsum receives a complex-dtype tensor + vc_val = vc_node.meta.get("val") + if vc_val is not None: + assert vc_val.dtype in (torch.complex64, torch.complex128), ( + f"view_as_complex output should be complex, got {vc_val.dtype}" + ) + + +# =========================================================================== +# Test 2 — lowerable ops TRT, unsupported op PyTorch (with complex input), +# lowerable ops TRT again; end-to-end numerical correctness +# =========================================================================== + + +class ComplexTwoTRTBlocksAroundCumsum(nn.Module): + """Two complex-rotation TRT blocks with cumsum (PyTorch) in between. + + Expected graph after ``complex_graph_detection``: + + [Block A — TRT] + z_real, freqs_real → re/im arithmetic for z * freqs → rotated_real + + [PyTorch fallback — complex inputs] + view_as_complex(rotated_real) → cumsum(complex) → view_as_real → acc_real + + [Block B — TRT] + acc_real, freqs_real → re/im arithmetic for acc * freqs → result_real + result_real → view_as_real substitute → flatten → output + """ + + def forward(self, z: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + # Block A: complex rotate — lowered to real arithmetic + rotated = z * freqs + + # Unsupported complex op — rewriter inserts view_as_complex/view_as_real; + # TRT graph-breaks here; cumsum runs in PyTorch on a complex tensor + accumulated = torch.cumsum(rotated, dim=0) + + # Block B: second complex rotate — lowered to real arithmetic + result = accumulated * freqs + return torch.view_as_real(result).flatten(-2) + + +@pytest.mark.unit +def test_complex_partial_lowering_with_graph_break() -> None: + """Lowerable complex ops compile to TRT; cumsum runs in PyTorch on complex input. + + Asserts: + 1. The compiled model is numerically correct (cosine sim > threshold). + 2. At least one ``PythonTorchTensorRTModule`` submodule exists — confirming + the lowerable complex ops were compiled to TRT, not all relegated to + PyTorch fallback. + 3. After lowering, cumsum receives a complex-dtype tensor (the + view_as_complex wrapper was inserted correctly). + """ + model = ComplexTwoTRTBlocksAroundCumsum().eval().cuda() + z = _make_freqs(8, 64) + freqs = _make_freqs(8, 64) + inputs = (z, freqs) + + # Structural check: verify cumsum gets a complex input after lowering + gm = _export_and_lower(model, inputs) + for n in gm.graph.nodes: + if n.target == torch.ops.aten.cumsum.default: + vc_val = n.args[0].meta.get("val") + if vc_val is not None: + assert vc_val.dtype in (torch.complex64, torch.complex128), ( + f"cumsum should receive a complex tensor, got {vc_val.dtype}" + ) + break + + # End-to-end: compile and verify numerical correctness + ep = torch.export.export(model, inputs) + trt_model = torchtrt.dynamo.compile( + ep, + inputs=inputs, + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=True, + ) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex_partial_lowering_with_graph_break") + + # Verify at least one TRT block was created for the lowerable complex ops + n_trt = _count_trt_modules(trt_model) + if n_trt >= 0: + assert n_trt >= 1, ( + f"Expected at least one TRT submodule (lowerable complex ops should " + f"compile to TRT) but found {n_trt}." + ) + + torch._dynamo.reset() From 00d038969d28b6618e43f0618896ddca9fbbcbdf Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Tue, 17 Mar 2026 20:03:45 +0000 Subject: [PATCH 7/8] chore: addressing PR AIs --- .github/workflows/build-test-linux-x86_64.yml | 2 +- .../workflows/build-test-linux-x86_64_rtx.yml | 2 +- .github/workflows/build-test-windows.yml | 2 +- .github/workflows/build-test-windows_rtx.yml | 2 +- py/torch_tensorrt/dynamo/utils.py | 6 +-- .../py/dynamo/hlo/test_complex_graph_break.py | 50 ++++++++++--------- .../dynamo/lowering/test_complex_rewrite.py | 7 ++- 7 files changed, 38 insertions(+), 33 deletions(-) diff --git a/.github/workflows/build-test-linux-x86_64.yml b/.github/workflows/build-test-linux-x86_64.yml index 883c756edd..8bff2f8805 100644 --- a/.github/workflows/build-test-linux-x86_64.yml +++ b/.github/workflows/build-test-linux-x86_64.yml @@ -141,7 +141,6 @@ jobs: python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_* python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/test_000_* python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/ - python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_hlo_tests_results.xml hlo/ popd L0-py-core-tests: @@ -237,6 +236,7 @@ jobs: cd tests/py/dynamo python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_* python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_partitioning_tests_results.xml partitioning/test_001_* + python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_hlo_tests_results.xml hlo/ popd diff --git a/.github/workflows/build-test-linux-x86_64_rtx.yml b/.github/workflows/build-test-linux-x86_64_rtx.yml index 3b2f913f25..de407803e8 100644 --- a/.github/workflows/build-test-linux-x86_64_rtx.yml +++ b/.github/workflows/build-test-linux-x86_64_rtx.yml @@ -142,7 +142,6 @@ jobs: python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_* python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/ python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/ - python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_hlo_tests_results.xml hlo/ popd L0-py-core-tests: @@ -205,6 +204,7 @@ jobs: pushd . cd tests/py/dynamo python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_* + python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_hlo_tests_results.xml hlo/ popd L1-dynamo-compile-tests: diff --git a/.github/workflows/build-test-windows.yml b/.github/workflows/build-test-windows.yml index a03e2209c6..f25d2a2a3b 100644 --- a/.github/workflows/build-test-windows.yml +++ b/.github/workflows/build-test-windows.yml @@ -140,7 +140,6 @@ jobs: ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_* ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/test_000_* ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/ - ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_hlo_tests_results.xml hlo/ popd L0-py-core-tests: @@ -227,6 +226,7 @@ jobs: cd tests/py/dynamo ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_* ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_partitioning_tests_results.xml partitioning/test_001_* + ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_hlo_tests_results.xml hlo/ popd L1-dynamo-compile-tests: diff --git a/.github/workflows/build-test-windows_rtx.yml b/.github/workflows/build-test-windows_rtx.yml index 104551cd14..d25ed8b770 100644 --- a/.github/workflows/build-test-windows_rtx.yml +++ b/.github/workflows/build-test-windows_rtx.yml @@ -144,7 +144,6 @@ jobs: ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_* ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/ ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/ - ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_hlo_tests_results.xml hlo/ popd L0-py-core-tests: @@ -201,6 +200,7 @@ jobs: pushd . cd tests/py/dynamo ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_* + ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_hlo_tests_results.xml hlo/ popd L1-dynamo-compile-tests: diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 0c7166655a..5c797a3940 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -838,9 +838,9 @@ def copy_metadata(match_and_replacements: List[Any]) -> None: """ for match_and_replacement in match_and_replacements: anchor_node = match_and_replacement.nodes_map[match_and_replacement.anchor] - assert len(match_and_replacement.replacements) == 1, ( - "Found more than 1 replacements for the anchor node." - ) + assert ( + len(match_and_replacement.replacements) == 1 + ), "Found more than 1 replacements for the anchor node." replacement_node = match_and_replacement.replacements[0] replacement_node.meta = anchor_node.meta diff --git a/tests/py/dynamo/hlo/test_complex_graph_break.py b/tests/py/dynamo/hlo/test_complex_graph_break.py index 4781c80255..e48d749cef 100644 --- a/tests/py/dynamo/hlo/test_complex_graph_break.py +++ b/tests/py/dynamo/hlo/test_complex_graph_break.py @@ -123,44 +123,45 @@ def test_unsupported_op_gets_complexify_wrap() -> None: nodes_by_target.setdefault(n.target, []).append(n) # view_as_complex must be present (inserted by the fallback wrapper) - assert torch.ops.aten.view_as_complex.default in nodes_by_target, ( - "Expected view_as_complex to be inserted before cumsum, but it was not found" - ) + assert ( + torch.ops.aten.view_as_complex.default in nodes_by_target + ), "Expected view_as_complex to be inserted before cumsum, but it was not found" # cumsum must still be present (it was NOT removed) - assert torch.ops.aten.cumsum.default in nodes_by_target, ( - "cumsum should remain in the graph (runs as PyTorch fallback)" - ) + assert ( + torch.ops.aten.cumsum.default in nodes_by_target + ), "cumsum should remain in the graph (runs as PyTorch fallback)" # The view_as_complex output feeds directly into cumsum vc_node = nodes_by_target[torch.ops.aten.view_as_complex.default][0] cumsum_node = nodes_by_target[torch.ops.aten.cumsum.default][0] - assert cumsum_node.args[0] is vc_node, ( - f"cumsum's first arg should be the view_as_complex node, got {cumsum_node.args[0]}" - ) + assert ( + cumsum_node.args[0] is vc_node + ), f"cumsum's first arg should be the view_as_complex node, got {cumsum_node.args[0]}" # The view_as_complex input is a real-layout (is_complex_layout) node vc_input = vc_node.args[0] assert isinstance(vc_input, torch.fx.Node), "view_as_complex input must be a Node" - assert vc_input.meta.get("is_complex_layout", False), ( - "view_as_complex input should be a real-layout complex node (is_complex_layout=True)" - ) + assert vc_input.meta.get( + "is_complex_layout", False + ), "view_as_complex input should be a real-layout complex node (is_complex_layout=True)" # view_as_real must follow cumsum - assert torch.ops.aten.view_as_real.default in nodes_by_target, ( - "Expected view_as_real to be inserted after cumsum, but it was not found" - ) + assert ( + torch.ops.aten.view_as_real.default in nodes_by_target + ), "Expected view_as_real to be inserted after cumsum, but it was not found" vr_node = nodes_by_target[torch.ops.aten.view_as_real.default][0] - assert vr_node.args[0] is cumsum_node, ( - f"view_as_real's arg should be the cumsum node, got {vr_node.args[0]}" - ) + assert ( + vr_node.args[0] is cumsum_node + ), f"view_as_real's arg should be the cumsum node, got {vr_node.args[0]}" # After metadata propagation, cumsum receives a complex-dtype tensor vc_val = vc_node.meta.get("val") if vc_val is not None: - assert vc_val.dtype in (torch.complex64, torch.complex128), ( - f"view_as_complex output should be complex, got {vc_val.dtype}" - ) + assert vc_val.dtype in ( + torch.complex64, + torch.complex128, + ), f"view_as_complex output should be complex, got {vc_val.dtype}" # =========================================================================== @@ -221,9 +222,10 @@ def test_complex_partial_lowering_with_graph_break() -> None: if n.target == torch.ops.aten.cumsum.default: vc_val = n.args[0].meta.get("val") if vc_val is not None: - assert vc_val.dtype in (torch.complex64, torch.complex128), ( - f"cumsum should receive a complex tensor, got {vc_val.dtype}" - ) + assert vc_val.dtype in ( + torch.complex64, + torch.complex128, + ), f"cumsum should receive a complex tensor, got {vc_val.dtype}" break # End-to-end: compile and verify numerical correctness diff --git a/tests/py/dynamo/lowering/test_complex_rewrite.py b/tests/py/dynamo/lowering/test_complex_rewrite.py index dfd752df9b..3d82d7d87b 100644 --- a/tests/py/dynamo/lowering/test_complex_rewrite.py +++ b/tests/py/dynamo/lowering/test_complex_rewrite.py @@ -62,8 +62,11 @@ def _export_and_lower( def _real_inputs(inputs: Tuple[Any, ...]) -> Tuple[Any, ...]: """Convert complex tensors to [..., 2] real layout.""" return tuple( - torch.view_as_real(x).contiguous() if isinstance(x, torch.Tensor) and x.is_complex() - else x + ( + torch.view_as_real(x).contiguous() + if isinstance(x, torch.Tensor) and x.is_complex() + else x + ) for x in inputs ) From fb4fc99f47db2989fa0850e6e3d70a553058b9a3 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Wed, 18 Mar 2026 23:00:36 +0000 Subject: [PATCH 8/8] remove the old complex subgraph detection --- .../lowering/passes/complex_graph_rewrite.py | 58 +------------------ uv.lock | 37 +++++++++++- 2 files changed, 35 insertions(+), 60 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py b/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py index f6bff89fa5..7c03080c8e 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py @@ -147,66 +147,10 @@ def node_include_in_subgraph(self, node: Node) -> bool: return False return self.is_complex_dtype(node) or self.has_complex_input(node) - def subgraph_from_anchor(self, anchor_node: Node) -> ComplexSubGraphInfo: - subgraph_nodes: Set[Node] = set() - input_nodes: Set[Node] = set() - stack = [anchor_node] - while stack: - n = stack.pop() - if n in subgraph_nodes: - continue - subgraph_nodes.add(n) - for inp in n.all_input_nodes: - if self.node_include_in_subgraph(inp): - stack.append(inp) - else: - input_nodes.add(inp) - # Sort subgraph_nodes in topological (graph) order so the rewriter - # processes producers before consumers. The set has no stable order, - # which caused bugs when e.g. mul(sin, sin) was processed before sin - # was rewritten (sin still had complex dtype, so the mul pattern ran - # against the original complex node and produced wrong results). - node_order = {n: i for i, n in enumerate(anchor_node.graph.nodes)} - ordered_subgraph = sorted(subgraph_nodes, key=lambda n: node_order.get(n, 0)) - return ComplexSubGraphInfo([anchor_node], ordered_subgraph, list(input_nodes)) - - def find_complex_op_subgraphs( - self, gm: GraphModule, anchor_target: str - ) -> List[ComplexSubGraphInfo]: - complex_op_subgraphs: List[ComplexSubGraphInfo] = [] - for node in gm.graph.nodes: - if node.target == anchor_target: - new_sub = self.subgraph_from_anchor(node) - # if any intersecting nodes between seen and sub.subgraph_nodes they should be merged - merged = False - for existing_sub in complex_op_subgraphs: - if set(existing_sub.subgraph_nodes) & set(new_sub.subgraph_nodes): - logger.debug(f"merging subgraphs {existing_sub} {new_sub}") - # merge the two subgraphs, preserving topological order - merged_nodes = set(existing_sub.subgraph_nodes) | set( - new_sub.subgraph_nodes - ) - node_order = {n: i for i, n in enumerate(gm.graph.nodes)} - existing_sub.subgraph_nodes = sorted( - merged_nodes, key=lambda n: node_order.get(n, 0) - ) - existing_sub.input_nodes = list( - set(existing_sub.input_nodes) | set(new_sub.input_nodes) - ) - existing_sub.anchor_nodes = list( - set(existing_sub.anchor_nodes) | set(new_sub.anchor_nodes) - ) - merged = True - break - if not merged: - complex_op_subgraphs.append(new_sub) - return complex_op_subgraphs - def find_all_complex_subgraphs(self, gm: GraphModule) -> List[ComplexSubGraphInfo]: """Forward scan: collect all complex-dtype call_function nodes as one subgraph. - Unlike find_complex_op_subgraphs (which walks backwards from a single anchor), - this scans forward over every node and collects all call_function nodes whose + Scans forward over every node and collects all call_function nodes whose output is complex — regardless of whether they are bounded by view_as_real. This ensures complex ops that feed directly into graph outputs (no view_as_real) are still rewritten to real arithmetic. diff --git a/uv.lock b/uv.lock index 82459b609a..c320b1b50b 100644 --- a/uv.lock +++ b/uv.lock @@ -32,9 +32,6 @@ required-markers = [ "python_full_version < '3.14' and platform_machine == 'AMD64' and sys_platform == 'win32'", ] -[options] -prerelease-mode = "allow" - [[package]] name = "accelerate" version = "1.12.0" @@ -2900,6 +2897,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/99/6c/64cafaceea3f99927e84b38a362ec6a8f24f33061c90bda77dfe1cd4c3c6/pulp-3.3.0-py3-none-any.whl", hash = "sha256:dd6ad2d63f196d1254eddf9dcff5cd224912c1f046120cb7c143c5b0eda63fae", size = 16387700, upload-time = "2025-09-18T08:14:53.368Z" }, ] +[[package]] +name = "py" +version = "1.11.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/98/ff/fec109ceb715d2a6b4c4a85a61af3b40c723a961e8828319fbcb15b868dc/py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719", size = 207796, upload-time = "2021-11-04T17:17:01.377Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/f0/10642828a8dfb741e5f3fbaac830550a518a775c7fff6f04a007259b0548/py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378", size = 98708, upload-time = "2021-11-04T17:17:00.152Z" }, +] + [[package]] name = "py-cpuinfo" version = "9.0.0" @@ -3157,6 +3163,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, ] +[[package]] +name = "pytest-forked" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "py", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "pytest", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8c/c9/93ad2ba2413057ee694884b88cf7467a46c50c438977720aeac26e73fdb7/pytest-forked-1.6.0.tar.gz", hash = "sha256:4dafd46a9a600f65d822b8f605133ecf5b3e1941ebb3588e943b4e3eb71a5a3f", size = 9977, upload-time = "2023-02-12T23:22:27.544Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/af/9c0bda43e486a3c9bf1e0f876d0f241bc3f229d7d65d09331a0868db9629/pytest_forked-1.6.0-py3-none-any.whl", hash = "sha256:810958f66a91afb1a1e2ae83089d8dc1cd2437ac96b12963042fbb9fb4d16af0", size = 4897, upload-time = "2023-02-12T23:22:26.022Z" }, +] + [[package]] name = "pytest-xdist" version = "3.8.0" @@ -4046,9 +4065,14 @@ debug = [ dev = [ { name = "black", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "clang-format", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "expecttest", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "isort", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "mypy", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "parameterized", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "pre-commit", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "pytest", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "pytest-forked", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "pytest-xdist", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "pyyaml", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "ruff", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "typos", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -4074,6 +4098,7 @@ test = [ { name = "expecttest", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "parameterized", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "pytest", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "pytest-forked", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "pytest-xdist", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] test-ext = [ @@ -4108,9 +4133,14 @@ debug = [ dev = [ { name = "black", specifier = ">=24.0.0" }, { name = "clang-format", specifier = "==14.0.6" }, + { name = "expecttest", specifier = "==0.1.6" }, { name = "isort" }, { name = "mypy" }, + { name = "parameterized", specifier = ">=0.2.0" }, { name = "pre-commit", specifier = ">=2.20.0" }, + { name = "pytest" }, + { name = "pytest-forked", specifier = ">=1.6.0" }, + { name = "pytest-xdist" }, { name = "pyyaml" }, { name = "ruff" }, { name = "typos" }, @@ -4134,6 +4164,7 @@ test = [ { name = "expecttest", specifier = "==0.1.6" }, { name = "parameterized", specifier = ">=0.2.0" }, { name = "pytest" }, + { name = "pytest-forked", specifier = ">=1.6.0" }, { name = "pytest-xdist" }, ] test-ext = [