Skip to content

Commit

Permalink
fix kernel parameter passing for SOR DPCPP kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed Nov 12, 2024
1 parent 97deeaa commit 82a740f
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions dpcpp/preconditioner/sor_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,18 @@ void initialize_weighted_l(
1, 1};

auto inv_weight = one(weight) / weight;
const auto in_row_ptrs = system_matrix->get_const_row_ptrs();
const auto in_col_idxs = system_matrix->get_const_col_idxs();
const auto in_values = system_matrix->get_const_values();
const auto l_row_ptrs = l_mtx->get_const_row_ptrs();
const auto l_col_idxs = l_mtx->get_col_idxs();
const auto l_values = l_mtx->get_values();

exec->get_queue()->parallel_for(
sycl_nd_range(grid_dim, block_size), [=](sycl::nd_item<3> item_ct1) {
factorization::helpers::initialize_l(
num_rows, system_matrix->get_const_row_ptrs(),
system_matrix->get_const_col_idxs(),
system_matrix->get_const_values(), l_mtx->get_const_row_ptrs(),
l_mtx->get_col_idxs(), l_mtx->get_values(),
num_rows, in_row_ptrs, in_col_idxs, in_values, l_row_ptrs,
l_col_idxs, l_values,
factorization::helpers::triangular_mtx_closure(
[inv_weight](auto val) { return val * inv_weight; },
factorization::helpers::identity{}),
Expand Down Expand Up @@ -67,15 +71,21 @@ void initialize_weighted_l_u(
auto inv_two_minus_weight =
one(weight) / (static_cast<remove_complex<ValueType>>(2.0) - weight);

const auto in_row_ptrs = system_matrix->get_const_row_ptrs();
const auto in_col_idxs = system_matrix->get_const_col_idxs();
const auto in_values = system_matrix->get_const_values();
const auto l_row_ptrs = l_mtx->get_const_row_ptrs();
const auto l_col_idxs = l_mtx->get_col_idxs();
const auto l_values = l_mtx->get_values();
const auto u_row_ptrs = u_mtx->get_const_row_ptrs();
const auto u_col_idxs = u_mtx->get_col_idxs();
const auto u_values = u_mtx->get_values();

exec->get_queue()->parallel_for(
sycl_nd_range(grid_dim, block_size), [=](sycl::nd_item<3> item_ct1) {
factorization::helpers::initialize_l_u(
num_rows, system_matrix->get_const_row_ptrs(),
system_matrix->get_const_col_idxs(),
system_matrix->get_const_values(), l_mtx->get_const_row_ptrs(),
l_mtx->get_col_idxs(), l_mtx->get_values(),
u_mtx->get_const_row_ptrs(), u_mtx->get_col_idxs(),
u_mtx->get_values(),
num_rows, in_row_ptrs, in_col_idxs, in_values, l_row_ptrs,
l_col_idxs, l_values, u_row_ptrs, u_col_idxs, u_values,
factorization::helpers::triangular_mtx_closure(
[inv_weight](auto val) { return val * inv_weight; },
factorization::helpers::identity{}),
Expand Down

0 comments on commit 82a740f

Please sign in to comment.