Skip to content

Commit

Permalink
[patch] change gateway vald's mutex lock (#765)
Browse files Browse the repository at this point in the history
* [patch] change gateway vald's mutex lock
Signed-off-by: kpango <i.can.feel.gravity@gmail.com>
Co-authored-by: Rintaro Okamura <rintaro.okamura@gmail.com>
  • Loading branch information
Yusuke Kato authored Oct 13, 2020
1 parent 6cd738f commit ac40a15
Showing 1 changed file with 37 additions and 43 deletions.
80 changes: 37 additions & 43 deletions pkg/gateway/vald/handler/grpc/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ func (s *server) search(ctx context.Context, cfg *payload.Search_Config,
span.End()
}
}()
maxDist := uint32(math.MaxUint32)
var maxDist uint32
atomic.StoreUint32(&maxDist, math.Float32bits(math.MaxFloat32))
num := int(cfg.GetNum())
res = new(payload.Search_Response)
res.Results = make([]*payload.Object_Distance, 0, s.gateway.GetAgentCount(ctx)*num)
Expand All @@ -147,32 +148,23 @@ func (s *server) search(ctx context.Context, cfg *payload.Search_Config,

eg.Go(safety.RecoverFunc(func() error {
defer cancel()
// cl := new(checkList)
visited := make(map[string]bool, len(res.Results))
mu := sync.RWMutex{}
visited := new(sync.Map)
return s.gateway.BroadCast(ectx, func(ctx context.Context, target string, ac agent.AgentClient, copts ...grpc.CallOption) error {
r, err := f(ctx, ac, copts...)
if err != nil {
log.Debug("ignoring error:", err)
return nil
}
for _, dist := range r.GetResults() {
if dist.GetDistance() > math.Float32frombits(atomic.LoadUint32(&maxDist)) {
if dist.GetDistance() >= math.Float32frombits(atomic.LoadUint32(&maxDist)) {
return nil
}
id := dist.GetId()
mu.Lock()
if !visited[id] {
visited[id] = true
mu.Unlock()
if dist == nil {
continue
}
if _, already := visited.LoadOrStore(dist.GetId(), struct{}{}); !already {
dch <- dist
} else {
mu.Unlock()
}
// if !cl.Exists(id) {
// dch <- dist
// cl.Check(id)
// }
}
return nil
})
Expand All @@ -188,8 +180,8 @@ func (s *server) search(ctx context.Context, cfg *payload.Search_Config,
if len(res.GetResults()) > num && num != 0 {
res.Results = res.Results[:num]
}
uuids := make([]string, 0, len(res.Results))
for _, r := range res.Results {
uuids := make([]string, 0, len(res.GetResults()))
for _, r := range res.GetResults() {
uuids = append(uuids, r.GetId())
}
if s.metadata != nil {
Expand Down Expand Up @@ -217,44 +209,46 @@ func (s *server) search(ctx context.Context, cfg *payload.Search_Config,
}
return res, nil
case dist := <-dch:
if len(res.GetResults()) >= num {
if dist.GetDistance() < math.Float32frombits(atomic.LoadUint32(&maxDist)) {
atomic.StoreUint32(&maxDist, math.Float32bits(dist.GetDistance()))
} else {
continue
}
rl := len(res.GetResults()) // result length
if rl >= num && dist.GetDistance() >= math.Float32frombits(atomic.LoadUint32(&maxDist)) {
continue
}
switch len(res.GetResults()) {
switch rl {
case 0:
res.Results = append(res.Results, dist)
continue
case 1:
if res.GetResults()[0].GetDistance() <= dist.GetDistance() {
res.Results = append(res.Results, dist)
} else {
res.Results = append([]*payload.Object_Distance{dist}, res.Results[0])
}
continue
}
default:
var pos int
for idx := rl; idx >= 1; idx-- {
if res.GetResults()[idx-1].GetDistance() <= dist.GetDistance() {
pos = idx - 1
break
}
}

pos := len(res.GetResults())
for idx := pos; idx >= 1; idx-- {
if res.GetResults()[idx-1].GetDistance() <= dist.GetDistance() {
pos = idx - 1
break
switch {
case pos == len(res.GetResults()):
res.Results = append([]*payload.Object_Distance{dist}, res.Results...)
case pos == len(res.GetResults())-1:
res.Results = append(res.GetResults(), dist)
case pos >= 0:
res.Results = append(res.GetResults()[:pos+1], res.GetResults()[pos:]...)
res.Results[pos+1] = dist
}
}
switch {
case pos == len(res.GetResults()):
res.Results = append([]*payload.Object_Distance{dist}, res.Results...)
case pos == len(res.GetResults())-1:
res.Results = append(res.GetResults(), dist)
case pos >= 0:
res.Results = append(res.GetResults()[:pos+1], res.GetResults()[pos:]...)
res.Results[pos+1] = dist
}
if len(res.GetResults()) > num && num != 0 {
rl = len(res.GetResults())
if rl > num && num != 0 {
res.Results = res.GetResults()[:num]
rl = len(res.GetResults())
}
if distEnd := res.GetResults()[rl].GetDistance(); rl >= num &&
distEnd < math.Float32frombits(atomic.LoadUint32(&maxDist)) {
atomic.StoreUint32(&maxDist, math.Float32bits(distEnd))
}
}
}
Expand Down

0 comments on commit ac40a15

Please sign in to comment.