diff --git a/agent.go b/agent.go index 7e2a86df48..95aae80ccd 100644 --- a/agent.go +++ b/agent.go @@ -156,6 +156,22 @@ var collatedTrace = false // if true, coredump when an internal error occurs or a fatal signal is received var crashOnError = false +// commType is used to denote the communication channel type used. +type commType int + +const ( + // virtio-serial channel + serialCh commType = iota + + // vsock channel + vsockCh + + // channel type not passed explicitly + unknownCh +) + +var commCh = unknownCh + // This is the list of file descriptors we can properly close after the process // has been started. When the new process is exec(), those file descriptors are // duplicated and it is our responsibility to close them since we have opened diff --git a/channel.go b/channel.go index f39ef50873..6590fbc887 100644 --- a/channel.go +++ b/channel.go @@ -50,24 +50,29 @@ func newChannel(ctx context.Context) (channel, error) { defer span.Finish() var serialErr error - var serialPath string var vsockErr error - var vSockSupported bool + var ch channel for i := 0; i < channelExistMaxTries; i++ { - // check vsock path - if _, err := os.Stat(vSockDevPath); err == nil { - if vSockSupported, vsockErr = isAFVSockSupportedFunc(); vSockSupported && vsockErr == nil { - span.SetTag("channel-type", "vsock") - return &vSockChannel{}, nil + switch commCh { + case serialCh: + if ch, serialErr = checkForSerialChannel(ctx); serialErr == nil && ch.(*serialChannel) != nil { + return ch, nil + } + case vsockCh: + if ch, vsockErr = checkForVsockChannel(ctx); vsockErr == nil && ch.(*vSockChannel) != nil { + return ch, nil } - } - // Check serial port path - if serialPath, serialErr = findVirtualSerialPath(serialChannelName); serialErr == nil { - span.SetTag("channel-type", "serial") - span.SetTag("serial-path", serialPath) - return &serialChannel{serialPath: serialPath}, nil + case unknownCh: + // If we have not been explicitly passed if vsock is used or not, maybe due to + // an older runtime, try to check for vsock support. + if ch, vsockErr = checkForVsockChannel(ctx); vsockErr == nil && ch.(*vSockChannel) != nil { + return ch, nil + } + if ch, serialErr = checkForSerialChannel(ctx); serialErr == nil && ch.(*serialChannel) != nil { + return ch, nil + } } time.Sleep(channelExistWaitTime) @@ -84,6 +89,44 @@ func newChannel(ctx context.Context) (channel, error) { return nil, fmt.Errorf("Neither vsocks nor serial ports were found") } +func checkForSerialChannel(ctx context.Context) (*serialChannel, error) { + span, _ := trace(ctx, "channel", "checkForSerialChannel") + defer span.Finish() + + // Check serial port path + serialPath, serialErr := findVirtualSerialPath(serialChannelName) + if serialErr == nil { + span.SetTag("channel-type", "serial") + span.SetTag("serial-path", serialPath) + agentLog.Debug("Serial channel type detected") + return &serialChannel{serialPath: serialPath}, nil + } + + return nil, serialErr +} + +func checkForVsockChannel(ctx context.Context) (*vSockChannel, error) { + span, _ := trace(ctx, "channel", "checkForVsockChannel") + defer span.Finish() + + // check vsock path + var err error + _, err = os.Stat(vSockDevPath) + + if err != nil { + return nil, err + } + + vSockSupported, vsockErr := isAFVSockSupportedFunc() + if vSockSupported && vsockErr == nil { + span.SetTag("channel-type", "vsock") + agentLog.Debug("Vsock channel type detected") + return &vSockChannel{}, nil + } + + return nil, fmt.Errorf("Vsock not found : %s", vsockErr) +} + type vSockChannel struct { } @@ -228,23 +271,51 @@ func (c *serialChannel) teardown() error { return c.serialConn.Close() } +// isAFVSockSupported checks if vsock channel is used by the runtime +// by checking for devices under the vhost-vsock driver path. +// It returns true if more a device is found for the driver. func isAFVSockSupported() (bool, error) { - fd, err := unix.Socket(unix.AF_VSOCK, unix.SOCK_STREAM, 0) - if err != nil { - // This case is valid. It means AF_VSOCK is not a supported - // domain on this system. - if err == unix.EAFNOSUPPORT { - return false, nil - } + // Driver path for virtio-vsock + sysVsockPath := "/sys/bus/virtio/drivers/vmw_vsock_virtio_transport/" + files, err := ioutil.ReadDir(sysVsockPath) + + // This should not happen for a hypervisor with vsock driver + if err != nil { return false, err } - if err := unix.Close(fd); err != nil { - return true, err + // standard driver files that should be ignored + driverFiles := []string{"bind", "uevent", "unbind"} + + for _, file := range files { + for _, f := range driverFiles { + if file.Name() == f { + continue + } + } + + fPath := filepath.Join(sysVsockPath, file.Name()) + fInfo, err := os.Lstat(fPath) + if err != nil { + return false, err + } + + if fInfo.Mode()&os.ModeSymlink == 0 { + continue + } + + link, err := os.Readlink(fPath) + if err != nil { + return false, err + } + + if strings.Contains(link, "devices") { + return true, nil + } } - return true, nil + return false, nil } func findVirtualSerialPath(serialName string) (string, error) { diff --git a/config.go b/config.go index 0f31e6ea1a..9d4ecd2364 100644 --- a/config.go +++ b/config.go @@ -8,6 +8,7 @@ package main import ( "io/ioutil" + "strconv" "strings" "github.com/sirupsen/logrus" @@ -20,6 +21,7 @@ const ( logLevelFlag = optionPrefix + "log" devModeFlag = optionPrefix + "devmode" traceModeFlag = optionPrefix + "trace" + useVsockFlag = optionPrefix + "use_vsock" kernelCmdlineFile = "/proc/cmdline" traceValueIsolated = "isolated" traceValueCollated = "collated" @@ -102,6 +104,16 @@ func (c *agentConfig) parseCmdlineOption(option string) error { case traceValueCollated: enableTracing(true) } + case useVsockFlag: + if flag, err := strconv.ParseBool(split[valuePosition]); err == nil { + if flag { + agentLog.Debug("Param passed to use vsock channel") + commCh = vsockCh + } else { + agentLog.Debug("Param passed to NOT use vsock channel") + commCh = serialCh + } + } default: if strings.HasPrefix(split[optionPosition], optionPrefix) { return grpcStatus.Errorf(codes.NotFound, "Unknown option %s", split[optionPosition])