diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 0de257f7c6..ca71ebb40e 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -312,7 +312,17 @@ def prepare_inputs( elif isinstance(inputs, Input): return inputs - elif isinstance(inputs, (torch.Tensor, int, float, bool)): + elif isinstance(inputs, torch.Tensor): + # Pass the tensor directly — torch.tensor() would create a full + # data copy, which wastes GPU memory when torch.compile lifts + # model parameters as graph inputs. + # Input.from_tensor only reads shape/dtype/format metadata. + return Input.from_tensor( + inputs, + disable_memory_format_check=disable_memory_format_check, + ) + + elif isinstance(inputs, (int, float, bool)): return Input.from_tensor( torch.tensor(inputs), disable_memory_format_check=disable_memory_format_check, diff --git a/tests/py/dynamo/runtime/test_000_compiler_utils.py b/tests/py/dynamo/runtime/test_000_compiler_utils.py index 02b0d63523..42d226911a 100644 --- a/tests/py/dynamo/runtime/test_000_compiler_utils.py +++ b/tests/py/dynamo/runtime/test_000_compiler_utils.py @@ -88,6 +88,40 @@ def test_prepare_mixed_type_compound_tensor_input(self): same_output_format(inputs, prepared_inputs_trt, enforce_tensor_type=False) ) + def test_prepare_tensor_does_not_copy_data(self): + """Verify that prepare_inputs does not duplicate GPU tensor data. + + When torch.compile lifts model parameters as graph inputs, + prepare_inputs receives every weight tensor. Previously, + torch.tensor(t) created a full copy of each tensor, doubling GPU + memory usage. Input.from_tensor only needs shape/dtype metadata, + so no copy is necessary. + """ + original = torch.randn(1024, 1024, device="cuda") + before = torch.cuda.memory_allocated() + result = prepare_inputs([original]) + after = torch.cuda.memory_allocated() + # No significant new allocation (allow small overhead, but not a full copy) + self.assertLess( + after - before, + original.nelement() * original.element_size(), + "prepare_inputs should not allocate a full copy of the input tensor", + ) + # Result should preserve shape and dtype + self.assertEqual(result[0].shape, original.shape) + self.assertEqual(result[0].dtype, original.dtype) + + def test_prepare_scalar_inputs(self): + """Verify that scalar inputs are still converted to tensors.""" + int_result = prepare_inputs(42) + self.assertIsInstance(int_result, torch_tensorrt.Input) + + float_result = prepare_inputs(3.14) + self.assertIsInstance(float_result, torch_tensorrt.Input) + + bool_result = prepare_inputs(True) + self.assertIsInstance(bool_result, torch_tensorrt.Input) + if __name__ == "__main__": unittest.main()