|
1 | 1 | // Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. |
2 | 2 | using System; |
3 | 3 | using System.IO; |
4 | | -using System.Linq; |
5 | 4 | using System.Runtime.InteropServices; |
6 | | -using static TorchSharp.torch; |
7 | | -using static TorchSharp.torch.nn; |
8 | 5 | using Xunit; |
9 | 6 |
|
10 | 7 | #nullable enable |
11 | 8 |
|
12 | 9 | namespace TorchSharp |
13 | 10 | { |
14 | | - /// <summary> |
15 | | - /// Fact attribute that only runs on the platform matching the compiled .pt2 test models. |
16 | | - /// AOTInductor produces platform-specific native code, so the checked-in models |
17 | | - /// (compiled on macOS arm64) can only be loaded on that platform. |
18 | | - /// </summary> |
19 | | - public sealed class ExportTestFactAttribute : FactAttribute |
20 | | - { |
21 | | - public ExportTestFactAttribute() |
22 | | - { |
23 | | - if (!(RuntimeInformation.IsOSPlatform(OSPlatform.OSX) && |
24 | | - RuntimeInformation.ProcessArchitecture == Architecture.Arm64)) |
25 | | - { |
26 | | - Skip = "Test .pt2 models were compiled for macOS arm64"; |
27 | | - } |
28 | | - } |
29 | | - } |
30 | | - |
31 | 11 | [Collection("Sequential")] |
32 | 12 | public class TestExport |
33 | 13 | { |
34 | | - [ExportTestFact] |
35 | | - public void TestLoadExport_SimpleLinear() |
36 | | - { |
37 | | - // Test loading a simple linear model (inference-only) |
38 | | - using var exported = torch.export.load(@"simple_linear.export.pt2"); |
39 | | - Assert.NotNull(exported); |
40 | | - |
41 | | - var input = torch.ones(10); |
42 | | - var results = exported.run(input); |
43 | | - |
44 | | - Assert.NotNull(results); |
45 | | - Assert.Single(results); |
46 | | - Assert.Equal(new long[] { 5 }, results[0].shape); |
47 | | - Assert.Equal(torch.float32, results[0].dtype); |
48 | | - } |
49 | | - |
50 | | - [ExportTestFact] |
51 | | - public void TestLoadExport_LinearReLU() |
52 | | - { |
53 | | - // Test loading a Linear + ReLU model with typed output |
54 | | - using var exported = torch.export.load<Tensor>(@"linrelu.export.pt2"); |
55 | | - Assert.NotNull(exported); |
56 | | - |
57 | | - var input = torch.ones(10); |
58 | | - var result = exported.call(input); |
59 | | - |
60 | | - Assert.Equal(new long[] { 6 }, result.shape); |
61 | | - Assert.Equal(torch.float32, result.dtype); |
62 | | - |
63 | | - // ReLU should zero out negative values |
64 | | - Assert.True(result.data<float>().All(v => v >= 0)); |
65 | | - } |
66 | | - |
67 | | - [ExportTestFact] |
68 | | - public void TestLoadExport_TwoInputs() |
| 14 | + [Fact] |
| 15 | + public void TestExport_LoadNonExistentFile() |
69 | 16 | { |
70 | | - // Test loading a model with two inputs |
71 | | - using var exported = torch.export.load(@"two_inputs.export.pt2"); |
72 | | - Assert.NotNull(exported); |
73 | | - |
74 | | - var input1 = torch.ones(10); |
75 | | - var input2 = torch.ones(10) * 2; |
76 | | - var results = exported.forward(input1, input2); |
77 | | - |
78 | | - Assert.NotNull(results); |
79 | | - Assert.Single(results); |
80 | | - Assert.Equal(new long[] { 10 }, results[0].shape); |
81 | | - |
82 | | - // Should be input1 + input2 = 1 + 2 = 3 |
83 | | - var expected = torch.ones(10) * 3; |
84 | | - Assert.True(expected.allclose(results[0])); |
| 17 | + Assert.Throws<ExternalException>(() => |
| 18 | + torch.export.load("nonexistent.pt2")); |
85 | 19 | } |
86 | 20 |
|
87 | | - [ExportTestFact] |
88 | | - public void TestLoadExport_TupleOutput() |
| 21 | + [Fact] |
| 22 | + public void TestExport_LoadInvalidFile() |
89 | 23 | { |
90 | | - // Test loading a model that returns a tuple |
91 | | - using var exported = torch.export.load<(Tensor, Tensor)>(@"tuple_out.export.pt2"); |
92 | | - Assert.NotNull(exported); |
93 | | - |
94 | | - var x = torch.rand(3, 4); |
95 | | - var y = torch.rand(3, 4); |
96 | | - var result = exported.call(x, y); |
97 | | - |
98 | | - Assert.IsType<ValueTuple<Tensor, Tensor>>(result); |
99 | | - var (sum, diff) = result; |
100 | | - |
101 | | - Assert.Equal(x.shape, sum.shape); |
102 | | - Assert.Equal(x.shape, diff.shape); |
103 | | - Assert.True((x + y).allclose(sum)); |
104 | | - Assert.True((x - y).allclose(diff)); |
| 24 | + var tmpFile = Path.GetTempFileName(); |
| 25 | + try |
| 26 | + { |
| 27 | + File.WriteAllBytes(tmpFile, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF }); |
| 28 | + Assert.ThrowsAny<Exception>(() => |
| 29 | + torch.export.load(tmpFile)); |
| 30 | + } |
| 31 | + finally |
| 32 | + { |
| 33 | + File.Delete(tmpFile); |
| 34 | + } |
105 | 35 | } |
106 | 36 |
|
107 | | - [ExportTestFact] |
108 | | - public void TestLoadExport_ListOutput() |
| 37 | + [Fact] |
| 38 | + public void TestExport_LoadEmptyPath() |
109 | 39 | { |
110 | | - // Test loading a model that returns a list |
111 | | - using var exported = torch.export.load<Tensor[]>(@"list_out.export.pt2"); |
112 | | - Assert.NotNull(exported); |
113 | | - |
114 | | - var x = torch.rand(3, 4); |
115 | | - var y = torch.rand(3, 4); |
116 | | - var result = exported.forward(x, y); |
117 | | - |
118 | | - Assert.IsType<Tensor[]>(result); |
119 | | - Assert.Equal(2, result.Length); |
120 | | - |
121 | | - Assert.True((x + y).allclose(result[0])); |
122 | | - Assert.True((x - y).allclose(result[1])); |
| 40 | + Assert.ThrowsAny<Exception>(() => |
| 41 | + torch.export.load("")); |
123 | 42 | } |
124 | 43 |
|
125 | | - [ExportTestFact] |
126 | | - public void TestLoadExport_ThreeOutputs() |
| 44 | + [Fact] |
| 45 | + public void TestExport_DisposeIsIdempotent() |
127 | 46 | { |
128 | | - // Test loading a model that returns a 3-tuple |
129 | | - using var exported = torch.export.load<(Tensor, Tensor, Tensor)>(@"three_out.export.pt2"); |
130 | | - Assert.NotNull(exported); |
131 | | - |
132 | | - var x = torch.rand(3, 4); |
133 | | - var y = torch.rand(3, 4); |
134 | | - var result = exported.call(x, y); |
| 47 | + // Verify that double-dispose doesn't throw. |
| 48 | + // We can't construct a valid ExportedProgram without a real model, |
| 49 | + // so we catch the load error and verify we can still call Dispose |
| 50 | + // without crashing (the constructor should have cleaned up already). |
| 51 | + ExportedProgram? program = null; |
| 52 | + try |
| 53 | + { |
| 54 | + program = torch.export.load("nonexistent.pt2"); |
| 55 | + } |
| 56 | + catch (ExternalException) |
| 57 | + { |
| 58 | + // Expected - the file doesn't exist |
| 59 | + } |
135 | 60 |
|
136 | | - Assert.IsType<ValueTuple<Tensor, Tensor, Tensor>>(result); |
137 | | - var (sum, diff, prod) = result; |
| 61 | + // If somehow a program was created (shouldn't happen), dispose it twice |
| 62 | + if (program != null) |
| 63 | + { |
| 64 | + program.Dispose(); |
| 65 | + program.Dispose(); // second dispose should not throw |
| 66 | + } |
138 | 67 |
|
139 | | - Assert.Equal(x.shape, sum.shape); |
140 | | - Assert.Equal(x.shape, diff.shape); |
141 | | - Assert.Equal(x.shape, prod.shape); |
142 | | - Assert.True((x + y).allclose(sum)); |
143 | | - Assert.True((x - y).allclose(diff)); |
144 | | - Assert.True((x * y).allclose(prod)); |
| 68 | + // The fact that we reach here without crashing validates idempotent cleanup |
145 | 69 | } |
146 | 70 |
|
147 | | - [ExportTestFact] |
148 | | - public void TestLoadExport_Sequential() |
| 71 | + [Fact] |
| 72 | + public void TestExport_GenericLoadNonExistentFile() |
149 | 73 | { |
150 | | - // Test loading a sequential model |
151 | | - using var exported = torch.export.load<Tensor>(@"sequential.export.pt2"); |
152 | | - Assert.NotNull(exported); |
153 | | - |
154 | | - var input = torch.ones(1000); |
155 | | - var result = exported.call(input); |
156 | | - |
157 | | - Assert.Equal(new long[] { 10 }, result.shape); |
158 | | - Assert.Equal(torch.float32, result.dtype); |
| 74 | + Assert.Throws<ExternalException>(() => |
| 75 | + torch.export.load<torch.Tensor>("nonexistent.pt2")); |
159 | 76 | } |
160 | 77 |
|
161 | 78 | [Fact] |
162 | | - public void TestExport_LoadNonExistentFile() |
| 79 | + public void TestExport_GenericLoadInvalidFile() |
163 | 80 | { |
164 | | - // This test is platform-independent - it validates error handling |
165 | | - Assert.Throws<System.Runtime.InteropServices.ExternalException>(() => |
166 | | - torch.export.load(@"nonexistent.pt2")); |
| 81 | + var tmpFile = Path.GetTempFileName(); |
| 82 | + try |
| 83 | + { |
| 84 | + File.WriteAllBytes(tmpFile, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF }); |
| 85 | + Assert.ThrowsAny<Exception>(() => |
| 86 | + torch.export.load<torch.Tensor>(tmpFile)); |
| 87 | + } |
| 88 | + finally |
| 89 | + { |
| 90 | + File.Delete(tmpFile); |
| 91 | + } |
167 | 92 | } |
168 | 93 | } |
169 | 94 | } |
0 commit comments