From 146657f2665fe9745e352801952d90675c7917de Mon Sep 17 00:00:00 2001 From: Andrei Zhabinski Date: Thu, 3 Dec 2020 23:04:51 +0300 Subject: [PATCH] Fix derivative for sum() with keywords --- Project.toml | 4 ++-- src/diffrules/basic.jl | 2 ++ test/test_grad.jl | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 8bca4da..a481b04 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Yota" uuid = "cd998857-8626-517d-b929-70ad188a48f0" authors = ["Andrei Zhabinski "] -version = "0.4.1" +version = "0.4.2" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" @@ -17,7 +17,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -CUDA = "1.2.1" +CUDA = "1.2, 2.3" Cassette = "0.2.6, 0.3" ChainRulesCore = "0.9.5" Distributions = "0.23.2" diff --git a/src/diffrules/basic.jl b/src/diffrules/basic.jl index 00d6af5..2f9839b 100644 --- a/src/diffrules/basic.jl +++ b/src/diffrules/basic.jl @@ -217,6 +217,8 @@ end @diffrule Base._sum(u::AbstractArray, v::Int) u sum_grad(u, dy) @diffrule Base._sum(u::AbstractArray, v::Int) v zero(eltype(u)) @diffrule Core.kwfunc(sum)(_dims, _, u::AbstractArray) u sum_grad(u, dy) +@nodiff Core.kwfunc(sum)(_dims, _, u::AbstractArray) _dims +@nodiff Core.kwfunc(sum)(_dims, _, u::AbstractArray) _ # special sums @diffrule sum(_fn::typeof(log), u::AbstractArray) u sum_grad(u, dy) ./ u diff --git a/test/test_grad.jl b/test/test_grad.jl index 2a1042d..fcc0a84 100644 --- a/test/test_grad.jl +++ b/test/test_grad.jl @@ -13,6 +13,8 @@ loss_kw_mean(W, b, x) = Statistics.mean(W * x .+ b; dims=1)[1] val, g = grad(loss_kw_mean, args...) @test val == loss_kw_mean(args...) @test gradcheck(loss_kw_mean, args...) + + @test gradcheck(x -> sum(sum(x, dims=1)), rand(2, 3)) end