-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplest prod(x; dims) gradient #112
Conversation
340bcb0
to
6b1b80f
Compare
Bump? Just cleaned up the mess that github's merge tool sometimes creates with unicode |
With a precursory look, it looks about right to me. Hopping on a flight now, but in the meantime bors try |
tryBuild succeeded |
bors r+ |
@maleadt is bors down? |
bors r+ |
https://app.bors.tech/ |
bors r+ |
🔒 Permission denied Existing reviewers: click here to make CarloLucibello a reviewer |
@MikeInnes @dhairyagandhi96 @maleadt could someone add me to bors' reviewers? |
Could you try now? |
bors r+ |
🔒 Permission denied Existing reviewers: click here to make CarloLucibello a reviewer |
no luck |
bors delegate=CarloLucibello |
✌️ CarloLucibello can now approve this pull request. To approve and merge a pull request, simply reply with |
Should do for now, will look at the dashboard in my morn |
bors r+ |
I think I am in Flux's list now but not in Zygote's |
That's what caught me off guard, bors-ng/bors-ng#517 suggests the fix too |
Build succeeded |
523: fix prod with tuple arg r=CarloLucibello a=CarloLucibello This fixes the following problem caused by a relaxation of the signature in #112 ```julia julia> gradient(x -> prod((1,2,3)), 1) ERROR: MethodError: no method matching prod(::Tuple{Int64,Int64,Int64}; dims=Colon()) Closest candidates are: prod(::Tuple{Any,Vararg{Any,N} where N}) at tuple.jl:385 got unsupported keyword argument "dims" prod(::Any) at reduce.jl:448 got unsupported keyword argument "dims" prod(::Any, ::StaticArrays.StaticArray{#s160,T,N} where N where #s160<:Tuple; dims) where T at /home/carlo/.julia/packages/StaticArrays/1g9bq/src/mapreduce.jl:234 ... Stacktrace: [1] #adjoint#3920 at /home/carlo/.julia/packages/Zygote/XCgv1/src/lib/array.jl:220 [inlined] [2] adjoint at ./none:0 [inlined] [3] _pullback at /home/carlo/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined] [4] #17 at ./REPL[18]:1 [inlined] [5] _pullback(::Zygote.Context, ::var"#17#18", ::Int64) at /home/carlo/.julia/packages/Zygote/XCgv1/src/compiler/interface2.jl:? [6] _pullback(::Function, ::Int64) at /home/carlo/.julia/packages/Zygote/XCgv1/src/compiler/interface.jl:31 [7] pullback(::Function, ::Int64) at /home/carlo/.julia/packages/Zygote/XCgv1/src/compiler/interface.jl:37 [8] gradient(::Function, ::Int64, ::Vararg{Int64,N} where N) at /home/carlo/.julia/packages/Zygote/XCgv1/src/compiler/interface.jl:46 [9] top-level scope at REPL[18]:1 ``` Co-authored-by: CarloLucibello <carlo.lucibello@gmail.com>
The current gradient for
prod(x; dims)
gives incorrect results, this PR fixes it (parallel to FluxML/Tracker.jl#1 ):This does not handle zeros in the array correctly -- see FluxML/Flux.jl#524 for attempts to do that. The
circshift(...
operation deleted here was a correct (but slow) gradient forprod(x)
, but is clearly independent ofdims
.The example above is almost the same as the one in the tests, which strangely passes, without this PR. Perhaps something is wrong with
gradtest
?