diff --git a/cmd/sql/flow_test.go b/cmd/sql/flow_test.go index f6bc23dbd..62621e55b 100644 --- a/cmd/sql/flow_test.go +++ b/cmd/sql/flow_test.go @@ -55,6 +55,7 @@ func patchDockerClientInit() error { mockDockerBinder.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil) mockDockerBinder.On("ContainerWait", mock.Anything, mock.Anything, mock.Anything).Return(getContainerWaitResponse(false)) mockDockerBinder.On("ContainerLogs", mock.Anything, mock.Anything, mock.Anything).Return(sampleLog, nil) + mockDockerBinder.On("ContainerRemove", mock.Anything, mock.Anything, mock.Anything).Return(nil) return mockDockerBinder, nil } return nil diff --git a/sql/docker_interface.go b/sql/docker_interface.go index b5b8e6da8..b966559fc 100644 --- a/sql/docker_interface.go +++ b/sql/docker_interface.go @@ -21,6 +21,7 @@ type DockerBind interface { ContainerStart(ctx context.Context, containerID string, options types.ContainerStartOptions) error ContainerWait(ctx context.Context, containerID string, condition container.WaitCondition) (<-chan container.ContainerWaitOKBody, <-chan error) ContainerLogs(ctx context.Context, container string, options types.ContainerLogsOptions) (io.ReadCloser, error) + ContainerRemove(ctx context.Context, containerID string, options types.ContainerRemoveOptions) error } func (d DockerBinder) ImageBuild(ctx context.Context, buildContext io.Reader, options *types.ImageBuildOptions) (types.ImageBuildResponse, error) { @@ -43,6 +44,10 @@ func (d DockerBinder) ContainerLogs(ctx context.Context, containerID string, opt return d.cli.ContainerLogs(ctx, containerID, options) } +func (d DockerBinder) ContainerRemove(ctx context.Context, containerID string, options types.ContainerRemoveOptions) error { + return d.cli.ContainerRemove(ctx, containerID, options) +} + func NewDockerClient() (DockerBind, error) { cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation()) if err != nil { diff --git a/sql/flow.go b/sql/flow.go index 6aaec6301..7095ef756 100644 --- a/sql/flow.go +++ b/sql/flow.go @@ -123,5 +123,9 @@ func CommonDockerUtil(cmd, args []string, flags map[string]string, mountDirs []s return err } + if err := cli.ContainerRemove(ctx, resp.ID, types.ContainerRemoveOptions{}); err != nil { + return fmt.Errorf("docker remove failed %w", err) + } + return nil } diff --git a/sql/flow_test.go b/sql/flow_test.go index ad8db9e2e..a1fe4c3ba 100644 --- a/sql/flow_test.go +++ b/sql/flow_test.go @@ -52,6 +52,7 @@ func TestCommonDockerUtilSuccess(t *testing.T) { mockDockerBinder.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil) mockDockerBinder.On("ContainerWait", mock.Anything, mock.Anything, mock.Anything).Return(getContainerWaitResponse(false)) mockDockerBinder.On("ContainerLogs", mock.Anything, mock.Anything, mock.Anything).Return(sampleLog, nil) + mockDockerBinder.On("ContainerRemove", mock.Anything, mock.Anything, mock.Anything).Return(nil) return mockDockerBinder, nil } err := CommonDockerUtil(testCommand, nil, map[string]string{"flag": "value"}, []string{"mountDirectory"}) @@ -186,3 +187,23 @@ func TestCommonDockerUtilLogsCopyFailure(t *testing.T) { assert.Equal(t, expectedErr, err) mockDockerBinder.AssertExpectations(t) } + +func TestContainerRemoveFailure(t *testing.T) { + mockDockerBinder := new(mocks.DockerBind) + DockerClientInit = func() (DockerBind, error) { + mockDockerBinder.On("ImageBuild", mock.Anything, mock.Anything, mock.Anything).Return(imageBuildResponse, nil) + mockDockerBinder.On("ContainerCreate", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(containerCreateCreatedBody, nil) + mockDockerBinder.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil) + mockDockerBinder.On("ContainerWait", mock.Anything, mock.Anything, mock.Anything).Return(getContainerWaitResponse(false)) + mockDockerBinder.On("ContainerLogs", mock.Anything, mock.Anything, mock.Anything).Return(sampleLog, nil) + mockDockerBinder.On("ContainerRemove", mock.Anything, mock.Anything, mock.Anything).Return(errMock) + return mockDockerBinder, nil + } + ioCopy = func(dst io.Writer, src io.Reader) (written int64, err error) { + return 0, nil + } + err := CommonDockerUtil(testCommand, nil, nil, nil) + expectedErr := fmt.Errorf("docker remove failed %w", errMock) + assert.Equal(t, expectedErr, err) + mockDockerBinder.AssertExpectations(t) +} diff --git a/sql/mocks/flow.go b/sql/mocks/flow.go index c3a669538..2e32ae8f2 100644 --- a/sql/mocks/flow.go +++ b/sql/mocks/flow.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.14.0. DO NOT EDIT. +// Code generated by mockery v2.14.1. DO NOT EDIT. package mocks @@ -67,6 +67,20 @@ func (_m *DockerBind) ContainerLogs(ctx context.Context, _a1 string, options typ return r0, r1 } +// ContainerRemove provides a mock function with given fields: ctx, containerID, options +func (_m *DockerBind) ContainerRemove(ctx context.Context, containerID string, options types.ContainerRemoveOptions) error { + ret := _m.Called(ctx, containerID, options) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, types.ContainerRemoveOptions) error); ok { + r0 = rf(ctx, containerID, options) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // ContainerStart provides a mock function with given fields: ctx, containerID, options func (_m *DockerBind) ContainerStart(ctx context.Context, containerID string, options types.ContainerStartOptions) error { ret := _m.Called(ctx, containerID, options)