From 307710eddc1c3f96f40a674f7dda786510e9c4cc Mon Sep 17 00:00:00 2001 From: Tanuj Dwivedi Date: Wed, 28 Feb 2024 21:40:42 +0530 Subject: [PATCH] feat: add proxysettings for azureopenai and openai (#987) Signed-off-by: tanujd11 Co-authored-by: Aris Boutselis Co-authored-by: Alex Jones --- cmd/serve/serve.go | 2 ++ pkg/ai/azureopenai.go | 17 +++++++++++++++++ pkg/ai/iai.go | 7 +++++++ pkg/ai/openai.go | 17 +++++++++++++++++ 4 files changed, 43 insertions(+) diff --git a/cmd/serve/serve.go b/cmd/serve/serve.go index 95ad986594..bf26461499 100644 --- a/cmd/serve/serve.go +++ b/cmd/serve/serve.go @@ -73,6 +73,7 @@ var ServeCmd = &cobra.Command{ model := os.Getenv("K8SGPT_MODEL") baseURL := os.Getenv("K8SGPT_BASEURL") engine := os.Getenv("K8SGPT_ENGINE") + proxyEndpoint := os.Getenv("K8SGPT_PROXY_ENDPOINT") // If the envs are set, allocate in place to the aiProvider // else exit with error envIsSet := backend != "" || password != "" || model != "" @@ -83,6 +84,7 @@ var ServeCmd = &cobra.Command{ Model: model, BaseURL: baseURL, Engine: engine, + ProxyEndpoint: proxyEndpoint, Temperature: temperature(), } diff --git a/pkg/ai/azureopenai.go b/pkg/ai/azureopenai.go index 7def5ea430..34cd2a46c8 100644 --- a/pkg/ai/azureopenai.go +++ b/pkg/ai/azureopenai.go @@ -3,6 +3,8 @@ package ai import ( "context" "errors" + "net/http" + "net/url" "github.com/sashabaranov/go-openai" ) @@ -21,6 +23,7 @@ func (c *AzureAIClient) Configure(config IAIConfig) error { token := config.GetPassword() baseURL := config.GetBaseURL() engine := config.GetEngine() + proxyEndpoint := config.GetProxyEndpoint() defaultConfig := openai.DefaultAzureConfig(token, baseURL) defaultConfig.AzureModelMapperFunc = func(model string) string { @@ -31,6 +34,20 @@ func (c *AzureAIClient) Configure(config IAIConfig) error { return azureModelMapping[model] } + + if proxyEndpoint != "" { + proxyUrl, err := url.Parse(proxyEndpoint) + if err != nil { + return err + } + transport := &http.Transport{ + Proxy: http.ProxyURL(proxyUrl), + } + + defaultConfig.HTTPClient = &http.Client{ + Transport: transport, + } + } client := openai.NewClientWithConfig(defaultConfig) if client == nil { return errors.New("error creating Azure OpenAI client") diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index 99de8e3a40..083d2f5c73 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -64,6 +64,7 @@ type IAIConfig interface { GetPassword() string GetModel() string GetBaseURL() string + GetProxyEndpoint() string GetEndpointName() string GetEngine() string GetTemperature() float32 @@ -92,6 +93,8 @@ type AIProvider struct { Model string `mapstructure:"model"` Password string `mapstructure:"password" yaml:"password,omitempty"` BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"` + ProxyEndpoint string `mapstructure:"proxyEndpoint" yaml:"proxyEndpoint,omitempty"` + ProxyPort string `mapstructure:"proxyPort" yaml:"proxyPort,omitempty"` EndpointName string `mapstructure:"endpointname" yaml:"endpointname,omitempty"` Engine string `mapstructure:"engine" yaml:"engine,omitempty"` Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"` @@ -104,6 +107,10 @@ func (p *AIProvider) GetBaseURL() string { return p.BaseURL } +func (p *AIProvider) GetProxyEndpoint() string { + return p.ProxyEndpoint +} + func (p *AIProvider) GetEndpointName() string { return p.EndpointName } diff --git a/pkg/ai/openai.go b/pkg/ai/openai.go index e60e734f39..b0e38e86a4 100644 --- a/pkg/ai/openai.go +++ b/pkg/ai/openai.go @@ -16,6 +16,8 @@ package ai import ( "context" "errors" + "net/http" + "net/url" "github.com/sashabaranov/go-openai" ) @@ -41,12 +43,27 @@ const ( func (c *OpenAIClient) Configure(config IAIConfig) error { token := config.GetPassword() defaultConfig := openai.DefaultConfig(token) + proxyEndpoint := config.GetProxyEndpoint() baseURL := config.GetBaseURL() if baseURL != "" { defaultConfig.BaseURL = baseURL } + if proxyEndpoint != "" { + proxyUrl, err := url.Parse(proxyEndpoint) + if err != nil { + return err + } + transport := &http.Transport{ + Proxy: http.ProxyURL(proxyUrl), + } + + defaultConfig.HTTPClient = &http.Client{ + Transport: transport, + } + } + client := openai.NewClientWithConfig(defaultConfig) if client == nil { return errors.New("error creating OpenAI client")