Skip to content

Commit

Permalink
Handle dynamo counter change after fake tensor cache (pytorch#6308)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Jan 17, 2024
1 parent 7fae2ab commit d6dc1a0
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def forward(self, index, copy_tensor, input_tensor, op_name):
return output

torch._dynamo.reset()
met.clear_counters()
met.clear_all()
device = xm.xla_device()

Expand Down Expand Up @@ -236,7 +235,6 @@ def fn_fallback(t):
return torch._foobar(t)

torch._dynamo.reset()
met.clear_counters()
met.clear_all()
device = xm.xla_device()

Expand All @@ -247,24 +245,25 @@ def fn_fallback(t):
cpu_res = fn_fallback(t)
xla_dynamo_res = dynamo_fn(t_xla)
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res.cpu()))
# 2 compilations are caused by `t_xla` init and a no-op graph.
self.assertEqual(met.metric_data('CompileTime')[0], 2)
# TODO(JackCaoG): invesgate this execution, from the HLO it is creating
# a f32[5] with all zeros. The cause of the execution is
# run_node (/src/pytorch/torch/_dynamo/utils.py:1381)
self.assertEqual(met.metric_data('ExecuteTime')[0], 4)
self.assertEqual(met.metric_data('ExecuteTime')[0], 2)

# Second tracing
met.clear_all()
xla_dynamo_res_2 = dynamo_fn(t_xla)
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res_2.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 2)
self.assertEqual(met.metric_data('ExecuteTime')[0], 4)
self.assertEqual(met.metric_data('CompileTime'), None)
self.assertEqual(met.metric_data('ExecuteTime'), None)

# Verify that dynamo can handle different inputs
met.clear_all()
xla_dynamo_res_3 = dynamo_fn(t_xla * 3)
cpu_res_3 = fn_fallback(t * 3)
self.assertTrue(torch.allclose(cpu_res_3, xla_dynamo_res_3.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 3)
self.assertEqual(met.metric_data('ExecuteTime')[0], 5)
# Compilation and executation are caused by `t * 3`
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)

def test_fallback_multiple_submodules(self):

Expand All @@ -276,7 +275,6 @@ def fn_fallback(t):
return t_4

torch._dynamo.reset()
met.clear_counters()
met.clear_all()
device = xm.xla_device()

Expand All @@ -288,21 +286,25 @@ def fn_fallback(t):
xla_dynamo_res = dynamo_fn(t_xla)
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 3)
self.assertEqual(met.metric_data('ExecuteTime')[0], 11)
self.assertEqual(met.metric_data('ExecuteTime')[0], 9)

# Second tracing
met.clear_counters()
met.clear_all()
xla_dynamo_res_2 = dynamo_fn(t_xla)
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res_2.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 3)
self.assertEqual(met.metric_data('ExecuteTime')[0], 13)
# We don't expect any new compilations. There will be 2 new executations
# since there is a fallback in the middle.
self.assertEqual(met.metric_data('CompileTime'), None)
self.assertEqual(met.metric_data('ExecuteTime')[0], 2)

# Verify that dynamo can handle different inputs
met.clear_all()
xla_dynamo_res_3 = dynamo_fn(t_xla * 3)
cpu_res_3 = fn_fallback(t * 3)
self.assertTrue(torch.allclose(cpu_res_3, xla_dynamo_res_3.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 4)
self.assertEqual(met.metric_data('ExecuteTime')[0], 16)
# We expect one more compilation and execution due to input is `t_xla * 3` which is a computation.
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], 3)


class DynamoTrainingBasicTest(unittest.TestCase):
Expand Down

0 comments on commit d6dc1a0

Please sign in to comment.