Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix mamba convert/ add test #10224

Merged
merged 7 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,21 @@ jobs:
rm -f /home/TestData/nlp/megatron_ir/sbert/sbert.nemo
rm -rf /home/TestData/nlp/megatron_ir/sbert/model_weights

L2_Community_LLM_Checkpoints_tests_Mamba2:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
with:
RUNNER: self-hosted-azure
SCRIPT: |
python scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py \
--input_name_or_path /home/TestData/nlp/megatron_mamba/model_optim_rng.pt \
--output_path /home/TestData/nlp/megatron_mamba/converted_mamba.nemo \
--precision=bf16 \
--mamba_ssm_ngroups 1
AFTER_SCRIPT: |
rm -f /home/TestData/nlp/megatron_mamba/converted_mamba.nemo
rm -rf /home/TestData/nlp/megatron_mamba/model_weights

L2_Community_LLM_Checkpoints_tests_Llama:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -4745,6 +4760,7 @@ jobs:
- L0_Unit_Tests_GPU
#- OPTIONAL_L0_Unit_Tests_CPU
- L2_Community_LLM_Checkpoints_tests_Bert
- L2_Community_LLM_Checkpoints_tests_Mamba2
- L2_Community_LLM_Checkpoints_tests_Llama
- L2_Community_LLM_Checkpoints_tests_StarCoder
- L2_Community_LLM_Checkpoints_tests_Falcon
Expand Down
28 changes: 16 additions & 12 deletions scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
'''
Example

CUDA_VISIBLE_DEVICES="0" python /NeMo/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py \
CUDA_VISIBLE_DEVICES="0" python /opt/NeMo/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py \
--input_name_or_path <path to the source pytorch model> \
--output_path <path to target .nemo model> \
--mamba_ssm_ngroups 8 \
Expand Down Expand Up @@ -63,10 +63,24 @@ def get_args():

def convert(args):

checkpoint_weights = torch.load(args.input_name_or_path, map_location='cpu')
checkpoint_weights = torch.load(args.input_name_or_path, map_location='cpu')['model']
new_state_dict = {}

if 'backbone' in list(checkpoint_weights.keys())[0]:
if 'model' in list(checkpoint_weights.keys())[0]:
checkpoint_weights = {key.replace('model.', '', 1): value for key, value in checkpoint_weights.items()}

# Codestral Mamba Model Tokenizer Settings
tokenizer_library = 'megatron'
tokenizer_type = 'GPTSentencePieceTokenizer'
tokenizer_model = args.tokenizer_model_dir

else:

# Tri Dao and Albert Gu Mamba Model Tokenizer Settings
tokenizer_library = 'huggingface'
tokenizer_type = 'EleutherAI/gpt-neox-20b'
tokenizer_model = None

layer_keys = [key for key in checkpoint_weights.keys() if re.match(r'backbone\.layers\.\d+\.', key)]
layer_numbers = set(int(re.search(r'backbone\.layers\.(\d+)\.', key).group(1)) for key in layer_keys)
Expand Down Expand Up @@ -103,11 +117,6 @@ def convert(args):
old_key = f'backbone.layers.{i}.{attr}'
new_state_dict[new_key] = checkpoint_weights[old_key]

# Tokenizer settings
tokenizer_library = 'huggingface'
tokenizer_type = 'EleutherAI/gpt-neox-20b'
tokenizer_model = None

else:

layer_keys = [key for key in checkpoint_weights.keys() if re.match(r'decoder\.layers\.\d+\.', key)]
Expand All @@ -124,11 +133,6 @@ def convert(args):
tokenizer_type = 'GPTSentencePieceTokenizer'
tokenizer_model = args.tokenizer_model_dir

# Tokenizer settings
tokenizer_library = 'megatron'
tokenizer_type = 'GPTSentencePieceTokenizer'
tokenizer_model = args.tokenizer_model_dir

layers = defaultdict(list)

for key in new_state_dict.keys():
Expand Down
Loading