Skip to content

Commit

Permalink
Allow leading context parameters in extension methods
Browse files Browse the repository at this point in the history
  • Loading branch information
odersky authored and nicolasstucki committed Jan 5, 2021
1 parent 90b56b5 commit 913386c
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 12 deletions.
10 changes: 9 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,15 @@ object desugar {
tparams = ext.tparams ++ mdef.tparams,
vparamss = mdef.vparamss match
case vparams1 :: vparamss1 if mdef.name.isRightAssocOperatorName =>
vparams1 :: ext.vparamss ::: vparamss1
def badRightAssoc(problem: String) =
report.error(i"right-associative extension method $problem", mdef.srcPos)
ext.vparamss ::: vparamss1
vparams1 match
case vparam :: Nil =>
if !vparam.mods.is(Given) then vparams1 :: ext.vparamss ::: vparamss1
else badRightAssoc("cannot start with using clause")
case _ =>
badRightAssoc("must start with a single parameter")
case _ =>
ext.vparamss ++ mdef.vparamss
).withMods(mdef.mods | ExtensionMethod)
Expand Down
8 changes: 8 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,14 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
/** Is this case guarded? */
def isGuardedCase(cdef: CaseDef): Boolean = cdef.guard ne EmptyTree

/** Is this parameter list a using clause? */
def isUsingClause(vparams: List[ValDef])(using Context): Boolean = vparams match
case vparam :: _ =>
val sym = vparam.symbol
if sym.exists then sym.is(Given) else vparam.mods.is(Given)
case _ =>
false

/** The underlying pattern ignoring any bindings */
def unbind(x: Tree): Tree = unsplice(x) match {
case Bind(_, y) => unbind(y)
Expand Down
16 changes: 11 additions & 5 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2997,7 +2997,7 @@ object Parsers {
if in.token == RPAREN && !prefix && !impliedMods.is(Given) then Nil
else
val clause =
if prefix then param() :: Nil
if prefix && !isIdent(nme.using) then param() :: Nil
else
paramMods()
if givenOnly && !impliedMods.is(Given) then
Expand Down Expand Up @@ -3389,7 +3389,6 @@ object Parsers {
* | [‘case’] ‘object’ ObjectDef
* | ‘enum’ EnumDef
* | ‘given’ GivenDef
* | ‘extension’ ExtensionDef
*/
def tmplDef(start: Int, mods: Modifiers): Tree =
in.token match {
Expand Down Expand Up @@ -3566,8 +3565,15 @@ object Parsers {
def extension(): ExtMethods =
val start = in.skipToken()
val tparams = typeParamClauseOpt(ParamOwner.Def)
val extParams = paramClause(0, prefix = true)
val givenParamss = paramClauses(givenOnly = true)
val leadParamss = ListBuffer[List[ValDef]]()
var nparams = 0
while
val extParams = paramClause(nparams, prefix = true)
leadParamss += extParams
nparams += extParams.length
isUsingClause(extParams)
do ()
leadParamss ++= paramClauses(givenOnly = true)
if in.token == COLON then
syntaxError("no `:` expected here")
in.nextToken()
Expand All @@ -3579,7 +3585,7 @@ object Parsers {
newLineOptWhenFollowedBy(LBRACE)
if in.isNestedStart then inDefScopeBraces(extMethods())
else { syntaxError("Extension without extension methods"); Nil }
val result = atSpan(start)(ExtMethods(tparams, extParams :: givenParamss, methods))
val result = atSpan(start)(ExtMethods(tparams, leadParamss.toList, methods))
val comment = in.getDocComment(start)
if comment.isDefined then
for meth <- methods do
Expand Down
21 changes: 15 additions & 6 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -794,18 +794,27 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {

protected def defDefToText[T >: Untyped](tree: DefDef[T]): Text = {
import untpd._

def splitParams(paramss: List[List[ValDef]]): (List[List[ValDef]], List[List[ValDef]]) =
paramss match
case params1 :: (rest @ (_ :: _)) if tree.name.isRightAssocOperatorName =>
val (leading, trailing) = splitParams(rest)
(leading, params1 :: trailing)
case _ =>
val trailing = paramss
.dropWhile(isUsingClause)
.drop(1)
.dropWhile(isUsingClause)
(paramss.take(paramss.length - trailing.length), trailing)

dclTextOr(tree) {
val defKeyword = modText(tree.mods, tree.symbol, keywordStr("def"), isType = false)
val isExtension = tree.hasType && tree.symbol.is(ExtensionMethod)
withEnclosingDef(tree) {
val (prefix, vparamss) =
if isExtension then
val (leadingParams, otherParamss) = (tree.vparamss: @unchecked) match
case vparams1 :: vparams2 :: rest if tree.name.isRightAssocOperatorName =>
(vparams2, vparams1 :: rest)
case vparams1 :: rest =>
(vparams1, rest)
(keywordStr("extension") ~~ paramsText(leadingParams)
val (leadingParamss, otherParamss) = splitParams(tree.vparamss)
(addVparamssText(keywordStr("extension "), leadingParamss)
~~ (defKeyword ~~ valDefText(nameIdText(tree))).close,
otherParamss)
else (defKeyword ~~ valDefText(nameIdText(tree)), tree.vparamss)
Expand Down
8 changes: 8 additions & 0 deletions tests/neg/rightassoc-extmethod.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- Error: tests/neg/rightassoc-extmethod.scala:1:23 --------------------------------------------------------------------
1 |extension (x: Int) def +: (using String): Int = x // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| right-associative extension method cannot start with using clause
-- Error: tests/neg/rightassoc-extmethod.scala:2:23 --------------------------------------------------------------------
2 |extension (x: Int) def *: (y: Int, z: Int) = x // error
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
| right-associative extension method must start with a single parameter
3 changes: 3 additions & 0 deletions tests/neg/rightassoc-extmethod.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
extension (x: Int) def +: (using String): Int = x // error
extension (x: Int) def *: (y: Int, z: Int) = x // error

35 changes: 35 additions & 0 deletions tests/pos/i9530.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
trait Scope:
type Expr
type Value
def expr(x: String): Expr
def value(e: Expr): Value
def combine(e: Expr, str: String): Expr

extension (using s: Scope)(expr: s.Expr)
def show = expr.toString
def eval = s.value(expr)
def *: (str: String) = s.combine(expr, str)

def f(using s: Scope)(x: s.Expr): (String, s.Value) =
(x.show, x.eval)

given scope: Scope with
case class Expr(str: String)
type Value = Int
def expr(x: String) = Expr(x)
def value(e: Expr) = e.str.toInt
def combine(e: Expr, str: String) = Expr(e.str ++ str)

@main def Test =
val e = scope.Expr("123")
val (s, v) = f(e)
println(s)
println(v)
val ss = e.show
println(ss)
val vv = e.eval
println(vv)
val e2 = e *: "4"
println(e2.show)
println(e2.eval)

0 comments on commit 913386c

Please sign in to comment.