Skip to content

Commit

Permalink
lattice overhaul step 3: clean up implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Jan 22, 2022
1 parent 6538da0 commit 9a46952
Show file tree
Hide file tree
Showing 16 changed files with 1,022 additions and 1,010 deletions.
144 changes: 71 additions & 73 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

83 changes: 11 additions & 72 deletions base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,17 @@ include("ordering.jl")
using .Order
include("sort.jl")
using .Sort
# required by sort/sort! functions
function extrema(x::Array)
isempty(x) && throw(ArgumentError("collection must be non-empty"))
vmin = vmax = x[1]
for i in 2:length(x)
xi = x[i]
vmax = max(vmax, xi)
vmin = min(vmin, xi)
end
return vmin, vmax
end

# We don't include some.jl, but this definition is still useful.
something(x::Nothing, y...) = something(y...)
Expand Down Expand Up @@ -138,78 +149,6 @@ include("compiler/abstractinterpretation.jl")
include("compiler/typeinfer.jl")
include("compiler/optimize.jl") # TODO: break this up further + extract utilities

# required for bootstrap
# TODO: find why this is needed and remove it.
function extrema(x::Array)
isempty(x) && throw(ArgumentError("collection must be non-empty"))
vmin = vmax = x[1]
for i in 2:length(x)
xi = x[i]
vmax = max(vmax, xi)
vmin = min(vmin, xi)
end
return vmin, vmax
end

# function show(io::IO, xs::Vector)
# print(io, eltype(xs), '[')
# show_itr(io, xs)
# print(io, ']')
# end
# function show(io::IO, xs::Tuple)
# print(io, '(')
# show_itr(io, xs)
# print(io, ')')
# end
# function show_itr(io::IO, xs)
# n = length(xs)
# for i in 1:n
# show(io, xs[i])
# i == n || print(io, ", ")
# end
# end
# function show(io::IO, typ′::LatticeElement)
# function name(x)
# if isLimitedAccuracy(typ′)
# return (nameof(x), '′',)
# else
# return (nameof(x),)
# end
# end
# typ = ignorelimited(typ′)
# if isConditional(typ)
# show(io, conditional(typ))
# elseif isConst(typ)
# print(io, name(Const)..., '(', constant(typ), ')')
# elseif isPartialStruct(typ)
# print(io, name(PartialStruct)..., '(', widenconst(typ), ", [")
# n = length(partialfields(typ))
# for i in 1:n
# show(io, partialfields(typ)[i])
# i == n || print(io, ", ")
# end
# print(io, "])")
# elseif isPartialTypeVar(typ)
# print(io, name(PartialTypeVar)..., '(')
# show(io, typ.partialtypevar.tv)
# print(io, ')')
# else
# print(io, name(NativeType)..., '(', widenconst(typ), ')')
# end
# end
# function show(io::IO, typ::ConditionalInfo)
# if typ === __NULL_CONDITIONAL__
# return print(io, "__NULL_CONDITIONAL__")
# end
# print(io, nameof(Conditional), '(')
# show(io, typ.var)
# print(io, ", ")
# show(io, typ.vtype)
# print(io, ", ")
# show(io, typ.elsetype)
# print(io, ')')
# end

include("compiler/bootstrap.jl")
ccall(:jl_set_typeinf_func, Cvoid, (Any,), typeinf_ext_toplevel)

Expand Down
14 changes: 7 additions & 7 deletions base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ function matching_cache_argtypes(
if slotid !== nothing
# using union-split signature, we may be able to narrow down `Conditional`
sigt = widenconst(slotid > nargs ? argtypes[slotid] : cache_argtypes[slotid])
vtype = tmeet(cnd.vtype, sigt)
elsetype = tmeet(cnd.elsetype, sigt)
vtype = cnd.vtype sigt
elsetype = cnd.elsetype sigt
if vtype === Bottom && elsetype === Bottom
# we accidentally proved this method match is impossible
# TODO bail out here immediately rather than just propagating Bottom ?
Expand Down Expand Up @@ -67,7 +67,7 @@ function matching_cache_argtypes(
else
last = nargs
end
isva_given_argtypes[nargs] = LatticeElement(tuple_tfunc(given_argtypes[last:end]))
isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[last:end])
# invalidate `Conditional` imposed on varargs
if condargs !== nothing
for (slotid, i) in condargs
Expand Down Expand Up @@ -166,13 +166,13 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
elseif isconstType(atyp)
atyp = Const(atyp.parameters[1])
else
atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes))
atyp = NativeType(elim_free_typevars(rewrap_unionall(atyp, specTypes)))
end
i == n && (lastatype = atyp)
cache_argtypes[i] = LatticeElement(atyp)
cache_argtypes[i] = atyp
end
for i = (tail_index + 1):nargs
cache_argtypes[i] = LatticeElement(lastatype)
cache_argtypes[i] = lastatype
end
else
@assert nargs == 0 "invalid specialization of method" # wrong number of arguments
Expand Down Expand Up @@ -219,7 +219,7 @@ function cache_lookup(linfo::MethodInstance, given_argtypes::Argtypes, cache::Ve
end
end
if method.isva && cache_match
cache_match = is_argtype_match(LatticeElement(tuple_tfunc(anymap(unwraptype, given_argtypes[(nargs + 1):end]))),
cache_match = is_argtype_match(tuple_tfunc(given_argtypes[(nargs + 1):end]),
cache_argtypes[end],
cache_overridden_by_const[end])
end
Expand Down
11 changes: 5 additions & 6 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ function stmt_effect_free(@nospecialize(stmt), @nospecialize(rt), src::Union{IRC
rt ===&& return false
return _builtin_nothrow(f, LatticeElement[argextype(args[i], src) for i = 2:length(args)], rt)
elseif head === :new
typ = unwraptype(argextype(args[1], src))
typ = argextype(args[1], src)
# `Expr(:new)` of unknown type could raise arbitrary TypeError.
typ, isexact = instanceof_tfunc(typ)
isexact || return false
Expand All @@ -240,7 +240,7 @@ function stmt_effect_free(@nospecialize(stmt), @nospecialize(rt), src::Union{IRC
return foreigncall_effect_free(stmt, src)
elseif head === :new_opaque_closure
length(args) < 5 && return false
typ = unwraptype(argextype(args[1], src))
typ = argextype(args[1], src)
typ, isexact = instanceof_tfunc(typ)
isexact || return false
typ Tuple || return false
Expand Down Expand Up @@ -350,7 +350,7 @@ function argextype(
if x.head === :static_parameter
return sptypes[x.args[1]::Int]
elseif x.head === :boundscheck
return NativeType(Bool)
return LBool
elseif x.head === :copyast
return argextype(x.args[1], src, sptypes, slottypes)
end
Expand Down Expand Up @@ -543,7 +543,7 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
# insert a side-effect instruction before the current instruction in the same basic block
insert!(code, idx, Expr(:code_coverage_effect))
insert!(codelocs, idx, codeloc)
insert!(ssavaluetypes, idx, NativeType(Nothing))
insert!(ssavaluetypes, idx, LNothing)
insert!(stmtinfo, idx, nothing)
insert!(ssaflags, idx, IR_FLAG_NULL)
changemap[oldidx] += 1
Expand Down Expand Up @@ -631,9 +631,8 @@ intrinsic_effect_free_if_nothrow(f) = f === Intrinsics.pointerref ||
# saturating sum (inputs are nonnegative), prevents overflow with typemax(Int) below
plus_saturate(x::Int, y::Int) = max(x, y, x+y)

# TODO (lattice overhaul) T::LatticeElement
# known return type
isknowntype(@nospecialize T) = (T === ⊥) || isConst(T) || isconcretetype(widenconst(T))
isknowntype(T::LatticeElement) = (T === ⊥) || isConst(T) || isconcretetype(widenconst(T))

function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptypes::Argtypes,
union_penalties::Bool, params::OptimizationParams, error_path::Bool = false)
Expand Down
50 changes: 25 additions & 25 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
push!(linetable, newentry)
end
if coverage && spec.ir.stmts[1][:line] + linetable_offset != topline
insert_node_here!(compact, NewInstruction(Expr(:code_coverage_effect), Nothing, topline))
insert_node_here!(compact, NewInstruction(Expr(:code_coverage_effect), LNothing, topline))
end
if def.isva
nargs_def = Int(def.nargs::Int32)
Expand Down Expand Up @@ -400,7 +400,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
inline_compact.result[idx′][:type] =
argextype(val, isa(val, Expr) ? compact : inline_compact)
insert_node_here!(inline_compact, NewInstruction(GotoNode(post_bb_id),
Any, compact.result[idx′][:line]),
, compact.result[idx′][:line]),
true)
push!(pn.values, SSAValue(idx′))
else
Expand Down Expand Up @@ -443,11 +443,11 @@ function fix_va_argexprs!(compact::IncrementalCompact,
argexprs::Vector{Any}, nargs_def::Int, line_idx::Int32)
newargexprs = argexprs[1:(nargs_def-1)]
tuple_call = Expr(:call, TOP_TUPLE)
tuple_typs = Any[]
tuple_typs = LatticeElement[]
for i in nargs_def:length(argexprs)
arg = argexprs[i]
push!(tuple_call.args, arg)
push!(tuple_typs, unwraptype(argextype(arg, compact)))
push!(tuple_typs, argextype(arg, compact))
end
tuple_typ = tuple_tfunc(tuple_typs)
push!(newargexprs, insert_node_here!(compact, NewInstruction(tuple_call, tuple_typ, line_idx)))
Expand Down Expand Up @@ -480,15 +480,15 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
a <: m && continue
# Generate isa check
isa_expr = Expr(:call, isa, argexprs[i], m)
ssa = insert_node_here!(compact, NewInstruction(isa_expr, Bool, line))
ssa = insert_node_here!(compact, NewInstruction(isa_expr, LBool, line))
if cond === true
cond = ssa
else
and_expr = Expr(:call, and_int, cond, ssa)
cond = insert_node_here!(compact, NewInstruction(and_expr, Bool, line))
cond = insert_node_here!(compact, NewInstruction(and_expr, LBool, line))
end
end
insert_node_here!(compact, NewInstruction(GotoIfNot(cond, next_cond_bb), Union{}, line))
insert_node_here!(compact, NewInstruction(GotoIfNot(cond, next_cond_bb), , line))
bb = next_cond_bb - 1
finish_current_bb!(compact, 0)
argexprs′ = argexprs
Expand All @@ -500,7 +500,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
a, m = aparams[i], mparams[i]
if !(a <: m)
argexprs′[i] = insert_node_here!(compact,
NewInstruction(PiNode(argex, m), m, line))
NewInstruction(PiNode(argex, m), NativeType(m), line))
end
end
end
Expand All @@ -517,25 +517,25 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
push!(pn.edges, bb)
push!(pn.values, val)
insert_node_here!(compact,
NewInstruction(GotoNode(join_bb), Union{}, line))
NewInstruction(GotoNode(join_bb), , line))
else
insert_node_here!(compact,
NewInstruction(ReturnNode(), Union{}, line))
NewInstruction(ReturnNode(), , line))
end
finish_current_bb!(compact, 0)
end
bb += 1
# We're now in the fall through block, decide what to do
if fully_covered
e = Expr(:call, GlobalRef(Core, :throw), FATAL_TYPE_BOUND_ERROR)
insert_node_here!(compact, NewInstruction(e, Union{}, line))
insert_node_here!(compact, NewInstruction(ReturnNode(), Union{}, line))
insert_node_here!(compact, NewInstruction(e, , line))
insert_node_here!(compact, NewInstruction(ReturnNode(), , line))
finish_current_bb!(compact, 0)
else
ssa = insert_node_here!(compact, NewInstruction(stmt, typ, line))
push!(pn.edges, bb)
push!(pn.values, ssa)
insert_node_here!(compact, NewInstruction(GotoNode(join_bb), Union{}, line))
insert_node_here!(compact, NewInstruction(GotoNode(join_bb), , line))
finish_current_bb!(compact, 0)
end

Expand Down Expand Up @@ -642,7 +642,7 @@ function rewrite_apply_exprargs!(
if thisarginfo === nothing
if isPartialStruct(def_type)
# def_type.typ <: Tuple is assumed
def_argtypes = LatticeElement[LatticeElement(t) for t in partialfields(def_type)]
def_argtypes = partialfields(def_type)
else
def_argtypes = LatticeElement[]
if isConst(def_type) # && isa(constant(def_type), Union{Tuple, SimpleVector}) is implied
Expand Down Expand Up @@ -699,15 +699,15 @@ function rewrite_apply_exprargs!(
new_sig, istate, todo)
end
if i != length(thisarginfo.each)
valT = getfield_tfunc(unwraptype(call.rt), Const(1))
valT = getfield_tfunc(call.rt, Const(1))
val_extracted = insert_node!(ir, idx, NewInstruction(
Expr(:call, GlobalRef(Core, :getfield), state1, 1),
valT))
push!(new_argexprs, val_extracted)
push!(new_argtypes, LatticeElement(valT))
push!(new_argtypes, valT)
state_extracted = insert_node!(ir, idx, NewInstruction(
Expr(:call, GlobalRef(Core, :getfield), state1, 2),
getfield_tfunc(unwraptype(call.rt), Const(2))))
getfield_tfunc(call.rt, Const(2))))
state = Core.svec(state_extracted)
end
end
Expand Down Expand Up @@ -892,19 +892,19 @@ function is_valid_type_for_apply_rewrite(typ::LatticeElement, params::Optimizati
end

function inline_splatnew!(ir::IRCode, idx::Int, stmt::Expr, rt::LatticeElement)
nf = nfields_tfunc(unwraptype(rt))
nf = nfields_tfunc(rt)
if isConst(nf)
eargs = stmt.args
tup = eargs[2]
tt = argextype(tup, ir)
tnf = nfields_tfunc(unwraptype(tt))
tnf = nfields_tfunc(tt)
# TODO: hoisting this constant(tnf) === constant(nf) check into codegen
# would enable us to almost always do this transform
if isConst(tnf) && constant(tnf) === constant(nf)
n = constant(tnf)::Int
new_argexprs = Any[eargs[1]]
for j = 1:n
atype = getfield_tfunc(unwraptype(tt), Const(j))
atype = getfield_tfunc(tt, Const(j))
new_call = Expr(:call, Core.getfield, tup, j)
new_argexpr = insert_node!(ir, idx, NewInstruction(new_call, atype))
push!(new_argexprs, new_argexpr)
Expand Down Expand Up @@ -1036,14 +1036,14 @@ end
function narrow_opaque_closure!(ir::IRCode, stmt::Expr, @nospecialize(info), state::InliningState)
if isa(info, OpaqueClosureCreateInfo)
lbt = argextype(stmt.args[3], ir)
lb, exact = instanceof_tfunc(unwraptype(lbt))
lb, exact = instanceof_tfunc(lbt)
exact || return
ubt = argextype(stmt.args[4], ir)
ub, exact = instanceof_tfunc(unwraptype(ubt))
ub, exact = instanceof_tfunc(ubt)
exact || return
# Narrow opaque closure type
newT = widenconst(tmeet(tmerge(lb, info.unspec.rt), ub))
if newT != ub
newT = widenconst(tmerge(lb, info.unspec.rt) ub)
if newT !== ub
# N.B.: Narrowing the ub requires a backdge on the mi whose type
# information we're using, since a change in that function may
# invalidate ub result.
Expand Down Expand Up @@ -1434,7 +1434,7 @@ function late_inline_special_case!(
return SomeCase(quoted(constant(type)))
end
cmp_call = Expr(:call, GlobalRef(Core, :(===)), stmt.args[2], stmt.args[3])
cmp_call_ssa = insert_node!(ir, idx, effect_free(NewInstruction(cmp_call, Bool)))
cmp_call_ssa = insert_node!(ir, idx, effect_free(NewInstruction(cmp_call, LBool)))
not_call = Expr(:call, GlobalRef(Core.Intrinsics, :not_int), cmp_call_ssa)
return SomeCase(not_call)
elseif isinlining && length(argtypes) == 3 && istopfunction(f, :(>:))
Expand Down
8 changes: 4 additions & 4 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,17 @@ struct NewInstruction
# Don't bother redoing so on insertion.
effect_free_computed::Bool

function NewInstruction(@nospecialize(stmt), @nospecialize(type), @nospecialize(info),
function NewInstruction(@nospecialize(stmt), type::LatticeElement, @nospecialize(info),
line::Union{Int32, Nothing}, flag::UInt8, effect_free_computed::Bool)
if isa(type, Type)
type = NativeType(type)
end
return new(stmt, type, info, line, flag, effect_free_computed)
end
end
NewInstruction(@nospecialize(stmt), @nospecialize(type)) =
NewInstruction(@nospecialize(stmt), type::LatticeElement) =
NewInstruction(stmt, type, nothing)
NewInstruction(@nospecialize(stmt), @nospecialize(type), line::Union{Nothing, Int32}) =
NewInstruction(@nospecialize(stmt), type::LatticeElement, line::Union{Nothing, Int32}) =
NewInstruction(stmt, type, nothing, line, IR_FLAG_NULL, false)

effect_free(inst::NewInstruction) =
Expand Down Expand Up @@ -1163,7 +1163,7 @@ function finish_current_bb!(compact::IncrementalCompact, active_bb, old_result_i
if unreachable
node[:inst], node[:type], node[:line] = ReturnNode(), ⊥, 0
else
node[:inst], node[:type], node[:line] = nothing, NativeType(Nothing), 0
node[:inst], node[:type], node[:line] = nothing, LNothing, 0
end
compact.result_idx = old_result_idx + 1
elseif compact.cfg_transforms_enabled && compact.result_idx - 1 == first(bb.stmts)
Expand Down
Loading

0 comments on commit 9a46952

Please sign in to comment.