Skip to content

Commit

Permalink
Attempting gradient types fix
Browse files Browse the repository at this point in the history
  • Loading branch information
murrellb authored Dec 12, 2024
1 parent e39add7 commit 43d30c6
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T
swapped = false
original_size = size(x)
x = reshape(x, size(x,1), nonfirstdims(x))
dx = reshape(dx, size(dx,1), nonfirstdims(dx))
dx = reshape(dx, size(x,1), nonfirstdims(dx))

first_dim, second_dim = size(x,1), size(x,2)
if o.sort_dims && second_dim < first_dim
Expand Down

0 comments on commit 43d30c6

Please sign in to comment.