Skip to content

Commit

Permalink
Merge pull request #59 from cdesiniotis/handle-kubelet-restarts
Browse files Browse the repository at this point in the history
Handle kubelet restarts
  • Loading branch information
rthallisey authored Sep 23, 2022
2 parents 186fe57 + b4eeae7 commit f7b0769
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 21 deletions.
62 changes: 50 additions & 12 deletions pkg/device_plugin/generic_device_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ type GenericDevicePlugin struct {
devs []*pluginapi.Device
server *grpc.Server
socketPath string
stop chan struct{}
stop chan struct{} // this channel signals to stop the DP
term chan bool // this channel detects kubelet restarts
healthy chan string
unhealthy chan string
devicePath string
Expand All @@ -74,6 +75,7 @@ func NewGenericDevicePlugin(deviceName string, devicePath string, devices []*plu
dpi := &GenericDevicePlugin{
devs: devices,
socketPath: serverSock,
term: make(chan bool, 1),
healthy: make(chan string),
unhealthy: make(chan string),
deviceName: deviceName,
Expand Down Expand Up @@ -166,12 +168,29 @@ func (dpi *GenericDevicePlugin) Stop() error {
return nil
}

// Send terminate signal to ListAndWatch()
dpi.term <- true

dpi.server.Stop()
dpi.server = nil

return dpi.cleanup()
}

// Restarts DP server
func (dpi *GenericDevicePlugin) restart() error {
log.Printf("Restarting %s device plugin server", dpi.deviceName)
if dpi.server == nil {
return fmt.Errorf("grpc server instance not found for %s", dpi.deviceName)
}

dpi.Stop()

// Create new instance of a grpc server
var stop = make(chan struct{})
return dpi.Start(stop)
}

// Register registers the device plugin for the given resourceName with Kubelet.
func (dpi *GenericDevicePlugin) Register() error {
conn, err := connect(pluginapi.KubeletSocket, connectionTimeout)
Expand Down Expand Up @@ -219,6 +238,8 @@ func (dpi *GenericDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.Dev
s.Send(&pluginapi.ListAndWatchResponse{Devices: dpi.devs})
case <-dpi.stop:
return nil
case <-dpi.term:
return nil
}
}
}
Expand Down Expand Up @@ -315,32 +336,40 @@ func (dpi *GenericDevicePlugin) GetPreferredAllocation(ctx context.Context, in *

//Health check of GPU devices
func (dpi *GenericDevicePlugin) healthCheck() error {
method := fmt.Sprintf("healthCheck(%s)", dpi.deviceName)
log.Printf("%s: invoked", method)
var pathDeviceMap = make(map[string]string)
log.Printf("In health check")
var path = dpi.devicePath
var health = ""

watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Printf("Unable to create fsnotify watcher: %v", err)
return nil
log.Printf("%s: Unable to create fsnotify watcher: %v", method, err)
return err
}
defer watcher.Close()

err = watcher.Add(filepath.Dir(dpi.socketPath))
if err != nil {
log.Printf("%s: Unable to add device plugin socket path to fsnotify watcher: %v", method, err)
return err
}

_, err = os.Stat(path)
if err != nil {
if !os.IsNotExist(err) {
log.Printf("Unable to stat device: %v", err)
log.Printf("%s: Unable to stat device: %v", method, err)
return err
}
}

for _, dev := range dpi.devs {
log.Printf("Path %s", path)
log.Printf("Dev %s", dev.ID)
err = watcher.Add(filepath.Join(path, dev.ID))
pathDeviceMap[filepath.Join(path, dev.ID)] = dev.ID
devicePath := filepath.Join(path, dev.ID)
err = watcher.Add(devicePath)
pathDeviceMap[devicePath] = dev.ID
if err != nil {
log.Printf("Unable to add path to fsnotify watcher: %v", err)
log.Printf("%s: Unable to add device path to fsnotify watcher: %v", method, err)
return err
}
}

Expand All @@ -349,18 +378,27 @@ func (dpi *GenericDevicePlugin) healthCheck() error {
case <-dpi.stop:
return nil
case event := <-watcher.Events:
log.Printf("health Event Op: %v", event.Op)
log.Printf("health Event Name: %s", event.Name)
v, ok := pathDeviceMap[event.Name]
if ok {
// Health in this case is if the device path actually exists
if event.Op == fsnotify.Create {
health = v
dpi.healthy <- health
} else if (event.Op == fsnotify.Remove) || (event.Op == fsnotify.Rename) {
log.Printf("%s: Marking device unhealthy: %s", method, event.Name)
health = v
dpi.unhealthy <- health
}
} else if event.Name == dpi.socketPath && event.Op == fsnotify.Remove {
// Watcher event for removal of socket file
log.Printf("%s: Socket path for GPU device was removed, kubelet likely restarted", method)
// Trigger restart of the DP servers
if err := dpi.restart(); err != nil {
log.Printf("%s: Unable to restart server %v", method, err)
return err
}
log.Printf("%s: Successfully restarted %s device plugin server. Terminating.", method, dpi.deviceName)
return nil
}
}
}
Expand Down
100 changes: 91 additions & 9 deletions pkg/device_plugin/generic_vgpu_device_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ import (
"net"
"os"
"path"
"path/filepath"
"strings"

"github.com/NVIDIA/gpu-monitoring-tools/bindings/go/nvml"
"github.com/fsnotify/fsnotify"
"google.golang.org/grpc"
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
)
Expand All @@ -54,7 +56,8 @@ type GenericVGpuDevicePlugin struct {
devs []*pluginapi.Device
server *grpc.Server
socketPath string
stop chan struct{}
stop chan struct{} // this channel signals to stop the DP
term chan bool // this channel detects kubelet restarts
healthy chan string
unhealthy chan string
devicePath string
Expand All @@ -68,6 +71,7 @@ func NewGenericVGpuDevicePlugin(deviceName string, devicePath string, devices []
dpi := &GenericVGpuDevicePlugin{
devs: devices,
socketPath: serverSock,
term: make(chan bool, 1),
healthy: make(chan string),
unhealthy: make(chan string),
deviceName: deviceName,
Expand Down Expand Up @@ -125,12 +129,29 @@ func (dpi *GenericVGpuDevicePlugin) Stop() error {
return nil
}

// Send terminate signal to ListAndWatch()
dpi.term <- true

dpi.server.Stop()
dpi.server = nil

return dpi.cleanup()
}

// Restarts DP server
func (dpi *GenericVGpuDevicePlugin) restart() error {
log.Printf("Restarting %s device plugin server", dpi.deviceName)
if dpi.server == nil {
return fmt.Errorf("grpc server instance not found for %s", dpi.deviceName)
}

dpi.Stop()

// Create new instance of a grpc server
var stop = make(chan struct{})
return dpi.Start(stop)
}

// Register registers the device plugin for the given resourceName with Kubelet.
func (dpi *GenericVGpuDevicePlugin) Register() error {
conn, err := connect(pluginapi.KubeletSocket, connectionTimeout)
Expand Down Expand Up @@ -178,6 +199,8 @@ func (dpi *GenericVGpuDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi
s.Send(&pluginapi.ListAndWatchResponse{Devices: dpi.devs})
case <-dpi.stop:
return nil
case <-dpi.term:
return nil
}
}
}
Expand Down Expand Up @@ -256,18 +279,54 @@ func (dpi *GenericVGpuDevicePlugin) GetPreferredAllocation(ctx context.Context,

//Health check of vGPU devices
func (dpi *GenericVGpuDevicePlugin) healthCheck() error {
log.Printf("In health check")
log.Println("Loading NVML")
if err := nvmlInit(); err != nil {
log.Printf("Failed to initialize NVML: %s.", err)
method := fmt.Sprintf("healthCheck(%s)", dpi.deviceName)
log.Printf("%s: invoked", method)
var xids chan *nvml.Device
var pathDeviceMap = make(map[string]string)
var path = dpi.devicePath
var health = ""

log.Printf("%s: Loading NVML", method)
if err := nvmlInit(); err == nil {
defer func() { log.Printf("%s: Shutdown of NVML returned: %v", method, nvmlShutdown()) }()
devs := getDevices()
xids = make(chan *nvml.Device)
go watchXIDs(devs, xids)
} else {
log.Printf("%s: Failed to initialize NVML: %s", method, err)
}

watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Printf("%s: Unable to create fsnotify watcher: %v", method, err)
return err
}
defer watcher.Close()

err = watcher.Add(filepath.Dir(dpi.socketPath))
if err != nil {
log.Printf("%s: Unable to add device plugin socket path to fsnotify watcher: %v", method, err)
return err
}

defer func() { log.Println("Shutdown of NVML returned:", nvmlShutdown()) }()
devs := getDevices()
_, err = os.Stat(path)
if err != nil {
if !os.IsNotExist(err) {
log.Printf("%s: Unable to stat device: %v", method, err)
return err
}
}

xids := make(chan *nvml.Device)
go watchXIDs(devs, xids)
for _, dev := range dpi.devs {
devicePath := filepath.Join(path, dev.ID)
log.Printf("%s: Adding watch for device path: %s", method, devicePath)
err = watcher.Add(devicePath)
pathDeviceMap[devicePath] = dev.ID
if err != nil {
log.Printf("%s: Unable to add device path to fsnotify watcher: %v", method, err)
return err
}
}

for {
select {
Expand All @@ -279,6 +338,29 @@ func (dpi *GenericVGpuDevicePlugin) healthCheck() error {
for _, id := range vGpuIDList {
dpi.unhealthy <- id
}
case event := <-watcher.Events:
v, ok := pathDeviceMap[event.Name]
if ok {
// Health in this case is if the device path actually exists
if event.Op == fsnotify.Create {
health = v
dpi.healthy <- health
} else if (event.Op == fsnotify.Remove) || (event.Op == fsnotify.Rename) {
log.Printf("%s: Marking device unhealthy: %s", method, event.Name)
health = v
dpi.unhealthy <- health
}
} else if event.Name == dpi.socketPath && event.Op == fsnotify.Remove {
// Watcher event for removal of socket file
log.Printf("%s: Socket path for GPU device was removed, kubelet likely restarted", method)
// Trigger restart of the DP servers
if err := dpi.restart(); err != nil {
log.Printf("%s: Unable to restart server %v", method, err)
return err
}
log.Printf("%s: Successfully restarted %s device plugin server. Terminating.", method, dpi.deviceName)
return nil
}
}
}

Expand Down

0 comments on commit f7b0769

Please sign in to comment.