From 4867d39c66a6c16906cd769a2055dea9f66f1ccb Mon Sep 17 00:00:00 2001 From: JuHyung Son Date: Fri, 14 Jun 2024 16:39:56 +0900 Subject: [PATCH] feat: support openai organization Id (#1133) * feat: add organization flag Signed-off-by: JuHyung-Son * feat: add orgId on openai backend Signed-off-by: JuHyung-Son --------- Signed-off-by: JuHyung-Son Co-authored-by: Alex Jones --- cmd/auth/add.go | 3 ++ cmd/auth/auth.go | 1 + cmd/auth/update.go | 80 +++++++++++++++++++++++-------------------- pkg/ai/azureopenai.go | 12 +++++-- pkg/ai/iai.go | 6 ++++ pkg/ai/openai.go | 14 +++++--- 6 files changed, 72 insertions(+), 44 deletions(-) diff --git a/cmd/auth/add.go b/cmd/auth/add.go index f93f521d44..5464eadf9f 100644 --- a/cmd/auth/add.go +++ b/cmd/auth/add.go @@ -131,6 +131,7 @@ var addCmd = &cobra.Command{ TopP: topP, TopK: topK, MaxTokens: maxTokens, + OrganizationId: organizationId, } if providerIndex == -1 { @@ -176,4 +177,6 @@ func init() { addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider specific ID for e.g. project (only for googlevertexai backend)") //add flag for OCI Compartment ID addCmd.Flags().StringVarP(&compartmentId, "compartmentId", "k", "", "Compartment ID for generative AI model (only for oci backend)") + // add flag for openai organization + addCmd.Flags().StringVarP(&organizationId, "organizationId", "o", "", "OpenAI or AzureOpenAI Organization ID (only for openai and azureopenai backend)") } diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index 197ee7ac12..c8f4e209e9 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -32,6 +32,7 @@ var ( topP float32 topK int32 maxTokens int + organizationId string ) var configAI ai.AIConfiguration diff --git a/cmd/auth/update.go b/cmd/auth/update.go index eb9a0e79ef..2030462d8f 100644 --- a/cmd/auth/update.go +++ b/cmd/auth/update.go @@ -26,13 +26,20 @@ var updateCmd = &cobra.Command{ Use: "update", Short: "Update a backend provider", Long: "The command to update an AI backend provider", - Args: cobra.ExactArgs(1), + // Args: cobra.ExactArgs(1), PreRun: func(cmd *cobra.Command, args []string) { backend, _ := cmd.Flags().GetString("backend") if strings.ToLower(backend) == "azureopenai" { _ = cmd.MarkFlagRequired("engine") _ = cmd.MarkFlagRequired("baseurl") } + organizationId, _ := cmd.Flags().GetString("organizationId") + if strings.ToLower(backend) != "azureopenai" && strings.ToLower(backend) != "openai" { + if organizationId != "" { + color.Red("Error: organizationId must be empty for backends other than azureopenai or openai.") + os.Exit(1) + } + } }, Run: func(cmd *cobra.Command, args []string) { @@ -43,50 +50,47 @@ var updateCmd = &cobra.Command{ os.Exit(1) } - inputBackends := strings.Split(args[0], ",") + backend, _ := cmd.Flags().GetString("backend") - if len(inputBackends) == 0 { - color.Red("Error: backend must be set.") - os.Exit(1) - } if temperature > 1.0 || temperature < 0.0 { color.Red("Error: temperature ranges from 0 to 1.") os.Exit(1) } - for _, b := range inputBackends { - foundBackend := false - for i, provider := range configAI.Providers { - if b == provider.Name { - foundBackend = true - if backend != "" { - configAI.Providers[i].Name = backend - color.Blue("Backend name updated successfully") - } - if model != "" { - configAI.Providers[i].Model = model - color.Blue("Model updated successfully") - } - if password != "" { - configAI.Providers[i].Password = password - color.Blue("Password updated successfully") - } - if baseURL != "" { - configAI.Providers[i].BaseURL = baseURL - color.Blue("Base URL updated successfully") - } - if engine != "" { - configAI.Providers[i].Engine = engine - } - configAI.Providers[i].Temperature = temperature - color.Green("%s updated in the AI backend provider list", b) + foundBackend := false + for i, provider := range configAI.Providers { + if backend == provider.Name { + foundBackend = true + if backend != "" { + configAI.Providers[i].Name = backend + color.Blue("Backend name updated successfully") } + if model != "" { + configAI.Providers[i].Model = model + color.Blue("Model updated successfully") + } + if password != "" { + configAI.Providers[i].Password = password + color.Blue("Password updated successfully") + } + if baseURL != "" { + configAI.Providers[i].BaseURL = baseURL + color.Blue("Base URL updated successfully") + } + if engine != "" { + configAI.Providers[i].Engine = engine + } + if organizationId != "" { + configAI.Providers[i].OrganizationId = organizationId + color.Blue("Organization Id updated successfully") + } + configAI.Providers[i].Temperature = temperature + color.Green("%s updated in the AI backend provider list", backend) } - if !foundBackend { - color.Red("Error: %s does not exist in configuration file. Please use k8sgpt auth new.", args[0]) - os.Exit(1) - } - + } + if !foundBackend { + color.Red("Error: %s does not exist in configuration file. Please use k8sgpt auth new.", args[0]) + os.Exit(1) } viper.Set("ai", configAI) @@ -110,4 +114,6 @@ func init() { updateCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)") // update flag for azure open ai engine/deployment name updateCmd.Flags().StringVarP(&engine, "engine", "e", "", "Update Azure AI deployment name") + // update flag for organizationId + updateCmd.Flags().StringVarP(&organizationId, "organizationId", "o", "", "Update OpenAI or Azure organization Id") } diff --git a/pkg/ai/azureopenai.go b/pkg/ai/azureopenai.go index ffeda18d3d..292bfbd1d9 100644 --- a/pkg/ai/azureopenai.go +++ b/pkg/ai/azureopenai.go @@ -14,9 +14,10 @@ const azureAIClientName = "azureopenai" type AzureAIClient struct { nopCloser - client *openai.Client - model string - temperature float32 + client *openai.Client + model string + temperature float32 + organizationId string } func (c *AzureAIClient) Configure(config IAIConfig) error { @@ -25,6 +26,7 @@ func (c *AzureAIClient) Configure(config IAIConfig) error { engine := config.GetEngine() proxyEndpoint := config.GetProxyEndpoint() defaultConfig := openai.DefaultAzureConfig(token, baseURL) + orgId := config.GetOrganizationId() defaultConfig.AzureModelMapperFunc = func(model string) string { // If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function @@ -48,6 +50,10 @@ func (c *AzureAIClient) Configure(config IAIConfig) error { Transport: transport, } } + if orgId != "" { + defaultConfig.OrgID = orgId + } + client := openai.NewClientWithConfig(defaultConfig) if client == nil { return errors.New("error creating Azure OpenAI client") diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index 7559749e47..08caa02079 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -78,6 +78,7 @@ type IAIConfig interface { GetMaxTokens() int GetProviderId() string GetCompartmentId() string + GetOrganizationId() string } func NewClient(provider string) IAI { @@ -111,6 +112,7 @@ type AIProvider struct { TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"` TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"` MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"` + OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"` } func (p *AIProvider) GetBaseURL() string { @@ -164,6 +166,10 @@ func (p *AIProvider) GetCompartmentId() string { return p.CompartmentId } +func (p *AIProvider) GetOrganizationId() string { + return p.OrganizationId +} + var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"} func NeedPassword(backend string) bool { diff --git a/pkg/ai/openai.go b/pkg/ai/openai.go index ff032f7391..8ed0f0ced8 100644 --- a/pkg/ai/openai.go +++ b/pkg/ai/openai.go @@ -27,10 +27,11 @@ const openAIClientName = "openai" type OpenAIClient struct { nopCloser - client *openai.Client - model string - temperature float32 - topP float32 + client *openai.Client + model string + temperature float32 + topP float32 + organizationId string } const ( @@ -43,6 +44,7 @@ const ( func (c *OpenAIClient) Configure(config IAIConfig) error { token := config.GetPassword() defaultConfig := openai.DefaultConfig(token) + orgId := config.GetOrganizationId() proxyEndpoint := config.GetProxyEndpoint() baseURL := config.GetBaseURL() @@ -64,6 +66,10 @@ func (c *OpenAIClient) Configure(config IAIConfig) error { } } + if orgId != "" { + defaultConfig.OrgID = orgId + } + client := openai.NewClientWithConfig(defaultConfig) if client == nil { return errors.New("error creating OpenAI client")