Skip to content

Commit

Permalink
Propagate context to Search operation. (#2117)
Browse files Browse the repository at this point in the history
* add search context to agent service

Signed-off-by: hlts2 <hiroto.funakoshi.hiroto@gmail.com>

* fix test execution error

Signed-off-by: hlts2 <hiroto.funakoshi.hiroto@gmail.com>

* fix build error of ngt stateful test

Signed-off-by: hlts2 <hiroto.funakoshi.hiroto@gmail.com>

---------

Signed-off-by: hlts2 <hiroto.funakoshi.hiroto@gmail.com>
  • Loading branch information
hlts2 authored and ykadowak committed Jul 25, 2023
1 parent 5be892c commit a195594
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 89 deletions.
11 changes: 9 additions & 2 deletions internal/core/algorithm/ngt/ngt.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ package ngt
import "C"

import (
"context"
"reflect"
"sync"
"unsafe"
Expand All @@ -40,7 +41,7 @@ type (
// NGT is core interface.
NGT interface {
// Search returns search result as []SearchResult
Search(vec []float32, size int, epsilon, radius float32) ([]SearchResult, error)
Search(ctx context.Context, vec []float32, size int, epsilon, radius float32) ([]SearchResult, error)

// Linear Search returns linear search result as []SearchResult
LinearSearch(vec []float32, size int) ([]SearchResult, error)
Expand Down Expand Up @@ -366,7 +367,7 @@ func (n *ngt) loadObjectSpace() error {
}

// Search returns search result as []SearchResult.
func (n *ngt) Search(vec []float32, size int, epsilon, radius float32) (result []SearchResult, err error) {
func (n *ngt) Search(ctx context.Context, vec []float32, size int, epsilon, radius float32) (result []SearchResult, err error) {
if len(vec) != int(n.dimension) {
return nil, errors.ErrIncompatibleDimensionSize(len(vec), int(n.dimension))
}
Expand Down Expand Up @@ -415,6 +416,12 @@ func (n *ngt) Search(vec []float32, size int, epsilon, radius float32) (result [
result = make([]SearchResult, rsize)

for i := range result {
select {
case <-ctx.Done():
n.PutErrorBuffer(ebuf)
return result[:i], nil
default:
}
d := C.ngt_get_result(results, C.uint32_t(i), ebuf)
if d.id == 0 && d.distance == 0 {
result[i] = SearchResult{0, 0, n.newGoError(ebuf)}
Expand Down
Loading

0 comments on commit a195594

Please sign in to comment.