Skip to content

Commit

Permalink
Use shared memory in band matrix solve
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed May 20, 2024
1 parent db8780f commit 46cc9c4
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 34 deletions.
23 changes: 15 additions & 8 deletions ext/cuda/matrix_fields_multiple_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ NVTX.@annotate function multiple_field_solve!(
b,
x1,
)
Ni, Nj, _, _, Nh = size(Fields.field_values(x1))
Ni, Nj, _, Nv, Nh = size(Fields.field_values(x1))
names = MatrixFields.matrix_row_keys(keys(A))
Nnames = length(names)
nthreads, nblocks = _configure_threadblock(Ni * Nj * Nh * Nnames)
nthreads, nblocks = _configure_threadblock(Ni * Nj * Nh * Nnames * Nv)
sscache = Operators.strip_space(cache)
ssx = Operators.strip_space(x)
ssA = Operators.strip_space(A)
Expand All @@ -38,7 +38,7 @@ NVTX.@annotate function multiple_field_solve!(

device = ClimaComms.device(x[first(names)])

args = (device, caches, xs, As, bs, x1, Val(Nnames))
args = (device, caches, xs, As, bs, x1, Val(Nnames), Val(Nv))

auto_launch!(
multiple_field_solve_kernel!,
Expand All @@ -62,9 +62,11 @@ Base.@propagate_inbounds column_A(A, i, j, h) = Spaces.column(A, i, j, h)
i,
j,
h,
v,
iname,
::Val{Nnames},
) where {Nnames}
::Val{Nv},
) where {Nnames, Nv}
return quote
Base.Cartesian.@nif $Nnames ξ -> (iname == ξ) ξ -> begin
_single_field_solve!(
Expand All @@ -73,6 +75,8 @@ Base.@propagate_inbounds column_A(A, i, j, h) = Spaces.column(A, i, j, h)
column_A(xs[ξ], i, j, h),
column_A(As[ξ], i, j, h),
column_A(bs[ξ], i, j, h),
v,
Val(Nv),
)
end
end
Expand All @@ -86,13 +90,14 @@ function multiple_field_solve_kernel!(
bs,
x1,
::Val{Nnames},
) where {Nnames}
::Val{Nv},
) where {Nnames, Nv}
@inbounds begin
Ni, Nj, _, _, Nh = size(Fields.field_values(x1))
tidx = (CUDA.blockIdx().x - 1) * CUDA.blockDim().x + CUDA.threadIdx().x
if 1 tidx prod((Ni, Nj, Nh, Nnames))
(i, j, h, iname) =
CartesianIndices((1:Ni, 1:Nj, 1:Nh, 1:Nnames))[tidx].I
if 1 tidx prod((Ni, Nj, Nh, Nv, Nnames))
(i, j, h, v, iname) =
CartesianIndices((1:Ni, 1:Nj, 1:Nh, 1:Nv, 1:Nnames))[tidx].I
generated_single_field_solve!(
device,
caches,
Expand All @@ -102,8 +107,10 @@ function multiple_field_solve_kernel!(
i,
j,
h,
v,
iname,
Val(Nnames),
Val(Nv),
)
end
end
Expand Down
138 changes: 120 additions & 18 deletions ext/cuda/matrix_fields_single_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ import ClimaCore.Fields: Field
import ClimaCore.Fields
import ClimaCore.Spaces
import ClimaCore.Topologies
import ClimaCore.MatrixFields
import ClimaCore.MatrixFields: single_field_solve!
import ClimaCore.MatrixFields: _single_field_solve!
import ClimaCore.MatrixFields: band_matrix_solve!, unzip_tuple_field_values
import ClimaCore.RecursiveApply: , , , rmap, rzero, rdiv

function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
nitems = Ni * Nj * Nh
Ni, Nj, _, Nv, Nh = size(Fields.field_values(A))
nitems = Ni * Nj * Nh * Nv
nthreads = min(256, nitems)
nblocks = cld(nitems, nthreads)
args = (device, cache, x, A, b)
args = (device, cache, x, A, b, Val(Nv))
auto_launch!(
single_field_solve_kernel!,
args,
Expand All @@ -27,17 +27,26 @@ function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
)
end

function single_field_solve_kernel!(device, cache, x, A, b)
function single_field_solve_kernel!(
device,
cache,
x,
A,
b,
::Val{Nv},
) where {Nv}
idx = CUDA.threadIdx().x + (CUDA.blockIdx().x - 1) * CUDA.blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
if idx <= Ni * Nj * Nh
i, j, h = Topologies._get_idx((Ni, Nj, Nh), idx)
if idx <= Ni * Nj * Nh * Nv
(v, i, j, h) = Topologies._get_idx((Nv, Ni, Nj, Nh), idx)
_single_field_solve!(
device,
Spaces.column(cache, i, j, h),
Spaces.column(x, i, j, h),
Spaces.column(A, i, j, h),
Spaces.column(b, i, j, h),
v,
Val(Nv),
)
end
return nothing
Expand All @@ -49,13 +58,17 @@ function _single_field_solve!(
x::Fields.ColumnField,
A::Fields.ColumnField,
b::Fields.ColumnField,
)
band_matrix_solve!(
v,
::Val{Nv},
) where {Nv}
band_matrix_solve_parallel!(
eltype(A),
unzip_tuple_field_values(Fields.field_values(cache)),
Fields.field_values(x),
unzip_tuple_field_values(Fields.field_values(A.entries)),
Fields.field_values(b),
v,
Val(Nv),
)
end

Expand All @@ -65,13 +78,12 @@ function _single_field_solve!(
x::Fields.ColumnField,
A::UniformScaling,
b::Fields.ColumnField,
)
v,
::Val{Nv},
) where {Nv}
x_data = Fields.field_values(x)
b_data = Fields.field_values(b)
n = length(x_data)
@inbounds for i in 1:n
x_data[i] = inv(A.λ) b_data[i]
end
x_data[v] = inv(A.λ) b_data[v]
end

function _single_field_solve!(
Expand All @@ -80,11 +92,101 @@ function _single_field_solve!(
x::Fields.PointDataField,
A::UniformScaling,
b::Fields.PointDataField,
)
v,
::Val{Nv},
) where {Nv}
x_data = Fields.field_values(x)
b_data = Fields.field_values(b)
n = length(x_data)
@inbounds begin
x_data[] = inv(A.λ) b_data[]
x_data[] = inv(A.λ) b_data[]
end

function band_matrix_solve_parallel!(
t::Type{<:MatrixFields.TridiagonalMatrixRow},
cache,
x,
Aⱼs,
b,
v,
::Val{Nv},
) where {Nv}
Ux, U₊₁ = cache
Ux_shmem = CUDA.CuStaticSharedArray(eltype(Ux), Nv)
U₊₁_shmem = CUDA.CuStaticSharedArray(eltype(U₊₁), Nv)
x_shmem = CUDA.CuStaticSharedArray(eltype(x), Nv)
cache_shmem = (Ux_shmem, U₊₁_shmem)
A₋₁, A₀, A₊₁ = Aⱼs
A₋₁_shmem = CUDA.CuStaticSharedArray(eltype(A₋₁), Nv)
A₀_shmem = CUDA.CuStaticSharedArray(eltype(A₀), Nv)
A₊₁_shmem = CUDA.CuStaticSharedArray(eltype(A₊₁), Nv)
A₋₁_shmem[v] = A₋₁[v]
A₀_shmem[v] = A₀[v]
A₊₁_shmem[v] = A₊₁[v]
b_shmem = CUDA.CuStaticSharedArray(eltype(b), Nv)
b_shmem[v] = b[v]
CUDA.sync_threads()
Aⱼs_shmem = (A₋₁, A₀, A₊₁)
if v == 1
band_matrix_solve!(t, cache_shmem, x_shmem, Aⱼs_shmem, b_shmem)
end
CUDA.sync_threads()
x[v] = x_shmem[v]
return nothing
end

function band_matrix_solve!(::Type{<:PentadiagonalMatrixRow}, cache, x, Aⱼs, b)
A₋₂, A₋₁, A₀, A₊₁, A₊₂ = Aⱼs
Ux, U₊₁, U₊₂ = cache
n = length(x)
end

function band_matrix_solve_parallel!(
t::Type{<:MatrixFields.PentadiagonalMatrixRow},
cache,
x,
Aⱼs,
b,
v,
::Val{Nv},
) where {Nv}
Ux, U₊₁ = cache
Ux_shmem = CUDA.CuStaticSharedArray(eltype(Ux), Nv)
U₊₁_shmem = CUDA.CuStaticSharedArray(eltype(U₊₁), Nv)
U₊₂_shmem = CUDA.CuStaticSharedArray(eltype(U₊₂), Nv)
x_shmem = CUDA.CuStaticSharedArray(eltype(x), Nv)
cache_shmem = (Ux_shmem, U₊₁_shmem)
A₋₂, A₋₁, A₀, A₊₁, A₊₂ = Aⱼs
A₋₂_shmem = CUDA.CuStaticSharedArray(eltype(A₋₂), Nv)
A₋₁_shmem = CUDA.CuStaticSharedArray(eltype(A₋₁), Nv)
A₀_shmem = CUDA.CuStaticSharedArray(eltype(A₀), Nv)
A₊₁_shmem = CUDA.CuStaticSharedArray(eltype(A₊₁), Nv)
A₊₂_shmem = CUDA.CuStaticSharedArray(eltype(A₊₂), Nv)
A₋₂_shmem[v] = A₋₂[v]
A₋₁_shmem[v] = A₋₁[v]
A₀_shmem[v] = A₀[v]
A₊₁_shmem[v] = A₊₁[v]
A₊₂_shmem[v] = A₊₂[v]
b_shmem = CUDA.CuStaticSharedArray(eltype(b), Nv)
b_shmem[v] = b[v]
CUDA.sync_threads()
Aⱼs_shmem = (A₋₂, A₋₁, A₀, A₊₁, A₊₂)
if v == 1
band_matrix_solve!(t, cache_shmem, x_shmem, Aⱼs_shmem, b_shmem)
end
CUDA.sync_threads()
x[v] = x_shmem[v]
return nothing
end

function band_matrix_solve_parallel!(
t::Type{<:MatrixFields.DiagonalMatrixRow},
cache,
x,
Aⱼs,
b,
v,
_,
)
(A₀,) = Aⱼs
x[v] = inv(A₀[v]) b[v]
return nothing
end
8 changes: 0 additions & 8 deletions src/MatrixFields/single_field_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,6 @@ function _single_field_solve_col!(
end
end

_single_field_solve!(
cache::Fields.Field,
x::Fields.Field,
A::Union{Fields.Field, UniformScaling},
b::Fields.Field,
dev::ClimaComms.AbstractCPUDevice,
) = _single_field_solve_col!(dev, cache, x, A, b)

unzip_tuple_field_values(data) =
ntuple(i -> data.:($i), Val(length(propertynames(data))))

Expand Down

0 comments on commit 46cc9c4

Please sign in to comment.