Skip to content

Commit

Permalink
Add ParamClause to allow multiple type param clauses
Browse files Browse the repository at this point in the history
Reflects the changes done in scala#10940
  • Loading branch information
nicolasstucki committed Jan 12, 2021
1 parent 5950f83 commit 061c5fc
Show file tree
Hide file tree
Showing 20 changed files with 300 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,21 +256,22 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
end DefDefTypeTest

object DefDef extends DefDefModule:
def apply(symbol: Symbol, rhsFn: List[TypeRepr] => List[List[Term]] => Option[Term]): DefDef =
withDefaultPos(tpd.DefDef(symbol.asTerm, prefss => {
val (tparams, vparamss) = tpd.splitArgs(prefss)
yCheckedOwners(rhsFn(tparams.map(_.tpe))(vparamss), symbol).getOrElse(tpd.EmptyTree)
}))
def copy(original: Tree)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term]): DefDef =
tpd.cpy.DefDef(original)(name.toTermName, tpd.joinParams(typeParams, paramss), tpt, yCheckedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
def unapply(ddef: DefDef): (String, List[TypeDef], List[List[ValDef]], TypeTree, Option[Term]) =
(ddef.name.toString, ddef.typeParams, ddef.termParamss, ddef.tpt, optional(ddef.rhs))
def apply(symbol: Symbol, rhsFn: List[List[Tree]] => Option[Term]): DefDef =
withDefaultPos(tpd.DefDef(symbol.asTerm, prefss =>
yCheckedOwners(rhsFn(prefss), symbol).getOrElse(tpd.EmptyTree)
))
def copy(original: Tree)(name: String, paramss: List[ParamClause], tpt: TypeTree, rhs: Option[Term]): DefDef =
tpd.cpy.DefDef(original)(name.toTermName, paramss, tpt, yCheckedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
def unapply(ddef: DefDef): (String, List[ParamClause], TypeTree, Option[Term]) =
(ddef.name.toString, ddef.paramss, ddef.tpt, optional(ddef.rhs))
end DefDef

given DefDefMethods: DefDefMethods with
extension (self: DefDef)
def typeParams: List[TypeDef] = self.leadingTypeParams // TODO: adapt to multiple type parameter clauses
def paramss: List[List[ValDef]] = self.termParamss
def paramss: List[ParamClause] = self.paramss
def leadingTypeParams: List[TypeDef] = self.leadingTypeParams
def trailingParamss: List[ParamClause] = self.trailingParamss
def termParamss: List[TermParamClause] = self.termParamss
def returnTpt: TypeTree = self.tpt
def rhs: Option[Term] = optional(self.rhs)
end extension
Expand Down Expand Up @@ -750,7 +751,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
tpd.Closure(meth, tss => yCheckedOwners(rhsFn(meth, tss.head.map(withDefaultPos)), meth))

def unapply(tree: Block): Option[(List[ValDef], Term)] = tree match {
case Block((ddef @ DefDef(_, _, params :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
case Block((ddef @ DefDef(_, TermParamClause(params) :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
if ddef.symbol == meth.symbol =>
Some((params, body))
case _ => None
Expand Down Expand Up @@ -1481,6 +1482,54 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
end extension
end AlternativesMethods

type ParamClause = tpd.ParamClause

object ParamClause extends ParamClauseModule

given ParamClauseMethods: ParamClauseMethods with
extension (self: ParamClause)
def params: List[ValDef] | List[TypeDef] = self
end ParamClauseMethods

type TermParamClause = List[tpd.ValDef]

given TermParamClauseTypeTest: TypeTest[ParamClause, TermParamClause] with
def apply(params: List[ValDef]): TermParamClause = params
def unapply(x: ParamClause): Option[TermParamClause & x.type] = x match
case tpd.ValDefs(_) => Some(x.asInstanceOf[TermParamClause & x.type])
case _ => None
end TermParamClauseTypeTest

object TermParamClause extends TermParamClauseModule:
def apply(x: List[ValDef]): TermParamClause = x
def unapply(x: TermParamClause): Some[List[ValDef]] = Some(x)
end TermParamClause

given TermParamClauseMethods: TermParamClauseMethods with
extension (self: TermParamClause)
def params: List[ValDef] = self
end TermParamClauseMethods

type TypeParamClause = List[tpd.TypeDef]

given TypeParamClauseTypeTest: TypeTest[ParamClause, TypeParamClause] with
def unapply(x: ParamClause): Option[TypeParamClause & x.type] = x match
case tpd.TypeDefs(_) => Some(x.asInstanceOf[TypeParamClause & x.type])
case _ => None
end TypeParamClauseTypeTest

object TypeParamClause extends TypeParamClauseModule:
def apply(params: List[TypeDef]): TypeParamClause =
if params.nonEmpty then throw IllegalArgumentException("Empty type parameters")
params
def unapply(x: TypeParamClause): Some[List[TypeDef]] = Some(x)
end TypeParamClause

given TypeParamClauseMethods: TypeParamClauseMethods with
extension (self: TypeParamClause)
def params: List[TypeDef] = self
end TypeParamClauseMethods

type Selector = untpd.ImportSelector

object Selector extends SelectorModule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,14 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
}))
def copy(original: Tree)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term]): DefDef =
tpd.cpy.DefDef(original)(name.toTermName, tpd.joinParams(typeParams, paramss), tpt, yCheckedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
def unapply(ddef: DefDef): (String, List[TypeDef], List[List[ValDef]], TypeTree, Option[Term]) =
(ddef.name.toString, ddef.typeParams, ddef.termParamss, ddef.tpt, optional(ddef.rhs))
// def unapply(ddef: DefDef): (String, List[TypeDef], List[List[ValDef]], TypeTree, Option[Term]) =
// (ddef.name.toString, ddef.typeParams, ddef.termParamss, ddef.tpt, optional(ddef.rhs))
end DefDef

given DefDefMethods: DefDefMethods with
extension (self: DefDef)
def typeParams: List[TypeDef] = self.leadingTypeParams // TODO: adapt to multiple type parameter clauses
def paramss: List[List[ValDef]] = self.termParamss
// def paramss: List[List[ValDef]] = self.termParamss
def returnTpt: TypeTree = self.tpt
def rhs: Option[Term] = optional(self.rhs)
end extension
Expand Down Expand Up @@ -747,12 +747,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
val meth = dotc.core.Symbols.newSymbol(owner, nme.ANON_FUN, Synthetic | Method, tpe)
tpd.Closure(meth, tss => yCheckedOwners(rhsFn(meth, tss.head), meth))

def unapply(tree: Block): Option[(List[ValDef], Term)] = tree match {
case Block((ddef @ DefDef(_, _, params :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
if ddef.symbol == meth.symbol =>
Some((params, body))
case _ => None
}
def unapply(tree: Block): Option[(List[ValDef], Term)] = ???
end Lambda

type If = tpd.If
Expand Down
27 changes: 19 additions & 8 deletions compiler/src/scala/quoted/runtime/impl/Matcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ object Matcher {
}.transformTree(scrutinee)(Symbol.spliceOwner)
}
val names = args.map {
case Block(List(DefDef("$anonfun", _, _, _, Some(Apply(Ident(name), _)))), _) => name
case Block(List(DefDef("$anonfun", _, _, Some(Apply(Ident(name), _)))), _) => name
case arg => arg.symbol.name
}
val argTypes = args.map(x => x.tpe.widenTermRefExpr)
Expand Down Expand Up @@ -302,16 +302,19 @@ object Matcher {
tpt1 =?= tpt2 &&& treeOptMatches(rhs1, rhs2)(using rhsEnv)

/* Match def */
case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) =>
case (DefDef(_, paramss1, tpt1, Some(rhs1)), DefDef(_, paramss2, tpt2, Some(rhs2))) =>
def rhsEnv =
val paramSyms: List[(Symbol, Symbol)] =
for
(clause1, clause2) <- paramss1.zip(paramss2)
(param1, param2) <- clause1.params.zip(clause2.params)
yield
param1.symbol -> param2.symbol
val oldEnv: Env = summon[Env]
val newEnv: List[(Symbol, Symbol)] =
(scrutinee.symbol -> pattern.symbol) :: typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) :::
paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
val newEnv: List[(Symbol, Symbol)] = (scrutinee.symbol -> pattern.symbol) :: paramSyms
oldEnv ++ newEnv

typeParams1 =?= typeParams2
&&& matchLists(paramss1, paramss2)(_ =?= _)
matchLists(paramss1, paramss2)(_ =?= _)
&&& tpt1 =?= tpt2
&&& withEnv(rhsEnv)(rhs1 =?= rhs2)

Expand Down Expand Up @@ -343,6 +346,14 @@ object Matcher {
}
end extension

extension (scrutinee: ParamClause)
/** Check that all parameters in the clauses clauses match with =?= and concatenate the results with &&& */
private def =?= (pattern: ParamClause)(using Env)(using DummyImplicit): Matching =
(scrutinee, pattern) match
case (TermParamClause(params1), TermParamClause(params2)) => matchLists(params1, params2)(_ =?= _)
case (TypeParamClause(params1), TypeParamClause(params2)) => matchLists(params1, params2)(_ =?= _)
case _ => notMatched

/** Does the scrutenne symbol match the pattern symbol? It matches if:
* - They are the same symbol
* - The scrutinee has is in the environment and they are equivalent
Expand Down Expand Up @@ -382,7 +393,7 @@ object Matcher {
def unapply(args: List[Term]): Option[List[Ident]] =
args.foldRight(Option(List.empty[Ident])) {
case (id: Ident, Some(acc)) => Some(id :: acc)
case (Block(List(DefDef("$anonfun", Nil, List(params), Inferred(), Some(Apply(id: Ident, args)))), Closure(Ident("$anonfun"), None)), Some(acc))
case (Block(List(DefDef("$anonfun", TermParamClause(params) :: Nil, Inferred(), Some(Apply(id: Ident, args)))), Closure(Ident("$anonfun"), None)), Some(acc))
if params.zip(args).forall(_.symbol == _.symbol) =>
Some(id :: acc)
case _ => None
Expand Down
13 changes: 11 additions & 2 deletions compiler/src/scala/quoted/runtime/impl/printers/Extractors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ object Extractors {
this += ", " ++= bindings += ", " += expansion += ")"
case ValDef(name, tpt, rhs) =>
this += "ValDef(\"" += name += "\", " += tpt += ", " += rhs += ")"
case DefDef(name, typeParams, paramss, returnTpt, rhs) =>
this += "DefDef(\"" += name += "\", " ++= typeParams += ", " +++= paramss += ", " += returnTpt += ", " += rhs += ")"
case DefDef(name, paramsClauses, returnTpt, rhs) =>
this += "DefDef(\"" += name += "\", " ++= paramsClauses += ", " += returnTpt += ", " += rhs += ")"
case TypeDef(name, rhs) =>
this += "TypeDef(\"" += name += "\", " += rhs += ")"
case ClassDef(name, constr, parents, derived, self, body) =>
Expand Down Expand Up @@ -256,6 +256,11 @@ object Extractors {
else if x.isTypeDef then this += "IsTypeDefSymbol(<" += x.fullName += ">)"
else { assert(x.isNoSymbol); this += "NoSymbol()" }

def visitParamClause(x: ParamClause): this.type =
x match
case TermParamClause(params) => this += "TermParamClause(" ++= params += ")"
case TypeParamClause(params) => this += "TypeParamClause(" ++= params += ")"

def +=(x: Boolean): this.type = { sb.append(x); this }
def +=(x: Byte): this.type = { sb.append(x); this }
def +=(x: Short): this.type = { sb.append(x); this }
Expand Down Expand Up @@ -301,6 +306,10 @@ object Extractors {
def +=(x: Symbol): self.type = { visitSymbol(x); buff }
}

private implicit class ParamClauseOps(buff: self.type) {
def ++=(x: List[ParamClause]): self.type = { visitList(x, visitParamClause); buff }
}

private def visitOption[U](opt: Option[U], visit: U => this.type): this.type = opt match {
case Some(x) =>
this += "Some("
Expand Down
29 changes: 15 additions & 14 deletions compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ object SourceCode {
this += "."
printSelectors(selectors)

case cdef @ ClassDef(name, DefDef(_, targs, argss, _, _), parents, derived, self, stats) =>
case cdef @ ClassDef(name, DefDef(_, paramss, _, _), parents, derived, self, stats) =>
printDefAnnotations(cdef)

val flags = cdef.symbol.flags
Expand All @@ -155,12 +155,13 @@ object SourceCode {
else if (flags.is(Flags.Abstract)) this += highlightKeyword("abstract class ") += highlightTypeDef(name)
else this += highlightKeyword("class ") += highlightTypeDef(name)

val typeParams = stats.collect { case targ: TypeDef => targ }.filter(_.symbol.isTypeParam).zip(targs)
if (!flags.is(Flags.Module)) {
printTargsDefs(typeParams)
val it = argss.iterator
while (it.hasNext)
printArgsDefs(it.next())
for paramClause <- paramss do
paramClause match
case TermParamClause(params) =>
printArgsDefs(params)
case TypeParamClause(params) =>
printTargsDefs(stats.collect { case targ: TypeDef => targ }.filter(_.symbol.isTypeParam).zip(params))
}

val parents1 = parents.filter {
Expand Down Expand Up @@ -212,8 +213,8 @@ object SourceCode {
// Currently the compiler does not allow overriding some of the methods generated for case classes
d.symbol.flags.is(Flags.Synthetic) &&
(d match {
case DefDef("apply" | "unapply" | "writeReplace", _, _, _, _) if d.symbol.owner.flags.is(Flags.Module) => true
case DefDef(n, _, _, _, _) if d.symbol.owner.flags.is(Flags.Case) =>
case DefDef("apply" | "unapply" | "writeReplace", _, _, _) if d.symbol.owner.flags.is(Flags.Module) => true
case DefDef(n, _, _, _) if d.symbol.owner.flags.is(Flags.Case) =>
n == "copy" ||
n.matches("copy\\$default\\$[1-9][0-9]*") || // default parameters for the copy method
n.matches("_[1-9][0-9]*") || // Getters from Product
Expand Down Expand Up @@ -301,7 +302,7 @@ object SourceCode {
printTree(body)
}

case ddef @ DefDef(name, targs, argss, tpt, rhs) =>
case ddef @ DefDef(name, paramss, tpt, rhs) =>
printDefAnnotations(ddef)

val isConstructor = name == "<init>"
Expand All @@ -316,10 +317,10 @@ object SourceCode {

val name1: String = if (isConstructor) "this" else splicedName(ddef.symbol).getOrElse(name)
this += highlightKeyword("def ") += highlightValDef(name1)
printTargsDefs(targs.zip(targs))
val it = argss.iterator
while (it.hasNext)
printArgsDefs(it.next())
for clause <- paramss do
clause match
case TermParamClause(params) => printArgsDefs(params)
case TypeParamClause(params) => printTargsDefs(params.zip(params))
if (!isConstructor) {
this += ": "
printTypeTree(tpt)
Expand Down Expand Up @@ -1251,7 +1252,7 @@ object SourceCode {

private def printDefinitionName(tree: Definition): this.type = tree match {
case ValDef(name, _, _) => this += highlightValDef(name)
case DefDef(name, _, _, _, _) => this += highlightValDef(name)
case DefDef(name, _, _, _) => this += highlightValDef(name)
case ClassDef(name, _, _, _, _, _) => this += highlightTypeDef(name.stripSuffix("$"))
case TypeDef(name, _) => this += highlightTypeDef(name)
}
Expand Down
Loading

0 comments on commit 061c5fc

Please sign in to comment.