diff --git a/tasty-query/shared/src/main/scala/tastyquery/Types.scala b/tasty-query/shared/src/main/scala/tastyquery/Types.scala index a9459a06..6ab6c446 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Types.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Types.scala @@ -198,7 +198,7 @@ object Types { ResolveMemberResult.TypeMember(rightSyms, rightBounds) ) => val syms = mergeSyms(leftSyms, rightSyms) - val bounds = leftBounds.intersect(rightBounds) + val bounds = mergeTypeMemberTypeBounds(leftBounds, rightBounds) ResolveMemberResult.TypeMember(syms, bounds) // Cases that cannot happen -- list them to preserve exhaustivity checking of every other case @@ -228,6 +228,54 @@ object Types { case _ => throw InvalidProgramStructureException(s"Cannot merge types $tp1 and $tp2") end mergeTermMemberTypes + + private def mergeTypeMemberTypeBounds(bounds1: TypeBounds, bounds2: TypeBounds)(using Context): TypeBounds = + // This implementation assumes that the program structure is valid + (bounds1, bounds2) match + case _ if bounds1 eq bounds2 => + bounds1 + case (bounds1: TypeAlias, _) => + bounds1 + case (_, bounds2: TypeAlias) => + bounds2 + + case (bounds1 @ AbstractTypeBounds(low1, high1), bounds2 @ AbstractTypeBounds(low2, high2)) => + val mergedLow = mergeTypeMemberLowBounds(low1, low2) + val mergedHigh = mergeTypeMemberHighBounds(high1, high2) + bounds1.derivedTypeBounds(mergedLow, mergedHigh) + end mergeTypeMemberTypeBounds + + private def mergeTypeMemberLowBounds(low1: Type, low2: Type)(using Context): Type = + (low1.dealias, low2.dealias) match + case (low1: TypeLambda, low2: TypeLambda) if low1.paramNames.sizeCompare(low2.paramNames) == 0 => + low1.derivedLambdaType( + low1.paramNames, + low1.paramTypeBounds, + mergeTypeMemberLowBounds(low1.resultType, low2.instantiate(low1.paramRefs)) + ) + case (_: NothingType, _) | (_, _: AnyKindType) => + low2 + case (_, _: NothingType) | (_: AnyKindType, _) => + low1 + case _ => + low1 | low2 + end mergeTypeMemberLowBounds + + private def mergeTypeMemberHighBounds(high1: Type, high2: Type)(using Context): Type = + (high1.dealias, high2.dealias) match + case (high1: TypeLambda, high2: TypeLambda) if high1.paramNames.sizeCompare(high2.paramNames) == 0 => + high1.derivedLambdaType( + high1.paramNames, + high1.paramTypeBounds, + mergeTypeMemberHighBounds(high1.resultType, high2.instantiate(high1.paramRefs)) + ) + case (_: AnyKindType, _) | (_, _: NothingType) => + high2 + case (_, _: AnyKindType) | (_: NothingType, _) => + high1 + case _ => + high1 & high2 + end mergeTypeMemberHighBounds end ResolveMemberResult /** A type parameter of a type constructor.