Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: openAI explicit value for maxToken and temperature #659

Merged
merged 2 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ To start the API server, follow the instruction in [LocalAI](https://github.com/
To run k8sgpt, run `k8sgpt auth add` with the `localai` backend:

```
k8sgpt auth add --backend localai --model <model_name> --baseurl http://localhost:8080/v1
k8sgpt auth add --backend localai --model <model_name> --baseurl http://localhost:8080/v1 --temperature 0.7
```

Now you can analyze with the `localai` backend:
Expand Down
17 changes: 12 additions & 5 deletions cmd/auth/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ var addCmd = &cobra.Command{
color.Red("Error: Model cannot be empty.")
os.Exit(1)
}
if temperature > 1.0 || temperature < 0.0 {
color.Red("Error: temperature ranges from 0 to 1.")
os.Exit(1)
}

if ai.NeedPassword(backend) && password == "" {
fmt.Printf("Enter %s Key: ", backend)
Expand All @@ -89,11 +93,12 @@ var addCmd = &cobra.Command{

// create new provider object
newProvider := ai.AIProvider{
Name: backend,
Model: model,
Password: password,
BaseURL: baseURL,
Engine: engine,
Name: backend,
Model: model,
Password: password,
BaseURL: baseURL,
Engine: engine,
Temperature: temperature,
}

if providerIndex == -1 {
Expand Down Expand Up @@ -121,6 +126,8 @@ func init() {
addCmd.Flags().StringVarP(&password, "password", "p", "", "Backend AI password")
// add flag for url
addCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "URL AI provider, (e.g `http://localhost:8080/v1`)")
// add flag for temperature
addCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
// add flag for azure open ai engine/deployment name
addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name")
}
11 changes: 6 additions & 5 deletions cmd/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ import (
)

var (
backend string
password string
baseURL string
model string
engine string
backend string
password string
baseURL string
model string
engine string
temperature float32
)

var configAI ai.AIConfiguration
Expand Down
7 changes: 7 additions & 0 deletions cmd/auth/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ var updateCmd = &cobra.Command{
color.Red("Error: backend must be set.")
os.Exit(1)
}
if temperature > 1.0 || temperature < 0.0 {
color.Red("Error: temperature ranges from 0 to 1.")
os.Exit(1)
}

for _, b := range inputBackends {
foundBackend := false
Expand All @@ -74,6 +78,7 @@ var updateCmd = &cobra.Command{
if engine != "" {
configAI.Providers[i].Engine = engine
}
configAI.Providers[i].Temperature = temperature
color.Green("%s updated in the AI backend provider list", b)
}
}
Expand Down Expand Up @@ -101,6 +106,8 @@ func init() {
updateCmd.Flags().StringVarP(&password, "password", "p", "", "Update backend AI password")
// update flag for url
updateCmd.Flags().StringVarP(&baseURL, "baseurl", "u", "", "Update URL AI provider, (e.g `http://localhost:8080/v1`)")
// add flag for temperature
updateCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
// update flag for azure open ai engine/deployment name
updateCmd.Flags().StringVarP(&engine, "engine", "e", "", "Update Azure AI deployment name")
}
9 changes: 6 additions & 3 deletions pkg/ai/azureopenai.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ import (
)

type AzureAIClient struct {
client *openai.Client
language string
model string
client *openai.Client
language string
model string
temperature float32
}

func (c *AzureAIClient) Configure(config IAIConfig, lang string) error {
Expand All @@ -42,6 +43,7 @@ func (c *AzureAIClient) Configure(config IAIConfig, lang string) error {
c.language = lang
c.client = client
c.model = config.GetModel()
c.temperature = config.GetTemperature()
return nil
}

Expand All @@ -55,6 +57,7 @@ func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string, prompt
Content: fmt.Sprintf(default_prompt, c.language, prompt),
},
},
Temperature: c.temperature,
})
if err != nil {
return "", err
Expand Down
10 changes: 6 additions & 4 deletions pkg/ai/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ import (
)

type CohereClient struct {
client *cohere.Client
language string
model string
client *cohere.Client
language string
model string
temperature float32
}

func (c *CohereClient) Configure(config IAIConfig, language string) error {
Expand All @@ -52,6 +53,7 @@ func (c *CohereClient) Configure(config IAIConfig, language string) error {
c.language = language
c.client = client
c.model = config.GetModel()
c.temperature = config.GetTemperature()
return nil
}

Expand All @@ -64,7 +66,7 @@ func (c *CohereClient) GetCompletion(ctx context.Context, prompt, promptTmpl str
Model: c.model,
Prompt: fmt.Sprintf(strings.TrimSpace(promptTmpl), c.language, prompt),
MaxTokens: cohere.Uint(2048),
Temperature: cohere.Float64(0.75),
Temperature: cohere.Float64(float64(c.temperature)),
K: cohere.Int(0),
StopSequences: []string{},
ReturnLikelihoods: "NONE",
Expand Down
15 changes: 10 additions & 5 deletions pkg/ai/iai.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type IAIConfig interface {
GetModel() string
GetBaseURL() string
GetEngine() string
GetTemperature() float32
}

func NewClient(provider string) IAI {
Expand All @@ -66,11 +67,12 @@ type AIConfiguration struct {
}

type AIProvider struct {
Name string `mapstructure:"name"`
Model string `mapstructure:"model"`
Password string `mapstructure:"password" yaml:"password,omitempty"`
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
Name string `mapstructure:"name"`
Model string `mapstructure:"model"`
Password string `mapstructure:"password" yaml:"password,omitempty"`
BaseURL string `mapstructure:"baseurl" yaml:"baseurl,omitempty"`
Engine string `mapstructure:"engine" yaml:"engine,omitempty"`
Temperature float32 `mapstructure:"temperature" yaml:"temperature,omitempty"`
}

func (p *AIProvider) GetBaseURL() string {
Expand All @@ -88,6 +90,9 @@ func (p *AIProvider) GetModel() string {
func (p *AIProvider) GetEngine() string {
return p.Engine
}
func (p *AIProvider) GetTemperature() float32 {
return p.Temperature
}

func NeedPassword(backend string) bool {
return backend != "localai"
Expand Down
21 changes: 18 additions & 3 deletions pkg/ai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,20 @@ import (
)

type OpenAIClient struct {
client *openai.Client
language string
model string
client *openai.Client
language string
model string
temperature float32
}

const (
// OpenAI completion parameters
maxToken = 2048
presencePenalty = 0.0
frequencyPenalty = 0.0
topP = 1.0
)

func (c *OpenAIClient) Configure(config IAIConfig, language string) error {
token := config.GetPassword()
defaultConfig := openai.DefaultConfig(token)
Expand All @@ -50,6 +59,7 @@ func (c *OpenAIClient) Configure(config IAIConfig, language string) error {
c.language = language
c.client = client
c.model = config.GetModel()
c.temperature = config.GetTemperature()
return nil
}

Expand All @@ -66,6 +76,11 @@ func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string, promptT
Content: fmt.Sprintf(promptTmpl, c.language, prompt),
},
},
Temperature: c.temperature,
MaxTokens: maxToken,
PresencePenalty: presencePenalty,
FrequencyPenalty: frequencyPenalty,
TopP: topP,
})
if err != nil {
return "", err
Expand Down