From e245d50a1ae56ce46fc8c1f0fe9b925964f1146e Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Wed, 18 Sep 2024 11:11:35 -0400 Subject: [PATCH] =?UTF-8?q?Fix=20computing=20of=20array=20element=20type?= =?UTF-8?q?=20for=20`=E2=88=87eachslice`=20(#808)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/format.yml | 5 +++++ Project.toml | 4 ++-- src/rulesets/Base/indexing.jl | 5 ++--- test/rulesets/Base/indexing.jl | 36 +++++++++++++++++++++++++++------- 4 files changed, 38 insertions(+), 12 deletions(-) diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index f80377a24..e8334abca 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -12,6 +12,10 @@ concurrency: jobs: format: runs-on: ubuntu-latest + permissions: + contents: read + checks: write + pull-requests: write steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@latest @@ -22,6 +26,7 @@ jobs: julia -e 'using JuliaFormatter; format("."; verbose=true)' - uses: reviewdog/action-suggester@v1 with: + github_token: ${{ secrets.GITHUB_TOKEN }} tool_name: JuliaFormatter fail_on_error: true filter_mode: added diff --git a/Project.toml b/Project.toml index a679c914d..5c3a5607e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.70.0" +version = "1.71.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -20,7 +20,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [compat] Adapt = "3.4.0, 4" -ChainRulesCore = "1.20" +ChainRulesCore = "1.25" ChainRulesTestUtils = "1.5" Compat = "3.46, 4.2" Distributed = "1" diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 830571ecd..61216bda2 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -267,7 +267,7 @@ function ∇eachslice(dys_raw, x::AbstractArray, vd::Val{dim}) where {dim} if i1 === nothing # all slices are Zero! return _zero_fill!(similar(x, float(eltype(x)), axes(x))) end - T = promote_type(eltype(dys[i1]), eltype(x)) + T = Base.promote_eltype(dys...) # The whole point of this gradient is that we can allocate one `dx` array: dx = similar(x, T, axes(x)) for i in axes(x, dim) @@ -282,8 +282,7 @@ function ∇eachslice(dys_raw, x::AbstractArray, vd::Val{dim}) where {dim} end ∇eachslice(dys::AbstractZero, x::AbstractArray, vd::Val{dim}) where {dim} = dys -_zero_fill!(dx::AbstractArray{<:Number}) = fill!(dx, zero(eltype(dx))) -_zero_fill!(dx::AbstractArray) = map!(zero, dx, dx) +_zero_fill!(dx::AbstractArray) = fill!(dx, zero(eltype(dx))) function rrule(::typeof(∇eachslice), dys, x, vd::Val) function ∇∇eachslice(dz_raw) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index e878dd061..f80a37048 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -217,16 +217,24 @@ end # DimensionMismatch("second dimension of A, 6, does not match length of x, 5") # Probably similar to https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/234 (about Broadcasted not Generator) - test_rrule(collect∘eachrow, rand(5)) - test_rrule(collect∘eachrow, rand(3, 4)) + # Inference on 1.6 sometimes fails, so don't enforce there. + test_rrule(collect ∘ eachrow, rand(5); check_inferred=(VERSION >= v"1.7")) + test_rrule(collect ∘ eachrow, rand(3, 4); check_inferred=(VERSION >= v"1.7")) - test_rrule(collect∘eachcol, rand(3, 4)) - @test_skip test_rrule(collect∘eachcol, Diagonal(rand(5))) # works locally! + test_rrule(collect ∘ eachcol, rand(3, 4); check_inferred=(VERSION >= v"1.7")) + @test_skip test_rrule(collect ∘ eachcol, Diagonal(rand(5))) # works locally! if VERSION >= v"1.7" # On 1.6, ComposedFunction doesn't take keywords. Only affects this testing strategy, not real use. - test_rrule(collect∘eachslice, rand(3, 4, 5); fkwargs = (; dims = 3)) - test_rrule(collect∘eachslice, rand(3, 4, 5); fkwargs = (; dims = (2,))) + test_rrule(collect ∘ eachslice, rand(3, 4, 5); fkwargs=(; dims=3)) + test_rrule(collect ∘ eachslice, rand(3, 4, 5); fkwargs=(; dims=(2,))) + + test_rrule( + collect ∘ eachslice, + FooTwoField.(rand(3, 4, 5), rand(3, 4, 5)); + check_inferred=false, + fkwargs=(; dims=3), + ) end # Make sure pulling back an array that mixes some AbstractZeros in works right @@ -235,8 +243,22 @@ end @test back([1:3, ZeroTangent(), 7:9, NoTangent()])[2] isa Matrix{Float64} @test back([ZeroTangent(), ZeroTangent(), NoTangent(), NoTangent()]) == (NoTangent(), [0 0 0 0; 0 0 0 0; 0 0 0 0]) + _, back = ChainRules.rrule( + eachslice, FooTwoField.(rand(2, 3, 2), rand(2, 3, 2)); dims=3 + ) + @test back([fill(Tangent{Any}(; x=0.0, y=1.0), 2, 3), fill(ZeroTangent(), 2, 3)]) == ( + NoTangent(), + cat(fill(Tangent{Any}(; x=0.0, y=1.0), 2, 3), fill(ZeroTangent(), 2, 3); dims=3), + ) + # Second derivative rule test_rrule(ChainRules.∇eachslice, [rand(4) for _ in 1:3], rand(3, 4), Val(1)) test_rrule(ChainRules.∇eachslice, [rand(3) for _ in 1:4], rand(3, 4), Val(2)) - test_rrule(ChainRules.∇eachslice, [rand(2, 3) for _ in 1:4], rand(2, 3, 4), Val(3), check_inferred=false) + test_rrule( + ChainRules.∇eachslice, + [rand(2, 3) for _ in 1:4], + rand(2, 3, 4), + Val(3); + check_inferred=(VERSION >= v"1.7"), + ) end