Skip to content

Commit

Permalink
feat: add OpenAI provider (#735)
Browse files Browse the repository at this point in the history
## Related Docs

- https://platform.openai.com/docs/api-reference/chat/create

## How to use this provider

```yaml
bridge:
  ai:
    server:
      addr: localhost:8000
      provider: openai

    providers:
      openai:
        api_key: <your-api-key>
        model: <gpt-3.5-turbo-1106>
```

Co-authored-by: C.C <fanweixiao@gmail.com>
  • Loading branch information
venjiang and fanweixiao authored Feb 28, 2024
1 parent 8d8e24a commit cad974b
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 2 deletions.
2 changes: 1 addition & 1 deletion ai/function_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (fco *FunctionCall) Bytes() ([]byte, error) {

// FromBytes deserialize the FunctionCallObject from the given []byte
func (fco *FunctionCall) FromBytes(b []byte) error {
var obj = &FunctionCall{}
obj := &FunctionCall{}
err := json.Unmarshal(b, &obj)
if err != nil {
return err
Expand Down
5 changes: 5 additions & 0 deletions cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

"github.com/yomorun/yomo/pkg/bridge/ai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/azopenai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/openai"
)

// serveCmd represents the serve command
Expand Down Expand Up @@ -133,6 +134,10 @@ func registerAIProvider(aiConfig *ai.Config) {
log.InfoStatusEvent(os.Stdout, "register [%s] AI provider", name)
// TODO: register other providers
}
// register the OpenAI provider
if name == "openai" {
ai.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"]))
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion example/10-ai/zipper.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ bridge:

openai:
api_key:
api_endpoint:
model:

gemini:
api_key:
270 changes: 270 additions & 0 deletions pkg/bridge/ai/provider/openai/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
package openai

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"sync"

_ "github.com/joho/godotenv/autoload"
"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/core/ylog"
)

const APIEndpoint = "https://api.openai.com/v1/chat/completions"

var fns sync.Map

// Message
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
// - https://github.com/openai/openai-python/blob/main/chatml.md
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
Name string `json:"name,omitempty"`
// MultiContent []ChatMessagePart
// For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls.
ToolCalls []ai.ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}

// RequestBody is the request body
type ReqBody struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
Tools []ai.ToolCall `json:"tools"` // chatCompletionTool
// ToolChoice string `json:"tool_choice"` // chatCompletionFunction
}

// Resp is the response body
type RespBody struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Choices []RespChoice `json:"choices"`
Usage RespUsage `json:"usage"`
SystemFingerprint string `json:"system_fingerprint"`
}

// RespMessage is the message in Response
type RespMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ai.ToolCall `json:"tool_calls"`
}

// RespChoice is used to indicate the choice in Response by `FinishReason`
type RespChoice struct {
FinishReason string `json:"finish_reason"`
Index int `json:"index"`
Message ChatCompletionMessage `json:"message"`
}

// RespUsage is the token usage in Response
type RespUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

// OpenAIProvider is the provider for OpenAI
type OpenAIProvider struct {
// APIKey is the API key for OpenAI
APIKey string
// Model is the model for OpenAI
// eg. "gpt-3.5-turbo-1106", "gpt-4-turbo-preview", "gpt-4-vision-preview", "gpt-4"
Model string
}

type connectedFn struct {
connID uint64
tag uint32
tc ai.ToolCall
}

func init() {
fns = sync.Map{}
}

// NewProvider creates a new OpenAIProvider
func NewProvider(apiKey string, model string) *OpenAIProvider {
if apiKey == "" {
apiKey = os.Getenv("OPENAI_API_KEY")
}
if model == "" {
model = os.Getenv("OPENAI_MODEL")
}
ylog.Debug("new openai provider", "api_endpoint", APIEndpoint, "api_key", apiKey, "model", model)
return &OpenAIProvider{
APIKey: apiKey,
Model: model,
}
}

// Name returns the name of the provider
func (p *OpenAIProvider) Name() string {
return "openai"
}

// GetChatCompletions get chat completions for ai service
func (p *OpenAIProvider) GetChatCompletions(userInstruction string) (*ai.InvokeResponse, error) {
toolCalls, ok := hasToolCalls()
if !ok {
ylog.Error(ai.ErrNoFunctionCall.Error())
return &ai.InvokeResponse{Content: "no toolcalls"}, ai.ErrNoFunctionCall
}

// messages
messages := []ChatCompletionMessage{
{Role: "system", Content: `You are a very helpful assistant. Your job is to choose the best possible action to solve the user question or task. Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous. If you don't know the answer, stop the conversation by saying "no func call".`},
{Role: "user", Content: userInstruction},
}

body := ReqBody{Model: p.Model, Messages: messages, Tools: toolCalls}
ylog.Debug("request", "tools", len(toolCalls), "messages", messages)

jsonBody, err := json.Marshal(body)
if err != nil {
return nil, err
}

req, err := http.NewRequest("POST", APIEndpoint, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
// OpenAI authentication
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", p.APIKey))

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
ylog.Debug("response", "body", respBody)

// ylog.Info("response body", "body", string(respBody))
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("ai response status code is %d", resp.StatusCode)
}

var respBodyStruct RespBody
err = json.Unmarshal(respBody, &respBodyStruct)
if err != nil {
return nil, err
}
// TODO: record usage
// usage := respBodyStruct.Usage
// log.Printf("Token Usage: %+v\n", usage)

choice := respBodyStruct.Choices[0]
ylog.Debug(">>finish_reason", "reason", choice.FinishReason)

calls := respBodyStruct.Choices[0].Message.ToolCalls
content := respBodyStruct.Choices[0].Message.Content

ylog.Debug("--response calls", "calls", calls)

result := &ai.InvokeResponse{}
if len(calls) == 0 {
result.Content = content
return result, ai.ErrNoFunctionCall
}

// functions may be more than one
// slog.Info("tool calls", "calls", calls, "mapTools", mapTools)
for _, call := range calls {
fns.Range(func(_, value any) bool {
fn := value.(*connectedFn)
if fn.tc.Equal(&call) {
// Use toolCalls because tool_id is required in the following llm request
if result.ToolCalls == nil {
result.ToolCalls = make(map[uint32][]*ai.ToolCall)
}
// Create a new variable to hold the current call
currentCall := call
result.ToolCalls[fn.tag] = append(result.ToolCalls[fn.tag], &currentCall)
}
return true
})
}

// sfn maybe disconnected, so we need to check if there is any function call
if len(result.ToolCalls) == 0 {
return nil, ai.ErrNoFunctionCall
}
return result, nil
}

// RegisterFunction register function
func (p *OpenAIProvider) RegisterFunction(tag uint32, functionDefinition *ai.FunctionDefinition, connID uint64) error {
fns.Store(connID, &connectedFn{
connID: connID,
tag: tag,
tc: ai.ToolCall{
Type: "function",
Function: functionDefinition,
},
})

return nil
}

// UnregisterFunction unregister function
// Be careful: a function can have multiple instances, remove the offline instance only.
func (p *OpenAIProvider) UnregisterFunction(name string, connID uint64) error {
fns.Delete(connID)
return nil
}

// ListToolCalls list tool functions
func (p *OpenAIProvider) ListToolCalls() (map[uint32]ai.ToolCall, error) {
tmp := make(map[uint32]ai.ToolCall)
fns.Range(func(_, value any) bool {
fn := value.(*connectedFn)
tmp[fn.tag] = fn.tc
return true
})

return tmp, nil
}

// GetOverview get overview for ai service
func (p *OpenAIProvider) GetOverview() (*ai.OverviewResponse, error) {
result := &ai.OverviewResponse{
Functions: make(map[uint32]*ai.FunctionDefinition),
}
_, ok := hasToolCalls()
if !ok {
return result, nil
}

fns.Range(func(_, value any) bool {
fn := value.(*connectedFn)
result.Functions[fn.tag] = fn.tc.Function
return true
})

return result, nil
}

// hasToolCalls check if there are tool calls
func hasToolCalls() ([]ai.ToolCall, bool) {
toolCalls := make([]ai.ToolCall, 0)
fns.Range(func(_, value any) bool {
fn := value.(*connectedFn)
toolCalls = append(toolCalls, fn.tc)
return true
})
return toolCalls, len(toolCalls) > 0
}

0 comments on commit cad974b

Please sign in to comment.