diff --git a/ast/builtins.go b/ast/builtins.go index 4b91b1e167..6d6576d901 100644 --- a/ast/builtins.go +++ b/ast/builtins.go @@ -125,12 +125,15 @@ var DefaultBuiltins = [...]*Builtin{ YAMLMarshal, YAMLUnmarshal, + // Object Manipulation + ObjectUnion, + ObjectRemove, + ObjectFilter, + ObjectGet, + // JSON Object Manipulation JSONFilter, - // Other object functions - ObjectGet, - // Tokens JWTDecode, JWTVerifyRS256, @@ -966,6 +969,62 @@ var JSONFilter = &Builtin{ ), } +// ObjectUnion creates a new object that is the asymmetric union of two objects +var ObjectUnion = &Builtin{ + Name: "object.union", + Decl: types.NewFunction( + types.Args( + types.NewObject( + nil, + types.NewDynamicProperty(types.A, types.A), + ), + types.NewObject( + nil, + types.NewDynamicProperty(types.A, types.A), + ), + ), + types.A, + ), +} + +// ObjectRemove Removes specified keys from an object +var ObjectRemove = &Builtin{ + Name: "object.remove", + Decl: types.NewFunction( + types.Args( + types.NewObject( + nil, + types.NewDynamicProperty(types.A, types.A), + ), + types.NewAny( + types.NewArray(nil, types.A), + types.NewSet(types.A), + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + ), + ), + types.A, + ), +} + +// ObjectFilter filters the object by keeping only specified keys +var ObjectFilter = &Builtin{ + Name: "object.filter", + Decl: types.NewFunction( + types.Args( + types.NewObject( + nil, + types.NewDynamicProperty(types.A, types.A), + ), + types.NewAny( + types.NewArray(nil, types.A), + types.NewSet(types.A), + types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), + ), + ), + types.A, + ), +} + // Base64Encode serializes the input string into base64 encoding. var Base64Encode = &Builtin{ Name: "base64.encode", diff --git a/ast/term.go b/ast/term.go index ec5cabd827..b6137d2c09 100644 --- a/ast/term.go +++ b/ast/term.go @@ -1571,6 +1571,7 @@ type Object interface { Diff(other Object) Object Intersect(other Object) [][3]*Term Merge(other Object) (Object, bool) + MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) Filter(filter Object) (Object, error) Keys() []*Term } @@ -1808,28 +1809,48 @@ func (obj *object) MarshalJSON() ([]byte, error) { // overlapping keys between obj and other, the values of associated with the keys are merged. Only // objects can be merged with other objects. If the values cannot be merged, the second turn value // will be false. -func (obj object) Merge(other Object) (result Object, ok bool) { - result = NewObject() +func (obj object) Merge(other Object) (Object, bool) { + return obj.MergeWith(other, func(v1, v2 *Term) (*Term, bool) { + obj1, ok1 := v1.Value.(Object) + obj2, ok2 := v2.Value.(Object) + if !ok1 || !ok2 { + return nil, true + } + obj3, ok := obj1.Merge(obj2) + if !ok { + return nil, true + } + return NewTerm(obj3), false + }) +} + +// MergeWith returns a new Object containing the merged keys of obj and other. +// If there are overlapping keys between obj and other, the conflictResolver +// is called. The conflictResolver can return a merged value and a boolean +// indicating if the merge has failed and should stop. +func (obj object) MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) { + result := NewObject() stop := obj.Until(func(k, v *Term) bool { - if v2 := other.Get(k); v2 == nil { + v2 := other.Get(k) + // The key didn't exist in other, keep the original value + if v2 == nil { result.Insert(k, v) - } else { - obj1, ok1 := v.Value.(Object) - obj2, ok2 := v2.Value.(Object) - if !ok1 || !ok2 { - return true - } - obj3, ok := obj1.Merge(obj2) - if !ok { - return true - } - result.Insert(k, NewTerm(obj3)) + return false } - return false + + // The key exists in both, resolve the conflict if possible + merged, stop := conflictResolver(v, v2) + if !stop { + result.Insert(k, merged) + } + return stop }) + if stop { return nil, false } + + // Copy in any values from other for keys that don't exist in obj other.Foreach(func(k, v *Term) { if v2 := obj.Get(k); v2 == nil { result.Insert(k, v) diff --git a/docs/content/policy-cheatsheet.md b/docs/content/policy-cheatsheet.md index 24b4292684..67490066df 100644 --- a/docs/content/policy-cheatsheet.md +++ b/docs/content/policy-cheatsheet.md @@ -259,6 +259,15 @@ merge_objects(a, b) = c { } ``` +> Note: use the `object.union` builtin function unless custom behavior is required! + +```live:rules/merge_builtin:query:read_only +x := {"a": true, "b": false} +y := {"b": "foo", "c": 4} +z := {"a": true, "b": "foo", "c": 4} +object.union(y, x) == z +``` + ## Tests ```live:tests:module:read_only diff --git a/docs/content/policy-reference.md b/docs/content/policy-reference.md index eccd4dd313..ffb82bc0b0 100644 --- a/docs/content/policy-reference.md +++ b/docs/content/policy-reference.md @@ -69,8 +69,15 @@ complex types. ### Objects | Built-in | Description | | -------- | ----------- | -| `filtered := json.filter(object, paths)` | `filtered` is the remaining data from `object` with only keys specified in `paths` which is an array or set of key paths. Each path may be a JSON string path or an array of path segments. For example: `json.filter({"a": {"b": "x", "c": "y"}}, ["a/b"]` will result in `{"a": {"b": "x"}}`). The `json` string `paths` may reference into array values by using index numbers. For example with the object `{"a": ["x", "y", "z"]}` the path `a/1` references `y` | | `value := object.get(object, key, default)` | `value` is the value stored by the `object` at `key`. If no value is found, `default` is returned. | +| `output := object.remove(object, keys)` | `output` is a new object which is the result of removing the specified `keys` from `object`. `keys` must be either an array, object, or set of keys. | +| `output := object.union(objectA, objectB)` | `output` is a new object which is the result of an asymmetric recursive union of two objects where conflicts are resolved by choosing the key from the right-hand object (`objectB`). For example: `object.union({"a": 1, "b": 2, "c": {"d": 3}}, {"a": 7, "c": {"d": 4, "e": 5}})` will result in `{"a": 7, "b": 2, "c": {"d": 4, "e": 5}}` | +| `filtered := object.filter(object, keys)` | `filtered` is a new object with the remaining data from `object` with only keys specified in `keys` which is an array, object, or set of keys. For example: `object.filter({"a": {"b": "x", "c": "y"}, "d": "z"}, ["a"])` will result in `{"a": {"b": "x", "c": "y"}}`). | +| `filtered := json.filter(object, paths)` | `filtered` is the remaining data from `object` with only keys specified in `paths` which is an array or set of JSON string paths. For example: `json.filter({"a": {"b": "x", "c": "y"}}, ["a/b"])` will result in `{"a": {"b": "x"}}`). | + +> The `json` string `paths` may reference into array values by using index numbers. For example with the object `{"a": ["x", "y", "z"]}` the path `a[1]` references `y` + +> When `keys` are provided as an object only the top level keys on the object will be used, values are ignored. For example: `object.remove({"a": {"b": {"c": 2}}, "x": 123}, {"a": 1}) == {"x": 123}` regardless of the value for key `a` in the keys object, the following `keys` object gives the same result `object.remove({"a": {"b": {"c": 2}}, "x": 123}, {"a": {"b": {"foo": "bar"}}}) == {"x": 123}` ### Strings diff --git a/topdown/array_test.go b/topdown/array_test.go index b984e09edc..094505c4c5 100644 --- a/topdown/array_test.go +++ b/topdown/array_test.go @@ -5,7 +5,6 @@ package topdown import ( - "fmt" "testing" ) @@ -17,8 +16,8 @@ func TestTopDownArray(t *testing.T) { expected interface{} }{ {"concat", []string{`p = x { x = array.concat([1,2], [3,4]) }`}, "[1,2,3,4]"}, - {"concat: err", []string{`p = x { x = array.concat(data.b, [3,4]) }`}, fmt.Errorf("operand 1 must be array")}, - {"concat: err rhs", []string{`p = x { x = array.concat([1,2], data.b) }`}, fmt.Errorf("operand 2 must be array")}, + {"concat: err", []string{`p = x { x = array.concat(data.b, [3,4]) }`}, &Error{Code: TypeErr, Message: "array.concat: operand 1 must be array but got object"}}, + {"concat: err rhs", []string{`p = x { x = array.concat([1,2], data.b) }`}, &Error{Code: TypeErr, Message: "array.concat: operand 2 must be array but got object"}}, {"slice", []string{`p = x { x = array.slice([1,2,3,4,5], 1, 3) }`}, "[2,3]"}, {"slice: empty slice", []string{`p = x { x = array.slice([1,2,3], 0, 0) }`}, "[]"}, diff --git a/topdown/casts_test.go b/topdown/casts_test.go index 31cf536f8b..d70eec0f7b 100644 --- a/topdown/casts_test.go +++ b/topdown/casts_test.go @@ -5,7 +5,6 @@ package topdown import ( - "fmt" "testing" "github.com/open-policy-agent/opa/ast" @@ -20,7 +19,7 @@ func TestToArray(t *testing.T) { panic(err) } - typeErr := fmt.Errorf("type") + typeErr := &Error{Code: TypeErr, Message: "operand 1 must be one of {array, set}"} tests := []struct { note string @@ -41,7 +40,7 @@ func TestToArray(t *testing.T) { func TestToSet(t *testing.T) { - typeErr := fmt.Errorf("type") + typeErr := &Error{Code: TypeErr, Message: "operand 1 must be one of {array, set}"} tests := []struct { note string @@ -61,7 +60,7 @@ func TestToSet(t *testing.T) { } func TestCasts(t *testing.T) { - typeErr := fmt.Errorf("type") + typeErr := &Error{Code: TypeErr} tests := []struct { note string diff --git a/topdown/cidr_test.go b/topdown/cidr_test.go index 4bf6815878..d7158c724f 100644 --- a/topdown/cidr_test.go +++ b/topdown/cidr_test.go @@ -2,7 +2,7 @@ package topdown import ( "context" - "errors" + "net" "testing" "time" @@ -120,7 +120,7 @@ func TestNetCIDRExpand(t *testing.T) { rules: []string{ `p = x { net.cidr_expand("192.168.1.1/33", x) }`, }, - expected: errors.New("invalid CIDR address"), + expected: &net.ParseError{Type: "CIDR address", Text: "192.168.1.1/33"}, }, } diff --git a/topdown/crypto_test.go b/topdown/crypto_test.go index cce93aee02..133dda8990 100644 --- a/topdown/crypto_test.go +++ b/topdown/crypto_test.go @@ -40,7 +40,7 @@ func TestCryptoX509ParseCertificates(t *testing.T) { note: "bad", certs: `YmFkc3RyaW5n`, rule: rule, - expected: fmt.Errorf("asn1: structure error"), + expected: &Error{Code: BuiltinErr, Message: "asn1: structure error"}, }, } diff --git a/topdown/http_test.go b/topdown/http_test.go index 4e8a968aa8..ae99eb0de6 100644 --- a/topdown/http_test.go +++ b/topdown/http_test.go @@ -9,7 +9,6 @@ import ( "crypto/tls" "crypto/x509" "encoding/json" - "errors" "fmt" "io/ioutil" "net/http" @@ -183,8 +182,8 @@ func TestHTTPCustomHeaders(t *testing.T) { } } -// TestHTTPostRequest adds a new person -func TestHTTPostRequest(t *testing.T) { +// TestHTTPPostRequest adds a new person +func TestHTTPPostRequest(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -258,7 +257,7 @@ func TestHTTPostRequest(t *testing.T) { "headers": {"Content-Type": "application/x-www-form-encoded"}, "raw_body": {"bar": "bar"} }`, - expected: errors.New("raw_body must be a string"), + expected: &Error{Code: BuiltinErr, Message: "raw_body must be a string"}, }, } @@ -365,8 +364,8 @@ func TestInvalidKeyError(t *testing.T) { rules []string expected interface{} }{ - {"invalid keys", []string{`p = x { http.send({"method": "get", "url": "http://127.0.0.1:51113", "bad_key": "bad_value"}, x) }`}, fmt.Errorf(`invalid request parameters(s): {"bad_key"}`)}, - {"missing keys", []string{`p = x { http.send({"method": "get"}, x) }`}, fmt.Errorf(`missing required request parameters(s): {"url"}`)}, + {"invalid keys", []string{`p = x { http.send({"method": "get", "url": "http://127.0.0.1:51113", "bad_key": "bad_value"}, x) }`}, &Error{Code: TypeErr, Message: `invalid request parameters(s): {"bad_key"}`}}, + {"missing keys", []string{`p = x { http.send({"method": "get"}, x) }`}, &Error{Code: TypeErr, Message: `missing required request parameters(s): {"url"}`}}, } data := loadSmallTestData() @@ -644,7 +643,7 @@ func TestHTTPSClient(t *testing.T) { t.Run("Negative Test: No Root Ca", func(t *testing.T) { - expectedResult := Error{Code: BuiltinErr, Message: "x509: certificate signed by unknown authority", Location: nil} + expectedResult := &Error{Code: BuiltinErr, Message: "x509: certificate signed by unknown authority", Location: nil} data := loadSmallTestData() rule := []string{fmt.Sprintf( `p = x { http.send({"method": "get", "url": "%s", "tls_client_cert_file": "%s", "tls_client_key_file": "%s"}, x) }`, s.URL, localClientCertFile, localClientKeyFile)} @@ -655,7 +654,7 @@ func TestHTTPSClient(t *testing.T) { t.Run("Negative Test: Wrong Cert/Key Pair", func(t *testing.T) { - expectedResult := Error{Code: BuiltinErr, Message: "tls: private key does not match public key", Location: nil} + expectedResult := &Error{Code: BuiltinErr, Message: "tls: private key does not match public key", Location: nil} data := loadSmallTestData() rule := []string{fmt.Sprintf( `p = x { http.send({"method": "get", "url": "%s", "tls_ca_cert_file": "%s", "tls_client_cert_file": "%s", "tls_client_key_file": "%s"}, x) }`, s.URL, localCaFile, localClientCert2File, localClientKeyFile)} @@ -666,7 +665,7 @@ func TestHTTPSClient(t *testing.T) { t.Run("Negative Test: System Certs do not include local rootCA", func(t *testing.T) { - expectedResult := Error{Code: BuiltinErr, Message: "x509: certificate signed by unknown authority", Location: nil} + expectedResult := &Error{Code: BuiltinErr, Message: "x509: certificate signed by unknown authority", Location: nil} data := loadSmallTestData() rule := []string{fmt.Sprintf( `p = x { http.send({"method": "get", "url": "%s", "tls_client_cert_file": "%s", "tls_client_key_file": "%s", "tls_use_system_certs": true}, x) }`, s.URL, localClientCertFile, localClientKeyFile)} diff --git a/topdown/object.go b/topdown/object.go new file mode 100644 index 0000000000..5fb3ee8470 --- /dev/null +++ b/topdown/object.go @@ -0,0 +1,139 @@ +// Copyright 2020 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package topdown + +import ( + "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/topdown/builtins" + "github.com/open-policy-agent/opa/types" +) + +func builtinObjectUnion(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { + objA, err := builtins.ObjectOperand(operands[0].Value, 1) + if err != nil { + return err + } + + objB, err := builtins.ObjectOperand(operands[1].Value, 2) + if err != nil { + return err + } + + r := mergeWithOverwrite(objA, objB) + + return iter(ast.NewTerm(r)) +} + +func builtinObjectRemove(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { + // Expect an object and an array/set/object of keys + obj, err := builtins.ObjectOperand(operands[0].Value, 1) + if err != nil { + return err + } + + // Build a set of keys to remove + keysToRemove, err := getObjectKeysParam(operands[1].Value) + if err != nil { + return err + } + r := ast.NewObject() + obj.Foreach(func(key *ast.Term, value *ast.Term) { + if !keysToRemove.Contains(key) { + r.Insert(key, value) + } + }) + + return iter(ast.NewTerm(r)) +} + +func builtinObjectFilter(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { + // Expect an object and an array/set/object of keys + obj, err := builtins.ObjectOperand(operands[0].Value, 1) + if err != nil { + return err + } + + // Build a new object from the supplied filter keys + keys, err := getObjectKeysParam(operands[1].Value) + if err != nil { + return err + } + + filterObj := ast.NewObject() + keys.Foreach(func(key *ast.Term) { + filterObj.Insert(key, ast.NullTerm()) + }) + + // Actually do the filtering + r, err := obj.Filter(filterObj) + if err != nil { + return err + } + + return iter(ast.NewTerm(r)) +} + +func builtinObjectGet(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { + object, err := builtins.ObjectOperand(operands[0].Value, 1) + if err != nil { + return err + } + + if ret := object.Get(operands[1]); ret != nil { + return iter(ret) + } + + return iter(operands[2]) +} + +// getObjectKeysParam returns a set of key values +// from a supplied ast array, object, set value +func getObjectKeysParam(arrayOrSet ast.Value) (ast.Set, error) { + keys := ast.NewSet() + + switch v := arrayOrSet.(type) { + case ast.Array: + for _, f := range v { + keys.Add(f) + } + case ast.Set: + _ = v.Iter(func(f *ast.Term) error { + keys.Add(f) + return nil + }) + case ast.Object: + _ = v.Iter(func(k *ast.Term, _ *ast.Term) error { + keys.Add(k) + return nil + }) + default: + return nil, builtins.NewOperandTypeErr(2, arrayOrSet, ast.TypeName(types.Object{}), ast.TypeName(types.S), ast.TypeName(types.Array{})) + } + + return keys, nil +} + +func mergeWithOverwrite(objA, objB ast.Object) ast.Object { + merged, _ := objA.MergeWith(objB, func(v1, v2 *ast.Term) (*ast.Term, bool) { + originalValueObj, ok2 := v1.Value.(ast.Object) + updateValueObj, ok1 := v2.Value.(ast.Object) + if !ok1 || !ok2 { + // If we can't merge, stick with the right-hand value + return v2, false + } + + // Recursively update the existing value + merged := mergeWithOverwrite(originalValueObj, updateValueObj) + return ast.NewTerm(merged), false + }) + return merged +} + +func init() { + RegisterBuiltinFunc(ast.ObjectUnion.Name, builtinObjectUnion) + RegisterBuiltinFunc(ast.ObjectRemove.Name, builtinObjectRemove) + RegisterBuiltinFunc(ast.ObjectFilter.Name, builtinObjectFilter) + RegisterBuiltinFunc(ast.ObjectGet.Name, builtinObjectGet) +} diff --git a/topdown/object_get.go b/topdown/object_get.go deleted file mode 100644 index 7d21ae5e4e..0000000000 --- a/topdown/object_get.go +++ /dev/null @@ -1,23 +0,0 @@ -package topdown - -import ( - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/topdown/builtins" -) - -func builtinObjectGet(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { - object, err := builtins.ObjectOperand(operands[0].Value, 1) - if err != nil { - return err - } - - if ret := object.Get(operands[1]); ret != nil { - return iter(ret) - } - - return iter(operands[2]) -} - -func init() { - RegisterBuiltinFunc(ast.ObjectGet.Name, builtinObjectGet) -} diff --git a/topdown/object_get_test.go b/topdown/object_get_test.go deleted file mode 100644 index 1b4ca15ae0..0000000000 --- a/topdown/object_get_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package topdown - -import ( - "fmt" - "testing" -) - -func TestObjectGet(t *testing.T) { - cases := []struct { - note string - object string - key interface{} - fallback interface{} - expected interface{} - }{ - { - note: "basic case . found", - object: `{"a": "b"}`, - key: `"a"`, - fallback: `"c"`, - expected: `"b"`, - }, - { - note: "basic case . not found", - object: `{"a": "b"}`, - key: `"c"`, - fallback: `"c"`, - expected: `"c"`, - }, - { - - note: "integer key . found", - object: "{1: 2}", - key: "1", - fallback: "3", - expected: "2", - }, - { - note: "integer key . not found", - object: "{1: 2}", - key: "2", - fallback: "3", - expected: "3", - }, - { - note: "complex value . found", - object: `{"a": {"b": "c"}}`, - key: `"a"`, - fallback: "true", - expected: `{"b": "c"}`, - }, - { - note: "complex value . not found", - object: `{"a": {"b": "c"}}`, - key: `"b"`, - fallback: "true", - expected: "true", - }, - } - - for _, tc := range cases { - rules := []string{ - fmt.Sprintf("p = x { x := object.get(%s, %s, %s) }", tc.object, tc.key, tc.fallback), - } - runTopDownTestCase(t, map[string]interface{}{}, tc.note, rules, tc.expected) - } -} diff --git a/topdown/object_test.go b/topdown/object_test.go new file mode 100644 index 0000000000..2b907773b7 --- /dev/null +++ b/topdown/object_test.go @@ -0,0 +1,544 @@ +// Copyright 2020 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package topdown + +import ( + "fmt" + "testing" + + "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/topdown/builtins" +) + +func TestObjectGet(t *testing.T) { + cases := []struct { + note string + object string + key interface{} + fallback interface{} + expected interface{} + }{ + { + note: "basic case . found", + object: `{"a": "b"}`, + key: `"a"`, + fallback: `"c"`, + expected: `"b"`, + }, + { + note: "basic case . not found", + object: `{"a": "b"}`, + key: `"c"`, + fallback: `"c"`, + expected: `"c"`, + }, + { + + note: "integer key . found", + object: "{1: 2}", + key: "1", + fallback: "3", + expected: "2", + }, + { + note: "integer key . not found", + object: "{1: 2}", + key: "2", + fallback: "3", + expected: "3", + }, + { + note: "complex value . found", + object: `{"a": {"b": "c"}}`, + key: `"a"`, + fallback: "true", + expected: `{"b": "c"}`, + }, + { + note: "complex value . not found", + object: `{"a": {"b": "c"}}`, + key: `"b"`, + fallback: "true", + expected: "true", + }, + } + + for _, tc := range cases { + rules := []string{ + fmt.Sprintf("p = x { x := object.get(%s, %s, %s) }", tc.object, tc.key, tc.fallback), + } + runTopDownTestCase(t, map[string]interface{}{}, tc.note, rules, tc.expected) + } +} + +func TestBuiltinObjectUnion(t *testing.T) { + cases := []struct { + note string + objectA string + objectB string + input string + expected interface{} + }{ + { + note: "both empty", + objectA: `{}`, + objectB: `{}`, + expected: `{}`, + }, + { + note: "left empty", + objectA: `{}`, + objectB: `{"a": 1}`, + expected: `{"a": 1}`, + }, + { + note: "right empty", + objectA: `{"a": 1}`, + objectB: `{}`, + expected: `{"a": 1}`, + }, + { + note: "base", + objectA: `{"a": 1}`, + objectB: `{"b": 2}`, + expected: `{"a": 1, "b": 2}`, + }, + { + note: "nested", + objectA: `{"a": {"b": {"c": 1}}}`, + objectB: `{"b": 2}`, + expected: `{"a": {"b": {"c": 1}}, "b": 2}`, + }, + { + note: "nested reverse", + objectA: `{"b": 2}`, + objectB: `{"a": {"b": {"c": 1}}}`, + expected: `{"a": {"b": {"c": 1}}, "b": 2}`, + }, + { + note: "conflict simple", + objectA: `{"a": 1}`, + objectB: `{"a": 2}`, + expected: `{"a": 2}`, + }, + { + note: "conflict nested and extra field", + objectA: `{"a": 1}`, + objectB: `{"a": {"b": {"c": 1}}, "d": 7}`, + expected: `{"a": {"b": {"c": 1}}, "d": 7}`, + }, + { + note: "conflict multiple", + objectA: `{"a": {"b": {"c": 1}}, "e": 1}`, + objectB: `{"a": {"b": "foo", "b1": "bar"}, "d": 7, "e": 17}`, + expected: `{"a": {"b": "foo", "b1": "bar"}, "d": 7, "e": 17}`, + }, + { + note: "error wrong lhs type", + objectA: `[1, 2, 3]`, + objectB: `{"b": 2}`, + expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.union: invalid argument(s)")}, + }, + { + note: "error wrong lhs type input", + objectA: `input.a`, + objectB: `{"b": 2}`, + input: `{"a": [1, 2, 3]}`, + expected: builtins.NewOperandErr(1, "must be object but got array"), + }, + { + note: "error wrong rhs type", + objectA: `{"a": 1}`, + objectB: `[1, 2, 3]`, + expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.union: invalid argument(s)")}, + }, + { + note: "error wrong rhs type input", + objectA: `{"a": 1}`, + objectB: `input.b`, + input: `{"b": [1, 2, 3]}`, + expected: builtins.NewOperandErr(2, "must be object but got array"), + }, + { + note: "error wrong both params", + objectA: `"foo"`, + objectB: `[1, 2, 3]`, + expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.union: invalid argument(s)")}, + }, + } + + for _, tc := range cases { + rules := []string{ + fmt.Sprintf("p = x { x := object.union(%s, %s) }", tc.objectA, tc.objectB), + } + runTopDownTestCaseWithModules(t, map[string]interface{}{}, tc.note, rules, nil, tc.input, tc.expected) + } +} + +func TestBuiltinObjectRemove(t *testing.T) { + cases := []struct { + note string + object string + keys string + input string + expected interface{} + }{ + { + note: "base", + object: `{"a": 1, "b": {"c": 3}}`, + keys: `{"a"}`, + expected: `{"b": {"c": 3}}`, + }, + { + note: "multiple keys set", + object: `{"a": 1, "b": {"c": 3}, "d": 4}`, + keys: `{"d", "b"}`, + expected: `{"a": 1}`, + }, + { + note: "multiple keys array", + object: `{"a": 1, "b": {"c": 3}, "d": 4}`, + keys: `["d", "b"]`, + expected: `{"a": 1}`, + }, + { + note: "multiple keys object", + object: `{"a": 1, "b": {"c": 3}, "d": 4}`, + keys: `{"d": "", "b": 1}`, + expected: `{"a": 1}`, + }, + { + note: "multiple keys object nested", + object: `{"a": {"b": {"c": 2}}, "x": 123}`, + keys: `{"a": {"b": {"foo": "bar"}}}`, + expected: `{"x": 123}`, + }, + { + note: "empty object", + object: `{}`, + keys: `{"a", "b"}`, + expected: `{}`, + }, + { + note: "empty keys set", + object: `{"a": 1, "b": {"c": 3}}`, + keys: `set()`, + expected: `{"a": 1, "b": {"c": 3}}`, + }, + { + note: "empty keys array", + object: `{"a": 1, "b": {"c": 3}}`, + keys: `[]`, + expected: `{"a": 1, "b": {"c": 3}}`, + }, + { + note: "empty keys obj", + object: `{"a": 1, "b": {"c": 3}}`, + keys: `{}`, + expected: `{"a": 1, "b": {"c": 3}}`, + }, + { + note: "key doesnt exist", + object: `{"a": 1, "b": {"c": 3}}`, + keys: `{"z"}`, + expected: `{"a": 1, "b": {"c": 3}}`, + }, + { + note: "error invalid object param type set", + object: `{"a"}`, + keys: `{"a"}`, + expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.remove: invalid argument(s)")}, + }, + { + note: "error invalid object param type bool", + object: `false`, + keys: `{"a"}`, + expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.remove: invalid argument(s)")}, + }, + { + note: "error invalid object param type array input", + object: `input.x`, + keys: `{"a"}`, + input: `{"x": ["a"]}`, + expected: builtins.NewOperandErr(1, "must be object but got array"), + }, + { + note: "error invalid object param type bool input", + object: `input.x`, + keys: `{"a"}`, + input: `{"x": false}`, + expected: builtins.NewOperandErr(1, "must be object but got boolean"), + }, + { + note: "error invalid object param type number input", + object: `input.x`, + keys: `{"a"}`, + input: `{"x": 123}`, + expected: builtins.NewOperandErr(1, "must be object but got number"), + }, + { + note: "error invalid object param type string input", + object: `input.x`, + keys: `{"a"}`, + input: `{"x": "foo"}`, + expected: builtins.NewOperandErr(1, "must be object but got string"), + }, + { + note: "error invalid object param type nil input", + object: `input.x`, + keys: `{"a"}`, + input: `{"x": none}`, + expected: builtins.NewOperandErr(1, "must be object but got var"), + }, + { + note: "error invalid key param type string", + object: `{"a": 1}`, + keys: `"a"`, + expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.remove: invalid argument(s)")}, + }, + { + note: "error invalid key param type boolean", + object: `{"a": 1}`, + keys: `false`, + expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.remove: invalid argument(s)")}, + }, + { + note: "error invalid key param type string input", + object: `{"a": 1}`, + keys: `input.x`, + input: `{"x": "foo"}`, + expected: builtins.NewOperandErr(2, "must be one of {object, string, array} but got string"), + }, + { + note: "error invalid key param type boolean input", + object: `{"a": 1}`, + keys: `input.x`, + input: `{"x": true}`, + expected: builtins.NewOperandErr(2, "must be one of {object, string, array} but got boolean"), + }, + { + note: "error invalid key param type number input", + object: `{"a": 1}`, + keys: `input.x`, + input: `{"x": 22}`, + expected: builtins.NewOperandErr(2, "must be one of {object, string, array} but got number"), + }, + { + note: "error invalid key param type nil input", + object: `{"a": 1}`, + keys: `input.x`, + input: `{"x": none}`, + expected: builtins.NewOperandErr(2, "must be one of {object, string, array} but got var"), + }, + } + + for _, tc := range cases { + rules := []string{ + fmt.Sprintf("p = x { x := object.remove(%s, %s) }", tc.object, tc.keys), + } + runTopDownTestCaseWithModules(t, map[string]interface{}{}, tc.note, rules, nil, tc.input, tc.expected) + } +} + +func TestBuiltinObjectRemoveIdempotent(t *testing.T) { + rule := ` + p { + # "base" should never be mutated + base := {"a": 1, "b": 2, "c": 3} + object.remove(base, {"a"}) == {"b": 2, "c": 3} + object.remove(base, {"b"}) == {"a": 1, "c": 3} + object.remove(base, {"c"}) == {"a": 1, "b": 2} + base == {"a": 1, "b": 2, "c": 3} + } + ` + runTopDownTestCase(t, map[string]interface{}{}, t.Name(), []string{rule}, "true") +} + +func TestBuiltinObjectRemoveNonStringKey(t *testing.T) { + rules := []string{ + `p { x := object.remove({"a": 1, [[7]]: 2}, {[[7]]}); x == {"a": 1} }`, + } + runTopDownTestCase(t, map[string]interface{}{}, "non string root", rules, "true") +} + +func TestBuiltinObjectFilter(t *testing.T) { + cases := []struct { + note string + object string + filters string + input string + expected interface{} + }{ + { + note: "base", + object: `{"a": {"b": {"c": 7, "d": 8}}, "e": 9}`, + filters: `{"a"}`, + expected: `{"a": {"b": {"c": 7, "d": 8}}}`, + }, + { + note: "multiple roots set", + object: `{"a": 1, "b": 2, "c": 3, "e": 9}`, + filters: `{"a", "e"}`, + expected: `{"a": 1, "e": 9}`, + }, + { + note: "multiple roots array", + object: `{"a": 1, "b": 2, "c": 3, "e": 9}`, + filters: `["a", "e"]`, + expected: `{"a": 1, "e": 9}`, + }, + { + note: "multiple roots object", + object: `{"a": 1, "b": 2, "c": 3, "e": 9}`, + filters: `{"a": "foo", "e": ""}`, + expected: `{"a": 1, "e": 9}`, + }, + { + note: "duplicate roots", + object: `{"a": {"b": {"c": 7, "d": 8}}, "e": 9}`, + filters: `{"a", "a"}`, + expected: `{"a": {"b": {"c": 7, "d": 8}}}`, + }, + { + note: "empty roots set", + object: `{"a": 7}`, + filters: `set()`, + expected: `{}`, + }, + { + note: "empty roots array", + object: `{"a": 7}`, + filters: `[]`, + expected: `{}`, + }, + { + note: "empty roots object", + object: `{"a": 7}`, + filters: `{}`, + expected: `{}`, + }, + { + note: "empty object", + object: `{}`, + filters: `{"a"}`, + expected: `{}`, + }, + { + note: "error invalid object param type set", + object: `{"a"}`, + filters: `{"a"}`, + expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.filter: invalid argument(s)")}, + }, + { + note: "error invalid object param type bool", + object: `false`, + filters: `{"a"}`, + expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.filter: invalid argument(s)")}, + }, + { + note: "error invalid object param type array input", + object: `input.x`, + filters: `{"a"}`, + input: `{"x": ["a"]}`, + expected: builtins.NewOperandErr(1, "must be object but got array"), + }, + { + note: "error invalid object param type bool input", + object: `input.x`, + filters: `{"a"}`, + input: `{"x": false}`, + expected: builtins.NewOperandErr(1, "must be object but got boolean"), + }, + { + note: "error invalid object param type number input", + object: `input.x`, + filters: `{"a"}`, + input: `{"x": 123}`, + expected: builtins.NewOperandErr(1, "must be object but got number"), + }, + { + note: "error invalid object param type string input", + object: `input.x`, + filters: `{"a"}`, + input: `{"x": "foo"}`, + expected: builtins.NewOperandErr(1, "must be object but got string"), + }, + { + note: "error invalid object param type nil input", + object: `input.x`, + filters: `{"a"}`, + input: `{"x": none}`, + expected: builtins.NewOperandErr(1, "must be object but got var"), + }, + { + note: "error invalid key param type string", + object: `{"a": 1}`, + filters: `"a"`, + expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.filter: invalid argument(s)")}, + }, + { + note: "error invalid key param type boolean", + object: `{"a": 1}`, + filters: `false`, + expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.filter: invalid argument(s)")}, + }, + { + note: "error invalid key param type string input", + object: `{"a": 1}`, + filters: `input.x`, + input: `{"x": "foo"}`, + expected: builtins.NewOperandErr(2, "must be one of {object, string, array} but got string"), + }, + { + note: "error invalid key param type boolean input", + object: `{"a": 1}`, + filters: `input.x`, + input: `{"x": true}`, + expected: builtins.NewOperandErr(2, "must be one of {object, string, array} but got boolean"), + }, + { + note: "error invalid key param type number input", + object: `{"a": 1}`, + filters: `input.x`, + input: `{"x": 22}`, + expected: builtins.NewOperandErr(2, "must be one of {object, string, array} but got number"), + }, + { + note: "error invalid key param type nil input", + object: `{"a": 1}`, + filters: `input.x`, + input: `{"x": none}`, + expected: builtins.NewOperandErr(2, "must be one of {object, string, array} but got var"), + }, + } + + for _, tc := range cases { + rules := []string{ + fmt.Sprintf("p = x { x := object.filter(%s, %s) }", tc.object, tc.filters), + } + runTopDownTestCaseWithModules(t, map[string]interface{}{}, tc.note, rules, nil, tc.input, tc.expected) + } +} + +func TestBuiltinObjectFilterNonStringKey(t *testing.T) { + rules := []string{ + `p { x := object.filter({"a": 1, [[7]]: 2}, {[[7]]}); x == {[[7]]: 2} }`, + } + runTopDownTestCase(t, map[string]interface{}{}, "non string root", rules, "true") +} + +func TestBuiltinObjectFilterIdempotent(t *testing.T) { + rule := ` + p { + # "base" should never be mutated + base := {"a": 1, "b": 2, "c": 3} + object.filter(base, {"a"}) == {"a": 1} + object.filter(base, {"b"}) == {"b": 2} + object.filter(base, {"c"}) == {"c": 3} + base == {"a": 1, "b": 2, "c": 3} + } + ` + runTopDownTestCase(t, map[string]interface{}{}, t.Name(), []string{rule}, "true") +} diff --git a/topdown/parse_test.go b/topdown/parse_test.go index 525c21cddb..16a23762bd 100644 --- a/topdown/parse_test.go +++ b/topdown/parse_test.go @@ -5,7 +5,6 @@ package topdown import ( - "fmt" "testing" ) @@ -24,6 +23,6 @@ func TestRegoParseModule(t *testing.T) { `p = x { rego.parse_module("x.rego", data.ok, module); x = module["package"].path[1].value }`}, `"foo"`) runTopDownTestCase(t, data, "error", []string{ - `p = x { rego.parse_module("x.rego", data.err, x) }`}, fmt.Errorf("rego_parse_error: no match found")) + `p = x { rego.parse_module("x.rego", data.err, x) }`}, &Error{Code: BuiltinErr, Message: "rego_parse_error: no match found"}) } diff --git a/topdown/tokens_test.go b/topdown/tokens_test.go index fe97ba67bd..acf74a3f0e 100644 --- a/topdown/tokens_test.go +++ b/topdown/tokens_test.go @@ -250,7 +250,7 @@ func TestTopDownJWTEncodeSignPayloadErrors(t *testing.T) { var exp interface{} exp = fmt.Sprintf(`%s`, p.result) if p.err != "" { - exp = errors.New(p.err) + exp = &Error{Code: BuiltinErr, Message: p.err} } tests = append(tests, test{ @@ -353,7 +353,7 @@ func TestTopDownJWTEncodeSignHeaderErrors(t *testing.T) { var exp interface{} exp = fmt.Sprintf(`%s`, p.result) if p.err != "" { - exp = errors.New(p.err) + exp = &Error{Code: BuiltinErr, Message: p.err} } tests = append(tests, test{ @@ -456,7 +456,7 @@ func TestTopDownJWTEncodeSignRaw(t *testing.T) { var exp interface{} exp = fmt.Sprintf(`%s`, p.result) if p.err != "" { - exp = errors.New(p.err) + exp = &Error{Code: BuiltinErr, Message: p.err} } rawTests = append(rawTests, test{ @@ -844,7 +844,7 @@ func TestTopDownJWTBuiltins(t *testing.T) { var exp interface{} exp = fmt.Sprintf(`[%s, %s, "%s"]`, p.header, p.payload, p.signature) if p.err != "" { - exp = errors.New(p.err) + exp = &Error{Code: BuiltinErr, Message: p.err} } tests = append(tests, test{ @@ -1029,7 +1029,7 @@ func TestTopDownJWTVerifyRSA(t *testing.T) { var exp interface{} exp = fmt.Sprintf(`%t`, p.result) if p.err != "" { - exp = errors.New(p.err) + exp = &Error{Code: BuiltinErr, Message: p.err} } tests = append(tests, test{ @@ -1088,7 +1088,7 @@ func TestTopDownJWTVerifyHS256(t *testing.T) { var exp interface{} exp = fmt.Sprintf(`%t`, p.result) if p.err != "" { - exp = errors.New(p.err) + exp = &Error{Code: BuiltinErr, Message: p.err} } tests = append(tests, test{ diff --git a/topdown/topdown_test.go b/topdown/topdown_test.go index a9b2789276..4a6a314e58 100644 --- a/topdown/topdown_test.go +++ b/topdown/topdown_test.go @@ -757,7 +757,7 @@ p[x] = y { data.enum_errors.a[x] = y }`, assertTopDownWithPath(t, compiler, store, "base/virtual: missing input value", []string{"topdown", "u"}, "{}", "{}") assertTopDownWithPath(t, compiler, store, "iterate ground", []string{"topdown", "iterate_ground"}, "{}", `["p", "r"]`) assertTopDownWithPath(t, compiler, store, "base/virtual: conflicts", []string{"topdown.conflicts"}, "{}", `{"k": "foo"}`) - assertTopDownWithPath(t, compiler, store, "enumerate virtual errors", []string{"enum_errors", "caller", "p"}, `{}`, fmt.Errorf("divide by zero")) + assertTopDownWithPath(t, compiler, store, "enumerate virtual errors", []string{"enum_errors", "caller", "p"}, `{}`, &Error{Code: BuiltinErr, Message: "divide by zero"}) } func TestTopDownFix1863(t *testing.T) { @@ -1097,11 +1097,11 @@ func TestTopDownArithmetic(t *testing.T) { {"minus", []string{`p[y] { a[i] = x; y = i - x }`}, "[-1]"}, {"multiply", []string{`p[y] { a[i] = x; y = i * x }`}, "[0,2,6,12]"}, {"divide+round", []string{`p[z] { a[i] = x; y = i / x; round(y, z) }`}, "[0, 1]"}, - {"divide+error", []string{`p[y] { a[i] = x; y = x / i }`}, fmt.Errorf("divide by zero")}, + {"divide+error", []string{`p[y] { a[i] = x; y = x / i }`}, &Error{Code: BuiltinErr, Message: "divide by zero"}}, {"abs", []string{`p = true { abs(-10, x); x = 10 }`}, "true"}, {"remainder", []string{`p = x { x = 7 % 4 }`}, "3"}, - {"remainder+error", []string{`p = x { x = 7 % 0 }`}, fmt.Errorf("modulo by zero")}, - {"remainder+error+floating", []string{`p = x { x = 1.1 % 1 }`}, fmt.Errorf("modulo on floating-point number")}, + {"remainder+error", []string{`p = x { x = 7 % 0 }`}, &Error{Code: BuiltinErr, Message: "modulo by zero"}}, + {"remainder+error+floating", []string{`p = x { x = 1.1 % 1 }`}, &Error{Code: BuiltinErr, Message: "modulo on floating-point number"}}, {"arity 1 ref dest", []string{`p = true { abs(-4, a[3]) }`}, "true"}, {"arity 1 ref dest (2)", []string{`p = true { not abs(-5, a[3]) }`}, "true"}, {"arity 2 ref dest", []string{`p = true { a[2] = 1 + 2 }`}, "true"}, @@ -1127,7 +1127,7 @@ func TestTopDownCasts(t *testing.T) { "[-42.0, 0, 100.1, 0, 1]"}, {"to_number ref dest", []string{`p = true { to_number("3", a[2]) }`}, "true"}, {"to_number ref dest", []string{`p = true { not to_number("-1", a[2]) }`}, "true"}, - {"to_number: bad input", []string{`p { to_number("broken", x) }`}, fmt.Errorf("invalid syntax")}, + {"to_number: bad input", []string{`p { to_number("broken", x) }`}, &Error{Code: BuiltinErr, Message: "invalid syntax"}}, } data := loadSmallTestData() @@ -1269,7 +1269,7 @@ func TestTopDownRegexMatch(t *testing.T) { }{ {"re_match", []string{`p = true { re_match("^[a-z]+\\[[0-9]+\\]$", "foo[1]") }`}, "true"}, {"re_match: undefined", []string{`p = true { re_match("^[a-z]+\\[[0-9]+\\]$", "foo[\"bar\"]") }`}, ""}, - {"re_match: bad pattern err", []string{`p = true { re_match("][", "foo[\"bar\"]") }`}, fmt.Errorf("re_match: error parsing regexp: missing closing ]: `[`")}, + {"re_match: bad pattern err", []string{`p = true { re_match("][", "foo[\"bar\"]") }`}, &Error{Code: BuiltinErr, Message: "re_match: error parsing regexp: missing closing ]: `[`"}}, {"re_match: ref", []string{`p[x] { re_match("^b.*$", d.e[x]) }`}, "[0,1]"}, {"re_match: raw", []string{fmt.Sprintf(`p = true { re_match(%s, "foo[1]") }`, "`^[a-z]+\\[[0-9]+\\]$`")}, "true"}, @@ -1309,7 +1309,7 @@ func TestTopDownGlobsMatch(t *testing.T) { }{ {"regex.globs_match", []string{`p = true { regex.globs_match("a.a.[0-9]+z", ".b.b2359825792*594823z") }`}, "true"}, {"regex.globs_match", []string{`p = true { regex.globs_match("[a-z]+", "[0-9]*") }`}, ""}, - {"regex.globs_match: bad pattern err", []string{`p = true { regex.globs_match("pqrs]", "[a-b]+") }`}, fmt.Errorf("input:pqrs], pos:5, set-close ']' with no preceding '[': the input provided is invalid")}, + {"regex.globs_match: bad pattern err", []string{`p = true { regex.globs_match("pqrs]", "[a-b]+") }`}, &Error{Code: BuiltinErr, Message: "input:pqrs], pos:5, set-close ']' with no preceding '[': the input provided is invalid"}}, {"regex.globs_match: ref", []string{`p[x] { regex.globs_match("b.*", d.e[x]) }`}, "[0,1]"}, {"regex.globs_match: raw", []string{fmt.Sprintf(`p = true { regex.globs_match(%s, "foo\\[1\\]") }`, "`[a-z]+\\[[0-9]+\\]`")}, "true"}, @@ -1354,7 +1354,7 @@ func TestTopDownStrings(t *testing.T) { {"format_int: undefined", []string{`p = true { format_int(15.5, 16, "10000") }`}, ""}, {"format_int: ref dest", []string{`p = true { format_int(3.1, 10, numbers[2]) }`}, "true"}, {"format_int: ref dest (2)", []string{`p = true { not format_int(4.1, 10, numbers[2]) }`}, "true"}, - {"format_int: err: bad base", []string{`p = true { format_int(4.1, 199, x) }`}, fmt.Errorf("operand 2 must be one of {2, 8, 10, 16}")}, + {"format_int: err: bad base", []string{`p = true { format_int(4.1, 199, x) }`}, &Error{Code: TypeErr, Message: "operand 2 must be one of {2, 8, 10, 16}"}}, {"concat", []string{`p = x { concat("/", ["", "foo", "bar", "0", "baz"], x) }`}, `"/foo/bar/0/baz"`}, {"concat: set", []string{`p = x { concat(",", {"1", "2", "3"}, x) }`}, `"1,2,3"`}, {"concat: undefined", []string{`p = true { concat("/", ["a", "b"], "deadbeef") }`}, ""}, @@ -1365,7 +1365,7 @@ func TestTopDownStrings(t *testing.T) { {"substring", []string{`p = x { substring("abcdefgh", 2, 3, x) }`}, `"cde"`}, {"substring: remainder", []string{`p = x { substring("abcdefgh", 2, -1, x) }`}, `"cdefgh"`}, {"substring: too long", []string{`p = x { substring("abcdefgh", 2, 10000, x) }`}, `"cdefgh"`}, - {"substring: offset negative", []string{`p = x { substring("aaa", -1, -1, x) }`}, fmt.Errorf("negative offset")}, + {"substring: offset negative", []string{`p = x { substring("aaa", -1, -1, x) }`}, &Error{Code: BuiltinErr, Message: "negative offset"}}, {"substring: offset too long", []string{`p = x { substring("aaa", 3, -1, x) }`}, `""`}, {"substring: offset too long 2", []string{`p = x { substring("aaa", 4, -1, x) }`}, `""`}, {"contains", []string{`p = true { contains("abcdefgh", "defg") }`}, "true"}, @@ -1415,9 +1415,9 @@ func TestTopDownJSONBuiltins(t *testing.T) { }{ {"marshal", []string{`p = x { json.marshal([{"foo": {1,2,3}}], x) }`}, `"[{\"foo\":[1,2,3]}]"`}, {"unmarshal", []string{`p = x { json.unmarshal("[{\"foo\":[1,2,3]}]", x) }`}, `[{"foo": [1,2,3]}]"`}, - {"unmarshal-non-string", []string{`p = x { json.unmarshal(data.a[0], x) }`}, fmt.Errorf("operand 1 must be string but got number")}, + {"unmarshal-non-string", []string{`p = x { json.unmarshal(data.a[0], x) }`}, &Error{Code: TypeErr, Message: "operand 1 must be string but got number"}}, {"yaml round-trip", []string{`p = y { yaml.marshal([{"foo": {1,2,3}}], x); yaml.unmarshal(x, y) }`}, `[{"foo": [1,2,3]}]`}, - {"yaml unmarshal error", []string{`p { yaml.unmarshal("[1,2,3", _) } `}, fmt.Errorf("yaml: line 1: did not find")}, + {"yaml unmarshal error", []string{`p { yaml.unmarshal("[1,2,3", _) } `}, &Error{Code: BuiltinErr, Message: "yaml: line 1: did not find"}}, } data := loadSmallTestData() @@ -1524,7 +1524,7 @@ func TestTopDownTime(t *testing.T) { p = [year, month, day] { [year, month, day] := time.date(1582977600*1000*1000*1000) }`}, "[2020, 2, 29]") runTopDownTestCase(t, data, "date too big", []string{` - p = [year, month, day] { [year, month, day] := time.date(1582977600*1000*1000*1000*1000) }`}, fmt.Errorf("timestamp too big")) + p = [year, month, day] { [year, month, day] := time.date(1582977600*1000*1000*1000*1000) }`}, &Error{Code: BuiltinErr, Message: "timestamp too big"}) runTopDownTestCase(t, data, "clock", []string{` p = [hour, minute, second] { [hour, minute, second] := time.clock(1517832000*1000*1000*1000) }`}, "[12, 0, 0]") @@ -1536,7 +1536,7 @@ func TestTopDownTime(t *testing.T) { p = [hour, minute, second] { [hour, minute, second] := time.clock(1582977600*1000*1000*1000) }`}, "[12, 0, 0]") runTopDownTestCase(t, data, "clock too big", []string{` - p = [hour, minute, second] { [hour, minute, second] := time.clock(1582977600*1000*1000*1000*1000) }`}, fmt.Errorf("timestamp too big")) + p = [hour, minute, second] { [hour, minute, second] := time.clock(1582977600*1000*1000*1000*1000) }`}, &Error{Code: BuiltinErr, Message: "timestamp too big"}) for i, day := range []string{"Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"} { ts := 1517832000*1000*1000*1000 + i*24*int(time.Hour) @@ -1545,7 +1545,7 @@ func TestTopDownTime(t *testing.T) { } runTopDownTestCase(t, data, "weekday too big", []string{` - p = weekday { weekday := time.weekday(1582977600*1000*1000*1000*1000) }`}, fmt.Errorf("timestamp too big")) + p = weekday { weekday := time.weekday(1582977600*1000*1000*1000*1000) }`}, &Error{Code: BuiltinErr, Message: "timestamp too big"}) } func TestTopDownWalkBuiltin(t *testing.T) { @@ -2879,7 +2879,11 @@ func runTopDownTestCaseWithModules(t *testing.T, data map[string]interface{}, no compiler, err := compileRules(imports, rules, modules) if err != nil { - t.Errorf("%v: Compiler error: %v", note, err) + if _, ok := expected.(error); ok { + assertError(t, expected, err) + } else { + t.Errorf("%v: Compiler error: %v", note, err) + } return } @@ -2941,27 +2945,9 @@ func assertTopDownWithPath(t *testing.T, compiler *ast.Compiler, store storage.S testutil.Subtest(t, note, func(t *testing.T) { switch e := expected.(type) { - case Error: - result, err := query.Run(ctx) - if err == nil { - t.Errorf("Expected error but got: %v", result) - return - } - errString := err.Error() - if !strings.Contains(errString, e.Code) || !strings.Contains(errString, e.Message) { - t.Errorf("Expected error %v but got: %v", e, err) - } - case error: - result, err := query.Run(ctx) - if err == nil { - t.Errorf("Expected error but got: %v", result) - return - } - - if !strings.Contains(err.Error(), e.Error()) { - t.Errorf("Expected error %v but got: %v", e, err) - } - + case Error, error: + _, err := query.Run(ctx) + assertError(t, expected, err) case string: qrs, err := query.Run(ctx) @@ -3197,3 +3183,43 @@ func dump(note string, modules map[string]*ast.Module, data interface{}, docpath } } + +func assertError(t *testing.T, expected interface{}, actual error) { + t.Helper() + if actual == nil { + t.Errorf("Expected error but got: %v", actual) + return + } + + errString := actual.Error() + + if reflect.TypeOf(expected) != reflect.TypeOf(actual) { + t.Errorf("Expected error of type '%T', got '%T'", expected, actual) + } + + switch e := expected.(type) { + case Error: + assertErrorContains(t, errString, e.Code) + assertErrorContains(t, errString, e.Message) + case *Error: + assertErrorContains(t, errString, e.Code) + assertErrorContains(t, errString, e.Message) + case *ast.Error: + assertErrorContains(t, errString, e.Code) + assertErrorContains(t, errString, e.Message) + case ast.Errors: + for _, astErr := range e { + assertErrorContains(t, errString, astErr.Code) + assertErrorContains(t, errString, astErr.Message) + } + case error: + assertErrorContains(t, errString, e.Error()) + } +} + +func assertErrorContains(t *testing.T, actualErrMsg string, expected string) { + t.Helper() + if !strings.Contains(actualErrMsg, expected) { + t.Errorf("Expected error '%v' but got: '%v'", expected, actualErrMsg) + } +}