diff --git a/cmd/serve/serve.go b/cmd/serve/serve.go index ab8fb748fb..69e29ed1e7 100644 --- a/cmd/serve/serve.go +++ b/cmd/serve/serve.go @@ -27,6 +27,7 @@ import ( const ( defaultTemperature float32 = 0.7 + defaultTopP float32 = 1.0 ) var ( @@ -67,6 +68,22 @@ var ServeCmd = &cobra.Command{ } return float32(temperature) } + topP := func() float32 { + env := os.Getenv("K8SGPT_TOP_P") + if env == "" { + return defaultTopP + } + topP, err := strconv.ParseFloat(env, 32) + if err != nil { + color.Red("Unable to convert topP value: %v", err) + os.Exit(1) + } + if topP > 1.0 || topP < 0.0 { + color.Red("Error: topP ranges from 0 to 1.") + os.Exit(1) + } + return float32(topP) + } // Check for env injection backend = os.Getenv("K8SGPT_BACKEND") password := os.Getenv("K8SGPT_PASSWORD") @@ -86,6 +103,7 @@ var ServeCmd = &cobra.Command{ Engine: engine, ProxyEndpoint: proxyEndpoint, Temperature: temperature(), + TopP: topP(), } configAI.Providers = append(configAI.Providers, *aiProvider) diff --git a/pkg/ai/openai.go b/pkg/ai/openai.go index d047722a1b..ff032f7391 100644 --- a/pkg/ai/openai.go +++ b/pkg/ai/openai.go @@ -30,6 +30,7 @@ type OpenAIClient struct { client *openai.Client model string temperature float32 + topP float32 } const ( @@ -37,7 +38,6 @@ const ( maxToken = 2048 presencePenalty = 0.0 frequencyPenalty = 0.0 - topP = 1.0 ) func (c *OpenAIClient) Configure(config IAIConfig) error { @@ -71,6 +71,7 @@ func (c *OpenAIClient) Configure(config IAIConfig) error { c.client = client c.model = config.GetModel() c.temperature = config.GetTemperature() + c.topP = config.GetTopP() return nil } @@ -88,7 +89,7 @@ func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string) (string MaxTokens: maxToken, PresencePenalty: presencePenalty, FrequencyPenalty: frequencyPenalty, - TopP: topP, + TopP: c.topP, }) if err != nil { return "", err