From 39e8c966680275f3222dc59ae63f9ce835a7494b Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Thu, 3 Jan 2019 16:41:48 -0500 Subject: [PATCH] add `splatnew` form; like `new` but accepts a tuple This allows the core NamedTuple constructor to be defined without generated functions. With a bit of work in the front end it will also allow splatting inside `new` in general. --- base/boot.jl | 29 +----------- base/compiler/abstractinterpretation.jl | 3 ++ base/compiler/ssair/inlining.jl | 25 ++++++++++ base/compiler/ssair/ir.jl | 2 +- base/compiler/tfuncs.jl | 20 ++++---- base/compiler/validation.jl | 5 +- base/show.jl | 4 +- doc/src/devdocs/ast.md | 5 ++ src/ast.c | 2 + src/codegen.cpp | 21 +++++++++ src/datatype.c | 62 ++++++++++++++++++++----- src/interpreter.c | 10 ++++ src/julia-syntax.scm | 6 +-- src/julia.h | 4 +- src/julia_internal.h | 1 + 15 files changed, 140 insertions(+), 59 deletions(-) diff --git a/base/boot.jl b/base/boot.jl index 63c7f4b5b4821..a134c0bacf8c6 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -548,33 +548,8 @@ NamedTuple{names}(args::Tuple) where {names} = NamedTuple{names,typeof(args)}(ar using .Intrinsics: sle_int, add_int -macro generated() - return Expr(:generated) -end - -function NamedTuple{names,T}(args::T) where {names, T <: Tuple} - if @generated - N = nfields(names) - flds = Array{Any,1}(undef, N) - i = 1 - while sle_int(i, N) - arrayset(false, flds, :(getfield(args, $i)), i) - i = add_int(i, 1) - end - Expr(:new, :(NamedTuple{names,T}), flds...) - else - N = nfields(names) - NT = NamedTuple{names,T} - flds = Array{Any,1}(undef, N) - i = 1 - while sle_int(i, N) - arrayset(false, flds, getfield(args, i), i) - i = add_int(i, 1) - end - ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), NT, - ccall(:jl_array_ptr, Ptr{Cvoid}, (Any,), flds), toUInt32(N))::NT - end -end +eval(Core, :(NamedTuple{names,T}(args::T) where {names, T <: Tuple} = + $(Expr(:splatnew, :(NamedTuple{names,T}), :args)))) # constructors for built-in types diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 8486f9ded91b9..284f93641b38d 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -922,6 +922,9 @@ function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState) t = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), t, args, length(args))) end end + elseif e.head === :splatnew + t = instanceof_tfunc(abstract_eval(e.args[1], vtypes, sv))[1] + # TODO: improve elseif e.head === :& abstract_eval(e.args[1], vtypes, sv) t = Any diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index b7844ab745e60..17c09d53a7785 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -782,6 +782,31 @@ function assemble_inline_todo!(ir::IRCode, linetable::Vector{LineInfoNode}, sv:: todo = Any[] for idx in 1:length(ir.stmts) stmt = ir.stmts[idx] + + if isexpr(stmt, :splatnew) + ty = ir.types[idx] + nf = nfields_tfunc(ty) + if nf isa Const + eargs = stmt.args + tup = eargs[2] + tt = argextype(tup, ir, sv.sp) + tnf = nfields_tfunc(tt) + if tnf isa Const && tnf.val <= nf.val + n = tnf.val + new_argexprs = Any[eargs[1]] + for j = 1:n + atype = getfield_tfunc(tt, Const(j)) + new_call = Expr(:call, Core.getfield, tup, j) + new_argexpr = insert_node!(ir, idx, atype, new_call) + push!(new_argexprs, new_argexpr) + end + stmt.head = :new + stmt.args = new_argexprs + end + end + continue + end + isexpr(stmt, :call) || continue eargs = stmt.args isempty(eargs) && continue diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index 024b3500fae40..19e18a3bd5f70 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -325,7 +325,7 @@ function getindex(x::UseRef) end function is_relevant_expr(e::Expr) - return e.head in (:call, :invoke, :new, :(=), :(&), + return e.head in (:call, :invoke, :new, :splatnew, :(=), :(&), :gc_preserve_begin, :gc_preserve_end, :foreigncall, :isdefined, :copyast, :undefcheck, :throw_undef_if_not, diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 927f2ae30e553..a25323de36a59 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -329,17 +329,17 @@ function sizeof_tfunc(@nospecialize(x),) return Int end add_tfunc(Core.sizeof, 1, 1, sizeof_tfunc, 0) -add_tfunc(nfields, 1, 1, - function (@nospecialize(x),) - isa(x, Const) && return Const(nfields(x.val)) - isa(x, Conditional) && return Const(0) - if isa(x, DataType) && !x.abstract && !(x.name === Tuple.name && isvatuple(x)) - if !(x.name === _NAMEDTUPLE_NAME && !isconcretetype(x)) - return Const(length(x.types)) - end +function nfields_tfunc(@nospecialize(x)) + isa(x, Const) && return Const(nfields(x.val)) + isa(x, Conditional) && return Const(0) + if isa(x, DataType) && !x.abstract && !(x.name === Tuple.name && isvatuple(x)) + if !(x.name === _NAMEDTUPLE_NAME && !isconcretetype(x)) + return Const(length(x.types)) end - return Int - end, 0) + end + return Int +end +add_tfunc(nfields, 1, 1, nfields_tfunc, 0) add_tfunc(Core._expr, 1, INT_INF, (@nospecialize args...)->Expr, 100) function typevar_tfunc(@nospecialize(n), @nospecialize(lb_arg), @nospecialize(ub_arg)) lb = Union{} diff --git a/base/compiler/validation.jl b/base/compiler/validation.jl index 567d06861ff1f..1fcbf624e4256 100644 --- a/base/compiler/validation.jl +++ b/base/compiler/validation.jl @@ -11,6 +11,7 @@ const VALID_EXPR_HEADS = IdDict{Any,Any}( :method => 1:4, :const => 1:1, :new => 1:typemax(Int), + :splatnew => 2:2, :return => 1:1, :unreachable => 0:0, :the_exception => 0:0, @@ -142,7 +143,7 @@ function validate_code!(errors::Vector{>:InvalidCodeError}, c::CodeInfo, is_top_ head === :inbounds || head === :foreigncall || head === :cfunction || head === :const || head === :enter || head === :leave || head == :pop_exception || head === :method || head === :global || head === :static_parameter || - head === :new || head === :thunk || head === :simdloop || + head === :new || head === :splatnew || head === :thunk || head === :simdloop || head === :throw_undef_if_not || head === :unreachable validate_val!(x) else @@ -224,7 +225,7 @@ end function is_valid_rvalue(@nospecialize(x)) is_valid_argument(x) && return true - if isa(x, Expr) && x.head in (:new, :the_exception, :isdefined, :call, :invoke, :foreigncall, :cfunction, :gc_preserve_begin, :copyast) + if isa(x, Expr) && x.head in (:new, :splatnew, :the_exception, :isdefined, :call, :invoke, :foreigncall, :cfunction, :gc_preserve_begin, :copyast) return true end return false diff --git a/base/show.jl b/base/show.jl index d1fb68186076d..8f7c4f0e2ba96 100644 --- a/base/show.jl +++ b/base/show.jl @@ -1208,8 +1208,8 @@ function show_unquoted(io::IO, ex::Expr, indent::Int, prec::Int) end # new expr - elseif head === :new - show_enclosed_list(io, "%new(", args, ", ", ")", indent) + elseif head === :new || head === :splatnew + show_enclosed_list(io, "%$head(", args, ", ", ")", indent) # other call-like expressions ("A[1,2]", "T{X,Y}", "f.(X,Y)") elseif haskey(expr_calls, head) && nargs >= 1 # :ref/:curly/:calldecl/:(.) diff --git a/doc/src/devdocs/ast.md b/doc/src/devdocs/ast.md index f81ccad54c270..2633efdabc23c 100644 --- a/doc/src/devdocs/ast.md +++ b/doc/src/devdocs/ast.md @@ -359,6 +359,11 @@ These symbols appear in the `head` field of `Expr`s in lowered form. to this, and the type is always inserted by the compiler. This is very much an internal-only feature, and does no checking. Evaluating arbitrary `new` expressions can easily segfault. + * `splatnew` + + Similar to `new`, except field values are passed as a single tuple. Works similarly to + `Base.splat(new)` if `new` were a first-class function, hence the name. + * `return` Returns its argument as the value of the enclosing function. diff --git a/src/ast.c b/src/ast.c index 385aa55c913d6..0d41f5ea5ed72 100644 --- a/src/ast.c +++ b/src/ast.c @@ -41,6 +41,7 @@ jl_sym_t *enter_sym; jl_sym_t *leave_sym; jl_sym_t *pop_exception_sym; jl_sym_t *exc_sym; jl_sym_t *error_sym; jl_sym_t *new_sym; jl_sym_t *using_sym; +jl_sym_t *splatnew_sym; jl_sym_t *const_sym; jl_sym_t *thunk_sym; jl_sym_t *abstracttype_sym; jl_sym_t *primtype_sym; jl_sym_t *structtype_sym; jl_sym_t *foreigncall_sym; @@ -325,6 +326,7 @@ void jl_init_frontend(void) leave_sym = jl_symbol("leave"); pop_exception_sym = jl_symbol("pop_exception"); new_sym = jl_symbol("new"); + splatnew_sym = jl_symbol("splatnew"); const_sym = jl_symbol("const"); global_sym = jl_symbol("global"); thunk_sym = jl_symbol("thunk"); diff --git a/src/codegen.cpp b/src/codegen.cpp index 83bfcbe9071bc..63bfe69062474 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -264,6 +264,7 @@ static Function *jltls_states_func; // important functions static Function *jlnew_func; +static Function *jlsplatnew_func; static Function *jlthrow_func; static Function *jlerror_func; static Function *jltypeerror_func; @@ -4069,6 +4070,15 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaval) // it to the inferred type. return mark_julia_type(ctx, val, true, (jl_value_t*)jl_any_type); } + else if (head == splatnew_sym) { + jl_cgval_t argv[2]; + argv[0] = emit_expr(ctx, args[0]); + argv[1] = emit_expr(ctx, args[1]); + Value *typ = boxed(ctx, argv[0]); + Value *tup = boxed(ctx, argv[1]); + Value *val = ctx.builder.CreateCall(prepare_call(jlsplatnew_func), { typ, tup }); + return mark_julia_type(ctx, val, true, (jl_value_t*)jl_any_type); + } else if (head == exc_sym) { return mark_julia_type(ctx, ctx.builder.CreateCall(prepare_call(jl_current_exception_func)), @@ -6981,6 +6991,17 @@ static void init_julia_llvm_env(Module *m) jlnew_func->addFnAttr(Thunk); add_named_global(jlnew_func, &jl_new_structv); + std::vector args_2rptrs_(0); + args_2rptrs_.push_back(T_prjlvalue); + args_2rptrs_.push_back(T_prjlvalue); + jlsplatnew_func = + Function::Create(FunctionType::get(T_prjlvalue, args_2rptrs_, false), + Function::ExternalLinkage, + "jl_new_structt", m); + add_return_attr(jlsplatnew_func, Attribute::NonNull); + jlsplatnew_func->addFnAttr(Thunk); + add_named_global(jlsplatnew_func, &jl_new_structt); + std::vector args2(0); args2.push_back(T_pint8); #ifndef _OS_WINDOWS_ diff --git a/src/datatype.c b/src/datatype.c index 8a9a883889905..1d21881e3917d 100644 --- a/src/datatype.c +++ b/src/datatype.c @@ -797,8 +797,24 @@ JL_DLLEXPORT jl_value_t *jl_new_struct(jl_datatype_t *type, ...) return jv; } -JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args, - uint32_t na) +static void init_struct_tail(jl_datatype_t *type, jl_value_t *jv, size_t na) +{ + size_t nf = jl_datatype_nfields(type); + for(size_t i=na; i < nf; i++) { + if (jl_field_isptr(type, i)) { + *(jl_value_t**)((char*)jl_data_ptr(jv)+jl_field_offset(type,i)) = NULL; + } + else { + jl_value_t *ft = jl_field_type(type, i); + if (jl_is_uniontype(ft)) { + uint8_t *psel = &((uint8_t *)jv)[jl_field_offset(type, i) + jl_field_size(type, i) - 1]; + *psel = 0; + } + } + } +} + +JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args, uint32_t na) { jl_ptls_t ptls = jl_get_ptls_states(); if (type->instance != NULL) { @@ -811,7 +827,6 @@ JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args, } if (type->layout == NULL) jl_type_error("new", (jl_value_t*)jl_datatype_type, (jl_value_t*)type); - size_t nf = jl_datatype_nfields(type); jl_value_t *jv = jl_gc_alloc(ptls, jl_datatype_size(type), type); JL_GC_PUSH1(&jv); for (size_t i = 0; i < na; i++) { @@ -820,18 +835,41 @@ JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args, jl_type_error("new", ft, args[i]); jl_set_nth_field(jv, i, args[i]); } - for(size_t i=na; i < nf; i++) { - if (jl_field_isptr(type, i)) { - *(jl_value_t**)((char*)jl_data_ptr(jv)+jl_field_offset(type,i)) = NULL; - } - else { + init_struct_tail(type, jv, na); + JL_GC_POP(); + return jv; +} + +JL_DLLEXPORT jl_value_t *jl_new_structt(jl_datatype_t *type, jl_value_t *tup) +{ + jl_ptls_t ptls = jl_get_ptls_states(); + if (!jl_is_tuple(tup)) + jl_type_error("new", (jl_value_t*)jl_tuple_type, tup); + size_t na = jl_nfields(tup); + size_t nf = jl_datatype_nfields(type); + if (na > nf) + jl_too_many_args("new", nf); + if (type->instance != NULL) { + for (size_t i = 0; i < na; i++) { jl_value_t *ft = jl_field_type(type, i); - if (jl_is_uniontype(ft)) { - uint8_t *psel = &((uint8_t *)jv)[jl_field_offset(type, i) + jl_field_size(type, i) - 1]; - *psel = 0; - } + jl_value_t *fi = jl_get_nth_field(tup, i); + if (!jl_isa(fi, ft)) + jl_type_error("new", ft, fi); } + return type->instance; + } + if (type->layout == NULL) + jl_type_error("new", (jl_value_t*)jl_datatype_type, (jl_value_t*)type); + jl_value_t *jv = jl_gc_alloc(ptls, jl_datatype_size(type), type); + JL_GC_PUSH1(&jv); + for (size_t i = 0; i < na; i++) { + jl_value_t *ft = jl_field_type(type, i); + jl_value_t *fi = jl_get_nth_field(tup, i); + if (!jl_isa(fi, ft)) + jl_type_error("new", ft, fi); + jl_set_nth_field(jv, i, fi); } + init_struct_tail(type, jv, na); JL_GC_POP(); return jv; } diff --git a/src/interpreter.c b/src/interpreter.c index ca96b8d38c728..c54caf2e3c323 100644 --- a/src/interpreter.c +++ b/src/interpreter.c @@ -467,6 +467,16 @@ SECT_INTERP static jl_value_t *eval_value(jl_value_t *e, interpreter_state *s) JL_GC_POP(); return v; } + else if (head == splatnew_sym) { + jl_value_t **argv; + JL_GC_PUSHARGS(argv, 2); + argv[0] = eval_value(args[0], s); + argv[1] = eval_value(args[1], s); + assert(jl_is_structtype(argv[0])); + jl_value_t *v = jl_new_structt((jl_datatype_t*)argv[0], argv[1]); + JL_GC_POP(); + return v; + } else if (head == static_parameter_sym) { ssize_t n = jl_unbox_long(args[0]); assert(n > 0); diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 8c4232e484607..bee5ac2528ef2 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -3047,7 +3047,7 @@ f(x) = yt(x) (del! unused (cadr e))) ;; in all other cases there's nothing to do except assert that ;; all expression heads have been handled. - #;(assert (memq (car e) '(= method new call foreigncall cfunction |::|))))))) + #;(assert (memq (car e) '(= method new splatnew call foreigncall cfunction |::|))))))) (visit (lam:body lam)) ;; Finally, variables can be marked never-undef if they were set in the first block, ;; or are currently live, or are back in the unused set (because we've left the only @@ -3414,7 +3414,7 @@ f(x) = yt(x) (or (ssavalue? lhs) (valid-ir-argument? e) (and (symbol? lhs) (pair? e) - (memq (car e) '(new the_exception isdefined call invoke foreigncall cfunction gc_preserve_begin copyast))))) + (memq (car e) '(new splatnew the_exception isdefined call invoke foreigncall cfunction gc_preserve_begin copyast))))) (define (valid-ir-return? e) ;; returning lambda directly is needed for @generated @@ -3604,7 +3604,7 @@ f(x) = yt(x) ((and (pair? e1) (eq? (car e1) 'globalref)) (emit e1) #f) ;; keep globals for undefined-var checking (else #f))) (case (car e) - ((call new foreigncall cfunction) + ((call new splatnew foreigncall cfunction) (let* ((args (cond ((eq? (car e) 'foreigncall) ;; NOTE: 2nd to 5th arguments of ccall must be left in place diff --git a/src/julia.h b/src/julia.h index 79f343cb75465..a9040eaee70e6 100644 --- a/src/julia.h +++ b/src/julia.h @@ -1146,8 +1146,8 @@ jl_datatype_t *jl_new_abstracttype(jl_value_t *name, jl_module_t *module, // constructors JL_DLLEXPORT jl_value_t *jl_new_bits(jl_value_t *bt, void *data); JL_DLLEXPORT jl_value_t *jl_new_struct(jl_datatype_t *type, ...); -JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args, - uint32_t na); +JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args, uint32_t na); +JL_DLLEXPORT jl_value_t *jl_new_structt(jl_datatype_t *type, jl_value_t *tup); JL_DLLEXPORT jl_value_t *jl_new_struct_uninit(jl_datatype_t *type); JL_DLLEXPORT jl_method_instance_t *jl_new_method_instance_uninit(void); JL_DLLEXPORT jl_svec_t *jl_svec(size_t n, ...) JL_MAYBE_UNROOTED; diff --git a/src/julia_internal.h b/src/julia_internal.h index 86d0689f27819..0b2b3cabe903c 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -991,6 +991,7 @@ extern jl_sym_t *method_sym; extern jl_sym_t *core_sym; extern jl_sym_t *enter_sym; extern jl_sym_t *leave_sym; extern jl_sym_t *exc_sym; extern jl_sym_t *error_sym; extern jl_sym_t *new_sym; extern jl_sym_t *using_sym; +extern jl_sym_t *splatnew_sym; extern jl_sym_t *pop_exception_sym; extern jl_sym_t *const_sym; extern jl_sym_t *thunk_sym; extern jl_sym_t *abstracttype_sym; extern jl_sym_t *primtype_sym;