diff --git a/pkg/bridge/ai/provider/azopenai/provider.go b/pkg/bridge/ai/provider/azopenai/provider.go index 03a3cc001..a61c7f14c 100644 --- a/pkg/bridge/ai/provider/azopenai/provider.go +++ b/pkg/bridge/ai/provider/azopenai/provider.go @@ -10,6 +10,7 @@ import ( "os" "sync" + // automatically load .env file _ "github.com/joho/godotenv/autoload" "github.com/yomorun/yomo/ai" @@ -78,10 +79,6 @@ type connectedFn struct { tc ai.ToolCall } -func init() { - fns = sync.Map{} -} - // NewProvider creates a new AzureOpenAIProvider func NewProvider(apiKey string, apiEndpoint string, deploymentID string, apiVersion string) *AzureOpenAIProvider { if apiKey == "" { @@ -201,7 +198,7 @@ func (p *AzureOpenAIProvider) GetChatCompletions(userInstruction string) (*ai.In // functions may be more than one // slog.Info("tool calls", "calls", calls, "mapTools", mapTools) for _, call := range calls { - fns.Range(func(key, value interface{}) bool { + fns.Range(func(_, value interface{}) bool { fn := value.(*connectedFn) if fn.tc.Equal(&call) { // Use toolCalls because tool_id is required in the following llm request diff --git a/pkg/bridge/ai/provider/azopenai/provider_test.go b/pkg/bridge/ai/provider/azopenai/provider_test.go new file mode 100644 index 000000000..5cdcfcc61 --- /dev/null +++ b/pkg/bridge/ai/provider/azopenai/provider_test.go @@ -0,0 +1,104 @@ +package azopenai + +import ( + "os" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/yomorun/yomo/ai" +) + +func TestNewProvider(t *testing.T) { + // Set environment variables for testing + os.Setenv("AZURE_OPENAI_API_KEY", "test_api_key") + os.Setenv("AZURE_OPENAI_API_ENDPOINT", "test_api_endpoint") + os.Setenv("AZURE_OPENAI_DEPLOYMENT_ID", "test_deployment_id") + os.Setenv("AZURE_OPENAI_API_VERSION", "test_api_version") + + provider := NewProvider("", "", "", "") + + assert.Equal(t, "test_api_key", provider.APIKey) + assert.Equal(t, "test_api_endpoint", provider.APIEndpoint) + assert.Equal(t, "test_deployment_id", provider.DeploymentID) + assert.Equal(t, "test_api_version", provider.APIVersion) +} + +func TestAzureOpenAIProvider_Name(t *testing.T) { + provider := &AzureOpenAIProvider{} + + name := provider.Name() + + assert.Equal(t, "azopenai", name) +} + +func TestAzureOpenAIProvider_RegisterFunction(t *testing.T) { + fns = sync.Map{} + provider := &AzureOpenAIProvider{} + tag := uint32(66) + functionDefinition := &ai.FunctionDefinition{ + Name: "TestFunction", + } + connID := uint64(88) + + err := provider.RegisterFunction(tag, functionDefinition, connID) + assert.NoError(t, err) + + fn, ok := fns.Load(connID) + assert.True(t, ok) + assert.Equal(t, connID, fn.(*connectedFn).connID) + assert.Equal(t, tag, fn.(*connectedFn).tag) + assert.Equal(t, "function", fn.(*connectedFn).tc.Type) + assert.Equal(t, functionDefinition.Name, fn.(*connectedFn).tc.Function.Name) + +} + +func TestAzureOpenAIProvider_UnregisterFunction(t *testing.T) { + provider := &AzureOpenAIProvider{} + err := provider.UnregisterFunction("", 1) + assert.NoError(t, err) + _, ok := fns.Load(1) + assert.False(t, ok) +} + +func TestAzureOpenAIProvider_ListToolCalls(t *testing.T) { + fns = sync.Map{} + provider := &AzureOpenAIProvider{} + + // Add a connectedFn to fns for testing + fns.Store(1, &connectedFn{ + tag: 0x16, + tc: ai.ToolCall{ + Type: "function", + Function: &ai.FunctionDefinition{ + Name: "TestFunction", + }, + }, + }) + + toolCalls, err := provider.ListToolCalls() + + assert.NoError(t, err) + assert.NotNil(t, toolCalls[0x16]) + assert.Equal(t, toolCalls[0x16].Function.Name, "TestFunction") +} + +func TestAzureOpenAIProvider_GetOverview(t *testing.T) { + fns = sync.Map{} + provider := &AzureOpenAIProvider{} + + // Add a connectedFn to fns for testing + fns.Store(1, &connectedFn{ + tag: 0x16, + tc: ai.ToolCall{Function: &ai.FunctionDefinition{ + Name: "TestFunction", + }}, + }) + + overview, err := provider.GetOverview() + + assert.NoError(t, err) + assert.NotNil(t, overview) + assert.NotNil(t, overview.Functions[0x16]) + assert.Equal(t, overview.Functions[0x16].Name, "TestFunction") +} diff --git a/pkg/bridge/ai/provider/openai/provider.go b/pkg/bridge/ai/provider/openai/provider.go index 017d5a4af..b06dceb48 100644 --- a/pkg/bridge/ai/provider/openai/provider.go +++ b/pkg/bridge/ai/provider/openai/provider.go @@ -1,3 +1,4 @@ +// Package openai is the OpenAI llm provider package openai import ( @@ -9,45 +10,58 @@ import ( "os" "sync" + // automatically load .env file _ "github.com/joho/godotenv/autoload" "github.com/yomorun/yomo/ai" "github.com/yomorun/yomo/core/ylog" ) +// APIEndpoint is the endpoint for OpenAI const APIEndpoint = "https://api.openai.com/v1/chat/completions" var fns sync.Map -// Message +// ChatCompletionMessage describes `messages` for /chat/completions type ChatCompletionMessage struct { - Role string `json:"role"` + // Role is the messages author + Role string `json:"role"` + // Content of the message Content string `json:"content"` - // - https://github.com/openai/openai-python/blob/main/chatml.md - // - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + // Name describes participant, provides the model information to differentiate + // between participants of the same role. Name string `json:"name,omitempty"` - // MultiContent []ChatMessagePart - // For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls. - ToolCalls []ai.ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + // ToolCalls describes the tool calls generated by the model. + ToolCalls []ai.ToolCall `json:"tool_calls,omitempty"` + // ToolCallID is the ID of the tool call + ToolCallID string `json:"tool_call_id,omitempty"` } -// RequestBody is the request body +// ReqBody is the request body type ReqBody struct { - Model string `json:"model"` + // Model describes the ID of the model to use for the completion. + Model string `json:"model"` + // Messages describes the messages in the conversation. Messages []ChatCompletionMessage `json:"messages"` - Tools []ai.ToolCall `json:"tools"` // chatCompletionTool - // ToolChoice string `json:"tool_choice"` // chatCompletionFunction + // Tools describes the tool calls generated by the model. + Tools []ai.ToolCall `json:"tools"` // chatCompletionTool } -// Resp is the response body +// RespBody is the response body type RespBody struct { - ID string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - Model string `json:"model"` - Choices []RespChoice `json:"choices"` - Usage RespUsage `json:"usage"` - SystemFingerprint string `json:"system_fingerprint"` + // ID is the unique identifier for the chat completion. + ID string `json:"id"` + // Object describes the object type, it is always "chat.completion". + Object string `json:"object"` + // Created describes the timestamp when the chat completion was created. + Created int `json:"created"` + // Model describes the model used for the chat completion. + Model string `json:"model"` + // Choices describes the choices made by the model, can more than one if `n`>1 + Choices []RespChoice `json:"choices"` + // Usage describes the token usage statistics for the chat completion request. + Usage RespUsage `json:"usage"` + // SystemFingerprint describes the system fingerprint of the chat completion. + SystemFingerprint string `json:"system_fingerprint"` } // RespMessage is the message in Response @@ -222,7 +236,7 @@ func (p *OpenAIProvider) RegisterFunction(tag uint32, functionDefinition *ai.Fun // UnregisterFunction unregister function // Be careful: a function can have multiple instances, remove the offline instance only. -func (p *OpenAIProvider) UnregisterFunction(name string, connID uint64) error { +func (p *OpenAIProvider) UnregisterFunction(_ string, connID uint64) error { fns.Delete(connID) return nil } diff --git a/pkg/bridge/ai/provider/openai/provider_test.go b/pkg/bridge/ai/provider/openai/provider_test.go new file mode 100644 index 000000000..2c92f3bef --- /dev/null +++ b/pkg/bridge/ai/provider/openai/provider_test.go @@ -0,0 +1,92 @@ +package openai + +import ( + "os" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/yomorun/yomo/ai" +) + +func TestOpenAIProvider_RegisterFunction(t *testing.T) { + fns = sync.Map{} + provider := &OpenAIProvider{} + tag := uint32(66) + functionDefinition := &ai.FunctionDefinition{ + Name: "TestFunction", + } + connID := uint64(88) + + err := provider.RegisterFunction(tag, functionDefinition, connID) + assert.NoError(t, err) + + fn, ok := fns.Load(connID) + assert.True(t, ok) + assert.Equal(t, connID, fn.(*connectedFn).connID) + assert.Equal(t, tag, fn.(*connectedFn).tag) + assert.Equal(t, "function", fn.(*connectedFn).tc.Type) + assert.Equal(t, functionDefinition.Name, fn.(*connectedFn).tc.Function.Name) +} + +func TestOpenAIProvider_UnregisterFunction(t *testing.T) { + provider := &OpenAIProvider{} + connID := uint64(1) + + // Assuming a function is already registered with connID + err := provider.UnregisterFunction("", connID) + assert.NoError(t, err) + + _, ok := fns.Load(connID) + assert.False(t, ok) +} + +func TestOpenAIProvider_ListToolCalls(t *testing.T) { + provider := &OpenAIProvider{} + + // Assuming some functions are already registered + toolCalls, err := provider.ListToolCalls() + assert.NoError(t, err) + + // Replace with your own checks + assert.NotEmpty(t, toolCalls) +} + +func TestOpenAIProvider_GetOverview(t *testing.T) { + provider := &OpenAIProvider{} + + // Assuming some functions are already registered + overview, err := provider.GetOverview() + assert.NoError(t, err) + + // Replace with your own checks + assert.NotEmpty(t, overview.Functions) +} + +func TestHasToolCalls(t *testing.T) { + // Assuming some functions are already registered + toolCalls, hasCalls := hasToolCalls() + + // Replace with your own checks + assert.True(t, hasCalls) + assert.NotEmpty(t, toolCalls) +} + +func TestNewProvider(t *testing.T) { + // Set environment variables for testing + os.Setenv("OPENAI_API_KEY", "test_api_key") + os.Setenv("OPENAI_MODEL", "test_model") + + provider := NewProvider("", "") + + assert.Equal(t, "test_api_key", provider.APIKey) + assert.Equal(t, "test_model", provider.Model) +} + +func TestOpenAIProvider_Name(t *testing.T) { + provider := &OpenAIProvider{} + + name := provider.Name() + + assert.Equal(t, "openai", name) +}