From 4e204d15b9567b0acb4228f879f13e1767f0cec0 Mon Sep 17 00:00:00 2001 From: Sertac Ozercan Date: Mon, 8 Apr 2024 17:45:31 +0000 Subject: [PATCH] update Signed-off-by: Sertac Ozercan --- demo/llm/README.md | 7 +-- pkg/gator/verify/runner.go | 2 +- .../pkg/client/drivers/llm/driver.go | 46 ++++++++----------- 3 files changed, 24 insertions(+), 31 deletions(-) diff --git a/demo/llm/README.md b/demo/llm/README.md index fec9c0b5f90..554306c79e4 100644 --- a/demo/llm/README.md +++ b/demo/llm/README.md @@ -1,9 +1,10 @@ > [!WARNING] > This is a demo of a prototype-stage feature and is subject to change. Feedback is welcome! + +> [!NOTE] +> LLM engine can be used in addition to Rego and CEL/K8s Native Validation drivers. It is not a replacement for Rego or CEL. > -> Do not use in production. -> -> LLM engine can be used in addition to Rego and CEL/K8s Native Validation drivers. It is not a replacement for them. +> Depending on your provider, LLM engine may have additional costs. Please refer to your provider's pricing details for more information. ## Pre-requisites diff --git a/pkg/gator/verify/runner.go b/pkg/gator/verify/runner.go index bfa85041dee..7efeba75fc2 100644 --- a/pkg/gator/verify/runner.go +++ b/pkg/gator/verify/runner.go @@ -38,7 +38,7 @@ type Runner struct { includeTrace bool useK8sCEL bool - useLLM bool + useLLM bool } func NewRunner(filesystem fs.FS, newClient func(includeTrace bool, useK8sCEL bool, useLLM bool) (gator.Client, error), opts ...RunnerOptions) (*Runner, error) { diff --git a/vendor/github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers/llm/driver.go b/vendor/github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers/llm/driver.go index 17579cb790a..6d0ffc7d76f 100644 --- a/vendor/github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers/llm/driver.go +++ b/vendor/github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers/llm/driver.go @@ -25,9 +25,11 @@ import ( ) const ( - maxRetries = 10 + maxRetries = 10 + // need minimum of 2023-12-01-preview for JSON mode azureOpenAIAPIVersion = "2024-03-01-preview" azureOpenAIURL = "openai.azure.com" + systemPrompt = "You are a policy engine for Kubernetes designed to output JSON. Input will be a policy definition, Kubernetes AdmissionRequest object, and parameters to apply to the policy if applicable. Output JSON should only have a 'decision' field with a boolean value and a 'reason' field with a string value explaining the decision, only if decision is false. Only output valid JSON." ) var ( @@ -43,20 +45,19 @@ type Driver struct { prompts map[string]string } -type LLMDecision struct { - Decision bool +var _ drivers.Driver = &Driver{} + +type Decision struct { Name string Constraint *unstructured.Unstructured + Decision bool Reason string } -type Decision struct { - Decision bool `json:"decision"` - Reason string `json:"reason"` +type ARGetter interface { + GetAdmissionRequest() *admissionv1.AdmissionRequest } -var _ drivers.Driver = &Driver{} - // Name returns the name of the driver. func (d *Driver) Name() string { return llmSchema.Name @@ -93,7 +94,6 @@ func (d *Driver) AddConstraint(_ context.Context, constraint *unstructured.Unstr if !found { return fmt.Errorf("no promptName with name: %q", promptName) } - return nil } @@ -110,7 +110,7 @@ func (d *Driver) RemoveData(_ context.Context, _ string, _ storage.Path) error { } func (d *Driver) Query(ctx context.Context, _ string, constraints []*unstructured.Unstructured, review interface{}, _ ...drivers.QueryOpt) (*drivers.QueryResponse, error) { - oaic, err := newOAIClients() + llmc, err := newLLMClients() if err != nil { return nil, err } @@ -121,7 +121,7 @@ func (d *Driver) Query(ctx context.Context, _ string, constraints []*unstructure } aRequest := arGetter.GetAdmissionRequest() - var allDecisions []*LLMDecision + var allDecisions []*Decision for _, constraint := range constraints { promptName := strings.ToLower(constraint.GetKind()) prompt, found := d.prompts[promptName] @@ -139,12 +139,12 @@ func (d *Driver) Query(ctx context.Context, _ string, constraints []*unstructure return nil, err } - llmPrompt := fmt.Sprintf("prompt: %s\nadmission request: %s\nparameters: %s", prompt, string(aRequest.Object.Raw), string(params)) + llmPrompt := fmt.Sprintf("policy: %s\nadmission request: %s\nparameters: %s", prompt, string(aRequest.Object.Raw), string(params)) var resp string r := retry.WithMaxRetries(maxRetries, retry.NewExponential(1*time.Second)) if err := retry.Do(ctx, r, func(ctx context.Context) error { - resp, err = oaic.openaiGptChatCompletion(ctx, llmPrompt) + resp, err = llmc.openaiGptChatCompletion(ctx, llmPrompt) requestErr := &openai.APIError{} if errors.As(err, &requestErr) { switch requestErr.HTTPStatusCode { @@ -164,7 +164,7 @@ func (d *Driver) Query(ctx context.Context, _ string, constraints []*unstructure } if !decision.Decision { - llmDecision := &LLMDecision{ + llmDecision := &Decision{ Decision: decision.Decision, Name: constraint.GetName(), Constraint: constraint, @@ -207,11 +207,11 @@ func (d *Driver) GetDescriptionForStat(_ string) (string, error) { panic("implement me") } -type oaiClients struct { +type llmClients struct { openAIClient openai.Client } -func newOAIClients() (oaiClients, error) { +func newLLMClients() (llmClients, error) { var config openai.ClientConfig // default to OpenAI API config = openai.DefaultConfig(*openAIAPIKey) @@ -232,19 +232,19 @@ func newOAIClients() (oaiClients, error) { config.APIVersion = azureOpenAIAPIVersion } - clients := oaiClients{ + clients := llmClients{ openAIClient: *openai.NewClientWithConfig(config), } return clients, nil } -func (c *oaiClients) openaiGptChatCompletion(ctx context.Context, prompt string) (string, error) { +func (c *llmClients) openaiGptChatCompletion(ctx context.Context, prompt string) (string, error) { req := openai.ChatCompletionRequest{ Model: *openAIDeploymentName, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleSystem, - Content: "You are a policy engine for Kubernetes designed to output JSON. Input will be a policy definition, Kubernetes AdmissionRequest object, and parameters to apply to the policy if applicable. Output JSON should only have a 'decision' field with a boolean value and a 'reason' field with a string value explaining the decision, only if decision is false.", + Content: systemPrompt, }, { Role: openai.ChatMessageRoleUser, @@ -270,11 +270,3 @@ func (c *oaiClients) openaiGptChatCompletion(ctx context.Context, prompt string) result := resp.Choices[0].Message.Content return result, nil } - -type ARGetter interface { - GetAdmissionRequest() *admissionv1.AdmissionRequest -} - -type IsAdmissionGetter interface { - IsAdmissionRequest() bool -}