Skip to content

Commit

Permalink
add explicit Init and Finalize methods and export them to python
Browse files Browse the repository at this point in the history
  • Loading branch information
inailuig committed Mar 19, 2024
1 parent b74bbb9 commit 23508eb
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
4 changes: 2 additions & 2 deletions xla/pjrt/cpu/mpi_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,15 @@ absl::Status MpiCollectivesCommunicator::ReduceScatter(
input_buffer, output_buffer, recvcounts.data(), type, op, comm_));
}

MpiCollectives::MpiCollectives() {
void MpiCollectives::Init() {
int provided;
MPI_Init_thread(NULL, NULL, MPI_THREAD_FUNNELED, &provided);
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank_);
MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size_);
VLOG(1) << "MPI rank=" << mpi_world_rank_ << " size=" << mpi_world_size_;
}

MpiCollectives::~MpiCollectives() {
void MpiCollectives::Finalize() {
contexts_.clear();
MPI_Finalize();
}
Expand Down
5 changes: 3 additions & 2 deletions xla/pjrt/cpu/mpi_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ class MpiCollectivesCommunicator : public CollectivesCommunicator {

class MpiCollectives : public CollectivesInterface {
public:
MpiCollectives();
~MpiCollectives() override;

void Init();
void Finalize();

absl::StatusOr<std::shared_ptr<CollectivesCommunicator>> GetCommunicator(
absl::Span<GlobalDeviceId const> global_devices, int rank) override;
Expand Down
8 changes: 7 additions & 1 deletion xla/python/xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,15 @@ NB_MODULE(xla_extension, m_nb) {
nb::arg("distributed_client"), nb::arg("hostname").none() = std::nullopt,
nb::arg("interface").none() = std::nullopt);

#ifndef _WIN32
nb::class_<cpu::MpiCollectives> mpi_collectives(m_nb, "MpiCollectives", cpu_collectives);
mpi_collectives.def("Init", &cpu::MpiCollectives::Init);
mpi_collectives.def("Finalize", &cpu::MpiCollectives::Finalize);
#endif // _WIN32

m_nb.def(
"make_mpi_collectives",
[]() -> std::shared_ptr<xla::cpu::CollectivesInterface> {
[]() -> std::shared_ptr<cpu::MpiCollectives> {
#ifndef _WIN32
return std::make_shared<cpu::MpiCollectives>();
#else // _WIN32
Expand Down

0 comments on commit 23508eb

Please sign in to comment.