From d23e00c23e9b848a8b613bc60ef2a7c03c4988a8 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Mon, 8 Jul 2024 10:07:37 -0700 Subject: [PATCH] Update test_pallas_spmd.py for CI Update test_pallas_spmd.py for setup config jax_default_matmul_precision from jax.lax.Precision.DEFAULT and jax.lax.Precision.DEFAULT to "highest" and "default" Same as https://github.com/pytorch/xla/pull/7629 --- test/test_pallas_spmd.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_pallas_spmd.py b/test/test_pallas_spmd.py index 334345941916..ee7cfe018403 100644 --- a/test/test_pallas_spmd.py +++ b/test/test_pallas_spmd.py @@ -31,7 +31,7 @@ def _attention(self, q, k, v): @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, "This test only works on TPUv3+.") def test_flash_attention_spmd_data_parallel(self): - jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) + jax.config.update('jax_default_matmul_precision', "highest") n_devices = xr.global_runtime_device_count() xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1))) @@ -46,12 +46,12 @@ def test_flash_attention_spmd_data_parallel(self): expected_o = self._attention(q, k, v) self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) - jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + jax.config.update('jax_default_matmul_precision', "default") @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, "This test only works on TPUv3+.") def test_flash_attention_backward_spmd_data_parallel(self): - jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) + jax.config.update('jax_default_matmul_precision', "highest") n_devices = xr.global_runtime_device_count() xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1))) @@ -96,7 +96,7 @@ def test_flash_attention_backward_spmd_data_parallel(self): for i in [(q, q_grad), (k, k_grad), (v, v_grad)]: self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) - jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT) + jax.config.update('jax_default_matmul_precision', "default") if __name__ == '__main__':