diff --git a/ast/parser_ext.go b/ast/parser_ext.go index b2a157abad..7acfc53847 100644 --- a/ast/parser_ext.go +++ b/ast/parser_ext.go @@ -120,21 +120,64 @@ func MustParseTerm(input string) *Term { // of the form = can be converted into rules of the form = // { true }. This is a concise way of defining constants inside modules. func ParseRuleFromBody(module *Module, body Body) (*Rule, error) { + if len(body) != 1 { return nil, fmt.Errorf("multiple %vs cannot be used for %v", ExprTypeName, HeadTypeName) } expr := body[0] - if !expr.IsEquality() { - return nil, fmt.Errorf("non-equality %v cannot be used for %v", ExprTypeName, HeadTypeName) + + if expr.IsEquality() { + return parseRuleFromEquality(module, expr) + } else if !expr.IsBuiltin() { + return parseRuleFromTerm(module, expr.Terms.(*Term)) + } + + return nil, fmt.Errorf("%vs cannot be used for %v", TypeName(expr), RuleTypeName) +} + +func parseRuleFromTerm(module *Module, term *Term) (*Rule, error) { + + ref, ok := term.Value.(Ref) + if !ok { + return nil, fmt.Errorf("%vs cannot be used for %v", TypeName(term.Value), HeadTypeName) + } + + var name Var + var key *Term + + if v, ok := ref[0].Value.(Var); ok && len(ref) == 2 { + name = v + key = ref[1] + } else { + return nil, fmt.Errorf("%v cannot be used for %v", RefTypeName, RuleTypeName) + } + + rule := &Rule{ + Location: term.Location, + Head: &Head{ + Location: term.Location, + Name: name, + Key: key, + }, + Body: NewBody( + NewExpr(BooleanTerm(true).SetLocation(term.Location)).SetLocation(term.Location), + ), + Module: module, } + return rule, nil +} + +func parseRuleFromEquality(module *Module, expr *Expr) (*Rule, error) { + if len(expr.With) > 0 { return nil, fmt.Errorf("%vs using %v cannot be used for %v", ExprTypeName, WithTypeName, HeadTypeName) } terms := expr.Terms.([]*Term) var name Var + var key *Term switch v := terms[1].Value.(type) { case Var: @@ -142,6 +185,9 @@ func ParseRuleFromBody(module *Module, body Body) (*Rule, error) { case Ref: if v.Equal(InputRootRef) || v.Equal(DefaultRootRef) { name = Var(v.String()) + } else if n, ok := v[0].Value.(Var); ok && len(v) == 2 { + name = n + key = v[1] } else { return nil, fmt.Errorf("%v cannot be used for name of %v", RefTypeName, RuleTypeName) } @@ -154,10 +200,11 @@ func ParseRuleFromBody(module *Module, body Body) (*Rule, error) { Head: &Head{ Location: expr.Location, Name: name, + Key: key, Value: terms[2], }, Body: NewBody( - &Expr{Terms: BooleanTerm(true).SetLocation(expr.Location)}, + NewExpr(BooleanTerm(true).SetLocation(expr.Location)).SetLocation(expr.Location), ), Module: module, } diff --git a/ast/parser_test.go b/ast/parser_test.go index 7fe0103da7..1e57eab3a9 100644 --- a/ast/parser_test.go +++ b/ast/parser_test.go @@ -1034,17 +1034,28 @@ p[x] { x = 1 } greeting = "hello" { true } cores = [{0: 1}, {1: 2}] { true } wrapper = cores[0][1] { true } -pi = [3, 1, 4, x, y, z] { true }` +pi = [3, 1, 4, x, y, z] { true } +foo["bar"] = "buz" +foo["9"] = "10" +foo.buz = "bar" +bar[1] +bar[[{"foo":"baz"}]] +` assertParseModule(t, "rules from bodies", testModule, &Module{ Package: MustParseStatement(`package a.b.c`).(*Package), Rules: []*Rule{ - MustParseStatement(`pi = 3.14159 { true }`).(*Rule), - MustParseStatement(`p[x] { x = 1 }`).(*Rule), - MustParseStatement(`greeting = "hello" { true }`).(*Rule), - MustParseStatement(`cores = [{0: 1}, {1: 2}] { true }`).(*Rule), - MustParseStatement(`wrapper = cores[0][1] { true }`).(*Rule), - MustParseStatement(`pi = [3, 1, 4, x, y, z] { true }`).(*Rule), + MustParseRule(`pi = 3.14159 { true }`), + MustParseRule(`p[x] { x = 1 }`), + MustParseRule(`greeting = "hello" { true }`), + MustParseRule(`cores = [{0: 1}, {1: 2}] { true }`), + MustParseRule(`wrapper = cores[0][1] { true }`), + MustParseRule(`pi = [3, 1, 4, x, y, z] { true }`), + MustParseRule(`foo["bar"] = "buz" { true }`), + MustParseRule(`foo["9"] = "10" { true }`), + MustParseRule(`foo["buz"] = "bar" { true }`), + MustParseRule(`bar[1] { true }`), + MustParseRule(`bar[[{"foo":"baz"}]] { true }`), }, }) @@ -1079,23 +1090,28 @@ data = {"bar": 2} { true }` "pi" = 3 ` - refName := ` + withExpr := ` package a.b.c - input.x = true + foo = input with input as 1 ` - withExpr := ` + badRefLen1 := ` package a.b.c - foo = input with input as 1 - ` + p["x"].y = 1` + + badRefLen2 := ` + package a.b.c + + p["x"].y` assertParseModuleError(t, "multiple expressions", multipleExprs) assertParseModuleError(t, "non-equality", nonEquality) assertParseModuleError(t, "non-var name", nonVarName) - assertParseModuleError(t, "ref name", refName) assertParseModuleError(t, "with expr", withExpr) + assertParseModuleError(t, "bad ref (too long)", badRefLen1) + assertParseModuleError(t, "bad ref (too long)", badRefLen2) } func TestWildcards(t *testing.T) { diff --git a/ast/policy.go b/ast/policy.go index 36901430ce..2f34018666 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -1095,6 +1095,12 @@ func (expr *Expr) OutputVars(safe VarSet) VarSet { return VarSet{} } +// SetLocation sets the expr's location and returns the expr itself. +func (expr *Expr) SetLocation(loc *Location) *Expr { + expr.Location = loc + return expr +} + func (expr *Expr) String() string { var buf []string if expr.Negated { diff --git a/format/format.go b/format/format.go index 87988f212e..16b206bd2e 100644 --- a/format/format.go +++ b/format/format.go @@ -218,7 +218,8 @@ func (w *writer) writeRule(rule *ast.Rule, comments []*ast.Comment) []*ast.Comme // `foo = {"a": "b"} { true }` in the AST. We want to preserve that notation // in the formatted code instead of expanding the bodies into rules, so we // pretend that the rule has no body in this case. - isExpandedConst := rule.Head.DocKind() == ast.CompleteDoc && rule.Body.Equal(ast.NewBody(ast.NewExpr(ast.BooleanTerm(true)))) + isExpandedConst := rule.Body.Equal(ast.NewBody(ast.NewExpr(ast.BooleanTerm(true)))) + if len(rule.Body) == 0 || isExpandedConst { w.endLine() return comments @@ -409,7 +410,6 @@ func (w *writer) writeTerm(term *ast.Term, comments []*ast.Comment) []*ast.Comme switch x := term.Value.(type) { case ast.Ref: w.write(x.String()) - return comments case ast.Object: comments = w.writeObject(x, term.Location, comments) case ast.Array: @@ -423,9 +423,14 @@ func (w *writer) writeTerm(term *ast.Term, comments []*ast.Comment) []*ast.Comme case *ast.SetComprehension: comments = w.writeSetComprehension(x, term.Location, comments) case ast.String: - // To preserve raw strings, we need to output the original text, - // not what x.String() would give us. - w.write(string(term.Location.Text)) + if term.Location.Text[0] == '.' { + // This string was parsed from a ref, so preserve the value. + w.write(`"%s"`, string(x)) + } else { + // To preserve raw strings, we need to output the original text, + // not what x.String() would give us. + w.write(string(term.Location.Text)) + } case fmt.Stringer: w.write(x.String()) } diff --git a/format/testfiles/test.rego b/format/testfiles/test.rego index c36fce158d..85941b50e9 100644 --- a/format/testfiles/test.rego +++ b/format/testfiles/test.rego @@ -19,6 +19,16 @@ not x = g globals = {"foo": "bar", "fizz": "buzz"} +partial_obj["x"] = 1 +partial_obj.y = 2 + +partial_obj["z"] = 3 { + true +} + +partial_set["x"] +partial_set.y + # Latent comment. r = y { @@ -84,8 +94,8 @@ a = {"a": "b", "c": "d"} b = [1, 2, 3, 4] c = [1, 2, # Comment inside array -3, 4, -5, 6, 7, +3, 4, +5, 6, 7, 8, ] # Comment on closing array bracket. @@ -115,7 +125,7 @@ m = {y: x | split("foo.bar", ".", x); y = x[_]} n = {y: x | split("foo.bar", ".", x) y = x[_]} -o = {y: x | +o = {y: x | split("foo.bar", ".", x) # comment in object comprehension y = x[_] diff --git a/format/testfiles/test.rego.formatted b/format/testfiles/test.rego.formatted index 8f8e01111d..e2b6cdef22 100644 --- a/format/testfiles/test.rego.formatted +++ b/format/testfiles/test.rego.formatted @@ -22,6 +22,16 @@ globals = { "fizz": "buzz", } +partial_obj["x"] = 1 + +partial_obj["y"] = 2 + +partial_obj["z"] = 3 + +partial_set["x"] + +partial_set["y"] + # Latent comment. r = y { diff --git a/repl/repl.go b/repl/repl.go index 7ce492906c..0f2d9ef92a 100644 --- a/repl/repl.go +++ b/repl/repl.go @@ -762,6 +762,10 @@ func (r *REPL) evalStatement(ctx context.Context, stmt interface{}) error { return err } + // This will only parse rules from equality expressions that can be + // interpreted as rules defining complete docs because rules defining + // partial sets/objects would fail to compile above (due to the head of + // the ref being unsafe, e.g., p["foo"] = "bar". rule, err3 := ast.ParseRuleFromBody(r.modules[r.currentModuleID], body) if err3 == nil { return r.compileRule(ctx, rule) diff --git a/repl/repl_test.go b/repl/repl_test.go index 0ded5b1b46..7084631ae8 100644 --- a/repl/repl_test.go +++ b/repl/repl_test.go @@ -298,13 +298,9 @@ import input.xyz` + "\n\n" import data.foo as bar import input.xyz -p[1] { - true -} +p[1] -p[2] { - true -}` + "\n\n" +p[2]` + "\n\n" assertREPLText(t, buffer, expected) buffer.Reset() diff --git a/topdown/topdown_test.go b/topdown/topdown_test.go index 56c097f38f..54ee55cf02 100644 --- a/topdown/topdown_test.go +++ b/topdown/topdown_test.go @@ -1656,6 +1656,73 @@ loopback = input { true }`}) }) } +func TestTopDownPartialDocConstants(t *testing.T) { + compiler := compileModules([]string{ + `package ex + + foo["bar"] = 0 + foo["baz"] = 1 + foo["*"] = [1, 2, 3] { + input.foo = 7 + } + + bar["x"] + bar["y"] + bar["*"] { + input.foo = 7 + } + `}) + + store := inmem.NewFromObject(loadSmallTestData()) + ctx := context.Background() + txn := storage.NewTransactionOrDie(ctx, store) + defer store.Abort(ctx, txn) + + tests := []struct { + note string + path string + input string + expected string + }{ + { + note: "obj-1", + path: "ex.foo.bar", + expected: "0", + }, + { + note: "obj", + path: "ex.foo", + expected: `{"bar": 0, "baz": 1}`, + }, + { + note: "obj-all", + path: "ex.foo", + input: `{"foo": 7}`, + expected: `{"bar": 0, "baz": 1, "*": [1,2,3]}`, + }, + { + note: "set-1", + path: "ex.bar.x", + expected: `"x"`, + }, + { + note: "set", + path: "ex.bar", + expected: `["x", "y"]`, + }, + { + note: "set-all", + path: "ex.bar", + input: `{"foo": 7}`, + expected: `["x", "y", "*"]`, + }, + } + + for _, tc := range tests { + assertTopDownWithPath(t, compiler, store, tc.note, strings.Split(tc.path, "."), tc.input, tc.expected) + } +} + func TestTopDownUserFunc(t *testing.T) { compiler := compileModules([]string{ `package ex