Skip to content

Commit

Permalink
unthunk in multigate rrule
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Sep 4, 2022
1 parent 3e12946 commit e4c650f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ multigate(x::AbstractArray, h, ::Val{N}) where N = ntuple(n -> gate(x,h,n), N)
function ChainRulesCore.rrule(::typeof(multigate), x::AbstractArray, h, c)
function multigate_pullback(dy)
dx = map!(zero, similar(x, float(eltype(x)), axes(x)), x)
foreach(multigate(dx, h, c), dy) do dxᵢ, dyᵢ
foreach(multigate(dx, h, c), unthunk(dy)) do dxᵢ, dyᵢ
dyᵢ isa AbstractZero && return
@. dxᵢ += dyᵢ
end
Expand Down

0 comments on commit e4c650f

Please sign in to comment.