Skip to content

Commit

Permalink
Merge pull request #42017 from JuliaLang/jn/optimize-atomics-3
Browse files Browse the repository at this point in the history
atomics: optimize atomic modify operations (mostly)
  • Loading branch information
vtjnash authored Sep 2, 2021
2 parents c53669f + 85518c8 commit 1b80634
Show file tree
Hide file tree
Showing 15 changed files with 303 additions and 152 deletions.
12 changes: 6 additions & 6 deletions base/atomics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,13 +356,13 @@ for typ in atomictypes
rt = "$lt, $lt*"
irt = "$ilt, $ilt*"
@eval getindex(x::Atomic{$typ}) =
llvmcall($"""
GC.@preserve x llvmcall($"""
%ptr = inttoptr i$WORD_SIZE %0 to $lt*
%rv = load atomic $rt %ptr acquire, align $(gc_alignment(typ))
ret $lt %rv
""", $typ, Tuple{Ptr{$typ}}, unsafe_convert(Ptr{$typ}, x))
@eval setindex!(x::Atomic{$typ}, v::$typ) =
llvmcall($"""
GC.@preserve x llvmcall($"""
%ptr = inttoptr i$WORD_SIZE %0 to $lt*
store atomic $lt %1, $lt* %ptr release, align $(gc_alignment(typ))
ret void
Expand All @@ -371,7 +371,7 @@ for typ in atomictypes
# Note: atomic_cas! succeeded (i.e. it stored "new") if and only if the result is "cmp"
if typ <: Integer
@eval atomic_cas!(x::Atomic{$typ}, cmp::$typ, new::$typ) =
llvmcall($"""
GC.@preserve x llvmcall($"""
%ptr = inttoptr i$WORD_SIZE %0 to $lt*
%rs = cmpxchg $lt* %ptr, $lt %1, $lt %2 acq_rel acquire
%rv = extractvalue { $lt, i1 } %rs, 0
Expand All @@ -380,7 +380,7 @@ for typ in atomictypes
unsafe_convert(Ptr{$typ}, x), cmp, new)
else
@eval atomic_cas!(x::Atomic{$typ}, cmp::$typ, new::$typ) =
llvmcall($"""
GC.@preserve x llvmcall($"""
%iptr = inttoptr i$WORD_SIZE %0 to $ilt*
%icmp = bitcast $lt %1 to $ilt
%inew = bitcast $lt %2 to $ilt
Expand All @@ -403,15 +403,15 @@ for typ in atomictypes
if rmwop in arithmetic_ops && !(typ <: ArithmeticTypes) continue end
if typ <: Integer
@eval $fn(x::Atomic{$typ}, v::$typ) =
llvmcall($"""
GC.@preserve x llvmcall($"""
%ptr = inttoptr i$WORD_SIZE %0 to $lt*
%rv = atomicrmw $rmw $lt* %ptr, $lt %1 acq_rel
ret $lt %rv
""", $typ, Tuple{Ptr{$typ}, $typ}, unsafe_convert(Ptr{$typ}, x), v)
else
rmwop === :xchg || continue
@eval $fn(x::Atomic{$typ}, v::$typ) =
llvmcall($"""
GC.@preserve x llvmcall($"""
%iptr = inttoptr i$WORD_SIZE %0 to $ilt*
%ival = bitcast $lt %1 to $ilt
%irv = atomicrmw $rmw $ilt* %iptr, $ilt %ival acq_rel
Expand Down
23 changes: 13 additions & 10 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,8 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
return abstract_apply(interp, argtypes, sv, max_methods)
elseif f === invoke
return abstract_invoke(interp, argtypes, sv)
elseif f === modifyfield!
return abstract_modifyfield!(interp, argtypes, sv)
end
return CallMeta(abstract_call_builtin(interp, f, fargs, argtypes, sv, max_methods), false)
elseif f === Core.kwfunc
Expand Down Expand Up @@ -1515,7 +1517,8 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
return abstract_eval_special_value(interp, e, vtypes, sv)
end
e = e::Expr
if e.head === :call
ehead = e.head
if ehead === :call
ea = e.args
argtypes = collect_argtypes(interp, ea, vtypes, sv)
if argtypes === nothing
Expand All @@ -1525,7 +1528,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
sv.stmt_info[sv.currpc] = callinfo.info
t = callinfo.rt
end
elseif e.head === :new
elseif ehead === :new
t = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))[1]
if isconcretetype(t) && !ismutabletype(t)
args = Vector{Any}(undef, length(e.args)-1)
Expand Down Expand Up @@ -1562,7 +1565,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
end
end
end
elseif e.head === :splatnew
elseif ehead === :splatnew
t = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))[1]
if length(e.args) == 2 && isconcretetype(t) && !ismutabletype(t)
at = abstract_eval_value(interp, e.args[2], vtypes, sv)
Expand All @@ -1575,7 +1578,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
t = PartialStruct(t, at.fields::Vector{Any})
end
end
elseif e.head === :new_opaque_closure
elseif ehead === :new_opaque_closure
t = Union{}
if length(e.args) >= 5
ea = e.args
Expand All @@ -1594,29 +1597,29 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
end
end
end
elseif e.head === :foreigncall
elseif ehead === :foreigncall
abstract_eval_value(interp, e.args[1], vtypes, sv)
t = sp_type_rewrap(e.args[2], sv.linfo, true)
for i = 3:length(e.args)
if abstract_eval_value(interp, e.args[i], vtypes, sv) === Bottom
t = Bottom
end
end
elseif e.head === :cfunction
elseif ehead === :cfunction
t = e.args[1]
isa(t, Type) || (t = Any)
abstract_eval_cfunction(interp, e, vtypes, sv)
elseif e.head === :method
elseif ehead === :method
t = (length(e.args) == 1) ? Any : Nothing
elseif e.head === :copyast
elseif ehead === :copyast
t = abstract_eval_value(interp, e.args[1], vtypes, sv)
if t isa Const && t.val isa Expr
# `copyast` makes copies of Exprs
t = Expr
end
elseif e.head === :invoke
elseif ehead === :invoke || ehead === :invoke_modify
error("type inference data-flow error: tried to double infer a function")
elseif e.head === :isdefined
elseif ehead === :isdefined
sym = e.args[1]
t = Bool
if isa(sym, SlotNumber)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
return 0
end
return error_path ? params.inline_error_path_cost : params.inline_nonleaf_penalty
elseif head === :foreigncall || head === :invoke
elseif head === :foreigncall || head === :invoke || head == :invoke_modify
# Calls whose "return type" is Union{} do not actually return:
# they are errors. Since these are not part of the typical
# run-time of the function, we omit them from
Expand Down
16 changes: 16 additions & 0 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,22 @@ function process_simple!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, sta
ir.stmts[idx][:inst] = res
return nothing
end
if (sig.f === modifyfield! || sig.ft typeof(modifyfield!)) && 5 <= length(stmt.args) <= 6
let info = ir.stmts[idx][:info]
info isa MethodResultPure && (info = info.info)
info isa ConstCallInfo && (info = info.call)
info isa MethodMatchInfo || return nothing
length(info.results) == 1 || return nothing
match = info.results[1]::MethodMatch
match.fully_covers || return nothing
case = compileable_specialization(state.et, match)
case === nothing && return nothing
stmt.head = :invoke_modify
pushfirst!(stmt.args, case)
ir.stmts[idx][:inst] = stmt
end
return nothing
end

check_effect_free!(ir, stmt, calltype, idx)

Expand Down
3 changes: 2 additions & 1 deletion base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,8 @@ function getindex(x::UseRef)
end

function is_relevant_expr(e::Expr)
return e.head in (:call, :invoke, :new, :splatnew, :(=), :(&),
return e.head in (:call, :invoke, :invoke_modify,
:new, :splatnew, :(=), :(&),
:gc_preserve_begin, :gc_preserve_end,
:foreigncall, :isdefined, :copyast,
:undefcheck, :throw_undef_if_not,
Expand Down
32 changes: 31 additions & 1 deletion base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -939,10 +939,40 @@ function modifyfield!_tfunc(o, f, op, v)
@nospecialize
T = _fieldtype_tfunc(o, isconcretetype(o), f)
T === Bottom && return Bottom
# note: we could sometimes refine this to a PartialStruct if we analyzed `op(o.f, v)::T`
PT = Const(Pair)
return instanceof_tfunc(apply_type_tfunc(PT, T, T))[1]
end
function abstract_modifyfield!(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState)
nargs = length(argtypes)
if !isempty(argtypes) && isvarargtype(argtypes[nargs])
nargs - 1 <= 6 || return CallMeta(Bottom, false)
nargs > 3 || return CallMeta(Any, false)
else
5 <= nargs <= 6 || return CallMeta(Bottom, false)
end
o = unwrapva(argtypes[2])
f = unwrapva(argtypes[3])
RT = modifyfield!_tfunc(o, f, Any, Any)
info = false
if nargs >= 5 && RT !== Bottom
# we may be able to refine this to a PartialStruct by analyzing `op(o.f, v)::T`
# as well as compute the info for the method matches
op = unwrapva(argtypes[4])
v = unwrapva(argtypes[5])
TF = getfield_tfunc(o, f)
push!(sv.ssavalue_uses[sv.currpc], sv.currpc) # temporarily disable `call_result_unused` check for this call
callinfo = abstract_call(interp, nothing, Any[op, TF, v], sv, #=max_methods=# 1)
pop!(sv.ssavalue_uses[sv.currpc], sv.currpc)
TF2 = tmeet(callinfo.rt, widenconst(TF))
if TF2 === Bottom
RT = Bottom
elseif isconcretetype(RT) && has_nontrivial_const_info(TF2) # isconcrete condition required to form a PartialStruct
RT = PartialStruct(RT, Any[TF, TF2])
end
info = callinfo.info
end
return CallMeta(RT, info)
end
replacefield!_tfunc(o, f, x, v, success_order, failure_order) = (@nospecialize; replacefield!_tfunc(o, f, x, v))
replacefield!_tfunc(o, f, x, v, success_order) = (@nospecialize; replacefield!_tfunc(o, f, x, v))
function replacefield!_tfunc(o, f, x, v)
Expand Down
8 changes: 5 additions & 3 deletions base/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
const VALID_EXPR_HEADS = IdDict{Symbol,UnitRange{Int}}(
:call => 1:typemax(Int),
:invoke => 2:typemax(Int),
:invoke_modify => 3:typemax(Int),
:static_parameter => 1:1,
:(&) => 1:1,
:(=) => 2:2,
Expand Down Expand Up @@ -78,7 +79,7 @@ end

function _validate_val!(@nospecialize(x), errors, ssavals::BitSet)
if isa(x, Expr)
if x.head === :call || x.head === :invoke
if x.head === :call || x.head === :invoke || x.head === :invoke_modify
f = x.args[1]
if f isa GlobalRef && (f.name === :cglobal) && x.head === :call
# TODO: these are not yet linearized
Expand Down Expand Up @@ -138,7 +139,8 @@ function validate_code!(errors::Vector{>:InvalidCodeError}, c::CodeInfo, is_top_
end
validate_val!(lhs)
validate_val!(rhs)
elseif head === :call || head === :invoke || head === :gc_preserve_end || head === :meta ||
elseif head === :call || head === :invoke || x.head === :invoke_modify ||
head === :gc_preserve_end || head === :meta ||
head === :inbounds || head === :foreigncall || head === :cfunction ||
head === :const || head === :enter || head === :leave || head === :pop_exception ||
head === :method || head === :global || head === :static_parameter ||
Expand Down Expand Up @@ -238,7 +240,7 @@ end

function is_valid_rvalue(@nospecialize(x))
is_valid_argument(x) && return true
if isa(x, Expr) && x.head in (:new, :splatnew, :the_exception, :isdefined, :call, :invoke, :foreigncall, :cfunction, :gc_preserve_begin, :copyast)
if isa(x, Expr) && x.head in (:new, :splatnew, :the_exception, :isdefined, :call, :invoke, :invoke_modify, :foreigncall, :cfunction, :gc_preserve_begin, :copyast)
return true
end
return false
Expand Down
2 changes: 2 additions & 0 deletions src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ extern "C" {

// head symbols for each expression type
jl_sym_t *call_sym; jl_sym_t *invoke_sym;
jl_sym_t *invoke_modify_sym;
jl_sym_t *empty_sym; jl_sym_t *top_sym;
jl_sym_t *module_sym; jl_sym_t *slot_sym;
jl_sym_t *export_sym; jl_sym_t *import_sym;
Expand Down Expand Up @@ -345,6 +346,7 @@ void jl_init_common_symbols(void)
empty_sym = jl_symbol("");
call_sym = jl_symbol("call");
invoke_sym = jl_symbol("invoke");
invoke_modify_sym = jl_symbol("invoke_modify");
foreigncall_sym = jl_symbol("foreigncall");
cfunction_sym = jl_symbol("cfunction");
quote_sym = jl_symbol("quote");
Expand Down
49 changes: 30 additions & 19 deletions src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1547,17 +1547,23 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx,
Value *parent, // for the write barrier, NULL if no barrier needed
bool isboxed, AtomicOrdering Order, AtomicOrdering FailOrder, unsigned alignment,
bool needlock, bool issetfield, bool isreplacefield, bool isswapfield, bool ismodifyfield,
bool maybe_null_if_boxed, const std::string &fname)
bool maybe_null_if_boxed, const jl_cgval_t *modifyop, const std::string &fname)
{
auto newval = [&](const jl_cgval_t &lhs) {
jl_cgval_t argv[3] = { cmp, lhs, rhs };
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, argv, 3, JLCALL_F_CC);
argv[0] = mark_julia_type(ctx, callval, true, jl_any_type);
if (!jl_subtype(argv[0].typ, jltype)) {
emit_typecheck(ctx, argv[0], jltype, fname + "typed_store");
argv[0] = update_julia_type(ctx, argv[0], jltype);
}
return argv[0];
const jl_cgval_t argv[3] = { cmp, lhs, rhs };
jl_cgval_t ret;
if (modifyop) {
ret = emit_invoke(ctx, *modifyop, argv, 3, (jl_value_t*)jl_any_type);
}
else {
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, argv, 3, JLCALL_F_CC);
ret = mark_julia_type(ctx, callval, true, jl_any_type);
}
if (!jl_subtype(ret.typ, jltype)) {
emit_typecheck(ctx, ret, jltype, fname + "typed_store");
ret = update_julia_type(ctx, ret, jltype);
}
return ret;
};
assert(!needlock || parent != nullptr);
Type *elty = isboxed ? T_prjlvalue : julia_type_to_llvm(ctx, jltype);
Expand All @@ -1570,7 +1576,7 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx,
else if (isreplacefield) {
Value *Success = emit_f_is(ctx, cmp, ghostValue(jltype));
Success = ctx.builder.CreateZExt(Success, T_int8);
jl_cgval_t argv[2] = {ghostValue(jltype), mark_julia_type(ctx, Success, false, jl_bool_type)};
const jl_cgval_t argv[2] = {ghostValue(jltype), mark_julia_type(ctx, Success, false, jl_bool_type)};
jl_datatype_t *rettyp = jl_apply_cmpswap_type(jltype);
return emit_new_struct(ctx, (jl_value_t*)rettyp, 2, argv);
}
Expand All @@ -1579,7 +1585,7 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx,
}
else { // modifyfield
jl_cgval_t oldval = ghostValue(jltype);
jl_cgval_t argv[2] = { oldval, newval(oldval) };
const jl_cgval_t argv[2] = { oldval, newval(oldval) };
jl_datatype_t *rettyp = jl_apply_modify_type(jltype);
return emit_new_struct(ctx, (jl_value_t*)rettyp, 2, argv);
}
Expand Down Expand Up @@ -1862,7 +1868,7 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx,
}
}
if (ismodifyfield) {
jl_cgval_t argv[2] = { oldval, rhs };
const jl_cgval_t argv[2] = { oldval, rhs };
jl_datatype_t *rettyp = jl_apply_modify_type(jltype);
oldval = emit_new_struct(ctx, (jl_value_t*)rettyp, 2, argv);
}
Expand All @@ -1881,7 +1887,7 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx,
oldval = mark_julia_type(ctx, instr, isboxed, jltype);
if (isreplacefield) {
Success = ctx.builder.CreateZExt(Success, T_int8);
jl_cgval_t argv[2] = {oldval, mark_julia_type(ctx, Success, false, jl_bool_type)};
const jl_cgval_t argv[2] = {oldval, mark_julia_type(ctx, Success, false, jl_bool_type)};
jl_datatype_t *rettyp = jl_apply_cmpswap_type(jltype);
oldval = emit_new_struct(ctx, (jl_value_t*)rettyp, 2, argv);
}
Expand Down Expand Up @@ -3269,7 +3275,7 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
jl_cgval_t rhs, jl_cgval_t cmp,
bool checked, bool wb, AtomicOrdering Order, AtomicOrdering FailOrder,
bool needlock, bool issetfield, bool isreplacefield, bool isswapfield, bool ismodifyfield,
const std::string &fname)
const jl_cgval_t *modifyop, const std::string &fname)
{
if (!sty->name->mutabl && checked) {
std::string msg = fname + "immutable struct of type "
Expand Down Expand Up @@ -3309,9 +3315,14 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
if (ismodifyfield) {
if (needlock)
emit_lockstate_value(ctx, strct, false);
jl_cgval_t argv[3] = { cmp, oldval, rhs };
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, argv, 3, JLCALL_F_CC);
rhs = mark_julia_type(ctx, callval, true, jl_any_type);
const jl_cgval_t argv[3] = { cmp, oldval, rhs };
if (modifyop) {
rhs = emit_invoke(ctx, *modifyop, argv, 3, (jl_value_t*)jl_any_type);
}
else {
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, argv, 3, JLCALL_F_CC);
rhs = mark_julia_type(ctx, callval, true, jl_any_type);
}
if (!jl_subtype(rhs.typ, jfty)) {
emit_typecheck(ctx, rhs, jfty, fname);
rhs = update_julia_type(ctx, rhs, jfty);
Expand Down Expand Up @@ -3364,7 +3375,7 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
return typed_store(ctx, addr, NULL, rhs, cmp, jfty, strct.tbaa, nullptr,
wb ? maybe_bitcast(ctx, data_pointer(ctx, strct), T_pjlvalue) : nullptr,
isboxed, Order, FailOrder, align,
needlock, issetfield, isreplacefield, isswapfield, ismodifyfield, maybe_null, fname);
needlock, issetfield, isreplacefield, isswapfield, ismodifyfield, maybe_null, modifyop, fname);
}
}

Expand Down Expand Up @@ -3543,7 +3554,7 @@ static jl_cgval_t emit_new_struct(jl_codectx_t &ctx, jl_value_t *ty, size_t narg
else
need_wb = false;
emit_typecheck(ctx, rhs, jl_svecref(sty->types, i), "new");
emit_setfield(ctx, sty, strctinfo, i, rhs, jl_cgval_t(), false, need_wb, AtomicOrdering::NotAtomic, AtomicOrdering::NotAtomic, false, true, false, false, false, "");
emit_setfield(ctx, sty, strctinfo, i, rhs, jl_cgval_t(), false, need_wb, AtomicOrdering::NotAtomic, AtomicOrdering::NotAtomic, false, true, false, false, false, nullptr, "");
}
return strctinfo;
}
Expand Down
Loading

0 comments on commit 1b80634

Please sign in to comment.