-
Notifications
You must be signed in to change notification settings - Fork 475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XLA:CPU] Add support for cross-process collectives using mpi. #7849
Conversation
also @penpornk @ezhulenev |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.
If we additionally add some easy way to compile xla with the mpi implementation present on the system I would be open for possibly trying WI4MPI. I hope that at some point in the future when the MPI abi standardization is done we can use that, so that only non-compliant mpi implementations need to be translated/wrapped. |
Hi @inailuig , This PR is in draft, any update on this? Please. Thank you! |
I was waiting for an answer on #7849 (comment) from @penpornk. |
Alright. I implemented a reasonably elegant way of building the GlobalDeviceID-> MPI_COMM_WORLD rank mapping with a minimal amount of non-blocking mpi communication. This resolves #7849 (comment), and does not involve changing the CollectivesInterface like my original proposal in #7849 (comment). The mpi communicator now in some sense autodiscovers which device is on which rank, and does not require external information on the topology (just like the other communicators such as gloo). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for all the hard work and sorry for the delayed response!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
Imported from GitHub PR openxla/xla#7849 Mpi collectives as proposed in jax-ml/jax#11182. I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly. For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`. @hawkinsp Copybara import of the project: -- b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <clemens@inailuig.it>: add mpi collectives -- 23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <clemens@inailuig.it>: add explicit Init and Finalize methods and export them to python -- bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <clemens@inailuig.it>: add comment -- 38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <clemens@inailuig.it>: fix windows build -- 201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <clemens@inailuig.it>: fmt -- 2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <clemens@inailuig.it>: bump xla_extension_version Merging this change closes #7849 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e PiperOrigin-RevId: 620001290
Imported from GitHub PR openxla/xla#7849 Mpi collectives as proposed in jax-ml/jax#11182. I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly. For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`. @hawkinsp Copybara import of the project: -- b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <clemens@inailuig.it>: add mpi collectives -- 23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <clemens@inailuig.it>: add explicit Init and Finalize methods and export them to python -- bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <clemens@inailuig.it>: add comment -- 38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <clemens@inailuig.it>: fix windows build -- 201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <clemens@inailuig.it>: fmt -- 2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <clemens@inailuig.it>: bump xla_extension_version Merging this change closes #7849 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e PiperOrigin-RevId: 620001290
Imported from GitHub PR openxla/xla#7849 Mpi collectives as proposed in jax-ml/jax#11182. I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly. For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`. @hawkinsp Copybara import of the project: -- b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <clemens@inailuig.it>: add mpi collectives -- 23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <clemens@inailuig.it>: add explicit Init and Finalize methods and export them to python -- bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <clemens@inailuig.it>: add comment -- 38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <clemens@inailuig.it>: fix windows build -- 201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <clemens@inailuig.it>: fmt -- 2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <clemens@inailuig.it>: bump xla_extension_version Merging this change closes #7849 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e PiperOrigin-RevId: 620001290
Imported from GitHub PR openxla/xla#7849 Mpi collectives as proposed in jax-ml/jax#11182. I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly. For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`. @hawkinsp Copybara import of the project: -- b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <clemens@inailuig.it>: add mpi collectives -- 23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <clemens@inailuig.it>: add explicit Init and Finalize methods and export them to python -- bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <clemens@inailuig.it>: add comment -- 38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <clemens@inailuig.it>: fix windows build -- 201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <clemens@inailuig.it>: fmt -- 2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <clemens@inailuig.it>: bump xla_extension_version Merging this change closes #7849 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e PiperOrigin-RevId: 620001290
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e PiperOrigin-RevId: 620092032
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e PiperOrigin-RevId: 619695690
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e PiperOrigin-RevId: 616865795
Previously, the counter for `MemorySpaceAssignment::FixSchedule` started at `0` and would check for async copies scheduled for before the counter and after (but only the exact value). We have some async copies that set their `start_after` value to `-1` meaning we would skip inserting them at the earliest point and then catch that they weren't inserted by their `start_before` time and insert them then. This would lead to a few async copy operations where `*-start` would be scheduled immediately before their corresponding `*-done` operation, leading to none of the latency being hidden. FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e PiperOrigin-RevId: 617018384
(so Jax persistent cache won't reuse code from incompatible machines) FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e PiperOrigin-RevId: 617288134
The default thread pool size is too small on Mac OS. An older version of this runtime based on StreamExecutor set a 2MiB stack size as well, but that change was most likely lost during the TFRT rewrite. Fixes jax-ml/jax#20428 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e PiperOrigin-RevId: 620227968
Imported from GitHub PR openxla/xla#7849 Mpi collectives as proposed in jax-ml/jax#11182. I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly. For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`. @hawkinsp Copybara import of the project: -- b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <clemens@inailuig.it>: add mpi collectives -- 23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <clemens@inailuig.it>: add explicit Init and Finalize methods and export them to python -- bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <clemens@inailuig.it>: add comment -- 38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <clemens@inailuig.it>: fix windows build -- 201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <clemens@inailuig.it>: fmt -- 2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <clemens@inailuig.it>: bump xla_extension_version Merging this change closes #7849 PiperOrigin-RevId: 620302264
Imported from GitHub PR openxla/xla#11721 we forgot this in openxla/xla#7849. Copybara import of the project: -- 3924cc0fbbb63e9503f38a59aede3b8e817b17fa by Clemens Giuliani <clemens@inailuig.it>: [XLA:CPU] add missing type annotations for the mpi collectives Merging this change closes #11721 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11721 from inailuig:mpicollectives_pytype 3924cc0fbbb63e9503f38a59aede3b8e817b17fa PiperOrigin-RevId: 627055918
Imported from GitHub PR openxla/xla#11721 we forgot this in openxla/xla#7849. Copybara import of the project: -- 3924cc0fbbb63e9503f38a59aede3b8e817b17fa by Clemens Giuliani <clemens@inailuig.it>: [XLA:CPU] add missing type annotations for the mpi collectives Merging this change closes #11721 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11721 from inailuig:mpicollectives_pytype 3924cc0fbbb63e9503f38a59aede3b8e817b17fa PiperOrigin-RevId: 627055918
Imported from GitHub PR #11721 we forgot this in #7849. Copybara import of the project: -- 3924cc0 by Clemens Giuliani <clemens@inailuig.it>: [XLA:CPU] add missing type annotations for the mpi collectives Merging this change closes #11721 COPYBARA_INTEGRATE_REVIEW=#11721 from inailuig:mpicollectives_pytype 3924cc0 PiperOrigin-RevId: 627544891
Imported from GitHub PR openxla/xla#11721 we forgot this in openxla/xla#7849. Copybara import of the project: -- 3924cc0fbbb63e9503f38a59aede3b8e817b17fa by Clemens Giuliani <clemens@inailuig.it>: [XLA:CPU] add missing type annotations for the mpi collectives Merging this change closes #11721 PiperOrigin-RevId: 627544891
…using mpi. Imported from GitHub PR openxla#7849 Mpi collectives as proposed in jax-ml/jax#11182. I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly. For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`. @hawkinsp Copybara import of the project: -- b74bbb9 by Clemens Giuliani <clemens@inailuig.it>: add mpi collectives -- 23508eb by Clemens Giuliani <clemens@inailuig.it>: add explicit Init and Finalize methods and export them to python -- bbe5840 by Clemens Giuliani <clemens@inailuig.it>: add comment -- 38d1562 by Clemens Giuliani <clemens@inailuig.it>: fix windows build -- 201f723 by Clemens Giuliani <clemens@inailuig.it>: fmt -- 2784869 by Clemens Giuliani <clemens@inailuig.it>: bump xla_extension_version Merging this change closes openxla#7849 COPYBARA_INTEGRATE_REVIEW=openxla#7849 from inailuig:mpi_collectives 2784869 PiperOrigin-RevId: 620302264
Mpi collectives as proposed in jax-ml/jax#11182.
I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly.
For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set
MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so
.@hawkinsp