diff --git a/ast/builtins.go b/ast/builtins.go
index 4b91b1e167..6d6576d901 100644
--- a/ast/builtins.go
+++ b/ast/builtins.go
@@ -125,12 +125,15 @@ var DefaultBuiltins = [...]*Builtin{
YAMLMarshal,
YAMLUnmarshal,
+ // Object Manipulation
+ ObjectUnion,
+ ObjectRemove,
+ ObjectFilter,
+ ObjectGet,
+
// JSON Object Manipulation
JSONFilter,
- // Other object functions
- ObjectGet,
-
// Tokens
JWTDecode,
JWTVerifyRS256,
@@ -966,6 +969,62 @@ var JSONFilter = &Builtin{
),
}
+// ObjectUnion creates a new object that is the asymmetric union of two objects
+var ObjectUnion = &Builtin{
+ Name: "object.union",
+ Decl: types.NewFunction(
+ types.Args(
+ types.NewObject(
+ nil,
+ types.NewDynamicProperty(types.A, types.A),
+ ),
+ types.NewObject(
+ nil,
+ types.NewDynamicProperty(types.A, types.A),
+ ),
+ ),
+ types.A,
+ ),
+}
+
+// ObjectRemove Removes specified keys from an object
+var ObjectRemove = &Builtin{
+ Name: "object.remove",
+ Decl: types.NewFunction(
+ types.Args(
+ types.NewObject(
+ nil,
+ types.NewDynamicProperty(types.A, types.A),
+ ),
+ types.NewAny(
+ types.NewArray(nil, types.A),
+ types.NewSet(types.A),
+ types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)),
+ ),
+ ),
+ types.A,
+ ),
+}
+
+// ObjectFilter filters the object by keeping only specified keys
+var ObjectFilter = &Builtin{
+ Name: "object.filter",
+ Decl: types.NewFunction(
+ types.Args(
+ types.NewObject(
+ nil,
+ types.NewDynamicProperty(types.A, types.A),
+ ),
+ types.NewAny(
+ types.NewArray(nil, types.A),
+ types.NewSet(types.A),
+ types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)),
+ ),
+ ),
+ types.A,
+ ),
+}
+
// Base64Encode serializes the input string into base64 encoding.
var Base64Encode = &Builtin{
Name: "base64.encode",
diff --git a/ast/term.go b/ast/term.go
index ec5cabd827..b6137d2c09 100644
--- a/ast/term.go
+++ b/ast/term.go
@@ -1571,6 +1571,7 @@ type Object interface {
Diff(other Object) Object
Intersect(other Object) [][3]*Term
Merge(other Object) (Object, bool)
+ MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool)
Filter(filter Object) (Object, error)
Keys() []*Term
}
@@ -1808,28 +1809,48 @@ func (obj *object) MarshalJSON() ([]byte, error) {
// overlapping keys between obj and other, the values of associated with the keys are merged. Only
// objects can be merged with other objects. If the values cannot be merged, the second turn value
// will be false.
-func (obj object) Merge(other Object) (result Object, ok bool) {
- result = NewObject()
+func (obj object) Merge(other Object) (Object, bool) {
+ return obj.MergeWith(other, func(v1, v2 *Term) (*Term, bool) {
+ obj1, ok1 := v1.Value.(Object)
+ obj2, ok2 := v2.Value.(Object)
+ if !ok1 || !ok2 {
+ return nil, true
+ }
+ obj3, ok := obj1.Merge(obj2)
+ if !ok {
+ return nil, true
+ }
+ return NewTerm(obj3), false
+ })
+}
+
+// MergeWith returns a new Object containing the merged keys of obj and other.
+// If there are overlapping keys between obj and other, the conflictResolver
+// is called. The conflictResolver can return a merged value and a boolean
+// indicating if the merge has failed and should stop.
+func (obj object) MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) {
+ result := NewObject()
stop := obj.Until(func(k, v *Term) bool {
- if v2 := other.Get(k); v2 == nil {
+ v2 := other.Get(k)
+ // The key didn't exist in other, keep the original value
+ if v2 == nil {
result.Insert(k, v)
- } else {
- obj1, ok1 := v.Value.(Object)
- obj2, ok2 := v2.Value.(Object)
- if !ok1 || !ok2 {
- return true
- }
- obj3, ok := obj1.Merge(obj2)
- if !ok {
- return true
- }
- result.Insert(k, NewTerm(obj3))
+ return false
}
- return false
+
+ // The key exists in both, resolve the conflict if possible
+ merged, stop := conflictResolver(v, v2)
+ if !stop {
+ result.Insert(k, merged)
+ }
+ return stop
})
+
if stop {
return nil, false
}
+
+ // Copy in any values from other for keys that don't exist in obj
other.Foreach(func(k, v *Term) {
if v2 := obj.Get(k); v2 == nil {
result.Insert(k, v)
diff --git a/docs/content/policy-cheatsheet.md b/docs/content/policy-cheatsheet.md
index 24b4292684..67490066df 100644
--- a/docs/content/policy-cheatsheet.md
+++ b/docs/content/policy-cheatsheet.md
@@ -259,6 +259,15 @@ merge_objects(a, b) = c {
}
```
+> Note: use the `object.union` builtin function unless custom behavior is required!
+
+```live:rules/merge_builtin:query:read_only
+x := {"a": true, "b": false}
+y := {"b": "foo", "c": 4}
+z := {"a": true, "b": "foo", "c": 4}
+object.union(y, x) == z
+```
+
## Tests
```live:tests:module:read_only
diff --git a/docs/content/policy-reference.md b/docs/content/policy-reference.md
index eccd4dd313..ffb82bc0b0 100644
--- a/docs/content/policy-reference.md
+++ b/docs/content/policy-reference.md
@@ -69,8 +69,15 @@ complex types.
### Objects
| Built-in | Description |
| -------- | ----------- |
-| `filtered := json.filter(object, paths)` | `filtered` is the remaining data from `object` with only keys specified in `paths` which is an array or set of key paths. Each path may be a JSON string path or an array of path segments. For example: `json.filter({"a": {"b": "x", "c": "y"}}, ["a/b"]` will result in `{"a": {"b": "x"}}`). The `json` string `paths` may reference into array values by using index numbers. For example with the object `{"a": ["x", "y", "z"]}` the path `a/1` references `y` |
| `value := object.get(object, key, default)` | `value` is the value stored by the `object` at `key`. If no value is found, `default` is returned. |
+| `output := object.remove(object, keys)` | `output` is a new object which is the result of removing the specified `keys` from `object`. `keys` must be either an array, object, or set of keys. |
+| `output := object.union(objectA, objectB)` | `output` is a new object which is the result of an asymmetric recursive union of two objects where conflicts are resolved by choosing the key from the right-hand object (`objectB`). For example: `object.union({"a": 1, "b": 2, "c": {"d": 3}}, {"a": 7, "c": {"d": 4, "e": 5}})` will result in `{"a": 7, "b": 2, "c": {"d": 4, "e": 5}}` |
+| `filtered := object.filter(object, keys)` | `filtered` is a new object with the remaining data from `object` with only keys specified in `keys` which is an array, object, or set of keys. For example: `object.filter({"a": {"b": "x", "c": "y"}, "d": "z"}, ["a"])` will result in `{"a": {"b": "x", "c": "y"}}`). |
+| `filtered := json.filter(object, paths)` | `filtered` is the remaining data from `object` with only keys specified in `paths` which is an array or set of JSON string paths. For example: `json.filter({"a": {"b": "x", "c": "y"}}, ["a/b"])` will result in `{"a": {"b": "x"}}`). |
+
+> The `json` string `paths` may reference into array values by using index numbers. For example with the object `{"a": ["x", "y", "z"]}` the path `a[1]` references `y`
+
+> When `keys` are provided as an object only the top level keys on the object will be used, values are ignored. For example: `object.remove({"a": {"b": {"c": 2}}, "x": 123}, {"a": 1}) == {"x": 123}` regardless of the value for key `a` in the keys object, the following `keys` object gives the same result `object.remove({"a": {"b": {"c": 2}}, "x": 123}, {"a": {"b": {"foo": "bar"}}}) == {"x": 123}`
### Strings
diff --git a/topdown/array_test.go b/topdown/array_test.go
index b984e09edc..094505c4c5 100644
--- a/topdown/array_test.go
+++ b/topdown/array_test.go
@@ -5,7 +5,6 @@
package topdown
import (
- "fmt"
"testing"
)
@@ -17,8 +16,8 @@ func TestTopDownArray(t *testing.T) {
expected interface{}
}{
{"concat", []string{`p = x { x = array.concat([1,2], [3,4]) }`}, "[1,2,3,4]"},
- {"concat: err", []string{`p = x { x = array.concat(data.b, [3,4]) }`}, fmt.Errorf("operand 1 must be array")},
- {"concat: err rhs", []string{`p = x { x = array.concat([1,2], data.b) }`}, fmt.Errorf("operand 2 must be array")},
+ {"concat: err", []string{`p = x { x = array.concat(data.b, [3,4]) }`}, &Error{Code: TypeErr, Message: "array.concat: operand 1 must be array but got object"}},
+ {"concat: err rhs", []string{`p = x { x = array.concat([1,2], data.b) }`}, &Error{Code: TypeErr, Message: "array.concat: operand 2 must be array but got object"}},
{"slice", []string{`p = x { x = array.slice([1,2,3,4,5], 1, 3) }`}, "[2,3]"},
{"slice: empty slice", []string{`p = x { x = array.slice([1,2,3], 0, 0) }`}, "[]"},
diff --git a/topdown/casts_test.go b/topdown/casts_test.go
index 31cf536f8b..d70eec0f7b 100644
--- a/topdown/casts_test.go
+++ b/topdown/casts_test.go
@@ -5,7 +5,6 @@
package topdown
import (
- "fmt"
"testing"
"github.com/open-policy-agent/opa/ast"
@@ -20,7 +19,7 @@ func TestToArray(t *testing.T) {
panic(err)
}
- typeErr := fmt.Errorf("type")
+ typeErr := &Error{Code: TypeErr, Message: "operand 1 must be one of {array, set}"}
tests := []struct {
note string
@@ -41,7 +40,7 @@ func TestToArray(t *testing.T) {
func TestToSet(t *testing.T) {
- typeErr := fmt.Errorf("type")
+ typeErr := &Error{Code: TypeErr, Message: "operand 1 must be one of {array, set}"}
tests := []struct {
note string
@@ -61,7 +60,7 @@ func TestToSet(t *testing.T) {
}
func TestCasts(t *testing.T) {
- typeErr := fmt.Errorf("type")
+ typeErr := &Error{Code: TypeErr}
tests := []struct {
note string
diff --git a/topdown/cidr_test.go b/topdown/cidr_test.go
index 4bf6815878..d7158c724f 100644
--- a/topdown/cidr_test.go
+++ b/topdown/cidr_test.go
@@ -2,7 +2,7 @@ package topdown
import (
"context"
- "errors"
+ "net"
"testing"
"time"
@@ -120,7 +120,7 @@ func TestNetCIDRExpand(t *testing.T) {
rules: []string{
`p = x { net.cidr_expand("192.168.1.1/33", x) }`,
},
- expected: errors.New("invalid CIDR address"),
+ expected: &net.ParseError{Type: "CIDR address", Text: "192.168.1.1/33"},
},
}
diff --git a/topdown/crypto_test.go b/topdown/crypto_test.go
index cce93aee02..133dda8990 100644
--- a/topdown/crypto_test.go
+++ b/topdown/crypto_test.go
@@ -40,7 +40,7 @@ func TestCryptoX509ParseCertificates(t *testing.T) {
note: "bad",
certs: `YmFkc3RyaW5n`,
rule: rule,
- expected: fmt.Errorf("asn1: structure error"),
+ expected: &Error{Code: BuiltinErr, Message: "asn1: structure error"},
},
}
diff --git a/topdown/http_test.go b/topdown/http_test.go
index 4e8a968aa8..ae99eb0de6 100644
--- a/topdown/http_test.go
+++ b/topdown/http_test.go
@@ -9,7 +9,6 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/json"
- "errors"
"fmt"
"io/ioutil"
"net/http"
@@ -183,8 +182,8 @@ func TestHTTPCustomHeaders(t *testing.T) {
}
}
-// TestHTTPostRequest adds a new person
-func TestHTTPostRequest(t *testing.T) {
+// TestHTTPPostRequest adds a new person
+func TestHTTPPostRequest(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -258,7 +257,7 @@ func TestHTTPostRequest(t *testing.T) {
"headers": {"Content-Type": "application/x-www-form-encoded"},
"raw_body": {"bar": "bar"}
}`,
- expected: errors.New("raw_body must be a string"),
+ expected: &Error{Code: BuiltinErr, Message: "raw_body must be a string"},
},
}
@@ -365,8 +364,8 @@ func TestInvalidKeyError(t *testing.T) {
rules []string
expected interface{}
}{
- {"invalid keys", []string{`p = x { http.send({"method": "get", "url": "http://127.0.0.1:51113", "bad_key": "bad_value"}, x) }`}, fmt.Errorf(`invalid request parameters(s): {"bad_key"}`)},
- {"missing keys", []string{`p = x { http.send({"method": "get"}, x) }`}, fmt.Errorf(`missing required request parameters(s): {"url"}`)},
+ {"invalid keys", []string{`p = x { http.send({"method": "get", "url": "http://127.0.0.1:51113", "bad_key": "bad_value"}, x) }`}, &Error{Code: TypeErr, Message: `invalid request parameters(s): {"bad_key"}`}},
+ {"missing keys", []string{`p = x { http.send({"method": "get"}, x) }`}, &Error{Code: TypeErr, Message: `missing required request parameters(s): {"url"}`}},
}
data := loadSmallTestData()
@@ -644,7 +643,7 @@ func TestHTTPSClient(t *testing.T) {
t.Run("Negative Test: No Root Ca", func(t *testing.T) {
- expectedResult := Error{Code: BuiltinErr, Message: "x509: certificate signed by unknown authority", Location: nil}
+ expectedResult := &Error{Code: BuiltinErr, Message: "x509: certificate signed by unknown authority", Location: nil}
data := loadSmallTestData()
rule := []string{fmt.Sprintf(
`p = x { http.send({"method": "get", "url": "%s", "tls_client_cert_file": "%s", "tls_client_key_file": "%s"}, x) }`, s.URL, localClientCertFile, localClientKeyFile)}
@@ -655,7 +654,7 @@ func TestHTTPSClient(t *testing.T) {
t.Run("Negative Test: Wrong Cert/Key Pair", func(t *testing.T) {
- expectedResult := Error{Code: BuiltinErr, Message: "tls: private key does not match public key", Location: nil}
+ expectedResult := &Error{Code: BuiltinErr, Message: "tls: private key does not match public key", Location: nil}
data := loadSmallTestData()
rule := []string{fmt.Sprintf(
`p = x { http.send({"method": "get", "url": "%s", "tls_ca_cert_file": "%s", "tls_client_cert_file": "%s", "tls_client_key_file": "%s"}, x) }`, s.URL, localCaFile, localClientCert2File, localClientKeyFile)}
@@ -666,7 +665,7 @@ func TestHTTPSClient(t *testing.T) {
t.Run("Negative Test: System Certs do not include local rootCA", func(t *testing.T) {
- expectedResult := Error{Code: BuiltinErr, Message: "x509: certificate signed by unknown authority", Location: nil}
+ expectedResult := &Error{Code: BuiltinErr, Message: "x509: certificate signed by unknown authority", Location: nil}
data := loadSmallTestData()
rule := []string{fmt.Sprintf(
`p = x { http.send({"method": "get", "url": "%s", "tls_client_cert_file": "%s", "tls_client_key_file": "%s", "tls_use_system_certs": true}, x) }`, s.URL, localClientCertFile, localClientKeyFile)}
diff --git a/topdown/object.go b/topdown/object.go
new file mode 100644
index 0000000000..5fb3ee8470
--- /dev/null
+++ b/topdown/object.go
@@ -0,0 +1,139 @@
+// Copyright 2020 The OPA Authors. All rights reserved.
+// Use of this source code is governed by an Apache2
+// license that can be found in the LICENSE file.
+
+package topdown
+
+import (
+ "github.com/open-policy-agent/opa/ast"
+ "github.com/open-policy-agent/opa/topdown/builtins"
+ "github.com/open-policy-agent/opa/types"
+)
+
+func builtinObjectUnion(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
+ objA, err := builtins.ObjectOperand(operands[0].Value, 1)
+ if err != nil {
+ return err
+ }
+
+ objB, err := builtins.ObjectOperand(operands[1].Value, 2)
+ if err != nil {
+ return err
+ }
+
+ r := mergeWithOverwrite(objA, objB)
+
+ return iter(ast.NewTerm(r))
+}
+
+func builtinObjectRemove(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
+ // Expect an object and an array/set/object of keys
+ obj, err := builtins.ObjectOperand(operands[0].Value, 1)
+ if err != nil {
+ return err
+ }
+
+ // Build a set of keys to remove
+ keysToRemove, err := getObjectKeysParam(operands[1].Value)
+ if err != nil {
+ return err
+ }
+ r := ast.NewObject()
+ obj.Foreach(func(key *ast.Term, value *ast.Term) {
+ if !keysToRemove.Contains(key) {
+ r.Insert(key, value)
+ }
+ })
+
+ return iter(ast.NewTerm(r))
+}
+
+func builtinObjectFilter(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
+ // Expect an object and an array/set/object of keys
+ obj, err := builtins.ObjectOperand(operands[0].Value, 1)
+ if err != nil {
+ return err
+ }
+
+ // Build a new object from the supplied filter keys
+ keys, err := getObjectKeysParam(operands[1].Value)
+ if err != nil {
+ return err
+ }
+
+ filterObj := ast.NewObject()
+ keys.Foreach(func(key *ast.Term) {
+ filterObj.Insert(key, ast.NullTerm())
+ })
+
+ // Actually do the filtering
+ r, err := obj.Filter(filterObj)
+ if err != nil {
+ return err
+ }
+
+ return iter(ast.NewTerm(r))
+}
+
+func builtinObjectGet(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
+ object, err := builtins.ObjectOperand(operands[0].Value, 1)
+ if err != nil {
+ return err
+ }
+
+ if ret := object.Get(operands[1]); ret != nil {
+ return iter(ret)
+ }
+
+ return iter(operands[2])
+}
+
+// getObjectKeysParam returns a set of key values
+// from a supplied ast array, object, set value
+func getObjectKeysParam(arrayOrSet ast.Value) (ast.Set, error) {
+ keys := ast.NewSet()
+
+ switch v := arrayOrSet.(type) {
+ case ast.Array:
+ for _, f := range v {
+ keys.Add(f)
+ }
+ case ast.Set:
+ _ = v.Iter(func(f *ast.Term) error {
+ keys.Add(f)
+ return nil
+ })
+ case ast.Object:
+ _ = v.Iter(func(k *ast.Term, _ *ast.Term) error {
+ keys.Add(k)
+ return nil
+ })
+ default:
+ return nil, builtins.NewOperandTypeErr(2, arrayOrSet, ast.TypeName(types.Object{}), ast.TypeName(types.S), ast.TypeName(types.Array{}))
+ }
+
+ return keys, nil
+}
+
+func mergeWithOverwrite(objA, objB ast.Object) ast.Object {
+ merged, _ := objA.MergeWith(objB, func(v1, v2 *ast.Term) (*ast.Term, bool) {
+ originalValueObj, ok2 := v1.Value.(ast.Object)
+ updateValueObj, ok1 := v2.Value.(ast.Object)
+ if !ok1 || !ok2 {
+ // If we can't merge, stick with the right-hand value
+ return v2, false
+ }
+
+ // Recursively update the existing value
+ merged := mergeWithOverwrite(originalValueObj, updateValueObj)
+ return ast.NewTerm(merged), false
+ })
+ return merged
+}
+
+func init() {
+ RegisterBuiltinFunc(ast.ObjectUnion.Name, builtinObjectUnion)
+ RegisterBuiltinFunc(ast.ObjectRemove.Name, builtinObjectRemove)
+ RegisterBuiltinFunc(ast.ObjectFilter.Name, builtinObjectFilter)
+ RegisterBuiltinFunc(ast.ObjectGet.Name, builtinObjectGet)
+}
diff --git a/topdown/object_get.go b/topdown/object_get.go
deleted file mode 100644
index 7d21ae5e4e..0000000000
--- a/topdown/object_get.go
+++ /dev/null
@@ -1,23 +0,0 @@
-package topdown
-
-import (
- "github.com/open-policy-agent/opa/ast"
- "github.com/open-policy-agent/opa/topdown/builtins"
-)
-
-func builtinObjectGet(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
- object, err := builtins.ObjectOperand(operands[0].Value, 1)
- if err != nil {
- return err
- }
-
- if ret := object.Get(operands[1]); ret != nil {
- return iter(ret)
- }
-
- return iter(operands[2])
-}
-
-func init() {
- RegisterBuiltinFunc(ast.ObjectGet.Name, builtinObjectGet)
-}
diff --git a/topdown/object_get_test.go b/topdown/object_get_test.go
deleted file mode 100644
index 1b4ca15ae0..0000000000
--- a/topdown/object_get_test.go
+++ /dev/null
@@ -1,67 +0,0 @@
-package topdown
-
-import (
- "fmt"
- "testing"
-)
-
-func TestObjectGet(t *testing.T) {
- cases := []struct {
- note string
- object string
- key interface{}
- fallback interface{}
- expected interface{}
- }{
- {
- note: "basic case . found",
- object: `{"a": "b"}`,
- key: `"a"`,
- fallback: `"c"`,
- expected: `"b"`,
- },
- {
- note: "basic case . not found",
- object: `{"a": "b"}`,
- key: `"c"`,
- fallback: `"c"`,
- expected: `"c"`,
- },
- {
-
- note: "integer key . found",
- object: "{1: 2}",
- key: "1",
- fallback: "3",
- expected: "2",
- },
- {
- note: "integer key . not found",
- object: "{1: 2}",
- key: "2",
- fallback: "3",
- expected: "3",
- },
- {
- note: "complex value . found",
- object: `{"a": {"b": "c"}}`,
- key: `"a"`,
- fallback: "true",
- expected: `{"b": "c"}`,
- },
- {
- note: "complex value . not found",
- object: `{"a": {"b": "c"}}`,
- key: `"b"`,
- fallback: "true",
- expected: "true",
- },
- }
-
- for _, tc := range cases {
- rules := []string{
- fmt.Sprintf("p = x { x := object.get(%s, %s, %s) }", tc.object, tc.key, tc.fallback),
- }
- runTopDownTestCase(t, map[string]interface{}{}, tc.note, rules, tc.expected)
- }
-}
diff --git a/topdown/object_test.go b/topdown/object_test.go
new file mode 100644
index 0000000000..2b907773b7
--- /dev/null
+++ b/topdown/object_test.go
@@ -0,0 +1,544 @@
+// Copyright 2020 The OPA Authors. All rights reserved.
+// Use of this source code is governed by an Apache2
+// license that can be found in the LICENSE file.
+
+package topdown
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/open-policy-agent/opa/ast"
+ "github.com/open-policy-agent/opa/topdown/builtins"
+)
+
+func TestObjectGet(t *testing.T) {
+ cases := []struct {
+ note string
+ object string
+ key interface{}
+ fallback interface{}
+ expected interface{}
+ }{
+ {
+ note: "basic case . found",
+ object: `{"a": "b"}`,
+ key: `"a"`,
+ fallback: `"c"`,
+ expected: `"b"`,
+ },
+ {
+ note: "basic case . not found",
+ object: `{"a": "b"}`,
+ key: `"c"`,
+ fallback: `"c"`,
+ expected: `"c"`,
+ },
+ {
+
+ note: "integer key . found",
+ object: "{1: 2}",
+ key: "1",
+ fallback: "3",
+ expected: "2",
+ },
+ {
+ note: "integer key . not found",
+ object: "{1: 2}",
+ key: "2",
+ fallback: "3",
+ expected: "3",
+ },
+ {
+ note: "complex value . found",
+ object: `{"a": {"b": "c"}}`,
+ key: `"a"`,
+ fallback: "true",
+ expected: `{"b": "c"}`,
+ },
+ {
+ note: "complex value . not found",
+ object: `{"a": {"b": "c"}}`,
+ key: `"b"`,
+ fallback: "true",
+ expected: "true",
+ },
+ }
+
+ for _, tc := range cases {
+ rules := []string{
+ fmt.Sprintf("p = x { x := object.get(%s, %s, %s) }", tc.object, tc.key, tc.fallback),
+ }
+ runTopDownTestCase(t, map[string]interface{}{}, tc.note, rules, tc.expected)
+ }
+}
+
+func TestBuiltinObjectUnion(t *testing.T) {
+ cases := []struct {
+ note string
+ objectA string
+ objectB string
+ input string
+ expected interface{}
+ }{
+ {
+ note: "both empty",
+ objectA: `{}`,
+ objectB: `{}`,
+ expected: `{}`,
+ },
+ {
+ note: "left empty",
+ objectA: `{}`,
+ objectB: `{"a": 1}`,
+ expected: `{"a": 1}`,
+ },
+ {
+ note: "right empty",
+ objectA: `{"a": 1}`,
+ objectB: `{}`,
+ expected: `{"a": 1}`,
+ },
+ {
+ note: "base",
+ objectA: `{"a": 1}`,
+ objectB: `{"b": 2}`,
+ expected: `{"a": 1, "b": 2}`,
+ },
+ {
+ note: "nested",
+ objectA: `{"a": {"b": {"c": 1}}}`,
+ objectB: `{"b": 2}`,
+ expected: `{"a": {"b": {"c": 1}}, "b": 2}`,
+ },
+ {
+ note: "nested reverse",
+ objectA: `{"b": 2}`,
+ objectB: `{"a": {"b": {"c": 1}}}`,
+ expected: `{"a": {"b": {"c": 1}}, "b": 2}`,
+ },
+ {
+ note: "conflict simple",
+ objectA: `{"a": 1}`,
+ objectB: `{"a": 2}`,
+ expected: `{"a": 2}`,
+ },
+ {
+ note: "conflict nested and extra field",
+ objectA: `{"a": 1}`,
+ objectB: `{"a": {"b": {"c": 1}}, "d": 7}`,
+ expected: `{"a": {"b": {"c": 1}}, "d": 7}`,
+ },
+ {
+ note: "conflict multiple",
+ objectA: `{"a": {"b": {"c": 1}}, "e": 1}`,
+ objectB: `{"a": {"b": "foo", "b1": "bar"}, "d": 7, "e": 17}`,
+ expected: `{"a": {"b": "foo", "b1": "bar"}, "d": 7, "e": 17}`,
+ },
+ {
+ note: "error wrong lhs type",
+ objectA: `[1, 2, 3]`,
+ objectB: `{"b": 2}`,
+ expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.union: invalid argument(s)")},
+ },
+ {
+ note: "error wrong lhs type input",
+ objectA: `input.a`,
+ objectB: `{"b": 2}`,
+ input: `{"a": [1, 2, 3]}`,
+ expected: builtins.NewOperandErr(1, "must be object but got array"),
+ },
+ {
+ note: "error wrong rhs type",
+ objectA: `{"a": 1}`,
+ objectB: `[1, 2, 3]`,
+ expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.union: invalid argument(s)")},
+ },
+ {
+ note: "error wrong rhs type input",
+ objectA: `{"a": 1}`,
+ objectB: `input.b`,
+ input: `{"b": [1, 2, 3]}`,
+ expected: builtins.NewOperandErr(2, "must be object but got array"),
+ },
+ {
+ note: "error wrong both params",
+ objectA: `"foo"`,
+ objectB: `[1, 2, 3]`,
+ expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.union: invalid argument(s)")},
+ },
+ }
+
+ for _, tc := range cases {
+ rules := []string{
+ fmt.Sprintf("p = x { x := object.union(%s, %s) }", tc.objectA, tc.objectB),
+ }
+ runTopDownTestCaseWithModules(t, map[string]interface{}{}, tc.note, rules, nil, tc.input, tc.expected)
+ }
+}
+
+func TestBuiltinObjectRemove(t *testing.T) {
+ cases := []struct {
+ note string
+ object string
+ keys string
+ input string
+ expected interface{}
+ }{
+ {
+ note: "base",
+ object: `{"a": 1, "b": {"c": 3}}`,
+ keys: `{"a"}`,
+ expected: `{"b": {"c": 3}}`,
+ },
+ {
+ note: "multiple keys set",
+ object: `{"a": 1, "b": {"c": 3}, "d": 4}`,
+ keys: `{"d", "b"}`,
+ expected: `{"a": 1}`,
+ },
+ {
+ note: "multiple keys array",
+ object: `{"a": 1, "b": {"c": 3}, "d": 4}`,
+ keys: `["d", "b"]`,
+ expected: `{"a": 1}`,
+ },
+ {
+ note: "multiple keys object",
+ object: `{"a": 1, "b": {"c": 3}, "d": 4}`,
+ keys: `{"d": "", "b": 1}`,
+ expected: `{"a": 1}`,
+ },
+ {
+ note: "multiple keys object nested",
+ object: `{"a": {"b": {"c": 2}}, "x": 123}`,
+ keys: `{"a": {"b": {"foo": "bar"}}}`,
+ expected: `{"x": 123}`,
+ },
+ {
+ note: "empty object",
+ object: `{}`,
+ keys: `{"a", "b"}`,
+ expected: `{}`,
+ },
+ {
+ note: "empty keys set",
+ object: `{"a": 1, "b": {"c": 3}}`,
+ keys: `set()`,
+ expected: `{"a": 1, "b": {"c": 3}}`,
+ },
+ {
+ note: "empty keys array",
+ object: `{"a": 1, "b": {"c": 3}}`,
+ keys: `[]`,
+ expected: `{"a": 1, "b": {"c": 3}}`,
+ },
+ {
+ note: "empty keys obj",
+ object: `{"a": 1, "b": {"c": 3}}`,
+ keys: `{}`,
+ expected: `{"a": 1, "b": {"c": 3}}`,
+ },
+ {
+ note: "key doesnt exist",
+ object: `{"a": 1, "b": {"c": 3}}`,
+ keys: `{"z"}`,
+ expected: `{"a": 1, "b": {"c": 3}}`,
+ },
+ {
+ note: "error invalid object param type set",
+ object: `{"a"}`,
+ keys: `{"a"}`,
+ expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.remove: invalid argument(s)")},
+ },
+ {
+ note: "error invalid object param type bool",
+ object: `false`,
+ keys: `{"a"}`,
+ expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.remove: invalid argument(s)")},
+ },
+ {
+ note: "error invalid object param type array input",
+ object: `input.x`,
+ keys: `{"a"}`,
+ input: `{"x": ["a"]}`,
+ expected: builtins.NewOperandErr(1, "must be object but got array"),
+ },
+ {
+ note: "error invalid object param type bool input",
+ object: `input.x`,
+ keys: `{"a"}`,
+ input: `{"x": false}`,
+ expected: builtins.NewOperandErr(1, "must be object but got boolean"),
+ },
+ {
+ note: "error invalid object param type number input",
+ object: `input.x`,
+ keys: `{"a"}`,
+ input: `{"x": 123}`,
+ expected: builtins.NewOperandErr(1, "must be object but got number"),
+ },
+ {
+ note: "error invalid object param type string input",
+ object: `input.x`,
+ keys: `{"a"}`,
+ input: `{"x": "foo"}`,
+ expected: builtins.NewOperandErr(1, "must be object but got string"),
+ },
+ {
+ note: "error invalid object param type nil input",
+ object: `input.x`,
+ keys: `{"a"}`,
+ input: `{"x": none}`,
+ expected: builtins.NewOperandErr(1, "must be object but got var"),
+ },
+ {
+ note: "error invalid key param type string",
+ object: `{"a": 1}`,
+ keys: `"a"`,
+ expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.remove: invalid argument(s)")},
+ },
+ {
+ note: "error invalid key param type boolean",
+ object: `{"a": 1}`,
+ keys: `false`,
+ expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.remove: invalid argument(s)")},
+ },
+ {
+ note: "error invalid key param type string input",
+ object: `{"a": 1}`,
+ keys: `input.x`,
+ input: `{"x": "foo"}`,
+ expected: builtins.NewOperandErr(2, "must be one of {object, string, array} but got string"),
+ },
+ {
+ note: "error invalid key param type boolean input",
+ object: `{"a": 1}`,
+ keys: `input.x`,
+ input: `{"x": true}`,
+ expected: builtins.NewOperandErr(2, "must be one of {object, string, array} but got boolean"),
+ },
+ {
+ note: "error invalid key param type number input",
+ object: `{"a": 1}`,
+ keys: `input.x`,
+ input: `{"x": 22}`,
+ expected: builtins.NewOperandErr(2, "must be one of {object, string, array} but got number"),
+ },
+ {
+ note: "error invalid key param type nil input",
+ object: `{"a": 1}`,
+ keys: `input.x`,
+ input: `{"x": none}`,
+ expected: builtins.NewOperandErr(2, "must be one of {object, string, array} but got var"),
+ },
+ }
+
+ for _, tc := range cases {
+ rules := []string{
+ fmt.Sprintf("p = x { x := object.remove(%s, %s) }", tc.object, tc.keys),
+ }
+ runTopDownTestCaseWithModules(t, map[string]interface{}{}, tc.note, rules, nil, tc.input, tc.expected)
+ }
+}
+
+func TestBuiltinObjectRemoveIdempotent(t *testing.T) {
+ rule := `
+ p {
+ # "base" should never be mutated
+ base := {"a": 1, "b": 2, "c": 3}
+ object.remove(base, {"a"}) == {"b": 2, "c": 3}
+ object.remove(base, {"b"}) == {"a": 1, "c": 3}
+ object.remove(base, {"c"}) == {"a": 1, "b": 2}
+ base == {"a": 1, "b": 2, "c": 3}
+ }
+ `
+ runTopDownTestCase(t, map[string]interface{}{}, t.Name(), []string{rule}, "true")
+}
+
+func TestBuiltinObjectRemoveNonStringKey(t *testing.T) {
+ rules := []string{
+ `p { x := object.remove({"a": 1, [[7]]: 2}, {[[7]]}); x == {"a": 1} }`,
+ }
+ runTopDownTestCase(t, map[string]interface{}{}, "non string root", rules, "true")
+}
+
+func TestBuiltinObjectFilter(t *testing.T) {
+ cases := []struct {
+ note string
+ object string
+ filters string
+ input string
+ expected interface{}
+ }{
+ {
+ note: "base",
+ object: `{"a": {"b": {"c": 7, "d": 8}}, "e": 9}`,
+ filters: `{"a"}`,
+ expected: `{"a": {"b": {"c": 7, "d": 8}}}`,
+ },
+ {
+ note: "multiple roots set",
+ object: `{"a": 1, "b": 2, "c": 3, "e": 9}`,
+ filters: `{"a", "e"}`,
+ expected: `{"a": 1, "e": 9}`,
+ },
+ {
+ note: "multiple roots array",
+ object: `{"a": 1, "b": 2, "c": 3, "e": 9}`,
+ filters: `["a", "e"]`,
+ expected: `{"a": 1, "e": 9}`,
+ },
+ {
+ note: "multiple roots object",
+ object: `{"a": 1, "b": 2, "c": 3, "e": 9}`,
+ filters: `{"a": "foo", "e": ""}`,
+ expected: `{"a": 1, "e": 9}`,
+ },
+ {
+ note: "duplicate roots",
+ object: `{"a": {"b": {"c": 7, "d": 8}}, "e": 9}`,
+ filters: `{"a", "a"}`,
+ expected: `{"a": {"b": {"c": 7, "d": 8}}}`,
+ },
+ {
+ note: "empty roots set",
+ object: `{"a": 7}`,
+ filters: `set()`,
+ expected: `{}`,
+ },
+ {
+ note: "empty roots array",
+ object: `{"a": 7}`,
+ filters: `[]`,
+ expected: `{}`,
+ },
+ {
+ note: "empty roots object",
+ object: `{"a": 7}`,
+ filters: `{}`,
+ expected: `{}`,
+ },
+ {
+ note: "empty object",
+ object: `{}`,
+ filters: `{"a"}`,
+ expected: `{}`,
+ },
+ {
+ note: "error invalid object param type set",
+ object: `{"a"}`,
+ filters: `{"a"}`,
+ expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.filter: invalid argument(s)")},
+ },
+ {
+ note: "error invalid object param type bool",
+ object: `false`,
+ filters: `{"a"}`,
+ expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.filter: invalid argument(s)")},
+ },
+ {
+ note: "error invalid object param type array input",
+ object: `input.x`,
+ filters: `{"a"}`,
+ input: `{"x": ["a"]}`,
+ expected: builtins.NewOperandErr(1, "must be object but got array"),
+ },
+ {
+ note: "error invalid object param type bool input",
+ object: `input.x`,
+ filters: `{"a"}`,
+ input: `{"x": false}`,
+ expected: builtins.NewOperandErr(1, "must be object but got boolean"),
+ },
+ {
+ note: "error invalid object param type number input",
+ object: `input.x`,
+ filters: `{"a"}`,
+ input: `{"x": 123}`,
+ expected: builtins.NewOperandErr(1, "must be object but got number"),
+ },
+ {
+ note: "error invalid object param type string input",
+ object: `input.x`,
+ filters: `{"a"}`,
+ input: `{"x": "foo"}`,
+ expected: builtins.NewOperandErr(1, "must be object but got string"),
+ },
+ {
+ note: "error invalid object param type nil input",
+ object: `input.x`,
+ filters: `{"a"}`,
+ input: `{"x": none}`,
+ expected: builtins.NewOperandErr(1, "must be object but got var"),
+ },
+ {
+ note: "error invalid key param type string",
+ object: `{"a": 1}`,
+ filters: `"a"`,
+ expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.filter: invalid argument(s)")},
+ },
+ {
+ note: "error invalid key param type boolean",
+ object: `{"a": 1}`,
+ filters: `false`,
+ expected: ast.Errors{ast.NewError(ast.TypeErr, nil, "object.filter: invalid argument(s)")},
+ },
+ {
+ note: "error invalid key param type string input",
+ object: `{"a": 1}`,
+ filters: `input.x`,
+ input: `{"x": "foo"}`,
+ expected: builtins.NewOperandErr(2, "must be one of {object, string, array} but got string"),
+ },
+ {
+ note: "error invalid key param type boolean input",
+ object: `{"a": 1}`,
+ filters: `input.x`,
+ input: `{"x": true}`,
+ expected: builtins.NewOperandErr(2, "must be one of {object, string, array} but got boolean"),
+ },
+ {
+ note: "error invalid key param type number input",
+ object: `{"a": 1}`,
+ filters: `input.x`,
+ input: `{"x": 22}`,
+ expected: builtins.NewOperandErr(2, "must be one of {object, string, array} but got number"),
+ },
+ {
+ note: "error invalid key param type nil input",
+ object: `{"a": 1}`,
+ filters: `input.x`,
+ input: `{"x": none}`,
+ expected: builtins.NewOperandErr(2, "must be one of {object, string, array} but got var"),
+ },
+ }
+
+ for _, tc := range cases {
+ rules := []string{
+ fmt.Sprintf("p = x { x := object.filter(%s, %s) }", tc.object, tc.filters),
+ }
+ runTopDownTestCaseWithModules(t, map[string]interface{}{}, tc.note, rules, nil, tc.input, tc.expected)
+ }
+}
+
+func TestBuiltinObjectFilterNonStringKey(t *testing.T) {
+ rules := []string{
+ `p { x := object.filter({"a": 1, [[7]]: 2}, {[[7]]}); x == {[[7]]: 2} }`,
+ }
+ runTopDownTestCase(t, map[string]interface{}{}, "non string root", rules, "true")
+}
+
+func TestBuiltinObjectFilterIdempotent(t *testing.T) {
+ rule := `
+ p {
+ # "base" should never be mutated
+ base := {"a": 1, "b": 2, "c": 3}
+ object.filter(base, {"a"}) == {"a": 1}
+ object.filter(base, {"b"}) == {"b": 2}
+ object.filter(base, {"c"}) == {"c": 3}
+ base == {"a": 1, "b": 2, "c": 3}
+ }
+ `
+ runTopDownTestCase(t, map[string]interface{}{}, t.Name(), []string{rule}, "true")
+}
diff --git a/topdown/parse_test.go b/topdown/parse_test.go
index 525c21cddb..16a23762bd 100644
--- a/topdown/parse_test.go
+++ b/topdown/parse_test.go
@@ -5,7 +5,6 @@
package topdown
import (
- "fmt"
"testing"
)
@@ -24,6 +23,6 @@ func TestRegoParseModule(t *testing.T) {
`p = x { rego.parse_module("x.rego", data.ok, module); x = module["package"].path[1].value }`}, `"foo"`)
runTopDownTestCase(t, data, "error", []string{
- `p = x { rego.parse_module("x.rego", data.err, x) }`}, fmt.Errorf("rego_parse_error: no match found"))
+ `p = x { rego.parse_module("x.rego", data.err, x) }`}, &Error{Code: BuiltinErr, Message: "rego_parse_error: no match found"})
}
diff --git a/topdown/tokens_test.go b/topdown/tokens_test.go
index fe97ba67bd..acf74a3f0e 100644
--- a/topdown/tokens_test.go
+++ b/topdown/tokens_test.go
@@ -250,7 +250,7 @@ func TestTopDownJWTEncodeSignPayloadErrors(t *testing.T) {
var exp interface{}
exp = fmt.Sprintf(`%s`, p.result)
if p.err != "" {
- exp = errors.New(p.err)
+ exp = &Error{Code: BuiltinErr, Message: p.err}
}
tests = append(tests, test{
@@ -353,7 +353,7 @@ func TestTopDownJWTEncodeSignHeaderErrors(t *testing.T) {
var exp interface{}
exp = fmt.Sprintf(`%s`, p.result)
if p.err != "" {
- exp = errors.New(p.err)
+ exp = &Error{Code: BuiltinErr, Message: p.err}
}
tests = append(tests, test{
@@ -456,7 +456,7 @@ func TestTopDownJWTEncodeSignRaw(t *testing.T) {
var exp interface{}
exp = fmt.Sprintf(`%s`, p.result)
if p.err != "" {
- exp = errors.New(p.err)
+ exp = &Error{Code: BuiltinErr, Message: p.err}
}
rawTests = append(rawTests, test{
@@ -844,7 +844,7 @@ func TestTopDownJWTBuiltins(t *testing.T) {
var exp interface{}
exp = fmt.Sprintf(`[%s, %s, "%s"]`, p.header, p.payload, p.signature)
if p.err != "" {
- exp = errors.New(p.err)
+ exp = &Error{Code: BuiltinErr, Message: p.err}
}
tests = append(tests, test{
@@ -1029,7 +1029,7 @@ func TestTopDownJWTVerifyRSA(t *testing.T) {
var exp interface{}
exp = fmt.Sprintf(`%t`, p.result)
if p.err != "" {
- exp = errors.New(p.err)
+ exp = &Error{Code: BuiltinErr, Message: p.err}
}
tests = append(tests, test{
@@ -1088,7 +1088,7 @@ func TestTopDownJWTVerifyHS256(t *testing.T) {
var exp interface{}
exp = fmt.Sprintf(`%t`, p.result)
if p.err != "" {
- exp = errors.New(p.err)
+ exp = &Error{Code: BuiltinErr, Message: p.err}
}
tests = append(tests, test{
diff --git a/topdown/topdown_test.go b/topdown/topdown_test.go
index a9b2789276..4a6a314e58 100644
--- a/topdown/topdown_test.go
+++ b/topdown/topdown_test.go
@@ -757,7 +757,7 @@ p[x] = y { data.enum_errors.a[x] = y }`,
assertTopDownWithPath(t, compiler, store, "base/virtual: missing input value", []string{"topdown", "u"}, "{}", "{}")
assertTopDownWithPath(t, compiler, store, "iterate ground", []string{"topdown", "iterate_ground"}, "{}", `["p", "r"]`)
assertTopDownWithPath(t, compiler, store, "base/virtual: conflicts", []string{"topdown.conflicts"}, "{}", `{"k": "foo"}`)
- assertTopDownWithPath(t, compiler, store, "enumerate virtual errors", []string{"enum_errors", "caller", "p"}, `{}`, fmt.Errorf("divide by zero"))
+ assertTopDownWithPath(t, compiler, store, "enumerate virtual errors", []string{"enum_errors", "caller", "p"}, `{}`, &Error{Code: BuiltinErr, Message: "divide by zero"})
}
func TestTopDownFix1863(t *testing.T) {
@@ -1097,11 +1097,11 @@ func TestTopDownArithmetic(t *testing.T) {
{"minus", []string{`p[y] { a[i] = x; y = i - x }`}, "[-1]"},
{"multiply", []string{`p[y] { a[i] = x; y = i * x }`}, "[0,2,6,12]"},
{"divide+round", []string{`p[z] { a[i] = x; y = i / x; round(y, z) }`}, "[0, 1]"},
- {"divide+error", []string{`p[y] { a[i] = x; y = x / i }`}, fmt.Errorf("divide by zero")},
+ {"divide+error", []string{`p[y] { a[i] = x; y = x / i }`}, &Error{Code: BuiltinErr, Message: "divide by zero"}},
{"abs", []string{`p = true { abs(-10, x); x = 10 }`}, "true"},
{"remainder", []string{`p = x { x = 7 % 4 }`}, "3"},
- {"remainder+error", []string{`p = x { x = 7 % 0 }`}, fmt.Errorf("modulo by zero")},
- {"remainder+error+floating", []string{`p = x { x = 1.1 % 1 }`}, fmt.Errorf("modulo on floating-point number")},
+ {"remainder+error", []string{`p = x { x = 7 % 0 }`}, &Error{Code: BuiltinErr, Message: "modulo by zero"}},
+ {"remainder+error+floating", []string{`p = x { x = 1.1 % 1 }`}, &Error{Code: BuiltinErr, Message: "modulo on floating-point number"}},
{"arity 1 ref dest", []string{`p = true { abs(-4, a[3]) }`}, "true"},
{"arity 1 ref dest (2)", []string{`p = true { not abs(-5, a[3]) }`}, "true"},
{"arity 2 ref dest", []string{`p = true { a[2] = 1 + 2 }`}, "true"},
@@ -1127,7 +1127,7 @@ func TestTopDownCasts(t *testing.T) {
"[-42.0, 0, 100.1, 0, 1]"},
{"to_number ref dest", []string{`p = true { to_number("3", a[2]) }`}, "true"},
{"to_number ref dest", []string{`p = true { not to_number("-1", a[2]) }`}, "true"},
- {"to_number: bad input", []string{`p { to_number("broken", x) }`}, fmt.Errorf("invalid syntax")},
+ {"to_number: bad input", []string{`p { to_number("broken", x) }`}, &Error{Code: BuiltinErr, Message: "invalid syntax"}},
}
data := loadSmallTestData()
@@ -1269,7 +1269,7 @@ func TestTopDownRegexMatch(t *testing.T) {
}{
{"re_match", []string{`p = true { re_match("^[a-z]+\\[[0-9]+\\]$", "foo[1]") }`}, "true"},
{"re_match: undefined", []string{`p = true { re_match("^[a-z]+\\[[0-9]+\\]$", "foo[\"bar\"]") }`}, ""},
- {"re_match: bad pattern err", []string{`p = true { re_match("][", "foo[\"bar\"]") }`}, fmt.Errorf("re_match: error parsing regexp: missing closing ]: `[`")},
+ {"re_match: bad pattern err", []string{`p = true { re_match("][", "foo[\"bar\"]") }`}, &Error{Code: BuiltinErr, Message: "re_match: error parsing regexp: missing closing ]: `[`"}},
{"re_match: ref", []string{`p[x] { re_match("^b.*$", d.e[x]) }`}, "[0,1]"},
{"re_match: raw", []string{fmt.Sprintf(`p = true { re_match(%s, "foo[1]") }`, "`^[a-z]+\\[[0-9]+\\]$`")}, "true"},
@@ -1309,7 +1309,7 @@ func TestTopDownGlobsMatch(t *testing.T) {
}{
{"regex.globs_match", []string{`p = true { regex.globs_match("a.a.[0-9]+z", ".b.b2359825792*594823z") }`}, "true"},
{"regex.globs_match", []string{`p = true { regex.globs_match("[a-z]+", "[0-9]*") }`}, ""},
- {"regex.globs_match: bad pattern err", []string{`p = true { regex.globs_match("pqrs]", "[a-b]+") }`}, fmt.Errorf("input:pqrs], pos:5, set-close ']' with no preceding '[': the input provided is invalid")},
+ {"regex.globs_match: bad pattern err", []string{`p = true { regex.globs_match("pqrs]", "[a-b]+") }`}, &Error{Code: BuiltinErr, Message: "input:pqrs], pos:5, set-close ']' with no preceding '[': the input provided is invalid"}},
{"regex.globs_match: ref", []string{`p[x] { regex.globs_match("b.*", d.e[x]) }`}, "[0,1]"},
{"regex.globs_match: raw", []string{fmt.Sprintf(`p = true { regex.globs_match(%s, "foo\\[1\\]") }`, "`[a-z]+\\[[0-9]+\\]`")}, "true"},
@@ -1354,7 +1354,7 @@ func TestTopDownStrings(t *testing.T) {
{"format_int: undefined", []string{`p = true { format_int(15.5, 16, "10000") }`}, ""},
{"format_int: ref dest", []string{`p = true { format_int(3.1, 10, numbers[2]) }`}, "true"},
{"format_int: ref dest (2)", []string{`p = true { not format_int(4.1, 10, numbers[2]) }`}, "true"},
- {"format_int: err: bad base", []string{`p = true { format_int(4.1, 199, x) }`}, fmt.Errorf("operand 2 must be one of {2, 8, 10, 16}")},
+ {"format_int: err: bad base", []string{`p = true { format_int(4.1, 199, x) }`}, &Error{Code: TypeErr, Message: "operand 2 must be one of {2, 8, 10, 16}"}},
{"concat", []string{`p = x { concat("/", ["", "foo", "bar", "0", "baz"], x) }`}, `"/foo/bar/0/baz"`},
{"concat: set", []string{`p = x { concat(",", {"1", "2", "3"}, x) }`}, `"1,2,3"`},
{"concat: undefined", []string{`p = true { concat("/", ["a", "b"], "deadbeef") }`}, ""},
@@ -1365,7 +1365,7 @@ func TestTopDownStrings(t *testing.T) {
{"substring", []string{`p = x { substring("abcdefgh", 2, 3, x) }`}, `"cde"`},
{"substring: remainder", []string{`p = x { substring("abcdefgh", 2, -1, x) }`}, `"cdefgh"`},
{"substring: too long", []string{`p = x { substring("abcdefgh", 2, 10000, x) }`}, `"cdefgh"`},
- {"substring: offset negative", []string{`p = x { substring("aaa", -1, -1, x) }`}, fmt.Errorf("negative offset")},
+ {"substring: offset negative", []string{`p = x { substring("aaa", -1, -1, x) }`}, &Error{Code: BuiltinErr, Message: "negative offset"}},
{"substring: offset too long", []string{`p = x { substring("aaa", 3, -1, x) }`}, `""`},
{"substring: offset too long 2", []string{`p = x { substring("aaa", 4, -1, x) }`}, `""`},
{"contains", []string{`p = true { contains("abcdefgh", "defg") }`}, "true"},
@@ -1415,9 +1415,9 @@ func TestTopDownJSONBuiltins(t *testing.T) {
}{
{"marshal", []string{`p = x { json.marshal([{"foo": {1,2,3}}], x) }`}, `"[{\"foo\":[1,2,3]}]"`},
{"unmarshal", []string{`p = x { json.unmarshal("[{\"foo\":[1,2,3]}]", x) }`}, `[{"foo": [1,2,3]}]"`},
- {"unmarshal-non-string", []string{`p = x { json.unmarshal(data.a[0], x) }`}, fmt.Errorf("operand 1 must be string but got number")},
+ {"unmarshal-non-string", []string{`p = x { json.unmarshal(data.a[0], x) }`}, &Error{Code: TypeErr, Message: "operand 1 must be string but got number"}},
{"yaml round-trip", []string{`p = y { yaml.marshal([{"foo": {1,2,3}}], x); yaml.unmarshal(x, y) }`}, `[{"foo": [1,2,3]}]`},
- {"yaml unmarshal error", []string{`p { yaml.unmarshal("[1,2,3", _) } `}, fmt.Errorf("yaml: line 1: did not find")},
+ {"yaml unmarshal error", []string{`p { yaml.unmarshal("[1,2,3", _) } `}, &Error{Code: BuiltinErr, Message: "yaml: line 1: did not find"}},
}
data := loadSmallTestData()
@@ -1524,7 +1524,7 @@ func TestTopDownTime(t *testing.T) {
p = [year, month, day] { [year, month, day] := time.date(1582977600*1000*1000*1000) }`}, "[2020, 2, 29]")
runTopDownTestCase(t, data, "date too big", []string{`
- p = [year, month, day] { [year, month, day] := time.date(1582977600*1000*1000*1000*1000) }`}, fmt.Errorf("timestamp too big"))
+ p = [year, month, day] { [year, month, day] := time.date(1582977600*1000*1000*1000*1000) }`}, &Error{Code: BuiltinErr, Message: "timestamp too big"})
runTopDownTestCase(t, data, "clock", []string{`
p = [hour, minute, second] { [hour, minute, second] := time.clock(1517832000*1000*1000*1000) }`}, "[12, 0, 0]")
@@ -1536,7 +1536,7 @@ func TestTopDownTime(t *testing.T) {
p = [hour, minute, second] { [hour, minute, second] := time.clock(1582977600*1000*1000*1000) }`}, "[12, 0, 0]")
runTopDownTestCase(t, data, "clock too big", []string{`
- p = [hour, minute, second] { [hour, minute, second] := time.clock(1582977600*1000*1000*1000*1000) }`}, fmt.Errorf("timestamp too big"))
+ p = [hour, minute, second] { [hour, minute, second] := time.clock(1582977600*1000*1000*1000*1000) }`}, &Error{Code: BuiltinErr, Message: "timestamp too big"})
for i, day := range []string{"Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"} {
ts := 1517832000*1000*1000*1000 + i*24*int(time.Hour)
@@ -1545,7 +1545,7 @@ func TestTopDownTime(t *testing.T) {
}
runTopDownTestCase(t, data, "weekday too big", []string{`
- p = weekday { weekday := time.weekday(1582977600*1000*1000*1000*1000) }`}, fmt.Errorf("timestamp too big"))
+ p = weekday { weekday := time.weekday(1582977600*1000*1000*1000*1000) }`}, &Error{Code: BuiltinErr, Message: "timestamp too big"})
}
func TestTopDownWalkBuiltin(t *testing.T) {
@@ -2879,7 +2879,11 @@ func runTopDownTestCaseWithModules(t *testing.T, data map[string]interface{}, no
compiler, err := compileRules(imports, rules, modules)
if err != nil {
- t.Errorf("%v: Compiler error: %v", note, err)
+ if _, ok := expected.(error); ok {
+ assertError(t, expected, err)
+ } else {
+ t.Errorf("%v: Compiler error: %v", note, err)
+ }
return
}
@@ -2941,27 +2945,9 @@ func assertTopDownWithPath(t *testing.T, compiler *ast.Compiler, store storage.S
testutil.Subtest(t, note, func(t *testing.T) {
switch e := expected.(type) {
- case Error:
- result, err := query.Run(ctx)
- if err == nil {
- t.Errorf("Expected error but got: %v", result)
- return
- }
- errString := err.Error()
- if !strings.Contains(errString, e.Code) || !strings.Contains(errString, e.Message) {
- t.Errorf("Expected error %v but got: %v", e, err)
- }
- case error:
- result, err := query.Run(ctx)
- if err == nil {
- t.Errorf("Expected error but got: %v", result)
- return
- }
-
- if !strings.Contains(err.Error(), e.Error()) {
- t.Errorf("Expected error %v but got: %v", e, err)
- }
-
+ case Error, error:
+ _, err := query.Run(ctx)
+ assertError(t, expected, err)
case string:
qrs, err := query.Run(ctx)
@@ -3197,3 +3183,43 @@ func dump(note string, modules map[string]*ast.Module, data interface{}, docpath
}
}
+
+func assertError(t *testing.T, expected interface{}, actual error) {
+ t.Helper()
+ if actual == nil {
+ t.Errorf("Expected error but got: %v", actual)
+ return
+ }
+
+ errString := actual.Error()
+
+ if reflect.TypeOf(expected) != reflect.TypeOf(actual) {
+ t.Errorf("Expected error of type '%T', got '%T'", expected, actual)
+ }
+
+ switch e := expected.(type) {
+ case Error:
+ assertErrorContains(t, errString, e.Code)
+ assertErrorContains(t, errString, e.Message)
+ case *Error:
+ assertErrorContains(t, errString, e.Code)
+ assertErrorContains(t, errString, e.Message)
+ case *ast.Error:
+ assertErrorContains(t, errString, e.Code)
+ assertErrorContains(t, errString, e.Message)
+ case ast.Errors:
+ for _, astErr := range e {
+ assertErrorContains(t, errString, astErr.Code)
+ assertErrorContains(t, errString, astErr.Message)
+ }
+ case error:
+ assertErrorContains(t, errString, e.Error())
+ }
+}
+
+func assertErrorContains(t *testing.T, actualErrMsg string, expected string) {
+ t.Helper()
+ if !strings.Contains(actualErrMsg, expected) {
+ t.Errorf("Expected error '%v' but got: '%v'", expected, actualErrMsg)
+ }
+}