diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 65085b39e7bc..768af3bf6474 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -878,7 +878,15 @@ object desugar { tparams = ext.tparams ++ mdef.tparams, vparamss = mdef.vparamss match case vparams1 :: vparamss1 if mdef.name.isRightAssocOperatorName => - vparams1 :: ext.vparamss ::: vparamss1 + def badRightAssoc(problem: String) = + report.error(i"right-associative extension method $problem", mdef.srcPos) + ext.vparamss ::: vparamss1 + vparams1 match + case vparam :: Nil => + if !vparam.mods.is(Given) then vparams1 :: ext.vparamss ::: vparamss1 + else badRightAssoc("cannot start with using clause") + case _ => + badRightAssoc("must start with a single parameter") case _ => ext.vparamss ++ mdef.vparamss ).withMods(mdef.mods | ExtensionMethod) diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index f29bde70744d..5e88aba081fe 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -245,6 +245,14 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] => /** Is this case guarded? */ def isGuardedCase(cdef: CaseDef): Boolean = cdef.guard ne EmptyTree + /** Is this parameter list a using clause? */ + def isUsingClause(vparams: List[ValDef])(using Context): Boolean = vparams match + case vparam :: _ => + val sym = vparam.symbol + if sym.exists then sym.is(Given) else vparam.mods.is(Given) + case _ => + false + /** The underlying pattern ignoring any bindings */ def unbind(x: Tree): Tree = unsplice(x) match { case Bind(_, y) => unbind(y) diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 9849400da403..24b81f32cd5b 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -2997,7 +2997,7 @@ object Parsers { if in.token == RPAREN && !prefix && !impliedMods.is(Given) then Nil else val clause = - if prefix then param() :: Nil + if prefix && !isIdent(nme.using) then param() :: Nil else paramMods() if givenOnly && !impliedMods.is(Given) then @@ -3389,7 +3389,6 @@ object Parsers { * | [‘case’] ‘object’ ObjectDef * | ‘enum’ EnumDef * | ‘given’ GivenDef - * | ‘extension’ ExtensionDef */ def tmplDef(start: Int, mods: Modifiers): Tree = in.token match { @@ -3566,8 +3565,15 @@ object Parsers { def extension(): ExtMethods = val start = in.skipToken() val tparams = typeParamClauseOpt(ParamOwner.Def) - val extParams = paramClause(0, prefix = true) - val givenParamss = paramClauses(givenOnly = true) + val leadParamss = ListBuffer[List[ValDef]]() + var nparams = 0 + while + val extParams = paramClause(nparams, prefix = true) + leadParamss += extParams + nparams += extParams.length + isUsingClause(extParams) + do () + leadParamss ++= paramClauses(givenOnly = true) if in.token == COLON then syntaxError("no `:` expected here") in.nextToken() @@ -3579,7 +3585,7 @@ object Parsers { newLineOptWhenFollowedBy(LBRACE) if in.isNestedStart then inDefScopeBraces(extMethods()) else { syntaxError("Extension without extension methods"); Nil } - val result = atSpan(start)(ExtMethods(tparams, extParams :: givenParamss, methods)) + val result = atSpan(start)(ExtMethods(tparams, leadParamss.toList, methods)) val comment = in.getDocComment(start) if comment.isDefined then for meth <- methods do diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index dabc9d25acf1..2cc7cbd6933e 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -794,18 +794,27 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { protected def defDefToText[T >: Untyped](tree: DefDef[T]): Text = { import untpd._ + + def splitParams(paramss: List[List[ValDef]]): (List[List[ValDef]], List[List[ValDef]]) = + paramss match + case params1 :: (rest @ (_ :: _)) if tree.name.isRightAssocOperatorName => + val (leading, trailing) = splitParams(rest) + (leading, params1 :: trailing) + case _ => + val trailing = paramss + .dropWhile(isUsingClause) + .drop(1) + .dropWhile(isUsingClause) + (paramss.take(paramss.length - trailing.length), trailing) + dclTextOr(tree) { val defKeyword = modText(tree.mods, tree.symbol, keywordStr("def"), isType = false) val isExtension = tree.hasType && tree.symbol.is(ExtensionMethod) withEnclosingDef(tree) { val (prefix, vparamss) = if isExtension then - val (leadingParams, otherParamss) = (tree.vparamss: @unchecked) match - case vparams1 :: vparams2 :: rest if tree.name.isRightAssocOperatorName => - (vparams2, vparams1 :: rest) - case vparams1 :: rest => - (vparams1, rest) - (keywordStr("extension") ~~ paramsText(leadingParams) + val (leadingParamss, otherParamss) = splitParams(tree.vparamss) + (addVparamssText(keywordStr("extension "), leadingParamss) ~~ (defKeyword ~~ valDefText(nameIdText(tree))).close, otherParamss) else (defKeyword ~~ valDefText(nameIdText(tree)), tree.vparamss) diff --git a/tests/neg/rightassoc-extmethod.check b/tests/neg/rightassoc-extmethod.check new file mode 100644 index 000000000000..a1d2328ed2ff --- /dev/null +++ b/tests/neg/rightassoc-extmethod.check @@ -0,0 +1,8 @@ +-- Error: tests/neg/rightassoc-extmethod.scala:1:23 -------------------------------------------------------------------- +1 |extension (x: Int) def +: (using String): Int = x // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | right-associative extension method cannot start with using clause +-- Error: tests/neg/rightassoc-extmethod.scala:2:23 -------------------------------------------------------------------- +2 |extension (x: Int) def *: (y: Int, z: Int) = x // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | right-associative extension method must start with a single parameter diff --git a/tests/neg/rightassoc-extmethod.scala b/tests/neg/rightassoc-extmethod.scala new file mode 100644 index 000000000000..4a136ca6eac3 --- /dev/null +++ b/tests/neg/rightassoc-extmethod.scala @@ -0,0 +1,3 @@ +extension (x: Int) def +: (using String): Int = x // error +extension (x: Int) def *: (y: Int, z: Int) = x // error + diff --git a/tests/pos/i9530.scala b/tests/pos/i9530.scala new file mode 100644 index 000000000000..20aee0bf4940 --- /dev/null +++ b/tests/pos/i9530.scala @@ -0,0 +1,35 @@ +trait Scope: + type Expr + type Value + def expr(x: String): Expr + def value(e: Expr): Value + def combine(e: Expr, str: String): Expr + +extension (using s: Scope)(expr: s.Expr) + def show = expr.toString + def eval = s.value(expr) + def *: (str: String) = s.combine(expr, str) + +def f(using s: Scope)(x: s.Expr): (String, s.Value) = + (x.show, x.eval) + +given scope: Scope with + case class Expr(str: String) + type Value = Int + def expr(x: String) = Expr(x) + def value(e: Expr) = e.str.toInt + def combine(e: Expr, str: String) = Expr(e.str ++ str) + +@main def Test = + val e = scope.Expr("123") + val (s, v) = f(e) + println(s) + println(v) + val ss = e.show + println(ss) + val vv = e.eval + println(vv) + val e2 = e *: "4" + println(e2.show) + println(e2.eval) +