diff --git a/internal/ir/ir.go b/internal/ir/ir.go index 81c7608a3d..b680f97847 100644 --- a/internal/ir/ir.go +++ b/internal/ir/ir.go @@ -28,7 +28,7 @@ type ( // Funcs represents a collection of planned functions to include in the // policy. Funcs struct { - Funcs map[string]*Func + Funcs []*Func } // Func represents a named plan (function) that can be invoked. Functions diff --git a/internal/ir/walk.go b/internal/ir/walk.go index 818ee0d986..7c62e11b81 100644 --- a/internal/ir/walk.go +++ b/internal/ir/walk.go @@ -4,8 +4,6 @@ package ir -import "sort" - // Visitor defines the interface for visiting IR nodes. type Visitor interface { Before(x interface{}) @@ -54,13 +52,8 @@ func (w *walkerImpl) walk(x interface{}) { w.walk(s) } case *Funcs: - keys := make([]string, 0, len(x.Funcs)) - for k := range x.Funcs { - keys = append(keys, k) - } - sort.Strings(keys) - for _, k := range keys { - w.walk(x.Funcs[k]) + for _, fn := range x.Funcs { + w.walk(fn) } case *Func: for _, b := range x.Blocks { diff --git a/internal/planner/functrie.go b/internal/planner/functrie.go deleted file mode 100644 index 882175c994..0000000000 --- a/internal/planner/functrie.go +++ /dev/null @@ -1,96 +0,0 @@ -package planner - -import ( - "sort" - - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/internal/ir" -) - -// functrie implements a simple trie structure for organizing planned functions. -// The functions are organized to facilitate access when planning references -// against the data document. -type functrie struct { - children map[ast.Value]*functrie - val *functrieValue -} - -type functrieValue struct { - Fn *ir.Func - Rules []*ast.Rule -} - -func (val *functrieValue) Arity() int { - return len(val.Rules[0].Head.Args) -} - -func newFunctrie() *functrie { - return &functrie{ - children: map[ast.Value]*functrie{}, - } -} - -func (t *functrie) Insert(key ast.Ref, val *functrieValue) { - node := t - for _, elem := range key { - child, ok := node.children[elem.Value] - if !ok { - child = newFunctrie() - node.children[elem.Value] = child - } - node = child - } - node.val = val -} - -func (t *functrie) Lookup(key ast.Ref) *functrieValue { - node := t - for _, elem := range key { - var ok bool - if node == nil { - return nil - } else if node, ok = node.children[elem.Value]; !ok { - return nil - } - } - return node.val -} - -func (t *functrie) LookupOrInsert(key ast.Ref, orElse *functrieValue) *functrieValue { - if val := t.Lookup(key); val != nil { - return val - } - t.Insert(key, orElse) - return orElse -} - -func (t *functrie) FuncMap() map[string]*ir.Func { - result := map[string]*ir.Func{} - t.toMap(result) - return result -} - -func (t *functrie) Children() []ast.Value { - - sorted := make([]ast.Value, 0, len(t.children)) - - for key := range t.children { - sorted = append(sorted, key) - } - - sort.Slice(sorted, func(i, j int) bool { - return sorted[i].Compare(sorted[j]) < 0 - }) - - return sorted -} - -func (t *functrie) toMap(result map[string]*ir.Func) { - if t.val != nil { - result[t.val.Fn.Name] = t.val.Fn - return - } - for _, node := range t.children { - node.toMap(result) - } -} diff --git a/internal/planner/planner.go b/internal/planner/planner.go index 37833f28e4..c1456992cf 100644 --- a/internal/planner/planner.go +++ b/internal/planner/planner.go @@ -18,13 +18,13 @@ type binaryiter func(ir.Local, ir.Local) error // Planner implements a query planner for Rego queries. type Planner struct { + policy *ir.Policy // result of planning queries []ast.Body // input query to plan modules []*ast.Module // input modules to support queries rewritten map[ast.Var]ast.Var // rewritten query vars - strindex map[string]int // global string constant indices - strings []*ir.StringConst // planned (global) string constants - blocks []*ir.Block // planned blocks - funcs *functrie // planned functions to support blocks + strings map[string]int // global string constant indices + rules *ruletrie // rules that may be planned + funcs *funcstack // functions that have been planned curr *ir.Block // in-progress query block vars *varstack // in-scope variables ltarget ir.Local // target variable of last planned statement @@ -34,13 +34,19 @@ type Planner struct { // New returns a new Planner object. func New() *Planner { return &Planner{ - strindex: map[string]int{}, - lcurr: ir.Unused, + policy: &ir.Policy{ + Static: &ir.Static{}, + Plan: &ir.Plan{}, + Funcs: &ir.Funcs{}, + }, + strings: map[string]int{}, + lcurr: ir.Unused, vars: newVarstack(map[ast.Var]ir.Local{ ast.InputRootDocument.Value.(ast.Var): ir.Input, ast.DefaultRootDocument.Value.(ast.Var): ir.Data, }), - funcs: newFunctrie(), + rules: newRuletrie(), + funcs: newFuncstack(), } } @@ -67,7 +73,7 @@ func (p *Planner) WithRewrittenVars(vs map[ast.Var]ast.Var) *Planner { // Plan returns a IR plan for the policy query. func (p *Planner) Plan() (*ir.Policy, error) { - if err := p.planModules(); err != nil { + if err := p.buildFunctrie(); err != nil { return nil, err } @@ -75,25 +81,10 @@ func (p *Planner) Plan() (*ir.Policy, error) { return nil, err } - policy := &ir.Policy{ - Static: &ir.Static{ - Strings: p.strings, - }, - Plan: &ir.Plan{ - Blocks: p.blocks, - }, - Funcs: &ir.Funcs{ - Funcs: p.funcs.FuncMap(), - }, - } - - return policy, nil + return p.policy, nil } -func (p *Planner) planModules() error { - - // Build a set of all the rulesets to plan. - funcs := map[*functrieValue]struct{}{} +func (p *Planner) buildFunctrie() error { for _, module := range p.modules { @@ -106,33 +97,30 @@ func (p *Planner) planModules() error { // // Expected result: {"y": {}} if len(module.Rules) == 0 { - _ = p.funcs.LookupOrInsert(module.Package.Path, nil) + _ = p.rules.LookupOrInsert(module.Package.Path) continue } for _, rule := range module.Rules { - val := p.funcs.LookupOrInsert(rule.Path(), &functrieValue{}) - val.Rules = append(val.Rules, rule) - funcs[val] = struct{}{} - } - } - - for val := range funcs { - if err := p.planRules(val); err != nil { - return err + val := p.rules.LookupOrInsert(rule.Path()) + val.rules = append(val.rules, rule) } } return nil } -func (p *Planner) planRules(trieNode *functrieValue) error { +func (p *Planner) planRules(rules []*ast.Rule) (string, error) { - rules := trieNode.Rules + path := rules[0].Path().String() + + if funcName, ok := p.funcs.Get(path); ok { + return funcName, nil + } // Create function definition for rules. fn := &ir.Func{ - Name: rules[0].Path().String(), + Name: fmt.Sprintf("g%d.%s", p.funcs.gen, path), Params: []ir.Local{ p.newLocal(), p.newLocal(), @@ -140,8 +128,6 @@ func (p *Planner) planRules(trieNode *functrieValue) error { Return: p.newLocal(), } - trieNode.Fn = fn - // Initialize parameters for functions. for i := 0; i < len(rules[0].Head.Args); i++ { fn.Params = append(fn.Params, p.newLocal()) @@ -273,7 +259,7 @@ func (p *Planner) planRules(trieNode *functrieValue) error { }) if err != nil { - return err + return "", err } } } @@ -296,7 +282,7 @@ func (p *Planner) planRules(trieNode *functrieValue) error { }) return nil }); err != nil { - return err + return "", err } } @@ -309,11 +295,14 @@ func (p *Planner) planRules(trieNode *functrieValue) error { }, }) + p.appendFunc(fn) + p.funcs.Add(path, fn.Name) + // Restore the state of the planner. p.vars = currVars p.curr = currBlock - return nil + return fn.Name, nil } func (p *Planner) planFuncParams(params []ir.Local, args ast.Args, idx int, iter planiter) error { @@ -347,7 +336,7 @@ func (p *Planner) planQueries() error { qv = p.rewrittenVar(qv) if !qv.IsGenerated() && !qv.IsWildcard() { stmt := &ir.MakeStringStmt{ - Index: p.appendStringConst(string(qv)), + Index: p.getStringConst(string(qv)), Target: p.newLocal(), } p.appendStmt(stmt) @@ -355,7 +344,7 @@ func (p *Planner) planQueries() error { } } - p.blocks = append(p.blocks, p.curr) + p.appendBlock(p.curr) for _, q := range p.queries { p.curr = &ir.Block{} @@ -397,11 +386,11 @@ func (p *Planner) planQueries() error { p.vars.Pop() if defined { - p.blocks = append(p.blocks, p.curr) + p.appendBlock(p.curr) } } - p.blocks = append(p.blocks, &ir.Block{ + p.appendBlock(&ir.Block{ Stmts: []ir.Stmt{ &ir.ReturnLocalStmt{ Source: lresultset, @@ -431,7 +420,7 @@ func (p *Planner) planExpr(e *ast.Expr, iter planiter) error { } if len(e.With) > 0 { - return fmt.Errorf("with keyword not implemented") + return p.planWith(e, iter) } if e.IsCall() { @@ -462,6 +451,140 @@ func (p *Planner) planNot(e *ast.Expr, iter planiter) error { return iter() } +func (p *Planner) planWith(e *ast.Expr, iter planiter) error { + + // Plan the values that will be applied by the with modifiers. All values + // must be defined for the overall expression to evaluate. + values := make([]*ast.Term, len(e.With)) + + for i := range e.With { + values[i] = e.With[i].Value + } + + return p.planTermSlice(values, func(locals []ir.Local) error { + + // Save the current values of the input and cached data documents. + linput := p.vars.GetOrEmpty(ast.InputRootDocument.Value.(ast.Var)) + ldata := p.vars.GetOrEmpty(ast.DefaultRootDocument.Value.(ast.Var)) + lprev := p.planSaveLocals(linput, ldata) + datarefs := []ast.Ref{} + + // Apply with modifiers to the input and cached data documents. + for i := range e.With { + + target := e.With[i].Target.Value.(ast.Ref) + head := target[0].Value.(ast.Var) + lhead := p.vars.GetOrEmpty(head) + lpath := make([]int, len(target)-1) + + for i := 1; i < len(target); i++ { + if s, ok := target[i].Value.(ast.String); ok { + lpath[i-1] = p.getStringConst(string(s)) + } else { + return errors.New("invalid with target") + } + } + + p.appendStmt(&ir.WithStmt{ + Source: lhead, + Path: lpath, + Value: locals[i], + Target: lhead, + }) + + if target[0].Equal(ast.DefaultRootDocument) { + datarefs = append(datarefs, target) + } + } + + // Save the new values of the input and cached data documents so they + // can be re-applied. + // + // NOTE(tsandall): If either of the documents have not been modified + // these statements are no-ops. We could safely elide them in the future + // if which heads had been touched. The same holds for the undo and + // re-apply operations below. + lmodified := p.planSaveLocals(linput, ldata) + prev := p.curr + p.curr = &ir.Block{} + + // If any of the with statements targeted the data document we shadow + // the existing planned functions during expression planning. This + // causes the planner to re-plan any rules that may be required during + // planning of this expression (transitively). + if len(datarefs) > 0 { + p.funcs.Push(map[string]string{}) + for _, ref := range datarefs { + p.rules.Push(ref) + } + } + + err := p.planExpr(e.NoWith(), func() error { + + var funcs map[string]string + + if len(datarefs) > 0 { + funcs = p.funcs.Pop() + for i := len(datarefs) - 1; i >= 0; i-- { + p.rules.Pop(datarefs[i]) + } + } + + // Undo the with modifiers by restoring the saved input and cached + // data document values. + p.appendStmt(&ir.AssignVarStmt{Source: lprev[0], Target: linput}) + p.appendStmt(&ir.AssignVarStmt{Source: lprev[1], Target: ldata}) + prev := p.curr + p.curr = &ir.Block{} + + if err := iter(); err != nil { + return err + } + + block := p.curr + p.curr = prev + p.appendStmt(&ir.BlockStmt{Blocks: []*ir.Block{block}}) + + // Re-apply the modified input and cached data documents in case we + // re-enter. + p.appendStmt(&ir.AssignVarStmt{Source: lmodified[0], Target: linput}) + p.appendStmt(&ir.AssignVarStmt{Source: lmodified[1], Target: ldata}) + + if len(datarefs) > 0 { + p.funcs.Push(funcs) + for _, ref := range datarefs { + p.rules.Push(ref) + } + } + + return nil + }) + + if err != nil { + return err + } + + block := p.curr + p.curr = prev + + p.appendStmt(&ir.BlockStmt{Blocks: []*ir.Block{block}}) + + // Restore the original input and cahced data document values after + // generating the plan for the rest of the query. + p.appendStmt(&ir.AssignVarStmt{Source: lprev[0], Target: linput}) + p.appendStmt(&ir.AssignVarStmt{Source: lprev[1], Target: ldata}) + + if len(datarefs) > 0 { + _ = p.funcs.Pop() + for i := len(datarefs) - 1; i >= 0; i-- { + p.rules.Pop(datarefs[i]) + } + } + + return nil + }) +} + func (p *Planner) planExprTerm(e *ast.Expr, iter planiter) error { return p.planTerm(e.Terms.(*ast.Term), func() error { falsy := p.newLocal() @@ -531,12 +654,17 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error { return iter() }) default: - trieNode := p.funcs.Lookup(e.Operator()) - if trieNode == nil { + node := p.rules.Lookup(e.Operator()) + if node == nil { return fmt.Errorf("illegal call: unknown operator %v", operator) } - arity := trieNode.Arity() + funcName, err := p.planRules(node.Rules()) + if err != nil { + return err + } + + arity := node.Arity() operands := e.Operands() args := []ir.Local{ @@ -550,7 +678,7 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error { return p.planCallArgs(operands, 0, args, func(args []ir.Local) error { p.ltarget = p.newLocal() p.appendStmt(&ir.CallStmt{ - Func: operator, + Func: funcName, Args: args, Result: p.ltarget, }) @@ -562,7 +690,7 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error { return p.planCallArgs(operands[:len(operands)-1], 0, args, func(args []ir.Local) error { result := p.newLocal() p.appendStmt(&ir.CallStmt{ - Func: operator, + Func: funcName, Args: args, Result: result, }) @@ -922,7 +1050,7 @@ func (p *Planner) planNumberInt(i int64, iter planiter) error { func (p *Planner) planString(str ast.String, iter planiter) error { - index := p.appendStringConst(string(str)) + index := p.getStringConst(string(str)) target := p.newLocal() p.appendStmt(&ir.MakeStringStmt{ @@ -1122,7 +1250,7 @@ func (p *Planner) planRef(ref ast.Ref, iter planiter) error { } if head.Compare(ast.DefaultRootDocument.Value) == 0 { - virtual := p.funcs.children[ref[0].Value] + virtual := p.rules.Get(ref[0].Value) base := &baseptr{local: p.vars.GetOrEmpty(ast.DefaultRootDocument.Value.(ast.Var))} return p.planRefData(virtual, base, ref, 1, iter) } @@ -1172,7 +1300,7 @@ type baseptr struct { // planRefData implements the virtual document model by generating the value of // the ref parameter and invoking the iterator with the planner target set to // the virtual document and all variables in the reference assigned. -func (p *Planner) planRefData(virtual *functrie, base *baseptr, ref ast.Ref, index int, iter planiter) error { +func (p *Planner) planRefData(virtual *ruletrie, base *baseptr, ref ast.Ref, index int, iter planiter) error { // Early-exit if the end of the reference has been reached. In this case the // plan has to materialize the full extent of the referenced value. @@ -1184,22 +1312,31 @@ func (p *Planner) planRefData(virtual *functrie, base *baseptr, ref ast.Ref, ind // operand or invoke the function for the rule referred to by this operand. if ref[index].IsGround() { - var vchild *functrie + var vchild *ruletrie if virtual != nil { - vchild = virtual.children[ref[index].Value] + vchild = virtual.Get(ref[index].Value) } - if vchild != nil && vchild.val != nil { + rules := vchild.Rules() + + if len(rules) > 0 { p.ltarget = p.newLocal() + + funcName, err := p.planRules(rules) + if err != nil { + return err + } + p.appendStmt(&ir.CallStmt{ - Func: vchild.val.Rules[0].Path().String(), + Func: funcName, Args: []ir.Local{ p.vars.GetOrEmpty(ast.InputRootDocument.Value.(ast.Var)), p.vars.GetOrEmpty(ast.DefaultRootDocument.Value.(ast.Var)), }, Result: p.ltarget, }) + return p.planRefRec(ref, index+1, iter) } @@ -1299,7 +1436,7 @@ func (p *Planner) planRefData(virtual *functrie, base *baseptr, ref ast.Ref, ind // planRefDataExtent generates the full extent (combined) of the base and // virtual nodes and then invokes the iterator with the planner target set to // the full extent. -func (p *Planner) planRefDataExtent(virtual *functrie, base *baseptr, iter planiter) error { +func (p *Planner) planRefDataExtent(virtual *ruletrie, base *baseptr, iter planiter) error { vtarget := p.newLocal() @@ -1312,22 +1449,25 @@ func (p *Planner) planRefDataExtent(virtual *functrie, base *baseptr, iter plani Target: vtarget, }) - for key, child := range virtual.children { + for _, key := range virtual.Children() { + child := virtual.Get(key) // Skip functions. - if child.val != nil && child.val.Arity() > 0 { + if child.Arity() > 0 { continue } lkey := p.newLocal() - idx := p.appendStringConst(string(key.(ast.String))) + idx := p.getStringConst(string(key.(ast.String))) p.appendStmt(&ir.MakeStringStmt{ Index: idx, Target: lkey, }) + rules := child.Rules() + // Build object hierarchy depth-first. - if child.val == nil { + if len(rules) == 0 { err := p.planRefDataExtent(child, nil, func() error { p.appendStmt(&ir.ObjectInsertStmt{ Object: vtarget, @@ -1345,13 +1485,18 @@ func (p *Planner) planRefDataExtent(virtual *functrie, base *baseptr, iter plani // Generate virtual document for leaf. lvalue := p.newLocal() + funcName, err := p.planRules(rules) + if err != nil { + return err + } + // Add leaf to object if defined. p.appendStmt(&ir.BlockStmt{ Blocks: []*ir.Block{ &ir.Block{ Stmts: []ir.Stmt{ &ir.CallStmt{ - Func: child.val.Rules[0].Path().String(), + Func: funcName, Args: []ir.Local{ p.vars.GetOrEmpty(ast.InputRootDocument.Value.(ast.Var)), p.vars.GetOrEmpty(ast.DefaultRootDocument.Value.(ast.Var)), @@ -1500,22 +1645,66 @@ func (p *Planner) planScan(key *ast.Term, iter scaniter) error { } -func (p *Planner) appendStmt(s ir.Stmt) { - p.curr.Stmts = append(p.curr.Stmts, s) +// planSaveLocals returns a slice of locals holding temporary variables that +// have been assigned from the supplied vars. +func (p *Planner) planSaveLocals(vars ...ir.Local) []ir.Local { + + lsaved := make([]ir.Local, len(vars)) + + for i := range vars { + + lsaved[i] = p.newLocal() + + p.appendStmt(&ir.AssignVarStmt{ + Source: vars[i], + Target: lsaved[i], + }) + } + + return lsaved +} + +type termsliceiter func([]ir.Local) error + +func (p *Planner) planTermSlice(terms []*ast.Term, iter termsliceiter) error { + return p.planTermSliceRec(terms, make([]ir.Local, len(terms)), 0, iter) } -func (p *Planner) appendStringConst(s string) int { - index, ok := p.strindex[s] +func (p *Planner) planTermSliceRec(terms []*ast.Term, locals []ir.Local, index int, iter termsliceiter) error { + if index >= len(terms) { + return iter(locals) + } + + return p.planTerm(terms[index], func() error { + locals[index] = p.ltarget + return p.planTermSliceRec(terms, locals, index+1, iter) + }) +} + +func (p *Planner) getStringConst(s string) int { + index, ok := p.strings[s] if !ok { - index = len(p.strings) - p.strings = append(p.strings, &ir.StringConst{ + index = len(p.policy.Static.Strings) + p.policy.Static.Strings = append(p.policy.Static.Strings, &ir.StringConst{ Value: s, }) - p.strindex[s] = index + p.strings[s] = index } return index } +func (p *Planner) appendStmt(s ir.Stmt) { + p.curr.Stmts = append(p.curr.Stmts, s) +} + +func (p *Planner) appendFunc(f *ir.Func) { + p.policy.Funcs.Funcs = append(p.policy.Funcs.Funcs, f) +} + +func (p *Planner) appendBlock(b *ir.Block) { + p.policy.Plan.Blocks = append(p.policy.Plan.Blocks, b) +} + func (p *Planner) newLocal() ir.Local { x := p.lcurr p.lcurr++ diff --git a/internal/planner/planner_test.go b/internal/planner/planner_test.go index 03245f715a..88b1154f3f 100644 --- a/internal/planner/planner_test.go +++ b/internal/planner/planner_test.go @@ -177,6 +177,24 @@ func TestPlannerHelloWorld(t *testing.T) { note: "variables in query", queries: []string{"x = 1", "y = 2", "x = 1; y = 2"}, }, + { + note: "with keyword", + queries: []string{ + `input[i] = 1 with input as [1]; i > 1`, + }, + }, + { + note: "with keyword data", + queries: []string{`data = x with data.foo as 1 with data.bar.r as 3`}, + modules: []string{ + `package foo + + p = 1`, + `package bar + + q = 2`, + }, + }, } for _, tc := range tests { diff --git a/internal/planner/rules.go b/internal/planner/rules.go new file mode 100644 index 0000000000..e2d9d15619 --- /dev/null +++ b/internal/planner/rules.go @@ -0,0 +1,156 @@ +package planner + +import ( + "sort" + + "github.com/open-policy-agent/opa/ast" +) + +// funcstack implements a simple map structure used to keep track of virtual +// document => planned function names. The structure supports Push and Pop +// operations so that the planner can shadow planned functions when 'with' +// statements are found. +type funcstack struct { + stack []map[string]string + gen int +} + +func newFuncstack() *funcstack { + return &funcstack{ + stack: []map[string]string{ + map[string]string{}, + }, + gen: 0, + } +} + +func (p funcstack) Add(key, value string) { + p.stack[len(p.stack)-1][key] = value +} + +func (p funcstack) Get(key string) (string, bool) { + value, ok := p.stack[len(p.stack)-1][key] + return value, ok +} + +func (p *funcstack) Push(funcs map[string]string) { + p.stack = append(p.stack, funcs) + p.gen++ +} + +func (p *funcstack) Pop() map[string]string { + last := p.stack[len(p.stack)-1] + p.stack = p.stack[:len(p.stack)-1] + p.gen++ + return last +} + +// ruletrie implements a simple trie structure for organizing rules that may be +// planned. The trie nodes are keyed by the rule path. The ruletrie supports +// Push and Pop operations that allow the planner to shadow subtrees when 'with' +// statements are found. +type ruletrie struct { + children map[ast.Value][]*ruletrie + rules []*ast.Rule +} + +func newRuletrie() *ruletrie { + return &ruletrie{ + children: map[ast.Value][]*ruletrie{}, + } +} + +func (t *ruletrie) Arity() int { + rules := t.Rules() + if len(rules) > 0 { + return len(rules[0].Head.Args) + } + return 0 +} + +func (t *ruletrie) Rules() []*ast.Rule { + if t != nil { + return t.rules + } + return nil +} + +func (t *ruletrie) Push(key ast.Ref) { + node := t + for i := 0; i < len(key)-1; i++ { + node = node.Get(key[i].Value) + if node == nil { + return + } + } + elem := key[len(key)-1] + node.children[elem.Value] = append(node.children[elem.Value], nil) +} + +func (t *ruletrie) Pop(key ast.Ref) { + node := t + for i := 0; i < len(key)-1; i++ { + node = node.Get(key[i].Value) + if node == nil { + return + } + } + elem := key[len(key)-1] + sl := node.children[elem.Value] + node.children[elem.Value] = sl[:len(sl)-1] +} + +func (t *ruletrie) Insert(key ast.Ref) *ruletrie { + node := t + for _, elem := range key { + child := node.Get(elem.Value) + if child == nil { + child = newRuletrie() + node.children[elem.Value] = append(node.children[elem.Value], child) + } + node = child + } + return node +} + +func (t *ruletrie) Lookup(key ast.Ref) *ruletrie { + node := t + for _, elem := range key { + node = node.Get(elem.Value) + if node == nil { + return nil + } + } + return node +} + +func (t *ruletrie) LookupOrInsert(key ast.Ref) *ruletrie { + if val := t.Lookup(key); val != nil { + return val + } + return t.Insert(key) +} + +func (t *ruletrie) Children() []ast.Value { + sorted := make([]ast.Value, 0, len(t.children)) + for key := range t.children { + if t.Get(key) != nil { + sorted = append(sorted, key) + } + } + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].Compare(sorted[j]) < 0 + }) + return sorted +} + +func (t *ruletrie) Get(k ast.Value) *ruletrie { + if t == nil { + return nil + } + nodes := t.children[k] + if len(nodes) == 0 { + return nil + } + return nodes[len(nodes)-1] +} diff --git a/test/wasm/assets/016_with.yaml b/test/wasm/assets/016_with.yaml new file mode 100644 index 0000000000..618bde33ee --- /dev/null +++ b/test/wasm/assets/016_with.yaml @@ -0,0 +1,220 @@ +cases: + - note: assignment + query: | + input = {"a": 1} with input as {"a": 1} + want_defined: true + - note: assignment (negative) + query: | + input = {"a": 1} with input as {"a": "deadbeef"} + want_defined: false + - note: assignment override + query: | + input = {"a": 2} with input as {"a": 2} + input: {"a": 1} + want_defined: true + - note: assignment override (negative) + query: | + input = {"a": 2} with input as {"a": "deadbeef"} + input: {"a": 2} + want_defined: false + - note: assignment undo + query: | + input = {"a": 1} with input as {"a": 1} + input = {"a": 2} + input: {"a": 2} + want_defined: true + - note: assignment iteration + query: | + input[i] = 1 with input as [1,2,1] + want_result: [{'i': 0}, {'i': 2}] + - note: assignment transitive + query: | + data.test.p = x with input as "p" + modules: + - | + package test + + p = x { + q = x with input as ["q", input] + } + + q = x { + r = x # intentionally unmodified, with keyword applies transitively + } + + r = input + want_result: [{'x': ["q", "p"]}] + - note: assignment undo across queries + query: | + data.test.p[x] + modules: + - | + package test + + p[x] { + x = input.a with input as {"a": 1} + } + + p[y] { + y = input.b with input as {"b": 2} + } + + p[t] { + t = input.b # expected to be undefined + } + + p[u] { + u = input.a # expected to be undefined + } + want_result: [{'x': 1}, {'x': 2}] + - note: upsert + query: | + input = x with input.foo as 1 + want_result: [{'x': {'foo': 1}}] + - note: upsert make intermediate nodes + query: | + input = x with input.foo.bar.baz as [1,2,3] + want_result: [{'x': {'foo': {'bar': {'baz': [1,2,3]}}}}] + - note: upsert merge top-level + query: | + input = x with input.foo as 1 + input: {'bar': 2} + want_result: [{'x': {'foo': 1, 'bar': 2}}] + - note: upsert merge top-level make intermediate nodes + query: | + input = x with input.foo.bar as 1 + input: {'baz': 2} + want_result: [{'x': {'foo': {'bar': 1}, 'baz': 2}}] + - note: upsert merge intermediate nodes + query: | + input = x with input.foo.bar as 1 + input: {'foo': {'baz': 2}} + want_result: [{'x': {'foo': {'bar': 1, 'baz': 2}}}] + - note: upsert merge intermediate nodes with new node + query: | + input = x with input.foo.bar.qux as 1 + input: { 'foo': {'baz': 2}} + want_result: [{'x': {'foo': {'bar': {'qux': 1}, 'baz': 2}}}] + - note: upsert merge top-level multiple + query: | + input = x with input.foo as 1 with input.bar as 2 + want_result: [{'x': {'foo': 1, 'bar': 2}}] + - note: upsert merge intermediate multiple + query: | + input = x with input.foo.bar as 1 with input.foo.baz as 2 + want_result: [{'x': {'foo': {'bar': 1, 'baz': 2}}}] + - note: upsert iteration + query: | + input.foo[x] = y with input.foo.bar as 1 with input.foo.baz as 2 + want_result: [ + { + 'x': 'bar', + 'y': 1, + }, + { + 'x': 'baz', + 'y': 2, + }, + ] + - note: shadow rules + query: | + data = x with data.foo as 1 + want_result: [ + { + 'x': { + 'foo': 1 + } + } + ] + modules: + - | + package foo + + p = 1 + - note: shadow rules and merge + query: | + data = x with data.foo as 1 with data.bar.r as 3 + want_result: [ + { + 'x': { + 'foo': 1, + 'bar': { + 'q': 2, + 'r': 3, + } + } + } + ] + modules: + - | + package foo + + p = 1 + - | + package bar + q = 2 + - note: shadow cached data + query: | + data = x with data.foo as 1 + data: {'foo': 2} + want_result: [ + { + 'x': { + 'foo': 1 + } + } + ] + - note: shadow cached data and merge + query: | + data = x with data.foo as 1 with data.bar.qux as 4 + data: {'foo': 2, 'bar': {'baz': 3}} + want_result: [ + { + 'x': { + 'foo': 1, + 'bar': { + 'baz': 3, + 'qux': 4, + } + } + } + ] + - note: undo rule shadow + query: | + data.test.p = x; data.test.q = y with data.test.r as 2; data.test.r = z; data.test.q = t + modules: + - | + package test + + p = r + q = [r] + r = 1 + want_result: [ + { + 'x': 1, + 'y': [2], + 'z': 1, + 't': [1], + } + ] + - note: undo data shadow + query: | + data.test.p = x; data.test.q = y with data.test.r as 2; data.test.r = z; data.test.q = t + modules: + - | + package test + p = data.test.r + q = [data.test.r] + data: { + 'test': { + 'r': 1, + } + } + want_result: [ + { + 'x': 1, + 'y': [2], + 'z': 1, + 't': [1], + } + ]