From 61af41c92ce3a9d6dbdf08b6b4744efc1159a01b Mon Sep 17 00:00:00 2001 From: Ali Taghibakhshi <71892896+JRD971000@users.noreply.github.com> Date: Wed, 21 Aug 2024 18:30:14 -0500 Subject: [PATCH] fix mamba convert/ add test (#10224) * fix mamba convert/ add test * Apply isort and black reformatting Signed-off-by: JRD971000 * add mamba test * fix ngroup in cicd --------- Signed-off-by: JRD971000 Co-authored-by: JRD971000 --- .github/workflows/cicd-main.yml | 16 +++++++++++ .../convert_mamba2_pyt_to_nemo.py | 28 +++++++++++-------- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 797b7888b01e1..3fc2b1a127e7a 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -159,6 +159,21 @@ jobs: rm -f /home/TestData/nlp/megatron_ir/sbert/sbert.nemo rm -rf /home/TestData/nlp/megatron_ir/sbert/model_weights + L2_Community_LLM_Checkpoints_tests_Mamba2: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + python scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py \ + --input_name_or_path /home/TestData/nlp/megatron_mamba/model_optim_rng.pt \ + --output_path /home/TestData/nlp/megatron_mamba/converted_mamba.nemo \ + --precision=bf16 \ + --mamba_ssm_ngroups 1 + AFTER_SCRIPT: | + rm -f /home/TestData/nlp/megatron_mamba/converted_mamba.nemo + rm -rf /home/TestData/nlp/megatron_mamba/model_weights + L2_Community_LLM_Checkpoints_tests_Llama: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -4745,6 +4760,7 @@ jobs: - L0_Unit_Tests_GPU #- OPTIONAL_L0_Unit_Tests_CPU - L2_Community_LLM_Checkpoints_tests_Bert + - L2_Community_LLM_Checkpoints_tests_Mamba2 - L2_Community_LLM_Checkpoints_tests_Llama - L2_Community_LLM_Checkpoints_tests_StarCoder - L2_Community_LLM_Checkpoints_tests_Falcon diff --git a/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py b/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py index 1a0a13709421d..7a7484bf9c206 100644 --- a/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py +++ b/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py @@ -26,7 +26,7 @@ ''' Example -CUDA_VISIBLE_DEVICES="0" python /NeMo/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py \ +CUDA_VISIBLE_DEVICES="0" python /opt/NeMo/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py \ --input_name_or_path \ --output_path \ --mamba_ssm_ngroups 8 \ @@ -63,10 +63,24 @@ def get_args(): def convert(args): - checkpoint_weights = torch.load(args.input_name_or_path, map_location='cpu') + checkpoint_weights = torch.load(args.input_name_or_path, map_location='cpu')['model'] new_state_dict = {} if 'backbone' in list(checkpoint_weights.keys())[0]: + if 'model' in list(checkpoint_weights.keys())[0]: + checkpoint_weights = {key.replace('model.', '', 1): value for key, value in checkpoint_weights.items()} + + # Codestral Mamba Model Tokenizer Settings + tokenizer_library = 'megatron' + tokenizer_type = 'GPTSentencePieceTokenizer' + tokenizer_model = args.tokenizer_model_dir + + else: + + # Tri Dao and Albert Gu Mamba Model Tokenizer Settings + tokenizer_library = 'huggingface' + tokenizer_type = 'EleutherAI/gpt-neox-20b' + tokenizer_model = None layer_keys = [key for key in checkpoint_weights.keys() if re.match(r'backbone\.layers\.\d+\.', key)] layer_numbers = set(int(re.search(r'backbone\.layers\.(\d+)\.', key).group(1)) for key in layer_keys) @@ -103,11 +117,6 @@ def convert(args): old_key = f'backbone.layers.{i}.{attr}' new_state_dict[new_key] = checkpoint_weights[old_key] - # Tokenizer settings - tokenizer_library = 'huggingface' - tokenizer_type = 'EleutherAI/gpt-neox-20b' - tokenizer_model = None - else: layer_keys = [key for key in checkpoint_weights.keys() if re.match(r'decoder\.layers\.\d+\.', key)] @@ -124,11 +133,6 @@ def convert(args): tokenizer_type = 'GPTSentencePieceTokenizer' tokenizer_model = args.tokenizer_model_dir - # Tokenizer settings - tokenizer_library = 'megatron' - tokenizer_type = 'GPTSentencePieceTokenizer' - tokenizer_model = args.tokenizer_model_dir - layers = defaultdict(list) for key in new_state_dict.keys():