Skip to content

Commit

Permalink
Fix derivative for sum() with keywords
Browse files Browse the repository at this point in the history
  • Loading branch information
dfdx committed Dec 3, 2020
1 parent ec48779 commit 146657f
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Yota"
uuid = "cd998857-8626-517d-b929-70ad188a48f0"
authors = ["Andrei Zhabinski <andrei.zhabinski@gmail.com>"]
version = "0.4.1"
version = "0.4.2"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand All @@ -17,7 +17,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
CUDA = "1.2.1"
CUDA = "1.2, 2.3"
Cassette = "0.2.6, 0.3"
ChainRulesCore = "0.9.5"
Distributions = "0.23.2"
Expand Down
2 changes: 2 additions & 0 deletions src/diffrules/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ end
@diffrule Base._sum(u::AbstractArray, v::Int) u sum_grad(u, dy)
@diffrule Base._sum(u::AbstractArray, v::Int) v zero(eltype(u))
@diffrule Core.kwfunc(sum)(_dims, _, u::AbstractArray) u sum_grad(u, dy)
@nodiff Core.kwfunc(sum)(_dims, _, u::AbstractArray) _dims
@nodiff Core.kwfunc(sum)(_dims, _, u::AbstractArray) _

# special sums
@diffrule sum(_fn::typeof(log), u::AbstractArray) u sum_grad(u, dy) ./ u
Expand Down
2 changes: 2 additions & 0 deletions test/test_grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ loss_kw_mean(W, b, x) = Statistics.mean(W * x .+ b; dims=1)[1]
val, g = grad(loss_kw_mean, args...)
@test val == loss_kw_mean(args...)
@test gradcheck(loss_kw_mean, args...)

@test gradcheck(x -> sum(sum(x, dims=1)), rand(2, 3))
end


Expand Down

0 comments on commit 146657f

Please sign in to comment.