diff --git a/cmd/auth/add.go b/cmd/auth/add.go index ed455d9add..e9158651b5 100644 --- a/cmd/auth/add.go +++ b/cmd/auth/add.go @@ -100,6 +100,10 @@ var addCmd = &cobra.Command{ color.Red("Error: topP ranges from 0 to 1.") os.Exit(1) } + if topK < 1 || topK > 100 { + color.Red("Error: topK ranges from 1 to 100.") + os.Exit(1) + } if ai.NeedPassword(backend) && password == "" { fmt.Printf("Enter %s Key: ", backend) @@ -124,6 +128,7 @@ var addCmd = &cobra.Command{ ProviderRegion: providerRegion, ProviderId: providerId, TopP: topP, + TopK: topK, MaxTokens: maxTokens, } @@ -156,6 +161,8 @@ func init() { addCmd.Flags().StringVarP(&endpointName, "endpointname", "n", "", "Endpoint Name, e.g. `endpoint-xxxxxxxxxxxx` (only for amazonbedrock, amazonsagemaker backends)") // add flag for topP addCmd.Flags().Float32VarP(&topP, "topp", "c", 0.5, "Probability Cutoff: Set a threshold (0.0-1.0) to limit word choices. Higher values add randomness, lower values increase predictability.") + // add flag for topK + addCmd.Flags().Int32VarP(&topK, "topk", "c", 50, "Sampling Cutoff: Set a threshold (1-100) to restrict the sampling process to the top K most probable words at each step. Higher values lead to greater variability, lower values increases predictability.") // max tokens addCmd.Flags().IntVarP(&maxTokens, "maxtokens", "l", 2048, "Specify a maximum output length. Adjust (1-...) to control text length. Higher values produce longer output, lower values limit length") // add flag for temperature diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index 31f424842b..9da9ccefb7 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -29,6 +29,7 @@ var ( providerRegion string providerId string topP float32 + topK int32 maxTokens int ) diff --git a/cmd/serve/serve.go b/cmd/serve/serve.go index 69e29ed1e7..1984ac9805 100644 --- a/cmd/serve/serve.go +++ b/cmd/serve/serve.go @@ -28,6 +28,7 @@ import ( const ( defaultTemperature float32 = 0.7 defaultTopP float32 = 1.0 + defaultTopK int32 = 50 ) var ( @@ -84,6 +85,22 @@ var ServeCmd = &cobra.Command{ } return float32(topP) } + topK := func() int32 { + env := os.Getenv("K8SGPT_TOP_K") + if env == "" { + return defaultTopK + } + topK, err := strconv.ParseFloat(env, 32) + if err != nil { + color.Red("Unable to convert topK value: %v", err) + os.Exit(1) + } + if topK < 10 || topK > 100 { + color.Red("Error: topK ranges from 1 to 100.") + os.Exit(1) + } + return int32(topK) + } // Check for env injection backend = os.Getenv("K8SGPT_BACKEND") password := os.Getenv("K8SGPT_PASSWORD") @@ -104,6 +121,7 @@ var ServeCmd = &cobra.Command{ ProxyEndpoint: proxyEndpoint, Temperature: temperature(), TopP: topP(), + TopK: topK(), } configAI.Providers = append(configAI.Providers, *aiProvider) diff --git a/pkg/ai/amazonsagemaker.go b/pkg/ai/amazonsagemaker.go index 23a6e6dac5..05bcdd97e8 100644 --- a/pkg/ai/amazonsagemaker.go +++ b/pkg/ai/amazonsagemaker.go @@ -33,6 +33,7 @@ type SageMakerAIClient struct { temperature float32 endpoint string topP float32 + topK int32 maxTokens int } @@ -56,6 +57,7 @@ type Message struct { type Parameters struct { MaxNewTokens int `json:"max_new_tokens"` TopP float64 `json:"top_p"` + TopK float64 `json:"top_k"` Temperature float64 `json:"temperature"` } @@ -74,6 +76,7 @@ func (c *SageMakerAIClient) Configure(config IAIConfig) error { c.temperature = config.GetTemperature() c.maxTokens = config.GetMaxTokens() c.topP = config.GetTopP() + c.topK = config.GetTopK() return nil } @@ -90,6 +93,7 @@ func (c *SageMakerAIClient) GetCompletion(_ context.Context, prompt string) (str Parameters: Parameters{ MaxNewTokens: int(c.maxTokens), TopP: float64(c.topP), + TopK: float64(c.topK), Temperature: float64(c.temperature), }, } diff --git a/pkg/ai/googlegenai.go b/pkg/ai/googlegenai.go index 1b439dc87c..48244245ae 100644 --- a/pkg/ai/googlegenai.go +++ b/pkg/ai/googlegenai.go @@ -31,6 +31,7 @@ type GoogleGenAIClient struct { model string temperature float32 topP float32 + topK int32 maxTokens int } @@ -53,6 +54,7 @@ func (c *GoogleGenAIClient) Configure(config IAIConfig) error { c.model = config.GetModel() c.temperature = config.GetTemperature() c.topP = config.GetTopP() + c.topK = config.GetTopK() c.maxTokens = config.GetMaxTokens() return nil } @@ -62,6 +64,7 @@ func (c *GoogleGenAIClient) GetCompletion(ctx context.Context, prompt string) (s model := c.client.GenerativeModel(c.model) model.SetTemperature(c.temperature) model.SetTopP(c.topP) + model.SetTopK(c.topK) model.SetMaxOutputTokens(int32(c.maxTokens)) // Google AI SDK is capable of different inputs than just text, for now set explicit text prompt type. diff --git a/pkg/ai/googlevertexai.go b/pkg/ai/googlevertexai.go index b46e90e95a..093392b7c1 100644 --- a/pkg/ai/googlevertexai.go +++ b/pkg/ai/googlevertexai.go @@ -30,6 +30,7 @@ type GoogleVertexAIClient struct { model string temperature float32 topP float32 + topK int32 maxTokens int } @@ -111,6 +112,7 @@ func (g *GoogleVertexAIClient) Configure(config IAIConfig) error { g.model = GetVertexAIModelOrDefault(config.GetModel()) g.temperature = config.GetTemperature() g.topP = config.GetTopP() + g.topK = config.GetTopK() g.maxTokens = config.GetMaxTokens() return nil @@ -121,6 +123,7 @@ func (g *GoogleVertexAIClient) GetCompletion(ctx context.Context, prompt string) model := g.client.GenerativeModel(g.model) model.SetTemperature(g.temperature) model.SetTopP(g.topP) + model.SetTopK(float32(g.topK)) model.SetMaxOutputTokens(int32(g.maxTokens)) // Google AI SDK is capable of different inputs than just text, for now set explicit text prompt type. diff --git a/pkg/ai/huggingface.go b/pkg/ai/huggingface.go index 68dcc8ccb4..f79f8a5912 100644 --- a/pkg/ai/huggingface.go +++ b/pkg/ai/huggingface.go @@ -2,6 +2,7 @@ package ai import ( "context" + "github.com/hupe1980/go-huggingface" "k8s.io/utils/ptr" ) @@ -14,6 +15,7 @@ type HuggingfaceClient struct { client *huggingface.InferenceClient model string topP float32 + topK int32 temperature float32 maxTokens int } @@ -26,6 +28,7 @@ func (c *HuggingfaceClient) Configure(config IAIConfig) error { c.client = client c.model = config.GetModel() c.topP = config.GetTopP() + c.topK = config.GetTopK() c.temperature = config.GetTemperature() if config.GetMaxTokens() > 500 { c.maxTokens = 500 @@ -43,6 +46,7 @@ func (c *HuggingfaceClient) GetCompletion(ctx context.Context, prompt string) (s Model: c.model, Parameters: huggingface.ConversationalParameters{ TopP: ptr.To[float64](float64(c.topP)), + TopK: ptr.To[int](int(c.topK)), Temperature: ptr.To[float64](float64(c.temperature)), MaxLength: &c.maxTokens, }, diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index 6d51664e46..87fba49650 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -72,6 +72,7 @@ type IAIConfig interface { GetTemperature() float32 GetProviderRegion() string GetTopP() float32 + GetTopK() int32 GetMaxTokens() int GetProviderId() string } @@ -104,6 +105,7 @@ type AIProvider struct { ProviderRegion string `mapstructure:"providerregion" yaml:"providerregion,omitempty"` ProviderId string `mapstructure:"providerid" yaml:"providerid,omitempty"` TopP float32 `mapstructure:"topp" yaml:"topp,omitempty"` + TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"` MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"` } @@ -123,6 +125,10 @@ func (p *AIProvider) GetTopP() float32 { return p.TopP } +func (p *AIProvider) GetTopK() int32 { + return p.TopK +} + func (p *AIProvider) GetMaxTokens() int { return p.MaxTokens }