From 8478bfbc1de28bfd973bbca36581e52c4cb38baa Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Thu, 19 Mar 2026 12:23:04 -0700 Subject: [PATCH 1/2] Migrate devtools tests and clean up remaining executorch/ CaptureConfig refs (#18135) Summary: Migrate etrecord_test, exported_op_graph_test, and size_analysis_tool_test from exir.capture to torch.export + to_edge. Remove dead CaptureConfig imports/usage from end2end tests, export_program, and qualcomm utils (delete qnn_capture_config() which returned a CaptureConfig). Reviewed By: Gasoonjia Differential Revision: D95605485 --- backends/qualcomm/utils/utils.py | 4 --- devtools/etrecord/tests/etrecord_test.py | 24 +++++++-------- .../size_analysis_tool_test.py | 10 +++++-- test/end2end/exported_module.py | 3 -- test/end2end/test_end2end.py | 29 +------------------ test/models/export_program.py | 3 +- 6 files changed, 20 insertions(+), 53 deletions(-) diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 19603e6219b..d45ef294bba 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -134,10 +134,6 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool: return True -def qnn_capture_config(): - return exir.CaptureConfig(enable_aot=True) - - def qnn_edge_config() -> exir.EdgeCompileConfig: return exir.EdgeCompileConfig( _check_ir_validity=False, diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index a57515bffee..2e58dd45ea4 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -618,12 +618,10 @@ def test_add_extra_export_modules(self): # Create additional module to add f2 = models.BasicSinMax() - captured_output2 = exir.capture( - f2, f2.get_random_inputs(), exir.CaptureConfig() - ) + captured_output2 = export(f2, f2.get_random_inputs(), strict=True) extra_modules = { - "new_module": captured_output2.exported_program, + "new_module": captured_output2, } # Add extra export modules @@ -640,7 +638,7 @@ def test_add_extra_export_modules(self): ) self.check_graph_closeness( etrecord.graph_map["new_module/forward"], - captured_output2.exported_program.graph_module, + captured_output2.graph_module, ) def test_add_extra_export_modules_reserved_name_validation(self): @@ -1066,13 +1064,11 @@ def test_add_exported_program_already_exists_exception(self): # Create another exported program to try to add f2 = models.BasicSinMax() - captured_output2 = exir.capture( - f2, f2.get_random_inputs(), exir.CaptureConfig() - ) + captured_output2 = export(f2, f2.get_random_inputs(), strict=True) # Verify that adding exported program raises RuntimeError with self.assertRaises(RuntimeError) as context: - etrecord.add_exported_program(captured_output2.exported_program) + etrecord.add_exported_program(captured_output2) self.assertIn( "Exported program already exists in the ETRecord", @@ -1202,11 +1198,11 @@ def test_add_edge_dialect_program_already_exists_exception(self): # Create another edge program to try to add f2 = models.BasicSinMax() - captured_output2 = exir.capture( - f2, f2.get_random_inputs(), exir.CaptureConfig() - ) - edge_output2 = captured_output2.to_edge( - exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False) + edge_output2 = to_edge( + export(f2, f2.get_random_inputs(), strict=True), + compile_config=exir.EdgeCompileConfig( + _check_ir_validity=False, _use_edge_ops=False + ), ) # Verify that adding edge dialect program raises RuntimeError diff --git a/devtools/size_analysis_tool/size_analysis_tool_test.py b/devtools/size_analysis_tool/size_analysis_tool_test.py index 2212b2e15a9..69f1f07e5d6 100644 --- a/devtools/size_analysis_tool/size_analysis_tool_test.py +++ b/devtools/size_analysis_tool/size_analysis_tool_test.py @@ -11,6 +11,7 @@ XnnpackFloatingPointPartitioner, ) from executorch.backends.xnnpack.utils.configs import ( + get_transform_passes, get_xnnpack_edge_compile_config, get_xnnpack_executorch_backend_config, ) @@ -19,6 +20,7 @@ generate_model_size_information, ) from executorch.exir import to_edge +from executorch.exir.backend.backend_api import to_backend, validation_disabled from executorch.exir.passes.spec_prop_pass import SpecPropPass from torch.export import export @@ -56,10 +58,14 @@ def forward(self, x): edge_program = to_edge( export(mm, (test_input,), strict=True), compile_config=get_xnnpack_edge_compile_config(), - ) + ).transform(get_transform_passes()) partitioner = XnnpackFloatingPointPartitioner() - delegated_program = edge_program.to_backend(partitioner) + with validation_disabled(): + delegated_program = edge_program + delegated_program._edge_programs["forward"] = to_backend( + edge_program.exported_program(), partitioner + ) program = delegated_program.to_executorch( get_xnnpack_executorch_backend_config([SpecPropPass()]), diff --git a/test/end2end/exported_module.py b/test/end2end/exported_module.py index 97deda4adf1..f2a3a700a6e 100644 --- a/test/end2end/exported_module.py +++ b/test/end2end/exported_module.py @@ -67,7 +67,6 @@ def export( methods: Sequence[str] = ("forward",), ignore_to_out_var_failure: bool = False, dynamic_memory_planning_mode: DynamicMemoryPlanningMode = DynamicMemoryPlanningMode.UPPER_BOUND, - capture_config=None, export_joint_graph: bool = False, external_constants: bool = False, export_state_names: bool = False, @@ -146,8 +145,6 @@ def return_wrapper(): method_name_to_dynamic_shapes = None if hasattr(eager_module, "get_dynamic_shapes"): - assert capture_config is not None - assert capture_config.enable_aot is True trace_dynamic_shapes = eager_module.get_dynamic_shapes() # type: ignore[operator] method_name_to_dynamic_shapes = {} for method in methods: diff --git a/test/end2end/test_end2end.py b/test/end2end/test_end2end.py index 4f052798e30..bce7959d8e0 100644 --- a/test/end2end/test_end2end.py +++ b/test/end2end/test_end2end.py @@ -21,12 +21,7 @@ import executorch.extension.pytree as pytree import torch -from executorch.exir import ( - CaptureConfig, - EdgeCompileConfig, - ExecutorchBackendConfig, - memory, -) +from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, memory from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode from executorch.exir.emit import emit_program from executorch.exir.pass_manager import PassManager @@ -471,7 +466,6 @@ def maketest( allow_non_contiguous_tensor: bool = False, method: str = "forward", dynamic_memory_planning_mode: DynamicMemoryPlanningMode = DynamicMemoryPlanningMode.UPPER_BOUND, - capture_config=None, verify_graph: Optional[Callable] = None, ) -> Callable[[unittest.TestCase], None]: r"""Returns a TestCase method to test the provided module class and method. @@ -507,7 +501,6 @@ def wrapper(self: unittest.TestCase) -> None: methods=(method,), ignore_to_out_var_failure=ignore_to_out_var_failure, dynamic_memory_planning_mode=dynamic_memory_planning_mode, - capture_config=capture_config, ) if verify_graph: verify_graph(self, module.exported_program.graph_module) @@ -599,9 +592,6 @@ def test_ops_return_multi(self): def test_mem_planning_toy_model(self): maketest( ToyModelForMemPlanning, - capture_config=exir.CaptureConfig( - enable_dynamic_shape=True, - ), )(self) # TODO: add ops implementations and turn on 'run_executor' @@ -621,9 +611,6 @@ def test_containers(self): maketest( ModuleContainers, do_tree_flatten=True, - capture_config=exir.CaptureConfig( - enable_dynamic_shape=True, - ), )(self) # can not run the graph module since the out variance with tensor list out @@ -675,9 +662,6 @@ def test_intermediate_dynamic_shape(self): ModuleIntermediateDynamicShape, run_graph_module=False, allow_non_contiguous_tensor=True, - capture_config=exir.CaptureConfig( - enable_dynamic_shape=True, - ), )(self) # TODO(shunting): some non constant tensors for transformer are non-contiguous. @@ -697,10 +681,6 @@ def test_transformer_encode(self): def test_ft_cond_basic(self): maketest( FTCondBasic, - capture_config=exir.CaptureConfig( - enable_dynamic_shape=True, - enable_functionalization=False, # TODO enable functionalization - ), )(self) def test_ft_map_basic(self): @@ -746,10 +726,6 @@ def test_ft_map_basic(self): def test_ft_cond_dynshape(self): maketest( FTCondDynShape, - capture_config=exir.CaptureConfig( - enable_dynamic_shape=True, - enable_functionalization=False, # TODO enable functionalization - ), )(self) def test_ft_map_dynshape(self): @@ -802,9 +778,6 @@ def test_ft_map_dynshape(self): def test_batch_norm(self): maketest( BatchNormModel, - capture_config=exir.CaptureConfig( - enable_dynamic_shape=True, - ), verify_graph=BatchNormModel.verify_graph, # TODO: lean mode does not have native_batch_norm.out implemented # run this on aten mode. diff --git a/test/models/export_program.py b/test/models/export_program.py index 2ee7d9b5e38..c501dc8d45e 100644 --- a/test/models/export_program.py +++ b/test/models/export_program.py @@ -13,7 +13,6 @@ from typing import Any, Dict, List, Type import torch -from executorch.exir import CaptureConfig from executorch.exir.passes import MemoryPlanningPass from executorch.exir.program._program import ExecutorchProgramManager from torch import nn @@ -140,7 +139,7 @@ def get_memory_planning_pass(self): @staticmethod def get_export_kwargs(): - return {"capture_config": CaptureConfig(pt2_mode=True, enable_aot=True)} + return {} class ModuleAddMul(torch.nn.Module): From fcd30c9448eecdc92b531cddcf83cce546af6a03 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Thu, 19 Mar 2026 12:23:04 -0700 Subject: [PATCH 2/2] Delete deprecated test files that only tested exir.capture behavior Summary: Remove test_capture.py (tested exir.capture directly), test_tracer.py (tested CaptureConfig flags like _unlift and enable_dynamic_shape), test_backends.py (covered by test_backends_lifted.py), and Remove corresponding build targets from TARGETS/BUCK files. Differential Revision: D95605525 --- exir/backend/test/BUCK | 1 + exir/tests/TARGETS | 32 --- exir/tests/test_capture.py | 29 --- exir/tests/test_tracer.py | 497 ------------------------------------- 4 files changed, 1 insertion(+), 558 deletions(-) delete mode 100644 exir/tests/test_capture.py delete mode 100644 exir/tests/test_tracer.py diff --git a/exir/backend/test/BUCK b/exir/backend/test/BUCK index 10278befea0..057aaf4caa3 100644 --- a/exir/backend/test/BUCK +++ b/exir/backend/test/BUCK @@ -158,6 +158,7 @@ fbcode_target(_kind = runtime.python_library, ], ) + fbcode_target(_kind = runtime.python_test, name = "test_to_backend_multi_method", srcs = [ diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 0db397798ba..c9136ce51da 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -67,24 +67,6 @@ runtime.python_library( ], ) -runtime.python_test( - name = "tracer", - srcs = [ - "test_tracer.py", - ], - # Static listing does not support tests generated in runtime. - supports_static_listing = False, - deps = [ - "fbsource//third-party/pypi/parameterized:parameterized", - ":lib", - ":models", - "//caffe2:torch", - "//caffe2/functorch:functorch_src", - "//executorch/exir:lib", - "//executorch/exir:tracer", - "//executorch/exir/dialects:lib", - ], -) python_unittest( name = "serde", @@ -339,20 +321,6 @@ python_unittest( ], ) -runtime.python_test( - name = "capture", - srcs = [ - "test_capture.py", - ], - # Static listing does not support tests generated in runtime. - supports_static_listing = False, - deps = [ - "fbsource//third-party/pypi/parameterized:parameterized", - ":models", - "//caffe2:torch", - "//executorch/exir:lib", - ], -) python_unittest( name = "dynamic_shape_propagation", diff --git a/exir/tests/test_capture.py b/exir/tests/test_capture.py deleted file mode 100644 index 1b649db191b..00000000000 --- a/exir/tests/test_capture.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import unittest - -import executorch.exir as exir -import executorch.exir.tests.models as models -import torch - -from parameterized import parameterized - - -class TestCapture(unittest.TestCase): - # pyre-ignore - @parameterized.expand(models.MODELS) - def test_module_call(self, model_name: str, model: torch.nn.Module) -> None: - # pyre-fixme[29]: `Union[torch._tensor.Tensor, - # torch.nn.modules.module.Module]` is not a function. - inputs = model.get_random_inputs() - expected = model(*inputs) - # TODO(ycao): Replace it with capture_multiple - exported_program = exir.capture(model, inputs, exir.CaptureConfig()) - - self.assertTrue(torch.allclose(expected, exported_program(*inputs))) diff --git a/exir/tests/test_tracer.py b/exir/tests/test_tracer.py deleted file mode 100644 index 22e01f33332..00000000000 --- a/exir/tests/test_tracer.py +++ /dev/null @@ -1,497 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import copy -import unittest -from typing import Dict, List, Tuple - -import executorch.exir as exir -import executorch.exir.tests.models as models - -import torch - -from executorch.exir import CaptureConfig -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.tests.common import register_additional_test_aten_ops -from executorch.exir.tracer import dynamo_trace, ExirDynamoConfig, using_dynamo -from functorch.experimental.control_flow import cond, map - -from parameterized import parameterized -from torch._export.verifier import SpecViolationError -from torch.fx.experimental.symbolic_shapes import is_concrete_int -from torch.testing import FileCheck - - -class TestTorchDispatchFXTracer(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - register_additional_test_aten_ops() - - def test_simple(self) -> None: - f = models.BasicSinMax() - f = ( - exir.capture(f, f.get_random_inputs(), exir.CaptureConfig()) - .to_edge() - .exported_program.graph_module - ) - - FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").run(f.code) - - def test_static_control_flow(self) -> None: - def f(pred: bool, x: torch.Tensor) -> torch.Tensor: - if pred: - return torch.sin(x).max() - else: - return torch.sin(x) - - pred = True - x = torch.randn(100) - f_true = ( - exir.capture(f, (pred, x), exir.CaptureConfig()) - .to_edge() - .exported_program.graph_module - ) - - FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").check( - "executorch_exir_dialects_edge__ops_aten_max" - ).run(f_true.code) - - pred = False - f_false = ( - exir.capture(f, (pred, x), exir.CaptureConfig()) - .to_edge() - .exported_program.graph_module - ) - FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").check_not( - "executorch_exir_dialects_edge__ops_aten_max" - ).run(f_false.code) - - def test_copy(self) -> None: - f = models.BasicSinMax() - f = ( - exir.capture(f, f.get_random_inputs(), exir.CaptureConfig()) - .to_edge() - .exported_program.graph_module - ) - - self.assertTrue(isinstance(f, torch.fx.GraphModule)) - g = copy.deepcopy(f) - self.assertTrue(isinstance(g, torch.fx.GraphModule)) - - def test_stacktrace(self) -> None: - def f(x: torch.Tensor) -> torch.Tensor: - return x + x - - traced_f = ( - exir.capture(f, (torch.rand(2, 2),), exir.CaptureConfig()) - .to_edge() - .exported_program.graph_module - ) - # Check that stacktrace is populated and retained (by checking twice) - self.assertTrue( - any(node.meta.get("stack_trace", None) for node in traced_f.graph.nodes) - ) - self.assertTrue( - any(node.meta.get("stack_trace", None) for node in traced_f.graph.nodes) - ) - - def test_ones(self) -> None: - class M(torch.nn.Module): - def forward(self, x): - y = torch.ones(x.shape[0]) - return x + y - - ep = torch.export.export( - M(), - (torch.ones(3),), - dynamic_shapes={"x": {0: torch.export.Dim("x")}}, - strict=True, - ) - exir.to_edge(ep) - - def test_possible_input_mutation(self) -> None: - def f(x: torch.Tensor) -> torch.Tensor: - return torch.add(torch.ones(5), torch.ones(5), out=x) - - with self.assertRaisesRegex( - SpecViolationError, - r"operator .* is not functional", - ): - exir.capture(f, (torch.zeros(5),), exir.CaptureConfig()).to_edge() - - def test_tensor_spec_for_const_tensors(self) -> None: - class Module(torch.nn.Module): - def __init__(self) -> None: - super(Module, self).__init__() - self.linear = torch.nn.Linear(2, 3) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(x) - - def get_random_inputs(self) -> Tuple[torch.Tensor, ...]: - return (torch.randn(2),) - - model = Module() - graph_module = ( - exir.capture(model, model.get_random_inputs(), exir.CaptureConfig()) - # torch._ops.aten.t.default - .to_edge( - exir.EdgeCompileConfig(_check_ir_validity=False) - ).exported_program.graph_module - ) - num_get_attr_node = 0 - num_get_attr_node_with_tensorspec = 0 - for nd in graph_module.graph.nodes: - if nd.op == "get_attr": - num_get_attr_node += 1 - if nd.meta.get("val") is not None: - num_get_attr_node_with_tensorspec += 1 - - self.assertEqual(2, num_get_attr_node) - self.assertEqual(2, num_get_attr_node_with_tensorspec) - - def test_multiple_returns_spec(self) -> None: - def f(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - return torch.ops.aten.max.dim(x, 0, False) - - cnt = 0 - module = ( - exir.capture(f, (torch.zeros(1, 2, 3),), exir.CaptureConfig()) - .to_edge() - .exported_program.graph_module - ) - for node in module.graph.nodes: - if node.target == exir_ops.edge.aten.max.dim: - cnt += 1 - self.assertIsInstance(node.meta["val"], tuple) - self.assertEqual(cnt, 1) - - def test_multiple_returns_pt2_mode(self) -> None: - def f(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - a = x * x - b = x + a - return a, b - - inputs = (torch.ones(1, 2, 3),) - orig_res = f(*inputs) - module = ( - exir.capture( - f, - inputs, - exir.CaptureConfig(), - ) - .to_edge() - .exported_program.graph_module - ) - new_res = module(*inputs) - for node in module.graph.nodes: - if node.op == "output": - self.assertIsInstance(node.meta["val"], list) - self.assertEqual(len(node.meta["val"]), 2) - - self.assertTrue(torch.allclose(orig_res[0], new_res[0])) - self.assertTrue(torch.allclose(orig_res[1], new_res[1])) - - def test_dynamo_capture_scalar_outputs(self) -> None: - def f(x: torch.Tensor) -> float: - return x.item() - - gm, guards = dynamo_trace( - f, - (torch.ones(1),), - False, - "real", - ExirDynamoConfig(), - ) - - # pyre-ignore - @parameterized.expand([("stock_tensor",)]) - def test_embedding_dynamic_shape(self, input_type: str) -> None: - class Module(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, x): - return x + x - - example_input = torch.ones(10, dtype=torch.int64) - m = Module() - gm = ( - exir.capture( - m, - (example_input,), - exir.CaptureConfig( - enable_functionalization=False, - enable_dynamic_shape=True, - ), - ) - .to_edge() - .exported_program.graph_module - ) - - print(gm.graph) - - def test_dynamic_shape(self) -> None: - def forward(x: torch.Tensor) -> torch.Tensor: - x = x.view(x.shape[0] - 1, -1) - return torch.cat([x, x]) - - gm = ( - exir.capture( - forward, - (torch.ones(3, 2, dtype=torch.int64),), - exir.CaptureConfig( - enable_functionalization=False, - enable_dynamic_shape=True, - _dynamo_config=ExirDynamoConfig(assume_static_by_default=True), - ), - # sym_size is not reg op - ) - .to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) - .exported_program.graph_module - ) - - for node in gm.graph.nodes: - if node.op in ("placeholder", "call_function"): - self.assertIn("val", node.meta) - - def test_dynamo_frontend_container_input(self) -> None: - class Module(torch.nn.Module): - def __init__(self) -> None: - super(Module, self).__init__() - - def forward( - self, x: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - ) -> torch.Tensor: - a = x[0] - b = x[1] - cum = 0 - for i in b: - cum += i.sum() - return a.cos() + cum.sin() - - with using_dynamo(True): - inp = ((torch.ones(6), (torch.ones(6), torch.ones(6))),) - gm = exir.capture(Module(), inp, exir.CaptureConfig()) - self.assertTrue(torch.allclose(Module()(*inp), gm(*inp))) - - # TODO (tmanlaibaatar) remove this test - def test_pt2_mode_with_dynamo_config(self) -> None: - def f(x: torch.Tensor) -> torch.Tensor: - return x[: x.shape[0] - 1] - - inp = (torch.randn(4, 5),) - prog = exir.capture( - f, - inp, - # missing dispatch key - ).to_edge() - self.assertTrue(prog(torch.randn(4, 5)).shape[0], 3) - - def test_input_container_type(self) -> None: - def f(x: torch.Tensor, y: List[torch.Tensor]) -> Dict[str, torch.Tensor]: - # pyre-ignore - return {"a": x.sum() + sum(y).sum()} - - inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)]) - - # pyre-fixme[23]: Unable to unpack `(...) -> Tuple[GraphModule, - # Set[torch._guards.Guard]]` into 2 values. - gm, _ = torch._dynamo.export(f, *inp, aten_graph=True, tracing_mode="symbolic") - prog = exir.capture(f, inp, config=exir.CaptureConfig()).to_edge() - - self.assertEqual(prog(*inp), f(*inp)) - - def test_assume_constant_by_default_prop(self) -> None: - def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - if x.shape[0] > 3: - return x.cos() - return x.sin() - - dynamo_config = ExirDynamoConfig(assume_static_by_default=True) - capture_config = exir.CaptureConfig( - enable_dynamic_shape=True, _dynamo_config=dynamo_config - ) - captured = exir.capture( - foo, (torch.ones(6, 2), torch.ones(6, 3)), capture_config - ).exported_program.graph_module - found = False - for node in captured.graph.nodes: - # at least one input needs to have concrete dims - if "val" in node.meta: - fake_val = node.meta["val"] - for dim in fake_val.shape: - if is_concrete_int(dim): - found = True - - self.assertTrue(found) - - def test_aot_config(self) -> None: - class FooWithBuffer(torch.nn.Module): - def __init__(self): - super().__init__() - self.register_buffer("buffer", torch.zeros(42)) - - def forward(self, x): - return x.cos() + self.buffer.sum() - - capture_config = exir.CaptureConfig(enable_aot=True) - captured_ep = exir.capture(FooWithBuffer(), (torch.ones(6, 2),), capture_config) - captured_gm = captured_ep.exported_program.graph_module - - placeholder_nodes = set() - print(captured_gm.graph) - for node in captured_gm.graph.nodes: - self.assertFalse(node.op == "get_attr") - if node.op == "placeholder": - placeholder_nodes.add(node) - if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: - # make sure the placeholders are used - arg_0, arg_1 = node.args - self.assertEqual( - placeholder_nodes, - { - list(arg_0._input_nodes.keys())[0], - list(arg_1._input_nodes.keys())[0], - }, - ) - - self.assertEqual(len(placeholder_nodes), 2) - captured_ep.to_edge() - - def test_export_unlift(self) -> None: - class Foo(torch.nn.Module): - def __init__(self): - super().__init__() - self.register_buffer("buffer", torch.ones(6, 4)) - - def forward(self, x): - return x.cos() + self.buffer.sin() - - ep = exir.capture( - Foo(), - (torch.ones(6, 4),), - exir.CaptureConfig(enable_aot=True, _unlift=True), - ) - - self.assertTrue(torch.allclose(ep(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))) - - def test_export_container_unlift(self) -> None: - class FooContainerInputOutput(torch.nn.Module): - def __init__(self): - super().__init__() - self.register_buffer("buffer", torch.ones(6, 4)) - - def forward(self, x): - return x[0][0].cos() + x[0][1].sin() + self.buffer.sin() - - inp = ((torch.ones(6, 4), torch.ones(6, 4)),) - ep = exir.capture( - FooContainerInputOutput(), - (inp,), - CaptureConfig(enable_aot=True, _unlift=True), - ) - self.assertTrue(torch.allclose(ep(inp), FooContainerInputOutput()(inp))) - - def test_export_container_input_unlift(self) -> None: - class FooContainerInputOutputV2(torch.nn.Module): - def __init__(self): - super().__init__() - self.register_buffer("buffer", torch.ones(6, 4)) - - def forward(self, x, y): - return x[0].cos() + y[0].sin() + self.buffer.sin() - - inp = ((torch.ones(6, 4),), (torch.ones(6, 4),)) - ep = exir.capture( - FooContainerInputOutputV2(), - inp, - CaptureConfig(enable_aot=True, _unlift=True), - ) - self.assertTrue(torch.allclose(ep(*inp), FooContainerInputOutputV2()(*inp))) - - def test_export_cond(self) -> None: - class A(torch.nn.Module): - def __init__(self): - super().__init__() - self.register_buffer("buffer", torch.ones(6, 4)) - - def forward(self): - return self.buffer.cos() - - class Foo(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = A() - - def forward(self, x): - def true_fn(x): - return x.cos() + self.a().sum() - - def false_fn(x): - return x.sin() - - return cond(x.shape[0] > 4, true_fn, false_fn, [x]) - - inp = torch.ones(6, 4) - ep = exir.capture( - Foo(), - (inp,), - CaptureConfig(enable_aot=True, _unlift=True), - ) - self.assertTrue(torch.allclose(ep(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))) - - def test_export_cond_map(self) -> None: - class A(torch.nn.Module): - def __init__(self): - super().__init__() - self.register_buffer("buffer", torch.ones(6, 4)) - - def forward(self): - return self.buffer.sum() - - class Module(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = A() - - def inner(self, x, pred): - def true_fn(x): - return x + x + self.a() - - def false_fn(x): - return x * x - self.a() - - return cond(pred, true_fn, false_fn, [x]) - - def forward(self, pred, xs): - def body(x, pred): - return self.inner(x, pred) + self.a() - - return map(body, xs, pred) - - inp = torch.randn(3, 2, 1) - ep = exir.capture( - Module(), - (torch.tensor(True), inp), - CaptureConfig(enable_aot=True, _unlift=True), - ) - - inp_test = torch.randn(3, 2, 1) - self.assertTrue( - torch.allclose( - ep(torch.tensor(True), inp_test), - Module()(torch.tensor(True), inp_test), - ) - ) - self.assertTrue( - torch.allclose( - ep(torch.tensor(False), inp_test), - Module()(torch.tensor(False), inp_test), - ) - )