From bd3b61be6a30b181607525e0e95b5cc4aa4ad381 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 26 Aug 2024 01:30:50 -0500 Subject: [PATCH] Enzyme: Reversemode cudaconvert (#2476) --- ext/EnzymeCoreExt.jl | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/ext/EnzymeCoreExt.jl b/ext/EnzymeCoreExt.jl index 2602e45b39..ea52c27ca3 100644 --- a/ext/EnzymeCoreExt.jl +++ b/ext/EnzymeCoreExt.jl @@ -55,7 +55,7 @@ end function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cudaconvert)}, ::Type{RT}, x::IT) where {RT, IT} if RT <: Duplicated - RT(ofn.val(x.val), ofn.val(x.dval)) + Duplicated(ofn.val(x.val), ofn.val(x.dval)) elseif RT <: Const ofn.val(x.val)::eltype(RT) elseif RT <: DuplicatedNoNeed @@ -73,6 +73,33 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cudaconvert)}, end end +function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(cudaconvert)}, ::Type{RT}, x::IT) where {RT, IT} + primal = if EnzymeRules.needs_primal(config) + ofn.val(x.val) + else + nothing + end + + shadow = if EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + ofn.val(x.dval) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + ofn.val(x.dval[i]) + end + end + else + nothing + end + return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing), Nothing}(primal, shadow, nothing) +end + +function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(cudaconvert)}, ::Type{RT}, tape, x::IT) where {RT, IT} + (nothing,) +end + + function EnzymeCore.EnzymeRules.forward(ofn::Const{Type{CT}}, ::Type{RT}, uval::EnzymeCore.Annotation{UndefInitializer}, args...) where {CT <: CuArray, RT} primargs = ntuple(Val(length(args))) do i