Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jul 21, 2024
1 parent d094e99 commit f24a2b7
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions ext/EnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ function EnzymeCore.EnzymeRules.reverse(config, ofn::EnzymeCore.Const{typeof(cuf
return (nothing, nothing)
end

function meta_augf(f, tape::CuArray{TapeType}, ::Val{ModifiedBetween}, args::Vararg{Any, N}) where {N, ModifiedBetween, TapeType}
function meta_augf(f, tape::CuDeviceArray{TapeType}, ::Val{ModifiedBetween}, args::Vararg{Any, N}) where {N, ModifiedBetween, TapeType}
forward, _ = EnzymeCore.autodiff_deferred_thunk(
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
TapeType,
Expand All @@ -212,7 +212,7 @@ function meta_augf(f, tape::CuArray{TapeType}, ::Val{ModifiedBetween}, args::Var
)

I = (blockIdx().x, blockIdx().y, blockIdx().z, threadIdx().x, threadIdx().y, threadIdx().z)
subtape[I] = forward(Const(f), args...)[1]
tape[I] = forward(Const(f), args...)[1]
nothing
end

Expand All @@ -233,17 +233,17 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::EnzymeCore.Annotat
subtape = CuArray{TapeType}(undef, (blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z))

GC.@preserve args0 subtape, begin
args = (cudaconvert(subtape), Val(ModifiedBetween), args...)
T2 = (F, typeof(subtape), Val{ModifiedBetween}, (typeof(a) for a in args)...)
subtape2 = cudaconvert(subtape)
T2 = (F, typeof(subtape2), Val{ModifiedBetween}, (typeof(a) for a in args)...)
TT2 = Tuple{T2...}
cuf = cufunction(meta_augf, TT2)
res = cuf(ofn.val.f, args...; threads, blocks, kwargs...)
res = cuf(ofn.val.f, subtape2, Val(ModifiedBetween), args...; threads, blocks, kwargs...)
end

return AugmentedReturn{Nothing,Nothing,Any}(nothing, nothing, subtape)
end

function meta_revf(f, tape::CuArray{TapeType}, ::Val{ModifiedBetween}, args::Vararg{Any, N}) where {N, ModifiedBetween, TapeType}
function meta_revf(f, tape::CuDeviceArray{TapeType}, ::Val{ModifiedBetween}, args::Vararg{Any, N}) where {N, ModifiedBetween, TapeType}
_, reverse = EnzymeCore.autodiff_deferred_thunk(
EnzymeCore.compiler_job_from_backend(CUDABackend(), typeof(Base.identity), Tuple{Float64}),
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
Expand All @@ -254,7 +254,7 @@ function meta_revf(f, tape::CuArray{TapeType}, ::Val{ModifiedBetween}, args::Va
)

I = (blockIdx().x, blockIdx().y, blockIdx().z, threadIdx().x, threadIdx().y, threadIdx().z)
reverse(Const(f), args..., subtape[I])
reverse(Const(f), args..., tape[I])
nothing
end

Expand All @@ -273,11 +273,11 @@ function EnzymeCore.EnzymeRules.reverse(config, ofn::EnzymeCore.Annotation{CUDA.
blocks = CuDim3(blocks)

GC.@preserve args0 subtape, begin
args = (cudaconvert(subtape), Val(ModifiedBetween),(cudaconvert(a) for a in args)...,)
T2 = (F, typeof(subtape), Val{ModifiedBetween}, (typeof(a) for a in args)...)
subtape2 = cudaconvert(subtape)
T2 = (F, typeof(subtape2), Val{ModifiedBetween}, (typeof(a) for a in args)...)
TT2 = Tuple{T2...}
cuf = cufunction(meta_revf, TT2)
res = cuf(ofn.val.f, args...; threads, blocks, kwargs...)
res = cuf(ofn.val.f, subtape2, Val(ModifiedBetween), args...; threads, blocks, kwargs...)
end

return AugmentedReturn{Nothing,Nothing,Any}(nothing, nothing, subtape)
Expand Down

0 comments on commit f24a2b7

Please sign in to comment.