diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 95b7ff9..7af3b69 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -657,12 +657,16 @@ def train( if is_gpu_used and use_deepspeed: deepspeed_plugin = DeepSpeedPlugin() logger.info("Using DeepSpeed training.") + if not fp16: + logger.info("Setting FP16 to True for DeepSpeed ZeRO Training.") + fp16 = True + train_params = dict( accumulate_grad_batches=gradient_accumulation_steps, gpus=n_gpu, max_steps=num_steps, - gradient_clip_val=max_grad_norm if not fp16 else 0, + gradient_clip_val=max_grad_norm, checkpoint_callback=False, logger=loggers if loggers else False, weights_summary=None, diff --git a/setup.py b/setup.py index 0586412..767343f 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ install_requires=[ "transformers>=4.3.0", "fire>=0.3.0", - "pytorch-lightning>=1.2.0", + "pytorch-lightning>=1.2.3", "torch>=1.6.0", ], )