Skip to content

Commit

Permalink
feat: make temperature a flag
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Pan <Peter.Pan@daocloud.io>
  • Loading branch information
panpan0000 committed Sep 18, 2023
1 parent 30f066d commit dcec8f4
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 28 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ To start the API server, follow the instruction in [LocalAI](https://github.com/
To run k8sgpt, run `k8sgpt auth add` with the `localai` backend:

```
k8sgpt auth add --backend localai --model <model_name> --baseurl http://localhost:8080/v1
k8sgpt auth add --backend localai --model <model_name> --baseurl http://localhost:8080/v1 --temperature 0.7
```

Now you can analyze with the `localai` backend:
Expand Down
17 changes: 12 additions & 5 deletions cmd/auth/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ var addCmd = &cobra.Command{
color.Red("Error: Model cannot be empty.")
os.Exit(1)
}
if temperature > 1.0 || temperature < 0.0 {
color.Red("Error: temperature ranges from 0 to 1.")
os.Exit(1)
}

if ai.NeedPassword(backend) && password == "" {
fmt.Printf("Enter %s Key: ", backend)
Expand All @@ -89,11 +93,12 @@ var addCmd = &cobra.Command{

// create new provider object
newProvider := ai.AIProvider{
Name: backend,
Model: model,
Password: password,
BaseURL: baseURL,
Engine: engine,
Name: backend,
Model: model,
Password: password,
BaseURL: baseURL,
Engine: engine,
Temperature: temperature,
}

if providerIndex == -1 {
Expand Down Expand Up @@ -121,6 +126,8 @@ func init() {
addCmd.Flags().StringVarP(&password, "password", "p", "", "Backend AI password")
// add flag for url
addCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)")
// add flag for temperature
addCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
// add flag for azure open ai engine/deployment name
addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name")
}
11 changes: 6 additions & 5 deletions cmd/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ import (
)

var (
backend string
password string
baseURL string
model string
engine string
backend string
password string
baseURL string
model string
engine string
temperature float32
)

var configAI ai.AIConfiguration
Expand Down
7 changes: 7 additions & 0 deletions cmd/auth/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ var updateCmd = &cobra.Command{
color.Red("Error: backend must be set.")
os.Exit(1)
}
if temperature > 1.0 || temperature < 0.0 {
color.Red("Error: temperature ranges from 0 to 1.")
os.Exit(1)
}

for _, b := range inputBackends {
foundBackend := false
Expand All @@ -74,6 +78,7 @@ var updateCmd = &cobra.Command{
if engine != "" {
configAI.Providers[i].Engine = engine
}
configAI.Providers[i].Temperature = temperature
color.Green("%s updated in the AI backend provider list", b)
}
}
Expand Down Expand Up @@ -101,6 +106,8 @@ func init() {
updateCmd.Flags().StringVarP(&password, "password", "p", "", "Update backend AI password")
// update flag for url
updateCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "Update URL AI provider, (e.g `http://localhost:8080/v1`)")
// add flag for temperature
updateCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
// update flag for azure open ai engine/deployment name
updateCmd.Flags().StringVarP(&engine, "engine", "e", "", "Update Azure AI deployment name")
}
9 changes: 6 additions & 3 deletions pkg/ai/azureopenai.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ import (
)

type AzureAIClient struct {
client *openai.Client
language string
model string
client *openai.Client
language string
model string
temperature float32
}

func (c *AzureAIClient) Configure(config IAIConfig, lang string) error {
Expand All @@ -42,6 +43,7 @@ func (c *AzureAIClient) Configure(config IAIConfig, lang string) error {
c.language = lang
c.client = client
c.model = config.GetModel()
c.temperature = config.GetTemperature()
return nil
}

Expand All @@ -55,6 +57,7 @@ func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string, prompt
Content: fmt.Sprintf(default_prompt, c.language, prompt),
},
},
Temperature: c.temperature,
})
if err != nil {
return "", err
Expand Down
10 changes: 6 additions & 4 deletions pkg/ai/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ import (
)

type CohereClient struct {
client *cohere.Client
language string
model string
client *cohere.Client
language string
model string
temperature float32
}

func (c *CohereClient) Configure(config IAIConfig, language string) error {
Expand All @@ -52,6 +53,7 @@ func (c *CohereClient) Configure(config IAIConfig, language string) error {
c.language = language
c.client = client
c.model = config.GetModel()
c.temperature = config.GetTemperature()
return nil
}

Expand All @@ -64,7 +66,7 @@ func (c *CohereClient) GetCompletion(ctx context.Context, prompt, promptTmpl str
Model: c.model,
Prompt: fmt.Sprintf(strings.TrimSpace(promptTmpl), c.language, prompt),
MaxTokens: cohere.Uint(2048),
Temperature: cohere.Float64(0.75),
Temperature: cohere.Float64(float64(c.temperature)),
K: cohere.Int(0),
StopSequences: []string{},
ReturnLikelihoods: "NONE",
Expand Down
15 changes: 10 additions & 5 deletions pkg/ai/iai.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type IAIConfig interface {
GetModel() string
GetBaseURL() string
GetEngine() string
GetTemperature() float32
}

func NewClient(provider string) IAI {
Expand All @@ -66,11 +67,12 @@ type AIConfiguration struct {
}

type AIProvider struct {
Name string `mapstructure:"name"`
Model string `mapstructure:"model"`
Password string `mapstructure:"password" yaml:"password,omitempty"`
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
Name string `mapstructure:"name"`
Model string `mapstructure:"model"`
Password string `mapstructure:"password" yaml:"password,omitempty"`
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
}

func (p *AIProvider) GetBaseURL() string {
Expand All @@ -88,6 +90,9 @@ func (p *AIProvider) GetModel() string {
func (p *AIProvider) GetEngine() string {
return p.Engine
}
func (p *AIProvider) GetTemperature() float32 {
return p.Temperature
}

func NeedPassword(backend string) bool {
return backend != "localai"
Expand Down
11 changes: 6 additions & 5 deletions pkg/ai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ import (
)

type OpenAIClient struct {
client *openai.Client
language string
model string
client *openai.Client
language string
model string
temperature float32
}

const (
// OpenAI completion parameters
maxToken = 2048
temperature = 0.7
presencePenalty = 0.0
frequencyPenalty = 0.0
topP = 1.0
Expand All @@ -59,6 +59,7 @@ func (c *OpenAIClient) Configure(config IAIConfig, language string) error {
c.language = language
c.client = client
c.model = config.GetModel()
c.temperature = config.GetTemperature()
return nil
}

Expand All @@ -75,8 +76,8 @@ func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string, promptT
Content: fmt.Sprintf(promptTmpl, c.language, prompt),
},
},
Temperature: c.temperature,
MaxTokens: maxToken,
Temperature: temperature,
PresencePenalty: presencePenalty,
FrequencyPenalty: frequencyPenalty,
TopP: topP,
Expand Down

0 comments on commit dcec8f4

Please sign in to comment.