Skip to content

Commit

Permalink
Nemo ux HF import tests (#10274)
Browse files Browse the repository at this point in the history
* Add __all__ for mistral & mixtral

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Add model import test

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Add to cicd

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* skip mixtral test until use_safetnsors arg is resolved

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* syntax

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* typo

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* remove unused imports

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>

* remove streaming ckpt

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Update ckpt paths

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

---------

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>
Signed-off-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com>
Co-authored-by: akoumpa <akoumpa@users.noreply.github.com>
  • Loading branch information
akoumpa and akoumpa authored Sep 27, 2024
1 parent 5e66cad commit 5b88aaa
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 0 deletions.
13 changes: 13 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5223,6 +5223,18 @@ jobs:
AFTER_SCRIPT: |
rm -rf /home/TestData/nlp/megatron_mamba/nemo-ux-mamba/cicd_test_sft/${{ github.run_id }}
L2_NeMo_2_HF_MODEL_IMPORT:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/llm/gpt/model/test_model_import.py
AFTER_SCRIPT: |
rm -rf ~/.cache/nemo/models
L2_NeMo_2_T5_Pretraining:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -5375,6 +5387,7 @@ jobs:
#- OPTIONAL_L2_Stable_Diffusion_Training
- L2_NeMo_2_GPT_Pretraining_no_transformer_engine
- L2_NeMo_2_GPT_DDP_Param_Parity_check
- L2_NeMo_2_HF_MODEL_IMPORT
- L2_NeMo_2_SSM_Pretraining
- L2_NeMo_2_SSM_Finetuning
- L2_NeMo_2_T5_Pretraining
Expand Down
8 changes: 8 additions & 0 deletions nemo/collections/llm/gpt/model/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,11 @@ def _export_linear_fc1(linear_fc1):
gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0)

return gate_proj, up_proj


__all__ = [
"MistralConfig7B",
"MistralNeMo2407Config12B",
"MistralNeMo2407Config123B",
"MistralModel",
]
9 changes: 9 additions & 0 deletions nemo/collections/llm/gpt/model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,12 @@ def _export_moe_w1_w3(linear_fc1):
gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0)

return gate_proj, up_proj


__all__ = [
"MixtralConfig",
"MixtralConfig8x3B",
"MixtralConfig8x7B",
"MixtralConfig8x22B",
"MixtralModel",
]
85 changes: 85 additions & 0 deletions tests/collections/llm/gpt/model/test_model_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import torch

torch.set_grad_enabled(False)


config_name_to_hf_id = {
'MistralConfig7B': 'mistralai/Mistral-7B-v0.1',
# 'Nemotron3Config4B': 'nvidia/Minitron-4B-Base',
'Llama2Config7B': 'meta-llama/Llama-2-7b-hf',
'Llama3Config8B': 'meta-llama/Meta-Llama-3-8B',
# 'MixtralConfig8x7B': 'mistralai/Mixtral-8x7B-v0.1',
# 'ChatGLM2Config6B': 'THUDM/chatglm2-6b',
'GemmaConfig2B': 'google/gemma-2b',
# 'Baichuan2Config7B': 'baichuan-inc/Baichuan2-7B-Base',
}


def strip_digits_from_end(s):
s = list(s)
while s and s[-1].isdigit():
s = s[:-1]
return ''.join(s)


def get_modulename_from_config_name(config_name):
# Finds name of model class from config class name.
# Llama2Config7B -> Llama2Model (fail) -> LlamaModel
import nemo.collections.llm.gpt.model as nemo_ux_llms

assert 'Config' in config_name, 'Expected config_name to contain "Config".'
module_name = config_name.split('Config')[0] + "Model"
if not hasattr(nemo_ux_llms, module_name):
module_name = strip_digits_from_end(config_name.split('Config')[0]) + "Model"
if not hasattr(nemo_ux_llms, module_name):
raise ValueError("Failed to get modulename")
return module_name


def generate_twolayer_checkpoints(config_name, hf_id):
from transformers import AutoConfig, AutoModel, AutoTokenizer

config = AutoConfig.from_pretrained(hf_id, trust_remote_code=True)
# Reduce number of layers to two.
if hasattr(config, 'num_hidden_layers'):
print(config.num_hidden_layers)
config.num_hidden_layers = 2
elif hasattr(config, 'num_layers'):
print(config.num_layers)
config.num_layers = 2
else:
print(config)
raise ValueError("HF config has neither num_hidden_layers nor num_layers")

# Calling random init is slow.
with torch.device('meta'):
model_2l = AutoModel.from_config(config, trust_remote_code=True)

model_2l = model_2l.to_empty(device='cpu')
state = model_2l.state_dict()
# Fill state-dict with i/n
n = len(state.items())
for i, key in enumerate(state.keys()):
value = torch.empty_like(state[key]).fill_(i / n)
state[key] = value
model_2l.load_state_dict(state)
model_2l.save_pretrained(f'hf_ckpts/{config_name}/', safe_serialization=False)
hf_tokenizer = AutoTokenizer.from_pretrained(hf_id, trust_remote_code=True)
hf_tokenizer.save_pretrained(f'hf_ckpts/{config_name}/', trust_remote_code=True)


def import_from_hf(config_name, hf_path):
import nemo.collections.llm.gpt.model as nemo_ux_llms
from nemo.collections.llm import import_ckpt

module_name = get_modulename_from_config_name(config_name)
config_cls = getattr(nemo_ux_llms, config_name)
model_cls = getattr(nemo_ux_llms, module_name)
model = model_cls(config_cls())
import_ckpt(model=model, source=hf_path)


if __name__ == '__main__':
for config_name, hf_id in config_name_to_hf_id.items():
src = f'hf:///home/TestData/nemo2_ckpt/{config_name}'
import_from_hf(config_name, src)

0 comments on commit 5b88aaa

Please sign in to comment.