-
-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type unification for rows, new records, and variants #1646
Changes from all commits
fed21d2
cdb8748
00dd40c
406a13f
f7895a2
d0bf6e5
2f930ec
5ee543f
a1689cd
0cea2be
4a728c8
1e9aa7f
684d3f8
e15664b
c88eb19
1e1aa8d
849b7a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ package at.forsyte.apalache.tla.typecheck.etc | |
|
||
import at.forsyte.apalache.tla.lir._ | ||
|
||
import scala.annotation.tailrec | ||
import scala.collection.immutable.SortedMap | ||
|
||
/** | ||
|
@@ -15,10 +16,12 @@ import scala.collection.immutable.SortedMap | |
* | ||
* <p>This class is not designed for concurrency. Use different instances in different threads.</p> | ||
* | ||
* @param varPool | ||
* variable pool that is used to create fresh variables | ||
* @author | ||
* Igor Konnov | ||
*/ | ||
class TypeUnifier { | ||
class TypeUnifier(varPool: TypeVarPool) { | ||
// A variable is mapped to its equivalence class. By default, a variable sits in the singleton equivalence class | ||
// of its own. When two variables are unified, they are merged in the same equivalence class. | ||
private var varToClass: Map[Int, EqClass] = Map.empty | ||
|
@@ -87,32 +90,33 @@ class TypeUnifier { | |
} | ||
} | ||
|
||
private def compute(lhs: TlaType1, rhs: TlaType1): Option[TlaType1] = { | ||
// Try to unify a variable with a non-variable term `typeTerm`. | ||
// If `typeTerm` refers to a variable in the equivalence class of `typeVar`, then this is a cyclic reference, | ||
// and there should be no unifier. | ||
def unifyVarWithNonVarTerm(typeVar: Int, typeTerm: TlaType1): Option[TlaType1] = { | ||
// Note that `typeTerm` is not a variable. | ||
val varClass = varToClass(typeVar) | ||
if (doesUseClass(typeTerm, varClass)) { | ||
// No unifier: `typeTerm` refers to a variable in the equivalence class of `typeVar`. | ||
None | ||
} else { | ||
// this variable is associated with an equivalence class, unify the class with `typeTerm` | ||
solution(varClass) match { | ||
case VarT1(_) => | ||
// an equivalence class of free variables, just assign `typeTerm` to this class | ||
solution += varClass -> typeTerm | ||
Some(typeTerm) | ||
|
||
case _ => | ||
// unify `typeTerm` with the term assigned to the equivalence class, if possible | ||
val unifier = compute(solution(varClass), typeTerm) | ||
unifier.foreach { t => solution += varClass -> t } | ||
unifier | ||
} | ||
// Try to unify a variable with a non-variable term `typeTerm`. | ||
// If `typeTerm` refers to a variable in the equivalence class of `typeVar`, then this is a cyclic reference, | ||
// and there should be no unifier. | ||
private def unifyVarWithNonVarTerm(typeVar: Int, typeTerm: TlaType1): Option[TlaType1] = { | ||
// Note that `typeTerm` is not a variable. | ||
val varClass = varToClass(typeVar) | ||
if (doesUseClass(typeTerm, varClass)) { | ||
// No unifier: `typeTerm` refers to a variable in the equivalence class of `typeVar`. | ||
None | ||
} else { | ||
// this variable is associated with an equivalence class, unify the class with `typeTerm` | ||
solution(varClass) match { | ||
case VarT1(_) => | ||
// an equivalence class of free variables, just assign `typeTerm` to this class | ||
solution += varClass -> typeTerm | ||
Some(typeTerm) | ||
|
||
case nonVar => | ||
// unify `typeTerm` with the term assigned to the equivalence class, if possible | ||
val unifier = compute(nonVar, typeTerm) | ||
unifier.foreach { t => solution += varClass -> t } | ||
unifier | ||
} | ||
} | ||
} | ||
|
||
private def compute(lhs: TlaType1, rhs: TlaType1): Option[TlaType1] = { | ||
|
||
// unify types as terms | ||
(lhs, rhs) match { | ||
|
@@ -201,7 +205,8 @@ class TypeUnifier { | |
case (l @ TupT1(_ @_*), r @ SparseTupT1(_)) => | ||
compute(r, l) | ||
|
||
// records join their keys, but the values for the intersecting keys should unify | ||
// Records join their keys, but the values for the intersecting keys should unify. | ||
// This is the old unification rule for the records. For the new records, see the rule for RecRowT1. | ||
case (RecT1(lfields), RecT1(rfields)) => | ||
val jointKeys = (lfields.keySet ++ rfields.keySet).toSeq | ||
val pairs = jointKeys.map(key => (key, computeFields(key, lfields, rfields))) | ||
|
@@ -212,12 +217,106 @@ class TypeUnifier { | |
Some(unifiedTuple) | ||
} | ||
|
||
case (RowT1(lfields, lv), RowT1(rfields, rv)) => | ||
unifyRows(lfields, rfields, lv, rv) | ||
|
||
case (RecRowT1(RowT1(lfields, lv)), RecRowT1(RowT1(rfields, rv))) => | ||
unifyRows(lfields, rfields, lv, rv).map(t => RecRowT1(t)) | ||
|
||
case (VariantT1(RowT1(lfields, lv)), VariantT1(RowT1(rfields, rv))) => | ||
unifyRows(lfields, rfields, lv, rv).map(t => VariantT1(t)) | ||
|
||
// everything else does not unify | ||
case _ => | ||
None // no unifier | ||
} | ||
} | ||
|
||
// unify two rows | ||
@tailrec | ||
private def unifyRows( | ||
lfields: SortedMap[String, TlaType1], | ||
rfields: SortedMap[String, TlaType1], | ||
lvar: Option[VarT1], | ||
rvar: Option[VarT1]): Option[RowT1] = { | ||
// assuming that a type is either a row, or a variable, make it a row type | ||
def asRow(rowOpt: Option[TlaType1]): Option[RowT1] = rowOpt.map { | ||
case r: RowT1 => r | ||
case v: VarT1 => RowT1(v) | ||
case tp => throw new IllegalStateException("Expected RowT1(_, _) or VarT1(_), found: " + tp) | ||
} | ||
|
||
// consider four cases | ||
if (lfields.isEmpty) { | ||
// the base case | ||
(lvar, rvar) match { | ||
case (None, None) => | ||
if (rfields.nonEmpty) None else Some(RowT1()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As an aside, I really don't like how |
||
|
||
case (Some(lv), Some(rv)) => | ||
if (rfields.isEmpty) { | ||
asRow(compute(lv, rv)) | ||
} else { | ||
asRow(unifyVarWithNonVarTerm(lv.no, RowT1(rfields, rvar))) | ||
} | ||
|
||
case (Some(lv), None) => | ||
asRow(unifyVarWithNonVarTerm(lv.no, RowT1(rfields, None))) | ||
|
||
case (None, Some(rv)) => | ||
if (rfields.isEmpty) { | ||
// the only way to match is to make the right variable equal to the empty row | ||
asRow(unifyVarWithNonVarTerm(rv.no, RowT1())) | ||
} else { | ||
// the left row is empty, whereas the right row is non-empty | ||
None | ||
} | ||
} | ||
} else if (rfields.isEmpty) { | ||
// the symmetric case above | ||
unifyRows(rfields, lfields, rvar, lvar) | ||
} else { | ||
val sharedFieldNames = lfields.keySet.intersect(rfields.keySet) | ||
if (sharedFieldNames.isEmpty) { | ||
// The easy case: no shared fields. | ||
// The left row is (| lfields | lvar |). | ||
// The right row is (| rfields | rvar |). | ||
// Introduce a fresh type variable to contain the common tail. | ||
val tailVar = freshVar() | ||
// Unify lvar with (| rfields | tailVar |). | ||
// Unify rvar with (| lfields | tailVar |). | ||
// If both unifiers exist, the result is (| lfields | rfields | tailVar |). | ||
if ( | ||
compute(lvar.getOrElse(RowT1()), RowT1(rfields, Some(tailVar))).isEmpty | ||
|| compute(rvar.getOrElse(RowT1()), RowT1(lfields, Some(tailVar))).isEmpty | ||
) { | ||
None | ||
} else { | ||
// apply the computed substitution to obtain the whole row | ||
asRow(Some(Substitution(solution).sub(RowT1(lfields, lvar))._1)) | ||
Comment on lines
+289
to
+296
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fact that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, if you don't need the substitution, it does unify things. But since we need the resulting substitution, |
||
} | ||
} else { | ||
// the general case: some fields are shared | ||
val lfieldsUniq = lfields.filter(p => !sharedFieldNames.keySet.contains(p._1)) | ||
val rfieldsUniq = rfields.filter(p => !sharedFieldNames.keySet.contains(p._1)) | ||
// Unify the disjoint fields and tail variables, see the above case | ||
compute(RowT1(lfieldsUniq, lvar), RowT1(rfieldsUniq, rvar)) match { | ||
case Some(RowT1(disjointFields, tailVar)) => | ||
// unify the shared fields, if possible | ||
val unifiedSharedFields = sharedFieldNames.map(key => (key, compute(lfields(key), rfields(key)))) | ||
if (unifiedSharedFields.exists(_._2.isEmpty)) { | ||
None | ||
} else { | ||
val finalSharedFields = SortedMap(unifiedSharedFields.map(p => (p._1, p._2.get)).toSeq: _*) | ||
Some(RowT1(finalSharedFields ++ disjointFields, tailVar)) | ||
} | ||
|
||
case _ => None | ||
} | ||
} | ||
} | ||
} | ||
|
||
// unify two sequences | ||
private def unifySeqs(ls: Seq[TlaType1], rs: Seq[TlaType1]): Option[Seq[TlaType1]] = { | ||
val len = ls.length | ||
|
@@ -289,6 +388,15 @@ class TypeUnifier { | |
} | ||
new Substitution(Map[EqClass, TlaType1](mapping: _*)) | ||
} | ||
|
||
// introduce a fresh variable | ||
private def freshVar(): VarT1 = { | ||
val fresh = varPool.fresh | ||
val cls = EqClass(fresh.no) | ||
varToClass += (fresh.no -> cls) | ||
solution += (cls -> fresh) | ||
fresh | ||
} | ||
} | ||
|
||
object TypeUnifier { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huh, was there a case where type-var name IDs were causing problems?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. We had it a month or two ago.