Skip to content

Commit

Permalink
Use local memory in band matrix solve
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed May 21, 2024
1 parent db8780f commit aa4f0ac
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 33 deletions.
12 changes: 8 additions & 4 deletions ext/cuda/matrix_fields_multiple_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ NVTX.@annotate function multiple_field_solve!(
b,
x1,
)
Ni, Nj, _, _, Nh = size(Fields.field_values(x1))
Ni, Nj, _, Nv, Nh = size(Fields.field_values(x1))
names = MatrixFields.matrix_row_keys(keys(A))
Nnames = length(names)
nthreads, nblocks = _configure_threadblock(Ni * Nj * Nh * Nnames)
Expand All @@ -38,7 +38,7 @@ NVTX.@annotate function multiple_field_solve!(

device = ClimaComms.device(x[first(names)])

args = (device, caches, xs, As, bs, x1, Val(Nnames))
args = (device, caches, xs, As, bs, x1, Val(Nv), Val(Nnames))

auto_launch!(
multiple_field_solve_kernel!,
Expand All @@ -62,9 +62,10 @@ Base.@propagate_inbounds column_A(A, i, j, h) = Spaces.column(A, i, j, h)
i,
j,
h,
::Val{Nv},
iname,
::Val{Nnames},
) where {Nnames}
) where {Nnames, Nv}
return quote
Base.Cartesian.@nif $Nnames ξ -> (iname == ξ) ξ -> begin
_single_field_solve!(
Expand All @@ -73,6 +74,7 @@ Base.@propagate_inbounds column_A(A, i, j, h) = Spaces.column(A, i, j, h)
column_A(xs[ξ], i, j, h),
column_A(As[ξ], i, j, h),
column_A(bs[ξ], i, j, h),
Val(Nv),
)
end
end
Expand All @@ -85,8 +87,9 @@ function multiple_field_solve_kernel!(
As,
bs,
x1,
::Val{Nv},
::Val{Nnames},
) where {Nnames}
) where {Nnames, Nv}
@inbounds begin
Ni, Nj, _, _, Nh = size(Fields.field_values(x1))
tidx = (CUDA.blockIdx().x - 1) * CUDA.blockDim().x + CUDA.threadIdx().x
Expand All @@ -102,6 +105,7 @@ function multiple_field_solve_kernel!(
i,
j,
h,
Val(Nv),
iname,
Val(Nnames),
)
Expand Down
125 changes: 110 additions & 15 deletions ext/cuda/matrix_fields_single_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ import ClimaCore.Fields: Field
import ClimaCore.Fields
import ClimaCore.Spaces
import ClimaCore.Topologies
import ClimaCore.MatrixFields
import ClimaCore.MatrixFields: single_field_solve!
import ClimaCore.MatrixFields: _single_field_solve!
import ClimaCore.MatrixFields: band_matrix_solve!, unzip_tuple_field_values
import ClimaCore.RecursiveApply: , , , rmap, rzero, rdiv

function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
Ni, Nj, _, Nv, Nh = size(Fields.field_values(A))
nitems = Ni * Nj * Nh
nthreads = min(256, nitems)
nblocks = cld(nitems, nthreads)
args = (device, cache, x, A, b)
args = (device, cache, x, A, b, Val(Nv))
auto_launch!(
single_field_solve_kernel!,
args,
Expand All @@ -27,17 +27,26 @@ function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
)
end

function single_field_solve_kernel!(device, cache, x, A, b)
function single_field_solve_kernel!(
device,
cache,
x,
A,
b,
::Val{Nv},
) where {Nv}
idx = CUDA.threadIdx().x + (CUDA.blockIdx().x - 1) * CUDA.blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
if idx <= Ni * Nj * Nh
i, j, h = Topologies._get_idx((Ni, Nj, Nh), idx)
(i, j, h) = CartesianIndices((1:Ni, 1:Nj, 1:Nh))[idx].I

_single_field_solve!(
device,
Spaces.column(cache, i, j, h),
Spaces.column(x, i, j, h),
Spaces.column(A, i, j, h),
Spaces.column(b, i, j, h),
Val(Nv),
)
end
return nothing
Expand All @@ -49,13 +58,15 @@ function _single_field_solve!(
x::Fields.ColumnField,
A::Fields.ColumnField,
b::Fields.ColumnField,
)
band_matrix_solve!(
::Val{Nv},
) where {Nv}
band_matrix_solve_local_mem!(
eltype(A),
unzip_tuple_field_values(Fields.field_values(cache)),
Fields.field_values(x),
unzip_tuple_field_values(Fields.field_values(A.entries)),
Fields.field_values(b),
Val(Nv),
)
end

Expand All @@ -65,12 +76,12 @@ function _single_field_solve!(
x::Fields.ColumnField,
A::UniformScaling,
b::Fields.ColumnField,
)
::Val{Nv},
) where {Nv}
x_data = Fields.field_values(x)
b_data = Fields.field_values(b)
n = length(x_data)
@inbounds for i in 1:n
x_data[i] = inv(A.λ) b_data[i]
@inbounds for v in 1:Nv
x_data[v] = inv(A.λ) b_data[v]
end
end

Expand All @@ -80,11 +91,95 @@ function _single_field_solve!(
x::Fields.PointDataField,
A::UniformScaling,
b::Fields.PointDataField,
)
::Val{Nv},
) where {Nv}
x_data = Fields.field_values(x)
b_data = Fields.field_values(b)
n = length(x_data)
@inbounds begin
x_data[] = inv(A.λ) b_data[]
x_data[] = inv(A.λ) b_data[]
end

using StaticArrays: MArray
function band_matrix_solve_local_mem!(
t::Type{<:MatrixFields.TridiagonalMatrixRow},
cache,
x,
Aⱼs,
b,
::Val{Nv},
) where {Nv}
Ux, U₊₁ = cache
A₋₁, A₀, A₊₁ = Aⱼs

Ux_local = MArray{Tuple{Nv}, eltype(Ux)}(undef)
U₊₁_local = MArray{Tuple{Nv}, eltype(U₊₁)}(undef)
x_local = MArray{Tuple{Nv}, eltype(x)}(undef)
A₋₁_local = MArray{Tuple{Nv}, eltype(A₋₁)}(undef)
A₀_local = MArray{Tuple{Nv}, eltype(A₀)}(undef)
A₊₁_local = MArray{Tuple{Nv}, eltype(A₊₁)}(undef)
b_local = MArray{Tuple{Nv}, eltype(b)}(undef)
@inbounds for v in 1:Nv
A₋₁_local[v] = A₋₁[v]
A₀_local[v] = A₀[v]
A₊₁_local[v] = A₊₁[v]
b_local[v] = b[v]
end
cache_local = (Ux_local, U₊₁_local)
Aⱼs_local = (A₋₁, A₀, A₊₁)
band_matrix_solve!(t, cache_local, x_local, Aⱼs_local, b_local)
@inbounds for v in 1:Nv
x[v] = x_local[v]
end
return nothing
end

function band_matrix_solve_local_mem!(
t::Type{<:MatrixFields.PentadiagonalMatrixRow},
cache,
x,
Aⱼs,
b,
::Val{Nv},
) where {Nv}
Ux, U₊₁, U₊₂ = cache
A₋₂, A₋₁, A₀, A₊₁, A₊₂ = Aⱼs
Ux_local = MArray{Tuple{Nv}, eltype(Ux)}(undef)
U₊₁_local = MArray{Tuple{Nv}, eltype(U₊₁)}(undef)
U₊₂_local = MArray{Tuple{Nv}, eltype(U₊₂)}(undef)
x_local = MArray{Tuple{Nv}, eltype(x)}(undef)
A₋₂_local = MArray{Tuple{Nv}, eltype(A₋₂)}(undef)
A₋₁_local = MArray{Tuple{Nv}, eltype(A₋₁)}(undef)
A₀_local = MArray{Tuple{Nv}, eltype(A₀)}(undef)
A₊₁_local = MArray{Tuple{Nv}, eltype(A₊₁)}(undef)
A₊₂_local = MArray{Tuple{Nv}, eltype(A₊₂)}(undef)
b_local = MArray{Tuple{Nv}, eltype(b)}(undef)
@inbounds for v in 1:Nv
A₋₂_local[v] = A₋₂[v]
A₋₁_local[v] = A₋₁[v]
A₀_local[v] = A₀[v]
A₊₁_local[v] = A₊₁[v]
A₊₂_local[v] = A₊₂[v]
b_local[v] = b[v]
end
cache_local = (Ux_local, U₊₁_local, U₊₂_local)
Aⱼs_local = (A₋₂, A₋₁, A₀, A₊₁, A₊₂)
band_matrix_solve!(t, cache_local, x_local, Aⱼs_local, b_local)
@inbounds for v in 1:Nv
x[v] = x_local[v]
end
return nothing
end

function band_matrix_solve_local_mem!(
t::Type{<:MatrixFields.DiagonalMatrixRow},
cache,
x,
Aⱼs,
b,
_,
)
(A₀,) = Aⱼs
@inbounds for v in 1:Nv
x[v] = inv(A₀[v]) b[v]
end
return nothing
end
14 changes: 13 additions & 1 deletion src/MatrixFields/field_matrix_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,17 @@ function check_field_matrix_solver(::BlockDiagonalSolve, _, A, b)
end
end

# TODO: we can remove the uniform_vertical_levels
# limitation while still using static shared memory
# once Nv is in the type space.
function uniform_vertical_levels(x, names)
_, _, _, Nv1, _ = size(Fields.field_values(x[first(names)]))
return all(Base.tail(names)) do name
_, _, _, Nv, _ = size(Fields.field_values(x[name]))
Nv == Nv1
end
end

NVTX.@annotate function run_field_matrix_solver!(
::BlockDiagonalSolve,
cache,
Expand All @@ -256,7 +267,8 @@ NVTX.@annotate function run_field_matrix_solver!(
)
names = matrix_row_keys(keys(A))
if length(names) == 1 ||
all(name -> A[name, name] isa UniformScaling, names.values)
all(name -> A[name, name] isa UniformScaling, names.values) ||
!uniform_vertical_levels(x, names.values)
foreach(names) do name
single_field_solve!(cache[name], x[name], A[name, name], b[name])
end
Expand Down
8 changes: 0 additions & 8 deletions src/MatrixFields/single_field_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,6 @@ function _single_field_solve_col!(
end
end

_single_field_solve!(
cache::Fields.Field,
x::Fields.Field,
A::Union{Fields.Field, UniformScaling},
b::Fields.Field,
dev::ClimaComms.AbstractCPUDevice,
) = _single_field_solve_col!(dev, cache, x, A, b)

unzip_tuple_field_values(data) =
ntuple(i -> data.:($i), Val(length(propertynames(data))))

Expand Down
22 changes: 17 additions & 5 deletions test/MatrixFields/field_matrix_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Revise; include(joinpath("test", "MatrixFields", "field_matrix_solvers.jl"
import Logging
import Logging: Debug
import LinearAlgebra: I, norm
import ClimaComms
import ClimaCore.Utilities: half
import ClimaCore.RecursiveApply:
import ClimaCore.MatrixFields: @name
Expand All @@ -21,8 +22,16 @@ function test_field_matrix_solver(; test_name, alg, A, b, use_rel_error = false)
solver = FieldMatrixSolver(alg, A, b)
args = (solver, x, A, b)

solve_time = @benchmark field_matrix_solve!(args...)
mul_time = @benchmark field_matrix_mul!(b_test, A, x)
solve_time =
@benchmark ClimaComms.@cuda_sync comms_device field_matrix_solve!(
args...,
)
mul_time =
@benchmark ClimaComms.@cuda_sync comms_device field_matrix_mul!(
b_test,
A,
x,
)

solve_time_rounded = round(solve_time; sigdigits = 2)
mul_time_rounded = round(mul_time; sigdigits = 2)
Expand Down Expand Up @@ -58,11 +67,14 @@ function test_field_matrix_solver(; test_name, alg, A, b, use_rel_error = false)
AnyFrameModule(MatrixFields.KrylovKit),
AnyFrameModule(Base.CoreLogging),
)
@test_opt ignored_modules = ignored FieldMatrixSolver(alg, A, b)
@test_opt ignored_modules = ignored field_matrix_solve!(args...)
using_cuda ||
@test_opt ignored_modules = ignored FieldMatrixSolver(alg, A, b)
using_cuda ||
@test_opt ignored_modules = ignored field_matrix_solve!(args...)
@test_opt ignored_modules = ignored field_matrix_mul!(b, A, x)

using_cuda || @test @allocated(field_matrix_solve!(args...)) == 0
# TODO: fix broken test when Nv is added to the type space
using_cuda || @test_broken @allocated(field_matrix_solve!(args...)) == 0
using_cuda || @test @allocated(field_matrix_mul!(b, A, x)) == 0
end
end
Expand Down

0 comments on commit aa4f0ac

Please sign in to comment.