diff --git a/internal/directives/copy_directive.go b/internal/directives/copy_directive.go index 9bfd545e2..fc4d73453 100644 --- a/internal/directives/copy_directive.go +++ b/internal/directives/copy_directive.go @@ -46,21 +46,18 @@ func (d *copyDirective) Run(ctx context.Context, stepCtx *StepContext) (Result, return failure, fmt.Errorf("could not convert config into %s config: %w", d.Name(), err) } - if err = d.run(ctx, stepCtx, cfg); err != nil { - return failure, err - } - return Result{Status: StatusSuccess}, nil + return d.run(ctx, stepCtx, cfg) } -func (d *copyDirective) run(ctx context.Context, stepCtx *StepContext, cfg CopyConfig) error { +func (d *copyDirective) run(ctx context.Context, stepCtx *StepContext, cfg CopyConfig) (Result, error) { // Secure join the paths to prevent path traversal attacks. inPath, err := securejoin.SecureJoin(stepCtx.WorkDir, cfg.InPath) if err != nil { - return fmt.Errorf("could not secure join inPath %q: %w", cfg.InPath, err) + return Result{Status: StatusFailure}, fmt.Errorf("could not secure join inPath %q: %w", cfg.InPath, err) } outPath, err := securejoin.SecureJoin(stepCtx.WorkDir, cfg.OutPath) if err != nil { - return fmt.Errorf("could not secure join outPath %q: %w", cfg.OutPath, err) + return Result{Status: StatusFailure}, fmt.Errorf("could not secure join outPath %q: %w", cfg.OutPath, err) } // Perform the copy operation. @@ -74,9 +71,9 @@ func (d *copyDirective) run(ctx context.Context, stepCtx *StepContext, cfg CopyC }, } if err = copy.Copy(inPath, outPath, opts); err != nil { - return fmt.Errorf("failed to copy %q to %q: %w", cfg.InPath, cfg.OutPath, err) + return Result{Status: StatusFailure}, fmt.Errorf("failed to copy %q to %q: %w", cfg.InPath, cfg.OutPath, err) } - return nil + return Result{Status: StatusSuccess}, nil } // sanitizePathError sanitizes the path in a path error to be relative to the diff --git a/internal/directives/copy_directive_test.go b/internal/directives/copy_directive_test.go index 00fe68b7c..4ac24fa5e 100644 --- a/internal/directives/copy_directive_test.go +++ b/internal/directives/copy_directive_test.go @@ -16,7 +16,7 @@ func Test_copyDirective_run(t *testing.T) { name string setupFiles func(*testing.T) string cfg CopyConfig - assertions func(*testing.T, string, error) + assertions func(*testing.T, string, Result, error) }{ { name: "succeeds copying file", @@ -32,8 +32,9 @@ func Test_copyDirective_run(t *testing.T) { InPath: "input.txt", OutPath: "output.txt", }, - assertions: func(t *testing.T, workDir string, err error) { + assertions: func(t *testing.T, workDir string, result Result, err error) { assert.NoError(t, err) + assert.Equal(t, Result{Status: StatusSuccess}, result) outPath := filepath.Join(workDir, "output.txt") b, err := os.ReadFile(outPath) @@ -63,20 +64,21 @@ func Test_copyDirective_run(t *testing.T) { InPath: "input/", OutPath: "output/", }, - assertions: func(t *testing.T, workDir string, err error) { + assertions: func(t *testing.T, workDir string, result Result, err error) { assert.NoError(t, err) + assert.Equal(t, Result{Status: StatusSuccess}, result) outDir := filepath.Join(workDir, "output") outPath := filepath.Join(outDir, "input.txt") b, err := os.ReadFile(outPath) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "test content", string(b)) nestedDir := filepath.Join(outDir, "nested") nestedPath := filepath.Join(nestedDir, "nested.txt") b, err = os.ReadFile(nestedPath) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "nested content", string(b)) }, }, @@ -100,8 +102,9 @@ func Test_copyDirective_run(t *testing.T) { InPath: "input/", OutPath: "output/", }, - assertions: func(t *testing.T, workDir string, err error) { + assertions: func(t *testing.T, workDir string, result Result, err error) { assert.NoError(t, err) + require.Equal(t, Result{Status: StatusSuccess}, result) outDir := filepath.Join(workDir, "output") @@ -124,7 +127,8 @@ func Test_copyDirective_run(t *testing.T) { cfg: CopyConfig{ InPath: "input.txt", }, - assertions: func(t *testing.T, _ string, err error) { + assertions: func(t *testing.T, _ string, result Result, err error) { + require.Equal(t, Result{Status: StatusFailure}, result) require.ErrorContains(t, err, "failed to copy") }, }, @@ -135,11 +139,8 @@ func Test_copyDirective_run(t *testing.T) { workDir := tt.setupFiles(t) d := ©Directive{} - tt.assertions( - t, - workDir, - d.run(context.Background(), &StepContext{WorkDir: workDir}, tt.cfg), - ) + result, err := d.run(context.Background(), &StepContext{WorkDir: workDir}, tt.cfg) + tt.assertions(t, workDir, result, err) }) } }