From bcaac990bb35a0117c36f4dd58c52424b4a79b1b Mon Sep 17 00:00:00 2001 From: Hiroto Funakoshi Date: Mon, 6 Jun 2022 13:36:34 +0900 Subject: [PATCH] Fix race error of server package (#1689) * fix race error and deleted duplicate test function Signed-off-by: hlts2 * apply suggestion Signed-off-by: hlts2 * apply suggestion Signed-off-by: hlts2 * Apply suggestions from code review Co-authored-by: Kevin Diu * fix deepsource warning Signed-off-by: hlts2 Co-authored-by: Kevin Diu --- internal/servers/option_test.go | 3 +- internal/servers/server/server_test.go | 945 ++----------------------- internal/servers/servers_test.go | 307 ++------ 3 files changed, 139 insertions(+), 1116 deletions(-) diff --git a/internal/servers/option_test.go b/internal/servers/option_test.go index 8468ec7ad5..87c95cfb9c 100644 --- a/internal/servers/option_test.go +++ b/internal/servers/option_test.go @@ -16,6 +16,7 @@ package servers import ( + "context" "reflect" "testing" "time" @@ -97,7 +98,7 @@ func TestWithErrorGroup(t *testing.T) { tests := []test{ func() test { - eg := errgroup.Get() + eg, _ := errgroup.New(context.Background()) return test{ name: "set success", diff --git a/internal/servers/server/server_test.go b/internal/servers/server/server_test.go index e614788ae9..76ead14f03 100644 --- a/internal/servers/server/server_test.go +++ b/internal/servers/server/server_test.go @@ -20,8 +20,6 @@ import ( "crypto/tls" "net/http" "net/http/httptest" - "reflect" - "sync" "testing" "time" @@ -30,12 +28,10 @@ import ( "github.com/vdaas/vald/internal/log" "github.com/vdaas/vald/internal/log/logger" "github.com/vdaas/vald/internal/net" - "github.com/vdaas/vald/internal/net/control" "github.com/vdaas/vald/internal/net/grpc" - "github.com/vdaas/vald/internal/test/goleak" ) -func TestString(t *testing.T) { +func TestServerMode_String(t *testing.T) { type test struct { name string m ServerMode @@ -312,7 +308,7 @@ func TestNew(t *testing.T) { } } -func TestIsRunning(t *testing.T) { +func Test_server_IsRunning(t *testing.T) { type test struct { name string s *server @@ -347,7 +343,7 @@ func TestIsRunning(t *testing.T) { } } -func TestName(t *testing.T) { +func Test_server_Name(t *testing.T) { type test struct { name string s *server @@ -374,7 +370,7 @@ func TestName(t *testing.T) { } } -func TestListenAndServe(t *testing.T) { +func Test_server_ListenAndServe(t *testing.T) { type args struct { ctx context.Context errCh chan error @@ -403,14 +399,15 @@ func TestListenAndServe(t *testing.T) { } tests := []test{ - { - name: "returns nil when server is already running", - field: field{ - running: true, - }, - want: nil, - }, - + func() test { + return test{ + name: "returns nil when server is already running", + field: field{ + running: true, + }, + want: nil, + } + }(), func() test { err := errors.New("faild to prestart") @@ -425,8 +422,10 @@ func TestListenAndServe(t *testing.T) { want: err, } }(), - func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, ctx := errgroup.New(ctx) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) }) @@ -439,7 +438,7 @@ func TestListenAndServe(t *testing.T) { name: "returns nil when serving of REST server is successes", field: field{ mode: REST, - eg: errgroup.Get(), + eg: eg, httpSrvStarter: srv.Serve, host: "vald", port: 8081, @@ -449,18 +448,25 @@ func TestListenAndServe(t *testing.T) { }, running: false, }, + afterFunc: func() { + srv.Shutdown(ctx) + cancel() + eg.Wait() + }, want: nil, } }(), - func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, _ := errgroup.New(ctx) + srv := new(grpc.Server) return test{ name: "returns nil when serving of gRPC server is successes", field: field{ mode: GRPC, - eg: errgroup.Get(), + eg: eg, httpSrvStarter: srv.Serve, grpcSrv: srv, host: "vald", @@ -471,6 +477,10 @@ func TestListenAndServe(t *testing.T) { }, running: false, }, + afterFunc: func() { + cancel() + eg.Wait() + }, want: nil, } }(), @@ -518,7 +528,7 @@ func TestListenAndServe(t *testing.T) { } } -func TestShutdown(t *testing.T) { +func Test_server_Shutdown(t *testing.T) { type args struct { ctx context.Context } @@ -543,19 +553,32 @@ func TestShutdown(t *testing.T) { want error } - tests := []test{ - { - name: "returns nil when server is not running", - checkFunc: func(s *server, got, want error) error { - if want != got { - return errors.Errorf("Shutdown returns error: %v", got) - } - return nil - }, - want: nil, - }, + defaultCheckFunc := func(s *server, got, want error) error { + if want != got { + return errors.Errorf("Shutdown returns error: %v", got) + } + var running bool + s.mu.RLock() + running = s.running + s.mu.RUnlock() + + if running { + return errors.New("server is running") + } + return nil + } + tests := []test{ func() test { + return test{ + name: "returns nil when server is not running", + want: nil, + } + }(), + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, ctx := errgroup.New(ctx) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) }) @@ -564,11 +587,11 @@ func TestShutdown(t *testing.T) { return test{ name: "returns nil when shutdown of REST server is successes", args: args{ - ctx: context.Background(), + ctx: ctx, }, field: field{ mode: REST, - eg: errgroup.Get(), + eg: eg, pwt: 10 * time.Millisecond, sddur: 1 * time.Second, running: true, @@ -577,30 +600,28 @@ func TestShutdown(t *testing.T) { return nil }, }, - checkFunc: func(s *server, got, want error) error { - if want != got { - return errors.Errorf("Shutdown returns error: %v", got) - } - return nil - }, afterFunc: func() { testSrv.Close() + cancel() + eg.Wait() }, want: nil, } }(), - func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, ctx := errgroup.New(ctx) + grpcSrv := grpc.NewServer() return test{ name: "returns nil when shutdown of gRPC server is successes", args: args{ - ctx: context.Background(), + ctx: ctx, }, field: field{ mode: GRPC, - eg: errgroup.Get(), + eg: eg, pwt: 10 * time.Millisecond, sddur: 1 * time.Second, running: true, @@ -609,13 +630,9 @@ func TestShutdown(t *testing.T) { return nil }, }, - checkFunc: func(s *server, got, want error) error { - if want != got { - return errors.Errorf("Shutdown returns error: %v", got) - } - return nil - }, afterFunc: func() { + cancel() + eg.Wait() }, want: nil, } @@ -630,6 +647,9 @@ func TestShutdown(t *testing.T) { defer tt.afterFunc() } }() + if tt.checkFunc == nil { + tt.checkFunc = defaultCheckFunc + } s := &server{ mode: tt.field.mode, @@ -662,830 +682,3 @@ func TestShutdown(t *testing.T) { }) } } - -func TestServerMode_String(t *testing.T) { - type want struct { - want string - } - type test struct { - name string - m ServerMode - want want - checkFunc func(want, string) error - beforeFunc func() - afterFunc func() - } - defaultCheckFunc := func(w want, got string) error { - if !reflect.DeepEqual(got, w.want) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got, w.want) - } - return nil - } - tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ - } - - for _, tc := range tests { - test := tc - t.Run(test.name, func(tt *testing.T) { - tt.Parallel() - defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) - if test.beforeFunc != nil { - test.beforeFunc() - } - if test.afterFunc != nil { - defer test.afterFunc() - } - checkFunc := test.checkFunc - if test.checkFunc == nil { - checkFunc = defaultCheckFunc - } - - got := test.m.String() - if err := checkFunc(test.want, got); err != nil { - tt.Errorf("error = %v", err) - } - }) - } -} - -func Test_server_IsRunning(t *testing.T) { - type fields struct { - mode ServerMode - name string - mu sync.RWMutex - wg sync.WaitGroup - eg errgroup.Group - http struct { - srv *http.Server - h http.Handler - starter func(net.Listener) error - } - grpc struct { - srv *grpc.Server - keepAlive *grpcKeepalive - opts []grpc.ServerOption - regs []func(*grpc.Server) - } - lc *net.ListenConfig - tcfg *tls.Config - pwt time.Duration - sddur time.Duration - rht time.Duration - rt time.Duration - wt time.Duration - it time.Duration - ctrl control.SocketController - sockFlg control.SocketFlag - network net.NetworkType - socketPath string - port uint16 - host string - enableRestart bool - shuttingDown bool - running bool - preStartFunc func() error - preStopFunc func() error - } - type want struct { - want bool - } - type test struct { - name string - fields fields - want want - checkFunc func(want, bool) error - beforeFunc func() - afterFunc func() - } - defaultCheckFunc := func(w want, got bool) error { - if !reflect.DeepEqual(got, w.want) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got, w.want) - } - return nil - } - tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - fields: fields { - mode: nil, - name: "", - mu: sync.RWMutex{}, - wg: sync.WaitGroup{}, - eg: nil, - http: nil, - grpc: nil, - lc: nil, - tcfg: nil, - pwt: nil, - sddur: nil, - rht: nil, - rt: nil, - wt: nil, - it: nil, - ctrl: nil, - sockFlg: nil, - network: nil, - socketPath: "", - port: 0, - host: "", - enableRestart: false, - shuttingDown: false, - running: false, - preStartFunc: nil, - preStopFunc: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - fields: fields { - mode: nil, - name: "", - mu: sync.RWMutex{}, - wg: sync.WaitGroup{}, - eg: nil, - http: nil, - grpc: nil, - lc: nil, - tcfg: nil, - pwt: nil, - sddur: nil, - rht: nil, - rt: nil, - wt: nil, - it: nil, - ctrl: nil, - sockFlg: nil, - network: nil, - socketPath: "", - port: 0, - host: "", - enableRestart: false, - shuttingDown: false, - running: false, - preStartFunc: nil, - preStopFunc: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ - } - - for _, tc := range tests { - test := tc - t.Run(test.name, func(tt *testing.T) { - tt.Parallel() - defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) - if test.beforeFunc != nil { - test.beforeFunc() - } - if test.afterFunc != nil { - defer test.afterFunc() - } - checkFunc := test.checkFunc - if test.checkFunc == nil { - checkFunc = defaultCheckFunc - } - s := &server{ - mode: test.fields.mode, - name: test.fields.name, - mu: test.fields.mu, - wg: test.fields.wg, - eg: test.fields.eg, - http: test.fields.http, - grpc: test.fields.grpc, - lc: test.fields.lc, - tcfg: test.fields.tcfg, - pwt: test.fields.pwt, - sddur: test.fields.sddur, - rht: test.fields.rht, - rt: test.fields.rt, - wt: test.fields.wt, - it: test.fields.it, - ctrl: test.fields.ctrl, - sockFlg: test.fields.sockFlg, - network: test.fields.network, - socketPath: test.fields.socketPath, - port: test.fields.port, - host: test.fields.host, - enableRestart: test.fields.enableRestart, - shuttingDown: test.fields.shuttingDown, - running: test.fields.running, - preStartFunc: test.fields.preStartFunc, - preStopFunc: test.fields.preStopFunc, - } - - got := s.IsRunning() - if err := checkFunc(test.want, got); err != nil { - tt.Errorf("error = %v", err) - } - }) - } -} - -func Test_server_Name(t *testing.T) { - type fields struct { - mode ServerMode - name string - mu sync.RWMutex - wg sync.WaitGroup - eg errgroup.Group - http struct { - srv *http.Server - h http.Handler - starter func(net.Listener) error - } - grpc struct { - srv *grpc.Server - keepAlive *grpcKeepalive - opts []grpc.ServerOption - regs []func(*grpc.Server) - } - lc *net.ListenConfig - tcfg *tls.Config - pwt time.Duration - sddur time.Duration - rht time.Duration - rt time.Duration - wt time.Duration - it time.Duration - ctrl control.SocketController - sockFlg control.SocketFlag - network net.NetworkType - socketPath string - port uint16 - host string - enableRestart bool - shuttingDown bool - running bool - preStartFunc func() error - preStopFunc func() error - } - type want struct { - want string - } - type test struct { - name string - fields fields - want want - checkFunc func(want, string) error - beforeFunc func() - afterFunc func() - } - defaultCheckFunc := func(w want, got string) error { - if !reflect.DeepEqual(got, w.want) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got, w.want) - } - return nil - } - tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - fields: fields { - mode: nil, - name: "", - mu: sync.RWMutex{}, - wg: sync.WaitGroup{}, - eg: nil, - http: nil, - grpc: nil, - lc: nil, - tcfg: nil, - pwt: nil, - sddur: nil, - rht: nil, - rt: nil, - wt: nil, - it: nil, - ctrl: nil, - sockFlg: nil, - network: nil, - socketPath: "", - port: 0, - host: "", - enableRestart: false, - shuttingDown: false, - running: false, - preStartFunc: nil, - preStopFunc: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - fields: fields { - mode: nil, - name: "", - mu: sync.RWMutex{}, - wg: sync.WaitGroup{}, - eg: nil, - http: nil, - grpc: nil, - lc: nil, - tcfg: nil, - pwt: nil, - sddur: nil, - rht: nil, - rt: nil, - wt: nil, - it: nil, - ctrl: nil, - sockFlg: nil, - network: nil, - socketPath: "", - port: 0, - host: "", - enableRestart: false, - shuttingDown: false, - running: false, - preStartFunc: nil, - preStopFunc: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ - } - - for _, tc := range tests { - test := tc - t.Run(test.name, func(tt *testing.T) { - tt.Parallel() - defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) - if test.beforeFunc != nil { - test.beforeFunc() - } - if test.afterFunc != nil { - defer test.afterFunc() - } - checkFunc := test.checkFunc - if test.checkFunc == nil { - checkFunc = defaultCheckFunc - } - s := &server{ - mode: test.fields.mode, - name: test.fields.name, - mu: test.fields.mu, - wg: test.fields.wg, - eg: test.fields.eg, - http: test.fields.http, - grpc: test.fields.grpc, - lc: test.fields.lc, - tcfg: test.fields.tcfg, - pwt: test.fields.pwt, - sddur: test.fields.sddur, - rht: test.fields.rht, - rt: test.fields.rt, - wt: test.fields.wt, - it: test.fields.it, - ctrl: test.fields.ctrl, - sockFlg: test.fields.sockFlg, - network: test.fields.network, - socketPath: test.fields.socketPath, - port: test.fields.port, - host: test.fields.host, - enableRestart: test.fields.enableRestart, - shuttingDown: test.fields.shuttingDown, - running: test.fields.running, - preStartFunc: test.fields.preStartFunc, - preStopFunc: test.fields.preStopFunc, - } - - got := s.Name() - if err := checkFunc(test.want, got); err != nil { - tt.Errorf("error = %v", err) - } - }) - } -} - -func Test_server_ListenAndServe(t *testing.T) { - type args struct { - ctx context.Context - ech chan<- error - } - type fields struct { - mode ServerMode - name string - mu sync.RWMutex - wg sync.WaitGroup - eg errgroup.Group - http struct { - srv *http.Server - h http.Handler - starter func(net.Listener) error - } - grpc struct { - srv *grpc.Server - keepAlive *grpcKeepalive - opts []grpc.ServerOption - regs []func(*grpc.Server) - } - lc *net.ListenConfig - tcfg *tls.Config - pwt time.Duration - sddur time.Duration - rht time.Duration - rt time.Duration - wt time.Duration - it time.Duration - ctrl control.SocketController - sockFlg control.SocketFlag - network net.NetworkType - socketPath string - port uint16 - host string - enableRestart bool - shuttingDown bool - running bool - preStartFunc func() error - preStopFunc func() error - } - type want struct { - err error - } - type test struct { - name string - args args - fields fields - want want - checkFunc func(want, error) error - beforeFunc func(args) - afterFunc func(args) - } - defaultCheckFunc := func(w want, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) - } - return nil - } - tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - ctx: nil, - ech: nil, - }, - fields: fields { - mode: nil, - name: "", - mu: sync.RWMutex{}, - wg: sync.WaitGroup{}, - eg: nil, - http: nil, - grpc: nil, - lc: nil, - tcfg: nil, - pwt: nil, - sddur: nil, - rht: nil, - rt: nil, - wt: nil, - it: nil, - ctrl: nil, - sockFlg: nil, - network: nil, - socketPath: "", - port: 0, - host: "", - enableRestart: false, - shuttingDown: false, - running: false, - preStartFunc: nil, - preStopFunc: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - ctx: nil, - ech: nil, - }, - fields: fields { - mode: nil, - name: "", - mu: sync.RWMutex{}, - wg: sync.WaitGroup{}, - eg: nil, - http: nil, - grpc: nil, - lc: nil, - tcfg: nil, - pwt: nil, - sddur: nil, - rht: nil, - rt: nil, - wt: nil, - it: nil, - ctrl: nil, - sockFlg: nil, - network: nil, - socketPath: "", - port: 0, - host: "", - enableRestart: false, - shuttingDown: false, - running: false, - preStartFunc: nil, - preStopFunc: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ - } - - for _, tc := range tests { - test := tc - t.Run(test.name, func(tt *testing.T) { - tt.Parallel() - defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) - if test.beforeFunc != nil { - test.beforeFunc(test.args) - } - if test.afterFunc != nil { - defer test.afterFunc(test.args) - } - checkFunc := test.checkFunc - if test.checkFunc == nil { - checkFunc = defaultCheckFunc - } - s := &server{ - mode: test.fields.mode, - name: test.fields.name, - mu: test.fields.mu, - wg: test.fields.wg, - eg: test.fields.eg, - http: test.fields.http, - grpc: test.fields.grpc, - lc: test.fields.lc, - tcfg: test.fields.tcfg, - pwt: test.fields.pwt, - sddur: test.fields.sddur, - rht: test.fields.rht, - rt: test.fields.rt, - wt: test.fields.wt, - it: test.fields.it, - ctrl: test.fields.ctrl, - sockFlg: test.fields.sockFlg, - network: test.fields.network, - socketPath: test.fields.socketPath, - port: test.fields.port, - host: test.fields.host, - enableRestart: test.fields.enableRestart, - shuttingDown: test.fields.shuttingDown, - running: test.fields.running, - preStartFunc: test.fields.preStartFunc, - preStopFunc: test.fields.preStopFunc, - } - - err := s.ListenAndServe(test.args.ctx, test.args.ech) - if err := checkFunc(test.want, err); err != nil { - tt.Errorf("error = %v", err) - } - }) - } -} - -func Test_server_Shutdown(t *testing.T) { - type args struct { - ctx context.Context - } - type fields struct { - mode ServerMode - name string - mu sync.RWMutex - wg sync.WaitGroup - eg errgroup.Group - http struct { - srv *http.Server - h http.Handler - starter func(net.Listener) error - } - grpc struct { - srv *grpc.Server - keepAlive *grpcKeepalive - opts []grpc.ServerOption - regs []func(*grpc.Server) - } - lc *net.ListenConfig - tcfg *tls.Config - pwt time.Duration - sddur time.Duration - rht time.Duration - rt time.Duration - wt time.Duration - it time.Duration - ctrl control.SocketController - sockFlg control.SocketFlag - network net.NetworkType - socketPath string - port uint16 - host string - enableRestart bool - shuttingDown bool - running bool - preStartFunc func() error - preStopFunc func() error - } - type want struct { - err error - } - type test struct { - name string - args args - fields fields - want want - checkFunc func(want, error) error - beforeFunc func(args) - afterFunc func(args) - } - defaultCheckFunc := func(w want, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) - } - return nil - } - tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - ctx: nil, - }, - fields: fields { - mode: nil, - name: "", - mu: sync.RWMutex{}, - wg: sync.WaitGroup{}, - eg: nil, - http: nil, - grpc: nil, - lc: nil, - tcfg: nil, - pwt: nil, - sddur: nil, - rht: nil, - rt: nil, - wt: nil, - it: nil, - ctrl: nil, - sockFlg: nil, - network: nil, - socketPath: "", - port: 0, - host: "", - enableRestart: false, - shuttingDown: false, - running: false, - preStartFunc: nil, - preStopFunc: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - ctx: nil, - }, - fields: fields { - mode: nil, - name: "", - mu: sync.RWMutex{}, - wg: sync.WaitGroup{}, - eg: nil, - http: nil, - grpc: nil, - lc: nil, - tcfg: nil, - pwt: nil, - sddur: nil, - rht: nil, - rt: nil, - wt: nil, - it: nil, - ctrl: nil, - sockFlg: nil, - network: nil, - socketPath: "", - port: 0, - host: "", - enableRestart: false, - shuttingDown: false, - running: false, - preStartFunc: nil, - preStopFunc: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ - } - - for _, tc := range tests { - test := tc - t.Run(test.name, func(tt *testing.T) { - tt.Parallel() - defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) - if test.beforeFunc != nil { - test.beforeFunc(test.args) - } - if test.afterFunc != nil { - defer test.afterFunc(test.args) - } - checkFunc := test.checkFunc - if test.checkFunc == nil { - checkFunc = defaultCheckFunc - } - s := &server{ - mode: test.fields.mode, - name: test.fields.name, - mu: test.fields.mu, - wg: test.fields.wg, - eg: test.fields.eg, - http: test.fields.http, - grpc: test.fields.grpc, - lc: test.fields.lc, - tcfg: test.fields.tcfg, - pwt: test.fields.pwt, - sddur: test.fields.sddur, - rht: test.fields.rht, - rt: test.fields.rt, - wt: test.fields.wt, - it: test.fields.it, - ctrl: test.fields.ctrl, - sockFlg: test.fields.sockFlg, - network: test.fields.network, - socketPath: test.fields.socketPath, - port: test.fields.port, - host: test.fields.host, - enableRestart: test.fields.enableRestart, - shuttingDown: test.fields.shuttingDown, - running: test.fields.running, - preStartFunc: test.fields.preStartFunc, - preStopFunc: test.fields.preStopFunc, - } - - err := s.Shutdown(test.args.ctx) - if err := checkFunc(test.want, err); err != nil { - tt.Errorf("error = %v", err) - } - }) - } -} diff --git a/internal/servers/servers_test.go b/internal/servers/servers_test.go index 1f5d78aedf..47a4cd413d 100644 --- a/internal/servers/servers_test.go +++ b/internal/servers/servers_test.go @@ -24,7 +24,6 @@ import ( "github.com/vdaas/vald/internal/errgroup" "github.com/vdaas/vald/internal/errors" "github.com/vdaas/vald/internal/servers/server" - "github.com/vdaas/vald/internal/test/goleak" ) func TestNew(t *testing.T) { @@ -90,7 +89,7 @@ func TestNew(t *testing.T) { } } -func TestListenAndServe(t *testing.T) { +func Test_listener_ListenAndServe(t *testing.T) { type args struct { ctx context.Context } @@ -106,11 +105,15 @@ func TestListenAndServe(t *testing.T) { args args field field checkFunc func(got, want <-chan error) error + afterFunc func() want <-chan error } tests := []test{ func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, ctx := errgroup.New(ctx) + srv1 := &mockServer{ IsRunningFunc: func() bool { return false @@ -151,13 +154,13 @@ func TestListenAndServe(t *testing.T) { name: "ListenAndServe is success", args: args{ ctx: func() context.Context { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) defer cancel() return ctx }(), }, field: field{ - eg: errgroup.Get(), + eg: eg, servers: servers, sus: sus, }, @@ -184,6 +187,10 @@ func TestListenAndServe(t *testing.T) { } return nil }, + afterFunc: func() { + cancel() + eg.Wait() + }, } }(), } @@ -192,6 +199,9 @@ func TestListenAndServe(t *testing.T) { t.Run(tt.name, func(t *testing.T) { ctx, cancel := context.WithCancel(tt.args.ctx) defer cancel() + if tt.afterFunc != nil { + defer tt.afterFunc() + } l := &listener{ eg: tt.field.eg, @@ -207,7 +217,7 @@ func TestListenAndServe(t *testing.T) { } } -func TestShutdown(t *testing.T) { +func Test_listener_Shutdown(t *testing.T) { type args struct { ctx context.Context } @@ -224,11 +234,22 @@ func TestShutdown(t *testing.T) { args args field field checkFunc func(got, want error) error + afterFunc func() want error } + defaultCheckFunc := func(got, want error) error { + if !errors.Is(want, got) { + return errors.Errorf("not equals. want: %v, got: %v", want, got) + } + return nil + } + tests := []test{ func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, ctx := errgroup.New(ctx) + srv1 := &mockServer{ IsRunningFunc: func() bool { return true @@ -260,44 +281,47 @@ func TestShutdown(t *testing.T) { return test{ name: "Shutdown is success", args: args{ - ctx: context.Background(), + ctx: ctx, }, field: field{ - eg: errgroup.Get(), + eg: eg, servers: servers, sds: sds, }, - - checkFunc: func(got, want error) error { - if got != nil { - return errors.Errorf("return error: %v", got) - } - return nil + afterFunc: func() { + cancel() + eg.Wait() }, want: nil, } }(), - { - name: "server not found error", - args: args{ - ctx: context.Background(), - }, - field: field{ - eg: errgroup.Get(), - servers: map[string]server.Server{}, - sds: []string{ - "srv1", + func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, ctx := errgroup.New(ctx) + + return test{ + name: "server not found error", + args: args{ + ctx: ctx, }, - }, - checkFunc: func(got, want error) error { - if !errors.Is(want, got) { - return errors.Errorf("not equals. want: %v, got: %v", want, got) - } - return nil - }, - want: errors.ErrServerNotFound("srv1"), - }, + field: field{ + eg: eg, + servers: map[string]server.Server{}, + sds: []string{ + "srv1", + }, + }, + afterFunc: func() { + cancel() + eg.Wait() + }, + want: errors.ErrServerNotFound("srv1"), + } + }(), func() test { + ctx, cancel := context.WithCancel(context.Background()) + eg, ctx := errgroup.New(ctx) + want := errors.Wrap(errors.Errorf("unexpected error"), "faild to shutdown") srv1 := &mockServer{ @@ -320,18 +344,16 @@ func TestShutdown(t *testing.T) { return test{ name: "unexpected error", args: args{ - ctx: context.Background(), + ctx: ctx, }, field: field{ - eg: errgroup.Get(), + eg: eg, servers: servers, sds: sds, }, - checkFunc: func(got, want error) error { - if got.Error() != want.Error() { - return errors.Errorf("not equals. want: %v, got: %v", want, got) - } - return nil + afterFunc: func() { + cancel() + eg.Wait() }, want: want, } @@ -342,6 +364,13 @@ func TestShutdown(t *testing.T) { t.Run(tt.name, func(t *testing.T) { ctx, cancel := context.WithCancel(tt.args.ctx) defer cancel() + if tt.afterFunc != nil { + defer tt.afterFunc() + } + + if tt.checkFunc == nil { + tt.checkFunc = defaultCheckFunc + } l := &listener{ eg: tt.field.eg, @@ -357,203 +386,3 @@ func TestShutdown(t *testing.T) { }) } } - -func Test_listener_ListenAndServe(t *testing.T) { - type args struct { - ctx context.Context - } - type fields struct { - servers map[string]server.Server - eg errgroup.Group - sus []string - sds []string - sddur time.Duration - } - type want struct { - want <-chan error - } - type test struct { - name string - args args - fields fields - want want - checkFunc func(want, <-chan error) error - beforeFunc func(args) - afterFunc func(args) - } - defaultCheckFunc := func(w want, got <-chan error) error { - if !reflect.DeepEqual(got, w.want) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got, w.want) - } - return nil - } - tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - ctx: nil, - }, - fields: fields { - servers: nil, - eg: nil, - sus: nil, - sds: nil, - sddur: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - ctx: nil, - }, - fields: fields { - servers: nil, - eg: nil, - sus: nil, - sds: nil, - sddur: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ - } - - for _, test := range tests { - t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) - if test.beforeFunc != nil { - test.beforeFunc(test.args) - } - if test.afterFunc != nil { - defer test.afterFunc(test.args) - } - checkFunc := test.checkFunc - if test.checkFunc == nil { - checkFunc = defaultCheckFunc - } - l := &listener{ - servers: test.fields.servers, - eg: test.fields.eg, - sus: test.fields.sus, - sds: test.fields.sds, - sddur: test.fields.sddur, - } - - got := l.ListenAndServe(test.args.ctx) - if err := checkFunc(test.want, got); err != nil { - tt.Errorf("error = %v", err) - } - }) - } -} - -func Test_listener_Shutdown(t *testing.T) { - type args struct { - ctx context.Context - } - type fields struct { - servers map[string]server.Server - eg errgroup.Group - sus []string - sds []string - sddur time.Duration - } - type want struct { - err error - } - type test struct { - name string - args args - fields fields - want want - checkFunc func(want, error) error - beforeFunc func(args) - afterFunc func(args) - } - defaultCheckFunc := func(w want, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) - } - return nil - } - tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - ctx: nil, - }, - fields: fields { - servers: nil, - eg: nil, - sus: nil, - sds: nil, - sddur: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - ctx: nil, - }, - fields: fields { - servers: nil, - eg: nil, - sus: nil, - sds: nil, - sddur: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ - } - - for _, test := range tests { - t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) - if test.beforeFunc != nil { - test.beforeFunc(test.args) - } - if test.afterFunc != nil { - defer test.afterFunc(test.args) - } - checkFunc := test.checkFunc - if test.checkFunc == nil { - checkFunc = defaultCheckFunc - } - l := &listener{ - servers: test.fields.servers, - eg: test.fields.eg, - sus: test.fields.sus, - sds: test.fields.sds, - sddur: test.fields.sddur, - } - - err := l.Shutdown(test.args.ctx) - if err := checkFunc(test.want, err); err != nil { - tt.Errorf("error = %v", err) - } - }) - } -}