Skip to content

Commit

Permalink
feat: void provider's errors are sent directly to the end user (#889)
Browse files Browse the repository at this point in the history
  • Loading branch information
woorui committed Aug 26, 2024
1 parent 44e864c commit 4a1699f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 15 deletions.
31 changes: 21 additions & 10 deletions pkg/bridge/ai/api_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func decorateReqContext(service *Service, logger *slog.Logger) func(handler http

caller, err := service.LoadOrCreateCaller(r)
if err != nil {
RespondWithError(w, http.StatusBadRequest, err)
RespondWithError(w, http.StatusBadRequest, err, logger)
return
}
ctx = WithCallerContext(ctx, caller)
Expand Down Expand Up @@ -146,7 +146,7 @@ func (h *Handler) HandleOverview(w http.ResponseWriter, r *http.Request) {

tcs, err := register.ListToolCalls(FromCallerContext(r.Context()).Metadata())
if err != nil {
RespondWithError(w, http.StatusInternalServerError, err)
RespondWithError(w, http.StatusInternalServerError, err, h.service.logger)
return
}

Expand All @@ -169,7 +169,7 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) {
)
defer r.Body.Close()

req, err := DecodeRequest[ai.InvokeRequest](r, w)
req, err := DecodeRequest[ai.InvokeRequest](r, w, h.service.logger)
if err != nil {
return
}
Expand All @@ -181,7 +181,7 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) {

res, err := h.service.GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, FromCallerContext(ctx), req.IncludeCallStack)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, err)
RespondWithError(w, http.StatusInternalServerError, err, h.service.logger)
return
}

Expand All @@ -197,7 +197,7 @@ func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request)
)
defer r.Body.Close()

req, err := DecodeRequest[openai.ChatCompletionRequest](r, w)
req, err := DecodeRequest[openai.ChatCompletionRequest](r, w, h.service.logger)
if err != nil {
return
}
Expand All @@ -206,28 +206,39 @@ func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request)
defer cancel()

if err := h.service.GetChatCompletions(ctx, req, transID, FromCallerContext(ctx), w); err != nil {
RespondWithError(w, http.StatusBadRequest, err)
RespondWithError(w, http.StatusBadRequest, err, h.service.logger)
return
}
}

// DecodeRequest decodes the request body into given type.
func DecodeRequest[T any](r *http.Request, w http.ResponseWriter) (T, error) {
func DecodeRequest[T any](r *http.Request, w http.ResponseWriter, logger *slog.Logger) (T, error) {
var req T
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
w.Header().Set("Content-Type", "application/json")
RespondWithError(w, http.StatusBadRequest, err)
RespondWithError(w, http.StatusBadRequest, err, logger)
return req, err
}

return req, nil
}

// RespondWithError writes an error to response according to the OpenAI API spec.
func RespondWithError(w http.ResponseWriter, code int, err error) {
func RespondWithError(w http.ResponseWriter, code int, err error, logger *slog.Logger) {
logger.Error("bridge server error", "error", err)

errString := err.Error()
oerr, ok := err.(*openai.APIError)
if ok {
if oerr.HTTPStatusCode >= 400 {
code = http.StatusInternalServerError
errString = "Internal Server Error, Please Try Again Later."
}
}

w.WriteHeader(code)
w.Write([]byte(fmt.Sprintf(`{"error":{"code":"%d","message":"%s"}}`, code, err.Error())))
w.Write([]byte(fmt.Sprintf(`{"error":{"code":"%d","message":"%s"}}`, code, errString)))
}

func getLocalIP() (string, error) {
Expand Down
10 changes: 5 additions & 5 deletions pkg/bridge/ai/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,15 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
// 4. request first chat for getting tools
if req.Stream {
_, firstCallSpan := srv.option.Tracer.Start(reqCtx, "first_call_request")
var (
flusher = eventFlusher(w)
isFunctionCall = false
)

resStream, err := srv.provider.GetChatCompletionsStream(reqCtx, req, md)
if err != nil {
return err
}

var (
flusher = eventFlusher(w)
isFunctionCall = false
)
var (
i int // number of chunks
j int // number of tool call chunks
Expand Down

0 comments on commit 4a1699f

Please sign in to comment.