Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable Half in mpi #1759

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
4 changes: 2 additions & 2 deletions common/cuda_hip/distributed/assembly_kernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -90,7 +90,7 @@ void count_non_owning_entries(
num_parts, local_part, row_part_ptrs.get_data(), send_count.get_data());
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_COUNT_NON_OWNING_ENTRIES);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base version of that macro is now unused.



Expand Down
18 changes: 9 additions & 9 deletions common/cuda_hip/distributed/matrix_kernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -137,11 +137,11 @@ void separate_local_nonlocal(
col_range_starting_indices[range_id];
};

using input_type = input_type<ValueType, GlobalIndexType>;
using input_type = input_type<device_type<ValueType>, GlobalIndexType>;
auto input_it = thrust::make_zip_iterator(thrust::make_tuple(
input.get_const_row_idxs(), input.get_const_col_idxs(),
input.get_const_values(), row_range_ids.get_const_data(),
col_range_ids.get_const_data()));
as_device_type(input.get_const_values()),
row_range_ids.get_const_data(), col_range_ids.get_const_data()));

// copy and transform local entries into arrays
local_row_idxs.resize_and_reset(num_local_elements);
Expand All @@ -157,9 +157,9 @@ void separate_local_nonlocal(
thrust::copy_if(
policy, local_it, local_it + input.get_num_stored_elements(),
range_ids_it,
thrust::make_zip_iterator(thrust::make_tuple(local_row_idxs.get_data(),
local_col_idxs.get_data(),
local_values.get_data())),
thrust::make_zip_iterator(thrust::make_tuple(
local_row_idxs.get_data(), local_col_idxs.get_data(),
as_device_type(local_values.get_data()))),
[local_part, row_part_ids, col_part_ids] __host__ __device__(
const thrust::tuple<size_type, size_type>& tuple) {
auto row_part = row_part_ids[thrust::get<0>(tuple)];
Expand All @@ -185,7 +185,7 @@ void separate_local_nonlocal(
range_ids_it,
thrust::make_zip_iterator(thrust::make_tuple(
non_local_row_idxs.get_data(), non_local_col_idxs.get_data(),
non_local_values.get_data())),
as_device_type(non_local_values.get_data()))),
[local_part, row_part_ids, col_part_ids] __host__ __device__(
const thrust::tuple<size_type, size_type>& tuple) {
auto row_part = row_part_ids[thrust::get<0>(tuple)];
Expand All @@ -194,7 +194,7 @@ void separate_local_nonlocal(
});
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_SEPARATE_LOCAL_NONLOCAL);


Expand Down
4 changes: 2 additions & 2 deletions common/cuda_hip/distributed/vector_kernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -83,7 +83,7 @@ void build_local(
range_id.get_data(), local_mtx->get_values(), is_local_row);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_DISTRIBUTED_VECTOR_BUILD_LOCAL);


Expand Down
4 changes: 2 additions & 2 deletions common/unified/distributed/assembly_kernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -48,7 +48,7 @@ void fill_send_buffers(
send_values.get_data());
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_FILL_SEND_BUFFERS);


Expand Down
2 changes: 1 addition & 1 deletion core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@
typename GlobalIndexType> \
_macro(ValueType, LocalIndexType, GlobalIndexType) \
GKO_NOT_COMPILED(GKO_HOOK_MODULE); \
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(_macro)
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(_macro)

#define GKO_STUB_TEMPLATE_TYPE_BASE(_macro) \
template <typename IndexType> \
Expand Down
4 changes: 2 additions & 2 deletions core/distributed/assembly.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -135,7 +135,7 @@ device_matrix_data<ValueType, GlobalIndexType> assemble_rows_from_neighbors(
mpi::communicator comm, \
const device_matrix_data<_value_type, _global_type>& input, \
ptr_param<const Partition<_local_type, _global_type>> partition)
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_ASSEMBLE_ROWS_FROM_NEIGHBORS);


Expand Down
16 changes: 6 additions & 10 deletions core/distributed/helpers.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -122,15 +122,11 @@ void vector_dispatch(T* linop, F&& f, Args&&... args)
{
#if GINKGO_BUILD_MPI
if (is_distributed(linop)) {
if constexpr (std::is_same_v<remove_complex<ValueType>, half>) {
GKO_NOT_SUPPORTED(linop);
} else {
using type = std::conditional_t<
std::is_const<T>::value,
const experimental::distributed::Vector<ValueType>,
experimental::distributed::Vector<ValueType>>;
f(dynamic_cast<type*>(linop), std::forward<Args>(args)...);
}
using type = std::conditional_t<
std::is_const<T>::value,
const experimental::distributed::Vector<ValueType>,
experimental::distributed::Vector<ValueType>>;
f(dynamic_cast<type*>(linop), std::forward<Args>(args)...);
} else
#endif
{
Expand Down
52 changes: 46 additions & 6 deletions core/distributed/matrix.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -200,8 +200,8 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::create(

template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::convert_to(
Matrix<next_precision_base<value_type>, local_index_type,
global_index_type>* result) const
Matrix<next_precision<value_type>, local_index_type, global_index_type>*
result) const
{
GKO_ASSERT(this->get_communicator().size() ==
result->get_communicator().size());
Expand All @@ -219,8 +219,8 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::convert_to(

template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::move_to(
Matrix<next_precision_base<value_type>, local_index_type,
global_index_type>* result)
Matrix<next_precision<value_type>, local_index_type, global_index_type>*
result)
{
GKO_ASSERT(this->get_communicator().size() ==
result->get_communicator().size());
Expand All @@ -237,6 +237,46 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::move_to(
}


#if GINKGO_ENABLE_HALF
template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::convert_to(
Matrix<next_precision<next_precision<value_type>>, local_index_type,
global_index_type>* result) const
{
GKO_ASSERT(this->get_communicator().size() ==
result->get_communicator().size());
result->local_mtx_->copy_from(this->local_mtx_.get());
result->non_local_mtx_->copy_from(this->non_local_mtx_.get());
result->gather_idxs_ = this->gather_idxs_;
result->send_offsets_ = this->send_offsets_;
result->recv_offsets_ = this->recv_offsets_;
result->recv_sizes_ = this->recv_sizes_;
result->send_sizes_ = this->send_sizes_;
result->non_local_to_global_ = this->non_local_to_global_;
result->set_size(this->get_size());
}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::move_to(
Matrix<next_precision<next_precision<value_type>>, local_index_type,
global_index_type>* result)
{
GKO_ASSERT(this->get_communicator().size() ==
result->get_communicator().size());
result->local_mtx_->move_from(this->local_mtx_.get());
result->non_local_mtx_->move_from(this->non_local_mtx_.get());
result->gather_idxs_ = std::move(this->gather_idxs_);
result->send_offsets_ = std::move(this->send_offsets_);
result->recv_offsets_ = std::move(this->recv_offsets_);
result->recv_sizes_ = std::move(this->recv_sizes_);
result->send_sizes_ = std::move(this->send_sizes_);
result->non_local_to_global_ = std::move(this->non_local_to_global_);
result->set_size(this->get_size());
this->set_size({});
}
#endif

template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(
const device_matrix_data<value_type, global_index_type>& data,
Expand Down Expand Up @@ -661,7 +701,7 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::operator=(Matrix&& other)
#define GKO_DECLARE_DISTRIBUTED_MATRIX(ValueType, LocalIndexType, \
GlobalIndexType) \
class Matrix<ValueType, LocalIndexType, GlobalIndexType>
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_DISTRIBUTED_MATRIX);


Expand Down
5 changes: 2 additions & 3 deletions core/distributed/preconditioner/schwarz.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -144,8 +144,7 @@ void Schwarz<ValueType, LocalIndexType, GlobalIndexType>::generate(

#define GKO_DECLARE_SCHWARZ(ValueType, LocalIndexType, GlobalIndexType) \
class Schwarz<ValueType, LocalIndexType, GlobalIndexType>
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(
GKO_DECLARE_SCHWARZ);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(GKO_DECLARE_SCHWARZ);


} // namespace preconditioner
Expand Down
Loading
Loading