diff --git a/go.mod b/go.mod index b7074e273f..1d578bb5b6 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/fatih/color v1.15.0 github.com/magiconair/properties v1.8.7 github.com/mittwald/go-helm-client v0.12.1 - github.com/sashabaranov/go-openai v1.9.3 + github.com/sashabaranov/go-openai v1.12.0 github.com/schollz/progressbar/v3 v3.13.1 github.com/spf13/cobra v1.7.0 github.com/spf13/viper v1.16.0 diff --git a/go.sum b/go.sum index cab4a0fcf7..ff8d8e35d1 100644 --- a/go.sum +++ b/go.sum @@ -1008,6 +1008,8 @@ github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= github.com/sashabaranov/go-openai v1.9.3 h1:uNak3Rn5pPsKRs9bdT7RqRZEyej/zdZOEI2/8wvrFtM= github.com/sashabaranov/go-openai v1.9.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/sashabaranov/go-openai v1.12.0 h1:aRNHH0gtVfrpIaEolD0sWrLLRnYQNK4cH/bIAHwL8Rk= +github.com/sashabaranov/go-openai v1.12.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/schollz/progressbar/v3 v3.13.1 h1:o8rySDYiQ59Mwzy2FELeHY5ZARXZTVJC7iHD6PEFUiE= github.com/schollz/progressbar/v3 v3.13.1/go.mod h1:xvrbki8kfT1fzWzBT/UZd9L6GA+jdL7HAgq2RFnO6fQ= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= diff --git a/pkg/ai/azureopenai.go b/pkg/ai/azureopenai.go index aff6708a86..0573caf72e 100644 --- a/pkg/ai/azureopenai.go +++ b/pkg/ai/azureopenai.go @@ -25,7 +25,16 @@ func (c *AzureAIClient) Configure(config IAIConfig, lang string) error { token := config.GetPassword() baseURL := config.GetBaseURL() engine := config.GetEngine() - defaultConfig := openai.DefaultAzureConfig(token, baseURL, engine) + defaultConfig := openai.DefaultAzureConfig(token, baseURL) + + defaultConfig.AzureModelMapperFunc = func(model string) string { + // If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function + azureModelMapping := map[string]string{ + model: engine, + } + return azureModelMapping[model] + + } client := openai.NewClientWithConfig(defaultConfig) if client == nil { return errors.New("error creating Azure OpenAI client") @@ -42,7 +51,7 @@ func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string, prompt Model: c.model, Messages: []openai.ChatCompletionMessage{ { - Role: "user", + Role: openai.ChatMessageRoleUser, Content: fmt.Sprintf(default_prompt, c.language, prompt), }, },