Skip to content

Commit

Permalink
Increase tolerance for tan (#5915)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored Dec 1, 2023
1 parent e98cb66 commit 2c4983d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
11 changes: 10 additions & 1 deletion FIX_LOWERING_FOR_CORE_ATEN_OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,16 @@ This subtest calls the torch_xla version of the op. If you've made changes to lo

### `torch_xla_diff`

This subtest compares the output of the op between torch and torch_xla. If this subtest fails, it implies that your lowering runs successfully but contains a bug and/or logical error. We recommend you to review your lowering code. And again, feel free to leave a comment in your assigned GitHub issue if you're blocked and/or unable to debug further.
This subtest compares the output of the op between torch and torch_xla.
If this subtest fails, it implies that your lowering runs successfully
but produced a different result than torch eager mode.

If the test uses 16-bit floats (float16, bfloat16); This is very likely
that the tolerances that we give to `torch.allclose` to compare was to
strict. You can relax it a bit. Take a look at [this issue](https://github.com/pytorch/xla/issues/5934) of one such example.

If the result torchxla produces is totally different than what torch produces, that means it's a bug in lowering code; and probably need
more work. Feel free to tag more people (such as qihqi to look).

### `can_export`, `can_convert_to_stablehlo`, `stablehlo_can_run`, `stablehlo_diff`

Expand Down
23 changes: 15 additions & 8 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,24 @@
import unittest


def diff_output(testcase, output1, output2, atol):
def diff_output(testcase, output1, output2, rtol, atol):
if isinstance(output1, torch.Tensor):
testcase.assertIsInstance(output2, torch.Tensor)
output2_cpu = output2.detach().cpu()
if output2_cpu.dtype != output1.dtype:
output2_cpu = output2_cpu.to(output1.dtype)
testcase.assertTrue(torch.allclose(output1, output2_cpu, atol=atol))
testcase.assertTrue(
torch.allclose(output1, output2_cpu, atol=atol, rtol=rtol))
elif isinstance(output1, (tuple, list)):
testcase.assertIsInstance(output2, (tuple, list))
testcase.assertEqual(len(output1), len(output2))
for o1, o2 in zip(output1, output2):
diff_output(testcase, o1, o2, atol)
diff_output(testcase, o1, o2, rtol, atol)
else:
testcase.assertEqual(output1, output2)


def run_export_and_compare(testcase, func, args, kwargs, atol=1e-3):
def run_export_and_compare(testcase, func, args, kwargs, atol=1e-3, rtol=1e-5):
device = xm.xla_device()
with testcase.subTest('torch_eval'):
res = func(*args, **kwargs)
Expand All @@ -36,15 +37,15 @@ def run_export_and_compare(testcase, func, args, kwargs, atol=1e-3):
lambda x: x.to(device=device), kwargs)
res_xla = func(*args2, **kwargs2)
with testcase.subTest('torch_xla_diff:' + str(atol)):
diff_output(testcase, res, res_xla, atol)
diff_output(testcase, res, res_xla, atol=atol, rtol=rtol)
with testcase.subTest('can_export'):
exported = torch.export.export(func, args, kwargs)
with testcase.subTest('can_convert_to_stablehlo'):
shlo = exported_program_to_stablehlo(exported)
with testcase.subTest('stablehlo_can_run'):
res2 = shlo(*args, **kwargs)
with testcase.subTest('stablehlo_diff: ' + str(atol)):
diff_output(testcase, res, res2, atol)
diff_output(testcase, res, res2, rtol=rtol, atol=atol)


class AtenOpTest(unittest.TestCase):
Expand Down Expand Up @@ -4372,11 +4373,17 @@ def test_aten_tan_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.tan, args, kwargs)

@unittest.skip
def test_aten_tan_1(self):
args = (torch.randn((10, 10)).to(torch.float16),)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.tan, args, kwargs)
run_export_and_compare(
self,
torch.ops.aten.tan,
args,
kwargs,
rtol=0.001,
atol=0.01,
)

def test_aten_tan_2(self):
args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)
Expand Down

0 comments on commit 2c4983d

Please sign in to comment.