Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Take is_inplaceable_destination seriously #577

Merged
merged 5 commits into from
Aug 28, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export frule_via_ad, rrule_via_ad
# definition helper macros
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented
export ProjectTo, canonicalize, unthunk # tangent operations
export add!! # gradient accumulation operations
export add!!, is_inplaceable_destination # gradient accumulation operations
export ignore_derivatives, @ignore_derivatives
# tangents
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
Expand Down
43 changes: 27 additions & 16 deletions src/accumulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,30 +38,41 @@ end
is_inplaceable_destination(x) -> Bool

Returns true if `x` is suitable for for storing inplace accumulation of gradients.
For arrays this boils down `x .= y` if will work to mutate `x`, if `y` is an appropriate
tangent.
Wrapper array types do not need to overload this if they overload `Base.parent`, and are
`is_inplaceable_destination` if and only if their parent array is.
Other types should overload this, as it defaults to `false`.
For arrays this means `x .= y` will mutate `x`, if `y` is an appropriate tangent.

Here "appropriate" means that both are real or both are complex,
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
and that for structured matrices like `x isa Diagonal`, `y` shares this structure.

!!! note "history"
Wrapper array types should overload this function if they can be written into.
Before ChainRulesCore 1.16, it would guess `true` for most wrappers based on `parent`,
but this is not safe, e.g. it will lead to an error with ReadOnlyArrays.jl.

There must always be a correct non-mutating path, so in uncertain cases,
this function returns `false`.
"""
is_inplaceable_destination(::Any) = false
is_inplaceable_destination(::Array) = true

is_inplaceable_destination(::Array)= true
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
is_inplaceable_destination(:: Array{<:Integer}) = false

is_inplaceable_destination(::SparseVector) = true
is_inplaceable_destination(::SparseMatrixCSC) = true
is_inplaceable_destination(::BitArray) = true
function is_inplaceable_destination(x::AbstractArray)
p = parent(x)
p === x && return false # no parent
# basically all wrapper types delegate `setindex!` to their `parent` after some
# processing and so are mutable if their `parent` is.
return is_inplaceable_destination(p)

function is_inplaceable_destination(x::SubArray)
alpha = is_inplaceable_destination(parent(x))
beta = x.indices isa Tuple{Vararg{ Union{Integer, Base.Slice, UnitRange}}}
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
return alpha && beta
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it clearer to do

Suggested change
alpha = is_inplaceable_destination(parent(x))
beta = x.indices isa Tuple{Vararg{ Union{Integer, Base.Slice, UnitRange}}}
return alpha && beta
is_inplaceable_destination(parent(x)) || return falsse
return x.indices isa Tuple{Vararg{Union{Integer,Base.Slice,UnitRange}}}

I will trust your judgement

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a few variants. Maybe I don't like that this seems to imply you ought to check one before the other, but in fact it's just dealing with line length.

end

for T in [:PermutedDimsArray, :ReshapedArray]
@eval is_inplaceable_destination(x::Base.$T) = is_inplaceable_destination(parent(x))
end
for T in [:Adjoint, :Transpose, :Diagonal, :UpperTriangular, :LowerTriangular]
@eval is_inplaceable_destination(x::LinearAlgebra.$T) = is_inplaceable_destination(parent(x))
Copy link
Member

@oxinabox oxinabox Aug 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bit sad that this won't catch NamedDimsArrays or KeyedArrays for free anymore.
But those need to overload ProjectTo anyway so its fine to make them do this

That wrapper types in Base is most of what we need to deal with

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides having something you have to opt into, the other approach I can think of would be to ask type inference. But... this doesn't seem like the package's style:

julia> Core.Compiler._return_type(setindex!, typeof(([1,2,3.0], 4.0, 2)))
Vector{Float64} (alias for Array{Float64, 1})

julia> Core.Compiler._return_type(setindex!, typeof((SA[1,2,3.0], 4.0, 2)))
Union{}

julia> Core.Compiler._return_type(setindex!, typeof((1:3.0, 4.0, 2)))
Union{}

julia> which(setindex!, typeof((SA[1,2,3.0], 4.0, 2)))  # their own error
setindex!(a::StaticArray, value, i::Int64)
     @ StaticArrays ~/.julia/packages/StaticArrays/6QFsp/src/indexing.jl:3

julia> which(setindex!, typeof((1:3.0, 4.0, 2)))  # not useful
setindex!(A::AbstractArray, v, I...)
     @ Base abstractarray.jl:1374

end
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
# Hermitian and Symmetric are too fussy to deal with right now
# https://github.com/JuliaLang/julia/issues/38056
# TODO: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/236
is_inplaceable_destination(::LinearAlgebra.Hermitian) = false
is_inplaceable_destination(::LinearAlgebra.Symmetric) = false

function debug_add!(accumuland, t::InplaceableThunk)
returned_value = t.add!(accumuland)
Expand Down
25 changes: 14 additions & 11 deletions test/accumulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,26 @@
@testset "is_inplaceable_destination" begin
is_inplaceable_destination = ChainRulesCore.is_inplaceable_destination

@test is_inplaceable_destination([1, 2, 3, 4])
@test !is_inplaceable_destination(1:4)
@test is_inplaceable_destination([1.0, 2.0, 3.0])
@test !is_inplaceable_destination([1, 2, 3, 4]) # gradients cannot reliably be written into integer arrays
@test !is_inplaceable_destination(1:4.0)

@test is_inplaceable_destination(Diagonal([1, 2, 3, 4]))
@test !is_inplaceable_destination(Diagonal(1:4))
@test is_inplaceable_destination(Diagonal([1.0, 2.0, 3.0]))
@test !is_inplaceable_destination(Diagonal(1:4.0))

@test is_inplaceable_destination(view([1, 2, 3, 4], :, :))
@test !is_inplaceable_destination(view(1:4, :, :))
@test is_inplaceable_destination(view([1.0, 2.0, 3.0], :, :))
@test is_inplaceable_destination(view([1.0 2.0; 3.0 4.0], :, 2))
@test !is_inplaceable_destination(view(1:4.0, :, :))
mat = view([1.0, 2.0, 3.0], :, fill(1, 10))
@test !is_inplaceable_destination(mat) # The concern is that `mat .+= x` is unsafe on GPU / parallel.

@test is_inplaceable_destination(falses(4))
@test !is_inplaceable_destination(falses(4)) # gradients can never be written into boolean
@test is_inplaceable_destination(spzeros(4))
@test is_inplaceable_destination(spzeros(2, 2))

@test !is_inplaceable_destination(1.3)
@test !is_inplaceable_destination(@SVector [1, 2, 3])
@test !is_inplaceable_destination(Hermitian([1 2; 2 4]))
@test !is_inplaceable_destination(Symmetric([1 2; 2 4]))
@test !is_inplaceable_destination(1:3.0)
@test !is_inplaceable_destination(@SVector [1.0, 2.0, 3.0])
@test !is_inplaceable_destination(Hermitian([1.0 2.0; 2.0 4.0]))
end

@testset "add!!" begin
Expand Down
24 changes: 14 additions & 10 deletions test/tangent_types/tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ end
@test NoTangent() === @inferred Base.tail(ntang1)

# TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516
if VERSION >= v"1.8-"
@test haskey(Tangent{Tuple{Float64}}(2.0), 1) == true
else
@test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true
end
# if VERSION >= v"1.8-"
# @test haskey(Tangent{Tuple{Float64}}(2.0), 1) == true
# else
# @test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true
# end
@test_broken hasproperty(Tangent{Tuple{Float64}}(2.0), 2) == false
Comment on lines 111 to 117
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated, but was needed to make tests pass on 1.9. I don't know why it says "uncomment" but now that things are commented, it may make sense.


@test length(Tangent{Foo}(; x=2.5)) == 1
Expand Down Expand Up @@ -148,12 +148,16 @@ end
cr = Tangent{Tuple{String,Int,Int}}("something", 2, 1)
@test reverse(c) === cr

# can't reverse a named tuple or a dict
@test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0))
if VERSION < v"1.9-"
# can't reverse a named tuple or a dict
@test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise needed for 1.9

This tests that you can't reverse a Tangent, but do we care that you can't? We know that it's meaningless, like the order of fields in a struct. But if it used to be an error and now works, why is that a failure?


d = Dict(:x => 1, :y => 2.0)
cdict = Tangent{typeof(d),typeof(d)}(d)
@test_throws MethodError reverse(Tangent{Foo}())
d = Dict(:x => 1, :y => 2.0)
cdict = Tangent{typeof(d),typeof(d)}(d)
@test_throws MethodError reverse(Tangent{Foo}())
else
# These now work but do we care?
end
end

@testset "unset properties" begin
Expand Down