From c85203bccde094c33ef83eb728aeed2608cbc136 Mon Sep 17 00:00:00 2001 From: Peter Pan Date: Tue, 13 Jun 2023 04:14:12 +0800 Subject: [PATCH] chore: customized prompt template for integration plugins (#403) Signed-off-by: Peter Pan Co-authored-by: Alex Jones --- pkg/ai/azureopenai.go | 6 +++--- pkg/ai/iai.go | 4 ++-- pkg/ai/noopai.go | 6 +++--- pkg/ai/openai.go | 11 +++++++---- pkg/ai/prompts.go | 6 ++++++ pkg/analysis/analysis.go | 9 ++++++++- 6 files changed, 29 insertions(+), 13 deletions(-) diff --git a/pkg/ai/azureopenai.go b/pkg/ai/azureopenai.go index 21593a25f1..aff6708a86 100644 --- a/pkg/ai/azureopenai.go +++ b/pkg/ai/azureopenai.go @@ -36,7 +36,7 @@ func (c *AzureAIClient) Configure(config IAIConfig, lang string) error { return nil } -func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) { +func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) { // Create a completion request resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ Model: c.model, @@ -53,7 +53,7 @@ func (c *AzureAIClient) GetCompletion(ctx context.Context, prompt string) (strin return resp.Choices[0].Message.Content, nil } -func (a *AzureAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache) (string, error) { +func (a *AzureAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) { inputKey := strings.Join(prompt, " ") // Check for cached data cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey) @@ -74,7 +74,7 @@ func (a *AzureAIClient) Parse(ctx context.Context, prompt []string, cache cache. } } - response, err := a.GetCompletion(ctx, inputKey) + response, err := a.GetCompletion(ctx, inputKey, promptTmpl) if err != nil { return "", err } diff --git a/pkg/ai/iai.go b/pkg/ai/iai.go index bd4dd0675a..6b3ea1d776 100644 --- a/pkg/ai/iai.go +++ b/pkg/ai/iai.go @@ -36,8 +36,8 @@ var ( type IAI interface { Configure(config IAIConfig, language string) error - GetCompletion(ctx context.Context, prompt string) (string, error) - Parse(ctx context.Context, prompt []string, cache cache.ICache) (string, error) + GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) + Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) GetName() string } diff --git a/pkg/ai/noopai.go b/pkg/ai/noopai.go index ff0877fceb..7183e1758c 100644 --- a/pkg/ai/noopai.go +++ b/pkg/ai/noopai.go @@ -38,20 +38,20 @@ func (c *NoOpAIClient) Configure(config IAIConfig, language string) error { return nil } -func (c *NoOpAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) { +func (c *NoOpAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) { // Create a completion request response := "I am a noop response to the prompt " + prompt return response, nil } -func (a *NoOpAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache) (string, error) { +func (a *NoOpAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) { // parse the text with the AI backend inputKey := strings.Join(prompt, " ") // Check for cached data sEnc := base64.StdEncoding.EncodeToString([]byte(inputKey)) cacheKey := util.GetCacheKey(a.GetName(), a.language, sEnc) - response, err := a.GetCompletion(ctx, inputKey) + response, err := a.GetCompletion(ctx, inputKey, promptTmpl) if err != nil { color.Red("error getting completion: %v", err) return "", err diff --git a/pkg/ai/openai.go b/pkg/ai/openai.go index dca391e270..7d9e6797af 100644 --- a/pkg/ai/openai.go +++ b/pkg/ai/openai.go @@ -53,14 +53,17 @@ func (c *OpenAIClient) Configure(config IAIConfig, language string) error { return nil } -func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) { +func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) { // Create a completion request + if len(promptTmpl) == 0 { + promptTmpl = PromptMap["default"] + } resp, err := c.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ Model: c.model, Messages: []openai.ChatCompletionMessage{ { Role: "user", - Content: fmt.Sprintf(default_prompt, c.language, prompt), + Content: fmt.Sprintf(promptTmpl, c.language, prompt), }, }, }) @@ -70,7 +73,7 @@ func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string) (string return resp.Choices[0].Message.Content, nil } -func (a *OpenAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache) (string, error) { +func (a *OpenAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) { inputKey := strings.Join(prompt, " ") // Check for cached data cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey) @@ -91,7 +94,7 @@ func (a *OpenAIClient) Parse(ctx context.Context, prompt []string, cache cache.I } } - response, err := a.GetCompletion(ctx, inputKey) + response, err := a.GetCompletion(ctx, inputKey, promptTmpl) if err != nil { return "", err } diff --git a/pkg/ai/prompts.go b/pkg/ai/prompts.go index 59a59886b1..626eab3343 100644 --- a/pkg/ai/prompts.go +++ b/pkg/ai/prompts.go @@ -6,4 +6,10 @@ const ( Error: {Explain error here} Solution: {Step by step solution here} ` + trivy_prompt = "Explain the following trivy scan result and the detail risk or root cause of the CVE ID, then provide a solution. Response in %s: %s" ) + +var PromptMap = map[string]string{ + "default": default_prompt, + "VulnerabilityReport": trivy_prompt, // for Trivy intergration, the key should match `Result.Kind` in pkg/common/types.go +} diff --git a/pkg/analysis/analysis.go b/pkg/analysis/analysis.go index d70cf3f4ca..aab590f96c 100644 --- a/pkg/analysis/analysis.go +++ b/pkg/analysis/analysis.go @@ -261,7 +261,14 @@ func (a *Analysis) GetAIResults(output string, anonymize bool) error { } texts = append(texts, failure.Text) } - parsedText, err := a.AIClient.Parse(a.Context, texts, a.Cache) + // If the resource `Kind` comes from a "integration plugin", maybe a customized prompt template will be involved. + var promptTemplate string + if prompt, ok := ai.PromptMap[analysis.Kind]; ok { + promptTemplate = prompt + } else { + promptTemplate = ai.PromptMap["default"] + } + parsedText, err := a.AIClient.Parse(a.Context, texts, a.Cache, promptTemplate) if err != nil { // FIXME: can we avoid checking if output is json multiple times? // maybe implement the progress bar better?