diff --git a/test/spmd/test_xla_spmd_python_api_interaction.py b/test/spmd/test_xla_spmd_python_api_interaction.py index edb51ac9d91..8f9c319ac54 100644 --- a/test/spmd/test_xla_spmd_python_api_interaction.py +++ b/test/spmd/test_xla_spmd_python_api_interaction.py @@ -135,6 +135,7 @@ def test_xla_autocast_api(self): class BasicDistributedTest(test_xla_sharding_base.XlaShardingTest): + @classmethod def setUpClass(cls): xr.use_spmd()