Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:CPU] Add support for cross-process collectives using mpi. #7849

Closed
wants to merge 7 commits into from

Conversation

inailuig
Copy link
Contributor

@inailuig inailuig commented Dec 16, 2023

Mpi collectives as proposed in jax-ml/jax#11182.

I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly.

For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so.

@hawkinsp

@cheshire
Copy link
Contributor

also @penpornk @ezhulenev

Copy link
Member

@penpornk penpornk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for the PR! :)

Can I ask why you picked MPItrampoline as the wrapper? Asking just because I've heard of a few other wrappers, e.g., WI4MPI, MUK, etc. and was wondering how they compare. (I've never used any.)

xla/pjrt/cpu/mpi_collectives.cc Outdated Show resolved Hide resolved
xla/pjrt/cpu/mpi_collectives.cc Outdated Show resolved Hide resolved
xla/pjrt/cpu/mpi_collectives.cc Outdated Show resolved Hide resolved
xla/pjrt/cpu/mpi_collectives.cc Outdated Show resolved Hide resolved
xla/pjrt/cpu/mpi_collectives.h Outdated Show resolved Hide resolved
xla/pjrt/cpu/mpi_collectives.h Outdated Show resolved Hide resolved
@inailuig
Copy link
Contributor Author

inailuig commented Dec 19, 2023

Can I ask why you picked MPItrampoline as the wrapper? Asking just because I've heard of a few other wrappers, e.g., WI4MPI, MUK, etc. and was wondering how they compare. (I've never used any.)

.
I had looked at all three of those but I ended up choosing MPItrampoline (I haven't prevously used any either).

  • MPItrampoline: more lightweight than WI4MPI, easy to use and compile, but requires a wrapper for every mpi implementation.

  • WI4MPI: Seemed more complicated to use, I couldn't quite figure out how hard it would be to use it with some new / exotic mpi implementation that is not already supported (in "Interface" mode). Otherwise looks like a valid option. It would have the advantage that it should work with the supported mpi implementations( IntelMPI, MPICH/MVAPICH or OpenMPI) without compiling additional wrappers, as they hardcode the ABI of these implementations and the wrappers are essentially already baked in.

  • Mukautuva: It seemed less mature than MPItrampoline

If we additionally add some easy way to compile xla with the mpi implementation present on the system I would be open for possibly trying WI4MPI.

I hope that at some point in the future when the MPI abi standardization is done we can use that, so that only non-compliant mpi implementations need to be translated/wrapped.
Mukautuva already uses the prototype, and the developer of MPItrampoline seems to be working on an implementation as well.

@kamaljeeti
Copy link
Contributor

Hi @inailuig , This PR is in draft, any update on this? Please. Thank you!

@inailuig
Copy link
Contributor Author

inailuig commented Jan 8, 2024

Hi @inailuig , This PR is in draft, any update on this? Please. Thank you!

I was waiting for an answer on #7849 (comment) from @penpornk.
I can try to implement my proposal in the next couple of days if that makes it easier.

@inailuig
Copy link
Contributor Author

inailuig commented Jan 13, 2024

Alright. I implemented a reasonably elegant way of building the GlobalDeviceID-> MPI_COMM_WORLD rank mapping with a minimal amount of non-blocking mpi communication.

This resolves #7849 (comment), and does not involve changing the CollectivesInterface like my original proposal in #7849 (comment).

The mpi communicator now in some sense autodiscovers which device is on which rank, and does not require external information on the topology (just like the other communicators such as gloo).

@inailuig inailuig marked this pull request as ready for review January 13, 2024 09:59
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 13, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 13, 2024
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 13, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 13, 2024
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 13, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 13, 2024
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 13, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 13, 2024
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 13, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 13, 2024
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 15, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 15, 2024
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 15, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 15, 2024
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 15, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 15, 2024
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 15, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 15, 2024
Copy link
Member

@penpornk penpornk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for all the hard work and sorry for the delayed response!

xla/pjrt/cpu/mpi_collectives.cc Outdated Show resolved Hide resolved
xla/pjrt/cpu/mpi_collectives.cc Outdated Show resolved Hide resolved
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Jan 15, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jan 15, 2024
Copy link
Member

@penpornk penpornk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@cheshire cheshire added the kokoro:force-run Forces CI to rerun label Mar 28, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Mar 28, 2024
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 29, 2024
Imported from GitHub PR openxla/xla#7849

Mpi collectives as proposed in jax-ml/jax#11182.

I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly.

For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`.

@hawkinsp
Copybara import of the project:

--
b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <clemens@inailuig.it>:

add mpi collectives

--
23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <clemens@inailuig.it>:

add explicit Init and Finalize methods and export them to python

--
bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <clemens@inailuig.it>:

add comment

--
38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <clemens@inailuig.it>:

fix windows build

--
201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <clemens@inailuig.it>:

fmt

--
2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <clemens@inailuig.it>:

bump xla_extension_version

Merging this change closes #7849

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e
PiperOrigin-RevId: 620001290
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 29, 2024
Imported from GitHub PR openxla/xla#7849

Mpi collectives as proposed in jax-ml/jax#11182.

I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly.

For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`.

@hawkinsp
Copybara import of the project:

--
b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <clemens@inailuig.it>:

add mpi collectives

--
23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <clemens@inailuig.it>:

add explicit Init and Finalize methods and export them to python

--
bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <clemens@inailuig.it>:

add comment

--
38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <clemens@inailuig.it>:

fix windows build

--
201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <clemens@inailuig.it>:

fmt

--
2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <clemens@inailuig.it>:

bump xla_extension_version

Merging this change closes #7849

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e
PiperOrigin-RevId: 620001290
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 29, 2024
Imported from GitHub PR openxla/xla#7849

Mpi collectives as proposed in jax-ml/jax#11182.

I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly.

For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`.

@hawkinsp
Copybara import of the project:

--
b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <clemens@inailuig.it>:

add mpi collectives

--
23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <clemens@inailuig.it>:

add explicit Init and Finalize methods and export them to python

--
bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <clemens@inailuig.it>:

add comment

--
38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <clemens@inailuig.it>:

fix windows build

--
201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <clemens@inailuig.it>:

fmt

--
2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <clemens@inailuig.it>:

bump xla_extension_version

Merging this change closes #7849

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e
PiperOrigin-RevId: 620001290
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 29, 2024
Imported from GitHub PR openxla/xla#7849

Mpi collectives as proposed in jax-ml/jax#11182.

I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly.

For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`.

@hawkinsp
Copybara import of the project:

--
b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <clemens@inailuig.it>:

add mpi collectives

--
23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <clemens@inailuig.it>:

add explicit Init and Finalize methods and export them to python

--
bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <clemens@inailuig.it>:

add comment

--
38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <clemens@inailuig.it>:

fix windows build

--
201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <clemens@inailuig.it>:

fmt

--
2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <clemens@inailuig.it>:

bump xla_extension_version

Merging this change closes #7849

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e
PiperOrigin-RevId: 620001290
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 29, 2024
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e
PiperOrigin-RevId: 620092032
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 29, 2024
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e
PiperOrigin-RevId: 619695690
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 29, 2024
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e
PiperOrigin-RevId: 616865795
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 29, 2024
Previously, the counter for `MemorySpaceAssignment::FixSchedule` started at `0` and would check for async copies scheduled for before the counter and after (but only the exact value). We have some async copies that set their `start_after` value to `-1` meaning we would skip inserting them at the earliest point and then catch that they weren't inserted by their `start_before` time and insert them then. This would lead to a few async copy operations where `*-start` would be scheduled immediately before their corresponding `*-done` operation, leading to none of the latency being hidden.

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e
PiperOrigin-RevId: 617018384
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 29, 2024
(so Jax persistent cache won't reuse code from incompatible machines)

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e
PiperOrigin-RevId: 617288134
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 29, 2024
The default thread pool size is too small on Mac OS.

An older version of this runtime based on StreamExecutor set a 2MiB stack size as well, but that change was most likely lost during the TFRT rewrite.

Fixes jax-ml/jax#20428

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e
PiperOrigin-RevId: 620227968
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 29, 2024
Imported from GitHub PR openxla/xla#7849

Mpi collectives as proposed in jax-ml/jax#11182.

I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly.

For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`.

@hawkinsp
Copybara import of the project:

--
b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <clemens@inailuig.it>:

add mpi collectives

--
23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <clemens@inailuig.it>:

add explicit Init and Finalize methods and export them to python

--
bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <clemens@inailuig.it>:

add comment

--
38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <clemens@inailuig.it>:

fix windows build

--
201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <clemens@inailuig.it>:

fmt

--
2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <clemens@inailuig.it>:

bump xla_extension_version

Merging this change closes #7849

PiperOrigin-RevId: 620302264
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 22, 2024
Imported from GitHub PR openxla/xla#11721

we forgot this in openxla/xla#7849.
Copybara import of the project:

--
3924cc0fbbb63e9503f38a59aede3b8e817b17fa by Clemens Giuliani <clemens@inailuig.it>:

[XLA:CPU] add missing type annotations for the mpi collectives

Merging this change closes #11721

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11721 from inailuig:mpicollectives_pytype 3924cc0fbbb63e9503f38a59aede3b8e817b17fa
PiperOrigin-RevId: 627055918
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 23, 2024
Imported from GitHub PR openxla/xla#11721

we forgot this in openxla/xla#7849.
Copybara import of the project:

--
3924cc0fbbb63e9503f38a59aede3b8e817b17fa by Clemens Giuliani <clemens@inailuig.it>:

[XLA:CPU] add missing type annotations for the mpi collectives

Merging this change closes #11721

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11721 from inailuig:mpicollectives_pytype 3924cc0fbbb63e9503f38a59aede3b8e817b17fa
PiperOrigin-RevId: 627055918
copybara-service bot pushed a commit that referenced this pull request Apr 24, 2024
Imported from GitHub PR #11721

we forgot this in #7849.
Copybara import of the project:

--
3924cc0 by Clemens Giuliani <clemens@inailuig.it>:

[XLA:CPU] add missing type annotations for the mpi collectives

Merging this change closes #11721

COPYBARA_INTEGRATE_REVIEW=#11721 from inailuig:mpicollectives_pytype 3924cc0
PiperOrigin-RevId: 627544891
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 24, 2024
Imported from GitHub PR openxla/xla#11721

we forgot this in openxla/xla#7849.
Copybara import of the project:

--
3924cc0fbbb63e9503f38a59aede3b8e817b17fa by Clemens Giuliani <clemens@inailuig.it>:

[XLA:CPU] add missing type annotations for the mpi collectives

Merging this change closes #11721

PiperOrigin-RevId: 627544891
steeve pushed a commit to zml/xla that referenced this pull request Aug 30, 2024
…using mpi.

Imported from GitHub PR openxla#7849

Mpi collectives as proposed in jax-ml/jax#11182.

I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly.

For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`.

@hawkinsp
Copybara import of the project:

--
b74bbb9 by Clemens Giuliani <clemens@inailuig.it>:

add mpi collectives

--
23508eb by Clemens Giuliani <clemens@inailuig.it>:

add explicit Init and Finalize methods and export them to python

--
bbe5840 by Clemens Giuliani <clemens@inailuig.it>:

add comment

--
38d1562 by Clemens Giuliani <clemens@inailuig.it>:

fix windows build

--
201f723 by Clemens Giuliani <clemens@inailuig.it>:

fmt

--
2784869 by Clemens Giuliani <clemens@inailuig.it>:

bump xla_extension_version

Merging this change closes openxla#7849

COPYBARA_INTEGRATE_REVIEW=openxla#7849 from inailuig:mpi_collectives 2784869
PiperOrigin-RevId: 620302264
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants