Skip to content

Commit

Permalink
Add new API for custom mark sharding op and update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Oct 26, 2023
1 parent 2b5f14a commit ad4cc4f
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 3 deletions.
16 changes: 15 additions & 1 deletion test/spmd/test_dynamo_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@ def test_dynamo_spmd_basic(self):
# TODO(JackCaoG): add counter checks after ExecuteReplicated also creates
# a ExecuteMetric.

def test_dynamo_spmd_basic_with_custom_mark_sharding_op(self):
device = xm.xla_device()
linear = SimpleLinear().to(device)
linear.eval()
xla_x = torch.randn(1, 128, device=device)
xs.mark_sharding_dynamo_custom_op(linear.fc2.weight, self._get_mesh((1, self.n_devices)),
(1, 0))
xla_res = linear(xla_x)
xm.mark_step()

dynamo_linear = torch.compile(linear, backend="openxla")
dynamo_res = dynamo_linear(xla_x)
torch.allclose(xla_res.cpu(), dynamo_res.cpu())

def test_dynamo_spmd_output_sharding_spec(self):
device = xm.xla_device()
linear = SimpleLinear().to(device)
Expand Down Expand Up @@ -177,7 +191,7 @@ def fn_simple(x):
y = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]],
dtype=torch.float,
device=xm.xla_device())
ys = xs.mark_sharding(y, self._get_mesh((1, self.n_devices)), (0, 1))
ys = xs.mark_sharding_dynamo_custom_op(y, self._get_mesh((1, self.n_devices)), (0, 1))

return x + ys

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1626,7 +1626,7 @@ void InitXlaModuleBindings(py::module m) {
// Register sharded tensor data.
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
});
m.def("_xla_mark_sharding_custom_op",
m.def("_xla_mark_sharding_dynamo_custom_op",
[](const at::Tensor& input, xla::OpSharding sharding) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
tensor_methods::custom_mark_sharding(xtensor, sharding);
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,7 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input) {

xla::XlaOp BuildCustomMarkSharding(const torch::lazy::BackendDevice& device,
const xla::XlaOp& input,
const xla::XlaOp sharding) {
const xla::XlaOp& sharding) {
return xla::CustomCall(input.builder(), /*call_target_name=*/"MarkSharding",
{input}, ShapeHelper::ShapeOfXlaOp(input));
}
Expand Down
26 changes: 26 additions & 0 deletions torch_xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,32 @@ def mark_sharding(
torch_xla._XLAC._xla_mark_sharding(t, op_sharding)
return XLAShardedTensor(t)

@xr.requires_pjrt
def mark_sharding_dynamo_custom_op(
t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
partition_spec: Tuple[Union[Tuple, int, str, None]]) -> XLAShardedTensor:
"""
Same functionality as `mark_sharding` above, except this variant uses the custom mark_sharding op in torch_xla._XLAC to allow dynamo to recognize and trace it.
"""
num_devices = xr.global_runtime_device_count()
assert num_devices > 0, "This requires XLA supported device(s)."
assert mesh.size() == num_devices, \
f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
# We only allow fully specified `partition_spec` to be applicable, as opposed
# to filling in the unspecified replicated dims. Fully specified `partiion_spec`
# should be of the same rank as `t`. This is to support partial replication
# where the group assignment may vary with different input ranks.
assert len(t.shape) == len(partition_spec), \
f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."

op_sharding = mesh.get_op_sharding(partition_spec)

if isinstance(t, XLAShardedTensor):
torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding)
return t
torch_xla._XLAC._xla_mark_sharding(t, op_sharding)
return XLAShardedTensor(t)


def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor:
"""Clear sharding annotation from the input tensor and return a `cpu` casted tensor."""
Expand Down

0 comments on commit ad4cc4f

Please sign in to comment.