From 43d30c60fd367c595570c85bc5326b802143f7e2 Mon Sep 17 00:00:00 2001 From: Ben Murrell Date: Thu, 12 Dec 2024 12:52:54 +0100 Subject: [PATCH] Attempting gradient types fix --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index a08ec1d..7bdea5e 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -691,7 +691,7 @@ function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T swapped = false original_size = size(x) x = reshape(x, size(x,1), nonfirstdims(x)) - dx = reshape(dx, size(dx,1), nonfirstdims(dx)) + dx = reshape(dx, size(x,1), nonfirstdims(dx)) first_dim, second_dim = size(x,1), size(x,2) if o.sort_dims && second_dim < first_dim