From cf797a6eb67efba957704077b4b04ed3ee166c24 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 21 Apr 2023 21:04:34 +0200 Subject: [PATCH] feat: allow to set a baseurl (#310) * feat: allow to set a baseURL for OpenAI providers This allows to run local models that have a compatible OpenAI API, or for instance use a proxy. Signed-off-by: mudler * feat: allow to set baseURL in the auth subcommand Signed-off-by: mudler --------- Signed-off-by: mudler Co-authored-by: Alex Jones Co-authored-by: Matthis <99146727+matthisholleville@users.noreply.github.com> --- cmd/auth/auth.go | 4 ++++ cmd/serve/serve.go | 9 +++++++-- pkg/ai/iai.go | 6 ++++++ pkg/ai/openai.go | 6 ++++++ 4 files changed, 23 insertions(+), 2 deletions(-) diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index b0c35cc649..33cba9bcad 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -29,6 +29,7 @@ import ( var ( backend string password string + baseURL string model string ) @@ -86,6 +87,7 @@ var AuthCmd = &cobra.Command{ Name: backend, Model: model, Password: password, + BaseURL: baseURL, } if providerIndex == -1 { @@ -113,4 +115,6 @@ func init() { AuthCmd.Flags().StringVarP(&model, "model", "m", "gpt-3.5-turbo", "Backend AI model") // add flag for password AuthCmd.Flags().StringVarP(&password, "password", "p", "", "Backend AI password") + // add flag for url + AuthCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)") } diff --git a/cmd/serve/serve.go b/cmd/serve/serve.go index 9d6351177a..ff2747bcda 100644 --- a/cmd/serve/serve.go +++ b/cmd/serve/serve.go @@ -47,13 +47,18 @@ var ServeCmd = &cobra.Command{ backend = os.Getenv("K8SGPT_BACKEND") password := os.Getenv("K8SGPT_PASSWORD") model := os.Getenv("K8SGPT_MODEL") - // If the envs are set, alocate in place to the aiProvider + baseURL := os.Getenv("K8SGPT_BASEURL") + + // If the envs are set, allocate in place to the aiProvider // else exit with error - if backend != "" || password != "" || model != "" { + envIsSet := backend != "" || password != "" || model != "" || baseURL != "" + + if envIsSet { aiProvider = &ai.AIProvider{ Name: backend, Password: password, Model: model, + BaseURL: baseURL, } configAI.Providers = append(configAI.Providers, *aiProvider) diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index c9899d84e2..5e5365b4db 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -27,6 +27,7 @@ type IAI interface { type IAIConfig interface { GetPassword() string GetModel() string + GetBaseURL() string } func NewClient(provider string) IAI { @@ -48,6 +49,11 @@ type AIProvider struct { Name string `mapstructure:"name"` Model string `mapstructure:"model"` Password string `mapstructure:"password"` + BaseURL string `mapstructure:"base_url"` +} + +func (p *AIProvider) GetBaseURL() string { + return p.BaseURL } func (p *AIProvider) GetPassword() string { diff --git a/pkg/ai/openai.go b/pkg/ai/openai.go index 62c33b7baf..42d3d6f987 100644 --- a/pkg/ai/openai.go +++ b/pkg/ai/openai.go @@ -37,6 +37,12 @@ type OpenAIClient struct { func (c *OpenAIClient) Configure(config IAIConfig, language string) error { token := config.GetPassword() defaultConfig := openai.DefaultConfig(token) + + baseURL := config.GetBaseURL() + if baseURL != "" { + defaultConfig.BaseURL = baseURL + } + client := openai.NewClientWithConfig(defaultConfig) if client == nil { return errors.New("error creating OpenAI client")