Skip to content

Commit

Permalink
AccumThunk
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Sep 17, 2022
1 parent e4ff1ab commit 4d48df8
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export ProjectTo, canonicalize, unthunk # tangent operations
export add!!, is_inplaceable_destination # gradient accumulation operations
export ignore_derivatives, @ignore_derivatives
# tangents
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk, AccumThunk

include("compat.jl")
include("debug_mode.jl")
Expand Down
11 changes: 11 additions & 0 deletions src/accumulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ end

add!!(x::AbstractArray, y::Thunk) = add!!(x, unthunk(y))

# add!!(x::AbstractArray, y::AccumThunk) = add!!(x, unthunk(y)) # not sure! This may be less efficient than fallback

function add!!(x::AbstractArray, y::AccumThunk)
return if is_inplaceable_destination(x)
x .+= y
else
# We are free to mutate the other way...
add!!(y.value, x)
end
end

function add!!(x::AbstractArray{<:Any,N}, y::AbstractArray{<:Any,N}) where {N}
return if is_inplaceable_destination(x)
if !debug_mode()
Expand Down
46 changes: 45 additions & 1 deletion src/tangent_arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ Base.complex(::ZeroTangent, ::ZeroTangent) = ZeroTangent()
Base.complex(::ZeroTangent, i::Real) = complex(oftype(i, 0), i)
Base.complex(r::Real, ::ZeroTangent) = complex(r)

Base.:+(a::AbstractThunk, b::AbstractThunk) = unthunk(a) + unthunk(b)
Base.:*(a::AbstractThunk, b::AbstractThunk) = unthunk(a) * unthunk(b)
for T in (:Tangent, :Any)
@eval Base.:+(a::AbstractThunk, b::$T) = unthunk(a) + b
Expand Down Expand Up @@ -154,3 +153,48 @@ for T in (:Number,)
@eval Base.:*(s::$T, tangent::Tangent) = map(x -> s * x, tangent)
@eval Base.:*(tangent::Tangent, s::$T) = map(x -> x * s, tangent)
end

# Accumulation
# While many weird operations above may never be called, accumulation of gradients is one of
# the big sources of memory allocation in AD, and is the entire reason InplaceableThunks exist.
# Here we try to mark any array known to be safe-to-mutate by wrapping it with AccumThunk.

Base.:+(a::AbstractThunk, b::AbstractThunk) = maybe_accumthunk(unthunk(a) + unthunk(b))
# Try not to put this wrapper on non-arrays
maybe_accumthunk(a) = is_inplaceable_destination(a) ? AccumThunk(a) : a

Base.:+(a::AbstractThunk, b::AbstractArray) = AccumThunk(unthunk(a) + b)
Base.:+(a::AbstractArray, b::AbstractThunk) = AccumThunk(a + unthunk(b))

Base.:+(a::AccumThunk, b::AbstractArray) = AccumThunk(add!!(a.value, b))
Base.:+(a::AbstractArray, b::AccumThunk) = AccumThunk(add!!(b.value, a))

Base.:+(a::AccumThunk, b::AbstractThunk) = maybe_accumthunk(add!!(a.value, b))
Base.:+(a::AbstractThunk, b::AccumThunk) = maybe_accumthunk(add!!(b.value, a))

function Base.:+(a::AccumThunk, b::AccumThunk)
return if is_inplaceable_destination(a.value)
AccumThunk(add!!(a.value, b.value))
elseif is_inplaceable_destination(b.value)
AccumThunk(add!!(b.value, a.value))
else # no point keeping this type:
a.value + b.value
end
end


#=
# You could go further and assume any result of unthunk is safe to mutate,
# something like this:
# Base.:+(a::AbstractThunk, b::AbstractThunk) = maybe_accumthunk(add!!(unthunk(a), b))
Base.:+(a::InplaceableThunk, b::AbstractThunk) = AccumThunk(add!!(unthunk(b), b))
Base.:+(a::AbstractThunk, b::InplaceableThunk) = AccumThunk(add!!(unthunk(a), b))
Base.:+(a::InplaceableThunk, b::InplaceableThunk) = AccumThunk(add!!(unthunk(a), b))
Base.:+(a::AccumThunk, b::InplaceableThunk) = maybe_accumthunk(add!!(a.value, b))
Base.:+(a::InplaceableThunk, b::AccumThunk) = maybe_accumthunk(add!!(b.value, a))
=#
111 changes: 111 additions & 0 deletions src/tangent_types/thunks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ function Base.showerror(io::IO, e::MutateThunkException)
return nothing
end

#####
##### Operations which un-thunk automatically
#####

# Note the if you use an object which might be thunked in two places,
# you should *always* call `unthunk` manually first, once, to avoid un-thunking twice.

# Maybe the docs should have a list of exactly what operations do un-thunk automatically...
# do we really need so many?

Base.Broadcast.broadcastable(x::AbstractThunk) = broadcastable(unthunk(x))

@inline function Base.iterate(x::AbstractThunk)
Expand Down Expand Up @@ -138,6 +148,11 @@ macro thunk(body)
func = Expr(:->, Expr(:tuple), Expr(:block, __source__, body))
return :(Thunk($(esc(func))))
end
# macro thunk(s::Symbol)
# @warn "Applying `@thunk` to a single symbol does nothing, as there is no calculation to defer."
# # But should it perhaps do something, if we also regard thunks as marking safe-to-mutate?
# return esc(s)
# end

"""
unthunk(x)
Expand All @@ -157,6 +172,7 @@ Base.transpose(x::AbstractThunk) = @thunk(transpose(unthunk(x)))

"""
Thunk(()->v)
A thunk is a deferred computation.
It wraps a zero argument closure that when invoked returns a tangent.
`@thunk(v)` is a macro that expands into `Thunk(()->v)`.
Expand Down Expand Up @@ -212,6 +228,10 @@ end

Base.convert(::Type{<:Thunk}, a::AbstractZero) = @thunk(a)

#####
##### `InplaceableThunk`
#####

"""
InplaceableThunk(add!::Function, val::Thunk)
Expand Down Expand Up @@ -244,3 +264,94 @@ function Base.show(io::IO, x::InplaceableThunk)
show(io, x.val)
print(io, ")")
end


#####
##### `AccumThunk`
#####

"""
AccumThunk(value) <: AbstractThunk
This isn't a delayed computation, but is instead a marker that its contents is known to be safe
to mutate during gradient accumulation. At present it is produced by adding two thunks,
allowing any further addition to keep mutating. Anything downstream which wants an array must
already know to `unthunk`, which is why this is `<: AbstractThunk`.
Ideally it would be produced by adding two Arrays too, but that's impossible in CR's design.
It might be good for many rules which produce a known-safe Array to wrap it in this.
If we may assume/demand that the result of `@thunk` is always a new array, too,
then more cases can mutate. And then it would make sense for `@thunk A` on one Symbol
to produce an `AccumThunk`, promoting `@thunk` to have two meanings. But not yet done.
"""
struct AccumThunk{T} <: AbstractThunk
value::T
end

@inline unthunk(x::AccumThunk) = x.value

function Base.show(io::IO, x::AccumThunk)
print(io, "AccumThunk(")
str = sprint(show, x.value, context = io)
if length(str) < 80
print(io, str)
else
print(io, first(str, 70), "...")
end
print(io, ")")
end


#=
julia> using ChainRules, ChainRulesCore, Diffractor
julia> _getindex(x...) = getindex(x...); # use CR's rule:
julia> function ChainRules.rrule(::typeof(_getindex), x::AbstractArray, inds...)
function getindex_pullback(dy)
nots = map(Returns(NoTangent()), inds)
return (NoTangent(), ChainRules.thunked_∇getindex(x, dy, inds...), nots...)
end
return x[inds...], getindex_pullback
end
julia> Diffractor.gradient(x -> _getindex(x,1), [1,2,3.0]) # calls unthunk on final answer
([1.0, 0.0, 0.0],)
julia> @btime Diffractor.gradient(x -> _getindex(x,1), $(rand(128 * 100)));
min 1.012 μs, mean 11.103 μs (2 allocations, 100.05 KiB)
julia> @btime Diffractor.gradient(x -> _getindex(x,1)+_getindex(x,2), $(rand(128 * 100)));
min 7.625 μs, mean 46.941 μs (6 allocations, 300.14 KiB) # unthunk, unthunk, add -- unchanged
julia> @btime Diffractor.gradient(x -> _getindex(x,1)+_getindex(x,2)+_getindex(x,3), $(rand(128 * 100)));
min 16.791 μs, mean 67.720 μs (10 allocations, 500.23 KiB) # before
min 8.625 μs, mean 44.642 μs (6 allocations, 300.14 KiB) # after
min 1.036 μs, mean 12.684 μs (2 allocations, 100.05 KiB) # with stronger assumption, overwrite any thunk
# Same example as https://github.com/FluxML/Zygote.jl/pull/981#issuecomment-861079488
# originally https://github.com/FluxML/Zygote.jl/issues/644
julia> function _evalpoly(x, p)
N = length(p)
ex = _getindex(p, length(p))
for i in N-1:-1:1
ex = muladd(x, ex, _getindex(p, i))
end
ex
end
_evalpoly (generic function with 1 method)
julia> x, p = rand(), randn(10000);
julia> @btime _evalpoly(x, p);
min 20.375 μs, mean 20.553 μs (1 allocation, 16 bytes)
julia> @btime Diffractor.gradient(_evalpoly, x, p);
min 566.669 ms, mean 585.185 ms (1174329 allocations, 2.44 GiB) # before
min 376.376 ms, mean 384.314 ms (1144338 allocations, 975.62 MiB) # after
=#

30 changes: 28 additions & 2 deletions test/accumulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@
end
end

@testset "AbstractThunk $(typeof(thunk))" for thunk in (
@testset "add!!(array, $(typeof(thunk)))" for thunk in (
@thunk(-1.0 * ones(2, 2)),
InplaceableThunk(x -> x .-= ones(2, 2), @thunk(-1.0 * ones(2, 2))),
AccumThunk(-ones(2, 2))
)
@testset "in place" begin
accumuland = [1.0 2.0; 3.0 4.0]
Expand All @@ -101,14 +102,18 @@
@test ret === accumuland # must be same object
end

@test unthunk(thunk) == -ones(2, 2) # AccumThunk has not been mutated

@testset "out of place" begin
accumuland = @SMatrix [1.0 2.0; 3.0 4.0]

ret = add!!(accumuland, thunk)
@test ret == [0.0 1.0; 2.0 3.0] # must return right answer
@test ret !== accumuland # must not be same object
@test accumuland == [1.0 2.0; 3.0 4.0] # must not have mutated
@test accumuland == [1.0 2.0; 3.0 4.0] # cannot ever be mutated
end

unthunk(thunk) # AccumThunk may have been mutated, test has no opinion?
end

@testset "not actually inplace but said it was" begin
Expand Down Expand Up @@ -137,4 +142,25 @@
msg_equal = sprint(showerror, BadInplaceException(ithunk, [22], [22]))
@test occursin("equal", msg_equal)
end

@testset "thunk + thunk" begin
s1 = @thunk([1.0]) + @thunk([2.0]) + @thunk([3.0])
@test unthunk(s1) == [6]
@test s1 isa AccumThunk

list = [[1.0], @thunk([1.0]), InplaceableThunk(x -> x .+ 1, @thunk [1.0]), AccumThunk([1.0])]
for x in list, y in list
z = deepcopy(x) + deepcopy(y)
@test unthunk(z) == [2]
@test z isa AccumThunk || (x isa Array && y isa Array)
end

triv = [1.0, @thunk(1.0), AccumThunk(1.0)]
for x in triv, y in triv
z = x + y
@test unthunk(z) === 2.0
@test z isa Float64 || (x isa AccumThunk && y isa AccumThunk)
# How much to se care about not applying these wrappers when not useful?
end
end
end

0 comments on commit 4d48df8

Please sign in to comment.