Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: new caller from interfaces & new service from its option #883

Merged
merged 4 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 48 additions & 57 deletions pkg/bridge/ai/api_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
zipperAddr string
credential string
httpHandler http.Handler
logger *slog.Logger
}

// Serve starts the Basic API Server
Expand All @@ -44,19 +43,20 @@
if err != nil {
return err
}
srv, err := NewBasicAPIServer(config, zipperListenAddr, provider, credential, logger)
srv, err := NewBasicAPIServer(config, zipperListenAddr, credential, provider, logger)

Check warning on line 46 in pkg/bridge/ai/api_server.go

View check run for this annotation

Codecov / codecov/patch

pkg/bridge/ai/api_server.go#L46

Added line #L46 was not covered by tests
if err != nil {
return err
}

logger.Info("start bridge server", "addr", config.Server.Addr, "provider", provider.Name())
return srv.ServeAddr(config.Server.Addr)
return http.ListenAndServe(config.Server.Addr, srv.httpHandler)

Check warning on line 52 in pkg/bridge/ai/api_server.go

View check run for this annotation

Codecov / codecov/patch

pkg/bridge/ai/api_server.go#L52

Added line #L52 was not covered by tests
}

func BridgeHTTPHanlder(provider provider.LLMProvider, decorater func(http.Handler) http.Handler) http.Handler {
// NewServeMux creates a new http.ServeMux for the llm bridge server.
func NewServeMux(service *Service) *http.ServeMux {
var (
h = &Handler{service}
mux = http.NewServeMux()
h = NewHandler(provider)
)
// GET /overview
mux.HandleFunc("/overview", h.HandleOverview)
Expand All @@ -65,57 +65,59 @@
// POST /v1/chat/completions (OpenAI compatible interface)
mux.HandleFunc("/v1/chat/completions", h.HandleChatCompletions)

return decorater(mux)
return mux
}

// DecorateHandler decorates the http.Handler.
func DecorateHandler(h http.Handler, decorates ...func(handler http.Handler) http.Handler) http.Handler {
// decorate the http.Handler
for i := len(decorates) - 1; i >= 0; i-- {
h = decorates[i](h)
}
return h
}

// NewBasicAPIServer creates a new restful service
func NewBasicAPIServer(config *Config, zipperAddr string, provider provider.LLMProvider, credential string, logger *slog.Logger) (*BasicAPIServer, error) {
func NewBasicAPIServer(config *Config, zipperAddr, credential string, provider provider.LLMProvider, logger *slog.Logger) (*BasicAPIServer, error) {

Check warning on line 81 in pkg/bridge/ai/api_server.go

View check run for this annotation

Codecov / codecov/patch

pkg/bridge/ai/api_server.go#L81

Added line #L81 was not covered by tests
zipperAddr = parseZipperAddr(zipperAddr)

cp := NewCallerProvider(zipperAddr, DefaultExchangeMetadataFunc)
logger = logger.With("component", "bridge")

Check warning on line 84 in pkg/bridge/ai/api_server.go

View check run for this annotation

Codecov / codecov/patch

pkg/bridge/ai/api_server.go#L84

Added line #L84 was not covered by tests

service := NewService(zipperAddr, provider, &ServiceOption{
Logger: logger,
Tracer: otel.Tracer("yomo-llm-bridge"),
CredentialFunc: func(r *http.Request) (string, error) { return credential, nil },

Check warning on line 89 in pkg/bridge/ai/api_server.go

View check run for this annotation

Codecov / codecov/patch

pkg/bridge/ai/api_server.go#L86-L89

Added lines #L86 - L89 were not covered by tests
})

mux := NewServeMux(service)

Check warning on line 92 in pkg/bridge/ai/api_server.go

View check run for this annotation

Codecov / codecov/patch

pkg/bridge/ai/api_server.go#L92

Added line #L92 was not covered by tests

server := &BasicAPIServer{
zipperAddr: zipperAddr,
credential: credential,
httpHandler: BridgeHTTPHanlder(provider, decorateReqContext(cp, logger, credential)),
logger: logger.With("component", "bridge"),
httpHandler: DecorateHandler(mux, decorateReqContext(service, logger)),

Check warning on line 97 in pkg/bridge/ai/api_server.go

View check run for this annotation

Codecov / codecov/patch

pkg/bridge/ai/api_server.go#L97

Added line #L97 was not covered by tests
}

return server, nil
}

// ServeAddr starts a http server that provides some endpoints to bridge up the http server and YoMo.
// User can chat to the http server and interact with the YoMo's stream function.
func (a *BasicAPIServer) ServeAddr(addr string) error {
return http.ListenAndServe(addr, a.httpHandler)
}

// decorateReqContext decorates the context of the request, it injects a transID and a caller into the context.
func decorateReqContext(cp CallerProvider, logger *slog.Logger, credential string) func(handler http.Handler) http.Handler {
tracer := otel.Tracer("yomo-llm-bridge")

caller, err := cp.Provide(credential)
if err != nil {
logger.Info("can't load caller", "err", err)

return func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
RespondWithError(w, http.StatusInternalServerError, err)
})
}
}

caller.SetTracer(tracer)

// decorateReqContext decorates the context of the request, it injects a transID into the request's context,
// log the request information and start tracing the request.
func decorateReqContext(service *Service, logger *slog.Logger) func(handler http.Handler) http.Handler {
host, _ := os.Hostname()

return func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

caller, err := service.LoadOrCreateCaller(r)
if err != nil {
RespondWithError(w, http.StatusBadRequest, err)
return

Check warning on line 115 in pkg/bridge/ai/api_server.go

View check run for this annotation

Codecov / codecov/patch

pkg/bridge/ai/api_server.go#L114-L115

Added lines #L114 - L115 were not covered by tests
}
ctx = WithCallerContext(ctx, caller)

// trace every request
ctx, span := tracer.Start(
ctx, span := service.option.Tracer.Start(
ctx,
r.URL.Path,
trace.WithSpanKind(trace.SpanKindServer),
Expand All @@ -125,7 +127,6 @@

transID := id.New(32)
ctx = WithTransIDContext(ctx, transID)
ctx = WithCallerContext(ctx, caller)

logger.Info("request", "method", r.Method, "path", r.URL.Path, "transID", transID)

Expand All @@ -136,24 +137,16 @@

// Handler handles the http request.
type Handler struct {
provider provider.LLMProvider
}

// NewHandler returns a new Handler.
func NewHandler(provider provider.LLMProvider) *Handler {
return &Handler{provider}
service *Service
}

// HandleOverview is the handler for GET /overview
func (h *Handler) HandleOverview(w http.ResponseWriter, r *http.Request) {
caller := FromCallerContext(r.Context())

w.Header().Set("Content-Type", "application/json")

tcs, err := register.ListToolCalls(caller.Metadata())
tcs, err := register.ListToolCalls(FromCallerContext(r.Context()).Metadata())
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
RespondWithError(w, http.StatusInternalServerError, err)

Check warning on line 149 in pkg/bridge/ai/api_server.go

View check run for this annotation

Codecov / codecov/patch

pkg/bridge/ai/api_server.go#L149

Added line #L149 was not covered by tests
return
}

Expand All @@ -172,7 +165,6 @@
func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
caller = FromCallerContext(ctx)
transID = FromTransIDContext(ctx)
)
defer r.Body.Close()
Expand All @@ -185,14 +177,14 @@
ctx, cancel := context.WithTimeout(r.Context(), RequestTimeout)
defer cancel()

res, err := GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, req.IncludeCallStack, caller, h.provider)
w.Header().Set("Content-Type", "application/json")

res, err := h.service.GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, FromCallerContext(ctx), req.IncludeCallStack)
if err != nil {
w.Header().Set("Content-Type", "application/json")
RespondWithError(w, http.StatusInternalServerError, err)
return
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(res)
}
Expand All @@ -201,7 +193,6 @@
func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
caller = FromCallerContext(ctx)
transID = FromTransIDContext(ctx)
)
defer r.Body.Close()
Expand All @@ -214,7 +205,7 @@
ctx, cancel := context.WithTimeout(r.Context(), RequestTimeout)
defer cancel()

if err := GetChatCompletions(ctx, req, transID, h.provider, caller, w); err != nil {
if err := h.service.GetChatCompletions(ctx, req, transID, FromCallerContext(ctx), w); err != nil {
RespondWithError(w, http.StatusBadRequest, err)
return
}
Expand Down Expand Up @@ -258,17 +249,17 @@
type callerContextKey struct{}

// WithCallerContext adds the caller to the request context
func WithCallerContext(ctx context.Context, caller Caller) context.Context {
func WithCallerContext(ctx context.Context, caller *Caller) context.Context {
return context.WithValue(ctx, callerContextKey{}, caller)
}

// FromCallerContext returns the caller from the request context
func FromCallerContext(ctx context.Context) Caller {
service, ok := ctx.Value(callerContextKey{}).(Caller)
func FromCallerContext(ctx context.Context) *Caller {
caller, ok := ctx.Value(callerContextKey{}).(*Caller)
if !ok {
return nil
}
return service
return caller
}

type transIDContextKey struct{}
Expand Down
18 changes: 14 additions & 4 deletions pkg/bridge/ai/api_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import (
"bytes"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/yomorun/yomo"
"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/core/metadata"
"github.com/yomorun/yomo/pkg/bridge/ai/provider"
"github.com/yomorun/yomo/pkg/bridge/ai/register"
)
Expand Down Expand Up @@ -38,11 +40,19 @@ func TestServer(t *testing.T) {
t.Fatal(err)
}

cp := newMockCallerProvider()
flow := newMockDataFlow(newHandler(2 * time.Hour).handle)

cp.provideFunc = mockCallerProvideFunc(map[uint32][]mockFunctionCall{})
newCaller := func(_ yomo.Source, _ yomo.StreamFunction, _ metadata.M, _ time.Duration) (*Caller, error) {
return mockCaller(nil), err
}

service := newService("fake_zipper_addr", pd, newCaller, &ServiceOption{
SourceBuilder: func(_, _ string) yomo.Source { return flow },
ReducerBuilder: func(_, _ string) yomo.StreamFunction { return flow },
MetadataExchanger: func(_ string) (metadata.M, error) { return metadata.M{"hello": "llm bridge"}, nil },
})

handler := BridgeHTTPHanlder(pd, decorateReqContext(cp, slog.Default(), ""))
handler := DecorateHandler(NewServeMux(service), decorateReqContext(service, service.logger))

// create a test server
server := httptest.NewServer(handler)
Expand Down
38 changes: 0 additions & 38 deletions pkg/bridge/ai/call_syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ import (
"time"

openai "github.com/sashabaranov/go-openai"
"github.com/yomorun/yomo"
"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/serverless"
)

// CallSyncer fires a bunch of function callings, and wait the result of these function callings.
Expand Down Expand Up @@ -223,39 +221,3 @@ func (f *callSyncer) background() {
}
}
}

// ToReducer converts a stream function to a reducer that can reduce the function calling result.
func ToReducer(sfn yomo.StreamFunction, logger *slog.Logger, ch chan ReduceMessage) {
// set observe data tags
sfn.SetObserveDataTags(ai.ReducerTag)
// set reduce handler
sfn.SetHandler(func(ctx serverless.Context) {
invoke, err := ctx.LLMFunctionCall()
if err != nil {
ch <- ReduceMessage{ReqID: ""}
logger.Error("parse function calling invoke", "err", err.Error())
return
}
logger.Debug("sfn-reducer", "req_id", invoke.ReqID, "tool_call_id", invoke.ToolCallID, "result", string(invoke.Result))

message := openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleTool,
Content: invoke.Result,
ToolCallID: invoke.ToolCallID,
}

ch <- ReduceMessage{ReqID: invoke.ReqID, Message: message}
})
}

// ToSource convert a yomo source to the source that can send function calling body to the llm function.
func ToSource(source yomo.Source, logger *slog.Logger, ch chan TagFunctionCall) {
go func() {
for c := range ch {
buf, _ := c.FunctionCall.Bytes()
if err := source.Write(c.Tag, buf); err != nil {
logger.Error("send data to zipper", "err", err.Error())
}
}
}()
}
22 changes: 8 additions & 14 deletions pkg/bridge/ai/call_syncer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,10 @@ func TestTimeoutCallSyncer(t *testing.T) {
flow := newMockDataFlow(h.handle)
defer flow.Close()

reqs := make(chan TagFunctionCall)
ToSource(flow, slog.Default(), reqs)
req, _ := sourceWriteToChan(flow, slog.Default())
res, _ := reduceToChan(flow, slog.Default())

messages := make(chan ReduceMessage)
ToReducer(flow, slog.Default(), messages)

syncer := NewCallSyncer(slog.Default(), reqs, messages, time.Millisecond)
syncer := NewCallSyncer(slog.Default(), req, res, time.Millisecond)
go flow.run()

var (
Expand Down Expand Up @@ -61,13 +58,10 @@ func TestCallSyncer(t *testing.T) {
flow := newMockDataFlow(h.handle)
defer flow.Close()

reqs := make(chan TagFunctionCall)
ToSource(flow, slog.Default(), reqs)

messages := make(chan ReduceMessage)
ToReducer(flow, slog.Default(), messages)
req, _ := sourceWriteToChan(flow, slog.Default())
res, _ := reduceToChan(flow, slog.Default())

syncer := NewCallSyncer(slog.Default(), reqs, messages, 0)
syncer := NewCallSyncer(slog.Default(), req, res, 0)
go flow.run()

var (
Expand Down Expand Up @@ -118,7 +112,7 @@ func (h *handler) result() []openai.ChatCompletionMessage {
return want
}

// mockDataFlow mocks the data flow of ai bridge.
// mockDataFlow mocks the data flow of llm bridge.
// The data flow is: source -> hander -> reducer,
// It is `Write() -> handler() -> reducer()` in this mock implementation.
type mockDataFlow struct {
Expand Down Expand Up @@ -160,11 +154,11 @@ var _ yomo.StreamFunction = (*mockDataFlow)(nil)

// The test will not use blowing function in this mock implementation.
func (t *mockDataFlow) SetObserveDataTags(tag ...uint32) {}
func (t *mockDataFlow) Connect() error { return nil }
func (t *mockDataFlow) Init(fn func() error) error { panic("unimplemented") }
func (t *mockDataFlow) SetCronHandler(spec string, fn core.CronHandler) error { panic("unimplemented") }
func (t *mockDataFlow) SetPipeHandler(fn core.PipeHandler) error { panic("unimplemented") }
func (t *mockDataFlow) SetWantedTarget(string) { panic("unimplemented") }
func (t *mockDataFlow) Wait() { panic("unimplemented") }
func (t *mockDataFlow) Connect() error { panic("unimplemented") }
func (t *mockDataFlow) SetErrorHandler(fn func(err error)) { panic("unimplemented") }
func (t *mockDataFlow) WriteWithTarget(_ uint32, _ []byte, _ string) error { panic("unimplemented") }
Loading