From 5542b4d317a01afbbb853b17cb3e264edb4f3990 Mon Sep 17 00:00:00 2001 From: Thomas Pani Date: Thu, 6 Jul 2023 14:23:31 +0200 Subject: [PATCH] Unify in-scope name types Enable construction of names with a more concrete type than the one in scope. --- .../subbuilder/LiteralAndNameBuilder.scala | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/tlair/src/main/scala/at/forsyte/apalache/tla/typecomp/subbuilder/LiteralAndNameBuilder.scala b/tlair/src/main/scala/at/forsyte/apalache/tla/typecomp/subbuilder/LiteralAndNameBuilder.scala index f05ebcac93..e03282313a 100644 --- a/tlair/src/main/scala/at/forsyte/apalache/tla/typecomp/subbuilder/LiteralAndNameBuilder.scala +++ b/tlair/src/main/scala/at/forsyte/apalache/tla/typecomp/subbuilder/LiteralAndNameBuilder.scala @@ -3,6 +3,7 @@ package at.forsyte.apalache.tla.typecomp.subbuilder import at.forsyte.apalache.tla.lir._ import at.forsyte.apalache.tla.typecomp._ import at.forsyte.apalache.tla.typecomp.unsafe.UnsafeLiteralAndNameBuilder +import at.forsyte.apalache.tla.types.{Substitution, TypeUnifier, TypeVarPool} import scalaz.Scalaz._ import scalaz._ @@ -58,16 +59,22 @@ trait LiteralAndNameBuilder { def name(exprName: String, t: TlaType1): TBuilderInstruction = State[TBuilderContext, TlaEx] { mi => val scope = mi.freeNameScope - // If already in scope, type must be the same - scope.get(exprName).foreach { tt => - if (tt != t) - throw new TBuilderScopeException( + // If already in scope, type must be unifiable + val unifT = scope.get(exprName).map { tt => + val unifOpt = new TypeUnifier(new TypeVarPool()).unify(Substitution.empty, tt, t) + unifOpt match { + case Some((_, unifiedOperT)) => unifiedOperT + case None => + throw new TBuilderScopeException( s"Name $exprName with type $t constructed in scope where expected type is $tt." - ) + ) + } } + // If not in scope, just use `t` + val finalT = unifT.getOrElse(t) - val ret = unsafeBuilder.name(exprName, t) - (mi.copy(freeNameScope = scope + (exprName -> t), usedNames = mi.usedNames + exprName), ret) + val ret = unsafeBuilder.name(exprName, finalT) + (mi.copy(freeNameScope = scope + (exprName -> finalT), usedNames = mi.usedNames + exprName), ret) } /** Attempt to infer the type from the scope. Fails if exprName is not in scope. */