Skip to content

Commit

Permalink
making ecs-init tests os-agnostic
Browse files Browse the repository at this point in the history
  • Loading branch information
singholt committed Jan 25, 2025
1 parent 7f838b7 commit 4d05823
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 36 deletions.
3 changes: 1 addition & 2 deletions ecs-init/apparmor/apparmor.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ var (
isProfileLoaded = aaprofile.IsLoaded
loadPath = loadProfile
createFile = os.Create
statFile = os.Stat
)

// loadPath runs `apparmor_parser -Kr` on a specified apparmor profile to
Expand Down Expand Up @@ -152,7 +151,7 @@ func LoadDefaultProfile(profileName string) error {
}

func fileExists(path string) (bool, error) {
_, err := statFile(path)
_, err := config.OsStat(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return false, nil
Expand Down
5 changes: 3 additions & 2 deletions ecs-init/apparmor/apparmor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"path/filepath"
"testing"

"github.com/aws/amazon-ecs-agent/ecs-init/config"
aaprofile "github.com/docker/docker/profiles/apparmor"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -108,7 +109,7 @@ func TestLoadDefaultProfile(t *testing.T) {
isProfileLoaded = aaprofile.IsLoaded
loadPath = loadProfile
createFile = os.Create
statFile = os.Stat
config.OsStat = os.Stat
}()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
Expand All @@ -123,7 +124,7 @@ func TestLoadDefaultProfile(t *testing.T) {
return f, err
}

statFile = func(fileName string) (os.FileInfo, error) {
config.OsStat = func(fileName string) (os.FileInfo, error) {
relativePath, err := filepath.Rel(appArmorProfileDir, fileName)
require.NoError(t, err)
return nil, tc.statErrors[relativePath]
Expand Down
4 changes: 3 additions & 1 deletion ecs-init/cache/dependencies.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"os"
"path/filepath"

cfg "github.com/aws/amazon-ecs-agent/ecs-init/config"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
Expand Down Expand Up @@ -185,7 +187,7 @@ func (s *standardFS) Open(name string) (io.ReadCloser, error) {
}

func (s *standardFS) Stat(name string) (fileSizeInfo, error) {
return os.Stat(name)
return cfg.OsStat(name)
}

func (s *standardFS) Base(path string) string {
Expand Down
7 changes: 5 additions & 2 deletions ecs-init/config/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ const (
ECSAgentAppArmorDefaultProfileName = "ecs-agent-default"
)

// OsStat is useful for mocking in unit tests
var OsStat = os.Stat

// partitionBucketRegion provides the "partitional" bucket region
// suitable for downloading agent from.
var partitionBucketRegion = map[string]string{
Expand Down Expand Up @@ -258,15 +261,15 @@ func MountDirectoryEBS() string {

// HostCertsDirPath() returns the CA store path on the host
func HostCertsDirPath() string {
if _, err := os.Stat(hostCertsDirPath); err != nil {
if _, err := OsStat(hostCertsDirPath); err != nil {
return ""
}
return hostCertsDirPath
}

// HostPKIDirPath() returns the CA store path on the host
func HostPKIDirPath() string {
if _, err := os.Stat(hostPKIDirPath); err != nil {
if _, err := OsStat(hostPKIDirPath); err != nil {
return ""
}
return hostPKIDirPath
Expand Down
10 changes: 7 additions & 3 deletions ecs-init/docker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ const (
// nvidiaGPUDevicesPresentMaxRetries specifies the maximum number of retries to attempt for checking if NVIDIA
// GPU devices are present.
nvidiaGPUDevicesPresentMaxRetries = 10

// lsblk lists information about block devices. This is used by the ECS agent for the EBS task attach functionality.
// Ref: https://man7.org/linux/man-pages/man8/lsblk.8.html
lsblkDir = "/usr/bin/lsblk"
)

// Do NOT include "CAP_" in capability string
Expand Down Expand Up @@ -495,7 +499,7 @@ func getCredentialsFetcherSocketBind() (string, bool) {
credentialsFetcherUnixSocketHostPath, ok := config.HostCredentialsFetcherPath()
if ok && credentialsFetcherUnixSocketHostPath != "" {
// check whether the path to the credentials fetcher socket exists
_, err := os.Stat(credentialsFetcherUnixSocketHostPath)
_, err := config.OsStat(credentialsFetcherUnixSocketHostPath)
if err != nil {
if os.IsNotExist(err) {
return "", false
Expand All @@ -508,7 +512,7 @@ func getCredentialsFetcherSocketBind() (string, bool) {
}

// getDockerSocketBind returns the bind for Docker socket.
// Value for the bind is as follow:
// Value for the bind is as follows:
// 1. DOCKER_HOST (as in os.Getenv) not set: source /var/run, dest /var/run
// 2. DOCKER_HOST (as in os.Getenv) set: source DOCKER_HOST (as in os.Getenv, trim unix:// prefix),
// dest DOCKER_HOST (as in /etc/ecs/ecs.config, trim unix:// prefix)
Expand Down Expand Up @@ -562,7 +566,7 @@ func getCapabilityBinds() []string {
}

func defaultIsPathValid(path string, shouldBeDirectory bool) bool {
fileInfo, err := os.Stat(path)
fileInfo, err := config.OsStat(path)
if err != nil {
return false
}
Expand Down
6 changes: 3 additions & 3 deletions ecs-init/docker/docker_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ func createHostConfig(binds []string) *godocker.HostConfig {
iptablesExecutableHostDir+":"+iptablesExecutableContainerDir+readOnly,
iptablesAltDir+":"+iptablesAltDir+readOnly,
iptablesLegacyDir+":"+iptablesLegacyDir+readOnly,
"/usr/bin/lsblk:/usr/bin/lsblk",
lsblkDir+":"+lsblkDir,
)
binds = append(binds, getNsenterBinds(os.Stat)...)
binds = append(binds, getModInfoBinds(os.Stat)...)
binds = append(binds, getNsenterBinds(config.OsStat)...)
binds = append(binds, getModInfoBinds(config.OsStat)...)

logConfig := config.AgentDockerLogDriverConfiguration()

Expand Down
75 changes: 53 additions & 22 deletions ecs-init/docker/docker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,6 @@ import (
"github.com/stretchr/testify/assert"
)

// defaultExpectedAgentBinds is the total number of agent host config binds.
// Note: Change this value every time when a new bind mount is added to
// agent for the tests to pass
const (
defaultExpectedAgentBinds = 22
)

func TestIsAgentImageLoadedListFailure(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
Expand Down Expand Up @@ -197,7 +190,7 @@ func TestStartAgentNoEnvFile(t *testing.T) {
mockFS.EXPECT().ReadFile(config.InstanceConfigFile()).Return(nil, errors.New("not found")).AnyTimes()
mockFS.EXPECT().ReadFile(config.AgentConfigFile()).Return(nil, errors.New("test error")).AnyTimes()
mockDocker.EXPECT().CreateContainer(gomock.Any()).Do(func(opts godocker.CreateContainerOptions) {
validateCommonCreateContainerOptions(t, opts, defaultExpectedAgentBinds)
validateCommonCreateContainerOptions(t, opts)
}).Return(&godocker.Container{
ID: containerID,
}, nil)
Expand All @@ -218,7 +211,6 @@ func TestStartAgentNoEnvFile(t *testing.T) {
func validateCommonCreateContainerOptions(
t *testing.T,
opts godocker.CreateContainerOptions,
expectedAgentBinds int,
) {
if opts.Name != "ecs-agent" {
t.Errorf("Expected container Name to be %s but was %s", "ecs-agent", opts.Name)
Expand Down Expand Up @@ -250,7 +242,6 @@ func validateCommonCreateContainerOptions(

hostCfg := opts.HostConfig

assert.Len(t, hostCfg.Binds, expectedAgentBinds)
binds := make(map[string]struct{})
for _, binding := range hostCfg.Binds {
binds[binding] = struct{}{}
Expand All @@ -261,6 +252,9 @@ func validateCommonCreateContainerOptions(
expectKey(config.AgentDataDirectory()+":/data", binds, t)
expectKey(config.AgentConfigDirectory()+":"+config.AgentConfigDirectory(), binds, t)
expectKey(config.CacheDirectory()+":"+config.CacheDirectory(), binds, t)
expectKey(config.CgroupMountpoint()+":"+DefaultCgroupMountpoint, binds, t)
expectKey(config.InstanceConfigDirectory()+":"+config.InstanceConfigDirectory(), binds, t)
expectKey(filepath.Join(config.LogDirectory(), execAgentLogRelativePath)+":"+filepath.Join(logDir, execAgentLogRelativePath), binds, t)
expectKey(config.ProcFS+":"+hostProcDir+":ro", binds, t)
expectKey(iptablesUsrLibDir+":"+iptablesUsrLibDir+":ro", binds, t)
expectKey(iptablesLibDir+":"+iptablesLibDir+":ro", binds, t)
Expand All @@ -269,11 +263,25 @@ func validateCommonCreateContainerOptions(
expectKey(iptablesExecutableHostDir+":"+iptablesExecutableContainerDir+":ro", binds, t)
expectKey(iptablesAltDir+":"+iptablesAltDir+":ro", binds, t)
expectKey(iptablesLegacyDir+":"+iptablesLegacyDir+":ro", binds, t)
expectKey(config.HostPKIDirPath()+":"+config.HostPKIDirPath()+":ro", binds, t)
expectKey(lsblkDir+":"+lsblkDir, binds, t)
expectKey(config.LogDirectory()+"/exec:/log/exec", binds, t)
for _, pluginDir := range pluginDirs {
expectKey(pluginDir+":"+pluginDir+readOnly, binds, t)
}

// verify nsenter binds are present in the hostConfig
nsEnterBinds := getNsenterBinds(config.OsStat)
for _, nsEnterDir := range nsEnterBinds {
expectKey(nsEnterDir, binds, t)
}

// verify modInfo binds are present in the hostConfig
modInfoBinds := getModInfoBinds(config.OsStat)
for _, modInfoBind := range modInfoBinds {
expectKey(modInfoBind, binds, t)
}

if hostCfg.NetworkMode != networkMode {
t.Errorf("Expected network mode to be %s, got %s", networkMode, hostCfg.NetworkMode)
}
Expand Down Expand Up @@ -318,6 +326,13 @@ func TestStartAgentEnvFile(t *testing.T) {
isPathValid = defaultIsPathValid
}()

config.OsStat = func(name string) (os.FileInfo, error) {
return nil, nil
}
defer func() {
config.OsStat = os.Stat
}()

envFile := "\nAGENT_TEST_VAR=val\nAGENT_TEST_VAR2=val2\n"
containerID := "container id"

Expand All @@ -327,7 +342,7 @@ func TestStartAgentEnvFile(t *testing.T) {
mockFS.EXPECT().ReadFile(config.InstanceConfigFile()).Return(nil, errors.New("not found")).AnyTimes()
mockFS.EXPECT().ReadFile(config.AgentConfigFile()).Return([]byte(envFile), nil).AnyTimes()
mockDocker.EXPECT().CreateContainer(gomock.Any()).Do(func(opts godocker.CreateContainerOptions) {
validateCommonCreateContainerOptions(t, opts, defaultExpectedAgentBinds)
validateCommonCreateContainerOptions(t, opts)
cfg := opts.Config

envVariables := make(map[string]struct{})
Expand Down Expand Up @@ -362,10 +377,15 @@ func TestStartAgentWithGPUConfig(t *testing.T) {
isPathValid = defaultIsPathValid
}()

config.OsStat = func(name string) (os.FileInfo, error) {
return nil, nil
}
defer func() {
config.OsStat = os.Stat
}()

envFile := "\nECS_ENABLE_GPU_SUPPORT=true\n"
containerID := "container id"
expectedAgentBinds := defaultExpectedAgentBinds
expectedAgentBinds += 1

defer func() {
MatchFilePatternForGPU = FilePatternMatchForGPU
Expand All @@ -380,7 +400,7 @@ func TestStartAgentWithGPUConfig(t *testing.T) {
mockFS.EXPECT().ReadFile(config.InstanceConfigFile()).Return([]byte(envFile), nil).AnyTimes()
mockFS.EXPECT().ReadFile(config.AgentConfigFile()).Return(nil, errors.New("not found")).AnyTimes()
mockDocker.EXPECT().CreateContainer(gomock.Any()).Do(func(opts godocker.CreateContainerOptions) {
validateCommonCreateContainerOptions(t, opts, expectedAgentBinds)
validateCommonCreateContainerOptions(t, opts)
var found bool
for _, bind := range opts.HostConfig.Binds {
if bind == gpu.GPUInfoDirPath+":"+gpu.GPUInfoDirPath {
Expand Down Expand Up @@ -421,6 +441,13 @@ func TestStartAgentWithGPUConfigNoDevices(t *testing.T) {
isPathValid = defaultIsPathValid
}()

config.OsStat = func(name string) (os.FileInfo, error) {
return nil, nil
}
defer func() {
config.OsStat = os.Stat
}()

envFile := "\nECS_ENABLE_GPU_SUPPORT=true\n"
containerID := "container id"

Expand All @@ -438,7 +465,7 @@ func TestStartAgentWithGPUConfigNoDevices(t *testing.T) {
mockFS.EXPECT().ReadFile(config.InstanceConfigFile()).Return([]byte(envFile), nil).AnyTimes()
mockFS.EXPECT().ReadFile(config.AgentConfigFile()).Return(nil, errors.New("not found")).AnyTimes()
mockDocker.EXPECT().CreateContainer(gomock.Any()).Do(func(opts godocker.CreateContainerOptions) {
validateCommonCreateContainerOptions(t, opts, defaultExpectedAgentBinds)
validateCommonCreateContainerOptions(t, opts)
cfg := opts.Config

envVariables := make(map[string]struct{})
Expand Down Expand Up @@ -864,6 +891,16 @@ func TestStartAgentWithExecBinds(t *testing.T) {
isPathValid = func(path string, isDir bool) bool {
return true
}
defer func() {
isPathValid = defaultIsPathValid
}()
config.OsStat = func(name string) (os.FileInfo, error) {
return nil, nil
}
defer func() {
config.OsStat = os.Stat
}()

hostCapabilityExecResourcesDir := filepath.Join(hostResourcesRootDir, execCapabilityName)
containerCapabilityExecResourcesDir := filepath.Join(containerResourcesRootDir, execCapabilityName)

Expand All @@ -876,21 +913,15 @@ func TestStartAgentWithExecBinds(t *testing.T) {
hostConfigDir + ":" + containerConfigDir,
}

expectedAgentBinds := defaultExpectedAgentBinds
expectedAgentBinds += len(expectedExecBinds)
// bind mount for the config folder is already included in expectedAgentBinds since it's always added
expectedExecBinds = append(expectedExecBinds, hostConfigDir+":"+containerConfigDir)
defer func() {
isPathValid = defaultIsPathValid
}()

mockFS := NewMockfileSystem(mockCtrl)
mockDocker := NewMockdockerclient(mockCtrl)

mockFS.EXPECT().ReadFile(config.InstanceConfigFile()).Return(nil, errors.New("not found")).AnyTimes()
mockFS.EXPECT().ReadFile(config.AgentConfigFile()).Return(nil, errors.New("not found")).AnyTimes()
mockDocker.EXPECT().CreateContainer(gomock.Any()).Do(func(opts godocker.CreateContainerOptions) {
validateCommonCreateContainerOptions(t, opts, expectedAgentBinds)
validateCommonCreateContainerOptions(t, opts)

// verify that exec binds are added
assert.Subset(t, opts.HostConfig.Binds, expectedExecBinds)
Expand Down
4 changes: 3 additions & 1 deletion ecs-init/volumes/state_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ import (
"path/filepath"
"sync"

"github.com/aws/amazon-ecs-agent/ecs-init/config"
"github.com/aws/amazon-ecs-agent/ecs-init/volumes/types"

"github.com/cihub/seelog"
)

Expand Down Expand Up @@ -138,7 +140,7 @@ func saveState(b []byte) error {
var fileExists = checkFile

func checkFile(filename string) bool {
_, err := os.Stat(filename)
_, err := config.OsStat(filename)
return !os.IsNotExist(err)
}

Expand Down

0 comments on commit 4d05823

Please sign in to comment.