diff --git a/GDeflateConsole/NativeLibrary.cs b/GDeflateConsole/NativeLibrary.cs index 13c2912..fccd0c0 100644 --- a/GDeflateConsole/NativeLibrary.cs +++ b/GDeflateConsole/NativeLibrary.cs @@ -33,16 +33,29 @@ static NativeLibrary() { if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { - string programFiles = Environment.GetFolderPath(Environment.SpecialFolder.ProgramFiles); - string nvidiaGpuComputingToolkit = Path.Combine(programFiles, "NVIDIA GPU Computing Toolkit", "CUDA"); - if (Directory.Exists(nvidiaGpuComputingToolkit)) + string[] programFilesPaths = { + Environment.GetFolderPath(Environment.SpecialFolder.ProgramFiles), + Environment.GetFolderPath(Environment.SpecialFolder.ProgramFilesX86) + }; + + foreach (var programFilesPath in programFilesPaths.Distinct()) { - var versions = Directory.GetDirectories(nvidiaGpuComputingToolkit, "v*.*") - .Select(path => new { Path = path, Version = GetVersionFromPath(path) }) - .OrderByDescending(x => x.Version) - .ToList(); + if (string.IsNullOrEmpty(programFilesPath)) continue; + + string nvidiaGpuComputingToolkit = Path.Combine(programFilesPath, "NVIDIA GPU Computing Toolkit", "CUDA"); + if (Directory.Exists(nvidiaGpuComputingToolkit)) + { + var versions = Directory.GetDirectories(nvidiaGpuComputingToolkit, "v*.*") + .Select(path => new { Path = path, Version = GetVersionFromPath(path) }) + .OrderByDescending(x => x.Version) + .ToList(); - return versions.FirstOrDefault()?.Path; + var latestVersion = versions.FirstOrDefault(); + if (latestVersion != null) + { + return latestVersion.Path; + } + } } } else @@ -72,6 +85,17 @@ static NativeLibrary() private static string? FindCudart() { + // Search in common paths first + string[] searchPaths = { Directory.GetCurrentDirectory(), AppContext.BaseDirectory }; + foreach (var path in searchPaths) + { + var dlls = Directory.GetFiles(path, "cudart64_*.dll", SearchOption.AllDirectories) + .Select(p => new { Path = p, Version = GetVersionFromFileName(p) }) + .OrderByDescending(x => x.Version) + .ToList(); + if (dlls.Any()) return dlls.First().Path; + } + if (_cudaToolkitPath == null) return null; string binPath = Path.Combine(_cudaToolkitPath, "bin"); @@ -101,15 +125,49 @@ static NativeLibrary() private static string? FindNvcomp() { + // Search in common paths first + string[] searchPaths = { Directory.GetCurrentDirectory(), AppContext.BaseDirectory }; + foreach (var path in searchPaths) + { + var nvcompPaths = Directory.GetFiles(path, "nvcomp*.dll", SearchOption.AllDirectories); + if (nvcompPaths.Any()) return nvcompPaths.First(); + } + + // Search in standard nvCOMP installation path on Windows + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + string[] programFilesPaths = { + Environment.GetFolderPath(Environment.SpecialFolder.ProgramFiles), + Environment.GetFolderPath(Environment.SpecialFolder.ProgramFilesX86) + }; + + foreach (var programFilesPath in programFilesPaths.Distinct()) + { + if (string.IsNullOrEmpty(programFilesPath)) continue; + + string nvcompInstallPath = Path.Combine(programFilesPath, "NVIDIA nvCOMP"); + if (Directory.Exists(nvcompInstallPath)) + { + var nvcompDllPaths = Directory.GetFiles(nvcompInstallPath, "nvcomp*.dll", SearchOption.AllDirectories); + if (nvcompDllPaths.Any()) + { + return nvcompDllPaths.First(); + } + } + } + } + if (_cudaToolkitPath == null) return null; string binPath = Path.Combine(_cudaToolkitPath, "bin"); if (Directory.Exists(binPath)) { - string nvcompPath = Path.Combine(binPath, "nvcomp.dll"); - if (File.Exists(nvcompPath)) + // Search recursively for nvcomp.dll + var nvcompPaths = Directory.GetFiles(binPath, "nvcomp*.dll", SearchOption.AllDirectories); + if (nvcompPaths.Length > 0) { - return nvcompPath; + // Return the first match + return nvcompPaths[0]; } } return null;