Skip to content

Commit

Permalink
[SPMD] auto-construct auto-sharding mesh ids (#6770) (#6782)
Browse files Browse the repository at this point in the history
  • Loading branch information
yeounoh authored Mar 20, 2024
1 parent ebedf4d commit c923e8f
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 137 deletions.
5 changes: 2 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,10 @@ http_archive(
"//openxla_patches:cache_urls.diff",
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:f16_abi_clang.diff",
"//openxla_patches:quant_dequant_converter.diff",
],
strip_prefix = "xla-18cbd2019898d3a7b563aeb73683f0c5a6ce14fd",
strip_prefix = "xla-25c8a6781af6be51d3bc43a0953b07803ab761ea",
urls = [
"https://github.com/openxla/xla/archive/18cbd2019898d3a7b563aeb73683f0c5a6ce14fd.tar.gz",
"https://github.com/openxla/xla/archive/25c8a6781af6be51d3bc43a0953b07803ab761ea.tar.gz",
],
)

Expand Down
122 changes: 0 additions & 122 deletions openxla_patches/quant_dequant_converter.diff

This file was deleted.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

base_dir = os.path.dirname(os.path.abspath(__file__))

_date = '20240305'
_date = '20240320'
_libtpu_version = f'0.1.dev{_date}'
_libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl'
_jax_version = f'0.4.26.dev{_date}'
Expand Down
8 changes: 5 additions & 3 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1402,9 +1402,11 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(

// Apply XLA_AUTO_SPMD_MESH if it is set.
// TODO(yeounoh) allow multi mesh exploration.
auto mesh_shape_ids = ShardingUtil::GetAutoShardingMesh();
std::vector<int64_t> auto_spmd_mesh_shape = std::get<0>(mesh_shape_ids);
std::vector<int64_t> auto_spmd_mesh_ids = std::get<1>(mesh_shape_ids);
std::vector<int64_t> auto_spmd_mesh_shape =
ShardingUtil::GetAutoShardingMesh();
std::vector<int64_t> auto_spmd_mesh_ids =
ShardingUtil::GetAutoShardingMeshIds(
instances.front().computation.proto());
instances.front().auto_spmd_mesh_shape = auto_spmd_mesh_shape;
instances.front().auto_spmd_mesh_ids = auto_spmd_mesh_ids;
TF_VLOG(5) << "auto_spmd_mesh_shape={"
Expand Down
42 changes: 36 additions & 6 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,14 +624,12 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData(
source_tensors, GetVirtualDevice().toString(), global_shape, sharding);
}

std::tuple<std::vector<int64_t>, std::vector<int64_t>>
ShardingUtil::GetAutoShardingMesh() {
std::vector<int64_t> ShardingUtil::GetAutoShardingMesh() {
// Auto-sharding uses mesh_shape = {n_devices, 1} if XLA_AUTO_SPMD_MESH
// is not set. XLA_AUTO_SPMD_MESH takes a form of string, "2,2" which
// corresponds to a 2-by-2 mesh.
std::vector<int64_t> mesh_shape = ParseStringToIntVector(
runtime::sys_util::GetEnvString("XLA_AUTO_SPMD_MESH", ""));
std::vector<int64_t> device_mesh_ids;
if (!mesh_shape.empty()) {
int64_t total_devices = 1;
for (auto i : mesh_shape) {
Expand All @@ -641,10 +639,42 @@ ShardingUtil::GetAutoShardingMesh() {
runtime::GetComputationClient()->GetAllDevices().size())
<< "Invalid auto-sharding mesh_shape: "
<< absl::StrJoin(mesh_shape, ",");
device_mesh_ids = std::vector<int64_t>(total_devices);
std::iota(device_mesh_ids.begin(), device_mesh_ids.end(), 0);
}
return std::make_tuple(mesh_shape, device_mesh_ids);
return mesh_shape;
}

std::vector<int64_t> ShardingUtil::GetAutoShardingMeshIds(
const xla::HloModuleProto& module) {
// Return the first non-default (iota) mesh ids arrangement, as we expect
// only one such assignment and/or the logical mesh device assignment should
// be compatible with the other arrangements in the HLO. This is a work-around
// as the auto-sharding pass takes only one arrangement for now.
// TODO(yeounoh) this was not necessary before; replace if this can be done
// during the auto-sharding pass.
int64_t n_devices = runtime::GetComputationClient()->GetAllDevices().size();
std::vector<int64_t> device_mesh_ids = std::vector<int64_t>(n_devices);
std::iota(device_mesh_ids.begin(), device_mesh_ids.end(), 0);

// Unforuntately, we have to go through the instructions since
// `spmd_parameters_shardings` is not available.
for (auto computation : module.computations()) {
for (auto instruction : computation.instructions()) {
if (instruction.opcode() == "parameter" && instruction.has_sharding()) {
xla::OpSharding sharding = instruction.sharding();
auto tile_assignment_devices = sharding.tile_assignment_devices();
if (!tile_assignment_devices.empty()) {
auto new_mesh_ids = std::vector<int64_t>(
tile_assignment_devices.begin(), tile_assignment_devices.end());
// return the first non-default (iota) device assigments.
if (new_mesh_ids != device_mesh_ids) {
return new_mesh_ids;
}
}
}
}
}
// return the default (iota) device assignments.
return device_mesh_ids;
}

void ShardingUtil::ReshardParameters(
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ class ShardingUtil {

// Construct a device mesh for auto-sharding pass. Returns a tuple of mesh
// shape and device ids vectors.
static std::tuple<std::vector<int64_t>, std::vector<int64_t>>
GetAutoShardingMesh();
static std::vector<int64_t> GetAutoShardingMesh();
static std::vector<int64_t> GetAutoShardingMeshIds(
const xla::HloModuleProto& module);

// Reshard the parameters if the expected shardings mismatch. Resharding is
// expensive especially for those already sharded. The cost can easily be
Expand Down

0 comments on commit c923e8f

Please sign in to comment.