diff --git a/lib/mps/linalg.jl b/lib/mps/linalg.jl index 091b84d95..a140dfa74 100644 --- a/lib/mps/linalg.jl +++ b/lib/mps/linalg.jl @@ -197,17 +197,16 @@ end function LinearAlgebra.ldiv!(A::LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} - orig = size(B) - M,N = size(B)[1], ndims(B) > 1 ? size(B)[2] : 1 + M,N = size(B,1), size(B,2) dev = current_device() queue = global_queue(dev) - B = reshape(B, (N,M)) + Bt = reshape(B, (N,M)) P = reshape((A.ipiv .- UInt32(1)), (1,M)) X = similar(B) mps_a = MPSMatrix(A.factors) - mps_b = MPSMatrix(B) + mps_b = MPSMatrix(Bt) mps_p = MPSMatrix(P) mps_x = MPSMatrix(X) @@ -216,86 +215,98 @@ function LinearAlgebra.ldiv!(A::LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}}, B::M encode!(cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x) end - B .= X - B = reshape(B, orig) + Bt .= X + return B end -function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T} - M,N = size(B) + +function LinearAlgebra.ldiv!(A::UpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + M,N = size(B,1), size(B,2) dev = current_device() queue = global_queue(dev) - cmdbuf = MTLCommandBuffer(queue) - enqueue!(cmdbuf) - Bh = reshape(B, ) - X = MtlMatrix{T}(undef, size(B)) + Ad = MtlMatrix(A; storage=Private) + Bt = reshape(B, (N,M)) + X = similar(B) - mps_a = MPSMatrix(A) - mps_b = MPSMatrix(Bh) # TODO reshape to matrix if B is a vector + mps_a = MPSMatrix(Ad) + mps_b = MPSMatrix(Bt) mps_x = MPSMatrix(X) - solve_kernel = MPSMatrixSolveTriangular(dev, false, false, false, true, M, N, 1.0) - encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x) - commit!(cmdbuf) + MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, false, false, false, false, M, N, 1.0) + encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) + end - return X + Bt .= X + return B end -function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T} - M,N = size(B) + +function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + M,N = size(B,1), size(B,2) dev = current_device() queue = global_queue(dev) - cmdbuf = MTLCommandBuffer(queue) - enqueue!(cmdbuf) - X = MtlMatrix{T}(undef, size(B)) + Ad = MtlMatrix(A; storage=Private) + Bt = reshape(B, (N,M)) + X = similar(B) - mps_a = MPSMatrix(A) - mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector + mps_a = MPSMatrix(Ad) + mps_b = MPSMatrix(Bt) mps_x = MPSMatrix(X) - solve_kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, M, N, 1.0) - encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x) - commit!(cmdbuf) + MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, false, false, false, true, M, N, 1.0) + encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) + end - return X + Bt .= X + return B end -function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T} - M,N = size(B) + +function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + M,N = size(B,1), size(B,2) dev = current_device() queue = global_queue(dev) - cmdbuf = MTLCommandBuffer(queue) - enqueue!(cmdbuf) - X = MtlMatrix{T}(undef, size(B)) + Ad = MtlMatrix(A; storage=Private) + Bt = reshape(B, (N,M)) + X = similar(B) - mps_a = MPSMatrix(A) - mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector + mps_a = MPSMatrix(Ad) + mps_b = MPSMatrix(Bt) mps_x = MPSMatrix(X) - solve_kernel = MPSMatrixSolveTriangular(dev, false, true, false, true, M, N, 1.0) - encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x) - commit!(cmdbuf) + MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, M, N, 1.0) + encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) + end - return X + Bt .= X + return B end -# function (\)(A::AbstractMatrix, B::AbstractVecOrMat) -# require_one_based_indexing(A, B) -# m, n = size(A) -# if m == n -# if istril(A) -# if istriu(A) -# return Diagonal(A) \ B -# else -# return LowerTriangular(A) \ B -# end -# end -# if istriu(A) -# return UpperTriangular(A) \ B -# end -# return lu(A) \ B -# end -# return qr(A, ColumnNorm()) \ B -# end \ No newline at end of file + +function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + M,N = size(B,1), size(B,2) + dev = current_device() + queue = global_queue(dev) + + A = MtlMatrix(A; storage=Private) + Bt = reshape(B, (N,M)) + X = similar(B) + + mps_a = MPSMatrix(A) + mps_b = MPSMatrix(Bt) + mps_x = MPSMatrix(X) + + MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, false, true, false, true, M, N, 1.0) + encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) + end + + Bt .= X + return B +end \ No newline at end of file