Skip to content

Commit

Permalink
Add unit tests for a few foo.Tensor aten ops (pytorch#6749)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghpvnist authored Mar 18, 2024
1 parent ff2bfe6 commit 6ac3223
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 3 deletions.
89 changes: 89 additions & 0 deletions test/stablehlo/test_unbounded_dynamism.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,24 @@ def test_conv(self):
compare_exported_program_and_saved_model_result(ep, tempdir, args)

def test_div(self):
args = (torch.rand((10, 12, 197)), torch.rand((10, 12, 197)))
dynamic_shapes = ([{0: Dim("dim")}, {0: Dim("dim")}],)
m = wrap_func_as_nn_module(torch.ops.aten.div.Tensor)
ep = export(m, args=args, dynamic_shapes=dynamic_shapes)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(
r'tensor<\?x12x197xf32>.*tensor<\?x12x197xf32>.*->.*tensor<\?x12x197xf32>',
shlo_text) is not None)
if has_tf_package():
with tempfile.TemporaryDirectory() as tempdir:
save_torch_module_as_tf_saved_model(
m, args, tempdir, dynamic_shapes=dynamic_shapes)
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)

def test_div_scalar(self):
args = (torch.rand((10, 12, 197)), 8.0)
dynamic_shapes = ([{0: Dim("dim")}, None],)
m = wrap_func_as_nn_module(torch.ops.aten.div.Tensor)
Expand Down Expand Up @@ -187,6 +205,42 @@ def test_gelu(self):
# "(%arg0: tensor<?x2xi64>, %arg1: tensor<?x2xi64>) -> tensor<?x2xi64>" in
# shlo_text)

def test_mul(self):
args = (torch.rand((10, 2, 768)), torch.rand((10, 2, 768)))
dynamic_shapes = ([{0: Dim("dim")}, {0: Dim("dim")}],)
m = wrap_func_as_nn_module(torch.ops.aten.mul.Tensor)
ep = export(m, args=args, dynamic_shapes=dynamic_shapes)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(
r'tensor<\?x2x768xf32>.*tensor<\?x2x768xf32>.*->.*tensor<\?x2x768xf32>',
shlo_text) is not None)
if has_tf_package():
with tempfile.TemporaryDirectory() as tempdir:
save_torch_module_as_tf_saved_model(
m, args, tempdir, dynamic_shapes=dynamic_shapes)
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)

def test_mul_scalar(self):
args = (torch.rand((10, 2, 768)), 0.125)
dynamic_shapes = ([{0: Dim("dim")}, None],)
m = wrap_func_as_nn_module(torch.ops.aten.mul.Tensor)
ep = export(m, args=args, dynamic_shapes=dynamic_shapes)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(
r'tensor<f32>.*tensor<\?x2x768xf32>.*->.*tensor<\?x2x768xf32>',
shlo_text) is not None)
if has_tf_package():
with tempfile.TemporaryDirectory() as tempdir:
save_torch_module_as_tf_saved_model(
m, args, tempdir, dynamic_shapes=dynamic_shapes)
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)

@unittest.skip("Unbounded Dynamism not supported yet.")
def test_native_layer_norm(self):
args = (
Expand Down Expand Up @@ -276,6 +330,41 @@ def test_softmax(self):
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)

def test_sub(self):
args = (torch.rand((10, 1, 1, 10)), torch.rand((10, 1, 1, 10)))
dynamic_shapes = ([{0: Dim("dim")}, {0: Dim("dim")}],)
m = wrap_func_as_nn_module(torch.ops.aten.sub.Tensor)
ep = export(m, args=args, dynamic_shapes=dynamic_shapes)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(
r'tensor<\?x1x1x10xf32>.*tensor<\?x1x1x10xf32>.*->.*tensor<\?x1x1x10xf32>',
shlo_text) is not None)
if has_tf_package():
with tempfile.TemporaryDirectory() as tempdir:
save_torch_module_as_tf_saved_model(
m, args, tempdir, dynamic_shapes=dynamic_shapes)
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)

def test_sub_scalar(self):
args = (1.0, torch.rand((10, 1, 1, 10)))
dynamic_shapes = ([None, {0: Dim("dim")}],)
m = wrap_func_as_nn_module(torch.ops.aten.sub.Tensor)
ep = export(m, args=args, dynamic_shapes=dynamic_shapes)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(r'tensor<\?x1x1x10xf32>.*->.*tensor<\?x1x1x10xf32>',
shlo_text) is not None)
if has_tf_package():
with tempfile.TemporaryDirectory() as tempdir:
save_torch_module_as_tf_saved_model(
m, args, tempdir, dynamic_shapes=dynamic_shapes)
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)


if __name__ == "__main__":
test = unittest.main()
Expand Down
12 changes: 9 additions & 3 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2761,14 +2761,20 @@ XLATensorPtr sub(const XLATensorPtr& input, const XLATensorPtr& other,
xla::Shape input_shape = input->shape().get();
xla::Shape other_shape = other->shape().get();
torch::lazy::Value alpha_xla;
const torch::lazy::BackendDevice& device = input->GetDevice();
if (!input_shape.is_dynamic() && !other_shape.is_dynamic()) {
alpha_xla = XLAGraphExecutor::Get()->GetIrValueForScalar(
alpha, other->shape(), logical_element_type, input->GetDevice());
alpha,
xla::ShapeUtil::MakeScalarShape(
MakeXlaPrimitiveType(other->dtype(), &device)),
logical_element_type, device);
} else {
SymIntElements sym_int_elements(other->GetIrValue());
alpha_xla = XLAGraphExecutor::Get()->GetIrValueForScalar(
alpha, other->shape(), sym_int_elements, logical_element_type,
input->GetDevice());
alpha,
xla::ShapeUtil::MakeScalarShape(
MakeXlaPrimitiveType(other->dtype(), &device)),
sym_int_elements, logical_element_type, device);
}

return input->CreateFrom(
Expand Down

0 comments on commit 6ac3223

Please sign in to comment.