From b062cee9efc455b6a7c947a647800f5a09113bf4 Mon Sep 17 00:00:00 2001 From: woorui Date: Tue, 13 Aug 2024 15:57:02 +0800 Subject: [PATCH 1/3] refactor: new caller from interfaces --- pkg/bridge/ai/api_server.go | 105 ++--- pkg/bridge/ai/api_server_test.go | 14 +- pkg/bridge/ai/call_syncer.go | 38 -- pkg/bridge/ai/call_syncer_test.go | 22 +- pkg/bridge/ai/caller.go | 707 ++++-------------------------- pkg/bridge/ai/caller_test.go | 497 +-------------------- pkg/bridge/ai/service.go | 613 ++++++++++++++++++++++++++ pkg/bridge/ai/service_test.go | 495 +++++++++++++++++++++ 8 files changed, 1276 insertions(+), 1215 deletions(-) create mode 100644 pkg/bridge/ai/service.go create mode 100644 pkg/bridge/ai/service_test.go diff --git a/pkg/bridge/ai/api_server.go b/pkg/bridge/ai/api_server.go index e42817df7..ddf2bcc62 100644 --- a/pkg/bridge/ai/api_server.go +++ b/pkg/bridge/ai/api_server.go @@ -35,7 +35,6 @@ type BasicAPIServer struct { zipperAddr string credential string httpHandler http.Handler - logger *slog.Logger } // Serve starts the Basic API Server @@ -44,19 +43,20 @@ func Serve(config *Config, zipperListenAddr string, credential string, logger *s if err != nil { return err } - srv, err := NewBasicAPIServer(config, zipperListenAddr, provider, credential, logger) + srv, err := NewBasicAPIServer(config, zipperListenAddr, credential, provider, logger) if err != nil { return err } logger.Info("start bridge server", "addr", config.Server.Addr, "provider", provider.Name()) - return srv.ServeAddr(config.Server.Addr) + return http.ListenAndServe(config.Server.Addr, srv.httpHandler) } -func BridgeHTTPHanlder(provider provider.LLMProvider, decorater func(http.Handler) http.Handler) http.Handler { +// NewServeMux creates a new http.ServeMux for the llm bridge server. +func NewServeMux(service *Service) *http.ServeMux { var ( + h = &Handler{service} mux = http.NewServeMux() - h = NewHandler(provider) ) // GET /overview mux.HandleFunc("/overview", h.HandleOverview) @@ -65,57 +65,59 @@ func BridgeHTTPHanlder(provider provider.LLMProvider, decorater func(http.Handle // POST /v1/chat/completions (OpenAI compatible interface) mux.HandleFunc("/v1/chat/completions", h.HandleChatCompletions) - return decorater(mux) + return mux +} + +// DecorateHandler decorates the http.Handler. +func DecorateHandler(h http.Handler, decorates ...func(handler http.Handler) http.Handler) http.Handler { + // decorate the http.Handler + for i := len(decorates) - 1; i >= 0; i-- { + h = decorates[i](h) + } + return h } // NewBasicAPIServer creates a new restful service -func NewBasicAPIServer(config *Config, zipperAddr string, provider provider.LLMProvider, credential string, logger *slog.Logger) (*BasicAPIServer, error) { +func NewBasicAPIServer(config *Config, zipperAddr, credential string, provider provider.LLMProvider, logger *slog.Logger) (*BasicAPIServer, error) { zipperAddr = parseZipperAddr(zipperAddr) - cp := NewCallerProvider(zipperAddr, DefaultExchangeMetadataFunc) + logger = logger.With("component", "bridge") + + service := NewService(provider, DefaultComponentCreator(zipperAddr), &ServiceOption{ + Logger: logger, + Tracer: otel.Tracer("yomo-llm-bridge"), + CredentialFunc: func(r *http.Request) (string, error) { return credential, nil }, + }) + + mux := NewServeMux(service) server := &BasicAPIServer{ zipperAddr: zipperAddr, credential: credential, - httpHandler: BridgeHTTPHanlder(provider, decorateReqContext(cp, logger, credential)), - logger: logger.With("component", "bridge"), + httpHandler: DecorateHandler(mux, decorateReqContext(service, logger)), } return server, nil } -// ServeAddr starts a http server that provides some endpoints to bridge up the http server and YoMo. -// User can chat to the http server and interact with the YoMo's stream function. -func (a *BasicAPIServer) ServeAddr(addr string) error { - return http.ListenAndServe(addr, a.httpHandler) -} - -// decorateReqContext decorates the context of the request, it injects a transID and a caller into the context. -func decorateReqContext(cp CallerProvider, logger *slog.Logger, credential string) func(handler http.Handler) http.Handler { - tracer := otel.Tracer("yomo-llm-bridge") - - caller, err := cp.Provide(credential) - if err != nil { - logger.Info("can't load caller", "err", err) - - return func(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - RespondWithError(w, http.StatusInternalServerError, err) - }) - } - } - - caller.SetTracer(tracer) - +// decorateReqContext decorates the context of the request, it injects a transID into the request's context, +// log the request information and start tracing the request. +func decorateReqContext(service *Service, logger *slog.Logger) func(handler http.Handler) http.Handler { host, _ := os.Hostname() return func(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + caller, err := service.LoadOrCreateCaller(r) + if err != nil { + RespondWithError(w, http.StatusBadRequest, err) + return + } + ctx = WithCallerContext(ctx, caller) + // trace every request - ctx, span := tracer.Start( + ctx, span := service.option.Tracer.Start( ctx, r.URL.Path, trace.WithSpanKind(trace.SpanKindServer), @@ -125,7 +127,6 @@ func decorateReqContext(cp CallerProvider, logger *slog.Logger, credential strin transID := id.New(32) ctx = WithTransIDContext(ctx, transID) - ctx = WithCallerContext(ctx, caller) logger.Info("request", "method", r.Method, "path", r.URL.Path, "transID", transID) @@ -136,24 +137,16 @@ func decorateReqContext(cp CallerProvider, logger *slog.Logger, credential strin // Handler handles the http request. type Handler struct { - provider provider.LLMProvider -} - -// NewHandler returns a new Handler. -func NewHandler(provider provider.LLMProvider) *Handler { - return &Handler{provider} + service *Service } // HandleOverview is the handler for GET /overview func (h *Handler) HandleOverview(w http.ResponseWriter, r *http.Request) { - caller := FromCallerContext(r.Context()) - w.Header().Set("Content-Type", "application/json") - tcs, err := register.ListToolCalls(caller.Metadata()) + tcs, err := register.ListToolCalls(FromCallerContext(r.Context()).Metadata()) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + RespondWithError(w, http.StatusInternalServerError, err) return } @@ -172,7 +165,6 @@ var baseSystemMessage = `You are a very helpful assistant. Your job is to choose func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) { var ( ctx = r.Context() - caller = FromCallerContext(ctx) transID = FromTransIDContext(ctx) ) defer r.Body.Close() @@ -185,14 +177,14 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), RequestTimeout) defer cancel() - res, err := GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, req.IncludeCallStack, caller, h.provider) + w.Header().Set("Content-Type", "application/json") + + res, err := h.service.GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, FromCallerContext(ctx), req.IncludeCallStack) if err != nil { - w.Header().Set("Content-Type", "application/json") RespondWithError(w, http.StatusInternalServerError, err) return } - w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _ = json.NewEncoder(w).Encode(res) } @@ -201,7 +193,6 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) { func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request) { var ( ctx = r.Context() - caller = FromCallerContext(ctx) transID = FromTransIDContext(ctx) ) defer r.Body.Close() @@ -214,7 +205,7 @@ func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request) ctx, cancel := context.WithTimeout(r.Context(), RequestTimeout) defer cancel() - if err := GetChatCompletions(ctx, req, transID, h.provider, caller, w); err != nil { + if err := h.service.GetChatCompletions(ctx, req, transID, FromCallerContext(ctx), w); err != nil { RespondWithError(w, http.StatusBadRequest, err) return } @@ -258,17 +249,17 @@ func getLocalIP() (string, error) { type callerContextKey struct{} // WithCallerContext adds the caller to the request context -func WithCallerContext(ctx context.Context, caller Caller) context.Context { +func WithCallerContext(ctx context.Context, caller *Caller) context.Context { return context.WithValue(ctx, callerContextKey{}, caller) } // FromCallerContext returns the caller from the request context -func FromCallerContext(ctx context.Context) Caller { - service, ok := ctx.Value(callerContextKey{}).(Caller) +func FromCallerContext(ctx context.Context) *Caller { + caller, ok := ctx.Value(callerContextKey{}).(*Caller) if !ok { return nil } - return service + return caller } type transIDContextKey struct{} diff --git a/pkg/bridge/ai/api_server_test.go b/pkg/bridge/ai/api_server_test.go index 5242c6161..f0e456b29 100644 --- a/pkg/bridge/ai/api_server_test.go +++ b/pkg/bridge/ai/api_server_test.go @@ -4,13 +4,15 @@ import ( "bytes" "fmt" "io" - "log/slog" "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/assert" + "github.com/yomorun/yomo" "github.com/yomorun/yomo/ai" + "github.com/yomorun/yomo/core/metadata" "github.com/yomorun/yomo/pkg/bridge/ai/provider" "github.com/yomorun/yomo/pkg/bridge/ai/register" ) @@ -38,11 +40,15 @@ func TestServer(t *testing.T) { t.Fatal(err) } - cp := newMockCallerProvider() + cc := &testComponentCreator{flow: newMockDataFlow(newHandler(2 * time.Hour).handle)} - cp.provideFunc = mockCallerProvideFunc(map[uint32][]mockFunctionCall{}) + newCaller := func(_ yomo.Source, _ yomo.StreamFunction, _ metadata.M, _ time.Duration) (*Caller, error) { + return mockCaller(nil), err + } + + service := newService(pd, cc, newCaller, nil) - handler := BridgeHTTPHanlder(pd, decorateReqContext(cp, slog.Default(), "")) + handler := DecorateHandler(NewServeMux(service), decorateReqContext(service, service.logger)) // create a test server server := httptest.NewServer(handler) diff --git a/pkg/bridge/ai/call_syncer.go b/pkg/bridge/ai/call_syncer.go index 048b4e7a0..594fa9f85 100644 --- a/pkg/bridge/ai/call_syncer.go +++ b/pkg/bridge/ai/call_syncer.go @@ -6,9 +6,7 @@ import ( "time" openai "github.com/sashabaranov/go-openai" - "github.com/yomorun/yomo" "github.com/yomorun/yomo/ai" - "github.com/yomorun/yomo/serverless" ) // CallSyncer fires a bunch of function callings, and wait the result of these function callings. @@ -223,39 +221,3 @@ func (f *callSyncer) background() { } } } - -// ToReducer converts a stream function to a reducer that can reduce the function calling result. -func ToReducer(sfn yomo.StreamFunction, logger *slog.Logger, ch chan ReduceMessage) { - // set observe data tags - sfn.SetObserveDataTags(ai.ReducerTag) - // set reduce handler - sfn.SetHandler(func(ctx serverless.Context) { - invoke, err := ctx.LLMFunctionCall() - if err != nil { - ch <- ReduceMessage{ReqID: ""} - logger.Error("parse function calling invoke", "err", err.Error()) - return - } - logger.Debug("sfn-reducer", "req_id", invoke.ReqID, "tool_call_id", invoke.ToolCallID, "result", string(invoke.Result)) - - message := openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleTool, - Content: invoke.Result, - ToolCallID: invoke.ToolCallID, - } - - ch <- ReduceMessage{ReqID: invoke.ReqID, Message: message} - }) -} - -// ToSource convert a yomo source to the source that can send function calling body to the llm function. -func ToSource(source yomo.Source, logger *slog.Logger, ch chan TagFunctionCall) { - go func() { - for c := range ch { - buf, _ := c.FunctionCall.Bytes() - if err := source.Write(c.Tag, buf); err != nil { - logger.Error("send data to zipper", "err", err.Error()) - } - } - }() -} diff --git a/pkg/bridge/ai/call_syncer_test.go b/pkg/bridge/ai/call_syncer_test.go index 1eb6e525b..31884aafb 100644 --- a/pkg/bridge/ai/call_syncer_test.go +++ b/pkg/bridge/ai/call_syncer_test.go @@ -27,13 +27,10 @@ func TestTimeoutCallSyncer(t *testing.T) { flow := newMockDataFlow(h.handle) defer flow.Close() - reqs := make(chan TagFunctionCall) - ToSource(flow, slog.Default(), reqs) + req, _ := sourceWriteToChan(flow, slog.Default()) + res, _ := reduceToChan(flow, slog.Default()) - messages := make(chan ReduceMessage) - ToReducer(flow, slog.Default(), messages) - - syncer := NewCallSyncer(slog.Default(), reqs, messages, time.Millisecond) + syncer := NewCallSyncer(slog.Default(), req, res, time.Millisecond) go flow.run() var ( @@ -61,13 +58,10 @@ func TestCallSyncer(t *testing.T) { flow := newMockDataFlow(h.handle) defer flow.Close() - reqs := make(chan TagFunctionCall) - ToSource(flow, slog.Default(), reqs) - - messages := make(chan ReduceMessage) - ToReducer(flow, slog.Default(), messages) + req, _ := sourceWriteToChan(flow, slog.Default()) + res, _ := reduceToChan(flow, slog.Default()) - syncer := NewCallSyncer(slog.Default(), reqs, messages, 0) + syncer := NewCallSyncer(slog.Default(), req, res, 0) go flow.run() var ( @@ -118,7 +112,7 @@ func (h *handler) result() []openai.ChatCompletionMessage { return want } -// mockDataFlow mocks the data flow of ai bridge. +// mockDataFlow mocks the data flow of llm bridge. // The data flow is: source -> hander -> reducer, // It is `Write() -> handler() -> reducer()` in this mock implementation. type mockDataFlow struct { @@ -160,11 +154,11 @@ var _ yomo.StreamFunction = (*mockDataFlow)(nil) // The test will not use blowing function in this mock implementation. func (t *mockDataFlow) SetObserveDataTags(tag ...uint32) {} +func (t *mockDataFlow) Connect() error { return nil } func (t *mockDataFlow) Init(fn func() error) error { panic("unimplemented") } func (t *mockDataFlow) SetCronHandler(spec string, fn core.CronHandler) error { panic("unimplemented") } func (t *mockDataFlow) SetPipeHandler(fn core.PipeHandler) error { panic("unimplemented") } func (t *mockDataFlow) SetWantedTarget(string) { panic("unimplemented") } func (t *mockDataFlow) Wait() { panic("unimplemented") } -func (t *mockDataFlow) Connect() error { panic("unimplemented") } func (t *mockDataFlow) SetErrorHandler(fn func(err error)) { panic("unimplemented") } func (t *mockDataFlow) WriteWithTarget(_ uint32, _ []byte, _ string) error { panic("unimplemented") } diff --git a/pkg/bridge/ai/caller.go b/pkg/bridge/ai/caller.go index 33c4b69a7..ddbfa04e7 100644 --- a/pkg/bridge/ai/caller.go +++ b/pkg/bridge/ai/caller.go @@ -1,134 +1,46 @@ package ai import ( - "context" - "encoding/json" - "fmt" - "io" "log/slog" - "net/http" - "strings" "sync/atomic" "time" - "github.com/hashicorp/golang-lru/v2/expirable" openai "github.com/sashabaranov/go-openai" "github.com/yomorun/yomo" "github.com/yomorun/yomo/ai" + "github.com/yomorun/yomo/core" "github.com/yomorun/yomo/core/metadata" "github.com/yomorun/yomo/core/ylog" - "github.com/yomorun/yomo/pkg/bridge/ai/provider" - "github.com/yomorun/yomo/pkg/bridge/ai/register" - "github.com/yomorun/yomo/pkg/id" - "go.opentelemetry.io/otel/trace" - "go.opentelemetry.io/otel/trace/noop" + "github.com/yomorun/yomo/serverless" ) -var ( - // CallerProviderCacheSize is the size of the caller provider cache - CallerProviderCacheSize = 1024 - // CallerProviderCacheTTL is the time to live of the provider cache - CallerProviderCacheTTL = time.Minute * 0 -) - -// CallerProvider provides the caller, which is used to interact with YoMo's stream function. -type CallerProvider interface { - Provide(credential string) (Caller, error) -} - -type callerProvider struct { - zipperAddr string - exFn ExchangeMetadataFunc - provideFunc provideFunc - callers *expirable.LRU[string, Caller] -} - -type provideFunc func(string, string, ExchangeMetadataFunc) (Caller, error) - -// NewCallerProvider returns a new caller provider. -func NewCallerProvider(zipperAddr string, exFn ExchangeMetadataFunc) CallerProvider { - return newCallerProvider(zipperAddr, exFn, NewCaller) -} - -func newCallerProvider(zipperAddr string, exFn ExchangeMetadataFunc, provideFunc provideFunc) CallerProvider { - p := &callerProvider{ - zipperAddr: zipperAddr, - exFn: exFn, - provideFunc: provideFunc, - callers: expirable.NewLRU(CallerProviderCacheSize, func(_ string, caller Caller) { caller.Close() }, CallerProviderCacheTTL), - } - - return p -} - -// Provide provides the caller according to the credential. -func (p *callerProvider) Provide(credential string) (Caller, error) { - caller, ok := p.callers.Get(credential) - if ok { - return caller, nil - } - - caller, err := p.provideFunc(credential, p.zipperAddr, p.exFn) - if err != nil { - return nil, err - } - p.callers.Add(credential, caller) - - return caller, nil -} - // Caller calls the invoke function and keeps the metadata and system prompt. -type Caller interface { - // Call calls the invoke function. - CallSyncer - // SetSystemPrompt sets the system prompt of the caller. - SetSystemPrompt(string) - // GetSystemPrompt returns the system prompt of the caller. - GetSystemPrompt() string - // SetTracer sets the tracer of the caller. - SetTracer(trace.Tracer) - // GetTracer returns the tracer of the caller. - GetTracer() trace.Tracer - // Metadata returns the metadata of the caller. - Metadata() metadata.M - // Close closes the caller, if the caller is closed, the caller will not be reused. - Close() error -} - -type caller struct { +type Caller struct { CallSyncer - source yomo.Source - reducer yomo.StreamFunction - - tracer atomic.Value - credential string + source yomo.Source + reducer yomo.StreamFunction md metadata.M systemPrompt atomic.Value logger *slog.Logger } // NewCaller returns a new caller. -func NewCaller(credential string, zipperAddr string, exFn ExchangeMetadataFunc) (Caller, error) { +func NewCaller(source yomo.Source, reducer yomo.StreamFunction, md metadata.M, callTimeout time.Duration) (*Caller, error) { logger := ylog.Default() - source, reqCh, err := ChanToSource(zipperAddr, credential, logger) + reqCh, err := sourceWriteToChan(source, logger) if err != nil { return nil, err } - reducer, resCh, err := ReduceToChan(zipperAddr, credential, logger) + resCh, err := reduceToChan(reducer, logger) if err != nil { return nil, err } - callSyncer := NewCallSyncer(logger, reqCh, resCh, 60*time.Second) - - md, err := exFn(credential) - if err != nil { - return nil, err - } + callSyncer := NewCallSyncer(logger, reqCh, resCh, callTimeout) - caller := &caller{ + caller := &Caller{ CallSyncer: callSyncer, source: source, reducer: reducer, @@ -136,59 +48,73 @@ func NewCaller(credential string, zipperAddr string, exFn ExchangeMetadataFunc) logger: logger, } - caller.SetSystemPrompt("") - return caller, nil } -// ChanToSource creates a yomo source and a channel, -// The ai.FunctionCall objects are continuously be received from the channel and be sent by the source. -func ChanToSource(zipperAddr, credential string, logger *slog.Logger) (yomo.Source, chan<- TagFunctionCall, error) { - source := yomo.NewSource( - "fc-source", - zipperAddr, - yomo.WithSourceReConnect(), - yomo.WithCredential(credential), - ) +// sourceWriteToChan makes source write data to the channel. +// The TagFunctionCall objects are continuously be received from the channel and be sent by the source. +func sourceWriteToChan(source yomo.Source, logger *slog.Logger) (chan<- TagFunctionCall, error) { err := source.Connect() if err != nil { - return nil, nil, err + return nil, err } ch := make(chan TagFunctionCall) - ToSource(source, logger, ch) + go func() { + for c := range ch { + buf, _ := c.FunctionCall.Bytes() + if err := source.Write(c.Tag, buf); err != nil { + logger.Error("send data to zipper", "err", err.Error()) + } + } + }() - return source, ch, nil + return ch, nil } -// ReduceToChan creates a yomo stream function to reduce the messages and returns both. -func ReduceToChan(zipperAddr, credential string, logger *slog.Logger) (yomo.StreamFunction, <-chan ReduceMessage, error) { - reducer := yomo.NewStreamFunction( - "ai-reducer", - zipperAddr, - yomo.WithSfnReConnect(), - yomo.WithSfnCredential(credential), - yomo.DisableOtelTrace(), - ) +// reduceToChan configures the reducer and returns a channel to accept messages from the reducer. +func reduceToChan(reducer yomo.StreamFunction, logger *slog.Logger) (<-chan ReduceMessage, error) { reducer.SetObserveDataTags(ai.ReducerTag) messages := make(chan ReduceMessage) - ToReducer(reducer, logger, messages) + + reducer.SetObserveDataTags(ai.ReducerTag) + reducer.SetHandler(reduceFunc(messages, logger)) if err := reducer.Connect(); err != nil { - return reducer, nil, err + return nil, err } - return reducer, messages, nil + return messages, nil +} + +func reduceFunc(messages chan ReduceMessage, logger *slog.Logger) core.AsyncHandler { + return func(ctx serverless.Context) { + invoke, err := ctx.LLMFunctionCall() + if err != nil { + messages <- ReduceMessage{ReqID: ""} + logger.Error("parse function calling invoke", "err", err.Error()) + return + } + logger.Debug("sfn-reducer", "req_id", invoke.ReqID, "tool_call_id", invoke.ToolCallID, "result", string(invoke.Result)) + + message := openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleTool, + Content: invoke.Result, + ToolCallID: invoke.ToolCallID, + } + + messages <- ReduceMessage{ReqID: invoke.ReqID, Message: message} + } } // SetSystemPrompt sets the system prompt -func (c *caller) SetSystemPrompt(prompt string) { +func (c *Caller) SetSystemPrompt(prompt string) { c.systemPrompt.Store(prompt) } // SetSystemPrompt gets the system prompt -func (c *caller) GetSystemPrompt() string { +func (c *Caller) GetSystemPrompt() string { if v := c.systemPrompt.Load(); v != nil { return v.(string) } @@ -196,25 +122,12 @@ func (c *caller) GetSystemPrompt() string { } // Metadata returns the metadata of caller. -func (c *caller) Metadata() metadata.M { +func (c *Caller) Metadata() metadata.M { return c.md } -// SetTracer sets the otel tracer. -func (c *caller) SetTracer(tracer trace.Tracer) { - c.tracer.Store(tracer) -} - -// GetTracer gets the otel tracer. -func (c *caller) GetTracer() trace.Tracer { - if v := c.tracer.Load(); v != nil { - return v.(trace.Tracer) - } - return noop.NewTracerProvider().Tracer("yomo-llm-bridge") -} - // Close closes the caller. -func (c *caller) Close() error { +func (c *Caller) Close() error { _ = c.CallSyncer.Close() var err error @@ -229,500 +142,42 @@ func (c *caller) Close() error { return err } -// GetInvoke returns the invoke response -func GetInvoke( - ctx context.Context, - userInstruction string, baseSystemMessage string, transID string, - includeCallStack bool, - caller Caller, provider provider.LLMProvider, -) (*ai.InvokeResponse, error) { - md := caller.Metadata().Clone() - // read tools attached to the metadata - tcs, err := register.ListToolCalls(md) - if err != nil { - return &ai.InvokeResponse{}, err - } - // prepare tools - tools := prepareToolCalls(tcs) - - chainMessage := ai.ChainMessage{} - messages := prepareMessages(baseSystemMessage, userInstruction, chainMessage, tools, true) - req := openai.ChatCompletionRequest{ - Messages: messages, - } - // with tools - if len(tools) > 0 { - req.Tools = tools - } - var ( - promptUsage int - completionUsage int - ) - _, span := caller.GetTracer().Start(ctx, "first_call") - chatCompletionResponse, err := provider.GetChatCompletions(ctx, req, md) - if err != nil { - return nil, err - } - span.End() - promptUsage = chatCompletionResponse.Usage.PromptTokens - completionUsage = chatCompletionResponse.Usage.CompletionTokens - - // convert ChatCompletionResponse to InvokeResponse - res, err := ai.ConvertToInvokeResponse(&chatCompletionResponse, tcs) - if err != nil { - return nil, err - } - // if no tool_calls fired, just return the llm text result - if res.FinishReason != string(openai.FinishReasonToolCalls) { - return res, nil - } - - // run llm function calls - ylog.Debug(">>>> start 1st call response", - "res_toolcalls", fmt.Sprintf("%+v", res.ToolCalls), - "res_assistant_msgs", fmt.Sprintf("%+v", res.AssistantMessage)) - - ylog.Debug(">> run function calls", "transID", transID, "res.ToolCalls", fmt.Sprintf("%+v", res.ToolCalls)) - - _, span = caller.GetTracer().Start(ctx, "run_sfn") - reqID := id.New(16) - llmCalls, err := caller.Call(ctx, transID, reqID, res.ToolCalls) - if err != nil { - return nil, err - } - span.End() - - ylog.Debug(">>>> start 2nd call with", "calls", fmt.Sprintf("%+v", llmCalls), "preceeding_assistant_message", fmt.Sprintf("%+v", res.AssistantMessage)) - - chainMessage.PreceedingAssistantMessage = res.AssistantMessage - chainMessage.ToolMessages = transToolMessage(llmCalls) - // do not attach toolMessage to prompt in 2nd call - messages2 := prepareMessages(baseSystemMessage, userInstruction, chainMessage, tools, false) - req2 := openai.ChatCompletionRequest{ - Messages: messages2, - } - _, span = caller.GetTracer().Start(ctx, "second_call") - chatCompletionResponse2, err := provider.GetChatCompletions(ctx, req2, md) - if err != nil { - return nil, err - } - span.End() - - chatCompletionResponse2.Usage.PromptTokens += promptUsage - chatCompletionResponse2.Usage.CompletionTokens += completionUsage - - res2, err := ai.ConvertToInvokeResponse(&chatCompletionResponse2, tcs) - if err != nil { - return nil, err - } - - // INFO: call stack infomation - if includeCallStack { - res2.ToolCalls = res.ToolCalls - res2.ToolMessages = transToolMessage(llmCalls) - } - ylog.Debug("<<<< complete 2nd call", "res2", fmt.Sprintf("%+v", res2)) - - return res2, err -} - -// GetChatCompletions accepts openai.ChatCompletionRequest and responds to http.ResponseWriter. -func GetChatCompletions( - ctx context.Context, - req openai.ChatCompletionRequest, transID string, - provider provider.LLMProvider, caller Caller, - w http.ResponseWriter, -) error { - reqCtx, reqSpan := caller.GetTracer().Start(ctx, "completions_request") - md := caller.Metadata().Clone() - - // 1. find all hosting tool sfn - tagTools, err := register.ListToolCalls(md) - if err != nil { - return err - } - // 2. add those tools to request - req = addToolsToRequest(req, tagTools) - - // 3. over write system prompt to request - req = overWriteSystemPrompt(req, caller.GetSystemPrompt()) - - var ( - promptUsage = 0 - completionUsage = 0 - totalUsage = 0 - reqMessages = req.Messages - toolCallsMap = make(map[int]openai.ToolCall) - toolCalls = []openai.ToolCall{} - assistantMessage = openai.ChatCompletionMessage{} - ) - // 4. request first chat for getting tools - if req.Stream { - _, firstCallSpan := caller.GetTracer().Start(reqCtx, "first_call_request") - var ( - flusher = eventFlusher(w) - isFunctionCall = false - ) - resStream, err := provider.GetChatCompletionsStream(reqCtx, req, md) - if err != nil { - return err - } - - var ( - i int // number of chunks - j int // number of tool call chunks - firstRespSpan trace.Span - respSpan trace.Span - ) - for { - if i == 0 { - _, firstRespSpan = caller.GetTracer().Start(reqCtx, "first_call_response_in_stream") - } - streamRes, err := resStream.Recv() - if err == io.EOF { - break - } - if err != nil { - return err - } - if len(streamRes.Choices) == 0 { - continue - } - if streamRes.Usage != nil { - promptUsage = streamRes.Usage.PromptTokens - completionUsage = streamRes.Usage.CompletionTokens - totalUsage = streamRes.Usage.TotalTokens - } - if tc := streamRes.Choices[0].Delta.ToolCalls; len(tc) > 0 { - isFunctionCall = true - if j == 0 { - firstCallSpan.End() - } - for _, t := range tc { - // this index should be toolCalls slice's index, the index field only appares in stream response - index := *t.Index - item, ok := toolCallsMap[index] - if !ok { - toolCallsMap[index] = openai.ToolCall{ - Index: t.Index, - ID: t.ID, - Type: t.Type, - Function: openai.FunctionCall{}, - } - item = toolCallsMap[index] - } - if t.Function.Arguments != "" { - item.Function.Arguments += t.Function.Arguments - } - if t.Function.Name != "" { - item.Function.Name = t.Function.Name - } - toolCallsMap[index] = item - } - j++ - } else if streamRes.Choices[0].FinishReason != openai.FinishReasonToolCalls { - _ = writeStreamEvent(w, flusher, streamRes) - } - if i == 0 && j == 0 && !isFunctionCall { - reqSpan.End() - recordTTFT(ctx, caller.GetTracer()) - _, respSpan = caller.GetTracer().Start(ctx, "response_in_stream(TBT)") - } - i++ - } - if !isFunctionCall { - respSpan.End() - return writeStreamDone(w, flusher) - } - firstRespSpan.End() - toolCalls = mapToSliceTools(toolCallsMap) - - assistantMessage = openai.ChatCompletionMessage{ - ToolCalls: toolCalls, - Role: openai.ChatMessageRoleAssistant, - } - reqSpan.End() - flusher.Flush() - } else { - _, firstCallSpan := caller.GetTracer().Start(reqCtx, "first_call") - resp, err := provider.GetChatCompletions(ctx, req, md) - if err != nil { - return err - } - reqSpan.End() - - promptUsage = resp.Usage.PromptTokens - completionUsage = resp.Usage.CompletionTokens - totalUsage = resp.Usage.CompletionTokens - - ylog.Debug(" #1 first call", "response", fmt.Sprintf("%+v", resp)) - // it is a function call - if resp.Choices[0].FinishReason == openai.FinishReasonToolCalls { - toolCalls = append(toolCalls, resp.Choices[0].Message.ToolCalls...) - assistantMessage = resp.Choices[0].Message - firstCallSpan.End() - } else { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) - return nil - } - } - - resCtx, resSpan := caller.GetTracer().Start(ctx, "completions_response") - defer resSpan.End() - - _, sfnSpan := caller.GetTracer().Start(resCtx, "run_sfn") - - // 5. find sfns that hit the function call - fnCalls := findTagTools(tagTools, toolCalls) - - // 6. run llm function calls - reqID := id.New(16) - llmCalls, err := caller.Call(ctx, transID, reqID, fnCalls) - if err != nil { - return err - } - sfnSpan.End() - - // 7. do the second call (the second call messages are from user input, first call resopnse and sfn calls result) - req.Messages = append(reqMessages, assistantMessage) - req.Messages = append(req.Messages, llmCalls...) - req.Tools = nil // reset tools field - - ylog.Debug(" #2 second call", "request", fmt.Sprintf("%+v", req)) - - if req.Stream { - _, secondCallSpan := caller.GetTracer().Start(resCtx, "second_call_request") - flusher := w.(http.Flusher) - resStream, err := provider.GetChatCompletionsStream(resCtx, req, md) - if err != nil { - return err - } - secondCallSpan.End() - - var ( - i int - secondRespSpan trace.Span - ) - for { - if i == 0 { - recordTTFT(resCtx, caller.GetTracer()) - _, secondRespSpan = caller.GetTracer().Start(resCtx, "second_call_response_in_stream(TBT)") - } - i++ - streamRes, err := resStream.Recv() - if err == io.EOF { - secondRespSpan.End() - return writeStreamDone(w, flusher) - } - if err != nil { - return err - } - if streamRes.Usage != nil { - streamRes.Usage.PromptTokens += promptUsage - streamRes.Usage.CompletionTokens += completionUsage - streamRes.Usage.TotalTokens += totalUsage - } - _ = writeStreamEvent(w, flusher, streamRes) - } - } else { - _, secondCallSpan := caller.GetTracer().Start(resCtx, "second_call") - - resp, err := provider.GetChatCompletions(resCtx, req, md) - if err != nil { - return err - } - - resp.Usage.PromptTokens += promptUsage - resp.Usage.CompletionTokens += completionUsage - resp.Usage.TotalTokens += totalUsage - - secondCallSpan.End() - w.Header().Set("Content-Type", "application/json") - return json.NewEncoder(w).Encode(resp) - } -} - -// ExchangeMetadataFunc is used to exchange metadata -type ExchangeMetadataFunc func(credential string) (metadata.M, error) - -// DefaultExchangeMetadataFunc is the default ExchangeMetadataFunc, It returns an empty metadata. -func DefaultExchangeMetadataFunc(credential string) (metadata.M, error) { - return metadata.M{}, nil +// ComponentCreator creates unconnected source, unconnected reducer, and exchange metadata from credential. +type ComponentCreator interface { + // CreateSource should creates an unconnected source. + CreateSource(credential string) yomo.Source + // CreateReducer should creates an unconnected reducer. + CreateReducer(credential string) yomo.StreamFunction + // ExchangeMetadata exchanges metadata from the credential. + ExchangeMetadata(credential string) (metadata.M, error) } -func addToolsToRequest(req openai.ChatCompletionRequest, tagTools map[uint32]openai.Tool) openai.ChatCompletionRequest { - toolCalls := prepareToolCalls(tagTools) - - if len(toolCalls) > 0 { - req.Tools = toolCalls - } - - ylog.Debug(" #1 first call", "request", fmt.Sprintf("%+v", req)) - - return req +type defaultComponentCreator struct { + zipperAddr string } -func overWriteSystemPrompt(req openai.ChatCompletionRequest, sysPrompt string) openai.ChatCompletionRequest { - // do nothing if system prompt is empty - if sysPrompt == "" { - return req +// DefaultComponentCreator returns a ComponentCreator that creates unconnected source, +// unconnected reducer, and exchange metadata from credential. +func DefaultComponentCreator(zipperAddr string) ComponentCreator { + return &defaultComponentCreator{ + zipperAddr: zipperAddr, } - // over write system prompt - isOverWrite := false - for i, msg := range req.Messages { - if msg.Role != "system" { - continue - } - req.Messages[i] = openai.ChatCompletionMessage{ - Role: msg.Role, - Content: sysPrompt, - } - isOverWrite = true - } - // append system prompt - if !isOverWrite { - req.Messages = append(req.Messages, openai.ChatCompletionMessage{ - Role: "system", - Content: sysPrompt, - }) - } - - ylog.Debug(" #1 first call after overwrite", "request", fmt.Sprintf("%+v", req)) - - return req -} - -func findTagTools(tagTools map[uint32]openai.Tool, toolCalls []openai.ToolCall) map[uint32][]*openai.ToolCall { - fnCalls := make(map[uint32][]*openai.ToolCall) - // functions may be more than one - for _, call := range toolCalls { - for tag, tc := range tagTools { - if tc.Function.Name == call.Function.Name && tc.Type == call.Type { - currentCall := call - fnCalls[tag] = append(fnCalls[tag], ¤tCall) - } - } - } - return fnCalls } -func writeStreamEvent(w http.ResponseWriter, flusher http.Flusher, streamRes openai.ChatCompletionStreamResponse) error { - if _, err := io.WriteString(w, "data: "); err != nil { - return err - } - if err := json.NewEncoder(w).Encode(streamRes); err != nil { - return err - } - if _, err := io.WriteString(w, "\n"); err != nil { - return err - } - flusher.Flush() - - return nil -} - -func writeStreamDone(w http.ResponseWriter, flusher http.Flusher) error { - _, err := io.WriteString(w, "data: [DONE]") - flusher.Flush() - - return err -} - -func prepareMessages(baseSystemMessage string, userInstruction string, chainMessage ai.ChainMessage, tools []openai.Tool, withTool bool) []openai.ChatCompletionMessage { - systemInstructions := []string{"## Instructions\n"} - - // only append if there are tool calls - if withTool { - for _, t := range tools { - systemInstructions = append(systemInstructions, "- ") - systemInstructions = append(systemInstructions, t.Function.Description) - systemInstructions = append(systemInstructions, "\n") - } - systemInstructions = append(systemInstructions, "\n") - } - - SystemPrompt := fmt.Sprintf("%s\n\n%s", baseSystemMessage, strings.Join(systemInstructions, "")) - - messages := []openai.ChatCompletionMessage{} - - // 1. system message - messages = append(messages, openai.ChatCompletionMessage{Role: "system", Content: SystemPrompt}) - - // 2. previous tool calls - // Ref: Tool Message Object in Messsages - // https://platform.openai.com/docs/guides/function-calling - // https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages - - if chainMessage.PreceedingAssistantMessage != nil { - // 2.1 assistant message - // try convert type of chainMessage.PreceedingAssistantMessage to type ChatCompletionMessage - assistantMessage, ok := chainMessage.PreceedingAssistantMessage.(openai.ChatCompletionMessage) - if ok { - ylog.Debug("======== add assistantMessage", "am", fmt.Sprintf("%+v", assistantMessage)) - messages = append(messages, assistantMessage) - } - - // 2.2 tool message - for _, tool := range chainMessage.ToolMessages { - tm := openai.ChatCompletionMessage{ - Role: "tool", - Content: tool.Content, - ToolCallID: tool.ToolCallID, - } - ylog.Debug("======== add toolMessage", "tm", fmt.Sprintf("%+v", tm)) - messages = append(messages, tm) - } - } - - // 3. user instruction - messages = append(messages, openai.ChatCompletionMessage{Role: "user", Content: userInstruction}) - - return messages -} - -func mapToSliceTools(m map[int]openai.ToolCall) []openai.ToolCall { - arr := make([]openai.ToolCall, len(m)) - for k, v := range m { - arr[k] = v - } - return arr -} - -func eventFlusher(w http.ResponseWriter) http.Flusher { - h := w.Header() - h.Set("Content-Type", "text/event-stream") - h.Set("Cache-Control", "no-cache, must-revalidate") - h.Set("x-content-type-options", "nosniff") - flusher := w.(http.Flusher) - return flusher -} - -func prepareToolCalls(tcs map[uint32]openai.Tool) []openai.Tool { - // prepare tools - toolCalls := make([]openai.Tool, len(tcs)) - idx := 0 - for _, tc := range tcs { - toolCalls[idx] = tc - idx++ - } - return toolCalls +func (c *defaultComponentCreator) CreateSource(credential string) yomo.Source { + return yomo.NewSource( + "fc-source", + c.zipperAddr, + yomo.WithSourceReConnect(), yomo.WithCredential(credential)) } -func transToolMessage(msgs []openai.ChatCompletionMessage) []ai.ToolMessage { - toolMessages := make([]ai.ToolMessage, len(msgs)) - for i, msg := range msgs { - toolMessages[i] = ai.ToolMessage{ - Role: msg.Role, - Content: msg.Content, - ToolCallID: msg.ToolCallID, - } - } - return toolMessages +func (c *defaultComponentCreator) CreateReducer(credential string) yomo.StreamFunction { + return yomo.NewStreamFunction( + "fc-reducer", + c.zipperAddr, + yomo.WithSfnReConnect(), yomo.WithSfnCredential(credential), yomo.DisableOtelTrace()) } -func recordTTFT(ctx context.Context, tracer trace.Tracer) { - _, span := tracer.Start(ctx, "TTFT") - span.End() - time.Sleep(time.Millisecond) +func (c *defaultComponentCreator) ExchangeMetadata(credential string) (metadata.M, error) { + return metadata.New(), nil } diff --git a/pkg/bridge/ai/caller_test.go b/pkg/bridge/ai/caller_test.go index e812ae49a..655311a97 100644 --- a/pkg/bridge/ai/caller_test.go +++ b/pkg/bridge/ai/caller_test.go @@ -1,499 +1,44 @@ package ai import ( - "context" - "errors" - "net/http/httptest" "testing" + "time" - "github.com/hashicorp/golang-lru/v2/expirable" - openai "github.com/sashabaranov/go-openai" "github.com/stretchr/testify/assert" - "github.com/yomorun/yomo/ai" + "github.com/yomorun/yomo" "github.com/yomorun/yomo/core/metadata" - "github.com/yomorun/yomo/pkg/bridge/ai/provider" - "github.com/yomorun/yomo/pkg/bridge/ai/register" ) -func TestCallerInvoke(t *testing.T) { - type args struct { - providerMockData []provider.MockData - mockCallReqResp map[uint32][]mockFunctionCall - systemPrompt string - userInstruction string - baseSystemMessage string - } - tests := []struct { - name string - args args - wantRequest []openai.ChatCompletionRequest - wantUsage ai.TokenUsage - }{ - { - name: "invoke with tool call", - args: args{ - providerMockData: []provider.MockData{ - provider.MockChatCompletionResponse(toolCallResp, stopResp), - }, - mockCallReqResp: map[uint32][]mockFunctionCall{ - // toolID should equal to toolCallResp's toolID - 0x33: {{toolID: "call_abc123", functionName: "get_current_weather", respContent: "temperature: 31°C"}}, - }, - systemPrompt: "this is a system prompt", - userInstruction: "hi", - baseSystemMessage: "this is a base system message", - }, - wantRequest: []openai.ChatCompletionRequest{ - { - Messages: []openai.ChatCompletionMessage{ - {Role: "system", Content: "this is a base system message\n\n## Instructions\n- \n\n"}, - {Role: "user", Content: "hi"}, - }, - Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}}, - }, - { - Messages: []openai.ChatCompletionMessage{ - {Role: "system", Content: "this is a base system message\n\n## Instructions\n"}, - {Role: "assistant", ToolCalls: []openai.ToolCall{{ID: "call_abc123", Type: openai.ToolTypeFunction, Function: openai.FunctionCall{Name: "get_current_weather", Arguments: "{\n\"location\": \"Boston, MA\"\n}"}}}}, - {Role: "tool", Content: "temperature: 31°C", ToolCallID: "call_abc123"}, - {Role: "user", Content: "hi"}, - }, - }, - }, - wantUsage: ai.TokenUsage{PromptTokens: 95, CompletionTokens: 43}, - }, - { - name: "invoke without tool call", - args: args{ - providerMockData: []provider.MockData{ - provider.MockChatCompletionResponse(stopResp), - }, - mockCallReqResp: map[uint32][]mockFunctionCall{}, - systemPrompt: "this is a system prompt", - userInstruction: "hi", - baseSystemMessage: "this is a base system message", - }, - wantRequest: []openai.ChatCompletionRequest{ - { - Messages: []openai.ChatCompletionMessage{ - {Role: "system", Content: "this is a base system message\n\n## Instructions\n\n"}, - {Role: "user", Content: "hi"}, - }, - }, - }, - wantUsage: ai.TokenUsage{PromptTokens: 13, CompletionTokens: 26}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - register.SetRegister(register.NewDefault()) +func TestCaller(t *testing.T) { + cc := &testComponentCreator{flow: newMockDataFlow(newHandler(time.Millisecond).handle)} - pd, err := provider.NewMock("mock provider", tt.args.providerMockData...) - if err != nil { - t.Fatal(err) - } + md, err := cc.ExchangeMetadata("") + assert.NoError(t, err) - cp := newMockCallerProvider() + caller, err := NewCaller(cc.CreateSource(""), cc.CreateReducer(""), md, time.Minute) + assert.NoError(t, err) - cp.provideFunc = mockCallerProvideFunc(tt.args.mockCallReqResp) + defer caller.Close() - caller, err := cp.Provide("") - assert.NoError(t, err) + assert.Equal(t, md, caller.Metadata()) - caller.SetSystemPrompt(tt.args.systemPrompt) - - resp, err := GetInvoke(context.TODO(), tt.args.userInstruction, tt.args.baseSystemMessage, "transID", true, caller, pd) - assert.NoError(t, err) - - assert.Equal(t, tt.wantUsage, resp.TokenUsage) - assert.Equal(t, tt.wantRequest, pd.RequestRecords()) - }) - } + sysPrompt := "hello system prompt" + caller.SetSystemPrompt(sysPrompt) + assert.Equal(t, sysPrompt, caller.GetSystemPrompt()) } -func TestCallerChatCompletion(t *testing.T) { - type args struct { - providerMockData []provider.MockData - mockCallReqResp map[uint32][]mockFunctionCall - systemPrompt string - request openai.ChatCompletionRequest - } - tests := []struct { - name string - args args - wantRequest []openai.ChatCompletionRequest - }{ - { - name: "chat with tool call", - args: args{ - providerMockData: []provider.MockData{ - provider.MockChatCompletionResponse(toolCallResp, stopResp), - }, - mockCallReqResp: map[uint32][]mockFunctionCall{ - // toolID should equal to toolCallResp's toolID - 0x33: {{toolID: "call_abc123", functionName: "get_current_weather", respContent: "temperature: 31°C"}}, - }, - systemPrompt: "this is a system prompt", - request: openai.ChatCompletionRequest{ - Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "How is the weather today in Boston, MA?"}}, - }, - }, - wantRequest: []openai.ChatCompletionRequest{ - { - Messages: []openai.ChatCompletionMessage{ - {Role: "user", Content: "How is the weather today in Boston, MA?"}, - {Role: "system", Content: "this is a system prompt"}, - }, - Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}}, - }, - { - Messages: []openai.ChatCompletionMessage{ - {Role: "user", Content: "How is the weather today in Boston, MA?"}, - {Role: "system", Content: "this is a system prompt"}, - {Role: "assistant", ToolCalls: []openai.ToolCall{{ID: "call_abc123", Type: openai.ToolTypeFunction, Function: openai.FunctionCall{Name: "get_current_weather", Arguments: "{\n\"location\": \"Boston, MA\"\n}"}}}}, - {Role: "tool", Content: "temperature: 31°C", ToolCallID: "call_abc123"}, - }, - }, - }, - }, - { - name: "chat without tool call", - args: args{ - providerMockData: []provider.MockData{ - provider.MockChatCompletionResponse(stopResp), - }, - mockCallReqResp: map[uint32][]mockFunctionCall{ - // toolID should equal to toolCallResp's toolID - 0x33: {{toolID: "call_abc123", functionName: "get_current_weather", respContent: "temperature: 31°C"}}, - }, - systemPrompt: "You are an assistant.", - request: openai.ChatCompletionRequest{ - Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "How are you"}}, - }, - }, - wantRequest: []openai.ChatCompletionRequest{ - { - Messages: []openai.ChatCompletionMessage{ - {Role: "user", Content: "How are you"}, - {Role: "system", Content: "You are an assistant."}, - }, - Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}}, - }, - }, - }, - { - name: "chat with tool call in stream", - args: args{ - providerMockData: []provider.MockData{ - provider.MockChatCompletionStreamResponse(toolCallStreamResp, stopStreamResp), - }, - mockCallReqResp: map[uint32][]mockFunctionCall{ - // toolID should equal to toolCallResp's toolID - 0x33: {{toolID: "call_9ctHOJqO3bYrpm2A6S7nHd5k", functionName: "get_current_weather", respContent: "temperature: 31°C"}}, - }, - systemPrompt: "You are a weather assistant", - request: openai.ChatCompletionRequest{ - Stream: true, - Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "How is the weather today in Boston, MA?"}}, - }, - }, - wantRequest: []openai.ChatCompletionRequest{ - { - Stream: true, - Messages: []openai.ChatCompletionMessage{ - {Role: "user", Content: "How is the weather today in Boston, MA?"}, - {Role: "system", Content: "You are a weather assistant"}, - }, - Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}}, - }, - { - Stream: true, - Messages: []openai.ChatCompletionMessage{ - {Role: "user", Content: "How is the weather today in Boston, MA?"}, - {Role: "system", Content: "You are a weather assistant"}, - {Role: "assistant", ToolCalls: []openai.ToolCall{{Index: toInt(0), ID: "call_9ctHOJqO3bYrpm2A6S7nHd5k", Type: openai.ToolTypeFunction, Function: openai.FunctionCall{Name: "get_current_weather", Arguments: "{\"location\":\"Boston, MA\"}"}}}}, - {Role: "tool", Content: "temperature: 31°C", ToolCallID: "call_9ctHOJqO3bYrpm2A6S7nHd5k"}, - }, - }, - }, - }, - { - name: "chat without tool call in stream", - args: args{ - providerMockData: []provider.MockData{ - provider.MockChatCompletionStreamResponse(stopStreamResp), - }, - mockCallReqResp: map[uint32][]mockFunctionCall{ - // toolID should equal to toolCallResp's toolID - 0x33: {{toolID: "call_9ctHOJqO3bYrpm2A6S7nHd5k", functionName: "get_current_weather", respContent: "temperature: 31°C"}}, - }, - systemPrompt: "You are a weather assistant", - request: openai.ChatCompletionRequest{ - Stream: true, - Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "How is the weather today in Boston, MA?"}}, - }, - }, - wantRequest: []openai.ChatCompletionRequest{ - { - Stream: true, - Messages: []openai.ChatCompletionMessage{ - {Role: "user", Content: "How is the weather today in Boston, MA?"}, - {Role: "system", Content: "You are a weather assistant"}, - }, - Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}}, - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - register.SetRegister(register.NewDefault()) - - pd, err := provider.NewMock("mock provider", tt.args.providerMockData...) - if err != nil { - t.Fatal(err) - } - - cp := newMockCallerProvider() - - cp.provideFunc = mockCallerProvideFunc(tt.args.mockCallReqResp) - - caller, err := cp.Provide("") - assert.NoError(t, err) - - caller.SetSystemPrompt(tt.args.systemPrompt) - - w := httptest.NewRecorder() - err = GetChatCompletions(context.TODO(), tt.args.request, "transID", pd, caller, w) - assert.NoError(t, err) - - assert.Equal(t, tt.wantRequest, pd.RequestRecords()) - }) - } +type testComponentCreator struct { + flow *mockDataFlow } -func newMockCallerProvider() *callerProvider { - cp := &callerProvider{ - zipperAddr: DefaultZipperAddr, - exFn: DefaultExchangeMetadataFunc, - callers: expirable.NewLRU(CallerProviderCacheSize, func(_ string, caller Caller) { caller.Close() }, CallerProviderCacheTTL), - } - return cp +func (c *testComponentCreator) CreateSource(_ string) yomo.Source { + return c.flow } -// mockCallerProvideFunc returns a mock caller provider, which is used for mockCallerProvider -// the request-response of caller be provided has been defined in advance, the request and response are defined in the `calls`. -func mockCallerProvideFunc(calls map[uint32][]mockFunctionCall) provideFunc { - // register function to register - for tag, call := range calls { - for _, c := range call { - register.RegisterFunction(tag, &openai.FunctionDefinition{Name: c.functionName}, uint64(tag), nil) - } - } - - return func(credential, _ string, _ ExchangeMetadataFunc) (Caller, error) { - caller := &caller{ - credential: credential, - md: metadata.M{"hello": "llm bridge"}, - } - - caller.SetSystemPrompt("") - caller.CallSyncer = &mockCallSyncer{calls: calls} - - return caller, nil - } -} - -type mockFunctionCall struct { - toolID string - functionName string - respContent string -} - -type mockCallSyncer struct { - calls map[uint32][]mockFunctionCall +func (c *testComponentCreator) CreateReducer(_ string) yomo.StreamFunction { + return c.flow } -// Call implements CallSyncer, it returns the mock response defined in advance. -func (m *mockCallSyncer) Call(ctx context.Context, transID string, reqID string, toolCalls map[uint32][]*openai.ToolCall) ([]openai.ChatCompletionMessage, error) { - res := []openai.ChatCompletionMessage{} - for tag, calls := range toolCalls { - mcs, ok := m.calls[tag] - if !ok { - return nil, errors.New("call not found") - } - mcm := make(map[string]mockFunctionCall, len(mcs)) - for _, mc := range mcs { - mcm[mc.toolID] = mc - } - for _, call := range calls { - mc, ok := mcm[call.ID] - if !ok { - return nil, errors.New("call not found") - } - res = append(res, openai.ChatCompletionMessage{ - ToolCallID: mc.toolID, - Role: openai.ChatMessageRoleTool, - Content: mc.respContent, - }) - } - } - return res, nil +func (c *testComponentCreator) ExchangeMetadata(_ string) (metadata.M, error) { + return metadata.M{"hello": "llm bridge"}, nil } - -func (m *mockCallSyncer) Close() error { return nil } - -func toInt(val int) *int { return &val } - -var stopStreamResp = `data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" I'm"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" just"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" a"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" computer"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" program"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":","},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" so"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" don't"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" have"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" feelings"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":","},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" but"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" I'm"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" here"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" and"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" ready"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" help"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" with"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" whatever"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" need"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} - -data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[],"usage":{"prompt_tokens":13,"completion_tokens":34,"total_tokens":47}} - -data: [DONE]` - -var stopResp = `{ - "id": "chatcmpl-9blYknv9rHvr2dvCQKMeW21hlBpCX", - "object": "chat.completion", - "created": 1718787982, - "model": "gpt-4o-2024-05-13", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "Hello! I'm just a computer program, so I don't have feelings, but thanks for asking. How can I assist you today?" - }, - "logprobs": null, - "finish_reason": "stop" - } - ], - "usage": { - "prompt_tokens": 13, - "completion_tokens": 26, - "total_tokens": 39 - }, - "system_fingerprint": "fp_f4e629d0a5" -}` - -var toolCallStreamResp = `data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_9ctHOJqO3bYrpm2A6S7nHd5k","type":"function","function":{"name":"get_current_weather","arguments":""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Boston"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":","}}]},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" MA"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} - -data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null} - -data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[],"usage":{"prompt_tokens":83,"completion_tokens":17,"total_tokens":100}}` - -var toolCallResp = `{ - "id": "chatcmpl-abc123", - "object": "chat.completion", - "created": 1699896916, - "model": "gpt-4-turbo-2024-04-09", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": null, - "tool_calls": [ - { - "id": "call_abc123", - "type": "function", - "function": { - "name": "get_current_weather", - "arguments": "{\n\"location\": \"Boston, MA\"\n}" - } - } - ] - }, - "logprobs": null, - "finish_reason": "tool_calls" - } - ], - "usage": { - "prompt_tokens": 82, - "completion_tokens": 17, - "total_tokens": 99 - } -}` diff --git a/pkg/bridge/ai/service.go b/pkg/bridge/ai/service.go new file mode 100644 index 000000000..985a07858 --- /dev/null +++ b/pkg/bridge/ai/service.go @@ -0,0 +1,613 @@ +package ai + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "strings" + "time" + + "github.com/hashicorp/golang-lru/v2/expirable" + openai "github.com/sashabaranov/go-openai" + "github.com/yomorun/yomo" + "github.com/yomorun/yomo/ai" + "github.com/yomorun/yomo/core/metadata" + "github.com/yomorun/yomo/core/ylog" + "github.com/yomorun/yomo/pkg/bridge/ai/provider" + "github.com/yomorun/yomo/pkg/bridge/ai/register" + "github.com/yomorun/yomo/pkg/id" + "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/noop" +) + +// Service is the service layer for llm bridge server. +// service is responsible for handling the logic from handler layer. +type Service struct { + provider provider.LLMProvider + componentCreator ComponentCreator + newCallerFunc newCallerFunc + callers *expirable.LRU[string, *Caller] + option *ServiceOption + logger *slog.Logger +} + +// ServiceOption is the option for creating service +type ServiceOption struct { + // Logger is the logger for the service + Logger *slog.Logger + // Tracer is the tracer for the service + Tracer trace.Tracer + // CredentialFunc is the function for getting the credential from the request + CredentialFunc func(r *http.Request) (string, error) + // CallerCacheSize is the size of the caller's cache + CallerCacheSize int + // CallerCacheTTL is the time to live of the callers cache + CallerCacheTTL time.Duration + // CallerCallTimeout is the timeout for awaiting the function response. + CallerCallTimeout time.Duration +} + +// NewService creates a new service for handling the logic from handler layer. +func NewService(provider provider.LLMProvider, cc ComponentCreator, opt *ServiceOption) *Service { + return newService(provider, cc, NewCaller, opt) +} + +func initOption(opt *ServiceOption) *ServiceOption { + if opt == nil { + opt = &ServiceOption{} + } + if opt.Tracer == nil { + opt.Tracer = noop.NewTracerProvider().Tracer("yomo-ai-bridge") + } + if opt.Logger == nil { + opt.Logger = ylog.Default() + } + if opt.CredentialFunc == nil { + opt.CredentialFunc = func(_ *http.Request) (string, error) { return "", nil } + } + if opt.CallerCacheSize == 0 { + opt.CallerCacheSize = 1 + } + if opt.CallerCallTimeout == 0 { + opt.CallerCallTimeout = 60 * time.Second + } + + return opt +} + +func newService(provider provider.LLMProvider, cct ComponentCreator, ncf newCallerFunc, opt *ServiceOption) *Service { + var onEvict = func(_ string, caller *Caller) { + caller.Close() + } + + opt = initOption(opt) + + service := &Service{ + provider: provider, + componentCreator: cct, + newCallerFunc: ncf, + callers: expirable.NewLRU(opt.CallerCacheSize, onEvict, opt.CallerCacheTTL), + option: opt, + logger: opt.Logger, + } + + return service +} + +type newCallerFunc func(yomo.Source, yomo.StreamFunction, metadata.M, time.Duration) (*Caller, error) + +// LoadOrCreateCaller loads or creates the caller according to the http request. +func (srv *Service) LoadOrCreateCaller(r *http.Request) (*Caller, error) { + credential, err := srv.option.CredentialFunc(r) + if err != nil { + return nil, err + } + return srv.loadOrCreateCaller(credential) +} + +// GetInvoke returns the invoke response +func (srv *Service) GetInvoke(ctx context.Context, userInstruction, baseSystemMessage, transID string, caller *Caller, includeCallStack bool) (*ai.InvokeResponse, error) { + md := caller.Metadata().Clone() + // read tools attached to the metadata + tcs, err := register.ListToolCalls(md) + if err != nil { + return &ai.InvokeResponse{}, err + } + // prepare tools + tools := prepareToolCalls(tcs) + + chainMessage := ai.ChainMessage{} + messages := srv.prepareMessages(baseSystemMessage, userInstruction, chainMessage, tools, true) + req := openai.ChatCompletionRequest{ + Messages: messages, + } + // with tools + if len(tools) > 0 { + req.Tools = tools + } + var ( + promptUsage int + completionUsage int + ) + _, span := srv.option.Tracer.Start(ctx, "first_call") + chatCompletionResponse, err := srv.provider.GetChatCompletions(ctx, req, md) + if err != nil { + return nil, err + } + span.End() + promptUsage = chatCompletionResponse.Usage.PromptTokens + completionUsage = chatCompletionResponse.Usage.CompletionTokens + + // convert ChatCompletionResponse to InvokeResponse + res, err := ai.ConvertToInvokeResponse(&chatCompletionResponse, tcs) + if err != nil { + return nil, err + } + // if no tool_calls fired, just return the llm text result + if res.FinishReason != string(openai.FinishReasonToolCalls) { + return res, nil + } + + // run llm function calls + srv.logger.Debug(">>>> start 1st call response", + "res_toolcalls", fmt.Sprintf("%+v", res.ToolCalls), + "res_assistant_msgs", fmt.Sprintf("%+v", res.AssistantMessage)) + + srv.logger.Debug(">> run function calls", "transID", transID, "res.ToolCalls", fmt.Sprintf("%+v", res.ToolCalls)) + + _, span = srv.option.Tracer.Start(ctx, "run_sfn") + reqID := id.New(16) + llmCalls, err := caller.Call(ctx, transID, reqID, res.ToolCalls) + if err != nil { + return nil, err + } + span.End() + + srv.logger.Debug(">>>> start 2nd call with", "calls", fmt.Sprintf("%+v", llmCalls), "preceeding_assistant_message", fmt.Sprintf("%+v", res.AssistantMessage)) + + chainMessage.PreceedingAssistantMessage = res.AssistantMessage + chainMessage.ToolMessages = transToolMessage(llmCalls) + // do not attach toolMessage to prompt in 2nd call + messages2 := srv.prepareMessages(baseSystemMessage, userInstruction, chainMessage, tools, false) + req2 := openai.ChatCompletionRequest{ + Messages: messages2, + } + _, span = srv.option.Tracer.Start(ctx, "second_call") + chatCompletionResponse2, err := srv.provider.GetChatCompletions(ctx, req2, md) + if err != nil { + return nil, err + } + span.End() + + chatCompletionResponse2.Usage.PromptTokens += promptUsage + chatCompletionResponse2.Usage.CompletionTokens += completionUsage + + res2, err := ai.ConvertToInvokeResponse(&chatCompletionResponse2, tcs) + if err != nil { + return nil, err + } + + // INFO: call stack infomation + if includeCallStack { + res2.ToolCalls = res.ToolCalls + res2.ToolMessages = transToolMessage(llmCalls) + } + srv.logger.Debug("<<<< complete 2nd call", "res2", fmt.Sprintf("%+v", res2)) + + return res2, err +} + +// GetChatCompletions accepts openai.ChatCompletionRequest and responds to http.ResponseWriter. +func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, transID string, caller *Caller, w http.ResponseWriter) error { + reqCtx, reqSpan := srv.option.Tracer.Start(ctx, "completions_request") + md := caller.Metadata().Clone() + + // 1. find all hosting tool sfn + tagTools, err := register.ListToolCalls(md) + if err != nil { + return err + } + // 2. add those tools to request + req = srv.addToolsToRequest(req, tagTools) + + // 3. over write system prompt to request + req = srv.overWriteSystemPrompt(req, caller.GetSystemPrompt()) + + var ( + promptUsage = 0 + completionUsage = 0 + totalUsage = 0 + reqMessages = req.Messages + toolCallsMap = make(map[int]openai.ToolCall) + toolCalls = []openai.ToolCall{} + assistantMessage = openai.ChatCompletionMessage{} + ) + // 4. request first chat for getting tools + if req.Stream { + _, firstCallSpan := srv.option.Tracer.Start(reqCtx, "first_call_request") + var ( + flusher = eventFlusher(w) + isFunctionCall = false + ) + resStream, err := srv.provider.GetChatCompletionsStream(reqCtx, req, md) + if err != nil { + return err + } + + var ( + i int // number of chunks + j int // number of tool call chunks + firstRespSpan trace.Span + respSpan trace.Span + ) + for { + if i == 0 { + _, firstRespSpan = srv.option.Tracer.Start(reqCtx, "first_call_response_in_stream") + } + streamRes, err := resStream.Recv() + if err == io.EOF { + break + } + if err != nil { + return err + } + if len(streamRes.Choices) == 0 { + continue + } + if streamRes.Usage != nil { + promptUsage = streamRes.Usage.PromptTokens + completionUsage = streamRes.Usage.CompletionTokens + totalUsage = streamRes.Usage.TotalTokens + } + if tc := streamRes.Choices[0].Delta.ToolCalls; len(tc) > 0 { + isFunctionCall = true + if j == 0 { + firstCallSpan.End() + } + for _, t := range tc { + // this index should be toolCalls slice's index, the index field only appares in stream response + index := *t.Index + item, ok := toolCallsMap[index] + if !ok { + toolCallsMap[index] = openai.ToolCall{ + Index: t.Index, + ID: t.ID, + Type: t.Type, + Function: openai.FunctionCall{}, + } + item = toolCallsMap[index] + } + if t.Function.Arguments != "" { + item.Function.Arguments += t.Function.Arguments + } + if t.Function.Name != "" { + item.Function.Name = t.Function.Name + } + toolCallsMap[index] = item + } + j++ + } else if streamRes.Choices[0].FinishReason != openai.FinishReasonToolCalls { + _ = writeStreamEvent(w, flusher, streamRes) + } + if i == 0 && j == 0 && !isFunctionCall { + reqSpan.End() + recordTTFT(ctx, srv.option.Tracer) + _, respSpan = srv.option.Tracer.Start(ctx, "response_in_stream(TBT)") + } + i++ + } + if !isFunctionCall { + respSpan.End() + return writeStreamDone(w, flusher) + } + firstRespSpan.End() + toolCalls = mapToSliceTools(toolCallsMap) + + assistantMessage = openai.ChatCompletionMessage{ + ToolCalls: toolCalls, + Role: openai.ChatMessageRoleAssistant, + } + reqSpan.End() + flusher.Flush() + } else { + _, firstCallSpan := srv.option.Tracer.Start(reqCtx, "first_call") + resp, err := srv.provider.GetChatCompletions(ctx, req, md) + if err != nil { + return err + } + reqSpan.End() + + promptUsage = resp.Usage.PromptTokens + completionUsage = resp.Usage.CompletionTokens + totalUsage = resp.Usage.CompletionTokens + + srv.logger.Debug(" #1 first call", "response", fmt.Sprintf("%+v", resp)) + // it is a function call + if resp.Choices[0].FinishReason == openai.FinishReasonToolCalls { + toolCalls = append(toolCalls, resp.Choices[0].Message.ToolCalls...) + assistantMessage = resp.Choices[0].Message + firstCallSpan.End() + } else { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + return nil + } + } + + resCtx, resSpan := srv.option.Tracer.Start(ctx, "completions_response") + defer resSpan.End() + + _, sfnSpan := srv.option.Tracer.Start(resCtx, "run_sfn") + + // 5. find sfns that hit the function call + fnCalls := findTagTools(tagTools, toolCalls) + + // 6. run llm function calls + reqID := id.New(16) + llmCalls, err := caller.Call(ctx, transID, reqID, fnCalls) + if err != nil { + return err + } + sfnSpan.End() + + // 7. do the second call (the second call messages are from user input, first call resopnse and sfn calls result) + req.Messages = append(reqMessages, assistantMessage) + req.Messages = append(req.Messages, llmCalls...) + req.Tools = nil // reset tools field + + srv.logger.Debug(" #2 second call", "request", fmt.Sprintf("%+v", req)) + + if req.Stream { + _, secondCallSpan := srv.option.Tracer.Start(resCtx, "second_call_request") + flusher := w.(http.Flusher) + resStream, err := srv.provider.GetChatCompletionsStream(resCtx, req, md) + if err != nil { + return err + } + secondCallSpan.End() + + var ( + i int + secondRespSpan trace.Span + ) + for { + if i == 0 { + recordTTFT(resCtx, srv.option.Tracer) + _, secondRespSpan = srv.option.Tracer.Start(resCtx, "second_call_response_in_stream(TBT)") + } + i++ + streamRes, err := resStream.Recv() + if err == io.EOF { + secondRespSpan.End() + return writeStreamDone(w, flusher) + } + if err != nil { + return err + } + if streamRes.Usage != nil { + streamRes.Usage.PromptTokens += promptUsage + streamRes.Usage.CompletionTokens += completionUsage + streamRes.Usage.TotalTokens += totalUsage + } + _ = writeStreamEvent(w, flusher, streamRes) + } + } else { + _, secondCallSpan := srv.option.Tracer.Start(resCtx, "second_call") + + resp, err := srv.provider.GetChatCompletions(resCtx, req, md) + if err != nil { + return err + } + + resp.Usage.PromptTokens += promptUsage + resp.Usage.CompletionTokens += completionUsage + resp.Usage.TotalTokens += totalUsage + + secondCallSpan.End() + w.Header().Set("Content-Type", "application/json") + return json.NewEncoder(w).Encode(resp) + } +} + +func (srv *Service) loadOrCreateCaller(credential string) (*Caller, error) { + caller, ok := srv.callers.Get(credential) + if ok { + return caller, nil + } + md, err := srv.componentCreator.ExchangeMetadata(credential) + if err != nil { + return nil, err + } + caller, err = srv.newCallerFunc( + srv.componentCreator.CreateSource(credential), + srv.componentCreator.CreateReducer(credential), + md, + srv.option.CallerCallTimeout, + ) + if err != nil { + return nil, err + } + + srv.callers.Add(credential, caller) + + return caller, nil +} + +func (srv *Service) addToolsToRequest(req openai.ChatCompletionRequest, tagTools map[uint32]openai.Tool) openai.ChatCompletionRequest { + toolCalls := prepareToolCalls(tagTools) + + if len(toolCalls) > 0 { + req.Tools = toolCalls + } + + srv.logger.Debug(" #1 first call", "request", fmt.Sprintf("%+v", req)) + + return req +} + +func (srv *Service) overWriteSystemPrompt(req openai.ChatCompletionRequest, sysPrompt string) openai.ChatCompletionRequest { + // do nothing if system prompt is empty + if sysPrompt == "" { + return req + } + // over write system prompt + isOverWrite := false + for i, msg := range req.Messages { + if msg.Role != "system" { + continue + } + req.Messages[i] = openai.ChatCompletionMessage{ + Role: msg.Role, + Content: sysPrompt, + } + isOverWrite = true + } + // append system prompt + if !isOverWrite { + req.Messages = append(req.Messages, openai.ChatCompletionMessage{ + Role: "system", + Content: sysPrompt, + }) + } + + srv.logger.Debug(" #1 first call after overwrite", "request", fmt.Sprintf("%+v", req)) + + return req +} + +func findTagTools(tagTools map[uint32]openai.Tool, toolCalls []openai.ToolCall) map[uint32][]*openai.ToolCall { + fnCalls := make(map[uint32][]*openai.ToolCall) + // functions may be more than one + for _, call := range toolCalls { + for tag, tc := range tagTools { + if tc.Function.Name == call.Function.Name && tc.Type == call.Type { + currentCall := call + fnCalls[tag] = append(fnCalls[tag], ¤tCall) + } + } + } + return fnCalls +} + +func writeStreamEvent(w http.ResponseWriter, flusher http.Flusher, streamRes openai.ChatCompletionStreamResponse) error { + if _, err := io.WriteString(w, "data: "); err != nil { + return err + } + if err := json.NewEncoder(w).Encode(streamRes); err != nil { + return err + } + if _, err := io.WriteString(w, "\n"); err != nil { + return err + } + flusher.Flush() + + return nil +} + +func writeStreamDone(w http.ResponseWriter, flusher http.Flusher) error { + _, err := io.WriteString(w, "data: [DONE]") + flusher.Flush() + + return err +} + +func (srv *Service) prepareMessages(baseSystemMessage string, userInstruction string, chainMessage ai.ChainMessage, tools []openai.Tool, withTool bool) []openai.ChatCompletionMessage { + systemInstructions := []string{"## Instructions\n"} + + // only append if there are tool calls + if withTool { + for _, t := range tools { + systemInstructions = append(systemInstructions, "- ") + systemInstructions = append(systemInstructions, t.Function.Description) + systemInstructions = append(systemInstructions, "\n") + } + systemInstructions = append(systemInstructions, "\n") + } + + SystemPrompt := fmt.Sprintf("%s\n\n%s", baseSystemMessage, strings.Join(systemInstructions, "")) + + messages := []openai.ChatCompletionMessage{} + + // 1. system message + messages = append(messages, openai.ChatCompletionMessage{Role: "system", Content: SystemPrompt}) + + // 2. previous tool calls + // Ref: Tool Message Object in Messsages + // https://platform.openai.com/docs/guides/function-calling + // https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages + + if chainMessage.PreceedingAssistantMessage != nil { + // 2.1 assistant message + // try convert type of chainMessage.PreceedingAssistantMessage to type ChatCompletionMessage + assistantMessage, ok := chainMessage.PreceedingAssistantMessage.(openai.ChatCompletionMessage) + if ok { + srv.logger.Debug("======== add assistantMessage", "am", fmt.Sprintf("%+v", assistantMessage)) + messages = append(messages, assistantMessage) + } + + // 2.2 tool message + for _, tool := range chainMessage.ToolMessages { + tm := openai.ChatCompletionMessage{ + Role: "tool", + Content: tool.Content, + ToolCallID: tool.ToolCallID, + } + srv.logger.Debug("======== add toolMessage", "tm", fmt.Sprintf("%+v", tm)) + messages = append(messages, tm) + } + } + + // 3. user instruction + messages = append(messages, openai.ChatCompletionMessage{Role: "user", Content: userInstruction}) + + return messages +} + +func mapToSliceTools(m map[int]openai.ToolCall) []openai.ToolCall { + arr := make([]openai.ToolCall, len(m)) + for k, v := range m { + arr[k] = v + } + return arr +} + +func eventFlusher(w http.ResponseWriter) http.Flusher { + h := w.Header() + h.Set("Content-Type", "text/event-stream") + h.Set("Cache-Control", "no-cache, must-revalidate") + h.Set("x-content-type-options", "nosniff") + flusher := w.(http.Flusher) + return flusher +} + +func prepareToolCalls(tcs map[uint32]openai.Tool) []openai.Tool { + // prepare tools + toolCalls := make([]openai.Tool, len(tcs)) + idx := 0 + for _, tc := range tcs { + toolCalls[idx] = tc + idx++ + } + return toolCalls +} + +func transToolMessage(msgs []openai.ChatCompletionMessage) []ai.ToolMessage { + toolMessages := make([]ai.ToolMessage, len(msgs)) + for i, msg := range msgs { + toolMessages[i] = ai.ToolMessage{ + Role: msg.Role, + Content: msg.Content, + ToolCallID: msg.ToolCallID, + } + } + return toolMessages +} + +func recordTTFT(ctx context.Context, tracer trace.Tracer) { + _, span := tracer.Start(ctx, "TTFT") + span.End() + time.Sleep(time.Millisecond) +} diff --git a/pkg/bridge/ai/service_test.go b/pkg/bridge/ai/service_test.go new file mode 100644 index 000000000..9966b8c23 --- /dev/null +++ b/pkg/bridge/ai/service_test.go @@ -0,0 +1,495 @@ +package ai + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + openai "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/assert" + "github.com/yomorun/yomo" + "github.com/yomorun/yomo/ai" + "github.com/yomorun/yomo/core/metadata" + "github.com/yomorun/yomo/pkg/bridge/ai/provider" + "github.com/yomorun/yomo/pkg/bridge/ai/register" +) + +func TestServiceInvoke(t *testing.T) { + type args struct { + providerMockData []provider.MockData + mockCallReqResp map[uint32][]mockFunctionCall + systemPrompt string + userInstruction string + baseSystemMessage string + } + tests := []struct { + name string + args args + wantRequest []openai.ChatCompletionRequest + wantUsage ai.TokenUsage + }{ + { + name: "invoke with tool call", + args: args{ + providerMockData: []provider.MockData{ + provider.MockChatCompletionResponse(toolCallResp, stopResp), + }, + mockCallReqResp: map[uint32][]mockFunctionCall{ + // toolID should equal to toolCallResp's toolID + 0x33: {{toolID: "call_abc123", functionName: "get_current_weather", respContent: "temperature: 31°C"}}, + }, + systemPrompt: "this is a system prompt", + userInstruction: "hi", + baseSystemMessage: "this is a base system message", + }, + wantRequest: []openai.ChatCompletionRequest{ + { + Messages: []openai.ChatCompletionMessage{ + {Role: "system", Content: "this is a base system message\n\n## Instructions\n- \n\n"}, + {Role: "user", Content: "hi"}, + }, + Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}}, + }, + { + Messages: []openai.ChatCompletionMessage{ + {Role: "system", Content: "this is a base system message\n\n## Instructions\n"}, + {Role: "assistant", ToolCalls: []openai.ToolCall{{ID: "call_abc123", Type: openai.ToolTypeFunction, Function: openai.FunctionCall{Name: "get_current_weather", Arguments: "{\n\"location\": \"Boston, MA\"\n}"}}}}, + {Role: "tool", Content: "temperature: 31°C", ToolCallID: "call_abc123"}, + {Role: "user", Content: "hi"}, + }, + }, + }, + wantUsage: ai.TokenUsage{PromptTokens: 95, CompletionTokens: 43}, + }, + { + name: "invoke without tool call", + args: args{ + providerMockData: []provider.MockData{ + provider.MockChatCompletionResponse(stopResp), + }, + mockCallReqResp: map[uint32][]mockFunctionCall{}, + systemPrompt: "this is a system prompt", + userInstruction: "hi", + baseSystemMessage: "this is a base system message", + }, + wantRequest: []openai.ChatCompletionRequest{ + { + Messages: []openai.ChatCompletionMessage{ + {Role: "system", Content: "this is a base system message\n\n## Instructions\n\n"}, + {Role: "user", Content: "hi"}, + }, + }, + }, + wantUsage: ai.TokenUsage{PromptTokens: 13, CompletionTokens: 26}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + register.SetRegister(register.NewDefault()) + + pd, err := provider.NewMock("mock provider", tt.args.providerMockData...) + if err != nil { + t.Fatal(err) + } + + cc := &testComponentCreator{flow: newMockDataFlow(newHandler(2 * time.Hour).handle)} + + newCaller := func(_ yomo.Source, _ yomo.StreamFunction, _ metadata.M, _ time.Duration) (*Caller, error) { + return mockCaller(tt.args.mockCallReqResp), err + } + + service := newService(pd, cc, newCaller, nil) + + caller, err := service.LoadOrCreateCaller(&http.Request{}) + assert.NoError(t, err) + + caller.SetSystemPrompt(tt.args.systemPrompt) + + resp, err := service.GetInvoke(context.TODO(), tt.args.userInstruction, tt.args.baseSystemMessage, "transID", caller, true) + assert.NoError(t, err) + + assert.Equal(t, tt.wantUsage, resp.TokenUsage) + assert.Equal(t, tt.wantRequest, pd.RequestRecords()) + }) + } +} + +func TestServiceChatCompletion(t *testing.T) { + type args struct { + providerMockData []provider.MockData + mockCallReqResp map[uint32][]mockFunctionCall + systemPrompt string + request openai.ChatCompletionRequest + } + tests := []struct { + name string + args args + wantRequest []openai.ChatCompletionRequest + }{ + { + name: "chat with tool call", + args: args{ + providerMockData: []provider.MockData{ + provider.MockChatCompletionResponse(toolCallResp, stopResp), + }, + mockCallReqResp: map[uint32][]mockFunctionCall{ + // toolID should equal to toolCallResp's toolID + 0x33: {{toolID: "call_abc123", functionName: "get_current_weather", respContent: "temperature: 31°C"}}, + }, + systemPrompt: "this is a system prompt", + request: openai.ChatCompletionRequest{ + Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "How is the weather today in Boston, MA?"}}, + }, + }, + wantRequest: []openai.ChatCompletionRequest{ + { + Messages: []openai.ChatCompletionMessage{ + {Role: "user", Content: "How is the weather today in Boston, MA?"}, + {Role: "system", Content: "this is a system prompt"}, + }, + Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}}, + }, + { + Messages: []openai.ChatCompletionMessage{ + {Role: "user", Content: "How is the weather today in Boston, MA?"}, + {Role: "system", Content: "this is a system prompt"}, + {Role: "assistant", ToolCalls: []openai.ToolCall{{ID: "call_abc123", Type: openai.ToolTypeFunction, Function: openai.FunctionCall{Name: "get_current_weather", Arguments: "{\n\"location\": \"Boston, MA\"\n}"}}}}, + {Role: "tool", Content: "temperature: 31°C", ToolCallID: "call_abc123"}, + }, + }, + }, + }, + { + name: "chat without tool call", + args: args{ + providerMockData: []provider.MockData{ + provider.MockChatCompletionResponse(stopResp), + }, + mockCallReqResp: map[uint32][]mockFunctionCall{ + // toolID should equal to toolCallResp's toolID + 0x33: {{toolID: "call_abc123", functionName: "get_current_weather", respContent: "temperature: 31°C"}}, + }, + systemPrompt: "You are an assistant.", + request: openai.ChatCompletionRequest{ + Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "How are you"}}, + }, + }, + wantRequest: []openai.ChatCompletionRequest{ + { + Messages: []openai.ChatCompletionMessage{ + {Role: "user", Content: "How are you"}, + {Role: "system", Content: "You are an assistant."}, + }, + Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}}, + }, + }, + }, + { + name: "chat with tool call in stream", + args: args{ + providerMockData: []provider.MockData{ + provider.MockChatCompletionStreamResponse(toolCallStreamResp, stopStreamResp), + }, + mockCallReqResp: map[uint32][]mockFunctionCall{ + // toolID should equal to toolCallResp's toolID + 0x33: {{toolID: "call_9ctHOJqO3bYrpm2A6S7nHd5k", functionName: "get_current_weather", respContent: "temperature: 31°C"}}, + }, + systemPrompt: "You are a weather assistant", + request: openai.ChatCompletionRequest{ + Stream: true, + Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "How is the weather today in Boston, MA?"}}, + }, + }, + wantRequest: []openai.ChatCompletionRequest{ + { + Stream: true, + Messages: []openai.ChatCompletionMessage{ + {Role: "user", Content: "How is the weather today in Boston, MA?"}, + {Role: "system", Content: "You are a weather assistant"}, + }, + Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}}, + }, + { + Stream: true, + Messages: []openai.ChatCompletionMessage{ + {Role: "user", Content: "How is the weather today in Boston, MA?"}, + {Role: "system", Content: "You are a weather assistant"}, + {Role: "assistant", ToolCalls: []openai.ToolCall{{Index: toInt(0), ID: "call_9ctHOJqO3bYrpm2A6S7nHd5k", Type: openai.ToolTypeFunction, Function: openai.FunctionCall{Name: "get_current_weather", Arguments: "{\"location\":\"Boston, MA\"}"}}}}, + {Role: "tool", Content: "temperature: 31°C", ToolCallID: "call_9ctHOJqO3bYrpm2A6S7nHd5k"}, + }, + }, + }, + }, + { + name: "chat without tool call in stream", + args: args{ + providerMockData: []provider.MockData{ + provider.MockChatCompletionStreamResponse(stopStreamResp), + }, + mockCallReqResp: map[uint32][]mockFunctionCall{ + // toolID should equal to toolCallResp's toolID + 0x33: {{toolID: "call_9ctHOJqO3bYrpm2A6S7nHd5k", functionName: "get_current_weather", respContent: "temperature: 31°C"}}, + }, + systemPrompt: "You are a weather assistant", + request: openai.ChatCompletionRequest{ + Stream: true, + Messages: []openai.ChatCompletionMessage{{Role: "user", Content: "How is the weather today in Boston, MA?"}}, + }, + }, + wantRequest: []openai.ChatCompletionRequest{ + { + Stream: true, + Messages: []openai.ChatCompletionMessage{ + {Role: "user", Content: "How is the weather today in Boston, MA?"}, + {Role: "system", Content: "You are a weather assistant"}, + }, + Tools: []openai.Tool{{Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{Name: "get_current_weather"}}}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + register.SetRegister(register.NewDefault()) + + pd, err := provider.NewMock("mock provider", tt.args.providerMockData...) + if err != nil { + t.Fatal(err) + } + + cc := &testComponentCreator{flow: newMockDataFlow(newHandler(2 * time.Hour).handle)} + + newCaller := func(_ yomo.Source, _ yomo.StreamFunction, _ metadata.M, _ time.Duration) (*Caller, error) { + return mockCaller(tt.args.mockCallReqResp), err + } + + service := newService(pd, cc, newCaller, nil) + + caller, err := service.LoadOrCreateCaller(&http.Request{}) + assert.NoError(t, err) + + caller.SetSystemPrompt(tt.args.systemPrompt) + + w := httptest.NewRecorder() + err = service.GetChatCompletions(context.TODO(), tt.args.request, "transID", caller, w) + assert.NoError(t, err) + + assert.Equal(t, tt.wantRequest, pd.RequestRecords()) + }) + } +} + +// mockCaller returns a mock caller. +// the request-response of caller has been defined in advance, the request and response are defined in the `calls`. +func mockCaller(calls map[uint32][]mockFunctionCall) *Caller { + // register function to register + for tag, call := range calls { + for _, c := range call { + register.RegisterFunction(tag, &openai.FunctionDefinition{Name: c.functionName}, uint64(tag), nil) + } + } + + caller := &Caller{ + CallSyncer: &mockCallSyncer{calls: calls}, + md: metadata.M{"hello": "llm bridge"}, + } + + return caller +} + +type mockFunctionCall struct { + toolID string + functionName string + respContent string +} + +type mockCallSyncer struct { + calls map[uint32][]mockFunctionCall +} + +// Call implements CallSyncer, it returns the mock response defined in advance. +func (m *mockCallSyncer) Call(ctx context.Context, transID string, reqID string, toolCalls map[uint32][]*openai.ToolCall) ([]openai.ChatCompletionMessage, error) { + res := []openai.ChatCompletionMessage{} + for tag, calls := range toolCalls { + mcs, ok := m.calls[tag] + if !ok { + return nil, errors.New("call not found") + } + mcm := make(map[string]mockFunctionCall, len(mcs)) + for _, mc := range mcs { + mcm[mc.toolID] = mc + } + for _, call := range calls { + mc, ok := mcm[call.ID] + if !ok { + return nil, errors.New("call not found") + } + res = append(res, openai.ChatCompletionMessage{ + ToolCallID: mc.toolID, + Role: openai.ChatMessageRoleTool, + Content: mc.respContent, + }) + } + } + return res, nil +} + +func (m *mockCallSyncer) Close() error { return nil } + +func toInt(val int) *int { return &val } + +var stopStreamResp = `data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" I'm"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" just"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" a"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" computer"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" program"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":","},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" so"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" don't"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" have"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" feelings"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":","},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" but"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" I'm"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" here"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" and"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" ready"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" help"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" with"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" whatever"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" need"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} + +data: {"id":"chatcmpl-9blY98pEJe6mXGKivCZyl61vxaUFq","object":"chat.completion.chunk","created":1718787945,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_f4e629d0a5","choices":[],"usage":{"prompt_tokens":13,"completion_tokens":34,"total_tokens":47}} + +data: [DONE]` + +var stopResp = `{ + "id": "chatcmpl-9blYknv9rHvr2dvCQKMeW21hlBpCX", + "object": "chat.completion", + "created": 1718787982, + "model": "gpt-4o-2024-05-13", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! I'm just a computer program, so I don't have feelings, but thanks for asking. How can I assist you today?" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 13, + "completion_tokens": 26, + "total_tokens": 39 + }, + "system_fingerprint": "fp_f4e629d0a5" +}` + +var toolCallStreamResp = `data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_9ctHOJqO3bYrpm2A6S7nHd5k","type":"function","function":{"name":"get_current_weather","arguments":""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Boston"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":","}}]},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" MA"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null} + +data: {"id":"chatcmpl-9blTCqGy0TGLdK4sOYlGrNxbGGknW","object":"chat.completion.chunk","created":1718787638,"model":"gpt-4-turbo-2024-04-09","system_fingerprint":"fp_9d7f5c6195","choices":[],"usage":{"prompt_tokens":83,"completion_tokens":17,"total_tokens":100}}` + +var toolCallResp = `{ + "id": "chatcmpl-abc123", + "object": "chat.completion", + "created": 1699896916, + "model": "gpt-4-turbo-2024-04-09", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\n\"location\": \"Boston, MA\"\n}" + } + } + ] + }, + "logprobs": null, + "finish_reason": "tool_calls" + } + ], + "usage": { + "prompt_tokens": 82, + "completion_tokens": 17, + "total_tokens": 99 + } +}` From ef92ece166e7564bf737dab8d7935e471993970c Mon Sep 17 00:00:00 2001 From: woorui Date: Tue, 13 Aug 2024 22:44:13 +0800 Subject: [PATCH 2/3] refactor: remove ComponentCreator interface --- pkg/bridge/ai/api_server.go | 2 +- pkg/bridge/ai/api_server_test.go | 8 +++- pkg/bridge/ai/caller.go | 40 -------------------- pkg/bridge/ai/service.go | 63 +++++++++++++++++++++++--------- pkg/bridge/ai/service_test.go | 16 ++++++-- 5 files changed, 64 insertions(+), 65 deletions(-) diff --git a/pkg/bridge/ai/api_server.go b/pkg/bridge/ai/api_server.go index ddf2bcc62..88e0fcae5 100644 --- a/pkg/bridge/ai/api_server.go +++ b/pkg/bridge/ai/api_server.go @@ -83,7 +83,7 @@ func NewBasicAPIServer(config *Config, zipperAddr, credential string, provider p logger = logger.With("component", "bridge") - service := NewService(provider, DefaultComponentCreator(zipperAddr), &ServiceOption{ + service := NewService(zipperAddr, provider, &ServiceOption{ Logger: logger, Tracer: otel.Tracer("yomo-llm-bridge"), CredentialFunc: func(r *http.Request) (string, error) { return credential, nil }, diff --git a/pkg/bridge/ai/api_server_test.go b/pkg/bridge/ai/api_server_test.go index f0e456b29..1678ee4c4 100644 --- a/pkg/bridge/ai/api_server_test.go +++ b/pkg/bridge/ai/api_server_test.go @@ -40,13 +40,17 @@ func TestServer(t *testing.T) { t.Fatal(err) } - cc := &testComponentCreator{flow: newMockDataFlow(newHandler(2 * time.Hour).handle)} + flow := newMockDataFlow(newHandler(2 * time.Hour).handle) newCaller := func(_ yomo.Source, _ yomo.StreamFunction, _ metadata.M, _ time.Duration) (*Caller, error) { return mockCaller(nil), err } - service := newService(pd, cc, newCaller, nil) + service := newService("fake_zipper_addr", pd, newCaller, &ServiceOption{ + SourceBuilder: func(_, _ string) yomo.Source { return flow }, + ReducerBuilder: func(_, _ string) yomo.StreamFunction { return flow }, + MetadataExchanger: func(_ string) (metadata.M, error) { return metadata.M{"hello": "llm bridge"}, nil }, + }) handler := DecorateHandler(NewServeMux(service), decorateReqContext(service, service.logger)) diff --git a/pkg/bridge/ai/caller.go b/pkg/bridge/ai/caller.go index ddbfa04e7..44c5b1256 100644 --- a/pkg/bridge/ai/caller.go +++ b/pkg/bridge/ai/caller.go @@ -141,43 +141,3 @@ func (c *Caller) Close() error { return err } - -// ComponentCreator creates unconnected source, unconnected reducer, and exchange metadata from credential. -type ComponentCreator interface { - // CreateSource should creates an unconnected source. - CreateSource(credential string) yomo.Source - // CreateReducer should creates an unconnected reducer. - CreateReducer(credential string) yomo.StreamFunction - // ExchangeMetadata exchanges metadata from the credential. - ExchangeMetadata(credential string) (metadata.M, error) -} - -type defaultComponentCreator struct { - zipperAddr string -} - -// DefaultComponentCreator returns a ComponentCreator that creates unconnected source, -// unconnected reducer, and exchange metadata from credential. -func DefaultComponentCreator(zipperAddr string) ComponentCreator { - return &defaultComponentCreator{ - zipperAddr: zipperAddr, - } -} - -func (c *defaultComponentCreator) CreateSource(credential string) yomo.Source { - return yomo.NewSource( - "fc-source", - c.zipperAddr, - yomo.WithSourceReConnect(), yomo.WithCredential(credential)) -} - -func (c *defaultComponentCreator) CreateReducer(credential string) yomo.StreamFunction { - return yomo.NewStreamFunction( - "fc-reducer", - c.zipperAddr, - yomo.WithSfnReConnect(), yomo.WithSfnCredential(credential), yomo.DisableOtelTrace()) -} - -func (c *defaultComponentCreator) ExchangeMetadata(credential string) (metadata.M, error) { - return metadata.New(), nil -} diff --git a/pkg/bridge/ai/service.go b/pkg/bridge/ai/service.go index 985a07858..1777ed29e 100644 --- a/pkg/bridge/ai/service.go +++ b/pkg/bridge/ai/service.go @@ -26,12 +26,12 @@ import ( // Service is the service layer for llm bridge server. // service is responsible for handling the logic from handler layer. type Service struct { - provider provider.LLMProvider - componentCreator ComponentCreator - newCallerFunc newCallerFunc - callers *expirable.LRU[string, *Caller] - option *ServiceOption - logger *slog.Logger + zipperAddr string + provider provider.LLMProvider + newCallerFunc newCallerFunc + callers *expirable.LRU[string, *Caller] + option *ServiceOption + logger *slog.Logger } // ServiceOption is the option for creating service @@ -48,11 +48,17 @@ type ServiceOption struct { CallerCacheTTL time.Duration // CallerCallTimeout is the timeout for awaiting the function response. CallerCallTimeout time.Duration + // SourceBuilder should builds an unconnected source. + SourceBuilder func(zipperAddr, credential string) yomo.Source + // ReducerBuilder should builds an unconnected reducer. + ReducerBuilder func(zipperAddr, credential string) yomo.StreamFunction + // MetadataExchanger exchanges metadata from the credential. + MetadataExchanger func(credential string) (metadata.M, error) } // NewService creates a new service for handling the logic from handler layer. -func NewService(provider provider.LLMProvider, cc ComponentCreator, opt *ServiceOption) *Service { - return newService(provider, cc, NewCaller, opt) +func NewService(zipperAddr string, provider provider.LLMProvider, opt *ServiceOption) *Service { + return newService(zipperAddr, provider, NewCaller, opt) } func initOption(opt *ServiceOption) *ServiceOption { @@ -74,11 +80,32 @@ func initOption(opt *ServiceOption) *ServiceOption { if opt.CallerCallTimeout == 0 { opt.CallerCallTimeout = 60 * time.Second } + if opt.SourceBuilder == nil { + opt.SourceBuilder = func(zipperAddr, credential string) yomo.Source { + return yomo.NewSource( + "fc-source", + zipperAddr, + yomo.WithSourceReConnect(), yomo.WithCredential(credential)) + } + } + if opt.ReducerBuilder == nil { + opt.ReducerBuilder = func(zipperAddr, credential string) yomo.StreamFunction { + return yomo.NewStreamFunction( + "fc-reducer", + zipperAddr, + yomo.WithSfnReConnect(), yomo.WithSfnCredential(credential), yomo.DisableOtelTrace()) + } + } + if opt.MetadataExchanger == nil { + opt.MetadataExchanger = func(credential string) (metadata.M, error) { + return metadata.New(), nil + } + } return opt } -func newService(provider provider.LLMProvider, cct ComponentCreator, ncf newCallerFunc, opt *ServiceOption) *Service { +func newService(zipperAddr string, provider provider.LLMProvider, ncf newCallerFunc, opt *ServiceOption) *Service { var onEvict = func(_ string, caller *Caller) { caller.Close() } @@ -86,12 +113,12 @@ func newService(provider provider.LLMProvider, cct ComponentCreator, ncf newCall opt = initOption(opt) service := &Service{ - provider: provider, - componentCreator: cct, - newCallerFunc: ncf, - callers: expirable.NewLRU(opt.CallerCacheSize, onEvict, opt.CallerCacheTTL), - option: opt, - logger: opt.Logger, + zipperAddr: zipperAddr, + provider: provider, + newCallerFunc: ncf, + callers: expirable.NewLRU(opt.CallerCacheSize, onEvict, opt.CallerCacheTTL), + option: opt, + logger: opt.Logger, } return service @@ -417,13 +444,13 @@ func (srv *Service) loadOrCreateCaller(credential string) (*Caller, error) { if ok { return caller, nil } - md, err := srv.componentCreator.ExchangeMetadata(credential) + md, err := srv.option.MetadataExchanger(credential) if err != nil { return nil, err } caller, err = srv.newCallerFunc( - srv.componentCreator.CreateSource(credential), - srv.componentCreator.CreateReducer(credential), + srv.option.SourceBuilder(srv.zipperAddr, credential), + srv.option.ReducerBuilder(srv.zipperAddr, credential), md, srv.option.CallerCallTimeout, ) diff --git a/pkg/bridge/ai/service_test.go b/pkg/bridge/ai/service_test.go index 9966b8c23..8a10cd2d2 100644 --- a/pkg/bridge/ai/service_test.go +++ b/pkg/bridge/ai/service_test.go @@ -95,13 +95,17 @@ func TestServiceInvoke(t *testing.T) { t.Fatal(err) } - cc := &testComponentCreator{flow: newMockDataFlow(newHandler(2 * time.Hour).handle)} + flow := newMockDataFlow(newHandler(2 * time.Hour).handle) newCaller := func(_ yomo.Source, _ yomo.StreamFunction, _ metadata.M, _ time.Duration) (*Caller, error) { return mockCaller(tt.args.mockCallReqResp), err } - service := newService(pd, cc, newCaller, nil) + service := newService("fake_zipper_addr", pd, newCaller, &ServiceOption{ + SourceBuilder: func(_, _ string) yomo.Source { return flow }, + ReducerBuilder: func(_, _ string) yomo.StreamFunction { return flow }, + MetadataExchanger: func(_ string) (metadata.M, error) { return metadata.M{"hello": "llm bridge"}, nil }, + }) caller, err := service.LoadOrCreateCaller(&http.Request{}) assert.NoError(t, err) @@ -260,13 +264,17 @@ func TestServiceChatCompletion(t *testing.T) { t.Fatal(err) } - cc := &testComponentCreator{flow: newMockDataFlow(newHandler(2 * time.Hour).handle)} + flow := newMockDataFlow(newHandler(2 * time.Hour).handle) newCaller := func(_ yomo.Source, _ yomo.StreamFunction, _ metadata.M, _ time.Duration) (*Caller, error) { return mockCaller(tt.args.mockCallReqResp), err } - service := newService(pd, cc, newCaller, nil) + service := newService("fake_zipper_addr", pd, newCaller, &ServiceOption{ + SourceBuilder: func(_, _ string) yomo.Source { return flow }, + ReducerBuilder: func(_, _ string) yomo.StreamFunction { return flow }, + MetadataExchanger: func(_ string) (metadata.M, error) { return metadata.M{"hello": "llm bridge"}, nil }, + }) caller, err := service.LoadOrCreateCaller(&http.Request{}) assert.NoError(t, err) From f0e222cf6349ad70bda6813d6dcaec22ae657be8 Mon Sep 17 00:00:00 2001 From: woorui Date: Wed, 14 Aug 2024 15:32:41 +0800 Subject: [PATCH 3/3] ServiceOption -> ServiceOptions --- pkg/bridge/ai/api_server.go | 2 +- pkg/bridge/ai/api_server_test.go | 2 +- pkg/bridge/ai/service.go | 14 +++++++------- pkg/bridge/ai/service_test.go | 4 ++-- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pkg/bridge/ai/api_server.go b/pkg/bridge/ai/api_server.go index 88e0fcae5..8f1275122 100644 --- a/pkg/bridge/ai/api_server.go +++ b/pkg/bridge/ai/api_server.go @@ -83,7 +83,7 @@ func NewBasicAPIServer(config *Config, zipperAddr, credential string, provider p logger = logger.With("component", "bridge") - service := NewService(zipperAddr, provider, &ServiceOption{ + service := NewService(zipperAddr, provider, &ServiceOptions{ Logger: logger, Tracer: otel.Tracer("yomo-llm-bridge"), CredentialFunc: func(r *http.Request) (string, error) { return credential, nil }, diff --git a/pkg/bridge/ai/api_server_test.go b/pkg/bridge/ai/api_server_test.go index 1678ee4c4..235d27c63 100644 --- a/pkg/bridge/ai/api_server_test.go +++ b/pkg/bridge/ai/api_server_test.go @@ -46,7 +46,7 @@ func TestServer(t *testing.T) { return mockCaller(nil), err } - service := newService("fake_zipper_addr", pd, newCaller, &ServiceOption{ + service := newService("fake_zipper_addr", pd, newCaller, &ServiceOptions{ SourceBuilder: func(_, _ string) yomo.Source { return flow }, ReducerBuilder: func(_, _ string) yomo.StreamFunction { return flow }, MetadataExchanger: func(_ string) (metadata.M, error) { return metadata.M{"hello": "llm bridge"}, nil }, diff --git a/pkg/bridge/ai/service.go b/pkg/bridge/ai/service.go index 1777ed29e..fd970465e 100644 --- a/pkg/bridge/ai/service.go +++ b/pkg/bridge/ai/service.go @@ -30,12 +30,12 @@ type Service struct { provider provider.LLMProvider newCallerFunc newCallerFunc callers *expirable.LRU[string, *Caller] - option *ServiceOption + option *ServiceOptions logger *slog.Logger } -// ServiceOption is the option for creating service -type ServiceOption struct { +// ServiceOptions is the option for creating service +type ServiceOptions struct { // Logger is the logger for the service Logger *slog.Logger // Tracer is the tracer for the service @@ -57,13 +57,13 @@ type ServiceOption struct { } // NewService creates a new service for handling the logic from handler layer. -func NewService(zipperAddr string, provider provider.LLMProvider, opt *ServiceOption) *Service { +func NewService(zipperAddr string, provider provider.LLMProvider, opt *ServiceOptions) *Service { return newService(zipperAddr, provider, NewCaller, opt) } -func initOption(opt *ServiceOption) *ServiceOption { +func initOption(opt *ServiceOptions) *ServiceOptions { if opt == nil { - opt = &ServiceOption{} + opt = &ServiceOptions{} } if opt.Tracer == nil { opt.Tracer = noop.NewTracerProvider().Tracer("yomo-ai-bridge") @@ -105,7 +105,7 @@ func initOption(opt *ServiceOption) *ServiceOption { return opt } -func newService(zipperAddr string, provider provider.LLMProvider, ncf newCallerFunc, opt *ServiceOption) *Service { +func newService(zipperAddr string, provider provider.LLMProvider, ncf newCallerFunc, opt *ServiceOptions) *Service { var onEvict = func(_ string, caller *Caller) { caller.Close() } diff --git a/pkg/bridge/ai/service_test.go b/pkg/bridge/ai/service_test.go index 8a10cd2d2..474593a31 100644 --- a/pkg/bridge/ai/service_test.go +++ b/pkg/bridge/ai/service_test.go @@ -101,7 +101,7 @@ func TestServiceInvoke(t *testing.T) { return mockCaller(tt.args.mockCallReqResp), err } - service := newService("fake_zipper_addr", pd, newCaller, &ServiceOption{ + service := newService("fake_zipper_addr", pd, newCaller, &ServiceOptions{ SourceBuilder: func(_, _ string) yomo.Source { return flow }, ReducerBuilder: func(_, _ string) yomo.StreamFunction { return flow }, MetadataExchanger: func(_ string) (metadata.M, error) { return metadata.M{"hello": "llm bridge"}, nil }, @@ -270,7 +270,7 @@ func TestServiceChatCompletion(t *testing.T) { return mockCaller(tt.args.mockCallReqResp), err } - service := newService("fake_zipper_addr", pd, newCaller, &ServiceOption{ + service := newService("fake_zipper_addr", pd, newCaller, &ServiceOptions{ SourceBuilder: func(_, _ string) yomo.Source { return flow }, ReducerBuilder: func(_, _ string) yomo.StreamFunction { return flow }, MetadataExchanger: func(_ string) (metadata.M, error) { return metadata.M{"hello": "llm bridge"}, nil },