Skip to content

Commit

Permalink
Add healthy server for liveness probe and isolate file.
Browse files Browse the repository at this point in the history
Signed-off-by: huozhixin.hzx <huozhixin.hzx@alibaba-inc.com>
  • Loading branch information
huozhixin.hzx committed Mar 29, 2024
1 parent 21ca876 commit ad934c1
Show file tree
Hide file tree
Showing 15 changed files with 1,212 additions and 5 deletions.
26 changes: 26 additions & 0 deletions cmd/nvidia-device-plugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/NVIDIA/k8s-device-plugin/internal/resource"
"os"
"strconv"
"syscall"
"time"

Expand All @@ -38,6 +40,8 @@ import (
"github.com/NVIDIA/k8s-device-plugin/internal/watch"
)

var healthyCheckPortFlag string

func main() {
var configFile string

Expand Down Expand Up @@ -125,6 +129,13 @@ func main() {
Usage: "the path on the host where MPS-specific mounts and files are created by the MPS control daemon manager",
EnvVars: []string{"MPS_ROOT"},
},
&cli.StringFlag{
Name: "healthy-check-port",
Value: resource.HealthyServerPort,
Usage: "the healthy check server port of nvidia device plugin",
Destination: &healthyCheckPortFlag,
EnvVars: []string{"HEALTHY_CHECK_PORT"},
},
}

err := c.Run(os.Args)
Expand All @@ -149,6 +160,10 @@ func validateFlags(config *spec.Config) error {
return fmt.Errorf("invalid --device-id-strategy option: %v", *config.Flags.Plugin.DeviceIDStrategy)
}

if _, err := strconv.Atoi(healthyCheckPortFlag); err != nil {
return fmt.Errorf("invalid healthy-check-port option: %v", healthyCheckPortFlag)
}

if config.Sharing.SharingStrategy() == spec.SharingStrategyMPS {
if *config.Flags.MigStrategy == spec.MigStrategyMixed {
return fmt.Errorf("using --mig-strategy=mixed is not supported with MPS")
Expand Down Expand Up @@ -185,6 +200,17 @@ func start(c *cli.Context, flags []cli.Flag) error {
klog.Info("Starting OS watcher.")
sigs := watch.Signals(syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)

healthServer, err := resource.NewHealthServer(healthyCheckPortFlag)
if err != nil {
return fmt.Errorf("failed to start health server: %v", err)
}
go func() {
klog.Info("Starting health server.")
if err := healthServer.Serve(); err != nil {
klog.Infof("Health server error: %v", err)
}
}()

var started bool
var restartTimeout <-chan time.Time
var plugins []plugin.Interface
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ require (
github.com/cpuguy83/go-md2man/v2 v2.0.3 // indirect
github.com/cyphar/filepath-securejoin v0.2.4 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/deckarep/golang-set v1.8.0 // indirect
github.com/distribution/reference v0.5.0 // indirect
github.com/docker/cli v25.0.3+incompatible // indirect
github.com/docker/distribution v2.8.3+incompatible // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ github.com/cyphar/filepath-securejoin v0.2.4/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxG
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/deckarep/golang-set v1.8.0 h1:sk9/l/KqpunDwP7pSjUg0keiOOLEnOBHzykLrsPppp4=
github.com/deckarep/golang-set v1.8.0/go.mod h1:5nI87KwE7wgsBU1F4GKAw2Qod7p5kyS383rP6+o6qqo=
github.com/distribution/distribution/v3 v3.0.0-20221208165359-362910506bc2 h1:aBfCb7iqHmDEIp6fBvC/hQUddQfg+3qdYjwzaiP9Hnc=
github.com/distribution/distribution/v3 v3.0.0-20221208165359-362910506bc2/go.mod h1:WHNsWjnIn2V1LYOrME7e8KxSeKunYHsxEm4am0BUtcI=
github.com/distribution/reference v0.5.0 h1:/FUIFXtfc/x2gpa5/VGfiGLuOIdYa1t65IKK2OFGvA0=
Expand Down
31 changes: 27 additions & 4 deletions internal/plugin/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ package plugin
import (
"errors"
"fmt"
"github.com/NVIDIA/k8s-device-plugin/internal/resource"
"github.com/NVIDIA/k8s-device-plugin/internal/watch"
"github.com/fsnotify/fsnotify"
"log"
"net"
"os"
"path"
Expand Down Expand Up @@ -66,6 +70,8 @@ type NvidiaDevicePlugin struct {

mpsDaemon *mps.Daemon
mpsHostRoot mps.Root

watcher *fsnotify.Watcher
}

// NewNvidiaDevicePlugin returns an initialized NvidiaDevicePlugin
Expand Down Expand Up @@ -104,9 +110,10 @@ func NewNvidiaDevicePlugin(config *spec.Config, resourceManager rm.ResourceManag

// These will be reinitialized every
// time the plugin server is restarted.
server: nil,
health: nil,
stop: nil,
server: nil,
health: nil,
stop: nil,
watcher: nil,
}
return &plugin, nil
}
Expand All @@ -115,13 +122,20 @@ func (plugin *NvidiaDevicePlugin) initialize() {
plugin.server = grpc.NewServer([]grpc.ServerOption{}...)
plugin.health = make(chan *rm.Device)
plugin.stop = make(chan interface{})
fsWatcher, err := watch.Files(resource.DevicePluginConfigPath)
if err != nil {
log.Println("failed to create file system watcher.")
return
}
plugin.watcher = fsWatcher
}

func (plugin *NvidiaDevicePlugin) cleanup() {
close(plugin.stop)
plugin.server = nil
plugin.health = nil
plugin.stop = nil
plugin.watcher.Close()
}

// Devices returns the full set of devices associated with the plugin.
Expand Down Expand Up @@ -278,7 +292,7 @@ func (plugin *NvidiaDevicePlugin) GetDevicePluginOptions(context.Context, *plugi

// ListAndWatch lists devices and update that list according to the health status
func (plugin *NvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.DevicePlugin_ListAndWatchServer) error {
if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil {
if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.healthyDevices()}); err != nil {
return err
}

Expand All @@ -293,6 +307,11 @@ func (plugin *NvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.D
if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil {
return nil
}
case _ = <-plugin.watcher.Events:
klog.Infof("find file %s changed, start resubmit devices", resource.IsolatedDevicesFilePath)
s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.healthyDevices()})
case err := <-plugin.watcher.Errors:
klog.Infof("inotify: %s", err)
}
}
}
Expand Down Expand Up @@ -483,6 +502,10 @@ func (plugin *NvidiaDevicePlugin) apiDevices() []*pluginapi.Device {
return plugin.rm.Devices().GetPluginDevices()
}

func (plugin *NvidiaDevicePlugin) healthyDevices() []*pluginapi.Device {
return plugin.rm.Devices().GetHealthyDevice()
}

// updateResponseForDeviceListEnvvar sets the environment variable for the requested devices.
func (plugin *NvidiaDevicePlugin) updateResponseForDeviceListEnvvar(response *pluginapi.ContainerAllocateResponse, deviceIDs ...string) {
response.Envs[plugin.deviceListEnvvar] = strings.Join(deviceIDs, ",")
Expand Down
85 changes: 85 additions & 0 deletions internal/resource/health.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package resource

import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
"strconv"
"time"
)

const (
DevicePluginConfigPath = "/etc/nvidia-device-plugin/"
IsolatedDevicesFilePath = "/etc/nvidia-device-plugin/unhealthyDevices.json"
HealthyServerPort = "7123"
)

type UnhealthyDevices struct {
GPUIndex []string `json:"index"`
GPUUuid []string `json:"uuid"`
}

type HealthServer struct {
httpServer *http.Server
mux *http.ServeMux
}

func NewHealthServer(portString string) (*HealthServer, error) {
port, err := strconv.Atoi(portString)
if err != nil {
log.Println("Port set for health server is invalid.")
return nil, err
}
if port > 65535 || port < 1 {
return nil, fmt.Errorf("port set for health server is invalid, it should be in [1, 65535]")
}

healthServer := &HealthServer{
httpServer: &http.Server{
Addr: fmt.Sprintf(":%v", port),
},
mux: http.NewServeMux(),
}
healthServer.init()

return healthServer, nil
}

func (h *HealthServer) init() {
h.mux.HandleFunc("/health", h.serveHealthyHandler)
h.httpServer.Handler = h.mux
}

func (h *HealthServer) Serve() error {
return h.httpServer.ListenAndServe()
}

func (h *HealthServer) serveHealthyHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}

func FindUnhealthyDevices() (*UnhealthyDevices, error) {
_, err := os.Stat(IsolatedDevicesFilePath)
if os.IsNotExist(err) {
return nil, nil
}

unhealthyDevices := UnhealthyDevices{}
// To wait for write file
time.Sleep(3 * time.Second)

jsonData, err := ioutil.ReadFile(IsolatedDevicesFilePath)
if err != nil {
return nil, fmt.Errorf("failed to read file %s", IsolatedDevicesFilePath)
}

err = json.Unmarshal(jsonData, &unhealthyDevices)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal json file %s", IsolatedDevicesFilePath)
}

return &unhealthyDevices, nil
}
22 changes: 22 additions & 0 deletions internal/rm/devices.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package rm

import (
"fmt"
"github.com/NVIDIA/k8s-device-plugin/internal/resource"
set "github.com/deckarep/golang-set"
"strconv"
"strings"

Expand Down Expand Up @@ -184,6 +186,26 @@ func (ds Devices) GetPluginDevices() []*pluginapi.Device {
return res
}

// GetHealthyDevice returns the Devices from all devices in the Devices but not in isolated Devices
func (ds Devices) GetHealthyDevice() []*pluginapi.Device {
var devs []*pluginapi.Device
unhealthyDevices, err := resource.FindUnhealthyDevices()
if err != nil || unhealthyDevices == nil {
for _, d := range ds {
devs = append(devs, &d.Device)
}
return devs
}
unhealthyIndex := set.NewSetFromSlice(StringSliceToInterfaceSlice(unhealthyDevices.GPUIndex))
unhealthyUuid := set.NewSetFromSlice(StringSliceToInterfaceSlice(unhealthyDevices.GPUUuid))
for _, d := range ds {
if !unhealthyUuid.Contains(d.GetID()) && !unhealthyIndex.Contains(d.Index) {
devs = append(devs, &d.Device)
}
}
return devs
}

// GetIndices returns the Indices from all devices in the Devices
func (ds Devices) GetIndices() []string {
var res []string
Expand Down
8 changes: 8 additions & 0 deletions internal/rm/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,11 @@ func (s int8Slice) String() string {
}
return string(b)
}

func StringSliceToInterfaceSlice(slice []string) []interface{} {
interfaceSlice := make([]interface{}, len(slice))
for i, v := range slice {
interfaceSlice[i] = v
}
return interfaceSlice
}
22 changes: 22 additions & 0 deletions vendor/github.com/deckarep/golang-set/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 22 additions & 0 deletions vendor/github.com/deckarep/golang-set/LICENSE

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit ad934c1

Please sign in to comment.