From 1508425368171c6d6b1da98d50095da5b8e7e42a Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Fri, 6 Jan 2023 17:50:26 -0500 Subject: [PATCH] Slightly generalize _compute_sparam elision (#48144) To catch a case that occurs in FuncPipelines.jl and was causing precision issues in #48066. --- base/compiler/ssair/passes.jl | 25 ++++++++++++++++++------- test/compiler/irpasses.jl | 8 ++++++++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 24293586e0629..82ba6ddf062d7 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -796,7 +796,6 @@ end end && return nothing arg = sig.parameters[i] - isa(arg, DataType) || return nothing rarg = def.args[2 + i] isa(rarg, SSAValue) || return nothing @@ -805,6 +804,10 @@ end rarg = argdef.args[1] isa(rarg, SSAValue) || return nothing argdef = compact[rarg][:inst] + else + isa(arg, DataType) || return nothing + isType(arg) || return nothing + arg = arg.parameters[1] end is_known_call(argdef, Core.apply_type, compact) || return nothing @@ -815,15 +818,23 @@ end applyT = applyT.val isa(applyT, UnionAll) || return nothing + # N.B.: At the moment we only lift the valI == 1 case, so we + # only need to look at the outermost tvar. applyTvar = applyT.var applyTbody = applyT.body - isa(applyTbody, DataType) || return nothing - applyTbody.name == arg.name || return nothing - length(applyTbody.parameters) == length(arg.parameters) == 1 || return nothing - applyTbody.parameters[1] === applyTvar || return nothing - arg.parameters[1] === tvar || return nothing - return LiftedValue(argdef.args[3]) + arg = unwrap_unionall(arg) + applyTbody = unwrap_unionall(applyTbody) + + (isa(arg, DataType) && isa(applyTbody, DataType)) || return nothing + applyTbody.name === arg.name || return nothing + length(applyTbody.parameters) == length(arg.parameters) || return nothing + for i = 1:length(applyTbody.parameters) + if applyTbody.parameters[i] === applyTvar && arg.parameters[i] === tvar + return LiftedValue(argdef.args[3]) + end + end + return nothing end # NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining, diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 2db01c4b85444..bc2cb0d3507f3 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -1221,3 +1221,11 @@ function a47180(b; stdout ) c end @test isa(a47180(``; stdout), Base.AbstractCmd) + +# Test that _compute_sparams can be eliminated for NamedTuple +named_tuple_elim(name::Symbol, result) = NamedTuple{(name,)}(result) +let src = code_typed1(named_tuple_elim, Tuple{Symbol, Tuple}) + @test count(iscall((src, Core._compute_sparams)), src.code) == 0 && + count(iscall((src, Core._svec_ref)), src.code) == 0 && + count(iscall(x->!isa(argextype(x, src).val, Core.Builtin)), src.code) == 0 +end