Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recursive map generalizing the make_zero mechanism #1852

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
14 changes: 9 additions & 5 deletions ext/EnzymeStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,15 @@ end
end
end

@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:SArray}
return Base.zero(x)
end
@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:MArray}
return Base.zero(x)
# SArrays and MArrays don't need special treatment for `make_zero(!)` to work or be correct,
# but in case their dedicated `zero` and `fill!` methods are more efficient than
# `make_zero(!)`s generic recursion, we opt into treating them as leaves when they have
# isbits eltypes (non-isbits eltypes excluded as the dedicated `zero` and `fill!` methods
# don't support those).
@inline function Enzyme.EnzymeCore.isvectortype(
::Type{<:Union{SArray{S,T},MArray{S,T}}}
) where {S,T}
return isbitstype(T) && Enzyme.EnzymeCore.isscalartype(T)
end

end
93 changes: 84 additions & 9 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -503,28 +503,103 @@ function autodiff_thunk end
function autodiff_deferred_thunk end

"""
make_zero(prev::T, ::Val{copy_if_inactive}=Val(false))::T
make_zero(::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}=Val(false))::T

Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies
what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value.
Recursively make a copy of the value `prev` of type `T` in which all differentiable values
are set to zero. The argument `copy_if_inactive` specifies what to do if the type `T` or any
of its constituent parts is guaranteed to be inactive: use `prev`s instance (the default) or
make a copy.

Extending this method for custom types is rarely needed. For new types that shouldn't be
recursed into, such as a GPU array type, extending [`isvectortype`](@ref) is sufficient as
long as the type implements `Base.zero`.
"""
function make_zero end

"""
make_zero!(val::T, seen::IdSet{Any}=IdSet())::Nothing
make_zero!(val::T, seen::IdDict=IdDict())::Nothing

Recursively set a variable's differentiable fields to zero. Only applicable for types `T`
that are mutable or hold all differentiable values in mutable containers (e.g.,
`Tuple{Vector{Float64}}`).

Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`.
Extending this method for custom types is rarely needed. For new mutable types that
shouldn't be recursed into, such as a GPU array type, extending [`isvectortype`](@ref) is
sufficient as long as the type implements `Base.zero` and `Base.fill!`.
"""
function make_zero! end

"""
make_zero(prev::T)
isvectortype(::Type{T})::Bool

Trait defining types whose values should be considered leaf nodes when [`make_zero`](@ref)
and [`make_zero!`](@ref) recurse through an object.

By default, `isvectortype(T) == true` for `T` such that `isscalartype(T) == true` or
`T <: Union{Array{U},GenericMemory{_,U}}` where `isscalartype(U) == true`.

A new leaf type, such as example a GPU array type, may extend this as follows:

```julia
@inline EnzymeCore.isvectortype(::Type{<:NewArray{U}}) where {U} = EnzymeCore.isscalartype(U)
```

Such a type should implement `Base.zero` and, if mutable, `Base.fill!`. (If this is not
feasible, an alternative is to add methods `EnzymeCore.make_zero(arr::T)::T` and, if
mutable, `EnzymeCore.make_zero!(arr::T)::Nothing`; such methods will also be picked up by
recursive calls.)

Helper function to recursively make zero.
Such extensions are mostly relevant for the lowest-level of abstraction of memory at which
vector space operations like addition and scalar multiplication are supported, the
prototypical case being `Array`. Regular Julia structs with vector space-like semantics
should normally not extend `isvectorspace`; `make_zero(!)` will recurse into them and act
directly on their backing arrays, just like how Enzyme treats them when differentiating. For
example, structured matrix wrappers and sparse array types that are backed by `Array` should
not extend `isvectortype`.

If a vector type `T` is also non-differentiable, `isvectortype` takes precedence, that is,
`make_zero(!)` will attempt to zero its values rather than share/copy them (out-of-place) or
skip them (in-place). This is for performance reasons, but should almost never be relevant
for behavior, as the two traits should be mutually exclusive.

See also [`isscalartype`](@ref).
"""
@inline function make_zero(prev::T, ::Val{copy_if_inactive}=Val(false)) where {T, copy_if_inactive}
make_zero(Core.Typeof(prev), IdDict(), prev, Val(copy_if_inactive))
end
function isvectortype end

"""
isscalartype(::Type{T})::Bool

Trait defining a subset of [`isvectortype`](@ref) types that should not be considered
composite, such that even if the type is mutable, [`make_zero!`](@ref) will not try to zero
values of the type in-place. For example, `BigFloat` is a mutable type but does not support
in-place mutation through any Julia API; `isscalartype(BigFloat) == true` ensures that
`make_zero!` will not try to mutate `BigFloat` values.[^BigFloat]

By default, `isscalartype(T) == true` for `T <: AbstractFloat` and
`T <: Complex{<:AbstractFloat}`.

A hypothetical new real number type with Enzyme support should in most cases simply subtype
`AbstractFloat` and inherit the `isscalartype` trait that way. If this is not appropriate,
the function can be extended as follows:

```julia
@inline EnzymeCore.isscalartype(::Type{<:NewReal}) = true
@inline EnzymeCore.isscalartype(::Type{<:Complex{<:NewReal}}) = true
```

In either case, the type should implement `Base.zero`. (If this is not feasible, an
alternative is to add a method `EnzymeCore.make_zero(x::T)::T`; such a method will also be
picked up by recursive calls.)

See also [`isvectortype`](@ref).

[^BigFloat]: Enzyme does not support differentiating `BigFloat` as of this writing; it is
mentioned here only to illustrate that it would be inappropriate to use traits like
`ismutable` or `isbitstype` to choose between in-place and out-of-place zeroing,
demonstrating the need for a dedicated `isscalartype` trait.
"""
function isscalartype end

function tape_type end

Expand Down
4 changes: 2 additions & 2 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -451,10 +451,10 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
# compute the correct complex derivative in reverse mode by propagating the conjugate return values
# then subtracting twice the imaginary component to get the correct result

for (k, v) in seen
for (k, (v,)) in seen
Compiler.recursive_accumulate(k, v, refn_seed)
end
for (k, v) in seen2
for (k, (v,)) in seen2
Compiler.recursive_accumulate(k, v, imfn_seed)
end

Expand Down
3 changes: 1 addition & 2 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1200,8 +1200,6 @@ struct Tape{TapeTy,ShadowTy,ResT}
shadow_return::ShadowTy
end

include("make_zero.jl")

function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, f, tt, world)
funcspec = my_methodinstance(typeof(f), tt, world)
nested_codegen!(mode, mod, funcspec, world)
Expand Down Expand Up @@ -7813,6 +7811,7 @@ end
end

# Recursively return x + f(y), where y is active, otherwise x
include("recursive_map.jl")

@inline function recursive_add(
x::T,
Expand Down
Loading
Loading