Skip to content

Commit

Permalink
Add ParamClause to allow multiple type param clauses
Browse files Browse the repository at this point in the history
Reflects the changes done in scala#10940
  • Loading branch information
nicolasstucki committed Jan 19, 2021
1 parent 23ba6c3 commit c50aa96
Show file tree
Hide file tree
Showing 32 changed files with 367 additions and 152 deletions.
2 changes: 1 addition & 1 deletion community-build/community-projects/sourcecode
2 changes: 1 addition & 1 deletion community-build/community-projects/upickle
Original file line number Diff line number Diff line change
Expand Up @@ -256,21 +256,22 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
end DefDefTypeTest

object DefDef extends DefDefModule:
def apply(symbol: Symbol, rhsFn: List[TypeRepr] => List[List[Term]] => Option[Term]): DefDef =
withDefaultPos(tpd.DefDef(symbol.asTerm, prefss => {
val (tparams, vparamss) = tpd.splitArgs(prefss)
yCheckedOwners(rhsFn(tparams.map(_.tpe))(vparamss), symbol).getOrElse(tpd.EmptyTree)
}))
def copy(original: Tree)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term]): DefDef =
tpd.cpy.DefDef(original)(name.toTermName, tpd.joinParams(typeParams, paramss), tpt, yCheckedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
def unapply(ddef: DefDef): (String, List[TypeDef], List[List[ValDef]], TypeTree, Option[Term]) =
(ddef.name.toString, ddef.typeParams, ddef.termParamss, ddef.tpt, optional(ddef.rhs))
def apply(symbol: Symbol, rhsFn: List[List[Tree]] => Option[Term]): DefDef =
withDefaultPos(tpd.DefDef(symbol.asTerm, prefss =>
yCheckedOwners(rhsFn(prefss), symbol).getOrElse(tpd.EmptyTree)
))
def copy(original: Tree)(name: String, paramss: List[ParamClause], tpt: TypeTree, rhs: Option[Term]): DefDef =
tpd.cpy.DefDef(original)(name.toTermName, paramss, tpt, yCheckedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
def unapply(ddef: DefDef): (String, List[ParamClause], TypeTree, Option[Term]) =
(ddef.name.toString, ddef.paramss, ddef.tpt, optional(ddef.rhs))
end DefDef

given DefDefMethods: DefDefMethods with
extension (self: DefDef)
def typeParams: List[TypeDef] = self.leadingTypeParams // TODO: adapt to multiple type parameter clauses
def paramss: List[List[ValDef]] = self.termParamss
def paramss: List[ParamClause] = self.paramss
def leadingTypeParams: List[TypeDef] = self.leadingTypeParams
def trailingParamss: List[ParamClause] = self.trailingParamss
def termParamss: List[TermParamClause] = self.termParamss
def returnTpt: TypeTree = self.tpt
def rhs: Option[Term] = optional(self.rhs)
end extension
Expand Down Expand Up @@ -750,7 +751,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
tpd.Closure(meth, tss => yCheckedOwners(rhsFn(meth, tss.head.map(withDefaultPos)), meth))

def unapply(tree: Block): Option[(List[ValDef], Term)] = tree match {
case Block((ddef @ DefDef(_, _, params :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
case Block((ddef @ DefDef(_, TermParamClause(params) :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
if ddef.symbol == meth.symbol =>
Some((params, body))
case _ => None
Expand Down Expand Up @@ -1481,6 +1482,59 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
end extension
end AlternativesMethods

type ParamClause = tpd.ParamClause

object ParamClause extends ParamClauseModule

given ParamClauseMethods: ParamClauseMethods with
extension (self: ParamClause)
def params: List[ValDef] | List[TypeDef] = self
end ParamClauseMethods

type TermParamClause = List[tpd.ValDef]

given TermParamClauseTypeTest: TypeTest[ParamClause, TermParamClause] with
def unapply(x: ParamClause): Option[TermParamClause & x.type] = x match
case tpd.ValDefs(_) => Some(x.asInstanceOf[TermParamClause & x.type])
case _ => None
end TermParamClauseTypeTest

object TermParamClause extends TermParamClauseModule:
def apply(params: List[ValDef]): TermParamClause =
if yCheck then
val implicitParams = params.count(_.symbol.is(dotc.core.Flags.Implicit))
assert(implicitParams == 0 || implicitParams == params.size, "Expected all or non of parameters to be implicit")
params
def unapply(x: TermParamClause): Some[List[ValDef]] = Some(x)
end TermParamClause

given TermParamClauseMethods: TermParamClauseMethods with
extension (self: TermParamClause)
def params: List[ValDef] = self
def isImplicit: Boolean =
self.nonEmpty && self.head.symbol.is(dotc.core.Flags.Implicit)
end TermParamClauseMethods

type TypeParamClause = List[tpd.TypeDef]

given TypeParamClauseTypeTest: TypeTest[ParamClause, TypeParamClause] with
def unapply(x: ParamClause): Option[TypeParamClause & x.type] = x match
case tpd.TypeDefs(_) => Some(x.asInstanceOf[TypeParamClause & x.type])
case _ => None
end TypeParamClauseTypeTest

object TypeParamClause extends TypeParamClauseModule:
def apply(params: List[TypeDef]): TypeParamClause =
if params.isEmpty then throw IllegalArgumentException("Empty type parameters")
params
def unapply(x: TypeParamClause): Some[List[TypeDef]] = Some(x)
end TypeParamClause

given TypeParamClauseMethods: TypeParamClauseMethods with
extension (self: TypeParamClause)
def params: List[TypeDef] = self
end TypeParamClauseMethods

type Selector = untpd.ImportSelector

object Selector extends SelectorModule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,14 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
}))
def copy(original: Tree)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term]): DefDef =
tpd.cpy.DefDef(original)(name.toTermName, tpd.joinParams(typeParams, paramss), tpt, yCheckedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
def unapply(ddef: DefDef): (String, List[TypeDef], List[List[ValDef]], TypeTree, Option[Term]) =
(ddef.name.toString, ddef.typeParams, ddef.termParamss, ddef.tpt, optional(ddef.rhs))
// def unapply(ddef: DefDef): (String, List[TypeDef], List[List[ValDef]], TypeTree, Option[Term]) =
// (ddef.name.toString, ddef.typeParams, ddef.termParamss, ddef.tpt, optional(ddef.rhs))
end DefDef

given DefDefMethods: DefDefMethods with
extension (self: DefDef)
def typeParams: List[TypeDef] = self.leadingTypeParams // TODO: adapt to multiple type parameter clauses
def paramss: List[List[ValDef]] = self.termParamss
// def paramss: List[List[ValDef]] = self.termParamss
def returnTpt: TypeTree = self.tpt
def rhs: Option[Term] = optional(self.rhs)
end extension
Expand Down Expand Up @@ -747,12 +747,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
val meth = dotc.core.Symbols.newSymbol(owner, nme.ANON_FUN, Synthetic | Method, tpe)
tpd.Closure(meth, tss => yCheckedOwners(rhsFn(meth, tss.head), meth))

def unapply(tree: Block): Option[(List[ValDef], Term)] = tree match {
case Block((ddef @ DefDef(_, _, params :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
if ddef.symbol == meth.symbol =>
Some((params, body))
case _ => None
}
def unapply(tree: Block): Option[(List[ValDef], Term)] = ???
end Lambda

type If = tpd.If
Expand Down
27 changes: 19 additions & 8 deletions compiler/src/scala/quoted/runtime/impl/Matcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ object Matcher {
}.transformTree(scrutinee)(Symbol.spliceOwner)
}
val names = args.map {
case Block(List(DefDef("$anonfun", _, _, _, Some(Apply(Ident(name), _)))), _) => name
case Block(List(DefDef("$anonfun", _, _, Some(Apply(Ident(name), _)))), _) => name
case arg => arg.symbol.name
}
val argTypes = args.map(x => x.tpe.widenTermRefByName)
Expand Down Expand Up @@ -302,16 +302,19 @@ object Matcher {
tpt1 =?= tpt2 &&& treeOptMatches(rhs1, rhs2)(using rhsEnv)

/* Match def */
case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) =>
case (DefDef(_, paramss1, tpt1, Some(rhs1)), DefDef(_, paramss2, tpt2, Some(rhs2))) =>
def rhsEnv =
val paramSyms: List[(Symbol, Symbol)] =
for
(clause1, clause2) <- paramss1.zip(paramss2)
(param1, param2) <- clause1.params.zip(clause2.params)
yield
param1.symbol -> param2.symbol
val oldEnv: Env = summon[Env]
val newEnv: List[(Symbol, Symbol)] =
(scrutinee.symbol -> pattern.symbol) :: typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) :::
paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
val newEnv: List[(Symbol, Symbol)] = (scrutinee.symbol -> pattern.symbol) :: paramSyms
oldEnv ++ newEnv

typeParams1 =?= typeParams2
&&& matchLists(paramss1, paramss2)(_ =?= _)
matchLists(paramss1, paramss2)(_ =?= _)
&&& tpt1 =?= tpt2
&&& withEnv(rhsEnv)(rhs1 =?= rhs2)

Expand Down Expand Up @@ -343,6 +346,14 @@ object Matcher {
}
end extension

extension (scrutinee: ParamClause)
/** Check that all parameters in the clauses clauses match with =?= and concatenate the results with &&& */
private def =?= (pattern: ParamClause)(using Env)(using DummyImplicit): Matching =
(scrutinee, pattern) match
case (TermParamClause(params1), TermParamClause(params2)) => matchLists(params1, params2)(_ =?= _)
case (TypeParamClause(params1), TypeParamClause(params2)) => matchLists(params1, params2)(_ =?= _)
case _ => notMatched

/** Does the scrutenne symbol match the pattern symbol? It matches if:
* - They are the same symbol
* - The scrutinee has is in the environment and they are equivalent
Expand Down Expand Up @@ -382,7 +393,7 @@ object Matcher {
def unapply(args: List[Term]): Option[List[Ident]] =
args.foldRight(Option(List.empty[Ident])) {
case (id: Ident, Some(acc)) => Some(id :: acc)
case (Block(List(DefDef("$anonfun", Nil, List(params), Inferred(), Some(Apply(id: Ident, args)))), Closure(Ident("$anonfun"), None)), Some(acc))
case (Block(List(DefDef("$anonfun", TermParamClause(params) :: Nil, Inferred(), Some(Apply(id: Ident, args)))), Closure(Ident("$anonfun"), None)), Some(acc))
if params.zip(args).forall(_.symbol == _.symbol) =>
Some(id :: acc)
case _ => None
Expand Down
13 changes: 11 additions & 2 deletions compiler/src/scala/quoted/runtime/impl/printers/Extractors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ object Extractors {
this += ", " ++= bindings += ", " += expansion += ")"
case ValDef(name, tpt, rhs) =>
this += "ValDef(\"" += name += "\", " += tpt += ", " += rhs += ")"
case DefDef(name, typeParams, paramss, returnTpt, rhs) =>
this += "DefDef(\"" += name += "\", " ++= typeParams += ", " +++= paramss += ", " += returnTpt += ", " += rhs += ")"
case DefDef(name, paramsClauses, returnTpt, rhs) =>
this += "DefDef(\"" += name += "\", " ++= paramsClauses += ", " += returnTpt += ", " += rhs += ")"
case TypeDef(name, rhs) =>
this += "TypeDef(\"" += name += "\", " += rhs += ")"
case ClassDef(name, constr, parents, derived, self, body) =>
Expand Down Expand Up @@ -256,6 +256,11 @@ object Extractors {
else if x.isTypeDef then this += "IsTypeDefSymbol(<" += x.fullName += ">)"
else { assert(x.isNoSymbol); this += "NoSymbol()" }

def visitParamClause(x: ParamClause): this.type =
x match
case TermParamClause(params) => this += "TermParamClause(" ++= params += ")"
case TypeParamClause(params) => this += "TypeParamClause(" ++= params += ")"

def +=(x: Boolean): this.type = { sb.append(x); this }
def +=(x: Byte): this.type = { sb.append(x); this }
def +=(x: Short): this.type = { sb.append(x); this }
Expand Down Expand Up @@ -301,6 +306,10 @@ object Extractors {
def +=(x: Symbol): self.type = { visitSymbol(x); buff }
}

private implicit class ParamClauseOps(buff: self.type) {
def ++=(x: List[ParamClause]): self.type = { visitList(x, visitParamClause); buff }
}

private def visitOption[U](opt: Option[U], visit: U => this.type): this.type = opt match {
case Some(x) =>
this += "Some("
Expand Down
29 changes: 15 additions & 14 deletions compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ object SourceCode {
this += "."
printSelectors(selectors)

case cdef @ ClassDef(name, DefDef(_, targs, argss, _, _), parents, derived, self, stats) =>
case cdef @ ClassDef(name, DefDef(_, paramss, _, _), parents, derived, self, stats) =>
printDefAnnotations(cdef)

val flags = cdef.symbol.flags
Expand All @@ -155,12 +155,13 @@ object SourceCode {
else if (flags.is(Flags.Abstract)) this += highlightKeyword("abstract class ") += highlightTypeDef(name)
else this += highlightKeyword("class ") += highlightTypeDef(name)

val typeParams = stats.collect { case targ: TypeDef => targ }.filter(_.symbol.isTypeParam).zip(targs)
if (!flags.is(Flags.Module)) {
printTargsDefs(typeParams)
val it = argss.iterator
while (it.hasNext)
printArgsDefs(it.next())
for paramClause <- paramss do
paramClause match
case TermParamClause(params) =>
printArgsDefs(params)
case TypeParamClause(params) =>
printTargsDefs(stats.collect { case targ: TypeDef => targ }.filter(_.symbol.isTypeParam).zip(params))
}

val parents1 = parents.filter {
Expand Down Expand Up @@ -212,8 +213,8 @@ object SourceCode {
// Currently the compiler does not allow overriding some of the methods generated for case classes
d.symbol.flags.is(Flags.Synthetic) &&
(d match {
case DefDef("apply" | "unapply" | "writeReplace", _, _, _, _) if d.symbol.owner.flags.is(Flags.Module) => true
case DefDef(n, _, _, _, _) if d.symbol.owner.flags.is(Flags.Case) =>
case DefDef("apply" | "unapply" | "writeReplace", _, _, _) if d.symbol.owner.flags.is(Flags.Module) => true
case DefDef(n, _, _, _) if d.symbol.owner.flags.is(Flags.Case) =>
n == "copy" ||
n.matches("copy\\$default\\$[1-9][0-9]*") || // default parameters for the copy method
n.matches("_[1-9][0-9]*") || // Getters from Product
Expand Down Expand Up @@ -301,7 +302,7 @@ object SourceCode {
printTree(body)
}

case ddef @ DefDef(name, targs, argss, tpt, rhs) =>
case ddef @ DefDef(name, paramss, tpt, rhs) =>
printDefAnnotations(ddef)

val isConstructor = name == "<init>"
Expand All @@ -316,10 +317,10 @@ object SourceCode {

val name1: String = if (isConstructor) "this" else splicedName(ddef.symbol).getOrElse(name)
this += highlightKeyword("def ") += highlightValDef(name1)
printTargsDefs(targs.zip(targs))
val it = argss.iterator
while (it.hasNext)
printArgsDefs(it.next())
for clause <- paramss do
clause match
case TermParamClause(params) => printArgsDefs(params)
case TypeParamClause(params) => printTargsDefs(params.zip(params))
if (!isConstructor) {
this += ": "
printTypeTree(tpt)
Expand Down Expand Up @@ -1251,7 +1252,7 @@ object SourceCode {

private def printDefinitionName(tree: Definition): this.type = tree match {
case ValDef(name, _, _) => this += highlightValDef(name)
case DefDef(name, _, _, _, _) => this += highlightValDef(name)
case DefDef(name, _, _, _) => this += highlightValDef(name)
case ClassDef(name, _, _, _, _, _) => this += highlightTypeDef(name.stripSuffix("$"))
case TypeDef(name, _) => this += highlightTypeDef(name)
}
Expand Down
Loading

0 comments on commit c50aa96

Please sign in to comment.