Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ZeRO-3] Partitioned init with deepspeed.zero.Init() #1190

Merged
merged 3 commits into from
Mar 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import math
import sys
from contextlib import nullcontext

import torch
import deepspeed
Expand Down Expand Up @@ -426,13 +427,15 @@ def get_model(neox_args, use_cache=False):
# If mup isn't being used anyways, this has no effect.
old_use_mup = neox_args.use_mup
neox_args.use_mup = False
model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
topology=mpu.get_topology(),
use_cache=use_cache,
)

with deepspeed.zero.Init() if neox_args.zero_stage == 3 else nullcontext() as gs:
model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
topology=mpu.get_topology(),
use_cache=use_cache,
)

### soft prompt tuning stuff ###
if neox_args.soft_prompt_tuning is not None and neox_args.soft_prompt_tuning.get(
Expand Down
Loading