From aceb42cb1f22eb8321cd8d1ed8bc0ea103ff560a Mon Sep 17 00:00:00 2001 From: statelesshz Date: Sat, 9 Sep 2023 10:04:49 +0800 Subject: [PATCH] mix precision setup & make fixup --- src/transformers/integrations/__init__.py | 2 +- src/transformers/trainer.py | 13 ++++++------- src/transformers/training_args.py | 10 +++++----- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 4bb81d3dc6c880..ddd36955b3bf36 100644 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -137,4 +137,4 @@ else: import sys - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) \ No newline at end of file + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 4a3b6f80b14a43..a257822dae446f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -571,13 +571,12 @@ def __init__( f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." ) - if args.fp16 or args.bf16: - if args.half_precision_backend == "auto": - if args.device == torch.device("cpu"): - if args.fp16: - raise ValueError("Tried to use `fp16` but it is not supported on cpu") - else: - args.half_precision_backend = "cpu_amp" + if (args.fp16 or args.bf16) and args.half_precision_backend == "auto": + if args.device == torch.device("cpu"): + if args.fp16: + raise ValueError("Tried to use `fp16` but it is not supported on cpu") + else: + args.half_precision_backend = "cpu_amp" logger.info(f"Using {args.half_precision_backend} half precision backend") if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()): diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index c02c02c9bf37d9..a7b60b5e7d8cde 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -325,9 +325,9 @@ class TrainingArguments: fp16_backend (`str`, *optional*, defaults to `"auto"`): This argument is deprecated. Use `half_precision_backend` instead. half_precision_backend (`str`, *optional*, defaults to `"auto"`): - The backend to use for mixed precision training. Must be one of `"auto", "cuda_amp", "apex", "cpu_amp"`. - `"auto"` will use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices - will force the requested backend. + The backend to use for mixed precision training. Must be one of `"auto", "apex", "cpu_amp"`. `"auto"` will + use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices will force the + requested backend. bf16_full_eval (`bool`, *optional*, defaults to `False`): Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm metric values. This is an experimental API and it may change. @@ -859,7 +859,7 @@ class TrainingArguments: default="auto", metadata={ "help": "The backend to be used for half precision.", - "choices": ["auto", "cuda_amp", "apex", "cpu_amp"], + "choices": ["auto", "apex", "cpu_amp"], }, ) bf16_full_eval: bool = field( @@ -1125,7 +1125,7 @@ class TrainingArguments: default="auto", metadata={ "help": "Deprecated. Use half_precision_backend instead", - "choices": ["auto", "cuda_amp", "apex", "cpu_amp"], + "choices": ["auto", "apex", "cpu_amp"], }, ) push_to_hub_model_id: Optional[str] = field(