diff --git a/src/transformers/training_args_tf.py b/src/transformers/training_args_tf.py index 4b35b66b07e2c8..060b78e9220518 100644 --- a/src/transformers/training_args_tf.py +++ b/src/transformers/training_args_tf.py @@ -195,8 +195,7 @@ def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]: # Set to float16 at first if self.fp16: - policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16") - tf.keras.mixed_precision.experimental.set_policy(policy) + tf.keras.mixed_precision.set_global_policy("mixed_float16") if self.no_cuda: strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0") @@ -217,8 +216,7 @@ def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]: if tpu: # Set to bfloat16 in case of TPU if self.fp16: - policy = tf.keras.mixed_precision.experimental.Policy("mixed_bfloat16") - tf.keras.mixed_precision.experimental.set_policy(policy) + tf.keras.mixed_precision.set_global_policy("mixed_bfloat16") tf.config.experimental_connect_to_cluster(tpu) tf.tpu.experimental.initialize_tpu_system(tpu) diff --git a/tests/utils/test_modeling_tf_core.py b/tests/utils/test_modeling_tf_core.py index 8edfc8eab02d4c..abdce686835077 100644 --- a/tests/utils/test_modeling_tf_core.py +++ b/tests/utils/test_modeling_tf_core.py @@ -205,7 +205,7 @@ def test_saved_model_creation_extended(self): @slow def test_mixed_precision(self): - tf.keras.mixed_precision.experimental.set_policy("mixed_float16") + tf.keras.mixed_precision.set_global_policy("mixed_float16") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -216,7 +216,7 @@ def test_mixed_precision(self): self.assertIsNotNone(outputs) - tf.keras.mixed_precision.experimental.set_policy("float32") + tf.keras.mixed_precision.set_global_policy("float32") @slow def test_train_pipeline_custom_model(self):