diff --git a/spec/compiler/parser/to_s_spec.cr b/spec/compiler/parser/to_s_spec.cr index 63da166aba45..d4838fc7945c 100644 --- a/spec/compiler/parser/to_s_spec.cr +++ b/spec/compiler/parser/to_s_spec.cr @@ -108,8 +108,28 @@ describe "ASTNode#to_s" do expect_to_s "def foo(x, @[Foo] **args)\nend" expect_to_s "def foo(x, **args, &block)\nend" expect_to_s "def foo(@[Foo] x, @[Bar] **args, @[Baz] &block)\nend" - expect_to_s "def foo(x, **args, &block : (_ -> _))\nend" - expect_to_s "def foo(& : (->))\nend" + + # 14216 + expect_to_s "def foo(x, **args, &block : _ -> _)\nend" + expect_to_s "def foo(x, **args, &block : (_ -> _))\nend", "def foo(x, **args, &block : _ -> _)\nend" + expect_to_s "def foo(& : ->)\nend" + expect_to_s "def foo(& : (->))\nend", "def foo(& : ->)\nend" + expect_to_s "def foo(x : (T -> U) -> V, *args : (T -> U) -> V, y : (T -> U) -> V, **opts : (T -> U) -> V, & : (T -> U) -> V) : ((T -> U) -> V)\nend" + expect_to_s "foo(x : (T -> U) -> V, W)" + expect_to_s "foo[x : (T -> U) -> V, W]" + expect_to_s "foo[x : (T -> U) -> V, W] = 1" + expect_to_s "lib LibFoo\n fun foo(x : (T -> U) -> V, W) : ((T -> U) -> V)\nend" + + expect_to_s "lib LibFoo\n fun foo(x : (T -> U) | V)\nend" + expect_to_s "lib LibFoo\n fun foo(x : Foo((T -> U)))\nend" + expect_to_s "lib LibFoo\n fun foo(x : (T -> U).class)\nend" + expect_to_s "def foo(x : (T -> U) | V)\nend" + expect_to_s "def foo(x : Foo((T -> U)))\nend" + expect_to_s "def foo(x : (T -> U).class)\nend" + expect_to_s "foo(x : (T -> U) | V)" + expect_to_s "foo(x : Foo((T -> U)))" + expect_to_s "foo(x : (T -> U).class)" + expect_to_s "macro foo(@[Foo] id)\nend" expect_to_s "macro foo(**args)\nend" expect_to_s "macro foo(@[Foo] **args)\nend" diff --git a/spec/compiler/semantic/restrictions_augmenter_spec.cr b/spec/compiler/semantic/restrictions_augmenter_spec.cr index 5095f8613943..2b7250658693 100644 --- a/spec/compiler/semantic/restrictions_augmenter_spec.cr +++ b/spec/compiler/semantic/restrictions_augmenter_spec.cr @@ -45,8 +45,8 @@ describe "Semantic: restrictions augmenter" do it_augments_for_ivar "Array(String)", "::Array(::String)" it_augments_for_ivar "Tuple(Int32, Char)", "::Tuple(::Int32, ::Char)" it_augments_for_ivar "NamedTuple(a: Int32, b: Char)", "::NamedTuple(a: ::Int32, b: ::Char)" - it_augments_for_ivar "Proc(Int32, Char)", "(::Int32 -> ::Char)" - it_augments_for_ivar "Proc(Int32, Nil)", "(::Int32 -> _)" + it_augments_for_ivar "Proc(Int32, Char)", "::Int32 -> ::Char" + it_augments_for_ivar "Proc(Int32, Nil)", "::Int32 -> _" it_augments_for_ivar "Pointer(Void)", "::Pointer(::Void)" it_augments_for_ivar "StaticArray(Int32, 8)", "::StaticArray(::Int32, 8)" it_augments_for_ivar "Char | Int32 | String", "::Char | ::Int32 | ::String" diff --git a/src/compiler/crystal/syntax/to_s.cr b/src/compiler/crystal/syntax/to_s.cr index 7735cddb3951..b94b3f6981f1 100644 --- a/src/compiler/crystal/syntax/to_s.cr +++ b/src/compiler/crystal/syntax/to_s.cr @@ -18,6 +18,13 @@ module Crystal @macro_expansion_pragmas : Hash(Int32, Array(Lexer::LocPragma))? @current_arg_type : DefArgType = :none + # Inside a comma-separated list of parameters or args, this becomes true and + # the outermost pair of parentheses are removed from type restrictions that + # are `ProcNotation` nodes, so `foo(x : (T, U -> V), W)` becomes + # `foo(x : T, U -> V, W)`. This is used by defs, lib funs, and calls to deal + # with the parsing rules for `->`. See #11966 and #14216 for more details. + getter? drop_parens_for_proc_notation = false + private enum DefArgType NONE SPLAT @@ -440,20 +447,20 @@ module Crystal break if exclude_last && i == node.args.size - 1 @str << ", " if printed_arg - arg.accept self + drop_parens_for_proc_notation(arg, &.accept(self)) printed_arg = true end if named_args = node.named_args named_args.each do |named_arg| @str << ", " if printed_arg - named_arg.accept self + drop_parens_for_proc_notation(named_arg, &.accept(self)) printed_arg = true end end if block_arg = node.block_arg @str << ", " if printed_arg @str << '&' - block_arg.accept self + drop_parens_for_proc_notation(block_arg, &.accept(self)) end end @@ -635,19 +642,19 @@ module Crystal node.args.each_with_index do |arg, i| @str << ", " if printed_arg @current_arg_type = :splat if node.splat_index == i - arg.accept self + drop_parens_for_proc_notation(arg, &.accept(self)) printed_arg = true end if double_splat = node.double_splat @current_arg_type = :double_splat @str << ", " if printed_arg - double_splat.accept self + drop_parens_for_proc_notation(double_splat, &.accept(self)) printed_arg = true end if block_arg = node.block_arg @current_arg_type = :block_arg @str << ", " if printed_arg - block_arg.accept self + drop_parens_for_proc_notation(block_arg, &.accept(self)) elsif node.block_arity @str << ", " if printed_arg @str << '&' @@ -680,6 +687,8 @@ module Crystal if node.args.size > 0 || node.block_arg || node.double_splat @str << '(' printed_arg = false + # NOTE: `drop_parens_for_proc_notation` needed here if macros support + # restrictions node.args.each_with_index do |arg, i| @str << ", " if printed_arg @current_arg_type = :splat if i == node.splat_index @@ -830,17 +839,24 @@ module Crystal end def visit(node : ProcNotation) - @str << '(' - if inputs = node.inputs - inputs.join(@str, ", ", &.accept self) - @str << ' ' - end - @str << "->" - if output = node.output - @str << ' ' - output.accept self + @str << '(' unless drop_parens_for_proc_notation? + + # only drop the outermost pair of parentheses; this produces + # `foo(x : (T -> U) -> V, W)`, not + # `foo(x : ((T -> U) -> V), W)` nor `foo(x : T -> U -> V, W)` + drop_parens_for_proc_notation(false) do + if inputs = node.inputs + inputs.join(@str, ", ", &.accept self) + @str << ' ' + end + @str << "->" + if output = node.output + @str << ' ' + output.accept self + end end - @str << ')' + + @str << ')' unless drop_parens_for_proc_notation? false end @@ -1147,7 +1163,9 @@ module Crystal if arg_name = arg.name.presence @str << arg_name << " : " end - arg.restriction.not_nil!.accept self + drop_parens_for_proc_notation(arg) do + arg.restriction.not_nil!.accept self + end end if node.varargs? @str << ", ..." @@ -1575,6 +1593,32 @@ module Crystal @inside_macro = old_inside_macro end + def drop_parens_for_proc_notation(drop : Bool = true, &) + old_drop_parens_for_proc_notation = @drop_parens_for_proc_notation + @drop_parens_for_proc_notation = drop + begin + yield + ensure + @drop_parens_for_proc_notation = old_drop_parens_for_proc_notation + end + end + + def drop_parens_for_proc_notation(node : ASTNode, &) + outermost_type_is_proc_notation = + case node + when Arg + # def / fun parameters + node.restriction.is_a?(ProcNotation) + when TypeDeclaration + # call arguments + node.declared_type.is_a?(ProcNotation) + else + false + end + + drop_parens_for_proc_notation(outermost_type_is_proc_notation) { yield node } + end + def to_s : String @str.to_s end