From c162cc22ee468070e0602d3fd684b022fa585c4f Mon Sep 17 00:00:00 2001 From: Guido Muscioni <32247226+muscionig@users.noreply.github.com> Date: Fri, 19 Apr 2024 10:38:52 -0500 Subject: [PATCH] fix: set topP from config (#1053) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: set topP from config Signed-off-by: “Guido * style: correct format of openai ai provider Signed-off-by: “Guido * feat: set topP from the environment Signed-off-by: “Guido --------- Signed-off-by: “Guido --- cmd/serve/serve.go | 18 ++++++++++++++++++ pkg/ai/openai.go | 5 +++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/cmd/serve/serve.go b/cmd/serve/serve.go index ab8fb748fb..69e29ed1e7 100644 --- a/cmd/serve/serve.go +++ b/cmd/serve/serve.go @@ -27,6 +27,7 @@ import ( const ( defaultTemperature float32 = 0.7 + defaultTopP float32 = 1.0 ) var ( @@ -67,6 +68,22 @@ var ServeCmd = &cobra.Command{ } return float32(temperature) } + topP := func() float32 { + env := os.Getenv("K8SGPT_TOP_P") + if env == "" { + return defaultTopP + } + topP, err := strconv.ParseFloat(env, 32) + if err != nil { + color.Red("Unable to convert topP value: %v", err) + os.Exit(1) + } + if topP > 1.0 || topP < 0.0 { + color.Red("Error: topP ranges from 0 to 1.") + os.Exit(1) + } + return float32(topP) + } // Check for env injection backend = os.Getenv("K8SGPT_BACKEND") password := os.Getenv("K8SGPT_PASSWORD") @@ -86,6 +103,7 @@ var ServeCmd = &cobra.Command{ Engine: engine, ProxyEndpoint: proxyEndpoint, Temperature: temperature(), + TopP: topP(), } configAI.Providers = append(configAI.Providers, *aiProvider) diff --git a/pkg/ai/openai.go b/pkg/ai/openai.go index d047722a1b..ff032f7391 100644 --- a/pkg/ai/openai.go +++ b/pkg/ai/openai.go @@ -30,6 +30,7 @@ type OpenAIClient struct { client *openai.Client model string temperature float32 + topP float32 } const ( @@ -37,7 +38,6 @@ const ( maxToken = 2048 presencePenalty = 0.0 frequencyPenalty = 0.0 - topP = 1.0 ) func (c *OpenAIClient) Configure(config IAIConfig) error { @@ -71,6 +71,7 @@ func (c *OpenAIClient) Configure(config IAIConfig) error { c.client = client c.model = config.GetModel() c.temperature = config.GetTemperature() + c.topP = config.GetTopP() return nil } @@ -88,7 +89,7 @@ func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string) (string MaxTokens: maxToken, PresencePenalty: presencePenalty, FrequencyPenalty: frequencyPenalty, - TopP: topP, + TopP: c.topP, }) if err != nil { return "", err