Skip to content

Commit

Permalink
Merge branch 'master' into chore/docker-platform
Browse files Browse the repository at this point in the history
  • Loading branch information
woorui committed Aug 16, 2024
2 parents 56c6ebd + c68e252 commit 6ff9755
Show file tree
Hide file tree
Showing 8 changed files with 1,288 additions and 1,228 deletions.
105 changes: 48 additions & 57 deletions pkg/bridge/ai/api_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ type BasicAPIServer struct {
zipperAddr string
credential string
httpHandler http.Handler
logger *slog.Logger
}

// Serve starts the Basic API Server
Expand All @@ -44,19 +43,20 @@ func Serve(config *Config, zipperListenAddr string, credential string, logger *s
if err != nil {
return err
}
srv, err := NewBasicAPIServer(config, zipperListenAddr, provider, credential, logger)
srv, err := NewBasicAPIServer(config, zipperListenAddr, credential, provider, logger)
if err != nil {
return err
}

logger.Info("start AI Bridge service", "addr", config.Server.Addr, "provider", provider.Name())
return srv.ServeAddr(config.Server.Addr)
return http.ListenAndServe(config.Server.Addr, srv.httpHandler)
}

func BridgeHTTPHanlder(provider provider.LLMProvider, decorater func(http.Handler) http.Handler) http.Handler {
// NewServeMux creates a new http.ServeMux for the llm bridge server.
func NewServeMux(service *Service) *http.ServeMux {
var (
h = &Handler{service}
mux = http.NewServeMux()
h = NewHandler(provider)
)
// GET /overview
mux.HandleFunc("/overview", h.HandleOverview)
Expand All @@ -65,57 +65,59 @@ func BridgeHTTPHanlder(provider provider.LLMProvider, decorater func(http.Handle
// POST /v1/chat/completions (OpenAI compatible interface)
mux.HandleFunc("/v1/chat/completions", h.HandleChatCompletions)

return decorater(mux)
return mux
}

// DecorateHandler decorates the http.Handler.
func DecorateHandler(h http.Handler, decorates ...func(handler http.Handler) http.Handler) http.Handler {
// decorate the http.Handler
for i := len(decorates) - 1; i >= 0; i-- {
h = decorates[i](h)
}
return h
}

// NewBasicAPIServer creates a new restful service
func NewBasicAPIServer(config *Config, zipperAddr string, provider provider.LLMProvider, credential string, logger *slog.Logger) (*BasicAPIServer, error) {
func NewBasicAPIServer(config *Config, zipperAddr, credential string, provider provider.LLMProvider, logger *slog.Logger) (*BasicAPIServer, error) {
zipperAddr = parseZipperAddr(zipperAddr)

cp := NewCallerProvider(zipperAddr, DefaultExchangeMetadataFunc)
logger = logger.With("component", "bridge")

service := NewService(zipperAddr, provider, &ServiceOptions{
Logger: logger,
Tracer: otel.Tracer("yomo-llm-bridge"),
CredentialFunc: func(r *http.Request) (string, error) { return credential, nil },
})

mux := NewServeMux(service)

server := &BasicAPIServer{
zipperAddr: zipperAddr,
credential: credential,
httpHandler: BridgeHTTPHanlder(provider, decorateReqContext(cp, logger, credential)),
logger: logger.With("component", "bridge"),
httpHandler: DecorateHandler(mux, decorateReqContext(service, logger)),
}

return server, nil
}

// ServeAddr starts a http server that provides some endpoints to bridge up the http server and YoMo.
// User can chat to the http server and interact with the YoMo's stream function.
func (a *BasicAPIServer) ServeAddr(addr string) error {
return http.ListenAndServe(addr, a.httpHandler)
}

// decorateReqContext decorates the context of the request, it injects a transID and a caller into the context.
func decorateReqContext(cp CallerProvider, logger *slog.Logger, credential string) func(handler http.Handler) http.Handler {
tracer := otel.Tracer("yomo-llm-bridge")

caller, err := cp.Provide(credential)
if err != nil {
logger.Info("can't load caller", "err", err)

return func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
RespondWithError(w, http.StatusInternalServerError, err)
})
}
}

caller.SetTracer(tracer)

// decorateReqContext decorates the context of the request, it injects a transID into the request's context,
// log the request information and start tracing the request.
func decorateReqContext(service *Service, logger *slog.Logger) func(handler http.Handler) http.Handler {
host, _ := os.Hostname()

return func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

caller, err := service.LoadOrCreateCaller(r)
if err != nil {
RespondWithError(w, http.StatusBadRequest, err)
return
}
ctx = WithCallerContext(ctx, caller)

// trace every request
ctx, span := tracer.Start(
ctx, span := service.option.Tracer.Start(
ctx,
r.URL.Path,
trace.WithSpanKind(trace.SpanKindServer),
Expand All @@ -125,7 +127,6 @@ func decorateReqContext(cp CallerProvider, logger *slog.Logger, credential strin

transID := id.New(32)
ctx = WithTransIDContext(ctx, transID)
ctx = WithCallerContext(ctx, caller)

logger.Info("request", "method", r.Method, "path", r.URL.Path, "transID", transID)

Expand All @@ -136,24 +137,16 @@ func decorateReqContext(cp CallerProvider, logger *slog.Logger, credential strin

// Handler handles the http request.
type Handler struct {
provider provider.LLMProvider
}

// NewHandler returns a new Handler.
func NewHandler(provider provider.LLMProvider) *Handler {
return &Handler{provider}
service *Service
}

// HandleOverview is the handler for GET /overview
func (h *Handler) HandleOverview(w http.ResponseWriter, r *http.Request) {
caller := FromCallerContext(r.Context())

w.Header().Set("Content-Type", "application/json")

tcs, err := register.ListToolCalls(caller.Metadata())
tcs, err := register.ListToolCalls(FromCallerContext(r.Context()).Metadata())
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
RespondWithError(w, http.StatusInternalServerError, err)
return
}

Expand All @@ -172,7 +165,6 @@ var baseSystemMessage = `You are a very helpful assistant. Your job is to choose
func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
caller = FromCallerContext(ctx)
transID = FromTransIDContext(ctx)
)
defer r.Body.Close()
Expand All @@ -185,14 +177,14 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), RequestTimeout)
defer cancel()

res, err := GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, req.IncludeCallStack, caller, h.provider)
w.Header().Set("Content-Type", "application/json")

res, err := h.service.GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, FromCallerContext(ctx), req.IncludeCallStack)
if err != nil {
w.Header().Set("Content-Type", "application/json")
RespondWithError(w, http.StatusInternalServerError, err)
return
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(res)
}
Expand All @@ -201,7 +193,6 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) {
func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
caller = FromCallerContext(ctx)
transID = FromTransIDContext(ctx)
)
defer r.Body.Close()
Expand All @@ -214,7 +205,7 @@ func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request)
ctx, cancel := context.WithTimeout(r.Context(), RequestTimeout)
defer cancel()

if err := GetChatCompletions(ctx, req, transID, h.provider, caller, w); err != nil {
if err := h.service.GetChatCompletions(ctx, req, transID, FromCallerContext(ctx), w); err != nil {
RespondWithError(w, http.StatusBadRequest, err)
return
}
Expand Down Expand Up @@ -258,17 +249,17 @@ func getLocalIP() (string, error) {
type callerContextKey struct{}

// WithCallerContext adds the caller to the request context
func WithCallerContext(ctx context.Context, caller Caller) context.Context {
func WithCallerContext(ctx context.Context, caller *Caller) context.Context {
return context.WithValue(ctx, callerContextKey{}, caller)
}

// FromCallerContext returns the caller from the request context
func FromCallerContext(ctx context.Context) Caller {
service, ok := ctx.Value(callerContextKey{}).(Caller)
func FromCallerContext(ctx context.Context) *Caller {
caller, ok := ctx.Value(callerContextKey{}).(*Caller)
if !ok {
return nil
}
return service
return caller
}

type transIDContextKey struct{}
Expand Down
18 changes: 14 additions & 4 deletions pkg/bridge/ai/api_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import (
"bytes"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/yomorun/yomo"
"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/core/metadata"
"github.com/yomorun/yomo/pkg/bridge/ai/provider"
"github.com/yomorun/yomo/pkg/bridge/ai/register"
)
Expand Down Expand Up @@ -38,11 +40,19 @@ func TestServer(t *testing.T) {
t.Fatal(err)
}

cp := newMockCallerProvider()
flow := newMockDataFlow(newHandler(2 * time.Hour).handle)

cp.provideFunc = mockCallerProvideFunc(map[uint32][]mockFunctionCall{})
newCaller := func(_ yomo.Source, _ yomo.StreamFunction, _ metadata.M, _ time.Duration) (*Caller, error) {
return mockCaller(nil), err
}

service := newService("fake_zipper_addr", pd, newCaller, &ServiceOptions{
SourceBuilder: func(_, _ string) yomo.Source { return flow },
ReducerBuilder: func(_, _ string) yomo.StreamFunction { return flow },
MetadataExchanger: func(_ string) (metadata.M, error) { return metadata.M{"hello": "llm bridge"}, nil },
})

handler := BridgeHTTPHanlder(pd, decorateReqContext(cp, slog.Default(), ""))
handler := DecorateHandler(NewServeMux(service), decorateReqContext(service, service.logger))

// create a test server
server := httptest.NewServer(handler)
Expand Down
38 changes: 0 additions & 38 deletions pkg/bridge/ai/call_syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ import (
"time"

openai "github.com/sashabaranov/go-openai"
"github.com/yomorun/yomo"
"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/serverless"
)

// CallSyncer fires a bunch of function callings, and wait the result of these function callings.
Expand Down Expand Up @@ -223,39 +221,3 @@ func (f *callSyncer) background() {
}
}
}

// ToReducer converts a stream function to a reducer that can reduce the function calling result.
func ToReducer(sfn yomo.StreamFunction, logger *slog.Logger, ch chan ReduceMessage) {
// set observe data tags
sfn.SetObserveDataTags(ai.ReducerTag)
// set reduce handler
sfn.SetHandler(func(ctx serverless.Context) {
invoke, err := ctx.LLMFunctionCall()
if err != nil {
ch <- ReduceMessage{ReqID: ""}
logger.Error("parse function calling invoke", "err", err.Error())
return
}
logger.Debug("sfn-reducer", "req_id", invoke.ReqID, "tool_call_id", invoke.ToolCallID, "result", string(invoke.Result))

message := openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleTool,
Content: invoke.Result,
ToolCallID: invoke.ToolCallID,
}

ch <- ReduceMessage{ReqID: invoke.ReqID, Message: message}
})
}

// ToSource convert a yomo source to the source that can send function calling body to the llm function.
func ToSource(source yomo.Source, logger *slog.Logger, ch chan TagFunctionCall) {
go func() {
for c := range ch {
buf, _ := c.FunctionCall.Bytes()
if err := source.Write(c.Tag, buf); err != nil {
logger.Error("send data to zipper", "err", err.Error())
}
}
}()
}
22 changes: 8 additions & 14 deletions pkg/bridge/ai/call_syncer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,10 @@ func TestTimeoutCallSyncer(t *testing.T) {
flow := newMockDataFlow(h.handle)
defer flow.Close()

reqs := make(chan TagFunctionCall)
ToSource(flow, slog.Default(), reqs)
req, _ := sourceWriteToChan(flow, slog.Default())
res, _ := reduceToChan(flow, slog.Default())

messages := make(chan ReduceMessage)
ToReducer(flow, slog.Default(), messages)

syncer := NewCallSyncer(slog.Default(), reqs, messages, time.Millisecond)
syncer := NewCallSyncer(slog.Default(), req, res, time.Millisecond)
go flow.run()

var (
Expand Down Expand Up @@ -61,13 +58,10 @@ func TestCallSyncer(t *testing.T) {
flow := newMockDataFlow(h.handle)
defer flow.Close()

reqs := make(chan TagFunctionCall)
ToSource(flow, slog.Default(), reqs)

messages := make(chan ReduceMessage)
ToReducer(flow, slog.Default(), messages)
req, _ := sourceWriteToChan(flow, slog.Default())
res, _ := reduceToChan(flow, slog.Default())

syncer := NewCallSyncer(slog.Default(), reqs, messages, 0)
syncer := NewCallSyncer(slog.Default(), req, res, 0)
go flow.run()

var (
Expand Down Expand Up @@ -118,7 +112,7 @@ func (h *handler) result() []openai.ChatCompletionMessage {
return want
}

// mockDataFlow mocks the data flow of ai bridge.
// mockDataFlow mocks the data flow of llm bridge.
// The data flow is: source -> hander -> reducer,
// It is `Write() -> handler() -> reducer()` in this mock implementation.
type mockDataFlow struct {
Expand Down Expand Up @@ -160,11 +154,11 @@ var _ yomo.StreamFunction = (*mockDataFlow)(nil)

// The test will not use blowing function in this mock implementation.
func (t *mockDataFlow) SetObserveDataTags(tag ...uint32) {}
func (t *mockDataFlow) Connect() error { return nil }
func (t *mockDataFlow) Init(fn func() error) error { panic("unimplemented") }
func (t *mockDataFlow) SetCronHandler(spec string, fn core.CronHandler) error { panic("unimplemented") }
func (t *mockDataFlow) SetPipeHandler(fn core.PipeHandler) error { panic("unimplemented") }
func (t *mockDataFlow) SetWantedTarget(string) { panic("unimplemented") }
func (t *mockDataFlow) Wait() { panic("unimplemented") }
func (t *mockDataFlow) Connect() error { panic("unimplemented") }
func (t *mockDataFlow) SetErrorHandler(fn func(err error)) { panic("unimplemented") }
func (t *mockDataFlow) WriteWithTarget(_ uint32, _ []byte, _ string) error { panic("unimplemented") }
Loading

0 comments on commit 6ff9755

Please sign in to comment.