diff --git a/README.md b/README.md index 37b588c013..b256238771 100644 --- a/README.md +++ b/README.md @@ -302,7 +302,7 @@ To start the API server, follow the instruction in [LocalAI](https://github.com/ To run k8sgpt, run `k8sgpt auth add` with the `localai` backend: ``` -k8sgpt auth add --backend localai --model --baseurl http://localhost:8080/v1 +k8sgpt auth add --backend localai --model --baseurl http://localhost:8080/v1 --temperature 0.7 ``` Now you can analyze with the `localai` backend: diff --git a/cmd/auth/add.go b/cmd/auth/add.go index 8ced8adb86..6aa6271433 100644 --- a/cmd/auth/add.go +++ b/cmd/auth/add.go @@ -75,6 +75,10 @@ var addCmd = &cobra.Command{ color.Red("Error: Model cannot be empty.") os.Exit(1) } + if temperature > 1.0 || temperature < 0.0 { + color.Red("Error: temperature ranges from 0 to 1.") + os.Exit(1) + } if ai.NeedPassword(backend) && password == "" { fmt.Printf("Enter %s Key: ", backend) @@ -89,11 +93,12 @@ var addCmd = &cobra.Command{ // create new provider object newProvider := ai.AIProvider{ - Name: backend, - Model: model, - Password: password, - BaseURL: baseURL, - Engine: engine, + Name: backend, + Model: model, + Password: password, + BaseURL: baseURL, + Engine: engine, + Temperature: temperature, } if providerIndex == -1 { @@ -121,6 +126,8 @@ func init() { addCmd.Flags().StringVarP(&password, "password", "p", "", "Backend AI password") // add flag for url addCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)") + // add flag for temperature + addCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)") // add flag for azure open ai engine/deployment name addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name") } diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index 75b4b4bebb..b64f54f5c6 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -19,11 +19,12 @@ import ( ) var ( - backend string - password string - baseURL string - model string - engine string + backend string + password string + baseURL string + model string + engine string + temperature float32 ) var configAI ai.AIConfiguration diff --git a/cmd/auth/update.go b/cmd/auth/update.go index be0561ba63..eb9a0e79ef 100644 --- a/cmd/auth/update.go +++ b/cmd/auth/update.go @@ -49,6 +49,10 @@ var updateCmd = &cobra.Command{ color.Red("Error: backend must be set.") os.Exit(1) } + if temperature > 1.0 || temperature < 0.0 { + color.Red("Error: temperature ranges from 0 to 1.") + os.Exit(1) + } for _, b := range inputBackends { foundBackend := false @@ -74,6 +78,7 @@ var updateCmd = &cobra.Command{ if engine != "" { configAI.Providers[i].Engine = engine } + configAI.Providers[i].Temperature = temperature color.Green("%s updated in the AI backend provider list", b) } } @@ -101,6 +106,8 @@ func init() { updateCmd.Flags().StringVarP(&password, "password", "p", "", "Update backend AI password") // update flag for url updateCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "Update URL AI provider, (e.g `http://localhost:8080/v1`)") + // add flag for temperature + updateCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)") // update flag for azure open ai engine/deployment name updateCmd.Flags().StringVarP(&engine, "engine", "e", "", "Update Azure AI deployment name") } diff --git a/pkg/ai/azureopenai.go b/pkg/ai/azureopenai.go index 0573caf72e..cd5e073959 100644 --- a/pkg/ai/azureopenai.go +++ b/pkg/ai/azureopenai.go @@ -16,9 +16,10 @@ import ( ) type AzureAIClient struct { - client *openai.Client - language string - model string + client *openai.Client + language string + model string + temperature float32 } func (c *AzureAIClient) Configure(config IAIConfig, lang string) error { @@ -42,6 +43,7 @@ func (c *AzureAIClient) Configure(config IAIConfig, lang string) error { c.language = lang c.client = client c.model = config.GetModel() + c.temperature = config.GetTemperature() return nil } @@ -55,6 +57,7 @@ func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string, prompt Content: fmt.Sprintf(default_prompt, c.language, prompt), }, }, + Temperature: c.temperature, }) if err != nil { return "", err diff --git a/pkg/ai/cohere.go b/pkg/ai/cohere.go index a09963c544..64a48c54ff 100644 --- a/pkg/ai/cohere.go +++ b/pkg/ai/cohere.go @@ -28,9 +28,10 @@ import ( ) type CohereClient struct { - client *cohere.Client - language string - model string + client *cohere.Client + language string + model string + temperature float32 } func (c *CohereClient) Configure(config IAIConfig, language string) error { @@ -52,6 +53,7 @@ func (c *CohereClient) Configure(config IAIConfig, language string) error { c.language = language c.client = client c.model = config.GetModel() + c.temperature = config.GetTemperature() return nil } @@ -64,7 +66,7 @@ func (c *CohereClient) GetCompletion(ctx context.Context, prompt, promptTmpl str Model: c.model, Prompt: fmt.Sprintf(strings.TrimSpace(promptTmpl), c.language, prompt), MaxTokens: cohere.Uint(2048), - Temperature: cohere.Float64(0.75), + Temperature: cohere.Float64(float64(c.temperature)), K: cohere.Int(0), StopSequences: []string{}, ReturnLikelihoods: "NONE", diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index 2d09e41231..b8172d161b 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -48,6 +48,7 @@ type IAIConfig interface { GetModel() string GetBaseURL() string GetEngine() string + GetTemperature() float32 } func NewClient(provider string) IAI { @@ -66,11 +67,12 @@ type AIConfiguration struct { } type AIProvider struct { - Name string `mapstructure:"name"` - Model string `mapstructure:"model"` - Password string `mapstructure:"password" yaml:"password,omitempty"` - BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"` - Engine string `mapstructure:"engine" yaml:"engine,omitempty"` + Name string `mapstructure:"name"` + Model string `mapstructure:"model"` + Password string `mapstructure:"password" yaml:"password,omitempty"` + BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"` + Engine string `mapstructure:"engine" yaml:"engine,omitempty"` + Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"` } func (p *AIProvider) GetBaseURL() string { @@ -88,6 +90,9 @@ func (p *AIProvider) GetModel() string { func (p *AIProvider) GetEngine() string { return p.Engine } +func (p *AIProvider) GetTemperature() float32 { + return p.Temperature +} func NeedPassword(backend string) bool { return backend != "localai" diff --git a/pkg/ai/openai.go b/pkg/ai/openai.go index 7d9e6797af..b37f2aec08 100644 --- a/pkg/ai/openai.go +++ b/pkg/ai/openai.go @@ -29,11 +29,20 @@ import ( ) type OpenAIClient struct { - client *openai.Client - language string - model string + client *openai.Client + language string + model string + temperature float32 } +const ( + // OpenAI completion parameters + maxToken = 2048 + presencePenalty = 0.0 + frequencyPenalty = 0.0 + topP = 1.0 +) + func (c *OpenAIClient) Configure(config IAIConfig, language string) error { token := config.GetPassword() defaultConfig := openai.DefaultConfig(token) @@ -50,6 +59,7 @@ func (c *OpenAIClient) Configure(config IAIConfig, language string) error { c.language = language c.client = client c.model = config.GetModel() + c.temperature = config.GetTemperature() return nil } @@ -66,6 +76,11 @@ func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string, promptT Content: fmt.Sprintf(promptTmpl, c.language, prompt), }, }, + Temperature: c.temperature, + MaxTokens: maxToken, + PresencePenalty: presencePenalty, + FrequencyPenalty: frequencyPenalty, + TopP: topP, }) if err != nil { return "", err