Skip to content

Commit

Permalink
Fix bug in applying MixBounds during type solver.
Browse files Browse the repository at this point in the history
  • Loading branch information
CuppoJava committed May 20, 2024
2 parents 8fb4241 + bb63a53 commit 63dccbe
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 66 deletions.
2 changes: 1 addition & 1 deletion compiler/params.stanza
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public defn compiler-flags () :
to-tuple(COMPILE-FLAGS)

;========= Stanza Configuration ========
public val STANZA-VERSION = [0 18 72]
public val STANZA-VERSION = [0 18 73]
public var STANZA-INSTALL-DIR:String = ""
public var OUTPUT-PLATFORM:Symbol = `platform
public var STANZA-PKG-DIRS:List<String> = List()
Expand Down
2 changes: 1 addition & 1 deletion compiler/type-equations.stanza
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ public defn format (f:TypeFormatter, e:TypeEqn) :
val fs = seq(fmt, xs(e))
"overload(%_, sel: %_) (%_) <: %_" % [fmt(n(e)), sel(e), indented-field-list(fs), fmt(y(e))]
(e:OverloadCallEqn) :
"overload(%_, exp-ns:%,, sel: %_) call (%,) with (%,)" % [fmt(n(e)), fmt(exp-ns(e)), sel(e), fmt(xs(e)), fmt(args(e))]
"overload(%_, exp-ns:[%,], sel: %_) call (%,) with (%,)" % [fmt(n(e)), fmt(exp-ns(e)), sel(e), fmt(xs(e)), fmt(args(e))]
(e:LSOverloadCallEqn) :
val fs = seq(fmt, xs(e))
"ls-overload(%_, exp-ns:%,, sel: %_) call (%_) with (%,)" % [fmt(n(e)), fmt(exp-ns(e)), sel(e), indented-field-list(fs), seq(fmt,args(e))]
Expand Down
6 changes: 5 additions & 1 deletion compiler/type-formatter.stanza
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public defn TypeFormatter (name:Int -> Symbol) :
join-exp(name-atom, paren(prefix("<:", pexp(type(b)))))
(b:MixBound) :
val name-atom = Atom(written(name(n(t))))
join-exp(name-atom, paren(prefix("mix", seq(pexp,types(b)))))
join-exp(name-atom, brackets(prefix("mix: ", seq(pexp,types(b)))))

;Convert a LoStanza type.
defn pexp (t:LSType) -> PrintExp :
Expand Down Expand Up @@ -156,6 +156,10 @@ defn paren (es:Seqable<PrintExp>) -> PrintExp :
defn paren (e:PrintExp) : paren([e])
defn braces (es:Seqable<PrintExp>) -> PrintExp :
Surround(to-tuple(es), Brace)
defn braces (e:PrintExp) : braces([e])
defn brackets (es:Seqable<PrintExp>) -> PrintExp :
Surround(to-tuple(es), Bracket)
defn brackets (e:PrintExp) : brackets([e])

;Surround the given expressions with parentheses if it isn't
;exactly 1 exp.
Expand Down
22 changes: 15 additions & 7 deletions compiler/type-solver.stanza
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ public defn select-overload (a:TArrow, b:Type, hier:TypeHierarchy) -> PredResult
;Compute whether the given arrow is appropriate to be called with the given
;arguments.
public defn select-overload-call (a:TArrow, args:Tuple<Type>, hier:TypeHierarchy) -> PredResult :
eval-gradual-match-result(select-overload-search(a, TArrow(args, TTop()), hier))
if same-length?(a1(a), args) :
val result = SAnd(seq(partof-search{_, _, hier}, args, a1(a)))
eval-gradual-match-result(result)

;Compute whether the given Fnt is appropriate to be called with the given
;arguments.
Expand Down Expand Up @@ -1191,18 +1193,24 @@ defn select-overload-search (x:TArrow, t:Type, hier:TypeHierarchy) -> SearchResu
;Fall through.
(t) : SFail()

;The argument type 'y' (as part of a mixed function) is going to be passed values of type 'x'.
;Return success if this is allowed.
; Passing: a|other -> b|other
; Expecting: x|y -> c
; So: x|y <: a|other
;Suppose we are calling a component of mixed function with argument 'y', e.g. 'Int'.
;And we are pass it 'x', e.g. 'Int|String'.
;Even though normally this is not a valid call because 'Int|String <: Int' is not true,
;in this case we allow it because we're testing only a single component of a mixed function.
;
;Example: Suppose we're attempting to call the mixed function:
; mix{Int -> False, String -> False, Char -> False}
;with argument types:
; Int|String.
defn partof-search (x:Type, y:Type, hier:TypeHierarchy) -> SearchResult :
;Sanity check.
if not overload-type?(y) :
fatal("Right-hand type not a function argument type.")

;Helper: Shorthand for calling partof-search
defn po (x:Type, y:Type) : partof-search(x, y, hier)
;Helper: Shorthand for calling select-overload-search
defn so (x:Type, y:Type) : select-overload-search(x as TArrow, y, hier)

match(x, y) :
;Bottom and top type.
Expand All @@ -1212,7 +1220,7 @@ defn partof-search (x:Type, y:Type, hier:TypeHierarchy) -> SearchResult :
(x:TOr|TAnd, y) : search-tor-tand(po, x, y, DisjunctiveLeftOr)
(x, y:TOr|TAnd) : search-tor-tand(po, x, y, DisjunctiveLeftOr)
;Unsolved types.
(x:TUVar, y) : search-tuvar(po, x, y)
(x:TUVar, y) : search-tuvar(so, x, y)
(x, y:TUVar) : search-tuvar(po, x, y)
;Gradual types
(x:TGradual, y:TGradual) : SSatisfied()
Expand Down
46 changes: 18 additions & 28 deletions tests/dev-types.stanza
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ defpackage stz-test-suite/dev-types :
import stz/type-locs
import stz/type-formatter
import stz/type-equation-solver
import stz/type-capside
import stz/type-fargs
import stz-test-suite/test-tools
import stz-test-suite/type-reader
import stz-test-suite/type-test-ir
Expand All @@ -25,47 +27,35 @@ public defn execute (prog:TestTypeProgram) :
defn fmt (v:SolverValue) : format(type-formatter(prog),v)
defn fmt (e:KeyValue<Int,SolverValue>) : fmt(key(e)) => fmt(value(e))
defn fmt (es:Tuple<KeyValue<Int,SolverValue>>) : seq(fmt,es)
defn fmt-r (r:CaptureResult) : fmt(r as Type) when r is Type else r
set-debug-formatter(type-formatter(prog))

val hier = make-hierarchy(prog)
for stmt in stmts(prog) do :
match(stmt) :
(stmt:Capture) :
println("Capture Test: %_" % [fmt(stmt)])
val captures = capture-variance(args(stmt) as Tuple<Int>, b(stmt))
for capture in captures do :
val arg = key(capture)
val variance = value(capture)
val sresult = capture-search(arg, CapRight, a(stmt), b(stmt), hier)
val simp = simplify(sresult, true)
val type? = match(simp:Solved) : evaluate-constraint(constraint(simp), variance)
val captures = capture-variance(args(stmt) as Tuple<Int>, [b(stmt)], CapRight)
for entry in captures do :
val arg = key(entry)
val variance = value(entry)
val result = capture(arg, variance, CapRight, [a(stmt)], [b(stmt)], hier)
println("Capturing %_" % [fmt(arg)])
println(Indented(fmt(sresult)))
println("Simplifies to:")
println(Indented(fmt(simp)))
println("Evaluates to:")
println(Indented(fmt(type?)))
println(Indented(fmt-r(result)))
println("")
(stmt:Subtype) :
val sresult = subtype-search(a(stmt), b(stmt), hier)
val simp = simplify(sresult, false)
val result = subtype(a(stmt), b(stmt), hier)
println("Subtype Test: %_" % [fmt(stmt)])
println(Indented(fmt(sresult)))
println("Simplifies to:")
println(Indented(fmt(simp)))
println(Indented(result))
println("")
(stmt:Infer) :
val sresult = capture-search(a(stmt), b(stmt), hier)
val simp = simplify(sresult, true)
val type? = match(simp:Solved) :
val variance = capture-variance(a(stmt), b(stmt))
evaluate-constraint-conservative(constraint(simp), variance)
val result = infer(a(stmt), b(stmt), false, hier)
println("Inference Test: %_" % [fmt(stmt)])
println(Indented(fmt(sresult)))
println("Simplifies to:")
println(Indented(fmt(simp)))
println("Evaluates to:")
println(Indented(fmt(type?)))
println(Indented(fmt-r(result)))
(stmt:SelectOverloadCall) :
val result = select-overload-call(a(stmt), args(stmt), hier)
println("SelectOverloadCall Test: %_" % [fmt(stmt)])
println(Indented(result))
(stmt:Solve) :
println("Solver Test:\n%_" % [Indented(fmt(stmt))])
val state = SolverState(eqns(stmt), hier)
Expand All @@ -88,7 +78,7 @@ defn make-hierarchy (prog:TestTypeProgram) -> TypeHierarchy :
val relations = for r in hierarchy(prog) map :
val child = child(r) as TOf
val child-args = map(n, args(child) as Tuple<TVar>)
TypeRelation(n(child), get?(special-table,n(child)), child-args, parent(r))
TypeRelation(n(child), get?(special-table,n(child)), child-args, parent(r), None())

;Return the hierarchy.
TypeHierarchy(relations)
Expand Down
40 changes: 28 additions & 12 deletions tests/type-reader.stanza
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ defn resolve (prog:TestTypeProgram) -> TestTypeProgram :
val args* = map(resolve{_, scope}, args(t))
val n = lookup-class-id(name(t) as Symbol)
TOf(n, args*)
(t:TMixBoundVar) :
if not prefix?(name(t), "$") :
fatal("Bad mix bound var.")
val n = match(lookup?(scope,name(t))) :
(n:Int) : n
(f:False) : make-id(scope,name(t))
val types* = map({resolve(_, scope) as TArrow}, types(t))
TUVar(n, MixBound(types*))
(t:TCap) :
val n = lookup(scope, name(t) as Symbol)
TCap(n, name(t))
Expand Down Expand Up @@ -145,6 +153,10 @@ defn resolve (prog:TestTypeProgram) -> TestTypeProgram :
defn resolve? (l:InferLoc) : l
defn resolve? (t:Type) : resolve(t,scope)
Infer(resolve?(a(s)), resolve?(b(s)))
(s:SelectOverloadCall) :
val scope = NameScope()
SelectOverloadCall(resolve(a(s),scope) as TArrow,
map(resolve{_,scope},args(s)))
(s:Solve) :
val scope = NameScope()
val eqns = eqns(s) as Tuple<NamedTypeEqn>
Expand Down Expand Up @@ -231,6 +243,9 @@ defsyntax type-syntax :
defrule type-stmt = (infer ?x:#type! <: ?y:#loc) :
Infer(x, y)

defrule type-stmt = (select-overload-call ?x:#type! (?ys:#type! ...)) :
SelectOverloadCall(x as TArrow, to-tuple(ys))

defrule type-stmt = (solve : (?es:#type-eqn! ...)) :
Solve(to-tuple(es))

Expand Down Expand Up @@ -265,7 +280,8 @@ defsyntax type-syntax :
defrule type1 = ((~ @cap ?x:#id)) : TCap(0,x)
defrule type1 = ([?xs:#types!]) : TTuple(to-tuple(xs))
defrule type1 = ((?x:#type))
defrule type1 = (?x:#id<?ys:#types!>) : TNamedOf(x,to-tuple(ys))
defrule type1 = (?x:#id{?ts:#types!}) : TMixBoundVar(x,to-tuple(ts) as Tuple<TArrow>)
defrule type1 = (?x:#id<?ys:#types!>) : TNamedOf(x,to-tuple(ys))
defrule type1 = (Void) : TBot()
defrule type1 = (?) : TGradual()
defrule type1 = (?x:#id) : TVar(0,x)
Expand Down Expand Up @@ -294,17 +310,17 @@ defsyntax type-syntax :
defrule type-eqn = (var ?v:#id = ?t:#type!) :
NamedEqualEqn(v,t)

defrule type-eqn = (overload(?v:#id, sel: ?sel:#id) (?xs:#type ...) <: ?y:#type) :
for x in xs do :
if x is-not TArrow :
throw(TSE(closest-info(), "Overloaded functions are expected to be arrows."))
NamedOverloadExpEqn(v,sel,to-tuple(xs) as Tuple<TArrow>,y)

defrule type-eqn = (overload(?v:#id, sel: ?sel:#id) call (?xs:#type ...) with ?y:#type) :
for x in xs do :
if x is-not TArrow :
throw(TSE(closest-info(), "Overloaded functions are expected to be arrows."))
NamedOverloadExpEqn(v,sel,to-tuple(xs) as Tuple<TArrow>,y)
;defrule type-eqn = (overload(?v:#id, sel: ?sel:#id, inst: ?inst:#id) (?xs:#type ...) <: ?y:#type) :
; for x in xs do :
; if x is-not TArrow :
; throw(TSE(closest-info(), "Overloaded functions are expected to be arrows."))
; NamedOverloadExpEqn(v,sel,inst,to-tuple(xs) as Tuple<TArrow>,y)
;
;defrule type-eqn = (overload(?v:#id, sel: ?sel:#id) call (?xs:#type ...) with ?y:#type) :
; for x in xs do :
; if x is-not TArrow :
; throw(TSE(closest-info(), "Overloaded functions are expected to be arrows."))
; NamedOverloadExpEqn(v,sel,to-tuple(xs) as Tuple<TArrow>,y)

defproduction sub-entry: KeyValue<Symbol,Type>
defrule sub-entry = (?v:#id => ?t:#type!) : v => t
Expand Down
60 changes: 44 additions & 16 deletions tests/type-test-ir.stanza
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ defpackage stz-test-suite/type-test-ir :
import collections
import stz/types
import stz/type-locs
import stz/type-fargs
import stz/type-capside
import stz/printing-utils
import stz/type-equations
import stz/type-formatter
Expand Down Expand Up @@ -43,13 +45,19 @@ public defn type-formatter (prog:TestTypeProgram) -> TypeFormatter :
public defstruct TypeNames :
table:IntTable<Symbol> with:
init => IntTable<Symbol>()
id-counter:Seq<Int> with:
init => to-seq(0 to false)

;Generate a new integer id for the given name.
public defn make-id (names:TypeNames, name:Symbol) -> Int :
val id = length(table(names))
val id = next(id-counter(names))
table(names)[id] = name
id

;Generate a new integer id.
public defn make-id (names:TypeNames) -> Int :
next(id-counter(names))

;Retrieve name of a given identifier.
public defn get (names:TypeNames, id:Int) -> Symbol :
table(names)[id]
Expand All @@ -73,6 +81,10 @@ public defstruct Infer <: TypeStmt :
a:InferLoc|Type
b:InferLoc|Type

public defstruct SelectOverloadCall <: TypeStmt :
a:TArrow
args:Tuple<Type>

public defstruct Solve <: TypeStmt :
eqns:Tuple<TypeEqn>

Expand All @@ -86,6 +98,12 @@ public defstruct TNamedOf <: Type :
with:
printer => true

public defstruct TMixBoundVar <: Type :
name:Symbol
types:Tuple<TArrow>
with:
printer => true

;============================================================
;=================== Equations with Names ===================
;============================================================
Expand All @@ -109,16 +127,19 @@ public defstruct NamedInferEqn <: NamedTypeEqn :
public defstruct NamedOverloadExpEqn <: NamedTypeEqn :
n:Symbol
sel:Symbol
xs:Tuple<TArrow>
inst:Symbol
functions:Tuple<KeyValue<Symbol,TFunction>>
y:Type

;Infer the type of an overloaded function in call
;position.
public defstruct NamedOverloadCallEqn <: NamedTypeEqn :
n:Symbol
exp-ns:Tuple<Symbol>
sel:Symbol
xs:Tuple<TFunction>
args:Tuple<FArg>
xs:Tuple<TArrow>
exp-xs:Tuple<Tuple<Type>>
args:Tuple<Type>

;Perform the given substitutions in the given
;type.
Expand Down Expand Up @@ -154,11 +175,14 @@ public defn map (f:Type -> Type, e:NamedTypeEqn) -> NamedTypeEqn :
defn g2 (x:TArrow) : f(x) as TArrow
defn g2 (x:Tuple<TFunction>) : map(g2,x)
defn g2 (x:Tuple<TArrow>) : map(g2,x)
defn g2 (x:KeyValue<Symbol,TFunction>) : key(x) => f(value(x)) as TFunction
defn g2 (x:Tuple<KeyValue<Symbol,TFunction>>) : map(g2,x)
defn g2 (x:Tuple<Tuple<Type>>) : map(g,x)
match(e) :
(e:NamedCaptureAllEqn) : NamedCaptureAllEqn(ns(e), g(a(e)), g(b(e)))
(e:NamedInferEqn) : NamedInferEqn(n(e), g(a(e)), g(b(e)))
(e:NamedOverloadExpEqn) : NamedOverloadExpEqn(n(e), sel(e), g2(xs(e)), g(y(e)))
(e:NamedOverloadCallEqn) : NamedOverloadCallEqn(n(e), sel(e), g2(xs(e)), g(args(e)))
(e:NamedOverloadExpEqn) : NamedOverloadExpEqn(n(e), sel(e), inst(e), g2(functions(e)), g(y(e)))
(e:NamedOverloadCallEqn) : NamedOverloadCallEqn(n(e), exp-ns(e), sel(e), g2(xs(e)), g2(exp-xs(e)), g(args(e)))
(e:NamedSubEqn) : NamedSubEqn(n(e), g(x(e)), g(ys(e)))
(e:NamedSuperEqn) : NamedSuperEqn(n(e), g(x(e)))
(e:NamedEqualEqn) : NamedEqualEqn(n(e), g(type(e)))
Expand All @@ -168,23 +192,23 @@ public defn do-named-vars (f:Symbol -> False, e:NamedTypeEqn) :
match(e) :
(e:NamedCaptureAllEqn) : do(f,ns(e))
(e:NamedInferEqn) : f(n(e))
(e:NamedOverloadExpEqn) : (f(n(e)), f(sel(e)))
(e:NamedOverloadCallEqn) : (f(n(e)), f(sel(e)))
(e:NamedOverloadExpEqn) : (f(n(e)), f(sel(e)), f(inst(e)))
(e:NamedOverloadCallEqn) : (f(n(e)), do(f,exp-ns(e)), f(sel(e)))
(e:NamedSubEqn) : f(n(e))
(e:NamedSuperEqn) : f(n(e))
(e:NamedEqualEqn) : f(n(e))

;Replace the names with their ids in the given equation.
public defn replace-names (f:Symbol -> Int, e:NamedTypeEqn) -> TypeEqn :
defn g (ns:Tuple<Symbol>) : map(f,ns)
defn g (xs:Tuple<KeyValue<Symbol,Type>>) : for x in xs map : f(key(x)) => value(x)
defn g2 (xs:Tuple<KeyValue<Symbol,TFunction>>) : for x in xs map : f(key(x)) => value(x)
match(e) :
(e:NamedCaptureAllEqn) : CaptureAllEqn(map(f,ns(e)), a(e), b(e))
(e:NamedInferEqn) : InferEqn(f(n(e)), a(e), b(e))
(e:NamedOverloadExpEqn) : OverloadExpEqn(f(n(e)), f(sel(e)), xs(e), y(e))
(e:NamedOverloadCallEqn) : OverloadCallEqn(f(n(e)), f(sel(e)), xs(e), args(e))
(e:NamedSubEqn) :
val new-ys = for y in ys(e) map :
f(key(y)) => value(y)
SubEqn(f(n(e)), x(e), new-ys)
(e:NamedCaptureAllEqn) : CaptureAllEqn(g(ns(e)), CapRight, [a(e)], [b(e)])
(e:NamedInferEqn) : InferEqn(f(n(e)), a(e), b(e), false)
(e:NamedOverloadExpEqn) : OverloadExpEqn(f(n(e)), f(sel(e)), f(inst(e)), g2(functions(e)), None(), y(e))
(e:NamedOverloadCallEqn) : OverloadCallEqn(f(n(e)), g(exp-ns(e)), f(sel(e)), xs(e), exp-xs(e), args(e))
(e:NamedSubEqn) : SubEqn(f(n(e)), x(e), g(ys(e)))
(e:NamedSuperEqn) : SuperEqn(f(n(e)), x(e))
(e:NamedEqualEqn) : EqualEqn(f(n(e)), type(e))

Expand All @@ -210,6 +234,9 @@ defmethod print (o:OutputStream, s:Subtype) :
defmethod print (o:OutputStream, s:Infer) :
print(o, "infer %_ <: %_" % [a(s), b(s)])

defmethod print (o:OutputStream, s:SelectOverloadCall) :
print(o, "select-overload-call %_ (%,)" % [a(s), args(s)])

defmethod print (o:OutputStream, s:Solve) :
print(o, "Solve%_" % [colon-field-list(eqns(s))])

Expand Down Expand Up @@ -240,6 +267,7 @@ public defn format (f:TypeFormatter, s:TypeStmt) :
(s:Capture) : "capture(%,) %_ <: %_" % [fmt(args(s)), fmt(a(s)), fmt(b(s))]
(s:Subtype) : "subtype %_ <: %_" % [format(f,a(s)), fmt(b(s))]
(s:Infer) : "infer %_ <: %_" % [fmt(a(s)), fmt(b(s))]
(s:SelectOverloadCall) : "select-overload-call %_ (%,)" % [fmt(a(s)), seq(fmt,args(s))]
(s:Solve) : "solve%_" % [colon-field-list(seq(fmt,eqns(s)))]

defn wrap-printable (f:OutputStream -> False) :
Expand Down

0 comments on commit 63dccbe

Please sign in to comment.