diff --git a/examples/lang/states0.mcl b/examples/lang/states0.mcl index eb3d015578..84aacf4e7a 100644 --- a/examples/lang/states0.mcl +++ b/examples/lang/states0.mcl @@ -2,7 +2,6 @@ import "world" $ns = "estate" $exchanged = world.kvlookup($ns) - $state = maplookup($exchanged, $hostname, "default") if $state == "one" || $state == "default" { diff --git a/lang/gapi.go b/lang/gapi.go index d78241d2df..bed43fe197 100644 --- a/lang/gapi.go +++ b/lang/gapi.go @@ -293,7 +293,13 @@ func (obj *GAPI) Cli(cliInfo *gapi.CliInfo) (*gapi.Deploy, error) { } } logf("running type unification...") - if err := unification.Unify(interpolated, unification.SimpleInvariantSolverLogger(unificationLogf)); err != nil { + unifier := &unification.Unifier{ + AST: interpolated, + Solver: unification.SimpleInvariantSolverLogger(unificationLogf), + Debug: debug, + Logf: unificationLogf, + } + if err := unifier.Unify(); err != nil { return nil, errwrap.Wrapf(err, "could not unify types") } diff --git a/lang/interfaces/ast.go b/lang/interfaces/ast.go index 44bf112712..6779cf90b2 100644 --- a/lang/interfaces/ast.go +++ b/lang/interfaces/ast.go @@ -32,6 +32,7 @@ import ( // often since we usually know which kind of node we want. type Node interface { Apply(fn func(Node) error) error + //Parent() Node // TODO: should we implement this? } // Stmt represents a statement node in the language. A stmt could be a resource, diff --git a/lang/interfaces/unification.go b/lang/interfaces/unification.go index 6290869ffa..a2e25b4277 100644 --- a/lang/interfaces/unification.go +++ b/lang/interfaces/unification.go @@ -19,6 +19,8 @@ package interfaces import ( "fmt" + + "github.com/purpleidea/mgmt/lang/types" ) // Invariant represents a constraint that is described by the Expr's and Stmt's, @@ -27,4 +29,11 @@ import ( type Invariant interface { // TODO: should we add any other methods to this type? fmt.Stringer + + // ExprList returns the list of valid expressions in this invariant. + ExprList() []Expr + + // Matches returns whether an invariant matches the existing solution. + // If it is inconsistent, then it errors. + Matches(solved map[Expr]*types.Type) (bool, error) } diff --git a/lang/interpret_test.go b/lang/interpret_test.go index 46e64e4f6f..3b00382efa 100644 --- a/lang/interpret_test.go +++ b/lang/interpret_test.go @@ -467,7 +467,13 @@ func TestAstFunc0(t *testing.T) { logf := func(format string, v ...interface{}) { t.Logf(fmt.Sprintf("test #%d", index)+": unification: "+format, v...) } - err = unification.Unify(iast, unification.SimpleInvariantSolverLogger(logf)) + unifier := &unification.Unifier{ + AST: iast, + Solver: unification.SimpleInvariantSolverLogger(logf), + Debug: testing.Verbose(), + Logf: logf, + } + err = unifier.Unify() if !fail && err != nil { t.Errorf("test #%d: FAIL", index) t.Errorf("test #%d: could not unify types: %+v", index, err) @@ -822,7 +828,13 @@ func TestAstFunc1(t *testing.T) { xlogf := func(format string, v ...interface{}) { logf("unification: "+format, v...) } - err = unification.Unify(iast, unification.SimpleInvariantSolverLogger(xlogf)) + unifier := &unification.Unifier{ + AST: iast, + Solver: unification.SimpleInvariantSolverLogger(xlogf), + Debug: testing.Verbose(), + Logf: xlogf, + } + err = unifier.Unify() if !fail && err != nil { t.Errorf("test #%d: FAIL", index) t.Errorf("test #%d: could not unify types: %+v", index, err) @@ -1216,7 +1228,13 @@ func TestAstFunc2(t *testing.T) { xlogf := func(format string, v ...interface{}) { logf("unification: "+format, v...) } - err = unification.Unify(iast, unification.SimpleInvariantSolverLogger(xlogf)) + unifier := &unification.Unifier{ + AST: iast, + Solver: unification.SimpleInvariantSolverLogger(xlogf), + Debug: testing.Verbose(), + Logf: xlogf, + } + err = unifier.Unify() if !fail && err != nil { t.Errorf("test #%d: FAIL", index) t.Errorf("test #%d: could not unify types: %+v", index, err) diff --git a/lang/interpret_test/TestAstFunc1/doubleinclude.graph b/lang/interpret_test/TestAstFunc1/doubleinclude.graph new file mode 100644 index 0000000000..b1095033bc --- /dev/null +++ b/lang/interpret_test/TestAstFunc1/doubleinclude.graph @@ -0,0 +1,11 @@ +Edge: str("hey") -> var(foo) # foo +Edge: str("hey") -> var(foo) # foo +Edge: str("t1") -> var(a) # a +Edge: str("t2") -> var(a) # a +Vertex: str("hey") +Vertex: str("t1") +Vertex: str("t2") +Vertex: var(a) +Vertex: var(a) +Vertex: var(foo) +Vertex: var(foo) diff --git a/lang/interpret_test/TestAstFunc1/doubleinclude/main.mcl b/lang/interpret_test/TestAstFunc1/doubleinclude/main.mcl new file mode 100644 index 0000000000..af51a44606 --- /dev/null +++ b/lang/interpret_test/TestAstFunc1/doubleinclude/main.mcl @@ -0,0 +1,8 @@ +include c1("t1") +include c1("t2") +class c1($a) { + test $a { + stringptr => $foo, + } +} +$foo = "hey" diff --git a/lang/interpret_test/TestAstFunc1/polydoubleinclude.graph b/lang/interpret_test/TestAstFunc1/polydoubleinclude.graph new file mode 100644 index 0000000000..2e6936d83e --- /dev/null +++ b/lang/interpret_test/TestAstFunc1/polydoubleinclude.graph @@ -0,0 +1,32 @@ +Edge: call:len(var(b)) -> call:fmt.printf(str("len is: %d"), call:len(var(b))) # b +Edge: call:len(var(b)) -> call:fmt.printf(str("len is: %d"), call:len(var(b))) # b +Edge: int(-37) -> list(int(13), int(42), int(0), int(-37)) # 3 +Edge: int(0) -> list(int(13), int(42), int(0), int(-37)) # 2 +Edge: int(13) -> list(int(13), int(42), int(0), int(-37)) # 0 +Edge: int(42) -> list(int(13), int(42), int(0), int(-37)) # 1 +Edge: list(int(13), int(42), int(0), int(-37)) -> var(b) # b +Edge: str("hello") -> var(b) # b +Edge: str("len is: %d") -> call:fmt.printf(str("len is: %d"), call:len(var(b))) # a +Edge: str("len is: %d") -> call:fmt.printf(str("len is: %d"), call:len(var(b))) # a +Edge: str("t1") -> var(a) # a +Edge: str("t2") -> var(a) # a +Edge: var(b) -> call:len(var(b)) # 0 +Edge: var(b) -> call:len(var(b)) # 0 +Vertex: call:fmt.printf(str("len is: %d"), call:len(var(b))) +Vertex: call:fmt.printf(str("len is: %d"), call:len(var(b))) +Vertex: call:len(var(b)) +Vertex: call:len(var(b)) +Vertex: int(-37) +Vertex: int(0) +Vertex: int(13) +Vertex: int(42) +Vertex: list(int(13), int(42), int(0), int(-37)) +Vertex: str("hello") +Vertex: str("len is: %d") +Vertex: str("len is: %d") +Vertex: str("t1") +Vertex: str("t2") +Vertex: var(a) +Vertex: var(a) +Vertex: var(b) +Vertex: var(b) diff --git a/lang/interpret_test/TestAstFunc1/polydoubleinclude/main.mcl b/lang/interpret_test/TestAstFunc1/polydoubleinclude/main.mcl new file mode 100644 index 0000000000..c314e262c4 --- /dev/null +++ b/lang/interpret_test/TestAstFunc1/polydoubleinclude/main.mcl @@ -0,0 +1,10 @@ +import "fmt" + +# note that the class can have two separate types for $b +include c1("t1", "hello") # len is 5 +include c1("t2", [13, 42, 0, -37,]) # len is 4 +class c1($a, $b) { + test $a { + anotherstr => fmt.printf("len is: %d", len($b)), + } +} diff --git a/lang/interpret_test/TestAstFunc1/slow_unification0.graph b/lang/interpret_test/TestAstFunc1/slow_unification0.graph new file mode 100644 index 0000000000..4710d2dac8 --- /dev/null +++ b/lang/interpret_test/TestAstFunc1/slow_unification0.graph @@ -0,0 +1,88 @@ +Edge: call:_operator(str("=="), var(state), str("default")) -> call:_operator(str("||"), call:_operator(str("=="), var(state), str("one")), call:_operator(str("=="), var(state), str("default"))) # b +Edge: call:_operator(str("=="), var(state), str("one")) -> call:_operator(str("||"), call:_operator(str("=="), var(state), str("one")), call:_operator(str("=="), var(state), str("default"))) # a +Edge: call:maplookup(var(exchanged), var(hostname), str("default")) -> var(state) # state +Edge: call:maplookup(var(exchanged), var(hostname), str("default")) -> var(state) # state +Edge: call:maplookup(var(exchanged), var(hostname), str("default")) -> var(state) # state +Edge: call:maplookup(var(exchanged), var(hostname), str("default")) -> var(state) # state +Edge: call:world.kvlookup(var(ns)) -> var(exchanged) # exchanged +Edge: str("") -> var(hostname) # hostname +Edge: str("==") -> call:_operator(str("=="), var(state), str("default")) # x +Edge: str("==") -> call:_operator(str("=="), var(state), str("one")) # x +Edge: str("==") -> call:_operator(str("=="), var(state), str("three")) # x +Edge: str("==") -> call:_operator(str("=="), var(state), str("two")) # x +Edge: str("default") -> call:_operator(str("=="), var(state), str("default")) # b +Edge: str("default") -> call:maplookup(var(exchanged), var(hostname), str("default")) # default +Edge: str("estate") -> var(ns) # ns +Edge: str("estate") -> var(ns) # ns +Edge: str("estate") -> var(ns) # ns +Edge: str("estate") -> var(ns) # ns +Edge: str("estate") -> var(ns) # ns +Edge: str("estate") -> var(ns) # ns +Edge: str("estate") -> var(ns) # ns +Edge: str("estate") -> var(ns) # ns +Edge: str("estate") -> var(ns) # ns +Edge: str("estate") -> var(ns) # ns +Edge: str("one") -> call:_operator(str("=="), var(state), str("one")) # b +Edge: str("three") -> call:_operator(str("=="), var(state), str("three")) # b +Edge: str("two") -> call:_operator(str("=="), var(state), str("two")) # b +Edge: str("||") -> call:_operator(str("||"), call:_operator(str("=="), var(state), str("one")), call:_operator(str("=="), var(state), str("default"))) # x +Edge: var(exchanged) -> call:maplookup(var(exchanged), var(hostname), str("default")) # map +Edge: var(hostname) -> call:maplookup(var(exchanged), var(hostname), str("default")) # key +Edge: var(ns) -> call:world.kvlookup(var(ns)) # namespace +Edge: var(state) -> call:_operator(str("=="), var(state), str("default")) # a +Edge: var(state) -> call:_operator(str("=="), var(state), str("one")) # a +Edge: var(state) -> call:_operator(str("=="), var(state), str("three")) # a +Edge: var(state) -> call:_operator(str("=="), var(state), str("two")) # a +Vertex: call:_operator(str("=="), var(state), str("default")) +Vertex: call:_operator(str("=="), var(state), str("one")) +Vertex: call:_operator(str("=="), var(state), str("three")) +Vertex: call:_operator(str("=="), var(state), str("two")) +Vertex: call:_operator(str("||"), call:_operator(str("=="), var(state), str("one")), call:_operator(str("=="), var(state), str("default"))) +Vertex: call:maplookup(var(exchanged), var(hostname), str("default")) +Vertex: call:world.kvlookup(var(ns)) +Vertex: str("") +Vertex: str("/tmp/mgmt/state") +Vertex: str("/tmp/mgmt/state") +Vertex: str("/tmp/mgmt/state") +Vertex: str("/usr/bin/sleep 1s") +Vertex: str("/usr/bin/sleep 1s") +Vertex: str("/usr/bin/sleep 1s") +Vertex: str("==") +Vertex: str("==") +Vertex: str("==") +Vertex: str("==") +Vertex: str("default") +Vertex: str("default") +Vertex: str("estate") +Vertex: str("one") +Vertex: str("one") +Vertex: str("state: one\n") +Vertex: str("state: three\n") +Vertex: str("state: two\n") +Vertex: str("three") +Vertex: str("three") +Vertex: str("timer") +Vertex: str("timer") +Vertex: str("timer") +Vertex: str("timer") +Vertex: str("timer") +Vertex: str("timer") +Vertex: str("two") +Vertex: str("two") +Vertex: str("||") +Vertex: var(exchanged) +Vertex: var(hostname) +Vertex: var(ns) +Vertex: var(ns) +Vertex: var(ns) +Vertex: var(ns) +Vertex: var(ns) +Vertex: var(ns) +Vertex: var(ns) +Vertex: var(ns) +Vertex: var(ns) +Vertex: var(ns) +Vertex: var(state) +Vertex: var(state) +Vertex: var(state) +Vertex: var(state) diff --git a/lang/interpret_test/TestAstFunc1/slow_unification0/main.mcl b/lang/interpret_test/TestAstFunc1/slow_unification0/main.mcl new file mode 100644 index 0000000000..76aac18874 --- /dev/null +++ b/lang/interpret_test/TestAstFunc1/slow_unification0/main.mcl @@ -0,0 +1,52 @@ +# state machine that previously experienced unusable slow type unification +import "world" + +$ns = "estate" +$exchanged = world.kvlookup($ns) +$state = maplookup($exchanged, $hostname, "default") + +if $state == "one" || $state == "default" { + + file "/tmp/mgmt/state" { + content => "state: one\n", + } + + exec "timer" { + cmd => "/usr/bin/sleep 1s", + } + kv "${ns}" { + key => $ns, + value => "two", + } + Exec["timer"] -> Kv["${ns}"] +} +if $state == "two" { + + file "/tmp/mgmt/state" { + content => "state: two\n", + } + + exec "timer" { + cmd => "/usr/bin/sleep 1s", + } + kv "${ns}" { + key => $ns, + value => "three", + } + Exec["timer"] -> Kv["${ns}"] +} +if $state == "three" { + + file "/tmp/mgmt/state" { + content => "state: three\n", + } + + exec "timer" { + cmd => "/usr/bin/sleep 1s", + } + kv "${ns}" { + key => $ns, + value => "one", + } + Exec["timer"] -> Kv["${ns}"] +} diff --git a/lang/lang.go b/lang/lang.go index acee229da2..5310f4d0a8 100644 --- a/lang/lang.go +++ b/lang/lang.go @@ -185,7 +185,13 @@ func (obj *Lang) Init() error { } } obj.Logf("running type unification...") - if err := unification.Unify(obj.ast, unification.SimpleInvariantSolverLogger(logf)); err != nil { + unifier := &unification.Unifier{ + AST: obj.ast, + Solver: unification.SimpleInvariantSolverLogger(logf), + Debug: obj.Debug, + Logf: logf, + } + if err := unifier.Unify(); err != nil { return errwrap.Wrapf(err, "could not unify types") } diff --git a/lang/structs.go b/lang/structs.go index 4273664471..cc679e8f63 100644 --- a/lang/structs.go +++ b/lang/structs.go @@ -2977,6 +2977,15 @@ type StmtInclude struct { // Nevertheless, it is a useful facility for operations that might only apply to // a select number of node types, since they won't need extra noop iterators... func (obj *StmtInclude) Apply(fn func(interfaces.Node) error) error { + // If the class exists, then descend into it, because at this point, the + // copy of the original class that is stored here, is the effective + // class that we care about for type unification, and everything else... + // It's not clear if this is needed, but it's probably nor harmful atm. + if obj.class != nil { + if err := obj.class.Apply(fn); err != nil { + return err + } + } if obj.Args != nil { for _, x := range obj.Args { if err := x.Apply(fn); err != nil { @@ -4890,7 +4899,11 @@ func (obj *ExprFunc) String() string { if obj.Return != nil { s += fmt.Sprintf(" %s", obj.Return.String()) } - s += fmt.Sprintf(" { %s }", obj.Body.String()) + if obj.Body == nil { + s += fmt.Sprintf(" { ??? }") // TODO: why does this happen? + } else { + s += fmt.Sprintf(" { %s }", obj.Body.String()) + } return s } diff --git a/lang/unification/simplesolver.go b/lang/unification/simplesolver.go index 50c722089a..1154a8f188 100644 --- a/lang/unification/simplesolver.go +++ b/lang/unification/simplesolver.go @@ -34,15 +34,16 @@ const ( // SimpleInvariantSolver with the log parameter of your choice specified. The // result satisfies the correct signature for the solver parameter of the // Unification function. -func SimpleInvariantSolverLogger(logf func(format string, v ...interface{})) func([]interfaces.Invariant) (*InvariantSolution, error) { - return func(invariants []interfaces.Invariant) (*InvariantSolution, error) { - return SimpleInvariantSolver(invariants, logf) +func SimpleInvariantSolverLogger(logf func(format string, v ...interface{})) func([]interfaces.Invariant, []interfaces.Expr) (*InvariantSolution, error) { + return func(invariants []interfaces.Invariant, expected []interfaces.Expr) (*InvariantSolution, error) { + return SimpleInvariantSolver(invariants, expected, logf) } } // SimpleInvariantSolver is an iterative invariant solver for AST expressions. // It is intended to be very simple, even if it's computationally inefficient. -func SimpleInvariantSolver(invariants []interfaces.Invariant, logf func(format string, v ...interface{})) (*InvariantSolution, error) { +func SimpleInvariantSolver(invariants []interfaces.Invariant, expected []interfaces.Expr, logf func(format string, v ...interface{})) (*InvariantSolution, error) { + debug := false // XXX: add to interface logf("%s: invariants:", Name) for i, x := range invariants { logf("invariant(%d): %T: %s", i, x, x) @@ -112,8 +113,18 @@ func SimpleInvariantSolver(invariants []interfaces.Invariant, logf func(format s structPartials := make(map[interfaces.Expr]map[interfaces.Expr]*types.Type) funcPartials := make(map[interfaces.Expr]map[interfaces.Expr]*types.Type) + isSolved := func(solved map[interfaces.Expr]*types.Type) bool { + for _, x := range expected { + if typ, exists := solved[x]; !exists || typ == nil { + return false + } + } + return true + } + logf("%s: starting loop with %d equalities", Name, len(equalities)) // run until we're solved, stop consuming equalities, or type clash +Loop: for { logf("%s: iterate...", Name) if len(equalities) == 0 && len(exclusives) == 0 { @@ -498,11 +509,71 @@ func SimpleInvariantSolver(invariants []interfaces.Invariant, logf func(format s } } // end inner for loop if len(used) == 0 { - // looks like we're now ambiguous, but if we have any + // Looks like we're now ambiguous, but if we have any // exclusives, recurse into each possibility to see if - // one of them can help solve this! first one wins. add + // one of them can help solve this! first one wins. Add // in the exclusive to the current set of equalities! + // To decrease the problem space, first check if we have + // enough solutions to solve everything. If so, then we + // don't need to solve any exclusives, and instead we + // only need to verify that they don't conflict with the + // found solution, which reduces the search space... + + // Another optimization that can be done before we run + // the combinatorial exclusive solver, is we can look at + // each exclusive, and remove the ones that already + // match, because they don't tell us any new information + // that we don't already know. We can also fail early + // if anything proves we're already inconsistent. + + // These two optimizations turn out to use the exact + // same algorithm and code, so they're combined here... + if isSolved(solved) { + logf("%s: solved early with %d exclusives left!", Name, len(exclusives)) + } else { + logf("%s: unsolved with %d exclusives left!", Name, len(exclusives)) + } + // check for consistency against remaining invariants + done := []int{} + for i, invar := range exclusives { + // test each one to see if at least one works + match, err := invar.Matches(solved) + if err != nil { + if debug { + logf("exclusive invar failed: %+v", invar) + } + return nil, errwrap.Wrapf(err, "inconsistent exclusive") + } + if !match { + continue + } + done = append(done, i) + } + + // remove exclusives that matched correctly + for i := len(done) - 1; i >= 0; i-- { + ix := done[i] // delete index that was marked as done! + exclusives = append(exclusives[:ix], exclusives[ix+1:]...) + } + + if len(exclusives) == 0 { + break Loop + } + + // TODO: Lastly, we could loop through each exclusive + // and see if it only has a single, easy solution. For + // example, if we know that an exclusive is A or B or C + // and that B and C are inconsistent, then we can + // replace the exclusive with a single invariant and + // then run that through our solver. We can do this + // iteratively (recursively in our case) so that if + // we're lucky, we rarely need to run the raw exclusive + // combinatorial solver which is slow. + + // TODO: We could try and replace our combinatorial + // exclusive solver with a real SAT solver algorithm. + // what have we learned for sure so far? partialSolutions := []interfaces.Invariant{} logf("%s: %d solved, %d unsolved, and %d exclusives left", Name, len(solved), len(equalities), len(exclusives)) @@ -535,7 +606,7 @@ func SimpleInvariantSolver(invariants []interfaces.Invariant, logf func(format s recursiveInvariants = append(recursiveInvariants, partialSolutions...) recursiveInvariants = append(recursiveInvariants, ex...) logf("%s: recursing...", Name) - solution, err := SimpleInvariantSolver(recursiveInvariants, logf) + solution, err := SimpleInvariantSolver(recursiveInvariants, expected, logf) if err != nil { logf("%s: recursive solution failed: %+v", Name, err) continue // no solution found here... diff --git a/lang/unification/unification.go b/lang/unification/unification.go index 987c19632b..a1fe2bdd51 100644 --- a/lang/unification/unification.go +++ b/lang/unification/unification.go @@ -25,6 +25,18 @@ import ( "github.com/purpleidea/mgmt/lang/types" ) +// Unifier holds all the data that the Unify function will need for it to run. +type Unifier struct { + // AST is the input abstract syntax tree to unify. + AST interfaces.Stmt + + // Solver is the solver algorithm implementation to use. + Solver func([]interfaces.Invariant, []interfaces.Expr) (*InvariantSolution, error) + + Debug bool + Logf func(format string, v ...interface{}) +} + // Unify takes an AST expression tree and attempts to assign types to every node // using the specified solver. The expression tree returns a list of invariants // (or constraints) which must be met in order to find a unique value for the @@ -37,32 +49,77 @@ import ( // type. This function and logic was invented after the author could not find // any proper literature or examples describing a well-known implementation of // this process. Improvements and polite recommendations are welcome. -func Unify(ast interfaces.Stmt, solver func([]interfaces.Invariant) (*InvariantSolution, error)) error { - //log.Printf("unification: tree: %+v", ast) // debug - if ast == nil { - return fmt.Errorf("AST is nil") +func (obj *Unifier) Unify() error { + if obj.AST == nil { + return fmt.Errorf("the AST is nil") + } + if obj.Solver == nil { + return fmt.Errorf("the Solver is missing") + } + if obj.Logf == nil { + return fmt.Errorf("the Logf function is missing") } - invariants, err := ast.Unify() + if obj.Debug { + obj.Logf("tree: %+v", obj.AST) + } + invariants, err := obj.AST.Unify() if err != nil { return err } - solved, err := solver(invariants) + // build a list of what we think we need to solve for to succeed + exprs := []interfaces.Expr{} + for _, x := range invariants { + exprs = append(exprs, x.ExprList()...) + } + exprMap := ExprListToExprMap(exprs) // makes searching faster + exprList := ExprMapToExprList(exprMap) // makes it unique (no duplicates) + + solved, err := obj.Solver(invariants, exprList) if err != nil { return err } - // TODO: ideally we would know how many different expressions need their - // types set in the AST and then ensure we have this many unique - // solutions, and if not, then fail. This would ensure we don't have an - // AST that is only partially populated with the correct types. + // determine what expr's we need to solve for + if obj.Debug { + obj.Logf("expr count: %d", len(exprList)) + //for _, x := range exprList { + // obj.Logf("> %p (%+v)", x, x) + //} + } + + // XXX: why doesn't `len(exprList)` always == `len(solved.Solutions)` ? + // XXX: is it due to the extra ExprAny ??? I see an extra function sometimes... - //log.Printf("unification: found a solution!") // TODO: get a logf function passed in... + if obj.Debug { + obj.Logf("solutions count: %d", len(solved.Solutions)) + //for _, x := range solved.Solutions { + // obj.Logf("> %p (%+v) -- %s", x.Expr, x.Type, x.Expr.String()) + //} + } + + // Determine that our solver produced a solution for every expr that + // we're interested in. If it didn't, and it didn't error, then it's a + // bug. We check for this because we care about safety, this ensures + // that our AST will get fully populated with the correct types! + for _, x := range solved.Solutions { + delete(exprMap, x.Expr) // remove everything we know about + } + if c := len(exprMap); c > 0 { // if there's anything left, it's bad... + // programming error! + return fmt.Errorf("got %d unbound expr's", c) + } + + if obj.Debug { + obj.Logf("found a solution!") + } // solver has found a solution, apply it... // we're modifying the AST, so code can't error now... for _, x := range solved.Solutions { - //log.Printf("unification: solution: %p => %+v\t(%+v)", x.Expr, x.Type, x.Expr.String()) // debug + if obj.Debug { + obj.Logf("solution: %p => %+v\t(%+v)", x.Expr, x.Type, x.Expr.String()) + } // apply this to each AST node if err := x.Expr.SetType(x.Type); err != nil { // programming error! @@ -85,6 +142,24 @@ func (obj *EqualsInvariant) String() string { return fmt.Sprintf("%p == %s", obj.Expr, obj.Type) } +// ExprList returns the list of valid expressions in this invariant. +func (obj *EqualsInvariant) ExprList() []interfaces.Expr { + return []interfaces.Expr{obj.Expr} +} + +// Matches returns whether an invariant matches the existing solution. If it is +// inconsistent, then it errors. +func (obj *EqualsInvariant) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) { + typ, exists := solved[obj.Expr] + if !exists { + return false, nil + } + if err := typ.Cmp(obj.Type); err != nil { + return false, err + } + return true, nil +} + // EqualityInvariant is an invariant that symbolizes that the two expressions // must have equivalent types. // TODO: is there a better name than EqualityInvariant @@ -98,6 +173,26 @@ func (obj *EqualityInvariant) String() string { return fmt.Sprintf("%p == %p", obj.Expr1, obj.Expr2) } +// ExprList returns the list of valid expressions in this invariant. +func (obj *EqualityInvariant) ExprList() []interfaces.Expr { + return []interfaces.Expr{obj.Expr1, obj.Expr2} +} + +// Matches returns whether an invariant matches the existing solution. If it is +// inconsistent, then it errors. +func (obj *EqualityInvariant) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) { + t1, exists1 := solved[obj.Expr1] + t2, exists2 := solved[obj.Expr2] + if !exists1 || !exists2 { + return false, nil // not matched yet + } + if err := t1.Cmp(t2); err != nil { + return false, err + } + + return true, nil // matched! +} + // EqualityInvariantList is an invariant that symbolizes that all the // expressions listed must have equivalent types. type EqualityInvariantList struct { @@ -113,6 +208,32 @@ func (obj *EqualityInvariantList) String() string { return fmt.Sprintf("[%s]", strings.Join(a, ", ")) } +// ExprList returns the list of valid expressions in this invariant. +func (obj *EqualityInvariantList) ExprList() []interfaces.Expr { + return obj.Exprs +} + +// Matches returns whether an invariant matches the existing solution. If it is +// inconsistent, then it errors. +func (obj *EqualityInvariantList) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) { + found := true // assume true + var typ *types.Type + for _, x := range obj.Exprs { + t, exists := solved[x] + if !exists { + found = false + continue + } + if typ == nil { // set the first time + typ = t + } + if err := typ.Cmp(t); err != nil { + return false, err + } + } + return found, nil +} + // EqualityWrapListInvariant expresses that a list in Expr1 must have elements // that have the same type as the expression in Expr2Val. type EqualityWrapListInvariant struct { @@ -125,6 +246,28 @@ func (obj *EqualityWrapListInvariant) String() string { return fmt.Sprintf("%p == [%p]", obj.Expr1, obj.Expr2Val) } +// ExprList returns the list of valid expressions in this invariant. +func (obj *EqualityWrapListInvariant) ExprList() []interfaces.Expr { + return []interfaces.Expr{obj.Expr1, obj.Expr2Val} +} + +// Matches returns whether an invariant matches the existing solution. If it is +// inconsistent, then it errors. +func (obj *EqualityWrapListInvariant) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) { + t1, exists1 := solved[obj.Expr1] // list type + t2, exists2 := solved[obj.Expr2Val] + if !exists1 || !exists2 { + return false, nil // not matched yet + } + if t1.Kind != types.KindList { + return false, fmt.Errorf("expected list kind") + } + if err := t1.Val.Cmp(t2); err != nil { + return false, err // inconsistent! + } + return true, nil // matched! +} + // EqualityWrapMapInvariant expresses that a map in Expr1 must have keys that // match the type of the expression in Expr2Key and values that match the type // of the expression in Expr2Val. @@ -139,6 +282,32 @@ func (obj *EqualityWrapMapInvariant) String() string { return fmt.Sprintf("%p == {%p: %p}", obj.Expr1, obj.Expr2Key, obj.Expr2Val) } +// ExprList returns the list of valid expressions in this invariant. +func (obj *EqualityWrapMapInvariant) ExprList() []interfaces.Expr { + return []interfaces.Expr{obj.Expr1, obj.Expr2Key, obj.Expr2Val} +} + +// Matches returns whether an invariant matches the existing solution. If it is +// inconsistent, then it errors. +func (obj *EqualityWrapMapInvariant) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) { + t1, exists1 := solved[obj.Expr1] // list type + t2, exists2 := solved[obj.Expr2Key] + t3, exists3 := solved[obj.Expr2Val] + if !exists1 || !exists2 || !exists3 { + return false, nil // not matched yet + } + if t1.Kind != types.KindMap { + return false, fmt.Errorf("expected map kind") + } + if err := t1.Key.Cmp(t2); err != nil { + return false, err // inconsistent! + } + if err := t1.Val.Cmp(t3); err != nil { + return false, err // inconsistent! + } + return true, nil // matched! +} + // EqualityWrapStructInvariant expresses that a struct in Expr1 must have fields // that match the type of the expressions listed in Expr2Map. type EqualityWrapStructInvariant struct { @@ -163,6 +332,49 @@ func (obj *EqualityWrapStructInvariant) String() string { return fmt.Sprintf("%p == struct{%s}", obj.Expr1, strings.Join(s, "; ")) } +// ExprList returns the list of valid expressions in this invariant. +func (obj *EqualityWrapStructInvariant) ExprList() []interfaces.Expr { + exprs := []interfaces.Expr{obj.Expr1} + for _, x := range obj.Expr2Map { + exprs = append(exprs, x) + } + return exprs +} + +// Matches returns whether an invariant matches the existing solution. If it is +// inconsistent, then it errors. +func (obj *EqualityWrapStructInvariant) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) { + t1, exists1 := solved[obj.Expr1] // list type + if !exists1 { + return false, nil // not matched yet + } + if t1.Kind != types.KindStruct { + return false, fmt.Errorf("expected struct kind") + } + + found := true // assume true + for _, key := range obj.Expr2Ord { + _, exists := t1.Map[key] + if !exists { + return false, fmt.Errorf("missing invariant struct key of: `%s`", key) + } + e, exists := obj.Expr2Map[key] + if !exists { + return false, fmt.Errorf("missing matched struct key of: `%s`", key) + } + t, exists := solved[e] + if !exists { + found = false + continue + } + if err := t1.Map[key].Cmp(t); err != nil { + return false, err // inconsistent! + } + } + + return found, nil // matched! +} + // EqualityWrapFuncInvariant expresses that a func in Expr1 must have args that // match the type of the expressions listed in Expr2Map and a return value that // matches the type of the expression in Expr2Out. @@ -190,6 +402,58 @@ func (obj *EqualityWrapFuncInvariant) String() string { return fmt.Sprintf("%p == func{%s} %p", obj.Expr1, strings.Join(s, "; "), obj.Expr2Out) } +// ExprList returns the list of valid expressions in this invariant. +func (obj *EqualityWrapFuncInvariant) ExprList() []interfaces.Expr { + exprs := []interfaces.Expr{obj.Expr1} + for _, x := range obj.Expr2Map { + exprs = append(exprs, x) + } + exprs = append(exprs, obj.Expr2Out) + return exprs +} + +// Matches returns whether an invariant matches the existing solution. If it is +// inconsistent, then it errors. +func (obj *EqualityWrapFuncInvariant) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) { + t1, exists1 := solved[obj.Expr1] // list type + if !exists1 { + return false, nil // not matched yet + } + if t1.Kind != types.KindFunc { + return false, fmt.Errorf("expected func kind") + } + + found := true // assume true + for _, key := range obj.Expr2Ord { + _, exists := t1.Map[key] + if !exists { + return false, fmt.Errorf("missing invariant struct key of: `%s`", key) + } + e, exists := obj.Expr2Map[key] + if !exists { + return false, fmt.Errorf("missing matched struct key of: `%s`", key) + } + t, exists := solved[e] + if !exists { + found = false + continue + } + if err := t1.Map[key].Cmp(t); err != nil { + return false, err // inconsistent! + } + } + + t, exists := solved[obj.Expr2Out] + if !exists { + return false, nil + } + if err := t1.Out.Cmp(t); err != nil { + return false, err // inconsistent! + } + + return found, nil // matched! +} + // ConjunctionInvariant represents a list of invariants which must all be true // together. In other words, it's a grouping construct for a set of invariants. type ConjunctionInvariant struct { @@ -206,6 +470,31 @@ func (obj *ConjunctionInvariant) String() string { return fmt.Sprintf("[%s]", strings.Join(a, ", ")) } +// ExprList returns the list of valid expressions in this invariant. +func (obj *ConjunctionInvariant) ExprList() []interfaces.Expr { + exprs := []interfaces.Expr{} + for _, x := range obj.Invariants { + exprs = append(exprs, x.ExprList()...) + } + return exprs +} + +// Matches returns whether an invariant matches the existing solution. If it is +// inconsistent, then it errors. +func (obj *ConjunctionInvariant) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) { + found := true // assume true + for _, invar := range obj.Invariants { + match, err := invar.Matches(solved) + if err != nil { + return false, nil + } + if !match { + found = false + } + } + return found, nil +} + // ExclusiveInvariant represents a list of invariants where one and *only* one // should hold true. To combine multiple invariants in one of the list elements, // you can group multiple invariants together using a ConjunctionInvariant. Do @@ -226,6 +515,54 @@ func (obj *ExclusiveInvariant) String() string { return fmt.Sprintf("[%s]", strings.Join(a, ", ")) } +// ExprList returns the list of valid expressions in this invariant. +func (obj *ExclusiveInvariant) ExprList() []interfaces.Expr { + // XXX: We should do this if we assume that exclusives don't have some + // sort of transient expr to satisfy that doesn't disappear depending on + // which choice in the exclusive is chosen... + //exprs := []interfaces.Expr{} + //for _, x := range obj.Invariants { + // exprs = append(exprs, x.ExprList()...) + //} + //return exprs + // XXX: But if we ever specify an expr in this exclusive that isn't + // referenced anywhere else, then we'd need to use the above so that our + // type unification algorithm knows not to stop too early. + return []interfaces.Expr{} // XXX: Do we want to the set instead? +} + +// Matches returns whether an invariant matches the existing solution. If it is +// inconsistent, then it errors. Because this partial invariant requires only +// one to be true, it will mask children errors, since it's normal for only one +// to be consistent. +func (obj *ExclusiveInvariant) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) { + found := false + reterr := fmt.Errorf("all exclusives errored") + for _, invar := range obj.Invariants { + match, err := invar.Matches(solved) + if err != nil { + continue + } + if !match { + // at least one was false, so we're not done here yet... + // we don't want to error yet, since we can't know there + // won't be a conflict once we get more data about this! + reterr = nil // clear the error + continue + } + if found { // we already found one + return false, fmt.Errorf("more than one exclusive solution") + } + found = true + } + + if found { // we got exactly one valid solution + return true, nil + } + + return false, reterr +} + // exclusivesProduct returns a list of different products produced from the // combinatorial product of the list of exclusives. Each ExclusiveInvariant // must contain between one and more Invariants. This takes every combination of @@ -278,8 +615,30 @@ func (obj *AnyInvariant) String() string { return fmt.Sprintf("%p == *", obj.Expr) } +// ExprList returns the list of valid expressions in this invariant. +func (obj *AnyInvariant) ExprList() []interfaces.Expr { + return []interfaces.Expr{obj.Expr} +} + +// Matches returns whether an invariant matches the existing solution. If it is +// inconsistent, then it errors. +func (obj *AnyInvariant) Matches(solved map[interfaces.Expr]*types.Type) (bool, error) { + _, exists := solved[obj.Expr] // we only care that it is found. + return exists, nil +} + // InvariantSolution lists a trivial set of EqualsInvariant mappings so that you // can populate your AST with SetType calls in a simple loop. type InvariantSolution struct { Solutions []*EqualsInvariant // list of trivial solutions for each node } + +// ExprList returns the list of valid expressions. This struct is not part of +// the invariant interface, but it implements this anyways. +func (obj *InvariantSolution) ExprList() []interfaces.Expr { + exprs := []interfaces.Expr{} + for _, x := range obj.Solutions { + exprs = append(exprs, x.ExprList()...) + } + return exprs +} diff --git a/lang/unification/util.go b/lang/unification/util.go new file mode 100644 index 0000000000..28dc45fb12 --- /dev/null +++ b/lang/unification/util.go @@ -0,0 +1,54 @@ +// Mgmt +// Copyright (C) 2013-2019+ James Shubin and the project contributors +// Written by James Shubin and the project contributors +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +package unification + +import ( + "github.com/purpleidea/mgmt/lang/interfaces" +) + +// ExprListToExprMap converts a list of expressions to a map that has the unique +// expr pointers as the keys. This is just an alternate representation of the +// same data structure. If you have any duplicate values in your list, they'll +// get removed when stored as a map. +func ExprListToExprMap(exprList []interfaces.Expr) map[interfaces.Expr]struct{} { + exprMap := make(map[interfaces.Expr]struct{}) + for _, x := range exprList { + exprMap[x] = struct{}{} + } + return exprMap +} + +// ExprMapToExprList converts a map of expressions to a list that has the unique +// expr pointers as the values. This is just an alternate representation of the +// same data structure. +func ExprMapToExprList(exprMap map[interfaces.Expr]struct{}) []interfaces.Expr { + exprList := []interfaces.Expr{} + // TODO: sort by pointer address for determinism ? + for x := range exprMap { + exprList = append(exprList, x) + } + return exprList +} + +// UniqueExprList returns a unique list of expressions with no duplicates. It +// does this my converting it to a map and then back. This isn't necessarily the +// most efficient way, and doesn't preserve list ordering. +func UniqueExprList(exprList []interfaces.Expr) []interfaces.Expr { + exprMap := ExprListToExprMap(exprList) + return ExprMapToExprList(exprMap) +} diff --git a/lang/unification_test.go b/lang/unification_test.go index db5ee95747..77908591c1 100644 --- a/lang/unification_test.go +++ b/lang/unification_test.go @@ -819,7 +819,13 @@ func TestUnification1(t *testing.T) { logf := func(format string, v ...interface{}) { t.Logf(fmt.Sprintf("test #%d", index)+": unification: "+format, v...) } - err := unification.Unify(ast, unification.SimpleInvariantSolverLogger(logf)) + unifier := &unification.Unifier{ + AST: ast, + Solver: unification.SimpleInvariantSolverLogger(logf), + Debug: testing.Verbose(), + Logf: logf, + } + err := unifier.Unify() // TODO: print out the AST's so that we can see the types t.Logf("\n\ntest #%d: AST (after): %+v\n", index, ast)