Skip to content

Commit

Permalink
Enzyme: Forward mode sync (JuliaGPU#2369)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 14, 2024
1 parent 7862981 commit f92b045
Showing 1 changed file with 29 additions and 2 deletions.
31 changes: 29 additions & 2 deletions ext/EnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,39 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cudaconvert)},
elseif RT <: DuplicatedNoNeed
return ofn.val(x.val)
else
tup = ntuple(Val(EnzymeRules.batch_width(RT))) do i
tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i
Base.@_inline_meta
ofn.val(x.dval[i])
end
if RT <: BatchDuplicated
return BatchDuplicated(ofv.val(x.val, tup))
return BatchDuplicated(ofv.val(x.val), tup)
else
return tup
end
end
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(synchronize)},
::Type{RT}, args::NTuple{N, EnzymeCore.Annotation}; kwargs...) where {RT, N}
pargs = ntuple(Val(N)) do i
Base.@_inline_meta
args.val
end
res = ofn.val(pargs...; kwargs...)

if RT <: Duplicated
return Duplicated(res, res)
elseif RT <: Const
return res
elseif RT <: DuplicatedNoNeed
return res
else
tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i
Base.@_inline_meta
res
end
if RT <: BatchDuplicated
return BatchDuplicated(res, tup)
else
return tup
end
Expand Down

0 comments on commit f92b045

Please sign in to comment.