Skip to content

Commit

Permalink
ensure that the hftrainer deepspeed config is set before the trainer …
Browse files Browse the repository at this point in the history
…class is ever init'ed (#1850) [skip ci]
  • Loading branch information
winglian authored Aug 22, 2024
1 parent de4ea2d commit 2f8037f
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,12 +399,15 @@ def setup_torch_compile_env(cfg):


def setup_deepspeed_env(cfg, stage=None):
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig

os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if stage:
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
HfTrainerDeepSpeedConfig(cfg.deepspeed)


def setup_fsdp_envs(cfg):
Expand Down

0 comments on commit 2f8037f

Please sign in to comment.