Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

surface: Fix #3454 - methods in generic classes - for Scala 3 #3455

Merged
merged 3 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -305,13 +305,13 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q):
}
}

// Build a table for resolving type parameters, e.g., class MyClass[A, B] -> Map("A" -> TypeRepr, "B" -> TypeRepr)
private def typeMappingTable(t: TypeRepr, method: Symbol): Map[String, TypeRepr] =
val classTypeParams = t.typeSymbol.typeMembers.filter(_.isTypeParam)
val classTypeArgs: List[TypeRepr] = t match
case a: AppliedType => a.args
case _ => List.empty[TypeRepr]

// Build a table for resolving type parameters, e.g., class MyClass[A, B] -> Map("A" -> TypeRepr, "B" -> TypeRepr)
(classTypeParams zip classTypeArgs)
.map { (paramType, argType) =>
paramType.name -> argType
Expand Down Expand Up @@ -492,15 +492,33 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q):
isSecret: Boolean
)

private def resolveType(t: TypeRepr, typeArgTable: Map[String, TypeRepr]): TypeRepr =
t match
case a: AppliedType =>
// println(s"=== a.args ${a.args}")
// println(s"=== typeArgTable ${typeArgTable}")
val resolvedTypeArgs = a.args.map {
case p if p.typeSymbol.isTypeParam && typeArgTable.contains(p.typeSymbol.name) =>
typeArgTable(p.typeSymbol.name)
case other =>
resolveType(other, typeArgTable)
}
// println(s"=== resolvedTypeArgs ${resolvedTypeArgs}")
// Need to use the base type of the applied type to replace the type parameters
a.tycon.appliedTo(resolvedTypeArgs)
case TypeRef(_, name) if typeArgTable.contains(name) =>
typeArgTable(name)
case other =>
other

private def methodArgsOf(t: TypeRepr, method: Symbol): List[List[MethodArg]] =
// println(s"==== method args of ${fullTypeNameOf(t)}")

val defaultValueMethods = t.typeSymbol.companionClass.declaredMethods.filter { m =>
m.name.startsWith("apply$default$") || m.name.startsWith("$lessinit$greater$default$")
}

// Build a table for resolving type parameters, e.g., class MyClass[A, B] -> Map("A" -> TypeRepr, "B" -> TypeRepr)
val typeArgTable: Map[String, TypeRepr] = typeMappingTable(t, method)
val typeArgTable = typeMappingTable(t, method)

val paramss: List[List[Symbol]] = method.paramSymss.filter { lst =>
// Empty arg is allowed
Expand All @@ -517,26 +535,7 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q):
// println(s"=== ${v.show} ${s.flags.show} ${s.flags.is(Flags.Implicit)}")
// Substitute type param to actual types

def resolveType(t: TypeRepr): TypeRepr =
t match
case a: AppliedType =>
// println(s"=== a.args ${a.args}")
// println(s"=== typeArgTable ${typeArgTable}")
val resolvedTypeArgs = a.args.map {
case p if p.typeSymbol.isTypeParam && typeArgTable.contains(p.typeSymbol.name) =>
typeArgTable(p.typeSymbol.name)
case other =>
resolveType(other)
}
// println(s"=== resolvedTypeArgs ${resolvedTypeArgs}")
// Need to use the base type of the applied type to replace the type parameters
a.tycon.appliedTo(resolvedTypeArgs)
case TypeRef(_, name) if typeArgTable.contains(name) =>
typeArgTable(name)
case other =>
other

val resolved: TypeRepr = resolveType(v)
val resolved = resolveType(v, typeArgTable)

val isSecret = hasSecretAnnotation(s)
val isRequired = hasRequiredAnnotation(s)
Expand Down Expand Up @@ -753,11 +752,13 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q):
seenMethodParent += targetType
val localMethods = localMethodsOf(targetType).distinct.sortBy(_.name)
val methodSurfaces = localMethods.map(m => (m, m.tree)).collect { case (m, df: DefDef) =>
val mod = Expr(modifierBitMaskOf(m))
val owner = surfaceOf(targetType)
val name = Expr(m.name)
// println(s"======= ${df.returnTpt.show}")
val ret = surfaceOf(df.returnTpt.tpe)
val mod = Expr(modifierBitMaskOf(m))
val owner = surfaceOf(targetType)
val name = Expr(m.name)
val typeArgTable = typeMappingTable(targetType, m)
val returnType = resolveType(df.returnTpt.tpe, typeArgTable)
// println(s"======= ${returnType.show}")
val ret = surfaceOf(returnType)
// println(s"==== method of: def ${m.name}")
val params = methodParametersOf(targetType, m)
val args = methodArgsOf(targetType, m)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ package wvlet.airframe.surface

import wvlet.airspec.AirSpec

object GenericMethodTest extends AirSpec {
class A {
object GenericMethodTest extends AirSpec:
class A:
def helloX[X](v: X): String = "hello"
}

test("generic method") {
val methods = Surface.methodsOf[A]
Expand All @@ -29,4 +28,21 @@ object GenericMethodTest extends AirSpec {
m.call(obj, "dummy") shouldBe "hello"
}

}
case class Gen[X](value: X):
def pass(x: X): X = x
def myself: Gen[X] = this
def wrap(x: X): Gen[X] = Gen[X](value)
def unwrap(x: Gen[X]): X = x.value

test("Methods of generic type") {
val typeSurface = Surface.of[Gen[String]]
val methods = Surface.methodsOf[Gen[String]]
val pass = methods.find(_.name == "pass").get
pass.returnType shouldBe Surface.of[String]
val myself = methods.find(_.name == "myself").get
myself.returnType shouldBe typeSurface
val wrap = methods.find(_.name == "wrap").get
wrap.returnType shouldBe typeSurface
val unwrap = methods.find(_.name == "unwrap").get
unwrap.returnType shouldBe Surface.of[String]
}
Loading