Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

[Feature] Support TPU backend for ShardParallel #764

Merged
merged 21 commits into from
Nov 8, 2022
Merged

[Feature] Support TPU backend for ShardParallel #764

merged 21 commits into from
Nov 8, 2022

Conversation

ZYHowell
Copy link
Collaborator

@ZYHowell ZYHowell commented Nov 3, 2022

This PR:

  • Add has_cuda and backend in global_config
  • Add IS_CUDA in setup.py
  • Change the semantic of run_auto_sharding_pass and run_backend_compilation: the output of the run_auto_sharding_pass and input of run_backend_compilation is no longer the spmd-partitioned module, but instead the sharding annotated module. This is because TPU cannot compile an spmd-partitioned module with sharding configurations;
  • write testcases specifically for TPUs. The pass that rewrites AllReduce to ReduceScatter and AllGather is after SPMD partitioner, so we cannot do it for the tpu backend.

@zhisbug
Copy link
Member

zhisbug commented Nov 3, 2022

One quick Q: how do we run TPU tests in CI?

@ZYHowell
Copy link
Collaborator Author

ZYHowell commented Nov 3, 2022

One quick Q: how do we run TPU tests in CI?

Google provides free TPU access for several days, but I've already extended the time twice(in the first extension they gave me 90 days and the second has 180 days). Can we use that machine for CI?

@ZYHowell ZYHowell changed the title [WIP][Feature] Support TPU backend for ShardParallel [Feature] Support TPU backend for ShardParallel Nov 7, 2022
@ZYHowell ZYHowell merged commit 4f14955 into main Nov 8, 2022
@ZYHowell ZYHowell deleted the tpu-support branch November 8, 2022 17:39
@merrymercy merrymercy mentioned this pull request Nov 9, 2022
2 tasks
@OhadRubin
Copy link

@merrymercy hey, it seems all the tpu specific tests are being skipped (on mobile, will provide a link soon).
So it is not clear to me what is the current working TPU capabilities.

@ZYHowell
Copy link
Collaborator Author

ZYHowell commented Nov 9, 2022

@merrymercy hey, it seems all the tpu specific tests are being skipped (on mobile, will provide a link soon). So it is not clear to me what is the current working TPU capabilities.

Currently we skip all tests other than those for ShardParallel/reduce-scatter optimization. You can see we rewrite all tests other than tests such as test_n_layer_mlp_data_parallel. For a full list of these kept tests, please refer to the if __name__ == "__main__": part.

@merrymercy
Copy link
Member

merrymercy commented Nov 9, 2022

Not all test case are skipped. The current code can pass these test cases on TPU

add_mlp("test_n_layer_mlp_data_parallel")
add_mlp("test_n_layer_mlp_model_parallel")
add_mlp("test_n_layer_mlp_2d_mesh")
add_mlp("test_n_layer_mlp_force_data_parallel")
add_mlp("test_n_layer_mlp_force_batch_dim_mapping")
add_mlp("test_weight_init")
add_moe("test_moe_layer")
add_moe("test_moe_layer_2d")
add_moe("test_moe_lm")
add_moe("test_moe_lm_2d")
add_moe("test_moe_lm_data_parallel")

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants