Skip to content

Commit

Permalink
baseline
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich committed Mar 31, 2023
1 parent f5d932a commit 5ab3330
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 34 deletions.
94 changes: 60 additions & 34 deletions lib/mps/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,30 +192,43 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T}

commit!(cmdbuf)

wait_completed(cmdbuf)

return B
end


function LinearAlgebra.:(\)(A::LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
C = deepcopy(B)
LinearAlgebra.ldiv!(A, C)
return C
end


function LinearAlgebra.ldiv!(A::LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
M,N = size(B,1), size(B,2)
dev = current_device()
queue = global_queue(dev)

Bt = reshape(B, (N,M))
At = similar(A.factors)
Bt = similar(B, (N,M))
P = reshape((A.ipiv .- UInt32(1)), (1,M))
X = similar(B)
X = similar(B, (N,M))

transpose!(At, A.factors)
transpose!(Bt, B)

mps_a = MPSMatrix(A.factors)
mps_a = MPSMatrix(At)
mps_b = MPSMatrix(Bt)
mps_p = MPSMatrix(P)
mps_x = MPSMatrix(X)

MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveLU(dev, true, M, N)
kernel = MPSMatrixSolveLU(dev, false, M, N)
encode!(cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x)
end

Bt .= X
transpose!(B, X)
return B
end

Expand All @@ -225,20 +238,24 @@ function LinearAlgebra.ldiv!(A::UpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrM
dev = current_device()
queue = global_queue(dev)

Ad = MtlMatrix(A; storage=Private)
Bt = reshape(B, (N,M))
X = similar(B)
Ad = MtlMatrix(A')
Br = similar(B, (M,M))
X = similar(Br)

transpose!(Br, B)

mps_a = MPSMatrix(Ad)
mps_b = MPSMatrix(Bt)
mps_b = MPSMatrix(Br)
mps_x = MPSMatrix(X)

MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, false, false, false, false, M, N, 1.0)
buf = MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, N, M, 1.0)
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
end

Bt .= X
wait_completed(buf)

copy!(B, X)
return B
end

Expand All @@ -248,20 +265,23 @@ function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T, <:MtlMatrix{T}}, B::MtlVe
dev = current_device()
queue = global_queue(dev)

Ad = MtlMatrix(A; storage=Private)
Bt = reshape(B, (N,M))
X = similar(B)
Ad = MtlMatrix(A)
Br = reshape(B, (M,N))
X = similar(Br)

mps_a = MPSMatrix(Ad)
mps_b = MPSMatrix(Bt)
mps_b = MPSMatrix(Br)
mps_x = MPSMatrix(X)

MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, false, false, false, true, M, N, 1.0)

buf = MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, true, false, false, true, M, N, 1.0)
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
end

Bt .= X
wait_completed(buf)

copy!(Br, X)
return B
end

Expand All @@ -271,20 +291,23 @@ function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrM
dev = current_device()
queue = global_queue(dev)

Ad = MtlMatrix(A; storage=Private)
Bt = reshape(B, (N,M))
X = similar(B)
Ad = MtlMatrix(A)
Br = reshape(B, (M,N))
X = similar(Br)

mps_a = MPSMatrix(Ad)
mps_b = MPSMatrix(Bt)
mps_b = MPSMatrix(Br)
mps_x = MPSMatrix(X)

MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, M, N, 1.0)

buf = MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, true, true, false, false, M, N, 1.0)
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
end

Bt .= X
wait_completed(buf)

copy!(Br, X)
return B
end

Expand All @@ -294,19 +317,22 @@ function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T, <:MtlMatrix{T}}, B::MtlVe
dev = current_device()
queue = global_queue(dev)

A = MtlMatrix(A; storage=Private)
Bt = reshape(B, (N,M))
X = similar(B)
Ad = MtlMatrix(A)
Br = reshape(B, (M,N))
X = similar(Br)

mps_a = MPSMatrix(A)
mps_b = MPSMatrix(Bt)
mps_a = MPSMatrix(Ad)
mps_b = MPSMatrix(Br)
mps_x = MPSMatrix(X)

MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, false, true, false, true, M, N, 1.0)

buf = MTLCommandBuffer(queue) do cmdbuf
kernel = MPSMatrixSolveTriangular(dev, true, true, false, true, M, N, 1.0)
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
end

Bt .= X
wait_completed(buf)

copy!(Br, X)
return B
end
35 changes: 35 additions & 0 deletions test/mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,39 @@ end
@test_throws SingularException lu(A)
end

@testset "solves" begin
b = MtlVector(rand(Float32, 1024))
B = MtlMatrix(rand(Float32, 1024, 1024))

A = MtlMatrix(rand(Float32, 1024, 512))
x = lu(A) \ b
@test A * x b
X = lu(A) \ B
@test A * X B

A = UpperTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
x = A \ b
@test A * x b
X = A \ B
@test A * X B

A = UnitUpperTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
x = A \ b
@test A * x b
X = A \ B
@test A * X B

A = LowerTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
x = A \ b
@test A * x b
X = A \ B
@test A * X B

A = UnitLowerTriangular(MtlMatrix(rand(Float32, 1024, 1024)))
x = A \ b
@test A * x b
X = A \ B
@test A * X B
end

end

0 comments on commit 5ab3330

Please sign in to comment.