Skip to content

Commit

Permalink
fix other solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich committed Mar 31, 2023
1 parent ef68e4d commit f5d932a
Showing 1 changed file with 69 additions and 58 deletions.
127 changes: 69 additions & 58 deletions lib/mps/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,17 +197,16 @@ end


function LinearAlgebra.ldiv!(A::LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
orig = size(B)
M,N = size(B)[1], ndims(B) > 1 ? size(B)[2] : 1
M,N = size(B,1), size(B,2)
dev = current_device()
queue = global_queue(dev)

B = reshape(B, (N,M))
Bt = reshape(B, (N,M))
P = reshape((A.ipiv .- UInt32(1)), (1,M))
X = similar(B)

mps_a = MPSMatrix(A.factors)
mps_b = MPSMatrix(B)
mps_b = MPSMatrix(Bt)
mps_p = MPSMatrix(P)
mps_x = MPSMatrix(X)

Expand All @@ -216,86 +215,98 @@ function LinearAlgebra.ldiv!(A::LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}}, B::M
encode!(cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x)
end

B .= X
B = reshape(B, orig)
Bt .= X
return B
end

function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
M,N = size(B)

function LinearAlgebra.ldiv!(A::UpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
M,N = size(B,1), size(B,2)
dev = current_device()
queue = global_queue(dev)
cmdbuf = MTLCommandBuffer(queue)
enqueue!(cmdbuf)

Bh = reshape(B, )
X = MtlMatrix{T}(undef, size(B))
Ad = MtlMatrix(A; storage=Private)
Bt = reshape(B, (N,M))
X = similar(B)

mps_a = MPSMatrix(A)
mps_b = MPSMatrix(Bh) # TODO reshape to matrix if B is a vector
mps_a = MPSMatrix(Ad)
mps_b = MPSMatrix(Bt)
mps_x = MPSMatrix(X)

solve_kernel = MPSMatrixSolveTriangular(dev, false, false, false, true, M, N, 1.0)
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
commit!(cmdbuf)
MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, false, false, false, false, M, N, 1.0)
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
end

return X
Bt .= X
return B
end

function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
M,N = size(B)

function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
M,N = size(B,1), size(B,2)
dev = current_device()
queue = global_queue(dev)
cmdbuf = MTLCommandBuffer(queue)
enqueue!(cmdbuf)

X = MtlMatrix{T}(undef, size(B))
Ad = MtlMatrix(A; storage=Private)
Bt = reshape(B, (N,M))
X = similar(B)

mps_a = MPSMatrix(A)
mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector
mps_a = MPSMatrix(Ad)
mps_b = MPSMatrix(Bt)
mps_x = MPSMatrix(X)

solve_kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, M, N, 1.0)
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
commit!(cmdbuf)
MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, false, false, false, true, M, N, 1.0)
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
end

return X
Bt .= X
return B
end

function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
M,N = size(B)

function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
M,N = size(B,1), size(B,2)
dev = current_device()
queue = global_queue(dev)
cmdbuf = MTLCommandBuffer(queue)
enqueue!(cmdbuf)

X = MtlMatrix{T}(undef, size(B))
Ad = MtlMatrix(A; storage=Private)
Bt = reshape(B, (N,M))
X = similar(B)

mps_a = MPSMatrix(A)
mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector
mps_a = MPSMatrix(Ad)
mps_b = MPSMatrix(Bt)
mps_x = MPSMatrix(X)

solve_kernel = MPSMatrixSolveTriangular(dev, false, true, false, true, M, N, 1.0)
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
commit!(cmdbuf)
MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, M, N, 1.0)
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
end

return X
Bt .= X
return B
end

# function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
# require_one_based_indexing(A, B)
# m, n = size(A)
# if m == n
# if istril(A)
# if istriu(A)
# return Diagonal(A) \ B
# else
# return LowerTriangular(A) \ B
# end
# end
# if istriu(A)
# return UpperTriangular(A) \ B
# end
# return lu(A) \ B
# end
# return qr(A, ColumnNorm()) \ B
# end

function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
M,N = size(B,1), size(B,2)
dev = current_device()
queue = global_queue(dev)

A = MtlMatrix(A; storage=Private)
Bt = reshape(B, (N,M))
X = similar(B)

mps_a = MPSMatrix(A)
mps_b = MPSMatrix(Bt)
mps_x = MPSMatrix(X)

MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, false, true, false, true, M, N, 1.0)
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
end

Bt .= X
return B
end

0 comments on commit f5d932a

Please sign in to comment.