diff --git a/helper/helper_test.go b/helper/helper_test.go index 3b8309b..e267c61 100644 --- a/helper/helper_test.go +++ b/helper/helper_test.go @@ -5,6 +5,9 @@ import ( "testing" "github.com/incu6us/goimports-reviser/v3/reviser" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestDetermineProjectName(t *testing.T) { @@ -15,10 +18,9 @@ func TestDetermineProjectName(t *testing.T) { filePath string } tests := []struct { - name string - args args - want string - wantErr bool + name string + args args + want string }{ { name: "success with manual filepath", @@ -26,26 +28,19 @@ func TestDetermineProjectName(t *testing.T) { projectName: "", filePath: func() string { dir, err := os.Getwd() - if err != nil { - t.Fatal(err) - } - + require.NoError(t, err) return dir }(), }, - want: "github.com/incu6us/goimports-reviser/v3", - wantErr: false, + want: "github.com/incu6us/goimports-reviser/v3", }, { name: "success with stdin", args: args{ projectName: "", - filePath: func() string { - return reviser.StandardInput - }(), + filePath: reviser.StandardInput, }, - want: "github.com/incu6us/goimports-reviser/v3", - wantErr: false, + want: "github.com/incu6us/goimports-reviser/v3", }, } for _, tt := range tests { @@ -54,13 +49,8 @@ func TestDetermineProjectName(t *testing.T) { t.Parallel() got, err := DetermineProjectName(tt.args.projectName, tt.args.filePath) - if (err != nil) != tt.wantErr { - t.Errorf("DetermineProjectName() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("DetermineProjectName() got = %v, want %v", got, tt.want) - } + require.NoError(t, err) + assert.Equal(t, tt.want, got) }) } } diff --git a/pkg/module/error_test.go b/pkg/module/error_test.go index 693057b..74b7c03 100644 --- a/pkg/module/error_test.go +++ b/pkg/module/error_test.go @@ -1,6 +1,10 @@ package module -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/assert" +) func TestPathIsNotSetError_Error(t *testing.T) { t.Parallel() @@ -18,9 +22,8 @@ func TestPathIsNotSetError_Error(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { e := &PathIsNotSetError{} - if got := e.Error(); got != tt.want { - t.Errorf("Error() = %v, want %v", got, tt.want) - } + got := e.Error() + assert.Equal(t, tt.want, got) }) } } @@ -40,9 +43,8 @@ func TestUndefinedModuleError_Error(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { e := &UndefinedModuleError{} - if got := e.Error(); got != tt.want { - t.Errorf("Error() = %v, want %v", got, tt.want) - } + got := e.Error() + assert.Equal(t, tt.want, got) }) } } diff --git a/pkg/module/module_test.go b/pkg/module/module_test.go index c0e8775..afca75a 100644 --- a/pkg/module/module_test.go +++ b/pkg/module/module_test.go @@ -4,74 +4,46 @@ import ( "os" "path/filepath" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGoModRootPathAndName(t *testing.T) { t.Parallel() - type args struct { - dir string - } - tests := []struct { - name string - args args - want string - wantErr bool - }{ - { - name: "success", - args: args{ - dir: func() string { - dir, err := os.Getwd() - if err != nil { - t.Fatal(err) - } - return dir - }(), - }, - want: "github.com/incu6us/goimports-reviser/v3", - wantErr: false, - }, - { - name: "path is not set error", - args: args{ - dir: "", - }, - want: "", - wantErr: true, - }, - { - name: "path with '.'", - args: args{ - dir: ".", - }, - want: "", - wantErr: true, - }, - } + t.Run("success", func(t *testing.T) { + t.Parallel() - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() + dir, err := os.Getwd() + require.NoError(t, err) - goModRootPath, err := GoModRootPath(tt.args.dir) - if err != nil && !tt.wantErr { - t.Errorf("GoModRootPath() error = %v, wantErr %v", err, tt.wantErr) - return - } + goModRootPath, err := GoModRootPath(dir) + require.NoError(t, err) - got, err := Name(goModRootPath) - if (err != nil) != tt.wantErr { - t.Errorf("Name() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if got != tt.want { - t.Errorf("Name() path = %v, want %v", got, tt.want) - } - }) - } + got, err := Name(goModRootPath) + require.NoError(t, err) + assert.Equal(t, "github.com/incu6us/goimports-reviser/v3", got) + }) + + t.Run("path is not set error", func(t *testing.T) { + t.Parallel() + + goModPath, err := GoModRootPath("") + assert.Error(t, err) + assert.Empty(t, goModPath) + }) + + t.Run("path is empty", func(t *testing.T) { + t.Parallel() + + goModPath, err := GoModRootPath(".") + assert.NoError(t, err) + + got, err := Name(goModPath) + assert.Error(t, err) + assert.Empty(t, got) + }) } func TestName(t *testing.T) { @@ -80,44 +52,28 @@ func TestName(t *testing.T) { tests := []struct { name string prepareFn func() string - want string - wantErr bool }{ { name: "read empty go.mod", prepareFn: func() string { dir := t.TempDir() f, err := os.Create(filepath.Join(dir, "go.mod")) - if err != nil { - t.Fatal(err) - } - if err := f.Close(); err != nil { - t.Fatal(err) - } + require.NoError(t, err) + require.NoError(t, f.Close()) return dir }, - want: "", - wantErr: true, }, { name: "check failed parsing of go.mod", prepareFn: func() string { dir := t.TempDir() file, err := os.Create(filepath.Join(dir, "go.mod")) - if err != nil { - t.Fatal(err) - } - - if _, err := file.WriteString("mod test"); err != nil { - t.Fatal(err) - } - if err := file.Close(); err != nil { - t.Fatal(err) - } + require.NoError(t, err) + _, err = file.WriteString("mod test") + require.NoError(t, err) + require.NoError(t, file.Close()) return dir }, - want: "", - wantErr: true, }, } for _, tt := range tests { @@ -127,13 +83,8 @@ func TestName(t *testing.T) { goModRootPath := tt.prepareFn() got, err := Name(goModRootPath) - if (err != nil) != tt.wantErr { - t.Errorf("Name() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("Name() got = %v, want %v", got, tt.want) - } + require.Error(t, err) + assert.Empty(t, got) }) } } @@ -146,10 +97,9 @@ func TestDetermineProjectName(t *testing.T) { filePath string } tests := []struct { - name string - args args - want string - wantErr bool + name string + args args + want string }{ { name: "success with auto determining", @@ -157,14 +107,11 @@ func TestDetermineProjectName(t *testing.T) { projectName: "", filePath: func() string { dir, err := os.Getwd() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) return filepath.Join(dir, "module.go") }(), }, - want: "github.com/incu6us/goimports-reviser/v3", - wantErr: false, + want: "github.com/incu6us/goimports-reviser/v3", }, { @@ -173,8 +120,7 @@ func TestDetermineProjectName(t *testing.T) { projectName: "github.com/incu6us/goimports-reviser/v3", filePath: "", }, - want: "github.com/incu6us/goimports-reviser/v3", - wantErr: false, + want: "github.com/incu6us/goimports-reviser/v3", }, } for _, tt := range tests { @@ -183,13 +129,8 @@ func TestDetermineProjectName(t *testing.T) { t.Parallel() got, err := DetermineProjectName(tt.args.projectName, tt.args.filePath) - if (err != nil) != tt.wantErr { - t.Errorf("DetermineProjectName() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("DetermineProjectName() got = %v, want %v", got, tt.want) - } + require.NoError(t, err) + assert.Equal(t, tt.want, got) }) } }