diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index f5e37d34b9..6fd6e9b773 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -5,12 +5,17 @@ import ( ) type IAI interface { - Configure(token string, model string, language string) error + Configure(config IAIConfig, language string) error GetCompletion(ctx context.Context, prompt string) (string, error) Parse(ctx context.Context, prompt []string, nocache bool) (string, error) GetName() string } +type IAIConfig interface { + GetPassword() string + GetModel() string +} + func NewClient(provider string) IAI { switch provider { case "openai": @@ -31,3 +36,11 @@ type AIProvider struct { Model string `mapstructure:"model"` Password string `mapstructure:"password"` } + +func (p *AIProvider) GetPassword() string { + return p.Password +} + +func (p *AIProvider) GetModel() string { + return p.Model +} diff --git a/pkg/ai/noopai.go b/pkg/ai/noopai.go index 9f17e569d9..10950d426c 100644 --- a/pkg/ai/noopai.go +++ b/pkg/ai/noopai.go @@ -17,10 +17,11 @@ type NoOpAIClient struct { model string } -func (c *NoOpAIClient) Configure(token string, model string, language string) error { +func (c *NoOpAIClient) Configure(config IAIConfig, language string) error { + token := config.GetPassword() c.language = language c.client = fmt.Sprintf("I am a noop client with the token %s ", token) - c.model = model + c.model = config.GetModel() return nil } diff --git a/pkg/ai/openai.go b/pkg/ai/openai.go index f35bcbad01..cdf5919506 100644 --- a/pkg/ai/openai.go +++ b/pkg/ai/openai.go @@ -15,27 +15,22 @@ import ( "github.com/sashabaranov/go-openai" ) -const ( - default_prompt = "Simplify the following Kubernetes error message and provide a solution in %s: %s" - prompt_a = "Read the following input %s and provide possible scenarios for remediation in %s" - prompt_b = "Considering the following input from the Kubernetes resource %s and the error message %s, provide possible scenarios for remediation in %s" - prompt_c = "Reading the following %s error message and it's accompanying log message %s, how would you simplify this message?" -) - type OpenAIClient struct { client *openai.Client language string model string } -func (c *OpenAIClient) Configure(token string, model string, language string) error { - client := openai.NewClient(token) +func (c *OpenAIClient) Configure(config IAIConfig, language string) error { + token := config.GetPassword() + defaultConfig := openai.DefaultConfig(token) + client := openai.NewClientWithConfig(defaultConfig) if client == nil { return errors.New("error creating OpenAI client") } c.language = language c.client = client - c.model = model + c.model = config.GetModel() return nil } diff --git a/pkg/ai/prompts.go b/pkg/ai/prompts.go new file mode 100644 index 0000000000..e5b40db05b --- /dev/null +++ b/pkg/ai/prompts.go @@ -0,0 +1,8 @@ +package ai + +const ( + default_prompt = "Simplify the following Kubernetes error message and provide a solution in %s: %s" + prompt_a = "Read the following input %s and provide possible scenarios for remediation in %s" + prompt_b = "Considering the following input from the Kubernetes resource %s and the error message %s, provide possible scenarios for remediation in %s" + prompt_c = "Reading the following %s error message and it's accompanying log message %s, how would you simplify this message?" +) diff --git a/pkg/analysis/analysis.go b/pkg/analysis/analysis.go index c8fcadf9aa..2461f759c0 100644 --- a/pkg/analysis/analysis.go +++ b/pkg/analysis/analysis.go @@ -69,7 +69,7 @@ func NewAnalysis(backend string, language string, filters []string, namespace st } aiClient := ai.NewClient(aiProvider.Name) - if err := aiClient.Configure(aiProvider.Password, aiProvider.Model, language); err != nil { + if err := aiClient.Configure(&aiProvider, language); err != nil { color.Red("Error: %v", err) return nil, err }