Skip to content

Commit

Permalink
wip: adjustments to the latest master
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Mar 22, 2024
1 parent 9883872 commit 9ddbb2a
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 126 deletions.
6 changes: 3 additions & 3 deletions src/Diffractor.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
module Diffractor

export ∂⃖, gradient

using StructArrays
using PrecompileTools

export ∂⃖, gradient

const CC = Core.Compiler
using Core.IR

@static if VERSION v"1.11.0-DEV.1498"
import .CC: get_inference_world
Expand Down Expand Up @@ -33,7 +34,6 @@ end
include("stage2/tfuncs.jl")
include("stage2/forward.jl")

include("codegen/forward.jl")
include("analysis/forward.jl")
include("codegen/forward_demand.jl")
include("codegen/reverse.jl")
Expand Down
117 changes: 0 additions & 117 deletions src/codegen/forward.jl

This file was deleted.

129 changes: 129 additions & 0 deletions src/stage1/recurse_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,135 @@ function ∂☆builtin((f_bundle, args...))
throw(DomainError(f, "No `ChainRulesCore.frule` found for the built-in function `$sig`"))
end

function fwd_transform(ci::CodeInfo, args...)
newci = copy(ci)
fwd_transform!(newci, args...)
return newci
end

function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int)
new_code = Any[]
@static if VERSION v"1.12.0-DEV.173"
debuginfo = Core.Compiler.DebugInfoStream(mi, ci.debuginfo, length(ci.code))
new_codelocs = Int32[]
else
new_codelocs = Any[]
end
ssa_mapping = Int[]
loc_mapping = Int[]

emit!(@nospecialize stmt) = stmt
function emit!(stmt::Expr)
stmt.head (:call, :(=), :new, :isdefined) || return stmt
push!(new_code, stmt)
push!(new_codelocs, isempty(new_codelocs) ? 0 : new_codelocs[end])
return SSAValue(length(new_code))
end

function mapstmt!(@nospecialize stmt)
if isexpr(stmt, :(=))
return Expr(stmt.head, emit!(mapstmt!(stmt.args[1])), emit!(mapstmt!(stmt.args[2])))
elseif isexpr(stmt, :call)
args = map(stmt.args) do stmt
emit!(mapstmt!(stmt))
end
return Expr(:call, ∂☆{N}(), args...)
elseif isexpr(stmt, :new)
args = map(stmt.args) do stmt
emit!(mapstmt!(stmt))
end
return Expr(:call, ∂☆new{N}(), args...)
elseif isexpr(stmt, :splatnew)
args = map(stmt.args) do stmt
emit!(mapstmt!(stmt))
end
return Expr(:call, Core._apply_iterate, FwdIterate(DNEBundle{N}(iterate)), ∂☆new{N}(), emit!(Expr(:call, tuple, args[1])), args[2:end]...)
elseif isa(stmt, SSAValue)
return SSAValue(ssa_mapping[stmt.id])
elseif isa(stmt, Core.SlotNumber)
return SlotNumber(2 + stmt.id)
elseif isa(stmt, Argument)
return SlotNumber(2 + stmt.n)
elseif isa(stmt, NewvarNode)
return NewvarNode(SlotNumber(2 + stmt.slot.id))
elseif isa(stmt, ReturnNode)
return ReturnNode(emit!(mapstmt!(stmt.val)))
elseif isa(stmt, GotoNode)
return stmt
elseif isa(stmt, GotoIfNot)
return GotoIfNot(emit!(Expr(:call, primal, emit!(mapstmt!(stmt.cond)))), stmt.dest)
elseif isexpr(stmt, :static_parameter)
return ZeroBundle{N}(mi.sparam_vals[stmt.args[1]::Int])
elseif isexpr(stmt, :foreigncall)
return Expr(:call, error, "Attempted to AD a foreigncall. Missing rule?")
elseif isexpr(stmt, :meta) || isexpr(stmt, :inbounds) || isexpr(stmt, :loopinfo) ||
isexpr(stmt, :code_coverage_effect)
# Can't trust that meta annotations are still valid in the AD'd
# version.
return nothing
elseif isexpr(stmt, :isdefined)
return Expr(:call, zero_bundle{N}(), emit!(stmt))
# Always disable `@inbounds`, as we don't actually know if the AD'd
# code is truly `@inbounds` or not.
elseif isexpr(stmt, :boundscheck)
return DNEBundle{N}(true)
else
# Fallback case, for literals.
# If it is an Expr, then it is not a literal
if isa(stmt, Expr)
error("Unexprected statement encountered. This is a bug in Diffractor. stmt=$stmt")
end
return Expr(:call, zero_bundle{N}(), stmt)
end
end

meth = mi.def::Method
for i = 1:meth.nargs
if meth.isva && i == meth.nargs
args = map(i:(nargs+1)) do j::Int
emit!(Expr(:call, getfield, SlotNumber(2), j))
end
emit!(Expr(:(=), SlotNumber(2 + i), Expr(:call, ∂vararg{N}(), args...)))
else
emit!(Expr(:(=), SlotNumber(2 + i), Expr(:call, getfield, SlotNumber(2), i)))
end
end

for (stmt, codeloc) in zip(ci.code, @static VERSION v"1.12.0-DEV.173" ? debuginfo.codelocs : ci.codelocs)
push!(loc_mapping, length(new_code)+1)
push!(new_codelocs, codeloc)
push!(new_code, mapstmt!(stmt))
push!(ssa_mapping, length(new_code))
end

# Rewrite control flow
for (i, stmt) in enumerate(new_code)
if isa(stmt, GotoNode)
new_code[i] = GotoNode(loc_mapping[stmt.label])
elseif isa(stmt, GotoIfNot)
new_code[i] = GotoIfNot(stmt.cond, loc_mapping[stmt.dest])
end
end

ci.slotnames = Symbol[Symbol("#self#"), :args, ci.slotnames...]
ci.slotflags = UInt8[0x00, 0x00, ci.slotflags...]
ci.slottypes = ci.slottypes === nothing ? nothing : Any[Any, Any, ci.slottypes...]
ci.code = new_code
@static if VERSION v"1.12.0-DEV.173"
empty!(debuginfo.codelocs)
append!(debuginfo.codelocs, new_codelocs)
ci.codelocs = Core.DebugInfo(debuginfo, length(new_code))
else
ci.codelocs = new_codelocs
end
ci.ssavaluetypes = length(new_code)
ci.ssaflags = UInt8[0 for i=1:length(new_code)]
ci.method_for_inference_limit_heuristics = meth
ci.edges = MethodInstance[mi]

return ci
end

function perform_fwd_transform(world::UInt, source::LineNumberNode,
@nospecialize(ff::Type{∂☆recurse{N}}), @nospecialize(args)) where {N}
if all(x->x <: ZeroBundle, args)
Expand Down
6 changes: 1 addition & 5 deletions src/stage2/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -327,13 +327,9 @@ function CC.src_inlining_policy(interp::ADInterpreter,
ret = diffractor_inlining_policy(src, info, stmt_flag)
ret === nothing && return false
ret !== missing && return true
return CC.src_inlining_policy(interp::AbstractInterpreter,
return @invoke CC.src_inlining_policy(interp::AbstractInterpreter,
src::Any, info::CC.CallInfo, stmt_flag::StmtFlag)
end
CC.retrieve_ir_for_inlining(cached_result::CodeInstance, src::Cthulhu.OptimizedSource) =
CC.retrieve_ir_for_inlining(cached_result.def, src.ir, true)
CC.retrieve_ir_for_inlining(mi::MethodInstance, src::Cthulhu.OptimizedSource, preserve_local_sources::Bool) =
CC.retrieve_ir_for_inlining(mi, src.ir, preserve_local_sources)
else
function CC.transform_result_for_cache(interp::ADInterpreter,
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult)
Expand Down
2 changes: 1 addition & 1 deletion src/stage2/lattice.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Core.Compiler: CodeInfo, CallInfo, CallMeta
using Core.Compiler: CallInfo, CallMeta
import Core.Compiler: widenconst

struct CompClosure; opaque; end # TODO: Is this a YAKC?
Expand Down

0 comments on commit 9ddbb2a

Please sign in to comment.