Skip to content
This repository has been archived by the owner on May 12, 2021. It is now read-only.

Use PCI Addresses to determine the device names for virtio-blk devices #227

Merged
merged 4 commits into from
May 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"os"
"os/exec"
"os/signal"
"path/filepath"
"runtime"
"runtime/debug"
"strings"
Expand All @@ -23,6 +24,7 @@ import (
"time"

"github.com/gogo/protobuf/proto"
"github.com/kata-containers/agent/pkg/uevent"
pb "github.com/kata-containers/agent/protocols/grpc"
"github.com/opencontainers/runc/libcontainer"
"github.com/opencontainers/runc/libcontainer/configs"
Expand Down Expand Up @@ -74,6 +76,8 @@ type sandbox struct {
mounts []string
subreaper reaper
server *grpc.Server
pciDeviceMap map[string]string
deviceWatchers map[string](chan string)
}

type namespace struct {
Expand Down Expand Up @@ -305,6 +309,55 @@ func (s *sandbox) teardownSharedPidNs() error {
return nil
}

func (s *sandbox) listenToUdevEvents() {
fieldLogger := agentLog.WithField("subsystem", "udevlistener")

uEvHandler, err := uevent.NewHandler()
if err != nil {
fieldLogger.Warnf("Error starting uevent listening loop %s", err)
return
}
defer uEvHandler.Close()

for {
uEv, err := uEvHandler.Read()
if err != nil {
fieldLogger.Error(err)
continue
}

fieldLogger = fieldLogger.WithFields(logrus.Fields{
"uevent-action": uEv.Action,
"uevent-devpath": uEv.DevPath,
"uevent-subsystem": uEv.SubSystem,
"uevent-seqnum": uEv.SeqNum,
"uevent-devname": uEv.DevName,
})

// Check if device hotplug event results in a device node being created.
if uEv.DevName != "" && uEv.Action == "add" && strings.HasPrefix(uEv.DevPath, rootBusPath) {
// Lock is needed to safey read and modify the pciDeviceMap and deviceWatchers.
// This makes sure that watchers do not access the map while it is being updated.
s.Lock()

// Add the device node name to the pci device map.
s.pciDeviceMap[uEv.DevPath] = uEv.DevName

// Notify watchers that are interested in the udev event.
// Close the channel after watcher has been notified.
for devPCIAddress, ch := range s.deviceWatchers {
if ch != nil && strings.HasPrefix(uEv.DevPath, filepath.Join(rootBusPath, devPCIAddress)) {
ch <- uEv.DevName
close(ch)
delete(s.deviceWatchers, uEv.DevName)
}
}

s.Unlock()
}
}
}

// This loop is meant to be run inside a separate Go routine.
func (s *sandbox) reaperLoop(sigCh chan os.Signal) {
for sig := range sigCh {
Expand Down Expand Up @@ -643,8 +696,10 @@ func main() {
running: false,
// pivot_root won't work for init, see
// Documention/filesystem/ramfs-rootfs-initramfs.txt
noPivotRoot: os.Getpid() == 1,
subreaper: r,
noPivotRoot: os.Getpid() == 1,
subreaper: r,
pciDeviceMap: make(map[string]string),
deviceWatchers: make(map[string](chan string)),
}

if err = s.initLogger(); err != nil {
Expand All @@ -665,5 +720,7 @@ func main() {
// Start gRPC server.
s.startGRPC()

go s.listenToUdevEvents()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please let s.wg.Wait() also wait for the udev listener goroutine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The udev listener routine does not quit for any reason, thats why I have not added the wait for it.
It should quit when the main thread quits.


s.wg.Wait()
}
132 changes: 116 additions & 16 deletions device.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ const (
driverSCSIType = "scsi"
)

const rootBusPath = "/devices/pci0000:00"

var (
sysBusPrefix = "/sys/bus/pci/devices"
pciBusPathFormat = "%s/%s/pci_bus/"
systemDevPath = "/dev"
)

// SCSI variables
var (
// Here in "0:0", the first number is the SCSI host number because
Expand All @@ -42,31 +50,123 @@ var (
scsiHostPath = filepath.Join(sysClassPrefix, "scsi_host")
)

type deviceHandler func(device pb.Device, spec *pb.Spec) error
type deviceHandler func(device pb.Device, spec *pb.Spec, s *sandbox) error

var deviceHandlerList = map[string]deviceHandler{
driverBlkType: virtioBlkDeviceHandler,
driverSCSIType: virtioSCSIDeviceHandler,
}

func virtioBlkDeviceHandler(device pb.Device, spec *pb.Spec) error {
// First need to make sure the expected device shows up properly,
// and then we need to retrieve its device info (such as major and
// minor numbers), useful to update the device provided
// through the OCI specification.
devName := strings.TrimPrefix(device.VmPath, devPrefix)
checkUevent := func(uEv *uevent.Uevent) bool {
return (uEv.Action == "add" &&
filepath.Base(uEv.DevPath) == devName)
// getDevicePCIAddress fetches the complete PCI address in sysfs, based on the PCI
// identifier provided. This should be in the format: "bridgeAddr/deviceAddr".
// Here, bridgeAddr is the address at which the brige is attached on the root bus,
// while deviceAddr is the address at which the device is attached on the bridge.
func getDevicePCIAddress(pciID string) (string, error) {
tokens := strings.Split(pciID, "/")

if len(tokens) != 2 {
return "", fmt.Errorf("PCI Identifier for device should be of format [bridgeAddr/deviceAddr], got %s", pciID)
}

bridgeID := tokens[0]
deviceID := tokens[1]

// Deduce the complete bridge address based on the bridge address identifier passed
// and the fact that bridges are attached on the main bus with function 0.
pciBridgeAddr := fmt.Sprintf("0000:00:%s.0", bridgeID)

// Find out the bus exposed by bridge
bridgeBusPath := fmt.Sprintf(pciBusPathFormat, sysBusPrefix, pciBridgeAddr)

files, err := ioutil.ReadDir(bridgeBusPath)
if err != nil {
return "", fmt.Errorf("Error with getting bridge pci bus : %s", err)
}

busNum := len(files)
if busNum != 1 {
return "", fmt.Errorf("Expected an entry for bus in %s, got %d entries instead", bridgeBusPath, busNum)
}
if err := waitForDevice(device.VmPath, devName, checkUevent); err != nil {

bus := files[0].Name()

// Device address is based on the bus of the bridge to which it is attached.
// We do not pass devices as multifunction, hence the trailing 0 in the address.
pciDeviceAddr := fmt.Sprintf("%s:%s.0", bus, deviceID)

bridgeDevicePCIAddr := fmt.Sprintf("%s/%s", pciBridgeAddr, pciDeviceAddr)
agentLog.WithField("completePCIAddr", bridgeDevicePCIAddr).Info("Fetched PCI address for device")

return bridgeDevicePCIAddr, nil
}

func getBlockDeviceNodeName(s *sandbox, pciID string) (string, error) {
pciAddr, err := getDevicePCIAddress(pciID)
if err != nil {
return "", err
}

var devName string
var notifyChan chan string

fieldLogger := agentLog.WithField("pciID", pciID)

// Check if the PCI identifier is in PCI device map.
s.Lock()
for key, value := range s.pciDeviceMap {
if strings.Contains(key, pciAddr) {
devName = value
fieldLogger.Info("Device found in pci device map")
break
}
}

// If device is not found in the device map, hotplug event has not
// been received yet, create and add channel to the watchers map.
// The key of the watchers map is the device we are interested in.
// Note this is done inside the lock, not to miss any events from the
// global udev listener.
if devName == "" {
notifyChan := make(chan string, 1)
s.deviceWatchers[pciAddr] = notifyChan
}
s.Unlock()

if devName == "" {
fieldLogger.Info("Waiting on channel for device notification")
select {
case devName = <-notifyChan:
case <-time.After(time.Duration(timeoutHotplug) * time.Second):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the watcher and close its notification channel on timeout. Otherwise the channel may exist forever if the device never shows up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

s.Lock()
delete(s.deviceWatchers, pciAddr)
close(notifyChan)
s.Unlock()

return "", grpcStatus.Errorf(codes.DeadlineExceeded,
"Timeout reached after %ds waiting for device %s",
timeoutHotplug, pciAddr)
}
}

return filepath.Join(systemDevPath, devName), nil
}

// device.Id should be the PCI address in the format "bridgeAddr/deviceAddr".
// Here, bridgeAddr is the address at which the brige is attached on the root bus,
// while deviceAddr is the address at which the device is attached on the bridge.
func virtioBlkDeviceHandler(device pb.Device, spec *pb.Spec, s *sandbox) error {
// Get the device node path based on the PCI device address
devPath, err := getBlockDeviceNodeName(s, device.Id)
if err != nil {
return err
}
device.VmPath = devPath

return updateSpecDeviceList(device, spec)
}

func virtioSCSIDeviceHandler(device pb.Device, spec *pb.Spec) error {
// device.Id should be the SCSI address of the disk in the format "scsiID:lunID"
func virtioSCSIDeviceHandler(device pb.Device, spec *pb.Spec, s *sandbox) error {
// Retrieve the device path from SCSI address.
devPath, err := getSCSIDevPath(device.Id)
if err != nil {
Expand Down Expand Up @@ -270,13 +370,13 @@ func getSCSIDevPath(scsiAddr string) (string, error) {
return filepath.Join(devPrefix, scsiDiskName), nil
}

func addDevices(devices []*pb.Device, spec *pb.Spec) error {
func addDevices(devices []*pb.Device, spec *pb.Spec, s *sandbox) error {
for _, device := range devices {
if device == nil {
continue
}

err := addDevice(device, spec)
err := addDevice(device, spec, s)
if err != nil {
return err
}
Expand All @@ -286,7 +386,7 @@ func addDevices(devices []*pb.Device, spec *pb.Spec) error {
return nil
}

func addDevice(device *pb.Device, spec *pb.Spec) error {
func addDevice(device *pb.Device, spec *pb.Spec, s *sandbox) error {
if device == nil {
return grpcStatus.Error(codes.InvalidArgument, "invalid device")
}
Expand Down Expand Up @@ -326,5 +426,5 @@ func addDevice(device *pb.Device, spec *pb.Spec) error {
"Unknown device type %q", device.Type)
}

return devHandler(*device, spec)
return devHandler(*device, spec, s)
}
57 changes: 51 additions & 6 deletions device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func testVirtioBlkDeviceHandlerFailure(t *testing.T, device pb.Device, spec *pb.
device.VmPath = devPath
device.ContainerPath = "some-not-empty-path"

err = virtioBlkDeviceHandler(device, spec)
err = virtioBlkDeviceHandler(device, spec, &sandbox{})
assert.NotNil(t, err, "blockDeviceHandler() should have failed")
}

Expand Down Expand Up @@ -73,6 +73,49 @@ func TestVirtioBlkDeviceHandlerEmptyLinuxDevicesSpecFailure(t *testing.T) {
testVirtioBlkDeviceHandlerFailure(t, device, spec)
}

func TestGetPCIAddress(t *testing.T) {
testDir, err := ioutil.TempDir("", "kata-agent-tmp-")
if err != nil {
t.Fatal(t, err)
}
defer os.RemoveAll(testDir)

pciID := "02"
_, err = getDevicePCIAddress(pciID)
assert.NotNil(t, err)

pciID = "02/03/04"
_, err = getDevicePCIAddress(pciID)
assert.NotNil(t, err)

bridgeID := "02"
deviceID := "03"
pciBus := "0000:01"
expectedPCIAddress := "0000:00:02.0/0000:01:03.0"
pciID = fmt.Sprintf("%s/%s", bridgeID, deviceID)

// Set sysBusPrefix to test directory for unit tests.
sysBusPrefix = testDir
bridgeBusPath := fmt.Sprintf(pciBusPathFormat, sysBusPrefix, "0000:00:02.0")

_, err = getDevicePCIAddress(pciID)
assert.NotNil(t, err)

err = os.MkdirAll(bridgeBusPath, mountPerm)
assert.Nil(t, err)

_, err = getDevicePCIAddress(pciID)
assert.NotNil(t, err)

err = os.MkdirAll(filepath.Join(bridgeBusPath, pciBus), mountPerm)
assert.Nil(t, err)

addr, err := getDevicePCIAddress(pciID)
assert.Nil(t, err)

assert.Equal(t, addr, expectedPCIAddress)
}

func TestScanSCSIBus(t *testing.T) {
testDir, err := ioutil.TempDir("", "kata-agent-tmp-")
if err != nil {
Expand Down Expand Up @@ -112,7 +155,7 @@ func TestScanSCSIBus(t *testing.T) {
}

func testAddDevicesSuccessful(t *testing.T, devices []*pb.Device, spec *pb.Spec) {
err := addDevices(devices, spec)
err := addDevices(devices, spec, &sandbox{})
assert.Nil(t, err, "addDevices() failed: %v", err)
}

Expand All @@ -133,11 +176,11 @@ func TestAddDevicesNilMountsSuccessful(t *testing.T) {
testAddDevicesSuccessful(t, devices, spec)
}

func noopDeviceHandlerReturnNil(device pb.Device, spec *pb.Spec) error {
func noopDeviceHandlerReturnNil(device pb.Device, spec *pb.Spec, s *sandbox) error {
return nil
}

func noopDeviceHandlerReturnError(device pb.Device, spec *pb.Spec) error {
func noopDeviceHandlerReturnError(device pb.Device, spec *pb.Spec, s *sandbox) error {
return fmt.Errorf("Noop handler failure")
}

Expand All @@ -159,7 +202,7 @@ func TestAddDevicesNoopHandlerSuccessful(t *testing.T) {
}

func testAddDevicesFailure(t *testing.T, devices []*pb.Device, spec *pb.Spec) {
err := addDevices(devices, spec)
err := addDevices(devices, spec, &sandbox{})
assert.NotNil(t, err, "addDevices() should have failed")
}

Expand Down Expand Up @@ -319,8 +362,10 @@ func TestAddDevice(t *testing.T) {
},
}

s := &sandbox{}

for i, d := range data {
err := addDevice(d.device, d.spec)
err := addDevice(d.device, d.spec, s)
if d.expectError {
assert.Errorf(err, "test %d (%+v)", i, d)
} else {
Expand Down
Loading