From 02356ec87f719248dc2df27d3db1127414614657 Mon Sep 17 00:00:00 2001 From: venjiang Date: Fri, 6 Sep 2024 20:13:27 +0800 Subject: [PATCH] refactor: ollama provider (#866) --- cli/serve.go | 2 +- example/10-ai/zipper.yaml | 12 +- pkg/bridge/ai/provider/ollama/Readme.md | 8 +- pkg/bridge/ai/provider/ollama/provider.go | 359 ++-------------------- pkg/bridge/ai/service.go | 78 +++++ 5 files changed, 121 insertions(+), 338 deletions(-) diff --git a/cli/serve.go b/cli/serve.go index 7a1bb66ae..7cafff3a5 100644 --- a/cli/serve.go +++ b/cli/serve.go @@ -150,7 +150,7 @@ func registerAIProvider(aiConfig *ai.Config) error { provider["model"], )) case "ollama": - providerpkg.RegisterProvider(ollama.NewProvider(provider["api_endpoint"])) + providerpkg.RegisterProvider(ollama.NewProvider(provider["api_endpoint"], provider["model"])) case "gemini": providerpkg.RegisterProvider(gemini.NewProvider(provider["api_key"])) case "githubmodels": diff --git a/example/10-ai/zipper.yaml b/example/10-ai/zipper.yaml index bbd3503dc..ac4903af1 100644 --- a/example/10-ai/zipper.yaml +++ b/example/10-ai/zipper.yaml @@ -2,15 +2,15 @@ name: ai-zipper host: 0.0.0.0 port: 9000 -auth: - type: token - token: Happy New Year +# auth: +# type: token +# token: Happy New Year bridge: ai: server: addr: localhost:8000 - provider: anthropic + provider: ollama providers: azopenai: @@ -33,6 +33,10 @@ bridge: deployment_id: api_version: + ollama: + api_endpoint: http://localhost:11434/v1 + model: llama3.1 + cerebras: api_key: model: diff --git a/pkg/bridge/ai/provider/ollama/Readme.md b/pkg/bridge/ai/provider/ollama/Readme.md index 4f1e25113..e681c79fd 100644 --- a/pkg/bridge/ai/provider/ollama/Readme.md +++ b/pkg/bridge/ai/provider/ollama/Readme.md @@ -8,12 +8,10 @@ Follow the Ollama doc: -## 2. Run the Mistral model - -Notice that only the Mistral v0.3+ models are supported currently. +## 2. Run the model ```sh -ollama run mistral:7b +ollama run llama3.1 ``` ## 3. Start YoMo Zipper @@ -34,7 +32,7 @@ bridge: provider: ollama providers: ollama: - api_endpoint: "http://localhost:11434/" + api_endpoint: "http://localhost:11434/v1" ``` ```sh diff --git a/pkg/bridge/ai/provider/ollama/provider.go b/pkg/bridge/ai/provider/ollama/provider.go index 2ce783ddb..8861ba647 100644 --- a/pkg/bridge/ai/provider/ollama/provider.go +++ b/pkg/bridge/ai/provider/ollama/provider.go @@ -2,349 +2,40 @@ package ollama import ( - "bytes" "context" - _ "embed" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "net/url" "os" - "strings" - "text/template" - "time" openai "github.com/sashabaranov/go-openai" "github.com/yomorun/yomo/core/metadata" - "github.com/yomorun/yomo/core/ylog" "github.com/yomorun/yomo/pkg/bridge/ai/provider" - "github.com/yomorun/yomo/pkg/id" ) +// check if implements ai.Provider +var _ provider.LLMProvider = &Provider{} + // Provider is the provider for Ollama type Provider struct { - Endpoint string -} - -type ollamaRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - Raw bool `json:"raw"` - Stream bool `json:"stream"` -} - -type ollamaResponse struct { - Response string `json:"response"` - Done bool `json:"done"` - PromptEvalCount int `json:"prompt_eval_count"` - EvalCount int `json:"eval_count"` -} - -type templateRequest struct { - Tools string - System string - Prompt string -} - -type mistralFunction struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` -} - -const ( - defaultSystem = "You are a very helpful assistant. Your job is to choose the best possible action to solve the user question or task." - - systemToolExtra = "If the question of the user matched the description of a tool, the tool will be called, and only the function description JSON object should be returned. Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous." - - mistralTmpl = "{{ if .Tools }}[AVAILABLE_TOOLS] {{.Tools}} [/AVAILABLE_TOOLS] {{ end }}[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]" -) - -func makeOllamaRequestBody(req openai.ChatCompletionRequest) (io.Reader, error) { - if req.Model == "" { - req.Model = "mistral" - } - - if !strings.HasPrefix(req.Model, "mistral") { - return nil, errors.New("currently only Mistral models are supported, see https://ollama.com/library/mistral") - } - - t := &templateRequest{ - System: defaultSystem, - Tools: "[]", - } - for _, msg := range req.Messages { - switch strings.ToLower(msg.Role) { - case openai.ChatMessageRoleSystem: - t.System = msg.Content - case openai.ChatMessageRoleUser: - t.Prompt += msg.Content + " " - case openai.ChatMessageRoleTool: - t.Prompt += msg.Content + " " - } - } - - if len(req.Tools) > 0 { - t.System += systemToolExtra - - req.Stream = false - - tools, err := json.Marshal(req.Tools) - if err != nil { - return nil, err - } - t.Tools = string(tools) - } - - ylog.Debug("ollama chat request", "model", req.Model, "system", t.System, "prompt", t.Prompt, "tools", t.Tools) - - tmpl, err := template.New("ollama").Parse(mistralTmpl) - if err != nil { - return nil, err - } - - prompt := bytes.NewBufferString("") - err = tmpl.Execute(prompt, t) - if err != nil { - return nil, err - } - - body, err := json.Marshal(&ollamaRequest{ - Model: req.Model, - Prompt: prompt.String(), - Raw: true, - Stream: req.Stream, - }) - if err != nil { - return nil, err - } - - return bytes.NewBuffer(body), nil -} - -func parseToolCallsFromResponse(response string) []openai.ToolCall { - toolCalls := make([]openai.ToolCall, 0) - - response = strings.TrimPrefix(response, "[TOOL_CALLS]") - for _, v := range strings.Split(response, "\n") { - var functions []mistralFunction - if json.Unmarshal([]byte(v), &functions) == nil { - for _, f := range functions { - arguments, _ := json.Marshal(f.Arguments) - toolCalls = append(toolCalls, openai.ToolCall{ - ID: id.New(), - Type: openai.ToolTypeFunction, - Function: openai.FunctionCall{ - Name: f.Name, - Arguments: string(arguments), - }, - }) - } - } - } - - return toolCalls + // ollama OpenAI compatibility api endpoint + APIEndpoint string + // Model is the default model for Ollama + Model string + client *openai.Client } // GetChatCompletions implements ai.LLMProvider. func (p *Provider) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (openai.ChatCompletionResponse, error) { - res := openai.ChatCompletionResponse{ - ID: "chatcmpl-" + id.New(29), - Model: req.Model, - Created: time.Now().Unix(), - Choices: []openai.ChatCompletionChoice{ - { - Index: 0, - Message: openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleAssistant, - Content: "error occured during inference period", - }, - FinishReason: openai.FinishReasonStop, - }, - }, - } - - urlPath, err := url.JoinPath(p.Endpoint, "api/generate") - if err != nil { - return res, err - } - - body, err := makeOllamaRequestBody(req) - if err != nil { - return res, err - } - - client := http.Client{} - request, err := http.NewRequestWithContext(ctx, http.MethodPost, urlPath, body) - if err != nil { - return res, err - } - - request.Header.Set("Content-Type", "application/json") - - resp, err := client.Do(request) - if err != nil { - return res, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return res, fmt.Errorf("ollama inference error: %s", resp.Status) - } - - buf, err := io.ReadAll(resp.Body) - if err != nil { - return res, err - } - - var o ollamaResponse - err = json.Unmarshal(buf, &o) - if err != nil { - return res, err - } - - ylog.Debug("ollama chat response", "response", o.Response) - ylog.Debug("ollama chat usage", "prompt_tokens", o.PromptEvalCount, "completion_tokens", o.EvalCount) - - if o.Response != "" { - res.Choices[0].Message.Content = o.Response - - if len(req.Tools) > 0 { - toolCalls := parseToolCallsFromResponse(o.Response) - if len(toolCalls) > 0 { - res.Choices[0].FinishReason = openai.FinishReasonToolCalls - res.Choices[0].Message.ToolCalls = toolCalls - } - } - - res.Usage = openai.Usage{ - PromptTokens: o.PromptEvalCount, - CompletionTokens: o.EvalCount, - TotalTokens: o.PromptEvalCount + o.EvalCount, - } - } - - return res, nil -} - -type streamResponse struct { - reader io.ReadCloser - withTools bool - index int - res openai.ChatCompletionStreamResponse -} - -func (s *streamResponse) Recv() (openai.ChatCompletionStreamResponse, error) { - var buf []byte - var err error - - if s.withTools { - buf, err = io.ReadAll(s.reader) - if err != nil { - return s.res, err - } - } else { - buf = make([]byte, 1024) - n, err := s.reader.Read(buf) - if err != nil { - return s.res, err - } - buf = buf[:n] - } - - ylog.Debug("ollama chat stream", "delta", string(buf)) - if len(buf) == 0 { - s.reader.Close() - return s.res, io.EOF - } - - var o ollamaResponse - err = json.Unmarshal(buf, &o) - if err != nil { - return s.res, err - } - - ylog.Debug("ollama chat stream response", "response", o.Response, "done", o.Done) - - s.res.Choices[0].Index++ - s.res.Choices[0].Delta.Content = o.Response - if o.Done { - ylog.Debug("ollama chat stream usage", "prompt_tokens", o.PromptEvalCount, "completion_tokens", o.EvalCount) - - s.res.Choices[0].FinishReason = openai.FinishReasonStop - s.res.Usage = &openai.Usage{ - PromptTokens: o.PromptEvalCount, - CompletionTokens: o.EvalCount, - TotalTokens: o.PromptEvalCount + o.EvalCount, - } - - if s.withTools { - toolCalls := parseToolCallsFromResponse(o.Response) - if len(toolCalls) > 0 { - for i := 0; i < len(toolCalls); i++ { - index := i - toolCalls[index].Index = &index - } - s.res.Choices[0].FinishReason = openai.FinishReasonToolCalls - s.res.Choices[0].Delta.ToolCalls = toolCalls - } - } + if req.Model == "" { + req.Model = p.Model } - - return s.res, nil + return p.client.CreateChatCompletion(ctx, req) } // GetChatCompletionsStream implements ai.LLMProvider. func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (provider.ResponseRecver, error) { - urlPath, err := url.JoinPath(p.Endpoint, "api/generate") - if err != nil { - return nil, err - } - - body, err := makeOllamaRequestBody(req) - if err != nil { - return nil, err - } - - client := http.Client{} - request, err := http.NewRequestWithContext(ctx, http.MethodPost, urlPath, body) - if err != nil { - return nil, err - } - - request.Header.Set("Content-Type", "application/json") - request.Header.Set("Accept", "text/event-stream") - request.Header.Set("Cache-Control", "no-cache") - request.Header.Set("Connection", "keep-alive") - - resp, err := client.Do(request) - if err != nil { - return nil, err - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("ollama inference error: %s", resp.Status) + if req.Model == "" { + req.Model = p.Model } - - return &streamResponse{ - reader: resp.Body, - withTools: len(req.Tools) > 0, - index: 0, - res: openai.ChatCompletionStreamResponse{ - ID: "chatcmpl-" + id.New(29), - Model: req.Model, - Created: time.Now().Unix(), - Choices: []openai.ChatCompletionStreamChoice{ - { - Index: -1, - Delta: openai.ChatCompletionStreamChoiceDelta{ - Role: openai.ChatMessageRoleAssistant, - }, - }, - }, - }, - }, nil + return p.client.CreateChatCompletionStream(ctx, req) } // Name implements ai.LLMProvider. @@ -353,14 +44,26 @@ func (p *Provider) Name() string { } // NewProvider creates a new OllamaProvider -func NewProvider(endpoint string) *Provider { - if endpoint == "" { +func NewProvider(apiEndpoint string, model string) *Provider { + if apiEndpoint == "" { v, ok := os.LookupEnv("OLLAMA_API_ENDPOINT") if ok { - endpoint = v + apiEndpoint = v } else { - endpoint = "http://localhost:11434/" + apiEndpoint = "http://localhost:11434/v1" + } + } + if model == "" { + v, ok := os.LookupEnv("OLLAMA_MODEL") + if ok { + model = v } } - return &Provider{endpoint} + config := openai.DefaultConfig("ollama") + config.BaseURL = apiEndpoint + return &Provider{ + APIEndpoint: apiEndpoint, + Model: model, + client: openai.NewClientWithConfig(config), + } } diff --git a/pkg/bridge/ai/service.go b/pkg/bridge/ai/service.go index 3dd456bff..a288962d6 100644 --- a/pkg/bridge/ai/service.go +++ b/pkg/bridge/ai/service.go @@ -252,6 +252,11 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl toolCalls = []openai.ToolCall{} assistantMessage = openai.ChatCompletionMessage{} ) + rawReq := req + // ollama request patch + // WARN: this is a temporary solution for ollama provider + req = srv.patchOllamaRequest(req) + // 4. request first chat for getting tools if req.Stream { _, firstCallSpan := srv.option.Tracer.Start(reqCtx, "first_call_request") @@ -357,6 +362,56 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl toolCalls = append(toolCalls, resp.Choices[0].Message.ToolCalls...) assistantMessage = resp.Choices[0].Message firstCallSpan.End() + } else if rawReq.Stream { + // if raw request is stream mode, we should return the stream response + // WARN: this is a temporary solution for ollama provider + flusher := eventFlusher(w) + // choices + choices := make([]openai.ChatCompletionStreamChoice, 0) + for _, choice := range resp.Choices { + delta := openai.ChatCompletionStreamChoiceDelta{ + Content: choice.Message.Content, + Role: choice.Message.Role, + FunctionCall: choice.Message.FunctionCall, + ToolCalls: choice.Message.ToolCalls, + } + choices = append(choices, openai.ChatCompletionStreamChoice{ + Index: choice.Index, + Delta: delta, + FinishReason: choice.FinishReason, + // ContentFilterResults + }) + } + // chunk response + streamRes := openai.ChatCompletionStreamResponse{ + ID: resp.ID, + Object: "chat.completion.chunk", + Created: resp.Created, + Model: resp.Model, + Choices: choices, + SystemFingerprint: resp.SystemFingerprint, + // PromptAnnotations: + // PromptFilterResults: + } + writeStreamEvent(w, flusher, streamRes) + // usage + if req.StreamOptions != nil && req.StreamOptions.IncludeUsage { + streamRes = openai.ChatCompletionStreamResponse{ + ID: resp.ID, + Object: "chat.completion.chunk", + Created: resp.Created, + Model: resp.Model, + SystemFingerprint: resp.SystemFingerprint, + Usage: &openai.Usage{ + PromptTokens: promptUsage, + CompletionTokens: completionUsage, + TotalTokens: totalUsage, + }, + } + writeStreamEvent(w, flusher, streamRes) + } + // done + return writeStreamDone(w, flusher) } else { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(resp) @@ -387,6 +442,8 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl if srv.provider.Name() != "anthropic" { req.Tools = nil // reset tools field } + // restore the original request stream field + req.Stream = rawReq.Stream srv.logger.Debug(" #2 second call", "request", fmt.Sprintf("%+v", req)) @@ -643,3 +700,24 @@ func recordTTFT(ctx context.Context, tracer trace.Tracer) { span.End() time.Sleep(time.Millisecond) } + +// patchOllamaRequest patch the request for ollama provider(ollama function calling unsupported in stream mode) +func (srv *Service) patchOllamaRequest(req openai.ChatCompletionRequest) openai.ChatCompletionRequest { + ylog.Debug("before request", + "stream", req.Stream, + "provider", srv.provider.Name(), + fmt.Sprintf("tools[%d]", len(req.Tools)), req.Tools, + ) + if !req.Stream { + return req + } + if srv.provider.Name() == "ollama" && len(req.Tools) > 0 { + req.Stream = false + } + ylog.Debug("patch request", + "stream", req.Stream, + "provider", srv.provider.Name(), + fmt.Sprintf("tools[%d]", len(req.Tools)), req.Tools, + ) + return req +}