diff --git a/internal/wclayer/cim/mount.go b/internal/wclayer/cim/mount.go new file mode 100644 index 0000000000..abb15dabc9 --- /dev/null +++ b/internal/wclayer/cim/mount.go @@ -0,0 +1,87 @@ +package cim + +import ( + "context" + "fmt" + "os" + "sync" + + "github.com/Microsoft/go-winio/pkg/guid" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" + cimfs "github.com/Microsoft/hcsshim/pkg/cimfs" +) + +// a cache of cim layer to its mounted volume - The mount manager plugin currently doesn't have an option of +// querying a mounted cim to get the volume at which it is mounted, so we maintain a cache of that here +var ( + cimMounts map[string]string = make(map[string]string) + cimMountMapLock sync.Mutex + // A random GUID used as a namespace for generating cim mount volume GUIDs: 6827367b-c388-4e9b-95ec-961c6d2c936c + cimMountNamespace guid.GUID = guid.GUID{Data1: 0x6827367b, Data2: 0xc388, Data3: 0x4e9b, Data4: [8]byte{0x96, 0x1c, 0x6d, 0x2c, 0x93, 0x6c}} +) + +// MountCimLayer mounts the cim at path `cimPath` and returns the mount location of that cim. This method +// uses the `CimMountFlagCacheFiles` mount flag when mounting the cim. The containerID is used to generated +// the volumeID for the volume at which this CIM is mounted. containerID is used so that if the shim process +// crashes for any reason, the mounted cim can be correctly cleaned up during `shim delete` call. +func MountCimLayer(ctx context.Context, cimPath, containerID string) (string, error) { + volumeGUID, err := guid.NewV5(cimMountNamespace, []byte(containerID)) + if err != nil { + return "", fmt.Errorf("generated cim mount GUID: %w", err) + } + + vol, err := cimfs.Mount(cimPath, volumeGUID, hcsschema.CimMountFlagCacheFiles) + if err != nil { + return "", err + } + + cimMountMapLock.Lock() + defer cimMountMapLock.Unlock() + cimMounts[fmt.Sprintf("%s_%s", containerID, cimPath)] = vol + + return vol, nil +} + +// Unmount unmounts the cim at mounted for given container. +func UnmountCimLayer(ctx context.Context, cimPath, containerID string) error { + cimMountMapLock.Lock() + defer cimMountMapLock.Unlock() + if vol, ok := cimMounts[fmt.Sprintf("%s_%s", containerID, cimPath)]; !ok { + return fmt.Errorf("cim %s not mounted", cimPath) + } else { + delete(cimMounts, fmt.Sprintf("%s_%s", containerID, cimPath)) + err := cimfs.Unmount(vol) + if err != nil { + return err + } + } + return nil +} + +// GetCimMountPath returns the volume at which a cim is mounted. If the cim is not mounted returns error +func GetCimMountPath(cimPath, containerID string) (string, error) { + cimMountMapLock.Lock() + defer cimMountMapLock.Unlock() + + if vol, ok := cimMounts[fmt.Sprintf("%s_%s", containerID, cimPath)]; !ok { + return "", fmt.Errorf("cim %s not mounted", cimPath) + } else { + return vol, nil + } +} + +func CleanupContainerMounts(containerID string) error { + volumeGUID, err := guid.NewV5(cimMountNamespace, []byte(containerID)) + if err != nil { + return fmt.Errorf("generated cim mount GUID: %w", err) + } + + volPath := fmt.Sprintf("\\\\?\\Volume{%s}\\", volumeGUID.String()) + if _, err := os.Stat(volPath); err == nil { + err = cimfs.Unmount(volPath) + if err != nil { + return err + } + } + return nil +} diff --git a/pkg/cimfs/cim_test.go b/pkg/cimfs/cim_test.go index f73bb5b25c..193231ad44 100644 --- a/pkg/cimfs/cim_test.go +++ b/pkg/cimfs/cim_test.go @@ -15,6 +15,8 @@ import ( "time" "github.com/Microsoft/go-winio" + "github.com/Microsoft/go-winio/pkg/guid" + hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" "golang.org/x/sys/windows" ) @@ -102,7 +104,12 @@ func TestCimReadWrite(t *testing.T) { } // mount and read the contents of the cim - mountvol, err := Mount(cimPath) + volumeGUID, err := guid.NewV4() + if err != nil { + t.Fatalf("generate cim mount GUID: %s", err) + } + + mountvol, err := Mount(cimPath, volumeGUID, hcsschema.CimMountFlagCacheFiles) if err != nil { t.Fatalf("mount cim : %s", err) } diff --git a/pkg/cimfs/mount_cim.go b/pkg/cimfs/mount_cim.go index 0e7d8f9ac0..22f96e82d4 100644 --- a/pkg/cimfs/mount_cim.go +++ b/pkg/cimfs/mount_cim.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/Microsoft/go-winio/pkg/guid" - hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" "github.com/Microsoft/hcsshim/internal/winapi" "github.com/pkg/errors" ) @@ -33,22 +32,13 @@ func (e *MountError) Error() string { return s } -func MountWithFlags(cimPath string, mountFlags uint32) (string, error) { - layerGUID, err := guid.NewV4() - if err != nil { - return "", &MountError{Cim: cimPath, Op: "Mount", Err: err} - } - if err := winapi.CimMountImage(filepath.Dir(cimPath), filepath.Base(cimPath), mountFlags, &layerGUID); err != nil { - return "", &MountError{Cim: cimPath, Op: "Mount", VolumeGUID: layerGUID, Err: err} +// Mount mounts the given cim at a volume with given GUID. Returns the full volume +// path if mount is successful. +func Mount(cimPath string, volumeGUID guid.GUID, mountFlags uint32) (string, error) { + if err := winapi.CimMountImage(filepath.Dir(cimPath), filepath.Base(cimPath), mountFlags, &volumeGUID); err != nil { + return "", &MountError{Cim: cimPath, Op: "Mount", VolumeGUID: volumeGUID, Err: err} } - return fmt.Sprintf("\\\\?\\Volume{%s}\\", layerGUID.String()), nil -} - -// Mount mounts the cim at path `cimPath` and returns the mount location of that cim. This method uses the -// `CimMountFlagCacheRegions` mount flag when mounting the cim, if some other mount flag is desired use the -// `MountWithFlags` method. -func Mount(cimPath string) (string, error) { - return MountWithFlags(cimPath, hcsschema.CimMountFlagCacheFiles) + return fmt.Sprintf("\\\\?\\Volume{%s}\\", volumeGUID.String()), nil } // Unmount unmounts the cim at mounted at path `volumePath`.