Skip to content

Commit

Permalink
device: Allow uevent handler to be stopped
Browse files Browse the repository at this point in the history
Pass a context to `waitForDevice()` to allow the goroutine started by
this function to be stopped. This specifically for testing
purposes.

Due to how and where `waitForDevice()` is called, this change required
adding a `context` parameter to a number of other functions, notable to
the `deviceHandler` and `storageHandler` functions.

Also moved the `timeoutHotplug` `const` variable to `device.go` (where
it is actually used) and made it a true `var` to allow the tests to
manipulate it.

Enabling the stopping of the uevent handler also allowed a lot of new
unit tests to be added.

Signed-off-by: James O. D. Hunt <james.o.hunt@intel.com>
  • Loading branch information
jodh-intel committed Jul 18, 2019
1 parent 8eb2134 commit d4a22d1
Show file tree
Hide file tree
Showing 5 changed files with 582 additions and 77 deletions.
85 changes: 49 additions & 36 deletions device.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package main

import (
"context"
"errors"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -40,10 +41,15 @@ const (
)

var (
sysBusPrefix = sysfsDir + "/bus/pci/devices"
pciBusRescanFile = sysfsDir + "/bus/pci/rescan"
pciBusPathFormat = "%s/%s/pci_bus/"
systemDevPath = "/dev"
sysBusPrefix = sysfsDir + "/bus/pci/devices"
pciBusRescanFile = sysfsDir + "/bus/pci/rescan"
pciBusPathFormat = "%s/%s/pci_bus/"
systemDevPath = "/dev"
timeoutHotplug = 3
getSCSIDevPath = getSCSIDevPathImpl
getPCIDeviceName = getPCIDeviceNameImpl
getDevicePCIAddress = getDevicePCIAddressImpl
scanSCSIBus = scanSCSIBusImpl
)

// SCSI variables
Expand All @@ -59,7 +65,7 @@ var (
scsiHostPath = filepath.Join(sysClassPrefix, "scsi_host")
)

type deviceHandler func(device pb.Device, spec *pb.Spec, s *sandbox) error
type deviceHandler func(ctx context.Context, device pb.Device, spec *pb.Spec, s *sandbox) error

var deviceHandlerList = map[string]deviceHandler{
driverMmioBlkType: virtioMmioBlkDeviceHandler,
Expand All @@ -76,7 +82,7 @@ func rescanPciBus() error {
// identifier provided. This should be in the format: "bridgeAddr/deviceAddr".
// Here, bridgeAddr is the address at which the brige is attached on the root bus,
// while deviceAddr is the address at which the device is attached on the bridge.
func getDevicePCIAddress(pciID string) (string, error) {
func getDevicePCIAddressImpl(pciID string) (string, error) {
tokens := strings.Split(pciID, "/")

if len(tokens) != 2 {
Expand Down Expand Up @@ -115,7 +121,7 @@ func getDevicePCIAddress(pciID string) (string, error) {
return bridgeDevicePCIAddr, nil
}

func getPCIDeviceName(s *sandbox, pciID string) (string, error) {
func getPCIDeviceNameImpl(s *sandbox, pciID string) (string, error) {
pciAddr, err := getDevicePCIAddress(pciID)
if err != nil {
return "", err
Expand Down Expand Up @@ -174,7 +180,7 @@ func getPCIDeviceName(s *sandbox, pciID string) (string, error) {

// device.Id should be the predicted device name (vda, vdb, ...)
// device.VmPath already provides a way to send it in
func virtioMmioBlkDeviceHandler(device pb.Device, spec *pb.Spec, s *sandbox) error {
func virtioMmioBlkDeviceHandler(_ context.Context, device pb.Device, spec *pb.Spec, s *sandbox) error {
if device.VmPath == "" {
return fmt.Errorf("Invalid path for virtioMmioBlkDevice")
}
Expand All @@ -185,7 +191,7 @@ func virtioMmioBlkDeviceHandler(device pb.Device, spec *pb.Spec, s *sandbox) err
// device.Id should be the PCI address in the format "bridgeAddr/deviceAddr".
// Here, bridgeAddr is the address at which the brige is attached on the root bus,
// while deviceAddr is the address at which the device is attached on the bridge.
func virtioBlkDeviceHandler(device pb.Device, spec *pb.Spec, s *sandbox) error {
func virtioBlkDeviceHandler(_ context.Context, device pb.Device, spec *pb.Spec, s *sandbox) error {
// Get the device node path based on the PCI device address
devPath, err := getPCIDeviceName(s, device.Id)
if err != nil {
Expand All @@ -197,9 +203,9 @@ func virtioBlkDeviceHandler(device pb.Device, spec *pb.Spec, s *sandbox) error {
}

// device.Id should be the SCSI address of the disk in the format "scsiID:lunID"
func virtioSCSIDeviceHandler(device pb.Device, spec *pb.Spec, s *sandbox) error {
func virtioSCSIDeviceHandler(ctx context.Context, device pb.Device, spec *pb.Spec, s *sandbox) error {
// Retrieve the device path from SCSI address.
devPath, err := getSCSIDevPath(device.Id)
devPath, err := getSCSIDevPath(ctx, device.Id)
if err != nil {
return err
}
Expand All @@ -208,7 +214,7 @@ func virtioSCSIDeviceHandler(device pb.Device, spec *pb.Spec, s *sandbox) error
return updateSpecDeviceList(device, spec)
}

func nvdimmDeviceHandler(device pb.Device, spec *pb.Spec, s *sandbox) error {
func nvdimmDeviceHandler(_ context.Context, device pb.Device, spec *pb.Spec, s *sandbox) error {
return updateSpecDeviceList(device, spec)
}

Expand Down Expand Up @@ -288,7 +294,7 @@ func updateSpecDeviceList(device pb.Device, spec *pb.Spec) error {

type checkUeventCb func(uEv *uevent.Uevent) bool

func waitForDevice(devicePath, deviceName string, checkUevent checkUeventCb) error {
func waitForDevice(ctx context.Context, devicePath, deviceName string, checkUevent checkUeventCb) error {
if devicePath == "" {
return errors.New("need device path")
}
Expand Down Expand Up @@ -324,25 +330,32 @@ func waitForDevice(devicePath, deviceName string, checkUevent checkUeventCb) err
// This loop will be either ended if the hotplugged device is
// found by listening to the netlink socket, or it will end
// after the function returns and the uevent handler is closed.
outer:
for {
uEv, err := uEvHandler.Read()
if err != nil {
fieldLogger.Error(err)
continue
}
select {
case <-ctx.Done():
break
default:

uEv, err := uEvHandler.Read()
if err != nil {
fieldLogger.Error(err)
continue
}

fieldLogger = fieldLogger.WithFields(logrus.Fields{
"uevent-action": uEv.Action,
"uevent-devpath": uEv.DevPath,
"uevent-subsystem": uEv.SubSystem,
"uevent-seqnum": uEv.SeqNum,
})
fieldLogger = fieldLogger.WithFields(logrus.Fields{
"uevent-action": uEv.Action,
"uevent-devpath": uEv.DevPath,
"uevent-subsystem": uEv.SubSystem,
"uevent-seqnum": uEv.SeqNum,
})

fieldLogger.Info("Got uevent")
fieldLogger.Info("Got uevent")

if checkUevent(uEv) {
fieldLogger.Info("Hotplug event received")
break
if checkUevent(uEv) {
fieldLogger.Info("Hotplug event received")
break outer
}
}
}

Expand All @@ -361,7 +374,7 @@ func waitForDevice(devicePath, deviceName string, checkUevent checkUeventCb) err
}

// scanSCSIBus scans SCSI bus for the given SCSI address(SCSI-Id and LUN)
func scanSCSIBus(scsiAddr string) error {
func scanSCSIBusImpl(scsiAddr string) error {
files, err := ioutil.ReadDir(scsiHostPath)
if err != nil {
return err
Expand Down Expand Up @@ -407,10 +420,10 @@ func findSCSIDisk(scsiPath string) (string, error) {
return files[0].Name(), nil
}

// getSCSIDevPath scans SCSI bus looking for the provided SCSI address, then
// getSCSIDevPathImpl scans SCSI bus looking for the provided SCSI address, then
// it waits for the SCSI disk to become available and returns the device path
// associated with the disk.
func getSCSIDevPath(scsiAddr string) (string, error) {
func getSCSIDevPathImpl(ctx context.Context, scsiAddr string) (string, error) {
if err := scanSCSIBus(scsiAddr); err != nil {
return "", err
}
Expand All @@ -422,7 +435,7 @@ func getSCSIDevPath(scsiAddr string) (string, error) {
return (uEv.Action == "add" &&
strings.Contains(uEv.DevPath, devSubPath))
}
if err := waitForDevice(devPath, scsiAddr, checkUevent); err != nil {
if err := waitForDevice(ctx, devPath, scsiAddr, checkUevent); err != nil {
return "", err
}

Expand All @@ -434,13 +447,13 @@ func getSCSIDevPath(scsiAddr string) (string, error) {
return filepath.Join(devPrefix, scsiDiskName), nil
}

func addDevices(devices []*pb.Device, spec *pb.Spec, s *sandbox) error {
func addDevices(ctx context.Context, devices []*pb.Device, spec *pb.Spec, s *sandbox) error {
for _, device := range devices {
if device == nil {
continue
}

err := addDevice(device, spec, s)
err := addDevice(ctx, device, spec, s)
if err != nil {
return err
}
Expand All @@ -450,7 +463,7 @@ func addDevices(devices []*pb.Device, spec *pb.Spec, s *sandbox) error {
return nil
}

func addDevice(device *pb.Device, spec *pb.Spec, s *sandbox) error {
func addDevice(ctx context.Context, device *pb.Device, spec *pb.Spec, s *sandbox) error {
if device == nil {
return grpcStatus.Error(codes.InvalidArgument, "invalid device")
}
Expand Down Expand Up @@ -490,5 +503,5 @@ func addDevice(device *pb.Device, spec *pb.Spec, s *sandbox) error {
"Unknown device type %q", device.Type)
}

return devHandler(*device, spec, s)
return devHandler(ctx, *device, spec, s)
}
Loading

0 comments on commit d4a22d1

Please sign in to comment.