From b16b878e3ff88e087748fdfc49d3b79ba4bb1ffd Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 24 Aug 2022 20:58:37 -0400 Subject: [PATCH 1/5] take is_inplaceable_destination seriously --- src/ChainRulesCore.jl | 2 +- src/accumulation.jl | 44 +++++++++++++++++++++++++------------------ test/accumulation.jl | 25 +++++++++++++----------- 3 files changed, 41 insertions(+), 30 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index f9eaf59f6..b75d8eff5 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -12,7 +12,7 @@ export frule_via_ad, rrule_via_ad # definition helper macros export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented export ProjectTo, canonicalize, unthunk # tangent operations -export add!! # gradient accumulation operations +export add!!, is_inplaceable_destination # gradient accumulation operations export ignore_derivatives, @ignore_derivatives # tangents export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk diff --git a/src/accumulation.jl b/src/accumulation.jl index 6e186546e..cc6a4d051 100644 --- a/src/accumulation.jl +++ b/src/accumulation.jl @@ -38,30 +38,38 @@ end is_inplaceable_destination(x) -> Bool Returns true if `x` is suitable for for storing inplace accumulation of gradients. -For arrays this boils down `x .= y` if will work to mutate `x`, if `y` is an appropriate -tangent. -Wrapper array types do not need to overload this if they overload `Base.parent`, and are -`is_inplaceable_destination` if and only if their parent array is. -Other types should overload this, as it defaults to `false`. +For arrays this means `x .= y` will mutate `x`, if `y` is an appropriate tangent. + +Here "appropriate" means that both are real or both are complex, +and that for structured matrices like `x isa Diagonal`, `y` shares this structure. + +Wrapper array types should overload this function if they can be written into. +Before ChainRulesCore 1.16, it would guess `true` for most wrappers based on `parent`, +but this is not safe, e.g. it will lead to an error with ReadOnltArrays.jl. + +There must always be a correct non-mutating path, so in uncertain cases, +this function returns `false`. """ is_inplaceable_destination(::Any) = false -is_inplaceable_destination(::Array) = true + +is_inplaceable_destination(::DenseArray)= true +is_inplaceable_destination(::DenseArray{<:Integer}) = false + is_inplaceable_destination(::SparseVector) = true is_inplaceable_destination(::SparseMatrixCSC) = true -is_inplaceable_destination(::BitArray) = true -function is_inplaceable_destination(x::AbstractArray) - p = parent(x) - p === x && return false # no parent - # basically all wrapper types delegate `setindex!` to their `parent` after some - # processing and so are mutable if their `parent` is. - return is_inplaceable_destination(p) + +function is_inplaceable_destination(x::SubArray) + alpha = is_inplaceable_destination(parent(x)) + beta = x.indices isa Tuple{Vararg{ Union{Integer, Base.Slice, UnitRange}}} + return alpha && beta end -# Hermitian and Symmetric are too fussy to deal with right now -# https://github.com/JuliaLang/julia/issues/38056 -# TODO: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/236 -is_inplaceable_destination(::LinearAlgebra.Hermitian) = false -is_inplaceable_destination(::LinearAlgebra.Symmetric) = false +for T in [:PermutedDimsArray, :ReshapedArray] + @eval is_inplaceable_destination(x::Base.$T) = is_inplaceable_destination(parent(x)) +end +for T in [:Adjoint, :Transpose, :Diagonal, :UpperTriangular, :LowerTriangular] + @eval is_inplaceable_destination(x::LinearAlgebra.$T) = is_inplaceable_destination(parent(x)) +end function debug_add!(accumuland, t::InplaceableThunk) returned_value = t.add!(accumuland) diff --git a/test/accumulation.jl b/test/accumulation.jl index a796b5289..597105d32 100644 --- a/test/accumulation.jl +++ b/test/accumulation.jl @@ -2,23 +2,26 @@ @testset "is_inplaceable_destination" begin is_inplaceable_destination = ChainRulesCore.is_inplaceable_destination - @test is_inplaceable_destination([1, 2, 3, 4]) - @test !is_inplaceable_destination(1:4) + @test is_inplaceable_destination([1.0, 2.0, 3.0]) + @test !is_inplaceable_destination([1, 2, 3, 4]) # gradients cannot reliably be written into integer arrays + @test !is_inplaceable_destination(1:4.0) - @test is_inplaceable_destination(Diagonal([1, 2, 3, 4])) - @test !is_inplaceable_destination(Diagonal(1:4)) + @test is_inplaceable_destination(Diagonal([1.0, 2.0, 3.0])) + @test !is_inplaceable_destination(Diagonal(1:4.0)) - @test is_inplaceable_destination(view([1, 2, 3, 4], :, :)) - @test !is_inplaceable_destination(view(1:4, :, :)) + @test is_inplaceable_destination(view([1.0, 2.0, 3.0], :, :)) + @test is_inplaceable_destination(view([1.0 2.0; 3.0 4.0], :, 2)) + @test !is_inplaceable_destination(view(1:4.0, :, :)) + mat = view([1.0, 2.0, 3.0], :, fill(1, 10)) + @test !is_inplaceable_destination(mat) # The concern is that `mat .+= x` is unsafe on GPU / parallel. - @test is_inplaceable_destination(falses(4)) + @test !is_inplaceable_destination(falses(4)) # gradients can never be written into boolean @test is_inplaceable_destination(spzeros(4)) @test is_inplaceable_destination(spzeros(2, 2)) - @test !is_inplaceable_destination(1.3) - @test !is_inplaceable_destination(@SVector [1, 2, 3]) - @test !is_inplaceable_destination(Hermitian([1 2; 2 4])) - @test !is_inplaceable_destination(Symmetric([1 2; 2 4])) + @test !is_inplaceable_destination(1:3.0) + @test !is_inplaceable_destination(@SVector [1.0, 2.0, 3.0]) + @test !is_inplaceable_destination(Hermitian([1.0 2.0; 2.0 4.0])) end @testset "add!!" begin From c3bf4b088a666c0e0b71da5b182d270ed4121d8b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 24 Aug 2022 20:58:53 -0400 Subject: [PATCH 2/5] fix tests on 1.9 --- test/tangent_types/tangent.jl | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index 004dd71cf..5970f07f4 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -109,11 +109,11 @@ end @test NoTangent() === @inferred Base.tail(ntang1) # TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516 - if VERSION >= v"1.8-" - @test haskey(Tangent{Tuple{Float64}}(2.0), 1) == true - else - @test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true - end + # if VERSION >= v"1.8-" + # @test haskey(Tangent{Tuple{Float64}}(2.0), 1) == true + # else + # @test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true + # end @test_broken hasproperty(Tangent{Tuple{Float64}}(2.0), 2) == false @test length(Tangent{Foo}(; x=2.5)) == 1 @@ -148,12 +148,16 @@ end cr = Tangent{Tuple{String,Int,Int}}("something", 2, 1) @test reverse(c) === cr - # can't reverse a named tuple or a dict - @test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0)) + if VERSION < v"1.9-" + # can't reverse a named tuple or a dict + @test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0)) - d = Dict(:x => 1, :y => 2.0) - cdict = Tangent{typeof(d),typeof(d)}(d) - @test_throws MethodError reverse(Tangent{Foo}()) + d = Dict(:x => 1, :y => 2.0) + cdict = Tangent{typeof(d),typeof(d)}(d) + @test_throws MethodError reverse(Tangent{Foo}()) + else + # These now work but do we care? + end end @testset "unset properties" begin From 54890b9072a0e271b4a505b87ba8583e181a3cb7 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 25 Aug 2022 08:04:58 -0400 Subject: [PATCH 3/5] Update src/accumulation.jl Co-authored-by: Miha Zgubic --- src/accumulation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accumulation.jl b/src/accumulation.jl index cc6a4d051..140dcb440 100644 --- a/src/accumulation.jl +++ b/src/accumulation.jl @@ -45,7 +45,7 @@ and that for structured matrices like `x isa Diagonal`, `y` shares this structur Wrapper array types should overload this function if they can be written into. Before ChainRulesCore 1.16, it would guess `true` for most wrappers based on `parent`, -but this is not safe, e.g. it will lead to an error with ReadOnltArrays.jl. +but this is not safe, e.g. it will lead to an error with ReadOnlyArrays.jl. There must always be a correct non-mutating path, so in uncertain cases, this function returns `false`. From 3d4acc8e70fa8a16085f8efb5ae71289b1026fd0 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 26 Aug 2022 20:46:33 -0400 Subject: [PATCH 4/5] Apply suggestions from code review Co-authored-by: Frames Catherine White --- src/accumulation.jl | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/accumulation.jl b/src/accumulation.jl index 140dcb440..0d63c80a2 100644 --- a/src/accumulation.jl +++ b/src/accumulation.jl @@ -43,17 +43,18 @@ For arrays this means `x .= y` will mutate `x`, if `y` is an appropriate tangent Here "appropriate" means that both are real or both are complex, and that for structured matrices like `x isa Diagonal`, `y` shares this structure. -Wrapper array types should overload this function if they can be written into. -Before ChainRulesCore 1.16, it would guess `true` for most wrappers based on `parent`, -but this is not safe, e.g. it will lead to an error with ReadOnlyArrays.jl. +!!! note "history" + Wrapper array types should overload this function if they can be written into. + Before ChainRulesCore 1.16, it would guess `true` for most wrappers based on `parent`, + but this is not safe, e.g. it will lead to an error with ReadOnlyArrays.jl. There must always be a correct non-mutating path, so in uncertain cases, this function returns `false`. """ is_inplaceable_destination(::Any) = false -is_inplaceable_destination(::DenseArray)= true -is_inplaceable_destination(::DenseArray{<:Integer}) = false +is_inplaceable_destination(::Array)= true +is_inplaceable_destination(:: Array{<:Integer}) = false is_inplaceable_destination(::SparseVector) = true is_inplaceable_destination(::SparseMatrixCSC) = true @@ -70,6 +71,8 @@ end for T in [:Adjoint, :Transpose, :Diagonal, :UpperTriangular, :LowerTriangular] @eval is_inplaceable_destination(x::LinearAlgebra.$T) = is_inplaceable_destination(parent(x)) end +# Hermitian and Symmetric are too fussy to deal with right now +# https://github.com/JuliaLang/julia/issues/38056 function debug_add!(accumuland, t::InplaceableThunk) returned_value = t.add!(accumuland) From acefc193273e8de44644c8f433e2acac466e2e7f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 28 Aug 2022 02:17:48 -0400 Subject: [PATCH 5/5] three tweaks --- src/accumulation.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/accumulation.jl b/src/accumulation.jl index 0d63c80a2..dc4ccd3bf 100644 --- a/src/accumulation.jl +++ b/src/accumulation.jl @@ -40,7 +40,7 @@ end Returns true if `x` is suitable for for storing inplace accumulation of gradients. For arrays this means `x .= y` will mutate `x`, if `y` is an appropriate tangent. -Here "appropriate" means that both are real or both are complex, +Here "appropriate" means that `y` cannot be complex unless `x` is too, and that for structured matrices like `x isa Diagonal`, `y` shares this structure. !!! note "history" @@ -53,7 +53,7 @@ this function returns `false`. """ is_inplaceable_destination(::Any) = false -is_inplaceable_destination(::Array)= true +is_inplaceable_destination(::Array) = true is_inplaceable_destination(:: Array{<:Integer}) = false is_inplaceable_destination(::SparseVector) = true @@ -61,7 +61,7 @@ is_inplaceable_destination(::SparseMatrixCSC) = true function is_inplaceable_destination(x::SubArray) alpha = is_inplaceable_destination(parent(x)) - beta = x.indices isa Tuple{Vararg{ Union{Integer, Base.Slice, UnitRange}}} + beta = x.indices isa Tuple{Vararg{Union{Integer, Base.Slice, UnitRange}}} return alpha && beta end