diff --git a/src/utilities.jl b/src/utilities.jl index 5be24a2b21..59c5b97c2f 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -84,7 +84,7 @@ function versioninfo(io::IO=stdout) println(io, length(devs), " devices:") end for (i, dev) in enumerate(devs) - if has_nvml() + function query_nvml() mig = uuid(dev) != parent_uuid(dev) nvml_gpu = NVML.Device(parent_uuid(dev)) nvml_dev = NVML.Device(uuid(dev); mig) @@ -92,13 +92,33 @@ function versioninfo(io::IO=stdout) str = NVML.name(nvml_dev) cap = NVML.compute_capability(nvml_gpu) mem = NVML.memory_info(nvml_dev) - else + + (; str, cap, mem) + end + + function query_cuda() str = name(dev) cap = capability(dev) mem = device!(dev) do # this requires a device context, so we prefer NVML (free=available_memory(), total=total_memory()) end + (; str, cap, mem) + end + + str, cap, mem = if has_nvml() + try + query_nvml() + catch err + @show err + if !isa(err, NVML.NVMLError) || + !in(err.code, [NVML.ERROR_NOT_SUPPORTED, NVML.ERROR_NO_PERMISSION]) + rethrow() + end + query_cuda() + end + else + query_cuda() end println(io, " $(i-1): $str (sm_$(cap.major)$(cap.minor), $(Base.format_bytes(mem.free)) / $(Base.format_bytes(mem.total)) available)") end