Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type unification for rows, new records, and variants #1646

Merged
merged 17 commits into from
Apr 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions UNRELEASED.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
DO NOT LEAVE A BLANK LINE BELOW THIS PREAMBLE -->
### Features

* Experimental type unification over rows, new records, and variants, see #1646

### Breaking changes

* Add the option `--features` to enable experimental features, see #1648

### Bug fixes

* Fix references to `--tune-here` (actually `--tuning-options`), see #1579
* Not failing when assignment and `UNCHANGED` appear in invariants, see #1664
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ object DefaultType1Parser extends Parsers with Type1Parser {
private def noFunExpr: Parser[TlaType1] = {
(INT() | REAL() | BOOL() | STR() | typeVar | typeConst
| set | seq | tuple | row | sparseTuple
| record | parametricRecord | recordFromRow | recordVar
| record | recordFromRow
| variant | variantVar | parenExpr) ^^ {
case INT() => IntT1()
case REAL() => RealT1()
Expand Down Expand Up @@ -146,7 +146,8 @@ object DefaultType1Parser extends Parsers with Type1Parser {

case _ ~ list ~ None ~ _ =>
RowT1(list: _*)
}
} | // the degenerate case of (| var |)
LROW() ~> typeVar <~ RROW() ^^ { v => RowT1(v) }
}

// a sparse tuple type like <| 3: Int, 5: Bool |>
Expand Down Expand Up @@ -178,13 +179,6 @@ object DefaultType1Parser extends Parsers with Type1Parser {
}
}

private def parametricRecord: Parser[TlaType1] = {
// special rule for a record that is completely underspecified, that is, { a }
LCURLY() ~ typeVar ~ RCURLY() ^^ { case _ ~ VarT1(v) ~ _ =>
RecRowT1(RowT1(VarT1(v)))
}
}

private def findDups(list: List[String]): Option[String] = {
// we could use list.groupBy(identity) to count the number of occurrences,
// but that would introduce an unnecessary map
Expand Down Expand Up @@ -215,14 +209,8 @@ object DefaultType1Parser extends Parsers with Type1Parser {

case list ~ None =>
RecRowT1(RowT1(list: _*))
}
}

// the general record constructor which may be used in conjunction with a row variable
private def recordVar: Parser[TlaType1] = {
RECORD() ~ LPAREN() ~ typeVar ~ RPAREN() ^^ { case _ ~ _ ~ VarT1(v) ~ _ =>
RecRowT1(RowT1(VarT1(v)))
}
} | // the degenerate case of a single variable
(LCURLY() ~> typeVar <~ RCURLY()) ^^ (v => RecRowT1(RowT1(v)))
}

// An option in the variant type that is constructed from a row.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import at.forsyte.apalache.tla.lir.transformations.standard.{
}
import at.forsyte.apalache.tla.lir.transformations.{TlaExTransformation, TransformationTracker}
import at.forsyte.apalache.tla.pp.Inliner.FilterFun
import at.forsyte.apalache.tla.typecheck.etc.{Substitution, TypeUnifier}
import at.forsyte.apalache.tla.typecheck.etc.{Substitution, TypeUnifier, TypeVarPool}

/**
* Given a module m, with global operators F1,...,Fn, Inliner performs the following transformation:
Expand Down Expand Up @@ -82,7 +82,8 @@ class Inliner(
// a substitution of the two. A substitution is assumed to exist, otherwise TypingException is thrown.
private def getSubstitution(targetType: TlaType1, decl: TlaOperDecl): (Substitution, TlaType1) = {
val genericType = decl.typeTag.asTlaType1()
new TypeUnifier().unify(Substitution.empty, genericType, targetType) match {
val maxUsedVar = Math.max(genericType.usedNames.foldLeft(0)(Math.max), targetType.usedNames.foldLeft(0)(Math.max))
new TypeUnifier(new TypeVarPool(maxUsedVar + 1)).unify(Substitution.empty, genericType, targetType) match {
Comment on lines +85 to +86
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, was there a case where type-var name IDs were causing problems?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We had it a month or two ago.

case None =>
throw new TypingException(
s"Inliner: Unable to unify generic signature $genericType of ${decl.name} with the concrete type $targetType",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import at.forsyte.apalache.tla.lir.TlaType1
* @author
* Igor Konnov
*/
class ConstraintSolver(approximateSolution: Substitution = Substitution.empty) {
class ConstraintSolver(varPool: TypeVarPool, approximateSolution: Substitution = Substitution.empty) {
private var solution: Substitution = approximateSolution
private var constraints: List[Clause] = List.empty
private var typesToReport: List[(Clause, TlaType1)] = List.empty
Expand Down Expand Up @@ -98,7 +98,7 @@ class ConstraintSolver(approximateSolution: Substitution = Substitution.empty) {
constraint match {
case EqClause(unknown, term) =>
// If there is a solution, we return it. We ignore the type, as it should be bound to `unknown`.
new TypeUnifier().unify(solution, unknown, term)
new TypeUnifier(varPool).unify(solution, unknown, term)

case OrClause(eqs @ _*) =>
// try to solve a disjunctive clause
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class EtcTypeChecker(varPool: TypeVarPool, inferPolytypes: Boolean = true) exten

// The types are computed in operator applications, add extra tests and listener calls for non-operators
try {
val rootSolver = new ConstraintSolver
val rootSolver = new ConstraintSolver(varPool)
// The whole expression has been processed. Compute the type of the expression.
val rootType = computeRec(rootCtx, rootSolver, rootEx)
rootSolver.solve() match {
Expand Down Expand Up @@ -240,7 +240,7 @@ class EtcTypeChecker(varPool: TypeVarPool, inferPolytypes: Boolean = true) exten
val approxSolution = solver.solvePartially().getOrElse(throw new UnwindException)

// introduce a new instance of the constraint solver for the operator definition
val letInSolver = new ConstraintSolver()
val letInSolver = new ConstraintSolver(varPool)
val operScheme =
ctx.types.get(name) match {
case Some(scheme @ TlaType1Scheme(OperT1(_, _), _)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package at.forsyte.apalache.tla.typecheck.etc

import at.forsyte.apalache.tla.lir._

import scala.annotation.tailrec
import scala.collection.immutable.SortedMap

/**
Expand All @@ -15,10 +16,12 @@ import scala.collection.immutable.SortedMap
*
* <p>This class is not designed for concurrency. Use different instances in different threads.</p>
*
* @param varPool
* variable pool that is used to create fresh variables
* @author
* Igor Konnov
*/
class TypeUnifier {
class TypeUnifier(varPool: TypeVarPool) {
// A variable is mapped to its equivalence class. By default, a variable sits in the singleton equivalence class
// of its own. When two variables are unified, they are merged in the same equivalence class.
private var varToClass: Map[Int, EqClass] = Map.empty
Expand Down Expand Up @@ -87,32 +90,33 @@ class TypeUnifier {
}
}

private def compute(lhs: TlaType1, rhs: TlaType1): Option[TlaType1] = {
// Try to unify a variable with a non-variable term `typeTerm`.
// If `typeTerm` refers to a variable in the equivalence class of `typeVar`, then this is a cyclic reference,
// and there should be no unifier.
def unifyVarWithNonVarTerm(typeVar: Int, typeTerm: TlaType1): Option[TlaType1] = {
// Note that `typeTerm` is not a variable.
val varClass = varToClass(typeVar)
if (doesUseClass(typeTerm, varClass)) {
// No unifier: `typeTerm` refers to a variable in the equivalence class of `typeVar`.
None
} else {
// this variable is associated with an equivalence class, unify the class with `typeTerm`
solution(varClass) match {
case VarT1(_) =>
// an equivalence class of free variables, just assign `typeTerm` to this class
solution += varClass -> typeTerm
Some(typeTerm)

case _ =>
// unify `typeTerm` with the term assigned to the equivalence class, if possible
val unifier = compute(solution(varClass), typeTerm)
unifier.foreach { t => solution += varClass -> t }
unifier
}
// Try to unify a variable with a non-variable term `typeTerm`.
// If `typeTerm` refers to a variable in the equivalence class of `typeVar`, then this is a cyclic reference,
// and there should be no unifier.
private def unifyVarWithNonVarTerm(typeVar: Int, typeTerm: TlaType1): Option[TlaType1] = {
// Note that `typeTerm` is not a variable.
val varClass = varToClass(typeVar)
if (doesUseClass(typeTerm, varClass)) {
// No unifier: `typeTerm` refers to a variable in the equivalence class of `typeVar`.
None
} else {
// this variable is associated with an equivalence class, unify the class with `typeTerm`
solution(varClass) match {
case VarT1(_) =>
// an equivalence class of free variables, just assign `typeTerm` to this class
solution += varClass -> typeTerm
Some(typeTerm)

case nonVar =>
// unify `typeTerm` with the term assigned to the equivalence class, if possible
val unifier = compute(nonVar, typeTerm)
unifier.foreach { t => solution += varClass -> t }
unifier
}
}
}

private def compute(lhs: TlaType1, rhs: TlaType1): Option[TlaType1] = {

// unify types as terms
(lhs, rhs) match {
Expand Down Expand Up @@ -201,7 +205,8 @@ class TypeUnifier {
case (l @ TupT1(_ @_*), r @ SparseTupT1(_)) =>
compute(r, l)

// records join their keys, but the values for the intersecting keys should unify
// Records join their keys, but the values for the intersecting keys should unify.
// This is the old unification rule for the records. For the new records, see the rule for RecRowT1.
case (RecT1(lfields), RecT1(rfields)) =>
val jointKeys = (lfields.keySet ++ rfields.keySet).toSeq
val pairs = jointKeys.map(key => (key, computeFields(key, lfields, rfields)))
Expand All @@ -212,12 +217,106 @@ class TypeUnifier {
Some(unifiedTuple)
}

case (RowT1(lfields, lv), RowT1(rfields, rv)) =>
unifyRows(lfields, rfields, lv, rv)

case (RecRowT1(RowT1(lfields, lv)), RecRowT1(RowT1(rfields, rv))) =>
unifyRows(lfields, rfields, lv, rv).map(t => RecRowT1(t))

case (VariantT1(RowT1(lfields, lv)), VariantT1(RowT1(rfields, rv))) =>
unifyRows(lfields, rfields, lv, rv).map(t => VariantT1(t))

// everything else does not unify
case _ =>
None // no unifier
}
}

// unify two rows
@tailrec
private def unifyRows(
lfields: SortedMap[String, TlaType1],
rfields: SortedMap[String, TlaType1],
lvar: Option[VarT1],
rvar: Option[VarT1]): Option[RowT1] = {
// assuming that a type is either a row, or a variable, make it a row type
def asRow(rowOpt: Option[TlaType1]): Option[RowT1] = rowOpt.map {
case r: RowT1 => r
case v: VarT1 => RowT1(v)
case tp => throw new IllegalStateException("Expected RowT1(_, _) or VarT1(_), found: " + tp)
}

// consider four cases
if (lfields.isEmpty) {
// the base case
(lvar, rvar) match {
case (None, None) =>
if (rfields.nonEmpty) None else Some(RowT1())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an aside, I really don't like how object RowT1 methods are all apply, because it makes it hard to track what is actually being constructed, since RowT1(), RowT1(_), RowT1(_,_) are all different methods that return a RowT1.


case (Some(lv), Some(rv)) =>
if (rfields.isEmpty) {
asRow(compute(lv, rv))
} else {
asRow(unifyVarWithNonVarTerm(lv.no, RowT1(rfields, rvar)))
}

case (Some(lv), None) =>
asRow(unifyVarWithNonVarTerm(lv.no, RowT1(rfields, None)))

case (None, Some(rv)) =>
if (rfields.isEmpty) {
// the only way to match is to make the right variable equal to the empty row
asRow(unifyVarWithNonVarTerm(rv.no, RowT1()))
} else {
// the left row is empty, whereas the right row is non-empty
None
}
}
} else if (rfields.isEmpty) {
// the symmetric case above
unifyRows(rfields, lfields, rvar, lvar)
} else {
val sharedFieldNames = lfields.keySet.intersect(rfields.keySet)
if (sharedFieldNames.isEmpty) {
// The easy case: no shared fields.
// The left row is (| lfields | lvar |).
// The right row is (| rfields | rvar |).
// Introduce a fresh type variable to contain the common tail.
val tailVar = freshVar()
// Unify lvar with (| rfields | tailVar |).
// Unify rvar with (| lfields | tailVar |).
// If both unifiers exist, the result is (| lfields | rfields | tailVar |).
if (
compute(lvar.getOrElse(RowT1()), RowT1(rfields, Some(tailVar))).isEmpty
|| compute(rvar.getOrElse(RowT1()), RowT1(lfields, Some(tailVar))).isEmpty
) {
None
} else {
// apply the computed substitution to obtain the whole row
asRow(Some(Substitution(solution).sub(RowT1(lfields, lvar))._1))
Comment on lines +289 to +296
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that compute mutates solution internally makes this really hard to understand. I'd at least appreciate a comment on that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, if you don't need the substitution, it does unify things. But since we need the resulting substitution, compute mutates the substitution pretty much everywhere.

}
} else {
// the general case: some fields are shared
val lfieldsUniq = lfields.filter(p => !sharedFieldNames.keySet.contains(p._1))
val rfieldsUniq = rfields.filter(p => !sharedFieldNames.keySet.contains(p._1))
// Unify the disjoint fields and tail variables, see the above case
compute(RowT1(lfieldsUniq, lvar), RowT1(rfieldsUniq, rvar)) match {
case Some(RowT1(disjointFields, tailVar)) =>
// unify the shared fields, if possible
val unifiedSharedFields = sharedFieldNames.map(key => (key, compute(lfields(key), rfields(key))))
if (unifiedSharedFields.exists(_._2.isEmpty)) {
None
} else {
val finalSharedFields = SortedMap(unifiedSharedFields.map(p => (p._1, p._2.get)).toSeq: _*)
Some(RowT1(finalSharedFields ++ disjointFields, tailVar))
}

case _ => None
}
}
}
}

// unify two sequences
private def unifySeqs(ls: Seq[TlaType1], rs: Seq[TlaType1]): Option[Seq[TlaType1]] = {
val len = ls.length
Expand Down Expand Up @@ -289,6 +388,15 @@ class TypeUnifier {
}
new Substitution(Map[EqClass, TlaType1](mapping: _*))
}

// introduce a fresh variable
private def freshVar(): VarT1 = {
val fresh = varPool.fresh
val cls = EqClass(fresh.no)
varToClass += (fresh.no -> cls)
solution += (cls -> fresh)
fresh
}
}

object TypeUnifier {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,12 @@ class TestDefaultType1Parser extends AnyFunSuite with Checkers with TlaType1Gen
assert(RowT1() == result)
}

test("single-variable row") {
val text = """(| c |)"""
val result = DefaultType1Parser.parseType(text)
assert(RowT1(VarT1("c")) == result)
}

test("concrete row") {
val text = """(| f: Int | g: c |)"""
val result = DefaultType1Parser.parseType(text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ import org.scalatestplus.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestConstraintSolver extends AnyFunSuite with EasyMockSugar with EtcBuilder {
private val FIRST_VAR: Int = 100
private val parser: Type1Parser = DefaultType1Parser

test("unique solution") {
val solver = new ConstraintSolver
val solver = new ConstraintSolver(new TypeVarPool(FIRST_VAR))
// a disjunctive constraint that comes from a tuple constructor
// either a == (b, c) => <<b, c>>
val option1 = EqClause(VarT1("a"), OperT1(Seq(VarT1("b"), VarT1("c")), parser("<<b, c>>")))
Expand All @@ -30,7 +31,7 @@ class TestConstraintSolver extends AnyFunSuite with EasyMockSugar with EtcBuilde
}

test("multiple solutions") {
val solver = new ConstraintSolver
val solver = new ConstraintSolver(new TypeVarPool(FIRST_VAR))
// a disjunctive constraint that comes from a tuple constructor
// either a == (b, c) => <<b, c>>
val option1 = EqClause(VarT1("a"), OperT1(Seq(VarT1("b"), VarT1("c")), parser("<<b, c>>")))
Expand All @@ -47,7 +48,7 @@ class TestConstraintSolver extends AnyFunSuite with EasyMockSugar with EtcBuilde
}

test("constraints in the reverse order") {
val solver = new ConstraintSolver
val solver = new ConstraintSolver(new TypeVarPool(FIRST_VAR))
// The following constraints come in the order that is reverse to the one that is required to solve the constraints.
// These constraints are made up, they do not come from any real constraints that are produced by TLA+ operators.
val eq1 = EqClause(VarT1("a"), parser("(b, c) => b"))
Expand Down
Loading