From dcb498ba6e94a06f9a52e2e6168e7a9c2d2bdcd0 Mon Sep 17 00:00:00 2001 From: hlts2 Date: Tue, 11 Jul 2023 17:43:39 +0900 Subject: [PATCH 1/3] add search context to agent service Signed-off-by: hlts2 --- internal/core/algorithm/ngt/ngt.go | 11 +++++++++-- pkg/agent/core/ngt/handler/grpc/search.go | 4 ++-- pkg/agent/core/ngt/service/ngt.go | 17 +++++++++++------ 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/internal/core/algorithm/ngt/ngt.go b/internal/core/algorithm/ngt/ngt.go index 99d73fac3d..28a40b2bfc 100644 --- a/internal/core/algorithm/ngt/ngt.go +++ b/internal/core/algorithm/ngt/ngt.go @@ -25,6 +25,7 @@ package ngt import "C" import ( + "context" "reflect" "sync" "unsafe" @@ -40,7 +41,7 @@ type ( // NGT is core interface. NGT interface { // Search returns search result as []SearchResult - Search(vec []float32, size int, epsilon, radius float32) ([]SearchResult, error) + Search(ctx context.Context, vec []float32, size int, epsilon, radius float32) ([]SearchResult, error) // Linear Search returns linear search result as []SearchResult LinearSearch(vec []float32, size int) ([]SearchResult, error) @@ -366,7 +367,7 @@ func (n *ngt) loadObjectSpace() error { } // Search returns search result as []SearchResult. -func (n *ngt) Search(vec []float32, size int, epsilon, radius float32) (result []SearchResult, err error) { +func (n *ngt) Search(ctx context.Context, vec []float32, size int, epsilon, radius float32) (result []SearchResult, err error) { if len(vec) != int(n.dimension) { return nil, errors.ErrIncompatibleDimensionSize(len(vec), int(n.dimension)) } @@ -415,6 +416,12 @@ func (n *ngt) Search(vec []float32, size int, epsilon, radius float32) (result [ result = make([]SearchResult, rsize) for i := range result { + select { + case <-ctx.Done(): + n.PutErrorBuffer(ebuf) + return result[:i], nil + default: + } d := C.ngt_get_result(results, C.uint32_t(i), ebuf) if d.id == 0 && d.distance == 0 { result[i] = SearchResult{0, 0, n.newGoError(ebuf)} diff --git a/pkg/agent/core/ngt/handler/grpc/search.go b/pkg/agent/core/ngt/handler/grpc/search.go index 328806f7e7..2a9c9d6fdc 100644 --- a/pkg/agent/core/ngt/handler/grpc/search.go +++ b/pkg/agent/core/ngt/handler/grpc/search.go @@ -69,7 +69,7 @@ func (s *server) Search(ctx context.Context, req *payload.Search_Request) (res * return nil, err } res, err = toSearchResponse( - s.ngt.Search( + s.ngt.Search(ctx, req.GetVector(), req.GetConfig().GetNum(), req.GetConfig().GetEpsilon(), @@ -195,7 +195,7 @@ func (s *server) SearchByID(ctx context.Context, req *payload.Search_IDRequest) } return nil, err } - vec, dst, err := s.ngt.SearchByID( + vec, dst, err := s.ngt.SearchByID(ctx, uuid, req.GetConfig().GetNum(), req.GetConfig().GetEpsilon(), diff --git a/pkg/agent/core/ngt/service/ngt.go b/pkg/agent/core/ngt/service/ngt.go index f83da3826c..56377b38a3 100644 --- a/pkg/agent/core/ngt/service/ngt.go +++ b/pkg/agent/core/ngt/service/ngt.go @@ -50,8 +50,8 @@ import ( type NGT interface { Start(ctx context.Context) <-chan error - Search(vec []float32, size uint32, epsilon, radius float32) ([]model.Distance, error) - SearchByID(uuid string, size uint32, epsilon, radius float32) ([]float32, []model.Distance, error) + Search(ctx context.Context, vec []float32, size uint32, epsilon, radius float32) ([]model.Distance, error) + SearchByID(ctx context.Context, uuid string, size uint32, epsilon, radius float32) ([]float32, []model.Distance, error) LinearSearch(vec []float32, size uint32) ([]model.Distance, error) LinearSearchByID(uuid string, size uint32) ([]float32, []model.Distance, error) Insert(uuid string, vec []float32) (err error) @@ -869,11 +869,11 @@ func (n *ngt) Start(ctx context.Context) <-chan error { return ech } -func (n *ngt) Search(vec []float32, size uint32, epsilon, radius float32) ([]model.Distance, error) { +func (n *ngt) Search(ctx context.Context, vec []float32, size uint32, epsilon, radius float32) ([]model.Distance, error) { if n.IsIndexing() { return nil, errors.ErrCreateIndexingIsInProgress } - sr, err := n.core.Search(vec, int(size), epsilon, radius) + sr, err := n.core.Search(ctx, vec, int(size), epsilon, radius) if err != nil { if n.IsIndexing() { return nil, errors.ErrCreateIndexingIsInProgress @@ -888,6 +888,11 @@ func (n *ngt) Search(vec []float32, size uint32, epsilon, radius float32) ([]mod ds := make([]model.Distance, 0, len(sr)) for _, d := range sr { + select { + case <-ctx.Done(): + return ds, nil + default: + } if err = d.Error; d.ID == 0 && err != nil { log.Warnf("an error occurred while searching: %s", err) continue @@ -906,7 +911,7 @@ func (n *ngt) Search(vec []float32, size uint32, epsilon, radius float32) ([]mod return ds, nil } -func (n *ngt) SearchByID(uuid string, size uint32, epsilon, radius float32) (vec []float32, dst []model.Distance, err error) { +func (n *ngt) SearchByID(ctx context.Context, uuid string, size uint32, epsilon, radius float32) (vec []float32, dst []model.Distance, err error) { if n.IsIndexing() { return nil, nil, errors.ErrCreateIndexingIsInProgress } @@ -914,7 +919,7 @@ func (n *ngt) SearchByID(uuid string, size uint32, epsilon, radius float32) (vec if err != nil { return nil, nil, err } - dst, err = n.Search(vec, size, epsilon, radius) + dst, err = n.Search(ctx, vec, size, epsilon, radius) if err != nil { return vec, nil, err } From b1a6435a372dbc18d76d8e451f52fd57b24dff9c Mon Sep 17 00:00:00 2001 From: hlts2 Date: Tue, 11 Jul 2023 18:46:41 +0900 Subject: [PATCH 2/3] fix test execution error Signed-off-by: hlts2 --- internal/core/algorithm/ngt/ngt_test.go | 206 ++++++++++++++++-------- 1 file changed, 135 insertions(+), 71 deletions(-) diff --git a/internal/core/algorithm/ngt/ngt_test.go b/internal/core/algorithm/ngt/ngt_test.go index 5b7c071276..70156d88ff 100644 --- a/internal/core/algorithm/ngt/ngt_test.go +++ b/internal/core/algorithm/ngt/ngt_test.go @@ -18,6 +18,7 @@ package ngt import ( + "context" "io/fs" "math" "os" @@ -230,11 +231,11 @@ func TestLoad(t *testing.T) { name string args args want want - checkFunc func(want, NGT, error) error + checkFunc func(context.Context, want, NGT, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, NGT) error } - defaultCheckFunc := func(w want, got NGT, err error) error { + defaultCheckFunc := func(_ context.Context, w want, got NGT, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } @@ -296,8 +297,8 @@ func TestLoad(t *testing.T) { mu: &sync.RWMutex{}, }, }, - checkFunc: func(w want, n NGT, e error) error { - if err := defaultCheckFunc(w, n, e); err != nil { + checkFunc: func(ctx context.Context, w want, n NGT, e error) error { + if err := defaultCheckFunc(ctx, w, n, e); err != nil { return err } @@ -308,7 +309,7 @@ func TestLoad(t *testing.T) { } // check no vector can be searched - vs, err := n.Search([]float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, 10, 0, 0) + vs, err := n.Search(ctx, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, 10, 0, 0) if err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { return err } @@ -361,8 +362,8 @@ func TestLoad(t *testing.T) { mu: &sync.RWMutex{}, }, }, - checkFunc: func(w want, n NGT, e error) error { - if err := defaultCheckFunc(w, n, e); err != nil { + checkFunc: func(ctx context.Context, w want, n NGT, e error) error { + if err := defaultCheckFunc(ctx, w, n, e); err != nil { return err } @@ -376,7 +377,7 @@ func TestLoad(t *testing.T) { } // check inserted vector can be searched - vs, err := n.Search(vec, 10, 0, 0) + vs, err := n.Search(ctx, vec, 10, 0, 0) if err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { return err } @@ -426,8 +427,8 @@ func TestLoad(t *testing.T) { mu: &sync.RWMutex{}, }, }, - checkFunc: func(w want, n NGT, e error) error { - if err := defaultCheckFunc(w, n, e); err != nil { + checkFunc: func(ctx context.Context, w want, n NGT, e error) error { + if err := defaultCheckFunc(ctx, w, n, e); err != nil { return err } @@ -438,7 +439,7 @@ func TestLoad(t *testing.T) { } // check no vector can be searched - vs, err := n.Search([]float32{0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}, 10, 0, 0) + vs, err := n.Search(ctx, []float32{0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}, 10, 0, 0) if err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { return err } @@ -491,8 +492,8 @@ func TestLoad(t *testing.T) { mu: &sync.RWMutex{}, }, }, - checkFunc: func(w want, n NGT, e error) error { - if err := defaultCheckFunc(w, n, e); err != nil { + checkFunc: func(ctx context.Context, w want, n NGT, e error) error { + if err := defaultCheckFunc(ctx, w, n, e); err != nil { return err } @@ -506,7 +507,7 @@ func TestLoad(t *testing.T) { } // check inserted vector can be searched - vs, err := n.Search(vec, 10, 0, 0) + vs, err := n.Search(ctx, vec, 10, 0, 0) if err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { return err } @@ -537,8 +538,8 @@ func TestLoad(t *testing.T) { want: nil, err: errors.ErrIndexFileNotFound, }, - checkFunc: func(w want, n NGT, e error) error { - if err := defaultCheckFunc(w, n, e); err != nil { + checkFunc: func(ctx context.Context, w want, n NGT, e error) error { + if err := defaultCheckFunc(ctx, w, n, e); err != nil { return err } @@ -551,7 +552,7 @@ func TestLoad(t *testing.T) { } // check no vector can be searched - vs, err := n.Search(vec, 10, 0, 0) + vs, err := n.Search(ctx, vec, 10, 0, 0) if err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { return err } @@ -587,7 +588,7 @@ func TestLoad(t *testing.T) { t.Error(err) } }, - checkFunc: func(w want, n NGT, e error) error { + checkFunc: func(_ context.Context, w want, n NGT, e error) error { if e != nil && !errors.As(e, w.err) { t.Error(e) return e @@ -620,7 +621,10 @@ func TestLoad(t *testing.T) { tt.Error(err) } }() - if err := checkFunc(test.want, got, err); err != nil { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := checkFunc(ctx, test.want, got, err); err != nil { tt.Errorf("error = %v", err) } }) @@ -640,11 +644,11 @@ func Test_gen(t *testing.T) { name string args args want want - checkFunc func(want, NGT, error) error + checkFunc func(context.Context, want, NGT, error) error beforeFunc func(*testing.T, args) afterFunc func(*testing.T, NGT) error } - defaultCheckFunc := func(w want, got NGT, err error) error { + defaultCheckFunc := func(_ context.Context, w want, got NGT, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } @@ -727,8 +731,8 @@ func Test_gen(t *testing.T) { mu: &sync.RWMutex{}, }, }, - checkFunc: func(w want, n NGT, e error) error { - if err := defaultCheckFunc(w, n, e); err != nil { + checkFunc: func(ctx context.Context, w want, n NGT, e error) error { + if err := defaultCheckFunc(ctx, w, n, e); err != nil { return err } @@ -742,7 +746,7 @@ func Test_gen(t *testing.T) { } // check inserted vector can be searched - vs, err := n.Search(vec, 10, 0, 0) + vs, err := n.Search(ctx, vec, 10, 0, 0) if err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { return err } @@ -789,7 +793,10 @@ func Test_gen(t *testing.T) { tt.Error(err) } }() - if err := checkFunc(test.want, got, err); err != nil { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := checkFunc(ctx, test.want, got, err); err != nil { tt.Errorf("error = %v", err) } }) @@ -1308,6 +1315,7 @@ func Test_ngt_loadObjectSpace(t *testing.T) { func Test_ngt_Search(t *testing.T) { type args struct { + ctx context.Context vec []float32 size int epsilon float32 @@ -1381,6 +1389,7 @@ func Test_ngt_Search(t *testing.T) { { name: "return vector id after the same vector inserted (uint8)", args: args{ + ctx: context.Background(), vec: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, size: 5, epsilon: 0, @@ -1410,6 +1419,7 @@ func Test_ngt_Search(t *testing.T) { { name: "resturn vector id after the nearby vector inserted (uint8)", args: args{ + ctx: context.Background(), vec: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, size: 5, }, @@ -1437,6 +1447,7 @@ func Test_ngt_Search(t *testing.T) { { name: "return vector ids after insert with multiple vectors (uint8)", args: args{ + ctx: context.Background(), vec: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, size: 5, }, @@ -1470,6 +1481,7 @@ func Test_ngt_Search(t *testing.T) { { name: "return limited result after insert 10 vectors with limited size 3 (uint8)", args: args{ + ctx: context.Background(), vec: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, size: 3, }, @@ -1510,6 +1522,7 @@ func Test_ngt_Search(t *testing.T) { { name: "return most accurate result after insert 10 vectors with limited size 5 (uint8)", args: args{ + ctx: context.Background(), vec: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, size: 5, }, @@ -1553,6 +1566,7 @@ func Test_ngt_Search(t *testing.T) { { name: "return vector id after the same vector inserted (float)", args: args{ + ctx: context.Background(), vec: []float32{0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}, size: 5, epsilon: 0, @@ -1582,6 +1596,7 @@ func Test_ngt_Search(t *testing.T) { { name: "resturn vector id after the nearby vector inserted (float)", args: args{ + ctx: context.Background(), vec: []float32{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.91}, size: 5, }, @@ -1609,6 +1624,7 @@ func Test_ngt_Search(t *testing.T) { { name: "return vector ids after insert with multiple vectors (float)", args: args{ + ctx: context.Background(), vec: []float32{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}, size: 5, }, @@ -1641,6 +1657,7 @@ func Test_ngt_Search(t *testing.T) { { name: "return limited result after insert 10 vectors with limited size 3 (float)", args: args{ + ctx: context.Background(), vec: []float32{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}, size: 3, }, @@ -1681,6 +1698,7 @@ func Test_ngt_Search(t *testing.T) { { name: "return most accurate result after insert 10 vectors with limited size 5 (float)", args: args{ + ctx: context.Background(), vec: []float32{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}, size: 5, }, @@ -1724,6 +1742,7 @@ func Test_ngt_Search(t *testing.T) { { name: "return nothing if the search dimension is less than the inserted vector", args: args{ + ctx: context.Background(), vec: []float32{0, 1, 2, 3, 4, 5, 6, 7}, size: 5, epsilon: 0, @@ -1751,6 +1770,7 @@ func Test_ngt_Search(t *testing.T) { { name: "return nothing if the search dimension is more than the inserted vector", args: args{ + ctx: context.Background(), vec: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, size: 5, epsilon: 0, @@ -1778,6 +1798,32 @@ func Test_ngt_Search(t *testing.T) { { name: "return ErrEmptySearchResult error if there is no inserted vector", args: args{ + ctx: context.Background(), + vec: []float32{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}, + size: 3, + }, + fields: fields{ + inMemory: false, + idxPath: "/tmp/ngt-813", + bulkInsertChunkSize: 100, + dimension: 9, + objectType: Float, + radius: float32(-1.0), + epsilon: float32(0.1), + }, + createFunc: defaultCreateFunc, + want: want{ + err: errors.ErrEmptySearchResult, + }, + }, + { + name: "return ErrEmptySearchResult error if the context is canceled", + args: args{ + ctx: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + return ctx + }(), vec: []float32{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}, size: 3, }, @@ -1801,6 +1847,9 @@ func Test_ngt_Search(t *testing.T) { test := tc t.Run(test.name, func(tt *testing.T) { tt.Parallel() + ctx, cancel := context.WithCancel(test.args.ctx) + defer cancel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) if test.beforeFunc != nil { test.beforeFunc(test.args) @@ -1821,7 +1870,7 @@ func Test_ngt_Search(t *testing.T) { tt.Fatal(err) } - got, err := n.Search(test.args.vec, test.args.size, test.args.epsilon, test.args.radius) + got, err := n.Search(ctx, test.args.vec, test.args.size, test.args.epsilon, test.args.radius) if err := checkFunc(test.want, got, n, err); err != nil { tt.Errorf("error = %v", err) } @@ -1857,7 +1906,7 @@ func Test_ngt_Insert(t *testing.T) { fields fields createFunc func(*testing.T, fields) (NGT, error) want want - checkFunc func(want, uint, NGT, args, error) error + checkFunc func(context.Context, want, uint, NGT, args, error) error beforeFunc func(args) afterFunc func(*testing.T, NGT) error } @@ -1875,7 +1924,7 @@ func Test_ngt_Insert(t *testing.T) { WithDimension(fields.dimension), ) } - defaultCheckFunc := func(w want, got uint, n NGT, args args, err error) error { + defaultCheckFunc := func(ctx context.Context, w want, got uint, n NGT, args args, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } @@ -1888,7 +1937,7 @@ func Test_ngt_Insert(t *testing.T) { } // search before indexing, it should return nothing - r, err := n.Search(args.vec, 5, 0, 0) + r, err := n.Search(ctx, args.vec, 5, 0, 0) if err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { return err } @@ -1900,7 +1949,7 @@ func Test_ngt_Insert(t *testing.T) { if err := n.CreateIndex(1); err != nil { return err } - r, err = n.Search(args.vec, 5, 0, 0) + r, err = n.Search(ctx, args.vec, 5, 0, 0) if err != nil { return err } @@ -2086,9 +2135,11 @@ func Test_ngt_Insert(t *testing.T) { if err != nil { tt.Fatal(err) } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() got, err := n.Insert(test.args.vec) - if err := checkFunc(test.want, got, n, test.args, err); err != nil { + if err := checkFunc(ctx, test.want, got, n, test.args, err); err != nil { tt.Errorf("error = %v", err) } @@ -2124,7 +2175,7 @@ func Test_ngt_InsertCommit(t *testing.T) { fields fields createFunc func(*testing.T, fields) (NGT, error) want want - checkFunc func(want, uint, NGT, args, error) error + checkFunc func(context.Context, want, uint, NGT, args, error) error beforeFunc func(args) afterFunc func(*testing.T, NGT) error } @@ -2142,7 +2193,7 @@ func Test_ngt_InsertCommit(t *testing.T) { WithDimension(fields.dimension), ) } - defaultCheckFunc := func(w want, got uint, n NGT, args args, err error) error { + defaultCheckFunc := func(ctx context.Context, w want, got uint, n NGT, args args, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) } @@ -2153,7 +2204,8 @@ func Test_ngt_InsertCommit(t *testing.T) { if got == 0 { return nil } - r, err := n.Search(args.vec, 5, 0, 0) + + r, err := n.Search(ctx, args.vec, 5, 0, 0) if err != nil { return err } @@ -2339,9 +2391,11 @@ func Test_ngt_InsertCommit(t *testing.T) { if err != nil { tt.Fatal(err) } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() got, err := n.InsertCommit(test.args.vec, test.args.poolSize) - if err := checkFunc(test.want, got, n, test.args, err); err != nil { + if err := checkFunc(ctx, test.want, got, n, test.args, err); err != nil { tt.Errorf("error = %v", err) } @@ -2376,7 +2430,7 @@ func Test_ngt_BulkInsert(t *testing.T) { fields fields createFunc func(*testing.T, fields) (NGT, error) want want - checkFunc func(want, []uint, NGT, fields, args, []error) error + checkFunc func(context.Context, want, []uint, NGT, fields, args, []error) error beforeFunc func(args) afterFunc func(*testing.T, NGT) error } @@ -2394,7 +2448,7 @@ func Test_ngt_BulkInsert(t *testing.T) { WithDimension(fields.dimension), ) } - defaultCheckFunc := func(w want, got []uint, n NGT, fields fields, args args, got1 []error) error { + defaultCheckFunc := func(ctx context.Context, w want, got []uint, n NGT, fields fields, args args, got1 []error) error { if diff := comparator.Diff(w.want1, got1, comparator.ErrorComparer); diff != "" { return errors.New(diff) } @@ -2418,7 +2472,7 @@ func Test_ngt_BulkInsert(t *testing.T) { if len(vec) != fields.dimension { continue } - r, err := n.Search(vec, 1, 0, 0) + r, err := n.Search(ctx, vec, 1, 0, 0) if err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { return err } @@ -2435,7 +2489,7 @@ func Test_ngt_BulkInsert(t *testing.T) { if len(vec) != fields.dimension { continue } - r, err := n.Search(vec, 1, 0, 0) + r, err := n.Search(ctx, vec, 1, 0, 0) if err != nil { return err } @@ -2655,9 +2709,11 @@ func Test_ngt_BulkInsert(t *testing.T) { tt.Error(err) } }() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() got, got1 := n.BulkInsert(test.args.vecs) - if err := checkFunc(test.want, got, n, test.fields, test.args, got1); err != nil { + if err := checkFunc(ctx, test.want, got, n, test.fields, test.args, got1); err != nil { tt.Errorf("error = %v", err) } }) @@ -2689,7 +2745,7 @@ func Test_ngt_BulkInsertCommit(t *testing.T) { fields fields createFunc func(*testing.T, fields) (NGT, error) want want - checkFunc func(want, []uint, NGT, fields, args, []error) error + checkFunc func(context.Context, want, []uint, NGT, fields, args, []error) error beforeFunc func(args) afterFunc func(*testing.T, NGT) error } @@ -2707,7 +2763,7 @@ func Test_ngt_BulkInsertCommit(t *testing.T) { WithDimension(fields.dimension), ) } - defaultCheckFunc := func(w want, got []uint, n NGT, fields fields, args args, got1 []error) error { + defaultCheckFunc := func(ctx context.Context, w want, got []uint, n NGT, fields fields, args args, got1 []error) error { if diff := comparator.Diff(w.want1, got1, comparator.ErrorComparer); diff != "" { return errors.New(diff) } @@ -2729,7 +2785,7 @@ func Test_ngt_BulkInsertCommit(t *testing.T) { if len(vec) != fields.dimension { continue } - r, err := n.Search(vec, 1, 0, 0) + r, err := n.Search(ctx, vec, 1, 0, 0) if err != nil { return err } @@ -2949,9 +3005,11 @@ func Test_ngt_BulkInsertCommit(t *testing.T) { tt.Error(err) } }() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() got, got1 := n.BulkInsertCommit(test.args.vecs, test.args.poolSize) - if err := checkFunc(test.want, got, n, test.fields, test.args, got1); err != nil { + if err := checkFunc(ctx, test.want, got, n, test.fields, test.args, got1); err != nil { tt.Errorf("error = %v", err) } }) @@ -2981,7 +3039,7 @@ func Test_ngt_CreateAndSaveIndex(t *testing.T) { fields fields createFunc func(*testing.T, fields) (NGT, error) want want - checkFunc func(want, NGT, args, error) error + checkFunc func(context.Context, want, NGT, args, error) error beforeFunc func(args) afterFunc func(*testing.T, NGT) error } @@ -2999,7 +3057,7 @@ func Test_ngt_CreateAndSaveIndex(t *testing.T) { WithDimension(fields.dimension), ) } - defaultCheckFunc := func(w want, n NGT, args args, got error) error { + defaultCheckFunc := func(_ context.Context, w want, n NGT, args args, got error) error { if diff := comparator.Diff(w.err, got); diff != "" { return errors.New(diff) } @@ -3086,14 +3144,14 @@ func Test_ngt_CreateAndSaveIndex(t *testing.T) { return ngt, err }, - checkFunc: func(w want, n NGT, a args, e error) error { - if err := defaultCheckFunc(w, n, a, e); err != nil { + checkFunc: func(ctx context.Context, w want, n NGT, a args, e error) error { + if err := defaultCheckFunc(ctx, w, n, a, e); err != nil { return err } // search the inserted vector exists after create index for _, v := range ivs { - if rs, err := n.Search(v, 1, 0, 0); err != nil { + if rs, err := n.Search(ctx, v, 1, 0, 0); err != nil { if rs[0].Distance != 0 { return errors.Errorf("vector distance is invalid, got: %d, want: %d", rs[0].Distance, 0) } @@ -3145,14 +3203,14 @@ func Test_ngt_CreateAndSaveIndex(t *testing.T) { return ngt, err }, - checkFunc: func(w want, n NGT, a args, e error) error { - if err := defaultCheckFunc(w, n, a, e); err != nil { + checkFunc: func(ctx context.Context, w want, n NGT, a args, e error) error { + if err := defaultCheckFunc(ctx, w, n, a, e); err != nil { return err } // search the inserted vector exists after create index for _, v := range ivs { - if rs, err := n.Search(v, 1, 0, 0); err != nil { + if rs, err := n.Search(ctx, v, 1, 0, 0); err != nil { if rs[0].Distance != 0 { return errors.Errorf("vector distance is invalid, got: %d, want: %d", rs[0].Distance, 0) } @@ -3206,9 +3264,11 @@ func Test_ngt_CreateAndSaveIndex(t *testing.T) { tt.Error(err) } }() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() err = n.CreateAndSaveIndex(test.args.poolSize) - if err := checkFunc(test.want, n, test.args, err); err != nil { + if err := checkFunc(ctx, test.want, n, test.args, err); err != nil { tt.Errorf("error = %v", err) } }) @@ -3238,7 +3298,7 @@ func Test_ngt_CreateIndex(t *testing.T) { fields fields createFunc func(*testing.T, fields) (NGT, error) want want - checkFunc func(want, NGT, args, error) error + checkFunc func(context.Context, want, NGT, args, error) error beforeFunc func(args) afterFunc func(*testing.T, NGT) error } @@ -3256,7 +3316,7 @@ func Test_ngt_CreateIndex(t *testing.T) { WithDimension(fields.dimension), ) } - defaultCheckFunc := func(w want, n NGT, args args, got error) error { + defaultCheckFunc := func(_ context.Context, w want, n NGT, args args, got error) error { if diff := comparator.Diff(w.err, got); diff != "" { return errors.New(diff) } @@ -3343,14 +3403,14 @@ func Test_ngt_CreateIndex(t *testing.T) { return ngt, err }, - checkFunc: func(w want, n NGT, a args, e error) error { - if err := defaultCheckFunc(w, n, a, e); err != nil { + checkFunc: func(ctx context.Context, w want, n NGT, a args, e error) error { + if err := defaultCheckFunc(ctx, w, n, a, e); err != nil { return err } // search the inserted vector exists after create index for _, v := range ivs { - if rs, err := n.Search(v, 1, 0, 0); err != nil { + if rs, err := n.Search(ctx, v, 1, 0, 0); err != nil { if rs[0].Distance != 0 { return errors.Errorf("vector distance is invalid, got: %d, want: %d", rs[0].Distance, 0) } @@ -3402,14 +3462,14 @@ func Test_ngt_CreateIndex(t *testing.T) { return ngt, err }, - checkFunc: func(w want, n NGT, a args, e error) error { - if err := defaultCheckFunc(w, n, a, e); err != nil { + checkFunc: func(ctx context.Context, w want, n NGT, a args, e error) error { + if err := defaultCheckFunc(ctx, w, n, a, e); err != nil { return err } // search the inserted vector exists after create index for _, v := range ivs { - if rs, err := n.Search(v, 1, 0, 0); err != nil { + if rs, err := n.Search(ctx, v, 1, 0, 0); err != nil { if rs[0].Distance != 0 { return errors.Errorf("vector distance is invalid, got: %d, want: %d", rs[0].Distance, 0) } @@ -3463,9 +3523,11 @@ func Test_ngt_CreateIndex(t *testing.T) { tt.Error(err) } }() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() err = n.CreateIndex(test.args.poolSize) - if err := checkFunc(test.want, n, test.args, err); err != nil { + if err := checkFunc(ctx, test.want, n, test.args, err); err != nil { tt.Errorf("error = %v", err) } }) @@ -3495,7 +3557,7 @@ func Test_ngt_SaveIndex(t *testing.T) { fields fields createFunc func(*testing.T, fields) (NGT, error) want want - checkFunc func(want, NGT, args, error) error + checkFunc func(context.Context, want, NGT, args, error) error beforeFunc func(args) afterFunc func(*testing.T, NGT) error } @@ -3513,7 +3575,7 @@ func Test_ngt_SaveIndex(t *testing.T) { WithDimension(fields.dimension), ) } - defaultCheckFunc := func(w want, n NGT, args args, e error) error { + defaultCheckFunc := func(_ context.Context, w want, n NGT, args args, e error) error { if ngt, ok := n.(*ngt); ok { _, err := os.Stat(ngt.idxPath) // if ngt is in-memory mode, the index file should not be created @@ -3597,14 +3659,14 @@ func Test_ngt_SaveIndex(t *testing.T) { return ngt, err }, - checkFunc: func(w want, n NGT, a args, e error) error { - if err := defaultCheckFunc(w, n, a, e); err != nil { + checkFunc: func(ctx context.Context, w want, n NGT, a args, e error) error { + if err := defaultCheckFunc(ctx, w, n, a, e); err != nil { return err } // search the inserted vector exists after create index for _, v := range ivs { - if rs, err := n.Search(v, 1, 0, 0); err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { + if rs, err := n.Search(ctx, v, 1, 0, 0); err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { if rs[0].Distance != 0 { return errors.Errorf("vector distance is invalid, got: %d, want: %d", rs[0].Distance, 0) } @@ -3656,14 +3718,14 @@ func Test_ngt_SaveIndex(t *testing.T) { return ngt, err }, - checkFunc: func(w want, n NGT, a args, e error) error { - if err := defaultCheckFunc(w, n, a, e); err != nil { + checkFunc: func(ctx context.Context, w want, n NGT, a args, e error) error { + if err := defaultCheckFunc(ctx, w, n, a, e); err != nil { return err } // search the inserted vector exists after create index for _, v := range ivs { - if rs, err := n.Search(v, 1, 0, 0); err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { + if rs, err := n.Search(ctx, v, 1, 0, 0); err != nil && !errors.Is(err, errors.ErrEmptySearchResult) { if rs[0].Distance != 0 { return errors.Errorf("vector distance is invalid, got: %d, want: %d", rs[0].Distance, 0) } @@ -3717,9 +3779,11 @@ func Test_ngt_SaveIndex(t *testing.T) { tt.Error(err) } }() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() err = n.SaveIndex() - if err := checkFunc(test.want, n, test.args, err); err != nil { + if err := checkFunc(ctx, test.want, n, test.args, err); err != nil { tt.Errorf("error = %v", err) } }) From f029ca3acd907176795d529a1baa08bc29365398 Mon Sep 17 00:00:00 2001 From: hlts2 Date: Tue, 11 Jul 2023 19:04:00 +0900 Subject: [PATCH 3/3] fix build error of ngt stateful test Signed-off-by: hlts2 --- .../core/ngt/service/ngt_stateful_test.go | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/pkg/agent/core/ngt/service/ngt_stateful_test.go b/pkg/agent/core/ngt/service/ngt_stateful_test.go index fb3da107b3..a3a1bd6e61 100644 --- a/pkg/agent/core/ngt/service/ngt_stateful_test.go +++ b/pkg/agent/core/ngt/service/ngt_stateful_test.go @@ -511,9 +511,10 @@ var ( RunFunc: func( systemUnderTest commands.SystemUnderTest, ) commands.Result { - ngt := systemUnderTest.(*ngtSystem).ngt + ngtSys := systemUnderTest.(*ngtSystem) + ngt := ngtSys.ngt - res, err := ngt.Search([]float32{0.1, 0.1, 0.1}, 3, 0.1, -1.0) + res, err := ngt.Search(ngtSys.ctx, []float32{0.1, 0.1, 0.1}, 3, 0.1, -1.0) return &resultContainer{ err: err, results: res, @@ -583,9 +584,10 @@ var ( RunFunc: func( systemUnderTest commands.SystemUnderTest, ) commands.Result { - ngt := systemUnderTest.(*ngtSystem).ngt + ngtSys := systemUnderTest.(*ngtSystem) + ngt := ngtSys.ngt - _, res, err := ngt.SearchByID(idA, 3, 0.1, -1.0) + _, res, err := ngt.SearchByID(ngtSys.ctx, idA, 3, 0.1, -1.0) return &resultContainer{ err: err, results: res, @@ -662,9 +664,10 @@ var ( RunFunc: func( systemUnderTest commands.SystemUnderTest, ) commands.Result { - ngt := systemUnderTest.(*ngtSystem).ngt + ngtSys := systemUnderTest.(*ngtSystem) + ngt := ngtSys.ngt - _, res, err := ngt.SearchByID(idB, 3, 0.1, -1.0) + _, res, err := ngt.SearchByID(ngtSys.ctx, idB, 3, 0.1, -1.0) return &resultContainer{ err: err, results: res, @@ -741,9 +744,10 @@ var ( RunFunc: func( systemUnderTest commands.SystemUnderTest, ) commands.Result { - ngt := systemUnderTest.(*ngtSystem).ngt + ngtSys := systemUnderTest.(*ngtSystem) + ngt := ngtSys.ngt - _, res, err := ngt.SearchByID(idC, 3, 0.1, -1.0) + _, res, err := ngt.SearchByID(ngtSys.ctx, idC, 3, 0.1, -1.0) return &resultContainer{ err: err, results: res,