From 39a969b16bcd987cc82f862ac90a3a8bf9015be5 Mon Sep 17 00:00:00 2001 From: "C.C" Date: Sun, 24 Mar 2024 20:10:16 -0700 Subject: [PATCH] fix: ai service race (#768) --- pkg/bridge/ai/api_server.go | 9 --------- pkg/bridge/ai/service.go | 38 +++++++++++++++++-------------------- 2 files changed, 17 insertions(+), 30 deletions(-) diff --git a/pkg/bridge/ai/api_server.go b/pkg/bridge/ai/api_server.go index b8adc330e..461b3fa64 100644 --- a/pkg/bridge/ai/api_server.go +++ b/pkg/bridge/ai/api_server.go @@ -7,7 +7,6 @@ import ( "fmt" "net" "net/http" - "sync" "time" gonanoid "github.com/matoous/go-nanoid/v2" @@ -127,14 +126,6 @@ func HandleInvoke(w http.ResponseWriter, r *http.Request) { return } - ci := &CacheItem{ - wg: &sync.WaitGroup{}, - ResponseWriter: w, - } - if _, ok := service.cache[reqID]; !ok { - service.cache[reqID] = ci - } - var req ai.InvokeRequest req.ReqID = reqID diff --git a/pkg/bridge/ai/service.go b/pkg/bridge/ai/service.go index d7329d5eb..cd6ec6ff5 100644 --- a/pkg/bridge/ai/service.go +++ b/pkg/bridge/ai/service.go @@ -2,7 +2,6 @@ package ai import ( "fmt" - "net/http" "sync" "time" @@ -24,27 +23,19 @@ var ( services *expirable.LRU[string, *Service] ) -// CacheItem cache the http.ResponseWriter, which is used for writing response from reducer. -// TODO: http.ResponseWriter is from the SimpleRestfulServer interface, should be decoupled -// from here. -type CacheItem struct { - ResponseWriter http.ResponseWriter - wg *sync.WaitGroup - mu sync.Mutex -} - // Service is used to invoke LLM Provider to get the functions to be executed, // then, use source to send arguments which returned by llm provider to target // function. Finally, use reducer to aggregate all the results, and write the // result by the http.ResponseWriter. type Service struct { - credential string - zipperAddr string - md metadata.M - source yomo.Source - reducer yomo.StreamFunction - cache map[string]*CacheItem + credential string + zipperAddr string + md metadata.M + source yomo.Source + reducer yomo.StreamFunction + // cache map[string]*CacheItem sfnCallCache map[string]*sfnAsyncCall + muCallCache sync.Mutex LLMProvider } @@ -72,9 +63,9 @@ func DefaultExchangeMetadataFunc(credential string) (metadata.M, error) { func newService(credential string, zipperAddr string, aiProvider LLMProvider, exFn ExchangeMetadataFunc) (*Service, error) { s := &Service{ - credential: credential, - zipperAddr: zipperAddr, - cache: make(map[string]*CacheItem), + credential: credential, + zipperAddr: zipperAddr, + // cache: make(map[string]*CacheItem), LLMProvider: aiProvider, sfnCallCache: make(map[string]*sfnAsyncCall), } @@ -116,7 +107,6 @@ func (s *Service) Release() { if s.reducer != nil { s.reducer.Close() } - clear(s.cache) } func (s *Service) createSource() (yomo.Source, error) { @@ -156,7 +146,9 @@ func (s *Service) createReducer() (yomo.StreamFunction, error) { reqID := invoke.ReqID // write parallel function calling results to cache, after all the results are written, the reducer will be done + s.muCallCache.Lock() c, ok := s.sfnCallCache[reqID] + s.muCallCache.Unlock() if !ok { ylog.Error("[sfn-reducer] req_id not found", "req_id", reqID) return @@ -242,7 +234,9 @@ func (s *Service) runFunctionCalls(fns map[uint32][]*ai.ToolCall, reqID string) wg: &sync.WaitGroup{}, val: make(map[string]ai.ToolMessage), } + s.muCallCache.Lock() s.sfnCallCache[reqID] = asyncCall + s.muCallCache.Unlock() for tag, tcs := range fns { ylog.Debug("+++invoke toolCalls", "tag", tag, "len(toolCalls)", len(tcs), "reqID", reqID) @@ -262,11 +256,13 @@ func (s *Service) runFunctionCalls(fns map[uint32][]*ai.ToolCall, reqID string) arr := make([]ai.ToolMessage, 0) + asyncCall.mu.RLock() for _, call := range asyncCall.val { ylog.Debug("---invoke done", "id", call.ToolCallId, "content", call.Content) call.Role = "tool" arr = append(arr, call) } + asyncCall.mu.RUnlock() return arr, nil } @@ -307,6 +303,6 @@ func init() { type sfnAsyncCall struct { wg *sync.WaitGroup - mu sync.Mutex + mu sync.RWMutex val map[string]ai.ToolMessage }