From a8bc71a7364a95ca8871a2724b01e52c3601bbbe Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Mon, 7 Aug 2017 16:12:11 -0400 Subject: [PATCH] add `if @generated ... else ... end` inside functions to provide optional optimizers use meta nodes instead of `stagedfunction` expression head --- NEWS.md | 4 + base/boot.jl | 28 +++++ base/docs/Docs.jl | 2 +- base/expr.jl | 19 +++- base/linalg/bidiag.jl | 16 ++- base/methodshow.jl | 12 +- base/multidimensional.jl | 52 +++++---- base/reflection.jl | 5 +- base/sysimg.jl | 32 +++--- doc/src/manual/metaprogramming.md | 67 +++++++++-- src/ast.c | 5 +- src/ast.scm | 14 +++ src/codegen.cpp | 8 +- src/dump.c | 2 +- src/interpreter.c | 2 +- src/jltypes.c | 3 +- src/julia-syntax.scm | 147 ++++++++++++++++--------- src/julia.h | 4 +- src/julia_internal.h | 2 + src/macroexpand.scm | 5 - src/method.c | 177 +++++++++++++----------------- src/utils.scm | 13 ++- test/staged.jl | 22 ++++ 23 files changed, 398 insertions(+), 243 deletions(-) diff --git a/NEWS.md b/NEWS.md index 11a56e3b90408..8974bfcce80f0 100644 --- a/NEWS.md +++ b/NEWS.md @@ -19,6 +19,10 @@ New language features * The macro call syntax `@macroname[args]` is now available and is parsed as `@macroname([args])` ([#23519]). + * The construct `if @generated ...; else ...; end` can be used to provide both + `@generated` and normal implementations of part of a function. Surrounding code + will be common to both versions ([#23168]). + Language changes ---------------- diff --git a/base/boot.jl b/base/boot.jl index 178f77d0d1857..263c59c5c8614 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -432,4 +432,32 @@ show(@nospecialize a) = show(STDOUT, a) print(@nospecialize a...) = print(STDOUT, a...) println(@nospecialize a...) = println(STDOUT, a...) +struct GeneratedFunctionStub + gen + argnames + spnames + line + file +end + +# invoke and wrap the results of @generated +function (g::GeneratedFunctionStub)(args...) + body = g.gen(args...) + if body isa CodeInfo + return body + end + lam = Expr(:lambda, g.argnames, + Expr(Symbol("scope-block"), + Expr(:block, + LineNumberNode(g.line, g.file), + Expr(:meta, :push_loc, g.file, Symbol("@generated body")), + Expr(:return, body), + Expr(:meta, :pop_loc)))) + if g.spnames === nothing + return lam + else + return Expr(Symbol("with-static-parameters"), lam, g.spnames...) + end +end + ccall(:jl_set_istopmod, Void, (Any, Bool), Core, true) diff --git a/base/docs/Docs.jl b/base/docs/Docs.jl index 94f0336926424..ed2d8b71ee6d1 100644 --- a/base/docs/Docs.jl +++ b/base/docs/Docs.jl @@ -642,7 +642,7 @@ finddoc(λ, def) = false # Predicates and helpers for `docm` expression selection: -const FUNC_HEADS = [:function, :stagedfunction, :macro, :(=)] +const FUNC_HEADS = [:function, :macro, :(=)] const BINDING_HEADS = [:typealias, :const, :global, :(=)] # deprecation: remove `typealias` post-0.6 # For the special `:@mac` / `:(Base.@mac)` syntax for documenting a macro after definition. isquotedmacrocall(x) = diff --git a/base/expr.jl b/base/expr.jl index d3138e1a0ac23..9035fdfeed04b 100644 --- a/base/expr.jl +++ b/base/expr.jl @@ -332,10 +332,23 @@ function remove_linenums!(ex::Expr) return ex end +macro generated() + return Expr(:generated) +end + macro generated(f) - if isa(f, Expr) && (f.head === :function || is_short_function_def(f)) - f.head = :stagedfunction - return Expr(:escape, f) + if isa(f, Expr) && (f.head === :function || is_short_function_def(f)) + body = f.args[2] + lno = body.args[1] + return Expr(:escape, + Expr(f.head, f.args[1], + Expr(:block, + lno, + Expr(:if, Expr(:generated), + body, + Expr(:block, + Expr(:meta, :generated_only), + Expr(:return, nothing)))))) else error("invalid syntax; @generated must be used with a function definition") end diff --git a/base/linalg/bidiag.jl b/base/linalg/bidiag.jl index 46241dc6ef1fd..e28e781ef0a8c 100644 --- a/base/linalg/bidiag.jl +++ b/base/linalg/bidiag.jl @@ -573,12 +573,18 @@ _valuefields(::Type{<:AbstractTriangular}) = [:data] const SpecialArrays = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal,AbstractTriangular} -@generated function fillslots!(A::SpecialArrays, x) - ex = :(xT = convert(eltype(A), x)) - for field in _valuefields(A) - ex = :($ex; fill!(A.$field, xT)) +function fillslots!(A::SpecialArrays, x) + xT = convert(eltype(A), x) + if @generated + quote + $([ :(fill!(A.$field, xT)) for field in _valuefields(A) ]...) + end + else + for field in _valuefields(A) + fill!(getfield(A, field), xT) + end end - :($ex;return A) + return A end # for historical reasons: diff --git a/base/methodshow.jl b/base/methodshow.jl index ed09e3da62d9e..c31c6879d0783 100644 --- a/base/methodshow.jl +++ b/base/methodshow.jl @@ -42,6 +42,15 @@ function argtype_decl(env, n, sig::DataType, i::Int, nargs, isva::Bool) # -> (ar return s, string_with_env(env, t) end +function method_argnames(m::Method) + if !isdefined(m, :source) && isdefined(m, :generator) + return m.generator.argnames + end + argnames = Vector{Any}(m.nargs) + ccall(:jl_fill_argnames, Void, (Any, Any), m.source, argnames) + return argnames +end + function arg_decl_parts(m::Method) tv = Any[] sig = m.sig @@ -52,8 +61,7 @@ function arg_decl_parts(m::Method) file = m.file line = m.line if isdefined(m, :source) || isdefined(m, :generator) - argnames = Vector{Any}(m.nargs) - ccall(:jl_fill_argnames, Void, (Any, Any), isdefined(m, :source) ? m.source : m.generator.inferred, argnames) + argnames = method_argnames(m) show_env = ImmutableDict{Symbol, Any}() for t in tv show_env = ImmutableDict(show_env, :unionall_env => t) diff --git a/base/multidimensional.jl b/base/multidimensional.jl index 9a25e4223272b..7534cd5480a3e 100644 --- a/base/multidimensional.jl +++ b/base/multidimensional.jl @@ -549,14 +549,11 @@ end @noinline throw_checksize_error(A, sz) = throw(DimensionMismatch("output array is the wrong size; expected $sz, got $(size(A))")) ## setindex! ## -@generated function _setindex!(l::IndexStyle, A::AbstractArray, x, I::Union{Real, AbstractArray}...) - N = length(I) - quote - @_inline_meta - @boundscheck checkbounds(A, I...) - _unsafe_setindex!(l, _maybe_reshape(l, A, I...), x, I...) - A - end +function _setindex!(l::IndexStyle, A::AbstractArray, x, I::Union{Real, AbstractArray}...) + @_inline_meta + @boundscheck checkbounds(A, I...) + _unsafe_setindex!(l, _maybe_reshape(l, A, I...), x, I...) + A end _iterable(v::AbstractArray) = v @@ -916,28 +913,29 @@ function copy!(dest::AbstractArray{T,N}, src::AbstractArray{T,N}) where {T,N} dest end -@generated function copy!(dest::AbstractArray{T1,N}, - Rdest::CartesianRange{N}, - src::AbstractArray{T2,N}, - Rsrc::CartesianRange{N}) where {T1,T2,N} - quote - isempty(Rdest) && return dest - if size(Rdest) != size(Rsrc) - throw(ArgumentError("source and destination must have same size (got $(size(Rsrc)) and $(size(Rdest)))")) +function copy!(dest::AbstractArray{T1,N}, Rdest::CartesianRange{N}, + src::AbstractArray{T2,N}, Rsrc::CartesianRange{N}) where {T1,T2,N} + isempty(Rdest) && return dest + if size(Rdest) != size(Rsrc) + throw(ArgumentError("source and destination must have same size (got $(size(Rsrc)) and $(size(Rdest)))")) + end + @boundscheck checkbounds(dest, first(Rdest)) + @boundscheck checkbounds(dest, last(Rdest)) + @boundscheck checkbounds(src, first(Rsrc)) + @boundscheck checkbounds(src, last(Rsrc)) + ΔI = first(Rdest) - first(Rsrc) + if @generated + quote + @nloops $N i (n->Rsrc.indices[n]) begin + @inbounds @nref($N,dest,n->i_n+ΔI[n]) = @nref($N,src,i) + end end - @boundscheck checkbounds(dest, first(Rdest)) - @boundscheck checkbounds(dest, last(Rdest)) - @boundscheck checkbounds(src, first(Rsrc)) - @boundscheck checkbounds(src, last(Rsrc)) - ΔI = first(Rdest) - first(Rsrc) - # TODO: restore when #9080 is fixed - # for I in Rsrc - # @inbounds dest[I+ΔI] = src[I] - @nloops $N i (n->Rsrc.indices[n]) begin - @inbounds @nref($N,dest,n->i_n+ΔI[n]) = @nref($N,src,i) + else + for I in Rsrc + @inbounds dest[I + ΔI] = src[I] end - dest end + dest end """ diff --git a/base/reflection.jl b/base/reflection.jl index fae06328abab4..b82451dc3a0e4 100644 --- a/base/reflection.jl +++ b/base/reflection.jl @@ -738,7 +738,8 @@ function length(mt::MethodTable) end isempty(mt::MethodTable) = (mt.defs === nothing) -uncompressed_ast(m::Method) = uncompressed_ast(m, isdefined(m, :source) ? m.source : m.generator.inferred) +uncompressed_ast(m::Method) = isdefined(m,:source) ? uncompressed_ast(m, m.source) : + error("Method is @generated; try `code_lowered` instead.") uncompressed_ast(m::Method, s::CodeInfo) = s uncompressed_ast(m::Method, s::Array{UInt8,1}) = ccall(:jl_uncompress_ast, Any, (Any, Any), m, s)::CodeInfo uncompressed_ast(m::Core.MethodInstance) = uncompressed_ast(m.def) @@ -852,7 +853,7 @@ code_native(::IO, ::Any, ::Symbol) = error("illegal code_native call") # resolve # give a decent error message if we try to instantiate a staged function on non-leaf types function func_for_method_checked(m::Method, @nospecialize types) - if isdefined(m,:generator) && !isdefined(m,:source) && !_isleaftype(types) + if isdefined(m,:generator) && !_isleaftype(types) error("cannot call @generated function `", m, "` ", "with abstract argument types: ", types) end diff --git a/base/sysimg.jl b/base/sysimg.jl index f845cbb31196f..0f1a6d2cd514f 100644 --- a/base/sysimg.jl +++ b/base/sysimg.jl @@ -236,21 +236,27 @@ include("broadcast.jl") using .Broadcast # define the real ntuple functions -@generated function ntuple(f::F, ::Val{N}) where {F,N} - Core.typeassert(N, Int) - (N >= 0) || return :(throw($(ArgumentError(string("tuple length should be ≥0, got ", N))))) - return quote - $(Expr(:meta, :inline)) - @nexprs $N i -> t_i = f(i) - @ncall $N tuple t +@inline function ntuple(f::F, ::Val{N}) where {F,N} + N::Int + (N >= 0) || throw(ArgumentError(string("tuple length should be ≥0, got ", N))) + if @generated + quote + @nexprs $N i -> t_i = f(i) + @ncall $N tuple t + end + else + Tuple(f(i) for i = 1:N) end end -@generated function fill_to_length(t::Tuple, val, ::Val{N}) where {N} - M = length(t.parameters) - M > N && return :(throw($(ArgumentError("input tuple of length $M, requested $N")))) - return quote - $(Expr(:meta, :inline)) - (t..., $(Any[ :val for i = (M + 1):N ]...)) +@inline function fill_to_length(t::Tuple, val, ::Val{N}) where {N} + M = length(t) + M > N && throw(ArgumentError("input tuple of length $M, requested $N")) + if @generated + quote + (t..., $(fill(:val, N-length(t.parameters))...)) + end + else + (t..., fill(val, N-M)...) end end diff --git a/doc/src/manual/metaprogramming.md b/doc/src/manual/metaprogramming.md index 743881b9a5d86..ae9a3914bbc1f 100644 --- a/doc/src/manual/metaprogramming.md +++ b/doc/src/manual/metaprogramming.md @@ -1012,17 +1012,16 @@ syntax tree. A very special macro is `@generated`, which allows you to define so-called *generated functions*. These have the capability to generate specialized code depending on the types of their arguments with more flexibility and/or less code than what can be achieved with multiple dispatch. While -macros work with expressions at parsing-time and cannot access the types of their inputs, a generated +macros work with expressions at parse time and cannot access the types of their inputs, a generated function gets expanded at a time when the types of the arguments are known, but the function is not yet compiled. Instead of performing some calculation or action, a generated function declaration returns a quoted expression which then forms the body for the method corresponding to the types of the arguments. -When called, the body expression is first evaluated and compiled, then the returned expression -is compiled and run. To make this efficient, the result is often cached. And to make this inferable, -only a limited subset of the language is usable. Thus, generated functions provide a flexible -framework to move work from run-time to compile-time, at the expense of greater restrictions on -the allowable constructs. +When a generated function is called, the expression it returns is compiled and then run. +To make this efficient, the result is usually cached. And to make this inferable, only a limited +subset of the language is usable. Thus, generated functions provide a flexible way to move work from +run time to compile time, at the expense of greater restrictions on allowed constructs. When defining generated functions, there are four main differences to ordinary functions: @@ -1038,7 +1037,7 @@ When defining generated functions, there are four main differences to ordinary f This means they can only read global constants, and cannot have any side effects. In other words, they must be completely pure. Due to an implementation limitation, this also means that they currently cannot define a closure - or untyped generator. + or generator. It's easiest to illustrate this with an example. We can declare a generated function `foo` as @@ -1053,9 +1052,8 @@ foo (generic function with 1 method) Note that the body returns a quoted expression, namely `:(x * x)`, rather than just the value of `x * x`. -From the caller's perspective, they are very similar to regular functions; in fact, you don't -have to know if you're calling a regular or generated function - the syntax and result of the -call is just the same. Let's see how `foo` behaves: +From the caller's perspective, this is identical to a regular function; in fact, you don't +have to know whether you're calling a regular or generated function. Let's see how `foo` behaves: ```jldoctest generated julia> x = foo(2); # note: output is from println() statement in the body @@ -1199,7 +1197,7 @@ end and at the call site; however, *don't copy them*, for the following reasons: when, how often or how many times these side-effects will occur * the `bar` function solves a problem that is better solved with multiple dispatch - defining `bar(x) = x` and `bar(x::Integer) = x ^ 2` will do the same thing, but it is both simpler and faster. - * the `baz` function is pathologically insane + * the `baz` function is pathological Note that the set of operations that should not be attempted in a generated function is unbounded, and the runtime system can currently only detect a subset of the invalid operations. There are @@ -1317,3 +1315,50 @@ the two tuples, multiplication and addition/subtraction. All the looping is perf and we avoid looping during execution entirely. Thus, we only loop *once per type*, in this case once per `N` (except in edge cases where the function is generated more than once - see disclaimer above). + +### Optionally-generated functions + +Generated functions can achieve high efficiency at run time, but come with a compile time cost: +a new function body must be generated for every combination of concrete argument types. +Typically, Julia is able to compile "generic" versions of functions that will work for any +arguments, but with generated functions this is impossible. +This means that programs making heavy use of generated functions might be impossible to +statically compile. + +To solve this problem, the language provides syntax for writing normal, non-generated +alternative implementations of generated functions. +Applied to the `sub2ind` example above, it would look like this: + +```julia +function sub2ind_gen(dims::NTuple{N}, I::Integer...) where N + if N != length(I) + throw(ArgumentError("Number of dimensions must match number of indices.")) + end + if @generated + ex = :(I[$N] - 1) + for i = (N - 1):-1:1 + ex = :(I[$i] - 1 + dims[$i] * $ex) + end + return :($ex + 1) + else + ind = I[N] - 1 + for i = (N - 1):-1:1 + ind = I[i] - 1 + dims[i]*ind + end + return ind + 1 + end +end +``` + +Internally, this code creates two implementations of the function: a generated one where +the first block in `if @generated` is used, and a normal one where the `else` block is used. +Notice that we added an error check to the top of the function. +This code will be common to both versions, and is run-time code in both versions +(in other words, it will be quoted and returned as an expression from the generated version). +Inside the `then` part of the `if @generated` block, code has the same semantics as other +generated functions: argument names refer to types, and the code should return an expression. + +In this style of definition, the code generation feature is essentially an optional +optimization. +The compiler will use it if convenient, but otherwise may choose to use the normal +implementation instead. diff --git a/src/ast.c b/src/ast.c index 8b9117ff3b25c..c2a7f90500cd5 100644 --- a/src/ast.c +++ b/src/ast.c @@ -55,7 +55,8 @@ jl_sym_t *meta_sym; jl_sym_t *compiler_temp_sym; jl_sym_t *inert_sym; jl_sym_t *vararg_sym; jl_sym_t *unused_sym; jl_sym_t *static_parameter_sym; jl_sym_t *polly_sym; jl_sym_t *inline_sym; -jl_sym_t *propagate_inbounds_sym; +jl_sym_t *propagate_inbounds_sym; jl_sym_t *generated_sym; +jl_sym_t *generated_only_sym; jl_sym_t *isdefined_sym; jl_sym_t *nospecialize_sym; jl_sym_t *macrocall_sym; jl_sym_t *hygienicscope_sym; @@ -343,6 +344,8 @@ void jl_init_frontend(void) hygienicscope_sym = jl_symbol("hygienic-scope"); gc_preserve_begin_sym = jl_symbol("gc_preserve_begin"); gc_preserve_end_sym = jl_symbol("gc_preserve_end"); + generated_sym = jl_symbol("generated"); + generated_only_sym = jl_symbol("generated_only"); } JL_DLLEXPORT void jl_lisp_prompt(void) diff --git a/src/ast.scm b/src/ast.scm index 18d389cc46037..c8799364a93e2 100644 --- a/src/ast.scm +++ b/src/ast.scm @@ -358,6 +358,20 @@ (and (if one (length= e 3) (length> e 2)) (eq? (car e) 'meta) (eq? (cadr e) 'nospecialize))) +(define (if-generated? e) + (and (length= e 4) (eq? (car e) 'if) (equal? (cadr e) '(generated)))) + +(define (generated-meta? e) + (and (length= e 3) (eq? (car e) 'meta) (eq? (cadr e) 'generated))) + +(define (generated_only-meta? e) + (and (length= e 2) (eq? (car e) 'meta) (eq? (cadr e) 'generated_only))) + +(define (function-def? e) + (and (pair? e) (or (eq? (car e) 'function) (eq? (car e) '->) + (and (eq? (car e) '=) (length= e 3) + (eventually-call? (cadr e)))))) + ;; flatten nested expressions with the given head ;; (op (op a b) c) => (op a b c) (define (flatten-ex head e) diff --git a/src/codegen.cpp b/src/codegen.cpp index 3ae37d1b3225a..b90a53d43652f 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -1200,8 +1200,6 @@ jl_llvm_functions_t jl_compile_linfo(jl_method_instance_t **pli, jl_code_info_t li->inferred && // and there is something to delete (test this before calling jl_ast_flag_inlineable) li->inferred != jl_nothing && - // don't delete the code for the generator - li != li->def.method->generator && // don't delete inlineable code, unless it is constant (li->jlcall_api == 2 || !jl_ast_flag_inlineable((jl_array_t*)li->inferred)) && // don't delete code when generating a precompile file @@ -3849,11 +3847,10 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr) } Value *a1 = boxed(ctx, emit_expr(ctx, args[1])); Value *a2 = boxed(ctx, emit_expr(ctx, args[2])); - Value *mdargs[4] = { + Value *mdargs[3] = { /*argdata*/a1, /*code*/a2, - /*module*/literal_pointer_val(ctx, (jl_value_t*)ctx.module), - /*isstaged*/literal_pointer_val(ctx, args[3]) + /*module*/literal_pointer_val(ctx, (jl_value_t*)ctx.module) }; ctx.builder.CreateCall(prepare_call(jlmethod_func), makeArrayRef(mdargs)); return ghostValue(jl_void_type); @@ -6401,7 +6398,6 @@ static void init_julia_llvm_env(Module *m) mdargs.push_back(T_prjlvalue); mdargs.push_back(T_prjlvalue); mdargs.push_back(T_pjlvalue); - mdargs.push_back(T_pjlvalue); jlmethod_func = Function::Create(FunctionType::get(T_void, mdargs, false), Function::ExternalLinkage, diff --git a/src/dump.c b/src/dump.c index 6e223ab4d6017..e45087a82bd25 100644 --- a/src/dump.c +++ b/src/dump.c @@ -1454,7 +1454,7 @@ static jl_value_t *jl_deserialize_value_method(jl_serializer_state *s, jl_value_ m->unspecialized = (jl_method_instance_t*)jl_deserialize_value(s, (jl_value_t**)&m->unspecialized); if (m->unspecialized) jl_gc_wb(m, m->unspecialized); - m->generator = (jl_method_instance_t*)jl_deserialize_value(s, (jl_value_t**)&m->generator); + m->generator = jl_deserialize_value(s, (jl_value_t**)&m->generator); if (m->generator) jl_gc_wb(m, m->generator); m->invokes.unknown = jl_deserialize_value(s, (jl_value_t**)&m->invokes); diff --git a/src/interpreter.c b/src/interpreter.c index 7faa061f405c3..8c2e1d0e521b0 100644 --- a/src/interpreter.c +++ b/src/interpreter.c @@ -322,7 +322,7 @@ static jl_value_t *eval(jl_value_t *e, interpreter_state *s) JL_GC_PUSH2(&atypes, &meth); atypes = eval(args[1], s); meth = eval(args[2], s); - jl_method_def((jl_svec_t*)atypes, (jl_code_info_t*)meth, s->module, args[3]); + jl_method_def((jl_svec_t*)atypes, (jl_code_info_t*)meth, s->module); JL_GC_POP(); return jl_nothing; } diff --git a/src/jltypes.c b/src/jltypes.c index 320c12b969936..56209a26aedef 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -2061,7 +2061,7 @@ void jl_init_types(void) jl_simplevector_type, jl_any_type, jl_any_type, // jl_method_instance_type - jl_any_type, // jl_method_instance_type + jl_any_type, jl_array_any_type, jl_any_type, jl_int32_type, @@ -2174,7 +2174,6 @@ void jl_init_types(void) #endif jl_svecset(jl_methtable_type->types, 8, jl_int32_type); // uint32_t jl_svecset(jl_method_type->types, 10, jl_method_instance_type); - jl_svecset(jl_method_type->types, 11, jl_method_instance_type); jl_svecset(jl_method_instance_type->types, 12, jl_voidpointer_type); jl_svecset(jl_method_instance_type->types, 13, jl_voidpointer_type); jl_svecset(jl_method_instance_type->types, 14, jl_voidpointer_type); diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 4d668dc861222..4076d9a7ac41e 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -286,9 +286,35 @@ (map (lambda (x) (replace-outer-vars x renames)) (cdr e)))))) +(define (make-generator-function name sp-names arg-names body) + (let ((arg-names (append sp-names + (map (lambda (n) + (if (eq? n '|#self#|) (gensy) n)) + arg-names)))) + (let ((body (insert-after-meta body ;; don't specialize on generator arguments + `((meta nospecialize ,@arg-names))))) + `(block + (global ,name) + (function (call ,name ,@arg-names) ,body))))) + +;; select the `then` or `else` part of `if @generated` based on flag `genpart` +(define (generated-part- x genpart) + (cond ((or (atom? x) (quoted? x) (function-def? x)) x) + ((if-generated? x) + (if genpart `($ ,(caddr x)) (cadddr x))) + (else (cons (car x) + (map (lambda (e) (generated-part- e genpart)) (cdr x)))))) + +(define (generated-version body) + `(block + ,(julia-bq-macro (generated-part- body #t)))) + +(define (non-generated-version body) + (generated-part- body #f)) + ;; construct the (method ...) expression for one primitive method definition, ;; assuming optional and keyword args are already handled -(define (method-def-expr- name sparams argl body isstaged (rett '(core Any))) +(define (method-def-expr- name sparams argl body (rett '(core Any))) (if (any kwarg? argl) ;; has optional positional args @@ -307,20 +333,39 @@ (dfl (map caddr kws))) (receive (vararg req) (separate vararg? argl) - (optional-positional-defs name sparams req opt dfl body isstaged + (optional-positional-defs name sparams req opt dfl body (append req opt vararg) rett))))) ;; no optional positional args - (let ((names (map car sparams))) - (let ((anames (llist-vars argl))) - (if (has-dups (filter (lambda (x) (not (eq? x UNUSED))) anames)) - (error "function argument names not unique")) - (if (has-dups names) - (error "function static parameter names not unique")) - (if (any (lambda (x) (and (not (eq? x UNUSED)) (memq x names))) anames) - (error "function argument and static parameter names must be distinct"))) + (let ((names (map car sparams)) + (anames (llist-vars argl))) + (if (has-dups (filter (lambda (x) (not (eq? x UNUSED))) anames)) + (error "function argument names not unique")) + (if (has-dups names) + (error "function static parameter names not unique")) + (if (any (lambda (x) (and (not (eq? x UNUSED)) (memq x names))) anames) + (error "function argument and static parameter names must be distinct")) (if (or (and name (not (sym-ref? name))) (eq? name 'true) (eq? name 'false)) (error (string "invalid function name \"" (deparse name) "\""))) - (let* ((types (llist-types argl)) + (let* ((generator (if (expr-contains-p if-generated? body (lambda (x) (not (function-def? x)))) + (let* ((gen (generated-version body)) + (nongen (non-generated-version body)) + (gname (symbol (string (gensy) "#" (current-julia-module-counter)))) + (gf (make-generator-function gname names (llist-vars argl) gen)) + (loc (function-body-lineno body))) + (set! body (insert-after-meta + nongen + `((meta generated + (new (core GeneratedFunctionStub) + ,gname + ,(cons 'list anames) + ,(if (null? sparams) + 'nothing + (cons 'list (map car sparams))) + ,(if (null? loc) 0 (cadr loc)) + (inert ,(if (null? loc) 'none (caddr loc)))))))) + (list gf)) + '())) + (types (llist-types argl)) (body (method-lambda-expr argl body rett)) ;; HACK: the typevars need to be bound to ssavalues, since this code ;; might be moved to a different scope by closure-convert. @@ -329,7 +374,7 @@ (mdef (if (null? sparams) `(method ,name (call (core svec) (call (core svec) ,@(dots->vararg types)) (call (core svec))) - ,body ,isstaged) + ,body) `(method ,name (block ,@(let loop ((n names) @@ -350,10 +395,12 @@ (replace-vars ty renames)) types))) (call (core svec) ,@temps))) - ,body ,isstaged)))) + ,body)))) (if (or (symbol? name) (globalref? name)) - `(block (method ,name) ,mdef (unnecessary ,name)) ;; return the function - mdef))))) + `(block ,@generator (method ,name) ,mdef (unnecessary ,name)) ;; return the function + (if (not (null? generator)) + `(block ,@generator ,mdef) + mdef)))))) ;; wrap expr in nested scopes assigning names to vals (define (scopenest names vals expr) @@ -365,10 +412,8 @@ (define empty-vector-any '(call (core AnyVector) 0)) -(define (keywords-method-def-expr name sparams argl body isstaged rett) +(define (keywords-method-def-expr name sparams argl body rett) (let* ((kargl (cdar argl)) ;; keyword expressions (= k v) - (annotations (map (lambda (a) `(meta nospecialize ,(arg-name (cadr (caddr a))))) - (filter nospecialize-meta? kargl))) (kargl (map (lambda (a) (if (nospecialize-meta? a) (caddr a) a)) kargl)) @@ -403,6 +448,8 @@ keynames)) ;; list of function's initial line number and meta nodes (empty if none) (prologue (extract-method-prologue body)) + (annotations (map (lambda (a) `(meta nospecialize ,(arg-name (cadr (caddr a))))) + (filter nospecialize-meta? kargl))) ;; body statements (stmts (cdr body)) (positional-sparams @@ -427,7 +474,7 @@ ,(method-def-expr- name positional-sparams (append pargl vararg) `(block - ,@prologue + ,@(without-generated prologue) ,(let (;; call mangled(vals..., [rest_kw,] pargs..., [vararg]...) (ret `(return (call ,mangled ,@(if ordered-defaults keynames vals) @@ -437,8 +484,7 @@ (list `(... ,(arg-name (car vararg))))))))) (if ordered-defaults (scopenest keynames vals ret) - ret))) - #f) + ret)))) ;; call with keyword args pre-sorted - original method code goes here ,(method-def-expr- @@ -457,7 +503,7 @@ (insert-after-meta `(block ,@stmts) annotations) - isstaged rett) + rett) ;; call with unsorted keyword args. this sorts and re-dispatches. ,(method-def-expr- @@ -539,8 +585,7 @@ ,@(if (null? restkw) '() (list rkw)) ,@(map arg-name pargl) ,@(if (null? vararg) '() - (list `(... ,(arg-name (car vararg))))))))) - #f) + (list `(... ,(arg-name (car vararg)))))))))) ;; return primary function ,(if (not (symbol? name)) '(null) name))))) @@ -553,6 +598,11 @@ (cdr body)) '())) +(define (without-generated stmts) + (filter (lambda (x) (not (or (generated-meta? x) + (generated_only-meta? x)))) + stmts)) + ;; keep only sparams used by `expr` or other sparams (define (filter-sparams expr sparams) (let loop ((filtered '()) @@ -566,8 +616,8 @@ (else (loop filtered (cdr params)))))) -(define (optional-positional-defs name sparams req opt dfl body isstaged overall-argl rett) - (let ((prologue (extract-method-prologue body))) +(define (optional-positional-defs name sparams req opt dfl body overall-argl rett) + (let ((prologue (without-generated (extract-method-prologue body)))) `(block ,@(map (lambda (n) (let* ((passed (append req (list-head opt n))) @@ -596,9 +646,9 @@ `(block ,@prologue (call ,(arg-name (car req)) ,@(map arg-name (cdr passed)) ,@vals))))) - (method-def-expr- name sp passed body #f))) + (method-def-expr- name sp passed body))) (iota (length opt))) - ,(method-def-expr- name sparams overall-argl body isstaged rett)))) + ,(method-def-expr- name sparams overall-argl body rett)))) ;; strip empty (parameters ...), normalizing `f(x;)` to `f(x)`. (define (remove-empty-parameters argl) @@ -627,14 +677,14 @@ ;; definitions without keyword arguments are passed to method-def-expr-, ;; which handles optional positional arguments by adding the needed small ;; boilerplate definitions. -(define (method-def-expr name sparams argl body isstaged rett) +(define (method-def-expr name sparams argl body rett) (let ((argl (remove-empty-parameters argl))) (if (has-parameters? argl) ;; has keywords (begin (check-kw-args (cdar argl)) - (keywords-method-def-expr name sparams argl body isstaged rett)) + (keywords-method-def-expr name sparams argl body rett)) ;; no keywords - (method-def-expr- name sparams argl body isstaged rett)))) + (method-def-expr- name sparams argl body rett)))) (define (struct-def-expr name params super fields mut) (receive @@ -763,12 +813,12 @@ ,@sig) new-params))))) -(define (ctor-def keyword name Tname params bounds sig ctor-body body wheres) +(define (ctor-def name Tname params bounds sig ctor-body body wheres) (let* ((curly? (and (pair? name) (eq? (car name) 'curly))) (curlyargs (if curly? (cddr name) '())) (name (if curly? (cadr name) name))) (cond ((not (eq? name Tname)) - `(,keyword ,(with-wheres `(call ,(if curly? + `(function ,(with-wheres `(call ,(if curly? `(curly ,name ,@curlyargs) name) ,@sig) @@ -777,7 +827,7 @@ ;; new{...} inside a non-ctor inner definition. ,(ctor-body body '()))) (wheres - `(,keyword ,(with-wheres `(call ,(if curly? + `(function ,(with-wheres `(call ,(if curly? `(curly ,name ,@curlyargs) name) ,@sig) @@ -791,7 +841,7 @@ (syntax-deprecation #f (string "inner constructor " name "(...)" (linenode-string (function-body-lineno body))) (deparse `(where (call (curly ,name ,@params) ...) ,@params)))) - `(,keyword ,sig ,(ctor-body body params))))))) + `(function ,sig ,(ctor-body body params))))))) (define (function-body-lineno body) (let ((lnos (filter (lambda (e) (and (pair? e) (eq? (car e) 'line))) @@ -818,18 +868,14 @@ (pattern-set ;; definitions without `where` (pattern-lambda (function (-$ (call name . sig) (|::| (call name . sig) _t)) body) - (ctor-def (car __) name Tname params bounds sig ctor-body body #f)) - (pattern-lambda (stagedfunction (-$ (call name . sig) (|::| (call name . sig) _t)) body) - (ctor-def (car __) name Tname params bounds sig ctor-body body #f)) + (ctor-def name Tname params bounds sig ctor-body body #f)) (pattern-lambda (= (-$ (call name . sig) (|::| (call name . sig) _t)) body) - (ctor-def 'function name Tname params bounds sig ctor-body body #f)) + (ctor-def name Tname params bounds sig ctor-body body #f)) ;; definitions with `where` (pattern-lambda (function (where (-$ (call name . sig) (|::| (call name . sig) _t)) . wheres) body) - (ctor-def (car __) name Tname params bounds sig ctor-body body wheres)) - (pattern-lambda (stagedfunction (where (-$ (call name . sig) (|::| (call name . sig) _t)) . wheres) body) - (ctor-def (car __) name Tname params bounds sig ctor-body body wheres)) + (ctor-def name Tname params bounds sig ctor-body body wheres)) (pattern-lambda (= (where (-$ (call name . sig) (|::| (call name . sig) _t)) . wheres) body) - (ctor-def 'function name Tname params bounds sig ctor-body body wheres))) + (ctor-def name Tname params bounds sig ctor-body body wheres))) ;; flatten `where`s first (pattern-replace @@ -970,7 +1016,7 @@ (loop (if isseq F (cdr F)) (cdr A) stmts (list* (if isamp `(& ,ca) ca) C) (list* g GC)))))))) -(define (expand-function-def e) ;; handle function or stagedfunction +(define (expand-function-def e) ;; handle function definitions (define (just-arglist? ex) (and (pair? ex) (or (memq (car ex) '(tuple block)) @@ -1054,7 +1100,6 @@ (where where) (else '()))) (sparams (map analyze-typevar raw-typevars)) - (isstaged (eq? (car e) 'stagedfunction)) (adj-decl (lambda (n) (if (and (decl? n) (length= n 2)) `(|::| |#self#| ,(cadr n)) n))) @@ -1083,7 +1128,7 @@ (cdr argl))) ,@raw-typevars)))) (expand-forms - (method-def-expr name sparams argl body isstaged rett)))) + (method-def-expr name sparams argl body rett)))) (else (error (string "invalid assignment location \"" (deparse name) "\"")))))) @@ -1888,7 +1933,6 @@ (define expand-table (table 'function expand-function-def - 'stagedfunction expand-function-def '-> expand-arrow 'let expand-let 'macro expand-macro-def @@ -3225,8 +3269,7 @@ f(x) = yt(x) ,@top-stmts (block ,@sp-inits (method ,name ,(cl-convert sig fname lam namemap toplevel interp) - ,(julia-bq-macro newlam) - ,(last e))))))) + ,(julia-bq-macro newlam))))))) ;; local case - lift to a new type at top level (let* ((exists (get namemap name #f)) (type-name (or exists @@ -3303,8 +3346,7 @@ f(x) = yt(x) (if iskw (caddr (lam:args lam2)) (car (lam:args lam2))) - #f closure-param-names) - ,(last e))))))) + #f closure-param-names))))))) (mk-closure ;; expression to make the closure (let* ((var-exprs (map (lambda (v) (let ((cv (assq v (cadr (lam:vinfo lam))))) @@ -3750,8 +3792,7 @@ f(x) = yt(x) (if (length> e 2) (begin (emit `(method ,(or (cadr e) 'false) ,(compile (caddr e) break-labels #t #f) - ,(linearize (cadddr e)) - ,(if (car (cddddr e)) 'true 'false))) + ,(linearize (cadddr e)))) (if value (compile '(null) break-labels value tail))) (cond (tail (emit-return e)) (value e) diff --git a/src/julia.h b/src/julia.h index 95c091affb5ee..32a7d2a62f4b8 100644 --- a/src/julia.h +++ b/src/julia.h @@ -248,7 +248,7 @@ typedef struct _jl_method_t { jl_svec_t *sparam_syms; // symbols giving static parameter names jl_value_t *source; // original code template (jl_code_info_t, but may be compressed), null for builtins struct _jl_method_instance_t *unspecialized; // unspecialized executable method instance, or null - struct _jl_method_instance_t *generator; // executable code-generating function if available + jl_value_t *generator; // executable code-generating function if available jl_array_t *roots; // pointers in generated code (shared to reduce memory), or null // cache of specializations of this method for invoke(), i.e. @@ -1055,7 +1055,7 @@ JL_DLLEXPORT jl_value_t *jl_generic_function_def(jl_sym_t *name, jl_module_t *module, jl_value_t **bp, jl_value_t *bp_owner, jl_binding_t *bnd); -JL_DLLEXPORT void jl_method_def(jl_svec_t *argdata, jl_code_info_t *f, jl_module_t *module, jl_value_t *isstaged); +JL_DLLEXPORT void jl_method_def(jl_svec_t *argdata, jl_code_info_t *f, jl_module_t *module); JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo); JL_DLLEXPORT jl_code_info_t *jl_copy_code_info(jl_code_info_t *src); JL_DLLEXPORT size_t jl_get_world_counter(void); diff --git a/src/julia_internal.h b/src/julia_internal.h index 4093a9ffe050d..afb14cd16f190 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -1000,6 +1000,8 @@ extern jl_sym_t *isdefined_sym; extern jl_sym_t *nospecialize_sym; extern jl_sym_t *boundscheck_sym; extern jl_sym_t *gc_preserve_begin_sym; extern jl_sym_t *gc_preserve_end_sym; +extern jl_sym_t *generated_sym; +extern jl_sym_t *generated_only_sym; struct _jl_sysimg_fptrs_t; diff --git a/src/macroexpand.scm b/src/macroexpand.scm index bbf72338d2a98..79c157f8eb581 100644 --- a/src/macroexpand.scm +++ b/src/macroexpand.scm @@ -400,11 +400,6 @@ (apply append (map decl-vars* (cdr e))) (list (decl-var* e)))) -(define (function-def? e) - (and (pair? e) (or (eq? (car e) 'function) (eq? (car e) '->) - (and (eq? (car e) '=) (length= e 3) - (eventually-call? (cadr e)))))) - ;; count hygienic / escape pairs ;; and fold together a list resulting from applying the function to ;; any block at the same hygienic scope diff --git a/src/method.c b/src/method.c index 40c027254743d..a984cb1ad9645 100644 --- a/src/method.c +++ b/src/method.c @@ -247,24 +247,23 @@ jl_code_info_t *jl_new_code_info_from_ast(jl_expr_t *ast) } // invoke (compiling if necessary) the jlcall function pointer for a method template -STATIC_INLINE jl_value_t *jl_call_staged(jl_svec_t *sparam_vals, jl_method_instance_t *generator, +STATIC_INLINE jl_value_t *jl_call_staged(jl_method_t *def, jl_value_t *generator, jl_svec_t *sparam_vals, jl_value_t **args, uint32_t nargs) { - jl_generic_fptr_t fptr; - fptr.fptr = generator->fptr; - fptr.jlcall_api = generator->jlcall_api; - if (__unlikely(fptr.fptr == NULL || fptr.jlcall_api == 0)) { - size_t world = generator->def.method->min_world; - const char *F = jl_compile_linfo(&generator, (jl_code_info_t*)generator->inferred, world, &jl_default_cgparams).functionObject; - fptr = jl_generate_fptr(generator, F, world); + size_t n_sparams = jl_svec_len(sparam_vals); + jl_value_t **gargs; + size_t totargs = 1 + n_sparams + nargs + def->isva; + JL_GC_PUSHARGS(gargs, totargs); + gargs[0] = generator; + memcpy(&gargs[1], jl_svec_data(sparam_vals), n_sparams * sizeof(void*)); + memcpy(&gargs[1 + n_sparams], args, nargs * sizeof(void*)); + if (def->isva) { + gargs[totargs-1] = jl_f_tuple(NULL, &gargs[1 + n_sparams + def->nargs - 1], nargs - (def->nargs - 1)); + gargs[1 + n_sparams + def->nargs - 1] = gargs[totargs - 1]; } - assert(jl_svec_len(generator->def.method->sparam_syms) == jl_svec_len(sparam_vals)); - if (fptr.jlcall_api == 1) - return fptr.fptr1(args[0], &args[1], nargs-1); - else if (fptr.jlcall_api == 3) - return fptr.fptr3(sparam_vals, args[0], &args[1], nargs-1); - else - abort(); // shouldn't have inferred any other calling convention + jl_value_t *code = jl_apply(gargs, 1 + n_sparams + def->nargs); + JL_GC_POP(); + return code; } // return a newly allocated CodeInfo for the function signature @@ -273,71 +272,34 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo) { JL_TIMING(STAGED_FUNCTION); jl_tupletype_t *tt = (jl_tupletype_t*)linfo->specTypes; - jl_svec_t *env = linfo->sparam_vals; - jl_expr_t *ex = NULL; - jl_value_t *linenum = NULL; - jl_svec_t *sparam_vals = env; - jl_method_instance_t *generator = linfo->def.method->generator; + jl_method_t *def = linfo->def.method; + jl_value_t *generator = def->generator; assert(generator != NULL); - assert(linfo != generator); + assert(jl_is_method(def)); jl_code_info_t *func = NULL; - JL_GC_PUSH4(&ex, &linenum, &sparam_vals, &func); + jl_value_t *ex = NULL; + JL_GC_PUSH2(&ex, &func); jl_ptls_t ptls = jl_get_ptls_states(); int last_lineno = jl_lineno; int last_in = ptls->in_pure_callback; jl_module_t *last_m = ptls->current_module; jl_module_t *task_last_m = ptls->current_task->current_module; size_t last_age = jl_get_ptls_states()->world_age; - assert(jl_svec_len(linfo->def.method->sparam_syms) == jl_svec_len(sparam_vals)); + JL_TRY { ptls->in_pure_callback = 1; // need to eval macros in the right module ptls->current_task->current_module = ptls->current_module = linfo->def.method->module; // and the right world - ptls->world_age = generator->def.method->min_world; - - ex = jl_exprn(lambda_sym, 2); - - jl_array_t *argnames = jl_alloc_vec_any(linfo->def.method->nargs); - jl_array_ptr_set(ex->args, 0, argnames); - jl_fill_argnames((jl_array_t*)generator->inferred, argnames); - - // build the rest of the body to pass to expand - jl_expr_t *scopeblock = jl_exprn(jl_symbol("scope-block"), 1); - jl_array_ptr_set(ex->args, 1, scopeblock); - jl_expr_t *body = jl_exprn(jl_symbol("block"), 3); - jl_array_ptr_set(((jl_expr_t*)jl_exprarg(ex, 1))->args, 0, body); - - // add location meta - linenum = jl_box_long(linfo->def.method->line); - jl_value_t *linenode = jl_new_struct(jl_linenumbernode_type, linenum, linfo->def.method->file); - jl_array_ptr_set(body->args, 0, linenode); - jl_expr_t *pushloc = jl_exprn(meta_sym, 3); - jl_array_ptr_set(body->args, 1, pushloc); - jl_array_ptr_set(pushloc->args, 0, jl_symbol("push_loc")); - jl_array_ptr_set(pushloc->args, 1, linfo->def.method->file); // file - jl_array_ptr_set(pushloc->args, 2, jl_symbol("@generated body")); // function + ptls->world_age = def->min_world; // invoke code generator - assert(jl_nparams(tt) == jl_array_len(argnames) || - (linfo->def.method->isva && (jl_nparams(tt) >= jl_array_len(argnames) - 1))); - jl_value_t *generated_body = jl_call_staged(sparam_vals, generator, jl_svec_data(tt->parameters), jl_nparams(tt)); - jl_array_ptr_set(body->args, 2, generated_body); - - if (jl_is_code_info(generated_body)) { - func = (jl_code_info_t*)generated_body; - } else { - if (linfo->def.method->sparam_syms != jl_emptysvec) { - // mark this function as having the same static parameters as the generator - size_t i, nsp = jl_svec_len(linfo->def.method->sparam_syms); - jl_expr_t *newast = jl_exprn(jl_symbol("with-static-parameters"), nsp + 1); - jl_exprarg(newast, 0) = (jl_value_t*)ex; - // (with-static-parameters func_expr sp_1 sp_2 ...) - for (i = 0; i < nsp; i++) - jl_exprarg(newast, i+1) = jl_svecref(linfo->def.method->sparam_syms, i); - ex = newast; - } + jl_value_t *ex = jl_call_staged(linfo->def.method, generator, linfo->sparam_vals, jl_svec_data(tt->parameters), jl_nparams(tt)); + if (jl_is_code_info(ex)) { + func = (jl_code_info_t*)ex; + } + else { func = (jl_code_info_t*)jl_expand((jl_value_t*)ex, linfo->def.method->module); if (!jl_is_code_info(func)) { if (jl_is_expr(func) && ((jl_expr_t*)func)->head == error_sym) @@ -349,15 +311,9 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo) size_t i, l; for (i = 0, l = jl_array_len(stmts); i < l; i++) { jl_value_t *stmt = jl_array_ptr_ref(stmts, i); - stmt = jl_resolve_globals(stmt, linfo->def.method->module, env); + stmt = jl_resolve_globals(stmt, linfo->def.method->module, linfo->sparam_vals); jl_array_ptr_set(stmts, i, stmt); } - - // add pop_loc meta - jl_array_ptr_1d_push(stmts, jl_nothing); - jl_expr_t *poploc = jl_exprn(meta_sym, 1); - jl_array_ptr_set(stmts, jl_array_len(stmts) - 1, poploc); - jl_array_ptr_set(poploc->args, 0, jl_symbol("pop_loc")); } ptls->in_pure_callback = last_in; @@ -404,6 +360,7 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src) { uint8_t j; uint8_t called = 0; + int gen_only = 0; for (j = 1; j < m->nargs && j <= 8; j++) { jl_value_t *ai = jl_array_ptr_ref(src->slotnames, j); if (ai == (jl_value_t*)unused_sym) @@ -434,28 +391,50 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src) set_lineno = 1; } } - else if (jl_is_expr(st) && ((jl_expr_t*)st)->head == meta_sym && - jl_expr_nargs(st) > 1 && jl_exprarg(st, 0) == (jl_value_t*)nospecialize_sym) { - for (size_t j=1; j < jl_expr_nargs(st); j++) { - jl_value_t *aj = jl_exprarg(st, j); - if (jl_is_slot(aj)) { - int sn = (int)jl_slot_number(aj) - 2; - if (sn >= 0) { // @nospecialize on self is valid but currently ignored - if (sn > (m->nargs - 2)) { - jl_error("@nospecialize annotation applied to a non-argument"); - } - else if (sn >= sizeof(m->nospecialize) * 8) { - jl_printf(JL_STDERR, - "WARNING: @nospecialize annotation only supported on the first %d arguments.\n", - (int)(sizeof(m->nospecialize) * 8)); - } - else { - m->nospecialize |= (1 << sn); + else if (jl_is_expr(st) && ((jl_expr_t*)st)->head == meta_sym) { + if (jl_expr_nargs(st) > 1 && jl_exprarg(st, 0) == (jl_value_t*)nospecialize_sym) { + for (size_t j=1; j < jl_expr_nargs(st); j++) { + jl_value_t *aj = jl_exprarg(st, j); + if (jl_is_slot(aj)) { + int sn = (int)jl_slot_number(aj) - 2; + if (sn >= 0) { // @nospecialize on self is valid but currently ignored + if (sn > (m->nargs - 2)) { + jl_error("@nospecialize annotation applied to a non-argument"); + } + else if (sn >= sizeof(m->nospecialize) * 8) { + jl_printf(JL_STDERR, + "WARNING: @nospecialize annotation only supported on the first %d arguments.\n", + (int)(sizeof(m->nospecialize) * 8)); + } + else { + m->nospecialize |= (1 << sn); + } } } } + st = jl_nothing; + } + else if (jl_expr_nargs(st) == 2 && jl_exprarg(st, 0) == (jl_value_t*)generated_sym) { + m->generator = NULL; + jl_value_t *gexpr = jl_exprarg(st, 1); + if (jl_expr_nargs(gexpr) == 6) { + // expects (new (core GeneratedFunctionStub) funcname argnames sp line file) + jl_value_t *funcname = jl_exprarg(gexpr, 1); + assert(jl_is_symbol(funcname)); + if (jl_get_global(m->module, (jl_sym_t*)funcname) != NULL) { + m->generator = jl_toplevel_eval(m->module, gexpr); + jl_gc_wb(m, m->generator); + } + } + if (m->generator == NULL) { + jl_error("invalid @generated function; try placing it in global scope"); + } + st = jl_nothing; + } + else if (jl_expr_nargs(st) == 1 && jl_exprarg(st, 0) == (jl_value_t*)generated_only_sym) { + gen_only = 1; + st = jl_nothing; } - st = jl_nothing; } else { st = jl_resolve_globals(st, m->module, sparam_vars); @@ -465,7 +444,10 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src) src = jl_copy_code_info(src); src->code = copy; jl_gc_wb(src, copy); - m->source = (jl_value_t*)jl_compress_ast(m, src); + if (gen_only) + m->source = NULL; + else + m->source = (jl_value_t*)jl_compress_ast(m, src); jl_gc_wb(m, m->source); JL_GC_POP(); } @@ -506,8 +488,7 @@ static jl_method_t *jl_new_method( jl_tupletype_t *sig, size_t nargs, int isva, - jl_svec_t *tvars, - int isstaged) + jl_svec_t *tvars) { size_t i, l = jl_svec_len(tvars); jl_svec_t *sparam_syms = jl_alloc_svec_uninit(l); @@ -527,13 +508,6 @@ static jl_method_t *jl_new_method( m->isva = isva; m->nargs = nargs; jl_method_set_source(m, definition); - if (isstaged) { - // create and store generator for generated functions - m->generator = jl_get_specialized(m, (jl_value_t*)jl_anytuple_type, jl_emptysvec); - jl_gc_wb(m, m->generator); - m->generator->inferred = (jl_value_t*)m->source; - m->source = NULL; - } #ifdef RECORD_METHOD_ORDER if (jl_all_methods == NULL) @@ -653,8 +627,7 @@ extern tracer_cb jl_newmeth_tracer; JL_DLLEXPORT void jl_method_def(jl_svec_t *argdata, jl_code_info_t *f, - jl_module_t *module, - jl_value_t *isstaged) + jl_module_t *module) { // argdata is svec(svec(types...), svec(typevars...)) jl_svec_t *atypes = (jl_svec_t*)jl_svecref(argdata, 0); @@ -711,7 +684,7 @@ JL_DLLEXPORT void jl_method_def(jl_svec_t *argdata, // the result is that the closure variables get interpolated directly into the AST f = jl_new_code_info_from_ast((jl_expr_t*)f); } - m = jl_new_method(f, name, module, (jl_tupletype_t*)argtype, nargs, isva, tvars, isstaged == jl_true); + m = jl_new_method(f, name, module, (jl_tupletype_t*)argtype, nargs, isva, tvars); m->nospecialize |= nospec; if (jl_has_free_typevars(argtype)) { diff --git a/src/utils.scm b/src/utils.scm index 97842a387b544..211d79ffef7b5 100644 --- a/src/utils.scm +++ b/src/utils.scm @@ -40,12 +40,13 @@ (cdr expr))))) ;; same as above, with predicate -(define (expr-contains-p p expr) - (or (p expr) - (and (pair? expr) - (not (quoted? expr)) - (any (lambda (y) (expr-contains-p p y)) - (cdr expr))))) +(define (expr-contains-p p expr (filt (lambda (x) #t))) + (and (filt expr) + (or (p expr) + (and (pair? expr) + (not (quoted? expr)) + (any (lambda (y) (expr-contains-p p y filt)) + (cdr expr)))))) ;; find all subexprs satisfying `p`, applying `key` to each one (define (expr-find-all p expr key (filt (lambda (x) #t))) diff --git a/test/staged.jl b/test/staged.jl index 7bcd41a478ced..336594d271443 100644 --- a/test/staged.jl +++ b/test/staged.jl @@ -250,3 +250,25 @@ end @test f22440(0.0) === f22440kernel(0.0) @test f22440(0.0f0) === f22440kernel(0.0f0) @test f22440(0) === f22440kernel(0) + +# PR #23168 + +function f23168(a, x) + push!(a, 1) + if @generated + :(y = (x + x, $x)) + else + y = (2x, typeof(x)) + end + push!(a, 2) + return y +end + +let a = Any[] + @test f23168(a, 3) == (6, Int) + @test a == [1, 2] + @test contains(string(code_lowered(f23168, (Vector{Any},Int))), "x + x") + @test contains(string(Base.uncompressed_ast(first(methods(f23168)))), "2 * x") + @test contains(string(code_lowered(f23168, (Vector{Any},Int), false)), "2 * x") + @test contains(string(code_typed(f23168, (Vector{Any},Int))), "(Base.add_int)(x, x)") +end