From 8b4ba6dcafaaefec912341398174606e8ac0d5d2 Mon Sep 17 00:00:00 2001 From: Quinton Miller Date: Thu, 18 Apr 2024 17:38:16 +0800 Subject: [PATCH 1/3] Drop parentheses around `->` inside certain comma-separated lists --- spec/compiler/parser/to_s_spec.cr | 14 ++- src/compiler/crystal/syntax/to_s.cr | 132 +++++++++++++++++----------- 2 files changed, 94 insertions(+), 52 deletions(-) diff --git a/spec/compiler/parser/to_s_spec.cr b/spec/compiler/parser/to_s_spec.cr index 63da166aba45..f1af8d9cb586 100644 --- a/spec/compiler/parser/to_s_spec.cr +++ b/spec/compiler/parser/to_s_spec.cr @@ -108,8 +108,18 @@ describe "ASTNode#to_s" do expect_to_s "def foo(x, @[Foo] **args)\nend" expect_to_s "def foo(x, **args, &block)\nend" expect_to_s "def foo(@[Foo] x, @[Bar] **args, @[Baz] &block)\nend" - expect_to_s "def foo(x, **args, &block : (_ -> _))\nend" - expect_to_s "def foo(& : (->))\nend" + + # 14216 + expect_to_s "def foo(x, **args, &block : _ -> _)\nend" + expect_to_s "def foo(x, **args, &block : (_ -> _))\nend", "def foo(x, **args, &block : _ -> _)\nend" + expect_to_s "def foo(& : ->)\nend" + expect_to_s "def foo(& : (->))\nend", "def foo(& : ->)\nend" + expect_to_s "def foo(x : (T -> U) -> V, *args : (T -> U) -> V, y : (T -> U) -> V, **opts : (T -> U) -> V, & : (T -> U) -> V) : ((T -> U) -> V)\nend" + expect_to_s "foo(x : (T -> U) -> V, W)" + expect_to_s "foo[x : (T -> U) -> V, W]" + expect_to_s "foo[x : (T -> U) -> V, W] = 1" + expect_to_s "lib LibFoo\n fun foo(x : (T -> U) -> V, W) : ((T -> U) -> V)\nend" + expect_to_s "macro foo(@[Foo] id)\nend" expect_to_s "macro foo(**args)\nend" expect_to_s "macro foo(@[Foo] **args)\nend" diff --git a/src/compiler/crystal/syntax/to_s.cr b/src/compiler/crystal/syntax/to_s.cr index 7735cddb3951..58c553ea18e6 100644 --- a/src/compiler/crystal/syntax/to_s.cr +++ b/src/compiler/crystal/syntax/to_s.cr @@ -18,6 +18,13 @@ module Crystal @macro_expansion_pragmas : Hash(Int32, Array(Lexer::LocPragma))? @current_arg_type : DefArgType = :none + # Inside a comma-separated list of parameters or args, this becomes true and + # the outermost pair of parentheses are removed from type restrictions that + # are `ProcNotation` nodes, so `foo(x : (T, U -> V), W)` becomes + # `foo(x : T, U -> V, W)`. This is used by defs, lib funs, and calls to deal + # with the parsing rules for `->`. See #11966 and #14216 for more details. + getter? drop_parens_for_proc_notation = false + private enum DefArgType NONE SPLAT @@ -436,24 +443,26 @@ module Crystal private def visit_args(node, exclude_last = false) printed_arg = false - node.args.each_with_index do |arg, i| - break if exclude_last && i == node.args.size - 1 + drop_parens_for_proc_notation do + node.args.each_with_index do |arg, i| + break if exclude_last && i == node.args.size - 1 - @str << ", " if printed_arg - arg.accept self - printed_arg = true - end - if named_args = node.named_args - named_args.each do |named_arg| @str << ", " if printed_arg - named_arg.accept self + arg.accept self printed_arg = true end - end - if block_arg = node.block_arg - @str << ", " if printed_arg - @str << '&' - block_arg.accept self + if named_args = node.named_args + named_args.each do |named_arg| + @str << ", " if printed_arg + named_arg.accept self + printed_arg = true + end + end + if block_arg = node.block_arg + @str << ", " if printed_arg + @str << '&' + block_arg.accept self + end end end @@ -632,25 +641,27 @@ module Crystal if node.args.size > 0 || node.block_arity || node.double_splat @str << '(' printed_arg = false - node.args.each_with_index do |arg, i| - @str << ", " if printed_arg - @current_arg_type = :splat if node.splat_index == i - arg.accept self - printed_arg = true - end - if double_splat = node.double_splat - @current_arg_type = :double_splat - @str << ", " if printed_arg - double_splat.accept self - printed_arg = true - end - if block_arg = node.block_arg - @current_arg_type = :block_arg - @str << ", " if printed_arg - block_arg.accept self - elsif node.block_arity - @str << ", " if printed_arg - @str << '&' + drop_parens_for_proc_notation do + node.args.each_with_index do |arg, i| + @str << ", " if printed_arg + @current_arg_type = :splat if node.splat_index == i + arg.accept self + printed_arg = true + end + if double_splat = node.double_splat + @current_arg_type = :double_splat + @str << ", " if printed_arg + double_splat.accept self + printed_arg = true + end + if block_arg = node.block_arg + @current_arg_type = :block_arg + @str << ", " if printed_arg + block_arg.accept self + elsif node.block_arity + @str << ", " if printed_arg + @str << '&' + end end @str << ')' end @@ -680,6 +691,8 @@ module Crystal if node.args.size > 0 || node.block_arg || node.double_splat @str << '(' printed_arg = false + # NOTE: `drop_parens_for_proc_notation` needed here if macros support + # restrictions node.args.each_with_index do |arg, i| @str << ", " if printed_arg @current_arg_type = :splat if i == node.splat_index @@ -830,17 +843,24 @@ module Crystal end def visit(node : ProcNotation) - @str << '(' - if inputs = node.inputs - inputs.join(@str, ", ", &.accept self) - @str << ' ' - end - @str << "->" - if output = node.output - @str << ' ' - output.accept self + @str << '(' unless drop_parens_for_proc_notation? + + # only drop the outermost pair of parentheses; this produces + # `foo(x : (T -> U) -> V, W)`, not + # `foo(x : ((T -> U) -> V), W)` nor `foo(x : T -> U -> V, W)` + drop_parens_for_proc_notation(false) do + if inputs = node.inputs + inputs.join(@str, ", ", &.accept self) + @str << ' ' + end + @str << "->" + if output = node.output + @str << ' ' + output.accept self + end end - @str << ')' + + @str << ')' unless drop_parens_for_proc_notation? false end @@ -1143,14 +1163,16 @@ module Crystal end if node.args.size > 0 @str << '(' - node.args.join(@str, ", ") do |arg| - if arg_name = arg.name.presence - @str << arg_name << " : " + drop_parens_for_proc_notation do + node.args.join(@str, ", ") do |arg| + if arg_name = arg.name.presence + @str << arg_name << " : " + end + arg.restriction.not_nil!.accept self + end + if node.varargs? + @str << ", ..." end - arg.restriction.not_nil!.accept self - end - if node.varargs? - @str << ", ..." end @str << ')' elsif node.varargs? @@ -1575,6 +1597,16 @@ module Crystal @inside_macro = old_inside_macro end + def drop_parens_for_proc_notation(drop : Bool = true, &) + old_drop_parens_for_proc_notation = @drop_parens_for_proc_notation + @drop_parens_for_proc_notation = drop + begin + yield + ensure + @drop_parens_for_proc_notation = old_drop_parens_for_proc_notation + end + end + def to_s : String @str.to_s end From 83e33aebec51a6898c70a8158b9885d5f4975d28 Mon Sep 17 00:00:00 2001 From: Quinton Miller Date: Thu, 18 Apr 2024 19:30:20 +0800 Subject: [PATCH 2/3] fixup --- spec/compiler/semantic/restrictions_augmenter_spec.cr | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spec/compiler/semantic/restrictions_augmenter_spec.cr b/spec/compiler/semantic/restrictions_augmenter_spec.cr index 5095f8613943..2b7250658693 100644 --- a/spec/compiler/semantic/restrictions_augmenter_spec.cr +++ b/spec/compiler/semantic/restrictions_augmenter_spec.cr @@ -45,8 +45,8 @@ describe "Semantic: restrictions augmenter" do it_augments_for_ivar "Array(String)", "::Array(::String)" it_augments_for_ivar "Tuple(Int32, Char)", "::Tuple(::Int32, ::Char)" it_augments_for_ivar "NamedTuple(a: Int32, b: Char)", "::NamedTuple(a: ::Int32, b: ::Char)" - it_augments_for_ivar "Proc(Int32, Char)", "(::Int32 -> ::Char)" - it_augments_for_ivar "Proc(Int32, Nil)", "(::Int32 -> _)" + it_augments_for_ivar "Proc(Int32, Char)", "::Int32 -> ::Char" + it_augments_for_ivar "Proc(Int32, Nil)", "::Int32 -> _" it_augments_for_ivar "Pointer(Void)", "::Pointer(::Void)" it_augments_for_ivar "StaticArray(Int32, 8)", "::StaticArray(::Int32, 8)" it_augments_for_ivar "Char | Int32 | String", "::Char | ::Int32 | ::String" From 0d5cad94a673e4e456ec862fdc3ec02d8360051b Mon Sep 17 00:00:00 2001 From: Quinton Miller Date: Fri, 19 Apr 2024 00:27:33 +0800 Subject: [PATCH 3/3] ensure only outermost `->` is stripped --- spec/compiler/parser/to_s_spec.cr | 10 +++ src/compiler/crystal/syntax/to_s.cr | 102 ++++++++++++++++------------ 2 files changed, 67 insertions(+), 45 deletions(-) diff --git a/spec/compiler/parser/to_s_spec.cr b/spec/compiler/parser/to_s_spec.cr index f1af8d9cb586..d4838fc7945c 100644 --- a/spec/compiler/parser/to_s_spec.cr +++ b/spec/compiler/parser/to_s_spec.cr @@ -120,6 +120,16 @@ describe "ASTNode#to_s" do expect_to_s "foo[x : (T -> U) -> V, W] = 1" expect_to_s "lib LibFoo\n fun foo(x : (T -> U) -> V, W) : ((T -> U) -> V)\nend" + expect_to_s "lib LibFoo\n fun foo(x : (T -> U) | V)\nend" + expect_to_s "lib LibFoo\n fun foo(x : Foo((T -> U)))\nend" + expect_to_s "lib LibFoo\n fun foo(x : (T -> U).class)\nend" + expect_to_s "def foo(x : (T -> U) | V)\nend" + expect_to_s "def foo(x : Foo((T -> U)))\nend" + expect_to_s "def foo(x : (T -> U).class)\nend" + expect_to_s "foo(x : (T -> U) | V)" + expect_to_s "foo(x : Foo((T -> U)))" + expect_to_s "foo(x : (T -> U).class)" + expect_to_s "macro foo(@[Foo] id)\nend" expect_to_s "macro foo(**args)\nend" expect_to_s "macro foo(@[Foo] **args)\nend" diff --git a/src/compiler/crystal/syntax/to_s.cr b/src/compiler/crystal/syntax/to_s.cr index 58c553ea18e6..b94b3f6981f1 100644 --- a/src/compiler/crystal/syntax/to_s.cr +++ b/src/compiler/crystal/syntax/to_s.cr @@ -443,26 +443,24 @@ module Crystal private def visit_args(node, exclude_last = false) printed_arg = false - drop_parens_for_proc_notation do - node.args.each_with_index do |arg, i| - break if exclude_last && i == node.args.size - 1 + node.args.each_with_index do |arg, i| + break if exclude_last && i == node.args.size - 1 + @str << ", " if printed_arg + drop_parens_for_proc_notation(arg, &.accept(self)) + printed_arg = true + end + if named_args = node.named_args + named_args.each do |named_arg| @str << ", " if printed_arg - arg.accept self + drop_parens_for_proc_notation(named_arg, &.accept(self)) printed_arg = true end - if named_args = node.named_args - named_args.each do |named_arg| - @str << ", " if printed_arg - named_arg.accept self - printed_arg = true - end - end - if block_arg = node.block_arg - @str << ", " if printed_arg - @str << '&' - block_arg.accept self - end + end + if block_arg = node.block_arg + @str << ", " if printed_arg + @str << '&' + drop_parens_for_proc_notation(block_arg, &.accept(self)) end end @@ -641,27 +639,25 @@ module Crystal if node.args.size > 0 || node.block_arity || node.double_splat @str << '(' printed_arg = false - drop_parens_for_proc_notation do - node.args.each_with_index do |arg, i| - @str << ", " if printed_arg - @current_arg_type = :splat if node.splat_index == i - arg.accept self - printed_arg = true - end - if double_splat = node.double_splat - @current_arg_type = :double_splat - @str << ", " if printed_arg - double_splat.accept self - printed_arg = true - end - if block_arg = node.block_arg - @current_arg_type = :block_arg - @str << ", " if printed_arg - block_arg.accept self - elsif node.block_arity - @str << ", " if printed_arg - @str << '&' - end + node.args.each_with_index do |arg, i| + @str << ", " if printed_arg + @current_arg_type = :splat if node.splat_index == i + drop_parens_for_proc_notation(arg, &.accept(self)) + printed_arg = true + end + if double_splat = node.double_splat + @current_arg_type = :double_splat + @str << ", " if printed_arg + drop_parens_for_proc_notation(double_splat, &.accept(self)) + printed_arg = true + end + if block_arg = node.block_arg + @current_arg_type = :block_arg + @str << ", " if printed_arg + drop_parens_for_proc_notation(block_arg, &.accept(self)) + elsif node.block_arity + @str << ", " if printed_arg + @str << '&' end @str << ')' end @@ -1163,17 +1159,17 @@ module Crystal end if node.args.size > 0 @str << '(' - drop_parens_for_proc_notation do - node.args.join(@str, ", ") do |arg| - if arg_name = arg.name.presence - @str << arg_name << " : " - end - arg.restriction.not_nil!.accept self + node.args.join(@str, ", ") do |arg| + if arg_name = arg.name.presence + @str << arg_name << " : " end - if node.varargs? - @str << ", ..." + drop_parens_for_proc_notation(arg) do + arg.restriction.not_nil!.accept self end end + if node.varargs? + @str << ", ..." + end @str << ')' elsif node.varargs? @str << "(...)" @@ -1607,6 +1603,22 @@ module Crystal end end + def drop_parens_for_proc_notation(node : ASTNode, &) + outermost_type_is_proc_notation = + case node + when Arg + # def / fun parameters + node.restriction.is_a?(ProcNotation) + when TypeDeclaration + # call arguments + node.declared_type.is_a?(ProcNotation) + else + false + end + + drop_parens_for_proc_notation(outermost_type_is_proc_notation) { yield node } + end + def to_s : String @str.to_s end