diff --git a/spec/compiler/parser/to_s_spec.cr b/spec/compiler/parser/to_s_spec.cr index 63da166aba45..f1af8d9cb586 100644 --- a/spec/compiler/parser/to_s_spec.cr +++ b/spec/compiler/parser/to_s_spec.cr @@ -108,8 +108,18 @@ describe "ASTNode#to_s" do expect_to_s "def foo(x, @[Foo] **args)\nend" expect_to_s "def foo(x, **args, &block)\nend" expect_to_s "def foo(@[Foo] x, @[Bar] **args, @[Baz] &block)\nend" - expect_to_s "def foo(x, **args, &block : (_ -> _))\nend" - expect_to_s "def foo(& : (->))\nend" + + # 14216 + expect_to_s "def foo(x, **args, &block : _ -> _)\nend" + expect_to_s "def foo(x, **args, &block : (_ -> _))\nend", "def foo(x, **args, &block : _ -> _)\nend" + expect_to_s "def foo(& : ->)\nend" + expect_to_s "def foo(& : (->))\nend", "def foo(& : ->)\nend" + expect_to_s "def foo(x : (T -> U) -> V, *args : (T -> U) -> V, y : (T -> U) -> V, **opts : (T -> U) -> V, & : (T -> U) -> V) : ((T -> U) -> V)\nend" + expect_to_s "foo(x : (T -> U) -> V, W)" + expect_to_s "foo[x : (T -> U) -> V, W]" + expect_to_s "foo[x : (T -> U) -> V, W] = 1" + expect_to_s "lib LibFoo\n fun foo(x : (T -> U) -> V, W) : ((T -> U) -> V)\nend" + expect_to_s "macro foo(@[Foo] id)\nend" expect_to_s "macro foo(**args)\nend" expect_to_s "macro foo(@[Foo] **args)\nend" diff --git a/src/compiler/crystal/syntax/to_s.cr b/src/compiler/crystal/syntax/to_s.cr index 7735cddb3951..58c553ea18e6 100644 --- a/src/compiler/crystal/syntax/to_s.cr +++ b/src/compiler/crystal/syntax/to_s.cr @@ -18,6 +18,13 @@ module Crystal @macro_expansion_pragmas : Hash(Int32, Array(Lexer::LocPragma))? @current_arg_type : DefArgType = :none + # Inside a comma-separated list of parameters or args, this becomes true and + # the outermost pair of parentheses are removed from type restrictions that + # are `ProcNotation` nodes, so `foo(x : (T, U -> V), W)` becomes + # `foo(x : T, U -> V, W)`. This is used by defs, lib funs, and calls to deal + # with the parsing rules for `->`. See #11966 and #14216 for more details. + getter? drop_parens_for_proc_notation = false + private enum DefArgType NONE SPLAT @@ -436,24 +443,26 @@ module Crystal private def visit_args(node, exclude_last = false) printed_arg = false - node.args.each_with_index do |arg, i| - break if exclude_last && i == node.args.size - 1 + drop_parens_for_proc_notation do + node.args.each_with_index do |arg, i| + break if exclude_last && i == node.args.size - 1 - @str << ", " if printed_arg - arg.accept self - printed_arg = true - end - if named_args = node.named_args - named_args.each do |named_arg| @str << ", " if printed_arg - named_arg.accept self + arg.accept self printed_arg = true end - end - if block_arg = node.block_arg - @str << ", " if printed_arg - @str << '&' - block_arg.accept self + if named_args = node.named_args + named_args.each do |named_arg| + @str << ", " if printed_arg + named_arg.accept self + printed_arg = true + end + end + if block_arg = node.block_arg + @str << ", " if printed_arg + @str << '&' + block_arg.accept self + end end end @@ -632,25 +641,27 @@ module Crystal if node.args.size > 0 || node.block_arity || node.double_splat @str << '(' printed_arg = false - node.args.each_with_index do |arg, i| - @str << ", " if printed_arg - @current_arg_type = :splat if node.splat_index == i - arg.accept self - printed_arg = true - end - if double_splat = node.double_splat - @current_arg_type = :double_splat - @str << ", " if printed_arg - double_splat.accept self - printed_arg = true - end - if block_arg = node.block_arg - @current_arg_type = :block_arg - @str << ", " if printed_arg - block_arg.accept self - elsif node.block_arity - @str << ", " if printed_arg - @str << '&' + drop_parens_for_proc_notation do + node.args.each_with_index do |arg, i| + @str << ", " if printed_arg + @current_arg_type = :splat if node.splat_index == i + arg.accept self + printed_arg = true + end + if double_splat = node.double_splat + @current_arg_type = :double_splat + @str << ", " if printed_arg + double_splat.accept self + printed_arg = true + end + if block_arg = node.block_arg + @current_arg_type = :block_arg + @str << ", " if printed_arg + block_arg.accept self + elsif node.block_arity + @str << ", " if printed_arg + @str << '&' + end end @str << ')' end @@ -680,6 +691,8 @@ module Crystal if node.args.size > 0 || node.block_arg || node.double_splat @str << '(' printed_arg = false + # NOTE: `drop_parens_for_proc_notation` needed here if macros support + # restrictions node.args.each_with_index do |arg, i| @str << ", " if printed_arg @current_arg_type = :splat if i == node.splat_index @@ -830,17 +843,24 @@ module Crystal end def visit(node : ProcNotation) - @str << '(' - if inputs = node.inputs - inputs.join(@str, ", ", &.accept self) - @str << ' ' - end - @str << "->" - if output = node.output - @str << ' ' - output.accept self + @str << '(' unless drop_parens_for_proc_notation? + + # only drop the outermost pair of parentheses; this produces + # `foo(x : (T -> U) -> V, W)`, not + # `foo(x : ((T -> U) -> V), W)` nor `foo(x : T -> U -> V, W)` + drop_parens_for_proc_notation(false) do + if inputs = node.inputs + inputs.join(@str, ", ", &.accept self) + @str << ' ' + end + @str << "->" + if output = node.output + @str << ' ' + output.accept self + end end - @str << ')' + + @str << ')' unless drop_parens_for_proc_notation? false end @@ -1143,14 +1163,16 @@ module Crystal end if node.args.size > 0 @str << '(' - node.args.join(@str, ", ") do |arg| - if arg_name = arg.name.presence - @str << arg_name << " : " + drop_parens_for_proc_notation do + node.args.join(@str, ", ") do |arg| + if arg_name = arg.name.presence + @str << arg_name << " : " + end + arg.restriction.not_nil!.accept self + end + if node.varargs? + @str << ", ..." end - arg.restriction.not_nil!.accept self - end - if node.varargs? - @str << ", ..." end @str << ')' elsif node.varargs? @@ -1575,6 +1597,16 @@ module Crystal @inside_macro = old_inside_macro end + def drop_parens_for_proc_notation(drop : Bool = true, &) + old_drop_parens_for_proc_notation = @drop_parens_for_proc_notation + @drop_parens_for_proc_notation = drop + begin + yield + ensure + @drop_parens_for_proc_notation = old_drop_parens_for_proc_notation + end + end + def to_s : String @str.to_s end