Skip to content

Commit

Permalink
Mock Enzyme plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Oct 3, 2024
1 parent 4ef6019 commit 1b8d203
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 4 deletions.
2 changes: 0 additions & 2 deletions test/native_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,6 @@ end
# smoke test
job, _ = Native.create_job(eval(kernel), (Int64,))

# TODO: Add a `kernel=true` test

ci, rt = only(GPUCompiler.code_typed(job))
@test rt === Ptr{Cvoid}

Expand Down
177 changes: 175 additions & 2 deletions test/plugin_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct NeverInlineMeta <: InlineStateMeta end
import GPUCompiler: abstract_call_known, GPUInterpreter
import Core.Compiler: CallMeta, Effects, NoCallInfo, ArgInfo,
StmtInfo, AbsIntState, EFFECTS_TOTAL,
MethodResultPure
MethodResultPure, CallInfo, IRCode

function abstract_call_known(meta::InlineStateMeta, interp::GPUInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int)
Expand Down Expand Up @@ -69,5 +69,178 @@ function inlining_handler(meta::InlineStateMeta, interp::GPUInterpreter, @nospec
return nothing
end

struct MockEnzymeMeta end

end
# Having to define this function is annoying
# introduce `abstract type InferenceMeta`
function inlining_handler(meta::MockEnzymeMeta, interp::GPUInterpreter, @nospecialize(atype), callinfo)
return nothing
end

function autodiff end

import GPUCompiler: DeferredCallInfo
struct AutodiffCallInfo <: CallInfo
rt
info::DeferredCallInfo
end

function abstract_call_known(meta::Nothing, interp::GPUInterpreter, f::typeof(autodiff),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int)
(; fargs, argtypes) = arginfo

@assert f === autodiff
if length(argtypes) <= 1
@static if VERSION < v"1.11.0-"
return CallMeta(Union{}, Effects(), NoCallInfo())
else
return CallMeta(Union{}, Union{}, Effects(), NoCallInfo())
end
end

other_fargs = fargs === nothing ? nothing : fargs[2:end]
other_arginfo = ArgInfo(other_fargs, argtypes[2:end])
# TODO: Ought we not change absint to use MockEnzymeMeta(), otherwise we fill the cache for nothing.
call = Core.Compiler.abstract_call(interp, other_arginfo, si, sv, max_methods)
callinfo = DeferredCallInfo(MockEnzymeMeta(), call.rt, call.info)

# Real Enzyme must compute `rt` and `exct` according to enzyme semantics
# and likely perform a unwrapping of fargs...
rt = call.rt

# TODO: Edges? Effects?
@static if VERSION < v"1.11.0-"
# Can't use call.effects since otherwise this call might be just replaced with rt
return CallMeta(rt, Effects(), AutodiffCallInfo(rt, callinfo))
else
return CallMeta(rt, call.exct, Effects(), AutodiffCallInfo(rt, callinfo))
end
end

function abstract_call_known(meta::MockEnzymeMeta, interp::GPUInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int)
return nothing
end

import Core.Compiler: insert_node!, NewInstruction, ReturnNode, Instruction, InliningState, Signature

# We really need a Compiler stdlib
Base.getindex(ir::IRCode, i) = Core.Compiler.getindex(ir, i)
Base.setindex!(inst::Instruction, val, i) = Core.Compiler.setindex!(inst, val, i)

const FlagType = VERSION >= v"1.11.0-" ? UInt32 : UInt8
function Core.Compiler.handle_call!(todo::Vector{Pair{Int,Any}}, ir::IRCode, stmt_idx::Int,
stmt::Expr, info::AutodiffCallInfo, flag::FlagType,
sig::Signature, state::InliningState)

# Goal:
# The IR we want to inline here is:
# unpack the args ..
# ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...)
# ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)

# 0. Obtain primal mi from DeferredCallInfo
# TODO: remove this code duplication
deferred_info = info.info
minfo = deferred_info.info
results = minfo.results
if length(results.matches) != 1
return nothing
end
match = only(results.matches)

# lookup the target mi with correct edge tracking
# TODO: Effects?
case = Core.Compiler.compileable_specialization(
match, Core.Compiler.Effects(), Core.Compiler.InliningEdgeTracker(state), info)
@assert case isa Core.Compiler.InvokeCase
@assert stmt.head === :call

# Now create the IR we want to inline
ir = Core.Compiler.IRCode() # contains a placeholder
args = [Core.Compiler.Argument(i) for i in 2:length(stmt.args)] # f, args...
idx = 0

# 0. Enzyme proper: Desugar args
primal_args = args
primal_argtypes = match.spec_types.parameters[2:end]

adjoint_rt = info.rt
adjoint_args = args # TODO
adjoint_argtypes = primal_argtypes

# 1: Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call
expr = Expr(:foreigncall,
"extern gpuc.lookup",
Ptr{Cvoid},
Core.svec(#=meta=# Any, #=mi=# Any, #=f=# Any, primal_argtypes...), # Must use Any for MethodInstance or ftype
0,
QuoteNode(:llvmcall),
deferred_info.meta,
case.invoke,
primal_args...
)
ptr = insert_node!(ir, (idx += 1), NewInstruction(expr, Ptr{Cvoid}))

# 2. Call to magic `__autodiff`
expr = Expr(:foreigncall,
"extern __autodiff",
adjoint_rt,
Core.svec(Ptr{Cvoid}, Any, adjoint_argtypes...),
0,
QuoteNode(:llvmcall),
ptr,
adjoint_args...
)
ret = insert_node!(ir, idx, NewInstruction(expr, adjoint_rt))

# Finally replace placeholder return
ir[Core.SSAValue(1)][:inst] = Core.ReturnNode(ret)
ir[Core.SSAValue(1)][:type] = Ptr{Cvoid}

ir = Core.Compiler.compact!(ir)

# which mi to use here?
# push inlining todos
# TODO: Effects
# aviatesk mentioned using inlining_policy instead...
itodo = Core.Compiler.InliningTodo(case.invoke, ir, Core.Compiler.Effects())
@assert itodo.linear_inline_eligible
push!(todo, (stmt_idx=>itodo))

return nothing
end

function mock_enzyme!(@nospecialize(job), intrinsic, mod::LLVM.Module)
changed = false

for use in LLVM.uses(intrinsic)
call = LLVM.user(use)
LLVM.@dispose builder=LLVM.IRBuilder() begin
LLVM.position!(builder, call)
ops = LLVM.operands(call)
target = ops[1]
if target isa LLVM.ConstantExpr && (LLVM.opcode(target) == LLVM.API.LLVMPtrToInt ||
LLVM.opcode(target) == LLVM.API.LLVMBitCast)
target = first(LLVM.operands(target))
end
funcT = LLVM.called_type(call)
funcT = LLVM.FunctionType(LLVM.return_type(funcT), LLVM.parameters(funcT)[3:end])
direct_call = LLVM.call!(builder, funcT, target, ops[3:end - 1]) # why is the -1 necessary

LLVM.replace_uses!(call, direct_call)
end
if isempty(LLVM.uses(call))
LLVM.erase!(call)
changed = true
else
# the validator will detect this
end
end

return changed
end

GPUCompiler.register_plugin!("__autodiff", mock_enzyme!)

end #module
24 changes: 24 additions & 0 deletions test/ptx_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -504,4 +504,28 @@ end
ir = sprint(io->PTX.code_llvm(io, kernel_inline, Tuple{Ptr{Int64}, Int64}, meta=Plugin.NeverInlineMeta()))
@test occursin("call fastcc i64 @julia_inline", ir)
end

@testset "Mock Enzyme" begin
function f(x)
x^2
end

function kernel(a, x)
y = Plugin.autodiff(f, x)
unsafe_store!(a, y)
nothing
end

# This tests deferred_codegen with kernel=true
@show PTX.code_typed(kernel, Tuple{Ptr{Float64}, Float64})

ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Ptr{Float64}, Float64}, optimize=false))
@test occursin("call double @__autodiff", ir)
@test !occursin("call fastcc double @julia_f", ir)

ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Ptr{Float64}, Float64}, optimize=true))
@test !occursin("call double @__autodiff", ir)
@test occursin("call fastcc double @julia_f", ir)
end

end #testitem

0 comments on commit 1b8d203

Please sign in to comment.