From ad184222b7a5007aabc1a5a7401ab996d2c8ea45 Mon Sep 17 00:00:00 2001 From: aby913 Date: Fri, 20 Dec 2024 21:55:31 +0800 Subject: [PATCH] fix: get cuda version (#84) --- pkg/gpu/module.go | 17 ++++++++++++++ pkg/gpu/tasks.go | 48 ++++++++++++++++++++++++++++++++++++++ pkg/phase/cluster/linux.go | 1 + pkg/terminus/ossystem.go | 5 ++-- 4 files changed, 69 insertions(+), 2 deletions(-) diff --git a/pkg/gpu/module.go b/pkg/gpu/module.go index c07212e..08b34d1 100644 --- a/pkg/gpu/module.go +++ b/pkg/gpu/module.go @@ -198,3 +198,20 @@ func (m *InstallPluginModule) Init() { installGPUShared, } } + +type GetCudaVersionModule struct { + common.KubeModule +} + +func (g *GetCudaVersionModule) Init() { + g.Name = "GetCudaVersion" + + getCudaVersion := &task.LocalTask{ + Name: "GetCudaVersion", + Action: new(GetCudaVersion), + } + + g.Tasks = []task.Interface{ + getCudaVersion, + } +} diff --git a/pkg/gpu/tasks.go b/pkg/gpu/tasks.go index 1435089..f8f9f4d 100644 --- a/pkg/gpu/tasks.go +++ b/pkg/gpu/tasks.go @@ -309,3 +309,51 @@ func (t *InstallGPUShared) Execute(runtime connector.Runtime) error { return nil } + +type GetCudaVersion struct { + common.KubeAction +} + +func (g *GetCudaVersion) Execute(runtime connector.Runtime) error { + var nvidiaSmiFile string + var systemInfo = runtime.GetSystemInfo() + + switch { + case systemInfo.IsWsl(): + nvidiaSmiFile = "/usr/lib/wsl/lib/nvidia-smi" + default: + nvidiaSmiFile = "/usr/bin/nvidia-smi" + } + + if !util.IsExist(nvidiaSmiFile) { + logger.Info("nvidia-smi not exists") + return nil + } + + var cudaVersion string + res, err := runtime.GetRunner().Host.Cmd(fmt.Sprintf("%s --version", nvidiaSmiFile), false, true) + if err != nil { + logger.Errorf("get cuda version error %v", err) + return nil + } + + lines := strings.Split(res, "\n") + + if lines == nil || len(lines) == 0 { + return nil + } + for _, line := range lines { + if strings.Contains(line, "CUDA Version") { + parts := strings.Split(line, ":") + if len(parts) != 2 { + break + } + cudaVersion = strings.TrimSpace(parts[1]) + } + } + if cudaVersion != "" { + common.TerminusGlobalEnvs["CUDA_VERSION"] = cudaVersion + } + + return nil +} diff --git a/pkg/phase/cluster/linux.go b/pkg/phase/cluster/linux.go index 8266a2e..a28000a 100644 --- a/pkg/phase/cluster/linux.go +++ b/pkg/phase/cluster/linux.go @@ -41,6 +41,7 @@ func (l *linuxInstallPhaseBuilder) installGpuPlugin() phase { return []module.Module{ &gpu.RestartK3sServiceModule{Skip: !(l.runtime.Arg.Kubetype == common.K3s)}, &gpu.InstallPluginModule{Skip: skipGpuPlugin}, + &gpu.GetCudaVersionModule{}, } } diff --git a/pkg/terminus/ossystem.go b/pkg/terminus/ossystem.go index 85dab38..a02930e 100644 --- a/pkg/terminus/ossystem.go +++ b/pkg/terminus/ossystem.go @@ -1,13 +1,14 @@ package terminus import ( - "bytetrade.io/web3os/installer/pkg/core/logger" - "bytetrade.io/web3os/installer/pkg/storage" "context" "fmt" "path" "time" + "bytetrade.io/web3os/installer/pkg/core/logger" + "bytetrade.io/web3os/installer/pkg/storage" + "bytetrade.io/web3os/installer/pkg/clientset" "bytetrade.io/web3os/installer/pkg/common" cc "bytetrade.io/web3os/installer/pkg/core/common"