From 89162f4e2117b353f46633ec02e778cfee9d95c8 Mon Sep 17 00:00:00 2001 From: JuHyung-Son Date: Thu, 6 Jun 2024 22:04:04 +0900 Subject: [PATCH] feat: add orgId on openai backend Signed-off-by: JuHyung-Son --- pkg/ai/azureopenai.go | 12 +++++++++--- pkg/ai/iai.go | 6 ++++++ pkg/ai/openai.go | 14 ++++++++++---- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/pkg/ai/azureopenai.go b/pkg/ai/azureopenai.go index ffeda18d3d..292bfbd1d9 100644 --- a/pkg/ai/azureopenai.go +++ b/pkg/ai/azureopenai.go @@ -14,9 +14,10 @@ const azureAIClientName = "azureopenai" type AzureAIClient struct { nopCloser - client *openai.Client - model string - temperature float32 + client *openai.Client + model string + temperature float32 + organizationId string } func (c *AzureAIClient) Configure(config IAIConfig) error { @@ -25,6 +26,7 @@ func (c *AzureAIClient) Configure(config IAIConfig) error { engine := config.GetEngine() proxyEndpoint := config.GetProxyEndpoint() defaultConfig := openai.DefaultAzureConfig(token, baseURL) + orgId := config.GetOrganizationId() defaultConfig.AzureModelMapperFunc = func(model string) string { // If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function @@ -48,6 +50,10 @@ func (c *AzureAIClient) Configure(config IAIConfig) error { Transport: transport, } } + if orgId != "" { + defaultConfig.OrgID = orgId + } + client := openai.NewClientWithConfig(defaultConfig) if client == nil { return errors.New("error creating Azure OpenAI client") diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index 7559749e47..08caa02079 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -78,6 +78,7 @@ type IAIConfig interface { GetMaxTokens() int GetProviderId() string GetCompartmentId() string + GetOrganizationId() string } func NewClient(provider string) IAI { @@ -111,6 +112,7 @@ type AIProvider struct { TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"` TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"` MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"` + OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"` } func (p *AIProvider) GetBaseURL() string { @@ -164,6 +166,10 @@ func (p *AIProvider) GetCompartmentId() string { return p.CompartmentId } +func (p *AIProvider) GetOrganizationId() string { + return p.OrganizationId +} + var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"} func NeedPassword(backend string) bool { diff --git a/pkg/ai/openai.go b/pkg/ai/openai.go index ff032f7391..8ed0f0ced8 100644 --- a/pkg/ai/openai.go +++ b/pkg/ai/openai.go @@ -27,10 +27,11 @@ const openAIClientName = "openai" type OpenAIClient struct { nopCloser - client *openai.Client - model string - temperature float32 - topP float32 + client *openai.Client + model string + temperature float32 + topP float32 + organizationId string } const ( @@ -43,6 +44,7 @@ const ( func (c *OpenAIClient) Configure(config IAIConfig) error { token := config.GetPassword() defaultConfig := openai.DefaultConfig(token) + orgId := config.GetOrganizationId() proxyEndpoint := config.GetProxyEndpoint() baseURL := config.GetBaseURL() @@ -64,6 +66,10 @@ func (c *OpenAIClient) Configure(config IAIConfig) error { } } + if orgId != "" { + defaultConfig.OrgID = orgId + } + client := openai.NewClientWithConfig(defaultConfig) if client == nil { return errors.New("error creating OpenAI client")