From 2f2659adc80769ae807f01427f6f446b33e974ed Mon Sep 17 00:00:00 2001 From: Kosuke Morimoto Date: Thu, 21 May 2020 10:02:50 +0900 Subject: [PATCH] fix search --- .../cli/loadtest/service/search/search.go | 59 ++++++++++--------- .../loadtest/service/search/search_option.go | 25 +++++--- 2 files changed, 47 insertions(+), 37 deletions(-) diff --git a/pkg/tools/cli/loadtest/service/search/search.go b/pkg/tools/cli/loadtest/service/search/search.go index 021bbb7bfd3..5a12fc8d19a 100644 --- a/pkg/tools/cli/loadtest/service/search/search.go +++ b/pkg/tools/cli/loadtest/service/search/search.go @@ -17,27 +17,29 @@ package search import ( "context" - "reflect" - "sync" - - "github.com/vdaas/vald/internal/client" + "fmt" + "github.com/vdaas/vald/apis/grpc/gateway/vald" + "github.com/vdaas/vald/apis/grpc/payload" "github.com/vdaas/vald/internal/errgroup" "github.com/vdaas/vald/internal/errors" + "github.com/vdaas/vald/internal/net/grpc" "github.com/vdaas/vald/internal/safety" "github.com/vdaas/vald/pkg/tools/cli/loadtest/assets" + "reflect" ) type search struct { - eg errgroup.Group - r client.Reader - c int - n string - req []*client.SearchRequest + eg errgroup.Group + client grpc.Client + addr string + concurrency int + dataset string + req []*payload.Search_Request } -func New(opts ...SearchOption) (s *search, err error) { +func New(opts ...Option) (s *search, err error) { s = new(search) - for _, opt := range append(defaultSearchOpts, opts...) { + for _, opt := range append(defaultOpts, opts...) { if err = opt(s); err != nil { return nil, errors.ErrOptionFailed(err, reflect.ValueOf(opt)) } @@ -47,14 +49,18 @@ func New(opts ...SearchOption) (s *search, err error) { } func (s *search) Prepare(ctx context.Context) error { - dataset, err := assets.Data(s.n)() + fn := assets.Data(s.dataset) + if fn == nil { + return fmt.Errorf("dataset load function is nil: %s", s.dataset) + } + dataset, err := assets.Data(s.dataset)() if err != nil { return err } vectors := dataset.Query() - s.req = make([]*client.SearchRequest, len(vectors)) + s.req = make([]*payload.Search_Request, len(vectors)) for i, v := range vectors { - s.req[i] = &client.SearchRequest{ + s.req[i] = &payload.Search_Request{ Vector: v, } } @@ -63,27 +69,24 @@ func (s *search) Prepare(ctx context.Context) error { } func (s *search) Do(ctx context.Context) <-chan error { - errCh := make(chan error, len(s.req)*10) + errCh := make(chan error, len(s.req)) s.eg.Go(safety.RecoverFunc(func() error { defer close(errCh) - wg := new(sync.WaitGroup) - sem := make(chan struct{}, s.c) + eg, egctx := errgroup.New(ctx) + eg.Limitation(s.concurrency) for _, req := range s.req { - wg.Add(1) - sem <- struct{}{} - go func(r *client.SearchRequest) { - defer wg.Done() - defer func() { - <-sem - }() - _, err := s.r.Search(ctx, r) + r := req + eg.Go(func() error { + _, err := s.client.Do(egctx, s.addr, func(Ctx context.Context, conn *grpc.ClientConn, copts ...grpc.CallOption) (interface{}, error) { + return vald.NewValdClient(conn).Search(ctx, r, copts...) + }) if err != nil { errCh <- err } - }(req) + return nil + }) } - wg.Wait() - return nil + return eg.Wait() })) return errCh } diff --git a/pkg/tools/cli/loadtest/service/search/search_option.go b/pkg/tools/cli/loadtest/service/search/search_option.go index 20e4c876a89..7c6d8c86d2a 100644 --- a/pkg/tools/cli/loadtest/service/search/search_option.go +++ b/pkg/tools/cli/loadtest/service/search/search_option.go @@ -16,34 +16,41 @@ package search import ( - "github.com/vdaas/vald/internal/client" + "github.com/vdaas/vald/internal/net/grpc" ) -type SearchOption func(*search) error +type Option func(*search) error var ( - defaultSearchOpts = []SearchOption{ + defaultOpts = []Option{ WithConcurrency(100), } ) -func WithReader(r client.Reader) SearchOption { +func WithAddr(a string) Option { return func(s *search) error { - s.r = r + s.addr = a return nil } } -func WithConcurrency(c int) SearchOption { +func WithClient(c grpc.Client) Option { return func(s *search) error { - s.c = c + s.client = c return nil } } -func WithDataset(n string) SearchOption { +func WithConcurrency(c int) Option { + return func(s *search) error { + s.concurrency = c + return nil + } +} + +func WithDataset(n string) Option { return func(s *search) (err error) { - s.n = n + s.dataset = n return nil } }