Skip to content

Commit

Permalink
Enzyme: Reversemode cudaconvert (#2476)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Aug 26, 2024
1 parent 4a215e3 commit bd3b61b
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion ext/EnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ end
function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cudaconvert)},
::Type{RT}, x::IT) where {RT, IT}
if RT <: Duplicated
RT(ofn.val(x.val), ofn.val(x.dval))
Duplicated(ofn.val(x.val), ofn.val(x.dval))
elseif RT <: Const
ofn.val(x.val)::eltype(RT)
elseif RT <: DuplicatedNoNeed
Expand All @@ -73,6 +73,33 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cudaconvert)},
end
end

function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(cudaconvert)}, ::Type{RT}, x::IT) where {RT, IT}
primal = if EnzymeRules.needs_primal(config)
ofn.val(x.val)
else
nothing
end

shadow = if EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
ofn.val(x.dval)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
ofn.val(x.dval[i])
end
end
else
nothing
end
return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing), Nothing}(primal, shadow, nothing)
end

function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(cudaconvert)}, ::Type{RT}, tape, x::IT) where {RT, IT}
(nothing,)
end


function EnzymeCore.EnzymeRules.forward(ofn::Const{Type{CT}},
::Type{RT}, uval::EnzymeCore.Annotation{UndefInitializer}, args...) where {CT <: CuArray, RT}
primargs = ntuple(Val(length(args))) do i
Expand Down

0 comments on commit bd3b61b

Please sign in to comment.