Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support openai organization Id #1133

Merged
merged 3 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cmd/auth/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ var addCmd = &cobra.Command{
TopP: topP,
TopK: topK,
MaxTokens: maxTokens,
OrganizationId: organizationId,
}

if providerIndex == -1 {
Expand Down Expand Up @@ -176,4 +177,6 @@ func init() {
addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider specific ID for e.g. project (only for googlevertexai backend)")
//add flag for OCI Compartment ID
addCmd.Flags().StringVarP(&compartmentId, "compartmentId", "k", "", "Compartment ID for generative AI model (only for oci backend)")
// add flag for openai organization
addCmd.Flags().StringVarP(&organizationId, "organizationId", "o", "", "OpenAI or AzureOpenAI Organization ID (only for openai and azureopenai backend)")
}
1 change: 1 addition & 0 deletions cmd/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var (
topP float32
topK int32
maxTokens int
organizationId string
)

var configAI ai.AIConfiguration
Expand Down
80 changes: 43 additions & 37 deletions cmd/auth/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,20 @@ var updateCmd = &cobra.Command{
Use: "update",
Short: "Update a backend provider",
Long: "The command to update an AI backend provider",
Args: cobra.ExactArgs(1),
// Args: cobra.ExactArgs(1),
PreRun: func(cmd *cobra.Command, args []string) {
backend, _ := cmd.Flags().GetString("backend")
if strings.ToLower(backend) == "azureopenai" {
_ = cmd.MarkFlagRequired("engine")
_ = cmd.MarkFlagRequired("baseurl")
}
organizationId, _ := cmd.Flags().GetString("organizationId")
if strings.ToLower(backend) != "azureopenai" && strings.ToLower(backend) != "openai" {
if organizationId != "" {
color.Red("Error: organizationId must be empty for backends other than azureopenai or openai.")
os.Exit(1)
}
}
},
Run: func(cmd *cobra.Command, args []string) {

Expand All @@ -43,50 +50,47 @@ var updateCmd = &cobra.Command{
os.Exit(1)
}

inputBackends := strings.Split(args[0], ",")
backend, _ := cmd.Flags().GetString("backend")

if len(inputBackends) == 0 {
color.Red("Error: backend must be set.")
os.Exit(1)
}
if temperature > 1.0 || temperature < 0.0 {
color.Red("Error: temperature ranges from 0 to 1.")
os.Exit(1)
}

for _, b := range inputBackends {
foundBackend := false
for i, provider := range configAI.Providers {
if b == provider.Name {
foundBackend = true
if backend != "" {
configAI.Providers[i].Name = backend
color.Blue("Backend name updated successfully")
}
if model != "" {
configAI.Providers[i].Model = model
color.Blue("Model updated successfully")
}
if password != "" {
configAI.Providers[i].Password = password
color.Blue("Password updated successfully")
}
if baseURL != "" {
configAI.Providers[i].BaseURL = baseURL
color.Blue("Base URL updated successfully")
}
if engine != "" {
configAI.Providers[i].Engine = engine
}
configAI.Providers[i].Temperature = temperature
color.Green("%s updated in the AI backend provider list", b)
foundBackend := false
for i, provider := range configAI.Providers {
if backend == provider.Name {
foundBackend = true
if backend != "" {
configAI.Providers[i].Name = backend
color.Blue("Backend name updated successfully")
}
if model != "" {
configAI.Providers[i].Model = model
color.Blue("Model updated successfully")
}
if password != "" {
configAI.Providers[i].Password = password
color.Blue("Password updated successfully")
}
if baseURL != "" {
configAI.Providers[i].BaseURL = baseURL
color.Blue("Base URL updated successfully")
}
if engine != "" {
configAI.Providers[i].Engine = engine
}
if organizationId != "" {
configAI.Providers[i].OrganizationId = organizationId
color.Blue("Organization Id updated successfully")
}
configAI.Providers[i].Temperature = temperature
color.Green("%s updated in the AI backend provider list", backend)
}
if !foundBackend {
color.Red("Error: %s does not exist in configuration file. Please use k8sgpt auth new.", args[0])
os.Exit(1)
}

}
if !foundBackend {
color.Red("Error: %s does not exist in configuration file. Please use k8sgpt auth new.", args[0])
os.Exit(1)
}

viper.Set("ai", configAI)
Expand All @@ -110,4 +114,6 @@ func init() {
updateCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
// update flag for azure open ai engine/deployment name
updateCmd.Flags().StringVarP(&engine, "engine", "e", "", "Update Azure AI deployment name")
// update flag for organizationId
updateCmd.Flags().StringVarP(&organizationId, "organizationId", "o", "", "Update OpenAI or Azure organization Id")
}
12 changes: 9 additions & 3 deletions pkg/ai/azureopenai.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
type AzureAIClient struct {
nopCloser

client *openai.Client
model string
temperature float32
client *openai.Client
model string
temperature float32
organizationId string

Check failure on line 20 in pkg/ai/azureopenai.go

View workflow job for this annotation

GitHub Actions / golangci-lint

[golangci] reported by reviewdog 🐶 field `organizationId` is unused (unused) Raw Output: pkg/ai/azureopenai.go:20:2: field `organizationId` is unused (unused) organizationId string ^
}

func (c *AzureAIClient) Configure(config IAIConfig) error {
Expand All @@ -25,6 +26,7 @@
engine := config.GetEngine()
proxyEndpoint := config.GetProxyEndpoint()
defaultConfig := openai.DefaultAzureConfig(token, baseURL)
orgId := config.GetOrganizationId()

defaultConfig.AzureModelMapperFunc = func(model string) string {
// If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function
Expand All @@ -48,6 +50,10 @@
Transport: transport,
}
}
if orgId != "" {
defaultConfig.OrgID = orgId
}

client := openai.NewClientWithConfig(defaultConfig)
if client == nil {
return errors.New("error creating Azure OpenAI client")
Expand Down
6 changes: 6 additions & 0 deletions pkg/ai/iai.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ type IAIConfig interface {
GetMaxTokens() int
GetProviderId() string
GetCompartmentId() string
GetOrganizationId() string
}

func NewClient(provider string) IAI {
Expand Down Expand Up @@ -111,6 +112,7 @@ type AIProvider struct {
TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"`
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"`
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"`
}

func (p *AIProvider) GetBaseURL() string {
Expand Down Expand Up @@ -164,6 +166,10 @@ func (p *AIProvider) GetCompartmentId() string {
return p.CompartmentId
}

func (p *AIProvider) GetOrganizationId() string {
return p.OrganizationId
}

var passwordlessProviders = []string{"localai", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"}

func NeedPassword(backend string) bool {
Expand Down
14 changes: 10 additions & 4 deletions pkg/ai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@
type OpenAIClient struct {
nopCloser

client *openai.Client
model string
temperature float32
topP float32
client *openai.Client
model string
temperature float32
topP float32
organizationId string

Check failure on line 34 in pkg/ai/openai.go

View workflow job for this annotation

GitHub Actions / golangci-lint

[golangci] reported by reviewdog 🐶 field `organizationId` is unused (unused) Raw Output: pkg/ai/openai.go:34:2: field `organizationId` is unused (unused) organizationId string ^
}

const (
Expand All @@ -43,6 +44,7 @@
func (c *OpenAIClient) Configure(config IAIConfig) error {
token := config.GetPassword()
defaultConfig := openai.DefaultConfig(token)
orgId := config.GetOrganizationId()
proxyEndpoint := config.GetProxyEndpoint()

baseURL := config.GetBaseURL()
Expand All @@ -64,6 +66,10 @@
}
}

if orgId != "" {
defaultConfig.OrgID = orgId
}

client := openai.NewClientWithConfig(defaultConfig)
if client == nil {
return errors.New("error creating OpenAI client")
Expand Down
Loading