diff --git a/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs b/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs index b306c0cd7..df758de80 100644 --- a/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs +++ b/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs @@ -647,7 +647,10 @@ private static void ValidateIntegerRange(long value, ScalarType dtype, string ar static torch() { deleters = new ConcurrentDictionary(); - TryInitializeDeviceType(DeviceType.CUDA); + if (!TryInitializeDeviceType(DeviceType.CUDA)) { + // CUDA not available, ensure CPU backend is loaded + LoadNativeBackend(false, out _); + } } #endregion } diff --git a/src/TorchSharp/Torch.cs b/src/TorchSharp/Torch.cs index dd7a07689..e0daf0bd7 100644 --- a/src/TorchSharp/Torch.cs +++ b/src/TorchSharp/Torch.cs @@ -116,7 +116,7 @@ private static void LoadNativeBackend(bool useCudaBackend, out StringBuilder? tr trace = null; if (!alreadyLoaded) { - bool ok; + bool ok = true; trace = new StringBuilder(); trace.AppendLine($""); trace.AppendLine($"TorchSharp: LoadNativeBackend: Initialising native backend, useCudaBackend = {useCudaBackend}"); @@ -134,30 +134,32 @@ private static void LoadNativeBackend(bool useCudaBackend, out StringBuilder? tr // Preloading these DLLs on windows seems to iron out problems where one native DLL // requests a load of another through dynamic linking techniques. // - ok = TryLoadNativeLibraryByName("cudnn_adv64_9", typeof(torch).Assembly, trace); - ok = TryLoadNativeLibraryByName("cudnn_cnn64_9", typeof(torch).Assembly, trace); - ok = TryLoadNativeLibraryByName("cudnn_ops64_9", typeof(torch).Assembly, trace); - ok = TryLoadNativeLibraryByName("cudnn_graph64_9.dll", typeof(torch).Assembly, trace); - ok = TryLoadNativeLibraryByName("cudnn_heuristic64_9.dll", typeof(torch).Assembly, trace); - ok = TryLoadNativeLibraryByName("cudnn_engines_precompiled64_9.dll", typeof(torch).Assembly, trace); - ok = TryLoadNativeLibraryByName("cudnn_engines_runtime_compiled64_9.dll", typeof(torch).Assembly, trace); - ok = TryLoadNativeLibraryByName("nvrtc-builtins64_128", typeof(torch).Assembly, trace); - ok = TryLoadNativeLibraryByName("caffe2_nvrtc", typeof(torch).Assembly, trace); - ok = TryLoadNativeLibraryByName("nvrtc64_120_0", typeof(torch).Assembly, trace); - ok = TryLoadNativeLibraryByName("cublasLt64_12", typeof(torch).Assembly, trace); - ok = TryLoadNativeLibraryByName("cufft64_11", typeof(torch).Assembly, trace); - ok = TryLoadNativeLibraryByName("cusparse64_12", typeof(torch).Assembly, trace); - ok = TryLoadNativeLibraryByName("cusolver64_11", typeof(torch).Assembly, trace); + // These are optional preloads to help Windows resolve transitive native dependencies. + // Failures here are non-fatal -- the critical loads are torch_cuda and LibTorchSharp below. + TryLoadNativeLibraryByName("cudnn_adv64_9", typeof(torch).Assembly, trace); + TryLoadNativeLibraryByName("cudnn_cnn64_9", typeof(torch).Assembly, trace); + TryLoadNativeLibraryByName("cudnn_ops64_9", typeof(torch).Assembly, trace); + TryLoadNativeLibraryByName("cudnn_graph64_9.dll", typeof(torch).Assembly, trace); + TryLoadNativeLibraryByName("cudnn_heuristic64_9.dll", typeof(torch).Assembly, trace); + TryLoadNativeLibraryByName("cudnn_engines_precompiled64_9.dll", typeof(torch).Assembly, trace); + TryLoadNativeLibraryByName("cudnn_engines_runtime_compiled64_9.dll", typeof(torch).Assembly, trace); + TryLoadNativeLibraryByName("nvrtc-builtins64_128", typeof(torch).Assembly, trace); + TryLoadNativeLibraryByName("caffe2_nvrtc", typeof(torch).Assembly, trace); + TryLoadNativeLibraryByName("nvrtc64_120_0", typeof(torch).Assembly, trace); + TryLoadNativeLibraryByName("cublasLt64_12", typeof(torch).Assembly, trace); + TryLoadNativeLibraryByName("cufft64_11", typeof(torch).Assembly, trace); + TryLoadNativeLibraryByName("cusparse64_12", typeof(torch).Assembly, trace); + TryLoadNativeLibraryByName("cusolver64_11", typeof(torch).Assembly, trace); } - ok = TryLoadNativeLibraryByName("torch_cuda", typeof(torch).Assembly, trace); + TryLoadNativeLibraryByName("torch_cuda", typeof(torch).Assembly, trace); ok = TryLoadNativeLibraryByName("LibTorchSharp", typeof(torch).Assembly, trace); } else { - ok = TryLoadNativeLibraryByName("torch_cpu", typeof(torch).Assembly, trace); + TryLoadNativeLibraryByName("torch_cpu", typeof(torch).Assembly, trace); ok = TryLoadNativeLibraryByName("LibTorchSharp", typeof(torch).Assembly, trace); } - trace.AppendLine($" Result from regular native load of LibTorchSharp is {ok}"); + trace.AppendLine($" Result from regular native load of backend libraries (CUDA preloads + torch_* + LibTorchSharp) is {ok}"); // Try dynamic load from package directories if (!ok) { @@ -299,7 +301,12 @@ public static bool TryInitializeDeviceType(DeviceType deviceType) return false; } - LoadNativeBackend(deviceType == DeviceType.CUDA, out _); + try { + LoadNativeBackend(deviceType == DeviceType.CUDA, out _); + } catch (NotSupportedException) { + return false; + } + if (deviceType == DeviceType.CUDA) { return cuda.CallTorchCudaIsAvailable(); } else {