From f55946d60ebc7725aba6702570ca1cb5ba978d78 Mon Sep 17 00:00:00 2001 From: Peter Pan Date: Mon, 18 Sep 2023 20:14:43 +0800 Subject: [PATCH] feat: openAI explicit value for maxToken and temperature (#659) * feat: openAI explicit value for maxToken and temp Because when k8sgpt talks with vLLM, the default MaxToken is 16, which is so small. Given the most model supports 2048 token(like Llama1 ..etc), so put here for a safe value. Signed-off-by: Peter Pan * feat: make temperature a flag Signed-off-by: Peter Pan --------- Signed-off-by: Peter Pan --- README.md | 2 +- cmd/auth/add.go | 17 ++++++++++++----- cmd/auth/auth.go | 11 ++++++----- cmd/auth/update.go | 7 +++++++ pkg/ai/azureopenai.go | 9 ++++++--- pkg/ai/cohere.go | 10 ++++++---- pkg/ai/iai.go | 15 ++++++++++----- pkg/ai/openai.go | 21 ++++++++++++++++++--- 8 files changed, 66 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 37b588c013..b256238771 100644 --- a/README.md +++ b/README.md @@ -302,7 +302,7 @@ To start the API server, follow the instruction in [LocalAI](https://github.com/ To run k8sgpt, run `k8sgpt auth add` with the `localai` backend: ``` -k8sgpt auth add --backend localai --model --baseurl http://localhost:8080/v1 +k8sgpt auth add --backend localai --model --baseurl http://localhost:8080/v1 --temperature 0.7 ``` Now you can analyze with the `localai` backend: diff --git a/cmd/auth/add.go b/cmd/auth/add.go index 8ced8adb86..6aa6271433 100644 --- a/cmd/auth/add.go +++ b/cmd/auth/add.go @@ -75,6 +75,10 @@ var addCmd = &cobra.Command{ color.Red("Error: Model cannot be empty.") os.Exit(1) } + if temperature > 1.0 || temperature < 0.0 { + color.Red("Error: temperature ranges from 0 to 1.") + os.Exit(1) + } if ai.NeedPassword(backend) && password == "" { fmt.Printf("Enter %s Key: ", backend) @@ -89,11 +93,12 @@ var addCmd = &cobra.Command{ // create new provider object newProvider := ai.AIProvider{ - Name: backend, - Model: model, - Password: password, - BaseURL: baseURL, - Engine: engine, + Name: backend, + Model: model, + Password: password, + BaseURL: baseURL, + Engine: engine, + Temperature: temperature, } if providerIndex == -1 { @@ -121,6 +126,8 @@ func init() { addCmd.Flags().StringVarP(&password, "password", "p", "", "Backend AI password") // add flag for url addCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)") + // add flag for temperature + addCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)") // add flag for azure open ai engine/deployment name addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name") } diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index 75b4b4bebb..b64f54f5c6 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -19,11 +19,12 @@ import ( ) var ( - backend string - password string - baseURL string - model string - engine string + backend string + password string + baseURL string + model string + engine string + temperature float32 ) var configAI ai.AIConfiguration diff --git a/cmd/auth/update.go b/cmd/auth/update.go index be0561ba63..eb9a0e79ef 100644 --- a/cmd/auth/update.go +++ b/cmd/auth/update.go @@ -49,6 +49,10 @@ var updateCmd = &cobra.Command{ color.Red("Error: backend must be set.") os.Exit(1) } + if temperature > 1.0 || temperature < 0.0 { + color.Red("Error: temperature ranges from 0 to 1.") + os.Exit(1) + } for _, b := range inputBackends { foundBackend := false @@ -74,6 +78,7 @@ var updateCmd = &cobra.Command{ if engine != "" { configAI.Providers[i].Engine = engine } + configAI.Providers[i].Temperature = temperature color.Green("%s updated in the AI backend provider list", b) } } @@ -101,6 +106,8 @@ func init() { updateCmd.Flags().StringVarP(&password, "password", "p", "", "Update backend AI password") // update flag for url updateCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "Update URL AI provider, (e.g `http://localhost:8080/v1`)") + // add flag for temperature + updateCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)") // update flag for azure open ai engine/deployment name updateCmd.Flags().StringVarP(&engine, "engine", "e", "", "Update Azure AI deployment name") } diff --git a/pkg/ai/azureopenai.go b/pkg/ai/azureopenai.go index 0573caf72e..cd5e073959 100644 --- a/pkg/ai/azureopenai.go +++ b/pkg/ai/azureopenai.go @@ -16,9 +16,10 @@ import ( ) type AzureAIClient struct { - client *openai.Client - language string - model string + client *openai.Client + language string + model string + temperature float32 } func (c *AzureAIClient) Configure(config IAIConfig, lang string) error { @@ -42,6 +43,7 @@ func (c *AzureAIClient) Configure(config IAIConfig, lang string) error { c.language = lang c.client = client c.model = config.GetModel() + c.temperature = config.GetTemperature() return nil } @@ -55,6 +57,7 @@ func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string, prompt Content: fmt.Sprintf(default_prompt, c.language, prompt), }, }, + Temperature: c.temperature, }) if err != nil { return "", err diff --git a/pkg/ai/cohere.go b/pkg/ai/cohere.go index a09963c544..64a48c54ff 100644 --- a/pkg/ai/cohere.go +++ b/pkg/ai/cohere.go @@ -28,9 +28,10 @@ import ( ) type CohereClient struct { - client *cohere.Client - language string - model string + client *cohere.Client + language string + model string + temperature float32 } func (c *CohereClient) Configure(config IAIConfig, language string) error { @@ -52,6 +53,7 @@ func (c *CohereClient) Configure(config IAIConfig, language string) error { c.language = language c.client = client c.model = config.GetModel() + c.temperature = config.GetTemperature() return nil } @@ -64,7 +66,7 @@ func (c *CohereClient) GetCompletion(ctx context.Context, prompt, promptTmpl str Model: c.model, Prompt: fmt.Sprintf(strings.TrimSpace(promptTmpl), c.language, prompt), MaxTokens: cohere.Uint(2048), - Temperature: cohere.Float64(0.75), + Temperature: cohere.Float64(float64(c.temperature)), K: cohere.Int(0), StopSequences: []string{}, ReturnLikelihoods: "NONE", diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index 2d09e41231..b8172d161b 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -48,6 +48,7 @@ type IAIConfig interface { GetModel() string GetBaseURL() string GetEngine() string + GetTemperature() float32 } func NewClient(provider string) IAI { @@ -66,11 +67,12 @@ type AIConfiguration struct { } type AIProvider struct { - Name string `mapstructure:"name"` - Model string `mapstructure:"model"` - Password string `mapstructure:"password" yaml:"password,omitempty"` - BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"` - Engine string `mapstructure:"engine" yaml:"engine,omitempty"` + Name string `mapstructure:"name"` + Model string `mapstructure:"model"` + Password string `mapstructure:"password" yaml:"password,omitempty"` + BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"` + Engine string `mapstructure:"engine" yaml:"engine,omitempty"` + Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"` } func (p *AIProvider) GetBaseURL() string { @@ -88,6 +90,9 @@ func (p *AIProvider) GetModel() string { func (p *AIProvider) GetEngine() string { return p.Engine } +func (p *AIProvider) GetTemperature() float32 { + return p.Temperature +} func NeedPassword(backend string) bool { return backend != "localai" diff --git a/pkg/ai/openai.go b/pkg/ai/openai.go index 7d9e6797af..b37f2aec08 100644 --- a/pkg/ai/openai.go +++ b/pkg/ai/openai.go @@ -29,11 +29,20 @@ import ( ) type OpenAIClient struct { - client *openai.Client - language string - model string + client *openai.Client + language string + model string + temperature float32 } +const ( + // OpenAI completion parameters + maxToken = 2048 + presencePenalty = 0.0 + frequencyPenalty = 0.0 + topP = 1.0 +) + func (c *OpenAIClient) Configure(config IAIConfig, language string) error { token := config.GetPassword() defaultConfig := openai.DefaultConfig(token) @@ -50,6 +59,7 @@ func (c *OpenAIClient) Configure(config IAIConfig, language string) error { c.language = language c.client = client c.model = config.GetModel() + c.temperature = config.GetTemperature() return nil } @@ -66,6 +76,11 @@ func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string, promptT Content: fmt.Sprintf(promptTmpl, c.language, prompt), }, }, + Temperature: c.temperature, + MaxTokens: maxToken, + PresencePenalty: presencePenalty, + FrequencyPenalty: frequencyPenalty, + TopP: topP, }) if err != nil { return "", err