Skip to content

Commit

Permalink
Merge #1711 Add Half solver and residual norm
Browse files Browse the repository at this point in the history
This PR implement the half precision for the solver (mostly krylov solver) and residual norm

Related PR: #1711
  • Loading branch information
yhmtsai authored Dec 3, 2024
2 parents 76ef161 + 6ad32f2 commit 066ca3c
Show file tree
Hide file tree
Showing 68 changed files with 603 additions and 333 deletions.
54 changes: 38 additions & 16 deletions common/cuda_hip/solver/idr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,13 @@ __global__ __launch_bounds__(config::warp_size) void compute_omega_kernel(

if (!stop_status[global_id].has_stopped()) {
auto thr = omega[global_id];
const auto normt = sqrt(real(tht[global_id]));
if (normt == zero<remove_complex<ValueType>>()) {
omega[global_id] = zero<ValueType>();
return;
}
omega[global_id] /= tht[global_id];
auto absrho =
abs(thr / (sqrt(real(tht[global_id])) * residual_norm[global_id]));
auto absrho = abs(thr / (normt * residual_norm[global_id]));

if (absrho < kappa) {
omega[global_id] *= kappa / absrho;
Expand Down Expand Up @@ -450,11 +454,19 @@ void update_g_and_u(std::shared_ptr<const DefaultExecutor> exec,
if (nrhs > 1 || is_complex<ValueType>()) {
components::fill_array(exec, alpha->get_values(), nrhs,
zero<ValueType>());
multidot_kernel<<<grid_dim, block_dim, 0, exec->get_stream()>>>(
size, nrhs, as_device_type(p_i),
as_device_type(g_k->get_values()), g_k->get_stride(),
as_device_type(alpha->get_values()),
stop_status->get_const_data());
// not support 16 bit atomic
#if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700))
if constexpr (std::is_same_v<remove_complex<ValueType>, half>) {
GKO_NOT_SUPPORTED(alpha);
} else
#endif
{
multidot_kernel<<<grid_dim, block_dim, 0, exec->get_stream()>>>(
size, nrhs, as_device_type(p_i),
as_device_type(g_k->get_values()), g_k->get_stride(),
as_device_type(alpha->get_values()),
stop_status->get_const_data());
}
} else {
blas::dot(exec->get_blas_handle(), size, p_i, 1, g_k->get_values(),
g_k->get_stride(), alpha->get_values());
Expand Down Expand Up @@ -501,10 +513,18 @@ void update_m(std::shared_ptr<const DefaultExecutor> exec, const size_type nrhs,
auto m_i = m->get_values() + i * m_stride + k * nrhs;
if (nrhs > 1 || is_complex<ValueType>()) {
components::fill_array(exec, m_i, nrhs, zero<ValueType>());
multidot_kernel<<<grid_dim, block_dim, 0, exec->get_stream()>>>(
size, nrhs, as_device_type(p_i),
as_device_type(g_k->get_const_values()), g_k->get_stride(),
as_device_type(m_i), stop_status->get_const_data());
// not support 16 bit atomic
#if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700))
if constexpr (std::is_same_v<remove_complex<ValueType>, half>) {
GKO_NOT_SUPPORTED(m_i);
} else
#endif
{
multidot_kernel<<<grid_dim, block_dim, 0, exec->get_stream()>>>(
size, nrhs, as_device_type(p_i),
as_device_type(g_k->get_const_values()), g_k->get_stride(),
as_device_type(m_i), stop_status->get_const_data());
}
} else {
blas::dot(exec->get_blas_handle(), size, p_i, 1,
g_k->get_const_values(), g_k->get_stride(), m_i);
Expand Down Expand Up @@ -555,7 +575,8 @@ void initialize(std::shared_ptr<const DefaultExecutor> exec,
orthonormalize_subspace_vectors(exec, subspace_vectors);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_IDR_INITIALIZE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_IDR_INITIALIZE_KERNEL);


template <typename ValueType>
Expand All @@ -582,7 +603,7 @@ void step_1(std::shared_ptr<const DefaultExecutor> exec, const size_type nrhs,
stop_status->get_const_data());
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_IDR_STEP_1_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_IDR_STEP_1_KERNEL);


template <typename ValueType>
Expand All @@ -609,7 +630,7 @@ void step_2(std::shared_ptr<const DefaultExecutor> exec, const size_type nrhs,
stop_status->get_const_data());
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_IDR_STEP_2_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_IDR_STEP_2_KERNEL);


template <typename ValueType>
Expand All @@ -626,7 +647,7 @@ void step_3(std::shared_ptr<const DefaultExecutor> exec, const size_type nrhs,
update_x_r_and_f(exec, nrhs, k, m, g, u, f, residual, x, stop_status);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_IDR_STEP_3_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_IDR_STEP_3_KERNEL);


template <typename ValueType>
Expand All @@ -644,7 +665,8 @@ void compute_omega(
as_device_type(omega->get_values()), stop_status->get_const_data());
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_IDR_COMPUTE_OMEGA_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_IDR_COMPUTE_OMEGA_KERNEL);


} // namespace idr
Expand Down
5 changes: 3 additions & 2 deletions common/cuda_hip/stop/residual_norm_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void residual_norm(std::shared_ptr<const DefaultExecutor> exec,
*one_changed = get_element(*device_storage, 1);
}

GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_RESIDUAL_NORM_KERNEL);


Expand Down Expand Up @@ -171,7 +171,8 @@ void implicit_residual_norm(
*one_changed = get_element(*device_storage, 1);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_IMPLICIT_RESIDUAL_NORM_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_IMPLICIT_RESIDUAL_NORM_KERNEL);


} // namespace implicit_residual_norm
Expand Down
7 changes: 4 additions & 3 deletions common/unified/solver/bicg_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ void initialize(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BICG_INITIALIZE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BICG_INITIALIZE_KERNEL);


template <typename ValueType>
Expand All @@ -90,7 +91,7 @@ void step_1(std::shared_ptr<const DefaultExecutor> exec,
row_vector(prev_rho), *stop_status);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BICG_STEP_1_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BICG_STEP_1_KERNEL);


template <typename ValueType>
Expand Down Expand Up @@ -119,7 +120,7 @@ void step_2(std::shared_ptr<const DefaultExecutor> exec,
default_stride(q2), row_vector(beta), row_vector(rho), *stop_status);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BICG_STEP_2_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BICG_STEP_2_KERNEL);


} // namespace bicg
Expand Down
15 changes: 10 additions & 5 deletions common/unified/solver/bicgstab_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ void initialize(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BICGSTAB_INITIALIZE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BICGSTAB_INITIALIZE_KERNEL);


template <typename ValueType>
Expand Down Expand Up @@ -98,7 +99,8 @@ void step_1(std::shared_ptr<const DefaultExecutor> exec,
row_vector(alpha), row_vector(omega), *stop_status);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BICGSTAB_STEP_1_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BICGSTAB_STEP_1_KERNEL);


template <typename ValueType>
Expand Down Expand Up @@ -127,7 +129,8 @@ void step_2(std::shared_ptr<const DefaultExecutor> exec,
*stop_status);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BICGSTAB_STEP_2_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BICGSTAB_STEP_2_KERNEL);


template <typename ValueType>
Expand Down Expand Up @@ -159,7 +162,8 @@ void step_3(
row_vector(omega), *stop_status);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BICGSTAB_STEP_3_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BICGSTAB_STEP_3_KERNEL);


template <typename ValueType>
Expand Down Expand Up @@ -188,7 +192,8 @@ void finalize(std::shared_ptr<const DefaultExecutor> exec,
x->get_size()[1], *stop_status);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BICGSTAB_FINALIZE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BICGSTAB_FINALIZE_KERNEL);


} // namespace bicgstab
Expand Down
6 changes: 3 additions & 3 deletions common/unified/solver/cg_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void initialize(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CG_INITIALIZE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_CG_INITIALIZE_KERNEL);


template <typename ValueType>
Expand All @@ -80,7 +80,7 @@ void step_1(std::shared_ptr<const DefaultExecutor> exec,
row_vector(rho), row_vector(prev_rho), *stop_status);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CG_STEP_1_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_CG_STEP_1_KERNEL);


template <typename ValueType>
Expand All @@ -106,7 +106,7 @@ void step_2(std::shared_ptr<const DefaultExecutor> exec,
default_stride(q), row_vector(beta), row_vector(rho), *stop_status);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CG_STEP_2_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_CG_STEP_2_KERNEL);


} // namespace cg
Expand Down
9 changes: 5 additions & 4 deletions common/unified/solver/cgs_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ void initialize(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CGS_INITIALIZE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_CGS_INITIALIZE_KERNEL);


template <typename ValueType>
Expand Down Expand Up @@ -103,7 +104,7 @@ void step_1(std::shared_ptr<const DefaultExecutor> exec,
row_vector(prev_rho), *stop_status);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CGS_STEP_1_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_CGS_STEP_1_KERNEL);


template <typename ValueType>
Expand Down Expand Up @@ -134,7 +135,7 @@ void step_2(std::shared_ptr<const DefaultExecutor> exec,
row_vector(alpha), row_vector(rho), row_vector(gamma), *stop_status);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CGS_STEP_2_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_CGS_STEP_2_KERNEL);

template <typename ValueType>
void step_3(std::shared_ptr<const DefaultExecutor> exec,
Expand All @@ -157,7 +158,7 @@ void step_3(std::shared_ptr<const DefaultExecutor> exec,
*stop_status);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CGS_STEP_3_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_CGS_STEP_3_KERNEL);


} // namespace cgs
Expand Down
7 changes: 4 additions & 3 deletions common/unified/solver/common_gmres_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ void initialize(std::shared_ptr<const DefaultExecutor> exec,
b->get_size()[0]);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_COMMON_GMRES_INITIALIZE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_COMMON_GMRES_INITIALIZE_KERNEL);


template <typename ValueType>
Expand Down Expand Up @@ -125,7 +126,7 @@ void hessenberg_qr(std::shared_ptr<const DefaultExecutor> exec,
stop_status);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_COMMON_GMRES_HESSENBERG_QR_KERNEL);


Expand Down Expand Up @@ -158,7 +159,7 @@ void solve_krylov(std::shared_ptr<const DefaultExecutor> exec,
residual_norm_collection->get_size()[1]);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_COMMON_GMRES_SOLVE_KRYLOV_KERNEL);


Expand Down
7 changes: 4 additions & 3 deletions common/unified/solver/fcg_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ void initialize(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_FCG_INITIALIZE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_FCG_INITIALIZE_KERNEL);


template <typename ValueType>
Expand All @@ -84,7 +85,7 @@ void step_1(std::shared_ptr<const DefaultExecutor> exec,
row_vector(rho_t), row_vector(prev_rho), *stop_status);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_FCG_STEP_1_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_FCG_STEP_1_KERNEL);


template <typename ValueType>
Expand Down Expand Up @@ -113,7 +114,7 @@ void step_2(std::shared_ptr<const DefaultExecutor> exec,
*stop_status);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_FCG_STEP_2_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_FCG_STEP_2_KERNEL);


} // namespace fcg
Expand Down
7 changes: 4 additions & 3 deletions common/unified/solver/gcr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ void initialize(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_GCR_INITIALIZE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_GCR_INITIALIZE_KERNEL);


template <typename ValueType>
Expand Down Expand Up @@ -78,7 +79,7 @@ void restart(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_GCR_RESTART_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_GCR_RESTART_KERNEL);


template <typename ValueType>
Expand All @@ -104,7 +105,7 @@ void step_1(std::shared_ptr<const DefaultExecutor> exec,
stop_status);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_GCR_STEP_1_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_GCR_STEP_1_KERNEL);

} // namespace gcr
} // namespace GKO_DEVICE_NAMESPACE
Expand Down
8 changes: 5 additions & 3 deletions common/unified/solver/gmres_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void restart(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_GMRES_RESTART_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_GMRES_RESTART_KERNEL);


template <typename ValueType>
Expand Down Expand Up @@ -92,7 +92,8 @@ void multi_axpy(std::shared_ptr<const DefaultExecutor> exec,
before_preconditioner->get_size()[1], stop_status);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_GMRES_MULTI_AXPY_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_GMRES_MULTI_AXPY_KERNEL);


template <typename ValueType>
Expand All @@ -119,7 +120,8 @@ void multi_dot(std::shared_ptr<const DefaultExecutor> exec,
next_krylov->get_size()[0]);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_GMRES_MULTI_DOT_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_GMRES_MULTI_DOT_KERNEL);

} // namespace gmres
} // namespace GKO_DEVICE_NAMESPACE
Expand Down
4 changes: 3 additions & 1 deletion core/config/config_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ get_value(const pnode& config)
* This is specialization for floating point type
*/
template <typename ValueType>
inline std::enable_if_t<std::is_floating_point<ValueType>::value, ValueType>
inline std::enable_if_t<std::is_floating_point<ValueType>::value ||
std::is_same<ValueType, half>::value,
ValueType>
get_value(const pnode& config)
{
auto val = config.get_real();
Expand Down
Loading

0 comments on commit 066ca3c

Please sign in to comment.