Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Device: Fix TPM name errors #13671

Merged
merged 5 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions lxd/device/tpm.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
deviceConfig "github.com/canonical/lxd/lxd/device/config"
"github.com/canonical/lxd/lxd/instance"
"github.com/canonical/lxd/lxd/instance/instancetype"
"github.com/canonical/lxd/lxd/storage/filesystem"
"github.com/canonical/lxd/lxd/subprocess"
"github.com/canonical/lxd/lxd/util"
"github.com/canonical/lxd/shared"
Expand Down Expand Up @@ -97,7 +98,7 @@ func (d *tpm) Start() (*deviceConfig.RunConfig, error) {
return nil, fmt.Errorf("Failed to validate environment: %w", err)
}

tpmDevPath := filepath.Join(d.inst.Path(), fmt.Sprintf("tpm.%s", d.name))
tpmDevPath := filepath.Join(d.inst.Path(), fmt.Sprintf("tpm.%s", filesystem.PathNameEncode(d.name)))

if !shared.PathExists(tpmDevPath) {
err := os.Mkdir(tpmDevPath, 0700)
Expand All @@ -114,8 +115,9 @@ func (d *tpm) Start() (*deviceConfig.RunConfig, error) {
}

func (d *tpm) startContainer() (*deviceConfig.RunConfig, error) {
tpmDevPath := filepath.Join(d.inst.Path(), fmt.Sprintf("tpm.%s", d.name))
logFileName := fmt.Sprintf("tpm.%s.log", d.name)
escapedDeviceName := filesystem.PathNameEncode(d.name)
tpmDevPath := filepath.Join(d.inst.Path(), fmt.Sprintf("tpm.%s", escapedDeviceName))
logFileName := fmt.Sprintf("tpm.%s.log", escapedDeviceName)
logPath := filepath.Join(d.inst.LogPath(), logFileName)

proc, err := subprocess.NewProcess("swtpm", []string{"chardev", "--tpm2", "--tpmstate", fmt.Sprintf("dir=%s", tpmDevPath), "--vtpm-proxy"}, logPath, "")
Expand All @@ -134,7 +136,7 @@ func (d *tpm) startContainer() (*deviceConfig.RunConfig, error) {
// Stop the TPM emulator if anything goes wrong.
revert.Add(func() { _ = proc.Stop() })

pidPath := filepath.Join(d.inst.DevicesPath(), fmt.Sprintf("%s.pid", d.name))
pidPath := filepath.Join(d.inst.DevicesPath(), fmt.Sprintf("%s.pid", escapedDeviceName))

err = proc.Save(pidPath)
if err != nil {
Expand Down Expand Up @@ -212,8 +214,9 @@ func (d *tpm) startVM() (*deviceConfig.RunConfig, error) {
revert := revert.New()
defer revert.Fail()

tpmDevPath := filepath.Join(d.inst.Path(), fmt.Sprintf("tpm.%s", d.name))
socketPath := filepath.Join(tpmDevPath, fmt.Sprintf("swtpm-%s.sock", d.name))
escapedDeviceName := filesystem.PathNameEncode(d.name)
tpmDevPath := filepath.Join(d.inst.Path(), fmt.Sprintf("tpm.%s", escapedDeviceName))
socketPath := filepath.Join(tpmDevPath, fmt.Sprintf("swtpm-%s.sock", escapedDeviceName))
runConf := deviceConfig.RunConfig{
TPMDevice: []deviceConfig.RunConfigItem{
{Key: "devName", Value: d.name},
Expand Down Expand Up @@ -280,7 +283,7 @@ func (d *tpm) startVM() (*deviceConfig.RunConfig, error) {

revert.Add(func() { _ = proc.Stop() })

pidPath := filepath.Join(d.inst.DevicesPath(), fmt.Sprintf("%s.pid", d.name))
pidPath := filepath.Join(d.inst.DevicesPath(), fmt.Sprintf("%s.pid", escapedDeviceName))

err = proc.Save(pidPath)
if err != nil {
Expand Down
47 changes: 7 additions & 40 deletions lxd/instance/drivers/driver_qemu.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@ import (
"bytes"
"compress/gzip"
"context"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -2183,8 +2181,8 @@ func (d *qemu) deviceStart(dev device.Device, instanceRunning bool) (*deviceConf
}

func (d *qemu) deviceAttachPath(deviceName string) (mountTag string, err error) {
deviceID := qemuHostDriveDeviceID(deviceName, "virtio-fs")
mountTag = d.generateQemuDeviceName(deviceName)
deviceID := qemuDeviceNameOrID(qemuDeviceIDPrefix, deviceName, "-virtio-fs", qemuDeviceIDMaxLength)
mountTag = qemuDeviceNameOrID(qemuDeviceNamePrefix, deviceName, "", qemuDeviceNameMaxLength)

// Detect virtiofsd path.
virtiofsdSockPath := filepath.Join(d.DevicesPath(), fmt.Sprintf("virtio-fs.%s.sock", filesystem.PathNameEncode(deviceName)))
Expand Down Expand Up @@ -2310,7 +2308,7 @@ func (d *qemu) deviceAttachBlockDevice(mount deviceConfig.MountEntryItem) error
}

func (d *qemu) deviceDetachPath(deviceName string) error {
deviceID := qemuHostDriveDeviceID(deviceName, "virtio-fs")
deviceID := qemuDeviceNameOrID(qemuDeviceIDPrefix, deviceName, "-virtio-fs", qemuDeviceIDMaxLength)

// Check if the agent is running.
monitor, err := qmp.Connect(d.monitorPath(), qemuSerialChardevName, d.getMonitorEventHandler())
Expand Down Expand Up @@ -2352,7 +2350,7 @@ func (d *qemu) deviceDetachBlockDevice(deviceName string) error {
}

deviceID := fmt.Sprintf("%s%s", qemuDeviceIDPrefix, filesystem.PathNameEncode(deviceName))
blockDevName := d.generateQemuDeviceName(deviceName)
blockDevName := qemuDeviceNameOrID(qemuDeviceNamePrefix, deviceName, "", qemuDeviceNameMaxLength)

err = monitor.RemoveFDFromFDSet(blockDevName)
if err != nil {
Expand Down Expand Up @@ -3730,7 +3728,7 @@ func (d *qemu) addRootDriveConfig(qemuDev map[string]string, mountInfo *storageP

// addDriveDirConfig adds the qemu config required for adding a supplementary drive directory share.
func (d *qemu) addDriveDirConfig(cfg *[]cfgSection, bus *qemuBus, fdFiles *[]*os.File, agentMounts *[]instancetype.VMAgentMount, driveConf deviceConfig.MountEntryItem) error {
mountTag := d.generateQemuDeviceName(driveConf.DevName)
mountTag := qemuDeviceNameOrID(qemuDeviceNamePrefix, driveConf.DevName, "", qemuDeviceNameMaxLength)

agentMount := instancetype.VMAgentMount{
Source: mountTag,
Expand Down Expand Up @@ -3961,7 +3959,7 @@ func (d *qemu) addDriveConfig(qemuDev map[string]string, bootIndexes map[string]
},
"discard": "unmap", // Forward as an unmap request. This is the same as `discard=on` in the qemu config file.
"driver": "file",
"node-name": d.generateQemuDeviceName(driveConf.DevName),
"node-name": qemuDeviceNameOrID(qemuDeviceNamePrefix, driveConf.DevName, "", qemuDeviceNameMaxLength),
"read-only": false,
}

Expand Down Expand Up @@ -8670,7 +8668,7 @@ func (d *qemu) checkFeatures(hostArch int, qemuPath string) (map[string]any, err

// Check io_uring feature.
blockDev := map[string]any{
"node-name": d.generateQemuDeviceName("feature-check"),
"node-name": fmt.Sprintf("%s%s", qemuDeviceNamePrefix, "feature-check"),
"driver": "file",
"filename": blockDevPath.Name(),
"aio": "io_uring",
Expand Down Expand Up @@ -8891,37 +8889,6 @@ func (d *qemu) deviceDetachUSB(usbDev deviceConfig.USBDeviceItem) error {
return nil
}

// hashIfLonger returns a full or partial hash of a name as to fit it within a size limit.
func hashIfLonger(name string, maxLength int) string {
if len(name) <= maxLength {
return name
}

// If the name is too long, hash it as SHA-256 (32 bytes).
// Then encode the SHA-256 binary hash as Base64 Raw URL format and trim down if needed.
hash := sha256.New()
hash.Write([]byte(name))
binaryHash := hash.Sum(nil)

// Raw URL avoids the use of "+" character and the padding "=" character which QEMU doesn't allow.
hashedName := base64.RawURLEncoding.EncodeToString(binaryHash)
if len(hashedName) > maxLength {
hashedName = hashedName[0:maxLength]
}

return hashedName
}

// Block node names and device tags may only be up to 31 characters long, so use a hash if longer.
// Also escapes / to -, and - to --.
func (d *qemu) generateQemuDeviceName(name string) string {
maxNameLength := qemuDeviceNameMaxLength - len(qemuDeviceNamePrefix)
name = hashIfLonger(filesystem.PathNameEncode(name), maxNameLength)

// Apply the lxd_ prefix.
return fmt.Sprintf("%s%s", qemuDeviceNamePrefix, name)
}

func (d *qemu) setCPUs(count int) error {
if count == 0 {
return nil
Expand Down
37 changes: 28 additions & 9 deletions lxd/instance/drivers/driver_qemu_templates.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package drivers

import (
"crypto/sha256"
"encoding/base64"
"fmt"
"strings"

Expand All @@ -9,11 +11,27 @@ import (
"github.com/canonical/lxd/shared/osarch"
)

// qemuHostDriveDeviceID returns the device ID to use for a host drive share.
func qemuHostDriveDeviceID(deviceName string, protocol string) string {
suffix := "-" + protocol
maxNameLength := qemuDeviceIDMaxLength - (len(qemuDeviceIDPrefix) + len(suffix))
return fmt.Sprintf("%s%s%s", qemuDeviceIDPrefix, hashIfLonger(filesystem.PathNameEncode(deviceName), maxNameLength), suffix)
// qemuDeviceNameOrID generates a QEMU device name or ID.
// Respects the property length limit by hashing the device name when necessary. Also escapes / to -, and - to --.
func qemuDeviceNameOrID(prefix string, deviceName string, suffix string, maxLength int) string {
baseName := filesystem.PathNameEncode(deviceName)
maxNameLength := maxLength - (len(prefix) + len(suffix))

if len(baseName) > maxNameLength {
// If the name is too long, hash it as SHA-256 (32 bytes).
// Then encode the SHA-256 binary hash as Base64 Raw URL format and trim down if needed.
hash := sha256.New()
hash.Write([]byte(baseName))
binaryHash := hash.Sum(nil)

// Raw URL avoids the use of "+" character and the padding "=" character which QEMU doesn't allow.
baseName = base64.RawURLEncoding.EncodeToString(binaryHash)
if len(baseName) > maxNameLength {
baseName = baseName[0:maxNameLength]
}
}

return fmt.Sprintf("%s%s%s", prefix, baseName, suffix)
}

type cfgEntry struct {
Expand Down Expand Up @@ -747,7 +765,7 @@ type qemuDriveDirOpts struct {
func qemuDriveDir(opts *qemuDriveDirOpts) []cfgSection {
return qemuHostDrive(&qemuHostDriveOpts{
dev: opts.dev,
id: qemuHostDriveDeviceID(opts.devName, opts.protocol),
id: qemuDeviceNameOrID(qemuDeviceIDPrefix, opts.devName, "-"+opts.protocol, qemuDeviceIDMaxLength),
// Devices use "lxd_" prefix indicating that this is a user named device.
name: fmt.Sprintf("lxd_%s", opts.devName),
comment: fmt.Sprintf("%s drive (%s)", opts.devName, opts.protocol),
Expand Down Expand Up @@ -874,8 +892,9 @@ type qemuTPMOpts struct {
}

func qemuTPM(opts *qemuTPMOpts) []cfgSection {
chardev := fmt.Sprintf("qemu_tpm-chardev_%s", opts.devName)
tpmdev := fmt.Sprintf("qemu_tpm-tpmdev_%s", opts.devName)
chardev := qemuDeviceNameOrID("qemu_tpm-chardev_", opts.devName, "", qemuDeviceIDMaxLength)
tpmdev := qemuDeviceNameOrID("qemu_tpm-tpmdev_", opts.devName, "", qemuDeviceIDMaxLength)
device := qemuDeviceNameOrID(qemuDeviceIDPrefix, opts.devName, "", qemuDeviceIDMaxLength)

return []cfgSection{{
name: fmt.Sprintf(`chardev "%s"`, chardev),
Expand All @@ -890,7 +909,7 @@ func qemuTPM(opts *qemuTPMOpts) []cfgSection {
{key: "chardev", value: chardev},
},
}, {
name: fmt.Sprintf(`device "dev-lxd_%s"`, opts.devName),
name: fmt.Sprintf(`device "%s"`, device),
entries: []cfgEntry{
{key: "driver", value: "tpm-crb"},
{key: "tpmdev", value: tpmdev},
Expand Down
Loading