Skip to content

Commit

Permalink
mix precision setup & make fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
statelesshz authored and LysandreJik committed Oct 6, 2023
1 parent 1af9b59 commit 140bd1d
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,4 @@
else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
13 changes: 6 additions & 7 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,13 +575,12 @@ def __init__(
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
"but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
)
if args.fp16 or args.bf16:
if args.half_precision_backend == "auto":
if args.device == torch.device("cpu"):
if args.fp16:
raise ValueError("Tried to use `fp16` but it is not supported on cpu")
else:
args.half_precision_backend = "cpu_amp"
if (args.fp16 or args.bf16) and args.half_precision_backend == "auto":
if args.device == torch.device("cpu"):
if args.fp16:
raise ValueError("Tried to use `fp16` but it is not supported on cpu")
else:
args.half_precision_backend = "cpu_amp"
logger.info(f"Using {args.half_precision_backend} half precision backend")

if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):
Expand Down
10 changes: 5 additions & 5 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,9 @@ class TrainingArguments:
fp16_backend (`str`, *optional*, defaults to `"auto"`):
This argument is deprecated. Use `half_precision_backend` instead.
half_precision_backend (`str`, *optional*, defaults to `"auto"`):
The backend to use for mixed precision training. Must be one of `"auto", "cuda_amp", "apex", "cpu_amp"`.
`"auto"` will use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices
will force the requested backend.
The backend to use for mixed precision training. Must be one of `"auto", "apex", "cpu_amp"`. `"auto"` will
use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices will force the
requested backend.
bf16_full_eval (`bool`, *optional*, defaults to `False`):
Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm
metric values. This is an experimental API and it may change.
Expand Down Expand Up @@ -861,7 +861,7 @@ class TrainingArguments:
default="auto",
metadata={
"help": "The backend to be used for half precision.",
"choices": ["auto", "cuda_amp", "apex", "cpu_amp"],
"choices": ["auto", "apex", "cpu_amp"],
},
)
bf16_full_eval: bool = field(
Expand Down Expand Up @@ -1127,7 +1127,7 @@ class TrainingArguments:
default="auto",
metadata={
"help": "Deprecated. Use half_precision_backend instead",
"choices": ["auto", "cuda_amp", "apex", "cpu_amp"],
"choices": ["auto", "apex", "cpu_amp"],
},
)
push_to_hub_model_id: Optional[str] = field(
Expand Down

0 comments on commit 140bd1d

Please sign in to comment.