diff --git a/src/Diffractor.jl b/src/Diffractor.jl index dae4e1e3..e4fef34b 100644 --- a/src/Diffractor.jl +++ b/src/Diffractor.jl @@ -1,11 +1,12 @@ module Diffractor +export ∂⃖, gradient + using StructArrays using PrecompileTools -export ∂⃖, gradient - const CC = Core.Compiler +using Core.IR @static if VERSION ≥ v"1.11.0-DEV.1498" import .CC: get_inference_world @@ -33,7 +34,6 @@ end include("stage2/tfuncs.jl") include("stage2/forward.jl") - include("codegen/forward.jl") include("analysis/forward.jl") include("codegen/forward_demand.jl") include("codegen/reverse.jl") diff --git a/src/codegen/forward.jl b/src/codegen/forward.jl deleted file mode 100644 index 9a310379..00000000 --- a/src/codegen/forward.jl +++ /dev/null @@ -1,117 +0,0 @@ -function fwd_transform(ci, args...) - newci = copy(ci) - fwd_transform!(newci, args...) - return newci -end - -function fwd_transform!(ci, mi, nargs, N) - new_code = Any[] - new_codelocs = Any[] - ssa_mapping = Int[] - loc_mapping = Int[] - - emit!(@nospecialize stmt) = stmt - function emit!(stmt::Expr) - stmt.head ∈ (:call, :(=), :new, :isdefined) || return stmt - push!(new_code, stmt) - push!(new_codelocs, isempty(new_codelocs) ? 0 : new_codelocs[end]) - return SSAValue(length(new_code)) - end - - function mapstmt!(@nospecialize stmt) - if isexpr(stmt, :(=)) - return Expr(stmt.head, emit!(mapstmt!(stmt.args[1])), emit!(mapstmt!(stmt.args[2]))) - elseif isexpr(stmt, :call) - args = map(stmt.args) do stmt - emit!(mapstmt!(stmt)) - end - return Expr(:call, ∂☆{N}(), args...) - elseif isexpr(stmt, :new) - args = map(stmt.args) do stmt - emit!(mapstmt!(stmt)) - end - return Expr(:call, ∂☆new{N}(), args...) - elseif isexpr(stmt, :splatnew) - args = map(stmt.args) do stmt - emit!(mapstmt!(stmt)) - end - return Expr(:call, Core._apply_iterate, FwdIterate(DNEBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...) - elseif isa(stmt, SSAValue) - return SSAValue(ssa_mapping[stmt.id]) - elseif isa(stmt, Core.SlotNumber) - return SlotNumber(2 + stmt.id) - elseif isa(stmt, Argument) - return SlotNumber(2 + stmt.n) - elseif isa(stmt, NewvarNode) - return NewvarNode(SlotNumber(2 + stmt.slot.id)) - elseif isa(stmt, ReturnNode) - return ReturnNode(emit!(mapstmt!(stmt.val))) - elseif isa(stmt, GotoNode) - return stmt - elseif isa(stmt, GotoIfNot) - return GotoIfNot(emit!(Expr(:call, primal, emit!(mapstmt!(stmt.cond)))), stmt.dest) - elseif isexpr(stmt, :static_parameter) - return ZeroBundle{N}(mi.sparam_vals[stmt.args[1]::Int]) - elseif isexpr(stmt, :foreigncall) - return Expr(:call, error, "Attempted to AD a foreigncall. Missing rule?") - elseif isexpr(stmt, :meta) || isexpr(stmt, :inbounds) || isexpr(stmt, :loopinfo) || - isexpr(stmt, :code_coverage_effect) - # Can't trust that meta annotations are still valid in the AD'd - # version. - return nothing - elseif isexpr(stmt, :isdefined) - return Expr(:call, zero_bundle{N}(), emit!(stmt)) - # Always disable `@inbounds`, as we don't actually know if the AD'd - # code is truly `@inbounds` or not. - elseif isexpr(stmt, :boundscheck) - return DNEBundle{N}(true) - else - # Fallback case, for literals. - # If it is an Expr, then it is not a literal - if isa(stmt, Expr) - error("Unexprected statement encountered. This is a bug in Diffractor. stmt=$stmt") - end - return Expr(:call, zero_bundle{N}(), stmt) - end - end - - meth = mi.def::Method - for i = 1:meth.nargs - if meth.isva && i == meth.nargs - args = map(i:(nargs+1)) do j::Int - emit!(Expr(:call, getfield, SlotNumber(2), j)) - end - emit!(Expr(:(=), SlotNumber(2 + i), Expr(:call, ∂vararg{N}(), args...))) - else - emit!(Expr(:(=), SlotNumber(2 + i), Expr(:call, getfield, SlotNumber(2), i))) - end - end - - for (stmt, codeloc) in zip(ci.code, ci.codelocs) - push!(loc_mapping, length(new_code)+1) - push!(new_codelocs, codeloc) - push!(new_code, mapstmt!(stmt)) - push!(ssa_mapping, length(new_code)) - end - - # Rewrite control flow - for (i, stmt) in enumerate(new_code) - if isa(stmt, GotoNode) - new_code[i] = GotoNode(loc_mapping[stmt.label]) - elseif isa(stmt, GotoIfNot) - new_code[i] = GotoIfNot(stmt.cond, loc_mapping[stmt.dest]) - end - end - - ci.slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...] - ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...] - ci.slottypes = ci.slottypes === nothing ? nothing : Any[Any, Any, ci.slottypes...] - ci.code = new_code - ci.codelocs = new_codelocs - ci.ssavaluetypes = length(new_code) - ci.ssaflags = UInt8[0 for i=1:length(new_code)] - ci.method_for_inference_limit_heuristics = meth - ci.edges = MethodInstance[mi] - - return ci -end diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index fa8a99fe..20c42f6c 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -73,6 +73,135 @@ function ∂☆builtin((f_bundle, args...)) throw(DomainError(f, "No `ChainRulesCore.frule` found for the built-in function `$sig`")) end +function fwd_transform(ci::CodeInfo, args...) + newci = copy(ci) + fwd_transform!(newci, args...) + return newci +end + +function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int) + new_code = Any[] + @static if VERSION ≥ v"1.12.0-DEV.173" + debuginfo = Core.Compiler.DebugInfoStream(mi, ci.debuginfo, length(ci.code)) + new_codelocs = Int32[] + else + new_codelocs = Any[] + end + ssa_mapping = Int[] + loc_mapping = Int[] + + emit!(@nospecialize stmt) = stmt + function emit!(stmt::Expr) + stmt.head ∈ (:call, :(=), :new, :isdefined) || return stmt + push!(new_code, stmt) + push!(new_codelocs, isempty(new_codelocs) ? 0 : new_codelocs[end]) + return SSAValue(length(new_code)) + end + + function mapstmt!(@nospecialize stmt) + if isexpr(stmt, :(=)) + return Expr(stmt.head, emit!(mapstmt!(stmt.args[1])), emit!(mapstmt!(stmt.args[2]))) + elseif isexpr(stmt, :call) + args = map(stmt.args) do stmt + emit!(mapstmt!(stmt)) + end + return Expr(:call, ∂☆{N}(), args...) + elseif isexpr(stmt, :new) + args = map(stmt.args) do stmt + emit!(mapstmt!(stmt)) + end + return Expr(:call, ∂☆new{N}(), args...) + elseif isexpr(stmt, :splatnew) + args = map(stmt.args) do stmt + emit!(mapstmt!(stmt)) + end + return Expr(:call, Core._apply_iterate, FwdIterate(DNEBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...) + elseif isa(stmt, SSAValue) + return SSAValue(ssa_mapping[stmt.id]) + elseif isa(stmt, Core.SlotNumber) + return SlotNumber(2 + stmt.id) + elseif isa(stmt, Argument) + return SlotNumber(2 + stmt.n) + elseif isa(stmt, NewvarNode) + return NewvarNode(SlotNumber(2 + stmt.slot.id)) + elseif isa(stmt, ReturnNode) + return ReturnNode(emit!(mapstmt!(stmt.val))) + elseif isa(stmt, GotoNode) + return stmt + elseif isa(stmt, GotoIfNot) + return GotoIfNot(emit!(Expr(:call, primal, emit!(mapstmt!(stmt.cond)))), stmt.dest) + elseif isexpr(stmt, :static_parameter) + return ZeroBundle{N}(mi.sparam_vals[stmt.args[1]::Int]) + elseif isexpr(stmt, :foreigncall) + return Expr(:call, error, "Attempted to AD a foreigncall. Missing rule?") + elseif isexpr(stmt, :meta) || isexpr(stmt, :inbounds) || isexpr(stmt, :loopinfo) || + isexpr(stmt, :code_coverage_effect) + # Can't trust that meta annotations are still valid in the AD'd + # version. + return nothing + elseif isexpr(stmt, :isdefined) + return Expr(:call, zero_bundle{N}(), emit!(stmt)) + # Always disable `@inbounds`, as we don't actually know if the AD'd + # code is truly `@inbounds` or not. + elseif isexpr(stmt, :boundscheck) + return DNEBundle{N}(true) + else + # Fallback case, for literals. + # If it is an Expr, then it is not a literal + if isa(stmt, Expr) + error("Unexprected statement encountered. This is a bug in Diffractor. stmt=$stmt") + end + return Expr(:call, zero_bundle{N}(), stmt) + end + end + + meth = mi.def::Method + for i = 1:meth.nargs + if meth.isva && i == meth.nargs + args = map(i:(nargs+1)) do j::Int + emit!(Expr(:call, getfield, SlotNumber(2), j)) + end + emit!(Expr(:(=), SlotNumber(2 + i), Expr(:call, ∂vararg{N}(), args...))) + else + emit!(Expr(:(=), SlotNumber(2 + i), Expr(:call, getfield, SlotNumber(2), i))) + end + end + + for (stmt, codeloc) in zip(ci.code, @static VERSION ≥ v"1.12.0-DEV.173" ? debuginfo.codelocs : ci.codelocs) + push!(loc_mapping, length(new_code)+1) + push!(new_codelocs, codeloc) + push!(new_code, mapstmt!(stmt)) + push!(ssa_mapping, length(new_code)) + end + + # Rewrite control flow + for (i, stmt) in enumerate(new_code) + if isa(stmt, GotoNode) + new_code[i] = GotoNode(loc_mapping[stmt.label]) + elseif isa(stmt, GotoIfNot) + new_code[i] = GotoIfNot(stmt.cond, loc_mapping[stmt.dest]) + end + end + + ci.slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...] + ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...] + ci.slottypes = ci.slottypes === nothing ? nothing : Any[Any, Any, ci.slottypes...] + ci.code = new_code + @static if VERSION ≥ v"1.12.0-DEV.173" + empty!(debuginfo.codelocs) + append!(debuginfo.codelocs, new_codelocs) + ci.codelocs = Core.DebugInfo(debuginfo, length(new_code)) + else + ci.codelocs = new_codelocs + end + ci.ssavaluetypes = length(new_code) + ci.ssaflags = UInt8[0 for i=1:length(new_code)] + ci.method_for_inference_limit_heuristics = meth + ci.edges = MethodInstance[mi] + + return ci +end + function perform_fwd_transform(world::UInt, source::LineNumberNode, @nospecialize(ff::Type{∂☆recurse{N}}), @nospecialize(args)) where {N} if all(x->x <: ZeroBundle, args) diff --git a/src/stage2/interpreter.jl b/src/stage2/interpreter.jl index de09698f..dd4a3711 100644 --- a/src/stage2/interpreter.jl +++ b/src/stage2/interpreter.jl @@ -327,13 +327,9 @@ function CC.src_inlining_policy(interp::ADInterpreter, ret = diffractor_inlining_policy(src, info, stmt_flag) ret === nothing && return false ret !== missing && return true - return CC.src_inlining_policy(interp::AbstractInterpreter, + return @invoke CC.src_inlining_policy(interp::AbstractInterpreter, src::Any, info::CC.CallInfo, stmt_flag::StmtFlag) end -CC.retrieve_ir_for_inlining(cached_result::CodeInstance, src::Cthulhu.OptimizedSource) = - CC.retrieve_ir_for_inlining(cached_result.def, src.ir, true) -CC.retrieve_ir_for_inlining(mi::MethodInstance, src::Cthulhu.OptimizedSource, preserve_local_sources::Bool) = - CC.retrieve_ir_for_inlining(mi, src.ir, preserve_local_sources) else function CC.transform_result_for_cache(interp::ADInterpreter, linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult) diff --git a/src/stage2/lattice.jl b/src/stage2/lattice.jl index 8683a795..663b5ffe 100644 --- a/src/stage2/lattice.jl +++ b/src/stage2/lattice.jl @@ -1,4 +1,4 @@ -using Core.Compiler: CodeInfo, CallInfo, CallMeta +using Core.Compiler: CallInfo, CallMeta import Core.Compiler: widenconst struct CompClosure; opaque; end # TODO: Is this a YAKC?