diff --git a/hack/benchmark/assets/x1b/loader.go b/hack/benchmark/assets/x1b/loader.go index ef9318ad53..0a5e9450f9 100644 --- a/hack/benchmark/assets/x1b/loader.go +++ b/hack/benchmark/assets/x1b/loader.go @@ -72,7 +72,7 @@ type ivecs struct { *file } -func open(fname string, elementSize int) (f *file, err error) { +func doOpen(fname string, elementSize int) (f *file, err error) { fp, err := os.Open(fname) if err != nil { return nil, err @@ -164,7 +164,7 @@ func (iv *ivecs) Load(i int) (interface{}, error) { } func NewUint8Vectors(fname string) (Uint8Vectors, error) { - f, err := open(fname, 1) + f, err := doOpen(fname, 1) if err != nil { return nil, err } @@ -172,7 +172,7 @@ func NewUint8Vectors(fname string) (Uint8Vectors, error) { } func NewFloatVectors(fname string) (FloatVectors, error) { - f, err := open(fname, 4) + f, err := doOpen(fname, 4) if err != nil { return nil, err } @@ -180,7 +180,7 @@ func NewFloatVectors(fname string) (FloatVectors, error) { } func NewInt32Vectors(fname string) (Int32Vectors, error) { - f, err := open(fname, 4) + f, err := doOpen(fname, 4) if err != nil { return nil, err } diff --git a/hack/benchmark/assets/x1b/loader_test.go b/hack/benchmark/assets/x1b/loader_test.go index 13ab633d07..f6319a95f4 100644 --- a/hack/benchmark/assets/x1b/loader_test.go +++ b/hack/benchmark/assets/x1b/loader_test.go @@ -21,7 +21,7 @@ import ( "github.com/vdaas/vald/internal/test/goleak" ) -func Test_open(t *testing.T) { +func Test_doOpen(t *testing.T) { t.Parallel() type args struct { fname string @@ -94,7 +94,7 @@ func Test_open(t *testing.T) { checkFunc = defaultCheckFunc } - gotF, err := open(test.args.fname, test.args.elementSize) + gotF, err := doOpen(test.args.fname, test.args.elementSize) if err := checkFunc(test.want, gotF, err); err != nil { tt.Errorf("error = %v", err) } diff --git a/internal/db/rdb/mysql/mysql.go b/internal/db/rdb/mysql/mysql.go index ef4a518faf..3771546f0d 100644 --- a/internal/db/rdb/mysql/mysql.go +++ b/internal/db/rdb/mysql/mysql.go @@ -391,7 +391,8 @@ func (m *mySQLClient) SetVectors(ctx context.Context, vecs ...Vector) error { return tx.Commit() } -func (m *mySQLClient) deleteVector(ctx context.Context, val string) error { +// DeleteVector deletes vector data from backup_vector table and podIPs from pod_ip table using vector's uuid. +func (m *mySQLClient) DeleteVector(ctx context.Context, val string) error { if !m.connected.Load().(bool) { return errors.ErrMySQLConnectionClosed } @@ -432,15 +433,10 @@ func (m *mySQLClient) deleteVector(ctx context.Context, val string) error { return tx.Commit() } -// DeleteVector deletes vector data from backup_vector table and podIPs from pod_ip table using vector's uuid. -func (m *mySQLClient) DeleteVector(ctx context.Context, uuid string) error { - return m.deleteVector(ctx, uuid) -} - // DeleteVectors is the same as DeleteVector() but it deletes multiple records. func (m *mySQLClient) DeleteVectors(ctx context.Context, uuids ...string) (err error) { for _, uuid := range uuids { - err = m.deleteVector(ctx, uuid) + err = m.DeleteVector(ctx, uuid) if err != nil { return err } diff --git a/internal/db/rdb/mysql/mysql_test.go b/internal/db/rdb/mysql/mysql_test.go index 412e8af00b..761e14bd0e 100644 --- a/internal/db/rdb/mysql/mysql_test.go +++ b/internal/db/rdb/mysql/mysql_test.go @@ -2793,10 +2793,10 @@ func Test_mySQLClient_SetVectors(t *testing.T) { } } -func Test_mySQLClient_deleteVector(t *testing.T) { +func Test_mySQLClient_DeleteVector(t *testing.T) { type args struct { - ctx context.Context - val string + ctx context.Context + uuid string } type fields struct { session dbr.Session @@ -2822,13 +2822,141 @@ func Test_mySQLClient_deleteVector(t *testing.T) { return nil } tests := []test{ + func() test { + return test{ + name: "return nil when deleteVector success with empty-uuid", + args: args{ + ctx: context.Background(), + uuid: "", + }, + fields: fields{ + session: &dbr.MockSession{ + BeginFunc: func() (dbr.Tx, error) { + return &dbr.MockTx{ + CommitFunc: func() error { + return nil + }, + RollbackUnlessCommittedFunc: func() {}, + SelectFunc: func(column ...string) dbr.SelectStmt { + s := new(dbr.MockSelect) + s.FromFunc = func(table interface{}) dbr.SelectStmt { + return s + } + s.WhereFunc = func(query interface{}, value ...interface{}) dbr.SelectStmt { + return s + } + s.LimitFunc = func(n uint64) dbr.SelectStmt { + return s + } + s.LoadContextFunc = func(ctx context.Context, value interface{}) (int, error) { + var id int64 + if reflect.TypeOf(value) == reflect.TypeOf(&id) { + id := int64(1) + reflect.ValueOf(value).Elem().Set(reflect.ValueOf(id)) + return 1, nil + } + return 0, nil + } + + return s + }, + DeleteFromFunc: func(table string) dbr.DeleteStmt { + s := new(dbr.MockDelete) + s.ExecContextFunc = func(ctx context.Context) (sql.Result, error) { + return nil, nil + } + s.WhereFunc = func(query interface{}, value ...interface{}) dbr.DeleteStmt { + return s + } + return s + }, + }, nil + }, + }, + connected: func() (v atomic.Value) { + v.Store(true) + return + }(), + dbr: &dbr.MockDBR{ + EqFunc: func(col string, val interface{}) dbr.Builder { + return dbr.New().Eq(col, val) + }, + }, + }, + want: want{}, + } + }(), + func() test { + return test{ + name: "return nil when deleteVector success with uuid", + args: args{ + ctx: context.Background(), + uuid: "vald-01", + }, + fields: fields{ + session: &dbr.MockSession{ + BeginFunc: func() (dbr.Tx, error) { + return &dbr.MockTx{ + CommitFunc: func() error { + return nil + }, + RollbackUnlessCommittedFunc: func() {}, + SelectFunc: func(column ...string) dbr.SelectStmt { + s := new(dbr.MockSelect) + s.FromFunc = func(table interface{}) dbr.SelectStmt { + return s + } + s.WhereFunc = func(query interface{}, value ...interface{}) dbr.SelectStmt { + return s + } + s.LimitFunc = func(n uint64) dbr.SelectStmt { + return s + } + s.LoadContextFunc = func(ctx context.Context, value interface{}) (int, error) { + var id int64 + if reflect.TypeOf(value) == reflect.TypeOf(&id) { + id := int64(1) + reflect.ValueOf(value).Elem().Set(reflect.ValueOf(id)) + return 1, nil + } + return 0, nil + } + + return s + }, + DeleteFromFunc: func(table string) dbr.DeleteStmt { + s := new(dbr.MockDelete) + s.ExecContextFunc = func(ctx context.Context) (sql.Result, error) { + return nil, nil + } + s.WhereFunc = func(query interface{}, value ...interface{}) dbr.DeleteStmt { + return s + } + return s + }, + }, nil + }, + }, + connected: func() (v atomic.Value) { + v.Store(true) + return + }(), + dbr: &dbr.MockDBR{ + EqFunc: func(col string, val interface{}) dbr.Builder { + return dbr.New().Eq(col, val) + }, + }, + }, + want: want{}, + } + }(), func() test { err := errors.ErrMySQLConnectionClosed return test{ name: "return error when MySQL connection is closed", args: args{ - ctx: context.Background(), - val: "vald-01", + ctx: context.Background(), + uuid: "vald-01", }, fields: fields{ connected: func() (v atomic.Value) { @@ -2845,8 +2973,8 @@ func Test_mySQLClient_deleteVector(t *testing.T) { return test{ name: "return error when MySQL session is nil", args: args{ - ctx: context.Background(), - val: "vald-01", + ctx: context.Background(), + uuid: "vald-01", }, fields: fields{ connected: func() (v atomic.Value) { @@ -2864,8 +2992,8 @@ func Test_mySQLClient_deleteVector(t *testing.T) { return test{ name: "return error when session.Begin returns error", args: args{ - ctx: context.Background(), - val: "vald-01", + ctx: context.Background(), + uuid: "vald-01", }, fields: fields{ session: &dbr.MockSession{ @@ -2888,8 +3016,8 @@ func Test_mySQLClient_deleteVector(t *testing.T) { return test{ name: "return error when transacton is nil", args: args{ - ctx: context.Background(), - val: "vald-01", + ctx: context.Background(), + uuid: "vald-01", }, fields: fields{ session: &dbr.MockSession{ @@ -2912,8 +3040,8 @@ func Test_mySQLClient_deleteVector(t *testing.T) { return test{ name: "return error when Select(idColumnName) returns error", args: args{ - ctx: context.Background(), - val: "vald-01", + ctx: context.Background(), + uuid: "vald-01", }, fields: fields{ session: &dbr.MockSession{ @@ -2964,8 +3092,8 @@ func Test_mySQLClient_deleteVector(t *testing.T) { return test{ name: "return error when returned id = 0 from Select statement", args: args{ - ctx: context.Background(), - val: uuid, + ctx: context.Background(), + uuid: uuid, }, fields: fields{ session: &dbr.MockSession{ @@ -3015,8 +3143,8 @@ func Test_mySQLClient_deleteVector(t *testing.T) { return test{ name: "return error when DeleteFromFunc(vectorTableName) returns error", args: args{ - ctx: context.Background(), - val: "vald-01", + ctx: context.Background(), + uuid: "vald-01", }, fields: fields{ session: &dbr.MockSession{ @@ -3085,8 +3213,8 @@ func Test_mySQLClient_deleteVector(t *testing.T) { return test{ name: "return error when DeleteFromFunc(podIPTableNmae) returns error", args: args{ - ctx: context.Background(), - val: "vald-01", + ctx: context.Background(), + uuid: "vald-01", }, fields: fields{ session: &dbr.MockSession{ @@ -3153,193 +3281,6 @@ func Test_mySQLClient_deleteVector(t *testing.T) { func() test { return test{ name: "return nil when no error occurs", - args: args{ - ctx: context.Background(), - val: "vald-01", - }, - fields: fields{ - session: &dbr.MockSession{ - BeginFunc: func() (dbr.Tx, error) { - return &dbr.MockTx{ - CommitFunc: func() error { - return nil - }, - RollbackUnlessCommittedFunc: func() {}, - SelectFunc: func(column ...string) dbr.SelectStmt { - s := new(dbr.MockSelect) - s.FromFunc = func(table interface{}) dbr.SelectStmt { - return s - } - s.WhereFunc = func(query interface{}, value ...interface{}) dbr.SelectStmt { - return s - } - s.LimitFunc = func(n uint64) dbr.SelectStmt { - return s - } - s.LoadContextFunc = func(ctx context.Context, value interface{}) (int, error) { - var id int64 - if reflect.TypeOf(value) == reflect.TypeOf(&id) { - id := int64(1) - reflect.ValueOf(value).Elem().Set(reflect.ValueOf(id)) - return 1, nil - } - return 0, nil - } - - return s - }, - DeleteFromFunc: func(table string) dbr.DeleteStmt { - s := new(dbr.MockDelete) - s.ExecContextFunc = func(ctx context.Context) (sql.Result, error) { - return nil, nil - } - s.WhereFunc = func(query interface{}, value ...interface{}) dbr.DeleteStmt { - return s - } - return s - }, - }, nil - }, - }, - connected: func() (v atomic.Value) { - v.Store(true) - return - }(), - dbr: &dbr.MockDBR{ - EqFunc: func(col string, val interface{}) dbr.Builder { - return dbr.New().Eq(col, val) - }, - }, - }, - want: want{}, - } - }(), - } - - for _, tc := range tests { - test := tc - t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(tt, goleakIgnoreOptions...) - if test.beforeFunc != nil { - test.beforeFunc(test.args) - } - if test.afterFunc != nil { - defer test.afterFunc(test.args) - } - checkFunc := test.checkFunc - if test.checkFunc == nil { - checkFunc = defaultCheckFunc - } - m := &mySQLClient{ - session: test.fields.session, - connected: test.fields.connected, - dbr: test.fields.dbr, - } - - err := m.deleteVector(test.args.ctx, test.args.val) - if err := checkFunc(test.want, err); err != nil { - tt.Errorf("error = %v", err) - } - }) - } -} - -func Test_mySQLClient_DeleteVector(t *testing.T) { - type args struct { - ctx context.Context - uuid string - } - type fields struct { - session dbr.Session - connected atomic.Value - dbr dbr.DBR - } - type want struct { - err error - } - type test struct { - name string - args args - fields fields - want want - checkFunc func(want, error) error - beforeFunc func(args) - afterFunc func(args) - } - defaultCheckFunc := func(w want, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) - } - return nil - } - tests := []test{ - func() test { - return test{ - name: "return nil when deleteVector success with empty-uuid", - args: args{ - ctx: context.Background(), - uuid: "", - }, - fields: fields{ - session: &dbr.MockSession{ - BeginFunc: func() (dbr.Tx, error) { - return &dbr.MockTx{ - CommitFunc: func() error { - return nil - }, - RollbackUnlessCommittedFunc: func() {}, - SelectFunc: func(column ...string) dbr.SelectStmt { - s := new(dbr.MockSelect) - s.FromFunc = func(table interface{}) dbr.SelectStmt { - return s - } - s.WhereFunc = func(query interface{}, value ...interface{}) dbr.SelectStmt { - return s - } - s.LimitFunc = func(n uint64) dbr.SelectStmt { - return s - } - s.LoadContextFunc = func(ctx context.Context, value interface{}) (int, error) { - var id int64 - if reflect.TypeOf(value) == reflect.TypeOf(&id) { - id := int64(1) - reflect.ValueOf(value).Elem().Set(reflect.ValueOf(id)) - return 1, nil - } - return 0, nil - } - - return s - }, - DeleteFromFunc: func(table string) dbr.DeleteStmt { - s := new(dbr.MockDelete) - s.ExecContextFunc = func(ctx context.Context) (sql.Result, error) { - return nil, nil - } - s.WhereFunc = func(query interface{}, value ...interface{}) dbr.DeleteStmt { - return s - } - return s - }, - }, nil - }, - }, - connected: func() (v atomic.Value) { - v.Store(true) - return - }(), - dbr: &dbr.MockDBR{ - EqFunc: func(col string, val interface{}) dbr.Builder { - return dbr.New().Eq(col, val) - }, - }, - }, - want: want{}, - } - }(), - func() test { - return test{ - name: "return nil when deleteVector success with uuid", args: args{ ctx: context.Background(), uuid: "vald-01", diff --git a/internal/file/file.go b/internal/file/file.go index 0445e2257a..3afff9455b 100644 --- a/internal/file/file.go +++ b/internal/file/file.go @@ -89,14 +89,14 @@ func Open(path string, flg int, perm fs.FileMode) (file *os.File, err error) { } func MoveDir(ctx context.Context, src, dst string) (err error) { - return moveDir(ctx, src, dst, true) + return doMoveDir(ctx, src, dst, true) } -func moveDir(ctx context.Context, src, dst string, rollback bool) (err error) { +func doMoveDir(ctx context.Context, src, dst string, rollback bool) (err error) { if len(src) == 0 || len(dst) == 0 || src == dst { return nil } - exits, fi, err := exists(src) + exits, fi, err := doExists(src) if !exits || fi == nil || !fi.IsDir() || err != nil { return errors.ErrDirectoryNotFound(err, src, fi) } @@ -105,7 +105,7 @@ func moveDir(ctx context.Context, src, dst string, rollback bool) (err error) { if err != nil { log.Debug(errors.ErrFailedToRenameDir(err, src, dst, nil, nil)) var tmpPath string - exits, fi, err := exists(dst) + exits, fi, err := doExists(dst) if exits && fi.IsDir() && err == nil { tmpPath = Join(filepath.Dir(dst), "tmp-"+strconv.FormatInt(fastime.UnixNanoNow(), 10)) _ = os.RemoveAll(tmpPath) @@ -123,7 +123,7 @@ func moveDir(ctx context.Context, src, dst string, rollback bool) (err error) { if err != nil && Exists(dst) { err = errors.ErrFailedToRemoveDir(err, dst, nil) if rollback { - err = errors.Wrap(moveDir(ctx, tmpPath, dst, false), errors.Wrapf(err, "trying to recover temporary file %s to rollback previous operation", tmpPath).Error()) + err = errors.Wrap(doMoveDir(ctx, tmpPath, dst, false), errors.Wrapf(err, "trying to recover temporary file %s to rollback previous operation", tmpPath).Error()) } log.Warn(err) return err @@ -131,7 +131,7 @@ func moveDir(ctx context.Context, src, dst string, rollback bool) (err error) { } log.Debugf("directory %s successfully moved to tmp location %s", dst, tmpPath) } - exits, fi, err = exists(src) + exits, fi, err = doExists(src) if exits && fi != nil && fi.IsDir() && err == nil { err = os.Rename(src, dst) if err != nil { @@ -140,7 +140,7 @@ func moveDir(ctx context.Context, src, dst string, rollback bool) (err error) { if err != nil { err = errors.ErrFailedToCopyDir(err, src, dst, fi, nil) if rollback { - err = errors.Wrap(moveDir(ctx, tmpPath, dst, false), errors.Wrapf(err, "trying to recover temporary file %s to rollback previous operation", tmpPath).Error()) + err = errors.Wrap(doMoveDir(ctx, tmpPath, dst, false), errors.Wrapf(err, "trying to recover temporary file %s to rollback previous operation", tmpPath).Error()) } log.Warn(err) return err @@ -223,7 +223,7 @@ func CopyFileWithPerm(ctx context.Context, src, dst string, perm fs.FileMode) (n } }() - exist, fi, err := exists(src) + exist, fi, err := doExists(src) switch { case !exist, fi == nil, fi.Size() == 0, fi.IsDir(): return 0, errors.Wrap(err, errors.ErrFileNotFound(src).Error()) @@ -258,23 +258,23 @@ func CopyFileWithPerm(ctx context.Context, src, dst string, perm fs.FileMode) (n } func WriteFile(ctx context.Context, target string, r io.Reader, perm fs.FileMode) (n int64, err error) { - return writeFile(ctx, target, r, os.O_CREATE|os.O_WRONLY|os.O_SYNC, perm) + return doWriteFile(ctx, target, r, os.O_CREATE|os.O_WRONLY|os.O_SYNC, perm) } func OverWriteFile(ctx context.Context, target string, r io.Reader, perm fs.FileMode) (n int64, err error) { - return writeFile(ctx, target, r, os.O_CREATE|os.O_TRUNC|os.O_WRONLY|os.O_SYNC, perm) + return doWriteFile(ctx, target, r, os.O_CREATE|os.O_TRUNC|os.O_WRONLY|os.O_SYNC, perm) } func AppendFile(ctx context.Context, target string, r io.Reader, perm fs.FileMode) (n int64, err error) { - return writeFile(ctx, target, r, os.O_CREATE|os.O_APPEND|os.O_RDWR|os.O_SYNC, perm) + return doWriteFile(ctx, target, r, os.O_CREATE|os.O_APPEND|os.O_RDWR|os.O_SYNC, perm) } -func writeFile(ctx context.Context, target string, r io.Reader, flg int, perm fs.FileMode) (n int64, err error) { +func doWriteFile(ctx context.Context, target string, r io.Reader, flg int, perm fs.FileMode) (n int64, err error) { if len(target) == 0 || r == nil { return 0, nil } - exist, fi, err := exists(target) + exist, fi, err := doExists(target) switch { case err == nil, exist, fi != nil && fi.Size() != 0, fi != nil && fi.IsDir(): err = errors.ErrFileAlreadyExists(target) @@ -359,13 +359,13 @@ func ReadFile(path string) (n []byte, err error) { // Exists returns file existence func Exists(path string) (e bool) { - e, _, _ = exists(path) + e, _, _ = doExists(path) return e } // ExistsWithDetail returns file existence func ExistsWithDetail(path string) (e bool, fi fs.FileInfo, err error) { - return exists(path) + return doExists(path) } // MkdirAll creates directory like mkdir -p @@ -375,7 +375,7 @@ func MkdirAll(path string, perm fs.FileMode) (err error) { fi fs.FileInfo merr, rerr error ) - exist, fi, err = exists(path) + exist, fi, err = doExists(path) if exist { if err == nil && fi != nil && fi.IsDir() { return nil @@ -447,8 +447,8 @@ func CreateTemp(baseDir string) (f *os.File, err error) { return nil, errors.ErrFailedToCreateFile(err, path, nil) } -// exists returns file existence with detailed information -func exists(path string) (exists bool, fi fs.FileInfo, err error) { +// doExists returns file existence with detailed information +func doExists(path string) (exists bool, fi fs.FileInfo, err error) { fi, err = os.Stat(path) if err != nil { if os.IsExist(err) { @@ -464,7 +464,7 @@ func exists(path string) (exists bool, fi fs.FileInfo, err error) { // ListInDir returns file list in directory func ListInDir(path string) ([]string, error) { - exists, fi, err := exists(path) + exists, fi, err := doExists(path) if !exists { return nil, err } @@ -484,7 +484,7 @@ func Join(paths ...string) (path string) { return "" } if len(paths) > 1 { - path = join(paths...) + path = doJoin(paths...) } else { path = replacer.Replace(paths[0]) } @@ -498,7 +498,7 @@ func Join(paths ...string) (path string) { log.Warn(err) return filepath.Clean(path) } - return filepath.Clean(join(root, path)) + return filepath.Clean(doJoin(root, path)) } var replacer = strings.NewReplacer( @@ -508,7 +508,7 @@ var replacer = strings.NewReplacer( string(os.PathSeparator), ) -func join(paths ...string) (path string) { +func doJoin(paths ...string) (path string) { for i, path := range paths { if path != "" { return replacer.Replace(strings.Join(paths[i:], string(os.PathSeparator))) diff --git a/internal/file/file_test.go b/internal/file/file_test.go index 8fc0d3b674..512b9181d7 100644 --- a/internal/file/file_test.go +++ b/internal/file/file_test.go @@ -484,7 +484,7 @@ func TestExistsWithDetail(t *testing.T) { } } -func Test_exists(t *testing.T) { +func Test_doExists(t *testing.T) { type args struct { path string } @@ -556,7 +556,7 @@ func Test_exists(t *testing.T) { test.checkFunc = defaultCheckFunc } - gotExists, gotFi, err := exists(test.args.path) + gotExists, gotFi, err := doExists(test.args.path) if err := test.checkFunc(test.want, gotExists, gotFi, err); err != nil { tt.Errorf("error = %v", err) } @@ -955,7 +955,7 @@ func Test_moveDir(t *testing.T) { checkFunc = defaultCheckFunc } - err := moveDir(test.args.ctx, test.args.src, test.args.dst, test.args.rollback) + err := doMoveDir(test.args.ctx, test.args.src, test.args.dst, test.args.rollback) if err := checkFunc(test.want, err); err != nil { tt.Errorf("error = %v", err) } @@ -1122,7 +1122,7 @@ func TestJoin(t *testing.T) { } } -func Test_join(t *testing.T) { +func Test_doJoin(t *testing.T) { type args struct { paths []string } @@ -1187,7 +1187,7 @@ func Test_join(t *testing.T) { checkFunc = defaultCheckFunc } - gotPath := join(test.args.paths...) + gotPath := doJoin(test.args.paths...) if err := checkFunc(test.want, gotPath); err != nil { tt.Errorf("error = %v", err) } @@ -1606,7 +1606,7 @@ func TestAppendFile(t *testing.T) { } } -func Test_writeFile(t *testing.T) { +func Test_doWriteFile(t *testing.T) { type args struct { ctx context.Context target string @@ -1687,7 +1687,7 @@ func Test_writeFile(t *testing.T) { checkFunc = defaultCheckFunc } - gotN, err := writeFile(test.args.ctx, test.args.target, test.args.r, test.args.flg, test.args.perm) + gotN, err := doWriteFile(test.args.ctx, test.args.target, test.args.r, test.args.flg, test.args.perm) if err := checkFunc(test.want, gotN, err); err != nil { tt.Errorf("error = %v", err) } diff --git a/internal/net/grpc/client.go b/internal/net/grpc/client.go index 5476390157..83e06a9445 100644 --- a/internal/net/grpc/client.go +++ b/internal/net/grpc/client.go @@ -326,7 +326,7 @@ func (g *gRPCClient) Range(ctx context.Context, case <-ctx.Done(): return false default: - g.do(ssctx, p, addr, true, func(ictx context.Context, + g.doConnect(ssctx, p, addr, true, func(ictx context.Context, conn *ClientConn, copts ...CallOption, ) (interface{}, error) { return nil, f(ictx, addr, conn, copts...) @@ -360,7 +360,7 @@ func (g *gRPCClient) RangeConcurrent(ctx context.Context, case <-egctx.Done(): return nil default: - g.do(ssctx, p, addr, true, func(ictx context.Context, + g.doConnect(ssctx, p, addr, true, func(ictx context.Context, conn *ClientConn, copts ...CallOption, ) (interface{}, error) { return nil, f(ictx, addr, conn, copts...) @@ -403,7 +403,7 @@ func (g *gRPCClient) OrderedRange(ctx context.Context, span.End() } }() - g.do(ssctx, p, addr, true, func(ictx context.Context, + g.doConnect(ssctx, p, addr, true, func(ictx context.Context, conn *ClientConn, copts ...CallOption, ) (interface{}, error) { return nil, f(ictx, addr, conn, copts...) @@ -447,7 +447,7 @@ func (g *gRPCClient) OrderedRangeConcurrent(ctx context.Context, case <-egctx.Done(): return nil default: - g.do(ssctx, p, addr, true, func(ictx context.Context, + g.doConnect(ssctx, p, addr, true, func(ictx context.Context, conn *ClientConn, copts ...CallOption, ) (interface{}, error) { return nil, f(ictx, addr, conn, copts...) @@ -475,7 +475,7 @@ func (g *gRPCClient) RoundRobin(ctx context.Context, f func(ctx context.Context, sctx = backoff.WithBackoffName(ctx, boName) } do := func(ctx context.Context, p pool.Conn, addr string, f func(ctx context.Context, conn *ClientConn, copts ...CallOption) (interface{}, error)) (r interface{}, ret bool, err error) { - r, err = g.do(ctx, p, addr, false, f) + r, err = g.doConnect(ctx, p, addr, false, f) if err != nil { st, ok := status.FromError(err) if !ok || st == nil { @@ -555,10 +555,10 @@ func (g *gRPCClient) Do(ctx context.Context, addr string, log.Warnf("gRPCClient.Do operation failed, grpc pool connection for %s is invalid,\terror: %v", addr, err) return nil, err } - return g.do(sctx, p, addr, true, f) + return g.doConnect(sctx, p, addr, true, f) } -func (g *gRPCClient) do(ctx context.Context, p pool.Conn, addr string, enableBackoff bool, +func (g *gRPCClient) doConnect(ctx context.Context, p pool.Conn, addr string, enableBackoff bool, f func(ctx context.Context, conn *ClientConn, copts ...CallOption) (interface{}, error), ) (data interface{}, err error) { diff --git a/internal/net/grpc/client_test.go b/internal/net/grpc/client_test.go index 2cebd3692f..d9fb54292a 100644 --- a/internal/net/grpc/client_test.go +++ b/internal/net/grpc/client_test.go @@ -1524,7 +1524,7 @@ func Test_gRPCClient_Do(t *testing.T) { } } -func Test_gRPCClient_do(t *testing.T) { +func Test_gRPCClient_doConnect(t *testing.T) { type args struct { ctx context.Context p pool.Conn @@ -1726,7 +1726,7 @@ func Test_gRPCClient_do(t *testing.T) { stopMonitor: test.fields.stopMonitor, } - gotData, err := g.do(test.args.ctx, test.args.p, test.args.addr, test.args.enableBackoff, test.args.f) + gotData, err := g.doConnect(test.args.ctx, test.args.p, test.args.addr, test.args.enableBackoff, test.args.f) if err := checkFunc(test.want, gotData, err); err != nil { tt.Errorf("error = %v", err) } diff --git a/internal/net/grpc/pool/pool.go b/internal/net/grpc/pool/pool.go index cc30fda979..eaa15a5d02 100644 --- a/internal/net/grpc/pool/pool.go +++ b/internal/net/grpc/pool/pool.go @@ -158,11 +158,11 @@ func (p *pool) Connect(ctx context.Context) (c Conn, err error) { } if p.isIP || !p.resolveDNS { - return p.connect(ctx) + return p.doConnect(ctx) } ips, err := p.lookupIPAddr(ctx) if err != nil { - return p.connect(ctx) + return p.doConnect(ctx) } p.reconnectHash = strings.Join(ips, "-") @@ -213,7 +213,7 @@ func (p *pool) load(idx int) (pc *poolConn, ok bool) { return } -func (p *pool) connect(ctx context.Context) (c Conn, err error) { +func (p *pool) doConnect(ctx context.Context) (c Conn, err error) { p.reconnectHash = p.host failCnt := uint64(0) for i := range p.pool { @@ -440,7 +440,7 @@ func (p *pool) Reconnect(ctx context.Context, force bool) (c Conn, err error) { if p.reconnectHash == "" { log.Debugf("connection history for %s not found starting first connection phase", p.addr) if p.isIP || !p.resolveDNS { - return p.connect(ctx) + return p.doConnect(ctx) } return p.Connect(ctx) } @@ -451,7 +451,7 @@ func (p *pool) Reconnect(ctx context.Context, force bool) (c Conn, err error) { if p.isIP { return nil, errors.ErrInvalidGRPCClientConn(p.addr) } - return p.connect(ctx) + return p.doConnect(ctx) } return p, nil } diff --git a/internal/net/grpc/pool/pool_test.go b/internal/net/grpc/pool/pool_test.go index 48e0ff14e2..97287c1dbc 100644 --- a/internal/net/grpc/pool/pool_test.go +++ b/internal/net/grpc/pool/pool_test.go @@ -413,7 +413,7 @@ func Test_pool_load(t *testing.T) { } } -func Test_pool_connect(t *testing.T) { +func Test_pool_doConnect(t *testing.T) { t.Parallel() type args struct { ctx context.Context @@ -555,7 +555,7 @@ func Test_pool_connect(t *testing.T) { reconnectHash: test.fields.reconnectHash, } - gotC, err := p.connect(test.args.ctx) + gotC, err := p.doConnect(test.args.ctx) if err := checkFunc(test.want, gotC, err); err != nil { tt.Errorf("error = %v", err) } diff --git a/internal/net/http/transport/roundtrip.go b/internal/net/http/transport/roundtrip.go index 6103a72a40..ba6b59b347 100644 --- a/internal/net/http/transport/roundtrip.go +++ b/internal/net/http/transport/roundtrip.go @@ -47,11 +47,15 @@ func NewExpBackoff(opts ...Option) http.RoundTripper { // It round trip the request and returns the response, and return any error occurred. // It returns errors.ErrTransportRetryable to indicate if the request is consider as retryable. func (e *ert) RoundTrip(req *http.Request) (res *http.Response, err error) { + if req != nil { + defer closeBody(req.Body) + } + if e.bo == nil { - return e.roundTrip(req) + return e.doRoundTrip(req) } _, err = e.bo.Do(req.Context(), func(ctx context.Context) (interface{}, bool, error) { - r, err := e.roundTrip(req) + r, err := e.doRoundTrip(req) if err != nil { return nil, errors.Is(err, errors.ErrTransportRetryable), err } @@ -65,7 +69,7 @@ func (e *ert) RoundTrip(req *http.Request) (res *http.Response, err error) { return res, nil } -func (e *ert) roundTrip(req *http.Request) (res *http.Response, err error) { +func (e *ert) doRoundTrip(req *http.Request) (res *http.Response, err error) { res, err = e.transport.RoundTrip(req) if err != nil { log.Error(err) diff --git a/internal/net/http/transport/roundtrip_test.go b/internal/net/http/transport/roundtrip_test.go index 30be01e59e..3490ac861e 100644 --- a/internal/net/http/transport/roundtrip_test.go +++ b/internal/net/http/transport/roundtrip_test.go @@ -330,7 +330,7 @@ func Test_ert_RoundTrip(t *testing.T) { } } -func Test_ert_roundTrip(t *testing.T) { +func Test_ert_doRoundTrip(t *testing.T) { t.Parallel() type args struct { req *http.Request @@ -457,10 +457,17 @@ func Test_ert_roundTrip(t *testing.T) { bo: test.fields.bo, } - gotRes, err := e.roundTrip(test.args.req) + gotRes, err := e.doRoundTrip(test.args.req) if err := checkFunc(test.want, gotRes, err); err != nil { tt.Errorf("error = %v", err) } + + if gotRes != nil { + defer closeBody(gotRes.Body) + } + if test.args.req != nil { + defer closeBody(test.args.req.Body) + } }) } } diff --git a/internal/safety/safety.go b/internal/safety/safety.go index dc7c10186f..9182c04ae4 100644 --- a/internal/safety/safety.go +++ b/internal/safety/safety.go @@ -27,14 +27,14 @@ import ( ) func RecoverFunc(fn func() error) func() error { - return recoverFunc(fn, true) + return recoverFn(fn, true) } func RecoverWithoutPanicFunc(fn func() error) func() error { - return recoverFunc(fn, false) + return recoverFn(fn, false) } -func recoverFunc(fn func() error, withPanic bool) func() error { +func recoverFn(fn func() error, withPanic bool) func() error { return func() (err error) { defer func() { if r := recover(); r != nil { diff --git a/internal/safety/safety_test.go b/internal/safety/safety_test.go index 69ee31c552..e3953b3cd8 100644 --- a/internal/safety/safety_test.go +++ b/internal/safety/safety_test.go @@ -236,7 +236,7 @@ func TestRecoverWithoutPanicFunc(t *testing.T) { } } -func Test_recoverFunc(t *testing.T) { +func Test_recoverFn(t *testing.T) { type args struct { fn func() error withPanic bool @@ -303,7 +303,7 @@ func Test_recoverFunc(t *testing.T) { checkFunc = defaultCheckFunc } - got := recoverFunc(test.args.fn, test.args.withPanic) + got := recoverFn(test.args.fn, test.args.withPanic) if err := checkFunc(test.want, got); err != nil { tt.Errorf("error = %v", err) } diff --git a/internal/test/data/vector/gen.go b/internal/test/data/vector/gen.go index 1171f8644b..63dde69a31 100644 --- a/internal/test/data/vector/gen.go +++ b/internal/test/data/vector/gen.go @@ -30,6 +30,10 @@ type ( const ( Gaussian Distribution = iota Uniform + + // NOTE: mean:128, sigma:128/3, all of 99.7% are in [0, 255]. + gaussianMean float64 = 128 + gaussianSigma float64 = 128 / 3 ) // ErrUnknownDistritbution represents an error which the distribution is unknown. @@ -59,8 +63,8 @@ func Uint8VectorGenerator(d Distribution) (Uint8VectorGeneratorFunc, error) { } } -// float32VectorGenerator return n float32 vectors with dim dimension -func float32VectorGenerator(n, dim int, gen func() float32) (ret [][]float32) { +// genFloat32Vec return n float32 vectors with dim dimension. +func genFloat32Vec(n, dim int, gen func() float32) (ret [][]float32) { ret = make([][]float32, 0, n) for i := 0; i < n; i++ { @@ -75,18 +79,18 @@ func float32VectorGenerator(n, dim int, gen func() float32) (ret [][]float32) { // UniformDistributedFloat32VectorGenerator returns n float32 vectors with dim dimension and their values under Uniform distribution func UniformDistributedFloat32VectorGenerator(n, dim int) [][]float32 { - return float32VectorGenerator(n, dim, rand.Float32) + return genFloat32Vec(n, dim, rand.Float32) } // GaussianDistributedFloat32VectorGenerator returns n float32 vectors with dim dimension and their values under Gaussian distribution func GaussianDistributedFloat32VectorGenerator(n, dim int) [][]float32 { - return float32VectorGenerator(n, dim, func() float32 { + return genFloat32Vec(n, dim, func() float32 { return float32(rand.NormFloat64()) }) } -// uint8VectorGenerator return n uint8 vectors with dim dimension -func uint8VectorGenerator(n, dim int, gen func() uint8) (ret [][]uint8) { +// genUint8Vec return n uint8 vectors with dim dimension +func genUint8Vec(n, dim int, gen func() uint8) (ret [][]uint8) { ret = make([][]uint8, 0, n) for i := 0; i < n; i++ { @@ -101,26 +105,16 @@ func uint8VectorGenerator(n, dim int, gen func() uint8) (ret [][]uint8) { // UniformDistributedUint8VectorGenerator returns n uint8 vectors with dim dimension and their values under Uniform distribution func UniformDistributedUint8VectorGenerator(n, dim int) [][]uint8 { - return uint8VectorGenerator(n, dim, func() uint8 { + return genUint8Vec(n, dim, func() uint8 { return uint8(irand.LimitedUint32(math.MaxUint8)) }) } // GaussianDistributedUint8VectorGenerator returns n uint8 vectors with dim dimension and their values under Gaussian distribution func GaussianDistributedUint8VectorGenerator(n, dim int) [][]uint8 { - // NOTE: mean:128, sigma:128/3, all of 99.7% are in [0, 255] - const ( - mean float64 = 128 - sigma float64 = 128 / 3 - ) - return gaussianDistributedUint8VectorGenerator(n, dim, mean, sigma) -} - -// gaussianDistributedUint8VectorGenerator returns n uint8 vectors with dim dimension and their values under Gaussian distribution with user-specified mean and sigma -func gaussianDistributedUint8VectorGenerator(n, dim int, mean, sigma float64) [][]uint8 { // NOTE: The boundary test is the main purpose for refactoring. Now, passing this function is dependent on the seed of the random generator. We should fix the randomness of the passing test. - return uint8VectorGenerator(n, dim, func() uint8 { - val := rand.NormFloat64()*sigma + mean + return genUint8Vec(n, dim, func() uint8 { + val := rand.NormFloat64()*gaussianSigma + gaussianMean if val < 0 { return 0 } else if val > math.MaxUint8 { diff --git a/internal/test/data/vector/gen_test.go b/internal/test/data/vector/gen_test.go index 31de19ea34..ce072b80b0 100644 --- a/internal/test/data/vector/gen_test.go +++ b/internal/test/data/vector/gen_test.go @@ -208,7 +208,7 @@ func TestUint8VectorGenerator(t *testing.T) { } } -func Test_float32VectorGenerator(t *testing.T) { +func Test_genFloat32Vec(t *testing.T) { type args struct { n int dim int @@ -279,7 +279,7 @@ func Test_float32VectorGenerator(t *testing.T) { checkFunc = defaultCheckFunc } - gotRet := float32VectorGenerator(test.args.n, test.args.dim, test.args.gen) + gotRet := genFloat32Vec(test.args.n, test.args.dim, test.args.gen) if err := checkFunc(test.want, gotRet); err != nil { tt.Errorf("error = %v", err) } @@ -439,7 +439,7 @@ func TestGaussianDistributedFloat32VectorGenerator(t *testing.T) { } } -func Test_uint8VectorGenerator(t *testing.T) { +func Test_genUint8Vec(t *testing.T) { type args struct { n int dim int @@ -510,7 +510,7 @@ func Test_uint8VectorGenerator(t *testing.T) { checkFunc = defaultCheckFunc } - gotRet := uint8VectorGenerator(test.args.n, test.args.dim, test.args.gen) + gotRet := genUint8Vec(test.args.n, test.args.dim, test.args.gen) if err := checkFunc(test.want, gotRet); err != nil { tt.Errorf("error = %v", err) } @@ -670,88 +670,6 @@ func TestGaussianDistributedUint8VectorGenerator(t *testing.T) { } } -func Test_gaussianDistributedUint8VectorGenerator(t *testing.T) { - type args struct { - n int - dim int - mean float64 - sigma float64 - } - type want struct { - want [][]uint8 - } - type test struct { - name string - args args - want want - checkFunc func(want, [][]uint8) error - beforeFunc func(args) - afterFunc func(args) - } - defaultCheckFunc := func(w want, got [][]uint8) error { - if !reflect.DeepEqual(got, w.want) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got, w.want) - } - return nil - } - tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - n: 0, - dim: 0, - mean: 0, - sigma: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - n: 0, - dim: 0, - mean: 0, - sigma: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ - } - - for _, tc := range tests { - test := tc - t.Run(test.name, func(tt *testing.T) { - tt.Parallel() - defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) - if test.beforeFunc != nil { - test.beforeFunc(test.args) - } - if test.afterFunc != nil { - defer test.afterFunc(test.args) - } - checkFunc := test.checkFunc - if test.checkFunc == nil { - checkFunc = defaultCheckFunc - } - - got := gaussianDistributedUint8VectorGenerator(test.args.n, test.args.dim, test.args.mean, test.args.sigma) - if err := checkFunc(test.want, got); err != nil { - tt.Errorf("error = %v", err) - } - }) - } -} - func TestGenF32Vec(t *testing.T) { type args struct { dist Distribution diff --git a/internal/worker/queue.go b/internal/worker/queue.go index 6337bcb5eb..6dc2ee8664 100644 --- a/internal/worker/queue.go +++ b/internal/worker/queue.go @@ -135,29 +135,23 @@ func (q *queue) Push(ctx context.Context, job JobFunc) error { // Pop returns (JobFunc, nil) if the channnel, which will be used for queuing job, contains JobFunc. // It returns (nil ,error) if it failed to pop from the job queue. func (q *queue) Pop(ctx context.Context) (JobFunc, error) { - return q.pop(ctx, q.Len()) -} + tryCnt := int(q.Len()) + 1 // include the first try -func (q *queue) pop(ctx context.Context, retry uint64) (JobFunc, error) { - if !q.isRunning() { - return nil, errors.ErrQueueIsNotRunning() - } - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case job := <-q.outCh: - if job != nil { - return job, nil + for i := 0; i < tryCnt; i++ { + if !q.isRunning() { + return nil, errors.ErrQueueIsNotRunning() } - } - if retry <= 0 { - return nil, errors.ErrJobFuncIsNil() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case job := <-q.outCh: + if job != nil { + return job, nil + } + } } - - retry-- - return q.pop(ctx, retry) + return nil, errors.ErrJobFuncIsNil() } // Len returns the length of queue. diff --git a/internal/worker/queue_test.go b/internal/worker/queue_test.go index 71df9f5b80..e39f2496b5 100644 --- a/internal/worker/queue_test.go +++ b/internal/worker/queue_test.go @@ -670,88 +670,20 @@ func Test_queue_Pop(t *testing.T) { }, } }(), - } - - for _, tc := range tests { - test := tc - t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(tt, goleakIgnoreOptions...) - if test.beforeFunc != nil { - test.beforeFunc(test.args) - } - if test.afterFunc != nil { - defer test.afterFunc(test.args) - } - checkFunc := test.checkFunc - if test.checkFunc == nil { - checkFunc = defaultCheckFunc - } - q := &queue{ - buffer: test.fields.buffer, - eg: test.fields.eg, - qcdur: test.fields.qcdur, - inCh: test.fields.inCh, - outCh: test.fields.outCh, - qLen: test.fields.qLen, - running: test.fields.running, - } - - got, err := q.Pop(test.args.ctx) - if err := checkFunc(test.want, got, err); err != nil { - tt.Errorf("error = %v", err) - } - }) - } -} - -func Test_queue_pop(t *testing.T) { - type args struct { - ctx context.Context - retry uint64 - } - type fields struct { - buffer int - eg errgroup.Group - qcdur time.Duration - inCh chan JobFunc - outCh chan JobFunc - qLen atomic.Value - running atomic.Value - } - type want struct { - want JobFunc - err error - } - type test struct { - name string - args args - fields fields - want want - checkFunc func(want, JobFunc, error) error - beforeFunc func(args) - afterFunc func(args) - } - defaultCheckFunc := func(w want, got JobFunc, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) - } - if reflect.ValueOf(w.want).Pointer() != reflect.ValueOf(got).Pointer() { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got, w.want) - } - return nil - } - tests := []test{ { name: "return (nil, error) when queue is not running.", args: args{ - ctx: context.Background(), - retry: 1, + ctx: context.Background(), }, fields: fields{ running: func() (v atomic.Value) { v.Store(false) return v }(), + qLen: func() (v atomic.Value) { + v.Store(uint64(1)) + return v + }(), }, want: want{ want: nil, @@ -767,8 +699,7 @@ func Test_queue_pop(t *testing.T) { return test{ name: "return (JobFunc, nil) when first pop is retry.", args: args{ - ctx: ctx, - retry: 10, + ctx: ctx, }, fields: fields{ buffer: 10, @@ -777,7 +708,7 @@ func Test_queue_pop(t *testing.T) { inCh: make(chan JobFunc, 10), outCh: outCh, qLen: func() (v atomic.Value) { - v.Store(uint64(0)) + v.Store(uint64(10)) return v }(), running: func() (v atomic.Value) { @@ -804,8 +735,7 @@ func Test_queue_pop(t *testing.T) { return test{ name: "return (nil, error) when retry is 1 and retry.", args: args{ - ctx: ctx, - retry: 1, + ctx: ctx, }, fields: fields{ buffer: 10, @@ -814,7 +744,7 @@ func Test_queue_pop(t *testing.T) { inCh: make(chan JobFunc, 10), outCh: outCh, qLen: func() (v atomic.Value) { - v.Store(uint64(0)) + v.Store(uint64(1)) return v }(), running: func() (v atomic.Value) { @@ -838,8 +768,7 @@ func Test_queue_pop(t *testing.T) { return test{ name: "return (JobFunc, error) when context canceled.", args: args{ - ctx: ctx, - retry: 0, + ctx: ctx, }, fields: fields{ buffer: 10, @@ -894,7 +823,7 @@ func Test_queue_pop(t *testing.T) { running: test.fields.running, } - got, err := q.pop(test.args.ctx, test.args.retry) + got, err := q.Pop(test.args.ctx) if err := checkFunc(test.want, got, err); err != nil { tt.Errorf("error = %v", err) } diff --git a/pkg/gateway/lb/handler/grpc/handler.go b/pkg/gateway/lb/handler/grpc/handler.go index 73af136a40..4b263b3616 100644 --- a/pkg/gateway/lb/handler/grpc/handler.go +++ b/pkg/gateway/lb/handler/grpc/handler.go @@ -245,7 +245,7 @@ func (s *server) Search(ctx context.Context, req *payload.Search_Request) (res * if req.Config != nil { req.Config.MinNum = 0 } - res, err = s.search(ctx, &payload.Search_Config{ + res, err = s.doSearch(ctx, &payload.Search_Config{ RequestId: cfg.GetRequestId(), Num: cfg.GetNum(), MinNum: mn, @@ -346,7 +346,7 @@ func (s *server) SearchByID(ctx context.Context, req *payload.Search_IDRequest) ResourceName: fmt.Sprintf("%s: %s(%s) to %v", apiName, s.name, s.ip, s.gateway.Addrs(ctx)), }) var serr error - res, serr = s.search(ctx, scfg, func(ctx context.Context, vc vald.Client, copts ...grpc.CallOption) (*payload.Search_Response, error) { + res, serr = s.doSearch(ctx, scfg, func(ctx context.Context, vc vald.Client, copts ...grpc.CallOption) (*payload.Search_Response, error) { return vc.SearchByID(ctx, req, copts...) }) if serr == nil { @@ -382,7 +382,7 @@ func (s *server) SearchByID(ctx context.Context, req *payload.Search_IDRequest) ResourceName: fmt.Sprintf("%s: %s(%s) to %v", apiName, s.name, s.ip, s.gateway.Addrs(ctx)), }, info.Get()) var serr error - res, serr = s.search(ctx, scfg, func(ctx context.Context, vc vald.Client, copts ...grpc.CallOption) (*payload.Search_Response, error) { + res, serr = s.doSearch(ctx, scfg, func(ctx context.Context, vc vald.Client, copts ...grpc.CallOption) (*payload.Search_Response, error) { return vc.SearchByID(ctx, req, copts...) }) if serr == nil { @@ -412,7 +412,7 @@ type DistPayload struct { distance *big.Float } -func (s *server) search(ctx context.Context, cfg *payload.Search_Config, +func (s *server) doSearch(ctx context.Context, cfg *payload.Search_Config, f func(ctx context.Context, vc vald.Client, copts ...grpc.CallOption) (*payload.Search_Response, error)) ( res *payload.Search_Response, err error, ) { @@ -1009,7 +1009,7 @@ func (s *server) LinearSearch(ctx context.Context, req *payload.Search_Request) if req.Config != nil { req.Config.MinNum = 0 } - res, err = s.search(ctx, &payload.Search_Config{ + res, err = s.doSearch(ctx, &payload.Search_Config{ RequestId: cfg.GetRequestId(), Num: cfg.GetNum(), MinNum: mn, @@ -1101,7 +1101,7 @@ func (s *server) LinearSearchByID(ctx context.Context, req *payload.Search_IDReq ResourceName: fmt.Sprintf("%s: %s(%s) to %v", apiName, s.name, s.ip, s.gateway.Addrs(ctx)), }) var serr error - res, serr = s.search(ctx, scfg, func(ctx context.Context, vc vald.Client, copts ...grpc.CallOption) (*payload.Search_Response, error) { + res, serr = s.doSearch(ctx, scfg, func(ctx context.Context, vc vald.Client, copts ...grpc.CallOption) (*payload.Search_Response, error) { return vc.LinearSearchByID(ctx, req, copts...) }) if serr == nil { @@ -1138,7 +1138,7 @@ func (s *server) LinearSearchByID(ctx context.Context, req *payload.Search_IDReq ResourceName: fmt.Sprintf("%s: %s(%s) to %v", apiName, s.name, s.ip, s.gateway.Addrs(ctx)), }, info.Get()) var serr error - res, serr = s.search(ctx, scfg, func(ctx context.Context, vc vald.Client, copts ...grpc.CallOption) (*payload.Search_Response, error) { + res, serr = s.doSearch(ctx, scfg, func(ctx context.Context, vc vald.Client, copts ...grpc.CallOption) (*payload.Search_Response, error) { return vc.LinearSearchByID(ctx, req, copts...) }) if serr == nil { diff --git a/pkg/gateway/lb/handler/grpc/handler_test.go b/pkg/gateway/lb/handler/grpc/handler_test.go index d70529d039..6d5fa4c611 100644 --- a/pkg/gateway/lb/handler/grpc/handler_test.go +++ b/pkg/gateway/lb/handler/grpc/handler_test.go @@ -436,7 +436,7 @@ func Test_server_SearchByID(t *testing.T) { } } -func Test_server_search(t *testing.T) { +func Test_server_doSearch(t *testing.T) { t.Parallel() type args struct { ctx context.Context @@ -541,7 +541,7 @@ func Test_server_search(t *testing.T) { streamConcurrency: test.fields.streamConcurrency, } - gotRes, err := s.search(test.args.ctx, test.args.cfg, test.args.f) + gotRes, err := s.doSearch(test.args.ctx, test.args.cfg, test.args.f) if err := checkFunc(test.want, gotRes, err); err != nil { tt.Errorf("error = %v", err) }