diff --git a/pkg/scale/result.go b/pkg/scale/result.go index 21f93f4c75..5fd30879a4 100644 --- a/pkg/scale/result.go +++ b/pkg/scale/result.go @@ -61,15 +61,21 @@ func NewResult(okIn, errIn interface{}) (res Result) { func (r *Result) Set(mode ResultMode, in interface{}) (err error) { switch mode { case OK: - if reflect.TypeOf(r.ok) != reflect.TypeOf(in) { + if reflect.TypeOf(r.ok) == reflect.TypeOf(empty{}) && in == nil { + r.mode = mode + return + } else if reflect.TypeOf(r.ok) != reflect.TypeOf(in) { err = fmt.Errorf("type mistmatch for result.ok: %T, and inputted: %T", r.ok, in) return } r.ok = in r.mode = mode case Err: - if reflect.TypeOf(r.err) != reflect.TypeOf(in) { - err = fmt.Errorf("type mistmatch for result.ok: %T, and inputted: %T", r.ok, in) + if reflect.TypeOf(r.err) == reflect.TypeOf(empty{}) && in == nil { + r.mode = mode + return + } else if reflect.TypeOf(r.err) != reflect.TypeOf(in) { + err = fmt.Errorf("type mistmatch for result.err: %T, and inputted: %T", r.ok, in) return } r.err = in diff --git a/pkg/scale/result_test.go b/pkg/scale/result_test.go index 4140a4254b..d5b5320974 100644 --- a/pkg/scale/result_test.go +++ b/pkg/scale/result_test.go @@ -216,29 +216,45 @@ func TestResult_Set(t *testing.T) { in interface{} } tests := []struct { - name string - res Result - args args - wantErr bool + name string + res Result + args args + wantErr bool + wantResult Result }{ - // TODO: Add test cases. { args: args{ mode: Unset, }, + res: NewResult(nil, nil), wantErr: true, + wantResult: Result{ + ok: empty{}, err: empty{}, + }, }, { args: args{ mode: OK, in: nil, }, + res: NewResult(nil, nil), + wantResult: Result{ + ok: empty{}, + err: empty{}, + mode: OK, + }, }, { args: args{ mode: Err, in: nil, }, + res: NewResult(nil, nil), + wantResult: Result{ + ok: empty{}, + err: empty{}, + mode: Err, + }, }, { args: args{ @@ -246,6 +262,11 @@ func TestResult_Set(t *testing.T) { in: true, }, res: NewResult(true, nil), + wantResult: Result{ + ok: true, + err: empty{}, + mode: OK, + }, }, { args: args{ @@ -253,6 +274,11 @@ func TestResult_Set(t *testing.T) { in: true, }, res: NewResult(nil, true), + wantResult: Result{ + ok: empty{}, + err: true, + mode: Err, + }, }, { args: args{ @@ -261,6 +287,10 @@ func TestResult_Set(t *testing.T) { }, res: NewResult("ok", "err"), wantErr: true, + wantResult: Result{ + ok: "ok", + err: "err", + }, }, { args: args{ @@ -269,6 +299,10 @@ func TestResult_Set(t *testing.T) { }, res: NewResult(nil, true), wantErr: true, + wantResult: Result{ + ok: empty{}, + err: true, + }, }, } for _, tt := range tests { @@ -277,6 +311,9 @@ func TestResult_Set(t *testing.T) { if err := r.Set(tt.args.mode, tt.args.in); (err != nil) != tt.wantErr { t.Errorf("Result.Set() error = %v, wantErr %v", err, tt.wantErr) } + if !reflect.DeepEqual(tt.wantResult, r) { + t.Errorf("Result.Unwrap() = %v, want %v", tt.wantResult, r) + } }) } }