Skip to content

Commit

Permalink
*Enable test_xla_auto_sharding.py
Browse files Browse the repository at this point in the history
* Linter fix
  • Loading branch information
yeounoh committed Mar 12, 2024
1 parent 5946e19 commit d3c1d70
Show file tree
Hide file tree
Showing 14 changed files with 53 additions and 46 deletions.
17 changes: 7 additions & 10 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
std::fill_n(tensors.begin(), tensors.size(), tensor);
std::vector<std::string> devices(2);
std::vector<std::string> devices(3);
std::fill_n(devices.begin(), devices.size(),
bridge::GetDefaultDevice()->toString());
std::vector<XLATensor::ShardingSpecPtr> shardings = {
Expand All @@ -345,10 +345,7 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
shards[0]->shape()));
EXPECT_TRUE(XlaDataValuesEqual(tensors_data[0], shards[0], at::kFloat));

// Returns multiple input shards, explicitly replicated
int64_t n_devices =
torch_xla::runtime::GetComputationClient()->GetLocalDevices().size();
if (n_devices > 1) {
// Returns multiple input shards, explicitly replicated
auto sharded_xla_data =
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
tensors_data[1]);
Expand All @@ -358,11 +355,10 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(sharded_xla_data->shape(),
shards[0]->shape()));
EXPECT_TRUE(XlaDataValuesEqual(shards[0], shards[1], at::kFloat));
}

// Returns multiple input shards, implicitly replicated
if (n_devices > 1) {
auto sharded_xla_data =
// Returns multiple input shards, implicitly replicated

sharded_xla_data =
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
tensors_data[2]);
shards = torch_xla::runtime::GetComputationClient()->GetDataShards(
Expand Down Expand Up @@ -427,7 +423,8 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
xla::HloSharding::Replicate().ToProto(), sharding_specs[0]->sharding));
}

// Check if the placeholder is on a SPMD device (sharded) with no real values.
// Check if the placeholder is on a SPMD device (sharded) with no real
// values.
EXPECT_EQ(data_placeholders.size(), 1);
EXPECT_EQ(
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
Expand Down
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/spmd/test_xla_distributed_checkpoint.py"
run_test "$CDIR/spmd/test_xla_spmd_python_api_interaction.py"
run_test "$CDIR/spmd/test_dtensor_integration.py"
run_test "$CDIR/spmd/test_xla_auto_sharding.py"
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_input_output_aliases.py"
run_test "$CDIR/test_torch_distributed_xla_backend.py"
Expand Down
11 changes: 6 additions & 5 deletions test/spmd/test_dynamo_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,23 +122,24 @@ def test_dynamo_input_sharding_changed(self):
xs.mark_sharding(xla_x, self._get_mesh((1, self.n_devices)), (1, 0))
dynamo_res_sharded = dynamo_linear(xla_x)
torch.allclose(dynamo_res.cpu(), dynamo_res_sharded.cpu())
# one graph is being generated by .cpu call above
# one graph is being generated per computation to get `dynamo_res_sharded` (1),
# and another for resharding the data (1) and to move the data to cpu (2).
if self.n_devices > 1:
self.assertEqual(met.metric_data('CompileTime')[0], 3)
self.assertEqual(met.metric_data('CompileTime')[0], 1 + 1 + 1 + 2)
else:
# if there is only one device(cpu) then sharding spec will be replicated
# hence no change.
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('CompileTime')[0], 1 + 2)

# Call the dynamo function with a different input with different sharding
xs.mark_sharding(xla_y, self._get_mesh((1, self.n_devices)), (0, 1))
dynamo_res_sharded_2 = dynamo_linear(xla_y)
if self.n_devices > 1:
self.assertEqual(met.metric_data('CompileTime')[0], 4)
self.assertEqual(met.metric_data('CompileTime')[0], 5 + 1 + 1)
else:
# if there is only one device(cpu) then sharding spec will be replicated
# hence no change.
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('CompileTime')[0], 1 + 2)
torch.allclose(linear(xla_y).cpu(), dynamo_res_sharded_2.cpu())

@unittest.skipIf(xr.global_runtime_device_count() == 1,
Expand Down
12 changes: 5 additions & 7 deletions test/spmd/test_xla_auto_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,12 @@ def setUpClass(cls):
xr.use_spmd(auto=True)
super().setUpClass()


@unittest.skipUnless(xr.device_type() in ["TPU", "CPU"],
"Auto-sharding currently supports TPU device.")
def test_matmul(self):
met.clear_counters()
t1 = torch.randn(64, 128)
t2 = torch.randn(128, 256)
t1 = torch.ones(64, 128)
t2 = torch.ones(128, 256)
t3 = (t1 @ t2).sum()

xt1 = t1.to(xm.xla_device())
Expand All @@ -48,7 +47,6 @@ def test_matmul(self):
self.assertEqual(met.counter_value("CompileWithAutoSharding"), 1)
self.assertTrue(torch.allclose(t3, xt3.cpu()))


@unittest.skipUnless(xr.device_type() in ["TPU", "CPU"],
"Auto-sharding currently supports TPU device.")
def test_simple_linear_training(self):
Expand All @@ -68,9 +66,9 @@ def test_simple_linear_training(self):
optimizer.step()
xm.mark_step()

self.assertEqual(met.counter_value("UncachedCompile"), 2)
self.assertEqual(met.counter_value("CachedCompile"), 3)
self.assertEqual(met.counter_value("CompileWithAutoSharding"), 2)
self.assertEqual(met.counter_value("UncachedCompile"), 3)
self.assertEqual(met.counter_value("CachedCompile"), 2)
self.assertEqual(met.counter_value("CompileWithAutoSharding"), 3)


if __name__ == '__main__':
Expand Down
3 changes: 2 additions & 1 deletion test/spmd/test_xla_sharding_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def setUpClass(cls):
@classmethod
def tearDownClass(cls):
del os.environ['XLA_USE_SPMD']
del os.environ['XLA_AUTO_SPMD']
if 'XLA_AUTO_SPMD' in os.environ:
del os.environ['XLA_AUTO_SPMD']

def _get_mesh(self, mesh_shape, device_ids=None, axis_names=None):
assert type(mesh_shape) is tuple, 'mesh_shape must be Tuple[int]'
Expand Down
3 changes: 0 additions & 3 deletions test/test_train_mp_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@

import os
import sys
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(parent_dir)

import schedulers
import numpy as np
import torch
Expand Down
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ python3 test/spmd/test_xla_virtual_device.py
python3 test/spmd/test_xla_distributed_checkpoint.py
python3 test/spmd/test_train_spmd_linear_model.py
python3 test/spmd/test_xla_spmd_python_api_interaction.py
python3 test/spmd/test_xla_auto_sharding.py
XLA_EXPERIMENTAL=nonzero:masked_select python3 test/ds/test_dynamic_shape_models.py -v
XLA_EXPERIMENTAL=nonzero:masked_select python3 test/ds/test_dynamic_shapes.py -v
python3 test/test_autocast.py
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ torch::lazy::BackendDevice GetVirtualDevice();
bool UseVirtualDevice(bool force_spmd = false);

// Return true if device is of "SPMD" device type.
bool IsVirtualDevice(const std::string& device);
bool IsVirtualDevice(const std::string& device);

// Return true if SPMD config can be switches. That is, no device has been
// initialized, yet.
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,8 @@ xla::StatusOr<xla::XlaComputation> XlaHelpers::WrapXlaComputation(
for (int i = 0; i < parameter_shapes.size(); ++i) {
*input_tuple_shape.add_tuple_shapes() = parameter_shapes[i];
}
xla::XlaOp input_tuple = xla::Parameter(&builder, 0, input_tuple_shape, "in.");
xla::XlaOp input_tuple =
xla::Parameter(&builder, 0, input_tuple_shape, "in.");

// Handle the results of the original computation.
std::vector<xla::XlaOp> inner_params;
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::ReshardData(
&program_shape.result(),
/*should_wrap_parameter=*/false,
/*is_sharded=*/true,
/*allow_spmd_sharding_propagation_to_output=*/true});
/*allow_spmd_sharding_propagation_to_output=*/false});
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>
computation = Compile(std::move(instances)).front();

Expand Down
10 changes: 6 additions & 4 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ torch::lazy::BackendDataPtr XLATensor::GetXlaData() {
return data()->handle;
}

void XLATensor::SetShardingSpec(const ShardingSpec& sharding, bool allow_overwrite) {
void XLATensor::SetShardingSpec(const ShardingSpec& sharding,
bool allow_overwrite) {
// Existing annotation must be cleared explicitly. We do not clear and
// overwrite the existing sharding on the user's behalf. This is a no-op if
// the same sharding already applied.
Expand Down Expand Up @@ -288,9 +289,10 @@ XLATensor::ShardingSpecPtr XLATensor::sharding_spec() const {
// Re-sync the sharding annotation from the node to the tensor if there is
// one attached to the node. A new sharding annotation is attached
// directly to the node, and gets synced to the tensor after this.
// If sharding is attached via SetShardingSpec, then it flows from the tensor
// to the node. If sharding is attached by the compiler pass, then it first
// gets attached to the graph node, and then synced to the tensor here.
// If sharding is attached via SetShardingSpec, then it flows from the
// tensor to the node. If sharding is attached by the compiler pass, then
// it first gets attached to the graph node, and then synced to the tensor
// here.
if (!sharding ||
(sharding && !ShardingUtil::EqualOpShardings(*new_op_sharding,
sharding->sharding))) {
Expand Down
12 changes: 9 additions & 3 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1437,9 +1437,15 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
TF_VLOG(5) << "Parameter sequence hash after resharding: "
<< torch::lazy::Hash(po_data->parameter_sequence);
}
XLA_CHECK(po_data->parameters_data.size() == should_wrap_parameter
? program_shape.parameters()[0].tuple_shapes_size()
: program_shape.parameters_size());

if (should_wrap_parameter) {
XLA_CHECK_EQ(program_shape.parameters_size(), 1);
XLA_CHECK_EQ(program_shape.parameters()[0].tuple_shapes_size(),
po_data->parameters_data.size());
} else {
XLA_CHECK_EQ(program_shape.parameters_size(),
po_data->parameters_data.size());
}

return {/*device=*/coll.device,
/*emitted_nodes=*/lowering_ctx.GetEmittedNodeCount(),
Expand Down
20 changes: 11 additions & 9 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -779,12 +779,16 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
// to clear the existing sharding first.
XLATensor::ShardingSpecPtr current_sharding_spec = xtensor->sharding_spec();
if (current_sharding_spec) {
XLA_CHECK(current_sharding_spec->sharding.type() ==
xla::OpSharding::UNKNOWN ||
ShardingUtil::EqualShardingSpecs(*new_sharding_spec,
*current_sharding_spec))
<< "Existing annotation must be cleared first.";
return;
if (ShardingUtil::EqualShardingSpecs(*new_sharding_spec,
*current_sharding_spec)) {
return;
}
auto type = current_sharding_spec->sharding.type();
if (type != xla::OpSharding::REPLICATED &&
type != xla::OpSharding::UNKNOWN) {
XLA_CHECK(false) << "Existing annotation must be cleared first: "
<< current_sharding_spec->sharding.DebugString();
}
}

// If the at::Tensor data is not present, we need to re-download the
Expand Down Expand Up @@ -850,7 +854,5 @@ void ShardingUtil::SetAutoSharding() {
// This stays on throughout the program.
use_auto_sharding = true;
}
bool ShardingUtil::GetAutoSharding() {
return use_auto_sharding;
}
bool ShardingUtil::GetAutoSharding() { return use_auto_sharding; }
} // namespace torch_xla
2 changes: 1 addition & 1 deletion torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ class ShardingType(IntEnum):
TILED = 3
MANUAL = 4
PARTIAL = 5
UNKNOWN = 6 # implicit replication. TODO(yeounoh) wait for auto-sharding support
UNKNOWN = 6 # implicit replication. TODO(yeounoh) wait for auto-sharding support


def _get_sharding_type(partition_spec: Tuple[Union[int, None]],
Expand Down

0 comments on commit d3c1d70

Please sign in to comment.