Skip to content

Commit

Permalink
Fix edge case, apply formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed May 21, 2024
1 parent 9d2a3e5 commit 841dd36
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
14 changes: 13 additions & 1 deletion src/MatrixFields/field_matrix_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,17 @@ function check_field_matrix_solver(::BlockDiagonalSolve, _, A, b)
end
end

# TODO: we can remove the uniform_vertical_levels
# limitation while still using static shared memory
# once Nv is in the type space.
function uniform_vertical_levels(x, names)
_, _, _, Nv1, _ = size(Fields.field_values(x[first(names)]))
return all(Base.tail(names)) do name
_, _, _, Nv, _ = size(Fields.field_values(x[name]))
Nv == Nv1
end
end

NVTX.@annotate function run_field_matrix_solver!(
::BlockDiagonalSolve,
cache,
Expand All @@ -256,7 +267,8 @@ NVTX.@annotate function run_field_matrix_solver!(
)
names = matrix_row_keys(keys(A))
if length(names) == 1 ||
all(name -> A[name, name] isa UniformScaling, names.values)
all(name -> A[name, name] isa UniformScaling, names.values) ||
!uniform_vertical_levels(x, names.values)
foreach(names) do name
single_field_solve!(cache[name], x[name], A[name, name], b[name])
end
Expand Down
6 changes: 4 additions & 2 deletions test/MatrixFields/field_matrix_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ function test_field_matrix_solver(; test_name, alg, A, b, use_rel_error = false)
AnyFrameModule(MatrixFields.KrylovKit),
AnyFrameModule(Base.CoreLogging),
)
using_cuda || @test_opt ignored_modules = ignored FieldMatrixSolver(alg, A, b)
using_cuda || @test_opt ignored_modules = ignored field_matrix_solve!(args...)
using_cuda ||
@test_opt ignored_modules = ignored FieldMatrixSolver(alg, A, b)
using_cuda ||
@test_opt ignored_modules = ignored field_matrix_solve!(args...)
@test_opt ignored_modules = ignored field_matrix_mul!(b, A, x)

using_cuda || @test @allocated(field_matrix_solve!(args...)) == 0
Expand Down

0 comments on commit 841dd36

Please sign in to comment.