Skip to content

Commit

Permalink
Update batched_utils.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten authored Sep 23, 2024
1 parent 1490f5e commit b90d62a
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/batched/batched_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ function batched_mul_T2(x::AbstractArray, y::AbstractArray)
return reshape(z, size(z, 1), size(z, 2), size(x)[3:end]...)
end

function batched_mul_large_small(A::AbstractArray, x::AbstractVecOrMat)
A= reshape(A, size(A, 1), size(A, 2), :)
y= batched_mul(A′, reshape(x, size(x, 1), size(x, 2)))
y = reshape(y′, size(A, 1), size(x, 2))
return y
function batched_mul_large_small(x::AbstractArray, y::AbstractVecOrMat)
x= reshape(x, size(x, 1), size(x, 2), :)
z= batched_mul(x′, reshape(y, size(y, 1), size(y, 2)))
z = reshape(z′, size(x, 1), size(y, 2), size(x)[3:end]...)
return z
end

# might need custom chain rule
Expand Down

0 comments on commit b90d62a

Please sign in to comment.