Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPMD] auto-sharding PoC #6719

Merged
merged 39 commits into from
Mar 14, 2024
Merged

[SPMD] auto-sharding PoC #6719

merged 39 commits into from
Mar 14, 2024

Conversation

yeounoh
Copy link
Contributor

@yeounoh yeounoh commented Mar 12, 2024

This implemented a PoC prototype on XLA:TPU, as described in #6322

PyTorch/XLA auto-sharding can be enabled by one of the following:

  • Setting envvar XLA_SPMD_AUTO=1
  • Calling the SPMD API in the beginning of your code:
import torch_xla.runtime as xr
xr.use_spmd(auto=True)
  • Calling pytorch.distributed._tensor.distribute_module with auto-policy and xla:
import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy

device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))

# Currently, model should be loaded to xla device via distribute_module.
model = MyModule()  # nn.module
sharded_model = distribute_module(model, device_mesh, auto_policy)

Some notable limitations that we will address in follow-ups:

  • XLA:GPU is not supported
  • TPU pod is not supported

cc @baoleai

@yeounoh yeounoh requested a review from JackCaoG March 12, 2024 00:21
@yeounoh yeounoh self-assigned this Mar 12, 2024
@yeounoh yeounoh marked this pull request as draft March 12, 2024 00:22
@yeounoh yeounoh force-pushed the spmd_auto_alpa branch 2 times, most recently from 126ceee to 4d568ef Compare March 12, 2024 00:25
WORKSPACE Outdated Show resolved Hide resolved
setup.py Outdated Show resolved Hide resolved
@yeounoh yeounoh force-pushed the spmd_auto_alpa branch 2 times, most recently from 6ca8f97 to d6dc442 Compare March 12, 2024 00:38
@yeounoh yeounoh force-pushed the spmd_auto_alpa branch 12 times, most recently from 303b239 to d3c1d70 Compare March 12, 2024 07:34
@yeounoh yeounoh force-pushed the spmd_auto_alpa branch 4 times, most recently from 968bca4 to eadcae6 Compare March 14, 2024 18:15
@@ -226,6 +226,8 @@ function run_xla_op_tests3 {
run_test "$CDIR/spmd/test_xla_distributed_checkpoint.py"
run_test "$CDIR/spmd/test_xla_spmd_python_api_interaction.py"
run_test "$CDIR/spmd/test_dtensor_integration.py"
run_test "$CDIR/spmd/test_dtensor_integration2.py"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this on TPU CI as well or it is ok to leave out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohhh i think it's ok to leave out. Want to run this sanity check on TPU!

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to adjust remaining comments in a follow up [r

@yeounoh yeounoh merged commit 370089a into master Mar 14, 2024
18 checks passed
yeounoh added a commit that referenced this pull request Mar 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants