Skip to content

Commit

Permalink
sdk: allow use of StrictBuiltinErrors (#5438)
Browse files Browse the repository at this point in the history
This allows the struct builtin errors functionality to be used in the SDK by passing the value in DecisionOptions & PartialOptions.

Related to #5176

Signed-off-by: Charlie Egan <charlieegan3@users.noreply.github.com>
  • Loading branch information
charlieegan3 authored Dec 6, 2022
1 parent a4ed72e commit 8a0c080
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 53 deletions.
114 changes: 61 additions & 53 deletions sdk/opa.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,18 +241,19 @@ func (opa *OPA) Decision(ctx context.Context, options DecisionOptions) (*Decisio
&record,
func(s state, result *DecisionResult) {
result.Result, record.InputAST, record.Bundles, record.Error = evaluate(ctx, evalArgs{
runtime: s.manager.Info,
printHook: s.manager.PrintHook(),
compiler: s.manager.GetCompiler(),
store: s.manager.Store,
queryCache: s.queryCache,
interQueryCache: s.interQueryBuiltinCache,
ndbcache: ndbc,
txn: record.Txn,
now: record.Timestamp,
path: record.Path,
input: *record.Input,
m: record.Metrics,
runtime: s.manager.Info,
printHook: s.manager.PrintHook(),
compiler: s.manager.GetCompiler(),
store: s.manager.Store,
queryCache: s.queryCache,
interQueryCache: s.interQueryBuiltinCache,
ndbcache: ndbc,
txn: record.Txn,
now: record.Timestamp,
path: record.Path,
input: *record.Input,
m: record.Metrics,
strictBuiltinErrors: options.StrictBuiltinErrors,
})
if record.Error == nil {
record.Results = &result.Result
Expand All @@ -268,10 +269,11 @@ func (opa *OPA) Decision(ctx context.Context, options DecisionOptions) (*Decisio

// DecisionOptions contains parameters for query evaluation.
type DecisionOptions struct {
Now time.Time // specifies wallclock time used for time.now_ns(), decision log timestamp, etc.
Path string // specifies name of policy decision to evaluate (e.g., example/allow)
Input interface{} // specifies value of the input document to evaluate policy with
NDBCache interface{} // specifies the non-deterministic builtins cache to use for evaluation.
Now time.Time // specifies wallclock time used for time.now_ns(), decision log timestamp, etc.
Path string // specifies name of policy decision to evaluate (e.g., example/allow)
Input interface{} // specifies value of the input document to evaluate policy with
NDBCache interface{} // specifies the non-deterministic builtins cache to use for evaluation.
StrictBuiltinErrors bool // treat built-in function errors as fatal
}

// DecisionResult contains the output of query evaluation.
Expand Down Expand Up @@ -351,16 +353,17 @@ func (opa *OPA) Partial(ctx context.Context, options PartialOptions) (*PartialRe
&record,
func(s state, result *DecisionResult) {
pq, record.InputAST, record.Bundles, record.Error = partial(ctx, partialEvalArgs{
runtime: s.manager.Info,
printHook: s.manager.PrintHook(),
compiler: s.manager.GetCompiler(),
store: s.manager.Store,
txn: record.Txn,
now: record.Timestamp,
query: record.Query,
unknowns: options.Unknowns,
input: *record.Input,
m: record.Metrics,
runtime: s.manager.Info,
printHook: s.manager.PrintHook(),
compiler: s.manager.GetCompiler(),
store: s.manager.Store,
txn: record.Txn,
now: record.Timestamp,
query: record.Query,
unknowns: options.Unknowns,
input: *record.Input,
m: record.Metrics,
strictBuiltinErrors: options.StrictBuiltinErrors,
})
if record.Error == nil {
result.Result, record.Error = options.Mapper.MapResults(pq)
Expand Down Expand Up @@ -395,11 +398,12 @@ type PartialQueryMapper interface {

// PartialOptions contains parameters for partial query evaluation.
type PartialOptions struct {
Now time.Time // specifies wallclock time used for time.now_ns(), decision log timestamp, etc.
Input interface{} // specifies value of the input document to evaluate policy with
Query string // specifies the query to be partially evaluated
Unknowns []string // specifies the unknown elements of the policy
Mapper PartialQueryMapper // specifies the mapper to use when processing results
Now time.Time // specifies wallclock time used for time.now_ns(), decision log timestamp, etc.
Input interface{} // specifies value of the input document to evaluate policy with
Query string // specifies the query to be partially evaluated
Unknowns []string // specifies the unknown elements of the policy
Mapper PartialQueryMapper // specifies the mapper to use when processing results
StrictBuiltinErrors bool // treat built-in function errors as fatal
}

type PartialResult struct {
Expand Down Expand Up @@ -437,18 +441,19 @@ func IsUndefinedErr(err error) bool {
}

type evalArgs struct {
runtime *ast.Term
printHook print.Hook
compiler *ast.Compiler
store storage.Store
txn storage.Transaction
queryCache *queryCache
interQueryCache cache.InterQueryCache
now time.Time
path string
input interface{}
ndbcache builtins.NDBCache
m metrics.Metrics
runtime *ast.Term
printHook print.Hook
compiler *ast.Compiler
store storage.Store
txn storage.Transaction
queryCache *queryCache
interQueryCache cache.InterQueryCache
now time.Time
path string
input interface{}
ndbcache builtins.NDBCache
m metrics.Metrics
strictBuiltinErrors bool
}

func evaluate(ctx context.Context, args evalArgs) (interface{}, ast.Value, map[string]server.BundleInfo, error) {
Expand All @@ -472,6 +477,7 @@ func evaluate(ctx context.Context, args evalArgs) (interface{}, ast.Value, map[s
rego.Store(args.store),
rego.Transaction(args.txn),
rego.PrintHook(args.printHook),
rego.StrictBuiltinErrors(args.strictBuiltinErrors),
rego.Runtime(args.runtime)).PrepareForEval(ctx)
if err != nil {
return nil, err
Expand Down Expand Up @@ -506,16 +512,17 @@ func evaluate(ctx context.Context, args evalArgs) (interface{}, ast.Value, map[s
}

type partialEvalArgs struct {
runtime *ast.Term
compiler *ast.Compiler
printHook print.Hook
store storage.Store
txn storage.Transaction
unknowns []string
query string
now time.Time
input interface{}
m metrics.Metrics
runtime *ast.Term
compiler *ast.Compiler
printHook print.Hook
store storage.Store
txn storage.Transaction
unknowns []string
query string
now time.Time
input interface{}
m metrics.Metrics
strictBuiltinErrors bool
}

func partial(ctx context.Context, args partialEvalArgs) (*rego.PartialQueries, ast.Value, map[string]server.BundleInfo, error) {
Expand All @@ -540,6 +547,7 @@ func partial(ctx context.Context, args partialEvalArgs) (*rego.PartialQueries, a
rego.Query(args.query),
rego.Unknowns(args.unknowns),
rego.PrintHook(args.printHook),
rego.StrictBuiltinErrors(args.strictBuiltinErrors),
)

pq, err := re.Partial(ctx)
Expand Down
135 changes: 135 additions & 0 deletions sdk/opa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/logging"
"github.com/open-policy-agent/opa/rego"
"github.com/open-policy-agent/opa/topdown"
"github.com/open-policy-agent/opa/topdown/builtins"

"github.com/fortytw2/leaktest"
Expand Down Expand Up @@ -215,6 +216,68 @@ func TestDecision(t *testing.T) {
}
}

func TestDecisionWithStrictBuiltinErrors(t *testing.T) {

ctx := context.Background()

server := sdktest.MustNewServer(
sdktest.MockBundle("/bundles/bundle.tar.gz", map[string]string{
"main.rego": `
package example
erroring_function(number) = output {
output := number / 0
}
allow {
erroring_function(1)
}`,
}),
)

defer server.Stop()

config := fmt.Sprintf(`{
"services": {
"test": {
"url": %q
}
},
"bundles": {
"test": {
"resource": "/bundles/bundle.tar.gz"
}
}
}`, server.URL())

opa, err := sdk.New(ctx, sdk.Options{
Config: strings.NewReader(config),
})
if err != nil {
t.Fatal(err)
}

defer opa.Stop(ctx)

_, err = opa.Decision(ctx, sdk.DecisionOptions{
StrictBuiltinErrors: true,
Path: "/example/allow",
})
if err == nil {
t.Fatal("expected error but got nil")
}

actual, ok := err.(*topdown.Error)
if !ok || actual.Code != "eval_builtin_error" {
t.Fatalf("expected eval_builtin_error but got %v", actual)
}

expectedMessage := "div: divide by zero"
if actual.Message != expectedMessage {
t.Fatalf("expected %v but got %v", expectedMessage, actual.Message)
}
}

func TestPartial(t *testing.T) {

ctx := context.Background()
Expand Down Expand Up @@ -293,6 +356,78 @@ func TestPartial(t *testing.T) {

}

func TestPartialWithStrictBuiltinErrors(t *testing.T) {

ctx := context.Background()

server := sdktest.MustNewServer(
sdktest.MockBundle("/bundles/bundle.tar.gz", map[string]string{
"main.rego": `
package example
erroring_function(number) = output {
output := number / 0
}
allow {
erroring_function(1)
}`,
}),
)

defer server.Stop()

config := fmt.Sprintf(`{
"services": {
"test": {
"url": %q
}
},
"bundles": {
"test": {
"resource": "/bundles/bundle.tar.gz"
}
},
"decision_logs": {
"console": true
}
}`, server.URL())

testLogger := loggingtest.New()
opa, err := sdk.New(ctx, sdk.Options{
Config: strings.NewReader(config),
ConsoleLogger: testLogger,
})
if err != nil {
t.Fatal(err)
}

defer opa.Stop(ctx)

_, err = opa.Partial(ctx, sdk.PartialOptions{
Input: map[string]interface{}{},
Query: "data.example.allow",
Unknowns: []string{},
Mapper: &sdk.RawMapper{},
Now: time.Unix(0, 1619868194450288000).UTC(),
StrictBuiltinErrors: true,
})
if err == nil {
t.Fatal("expected error but got nil")
}

actual, ok := err.(*topdown.Error)
if !ok || actual.Code != "eval_builtin_error" {
t.Fatalf("expected eval_builtin_error but got %v", actual)
}

expectedMessage := "div: divide by zero"
if actual.Message != expectedMessage {
t.Fatalf("expected %v but got %v", expectedMessage, actual.Message)
}

}

func TestUndefinedError(t *testing.T) {

ctx := context.Background()
Expand Down

0 comments on commit 8a0c080

Please sign in to comment.