-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add __all__ for mistral & mixtral Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Add model import test Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Add to cicd Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * skip mixtral test until use_safetnsors arg is resolved Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * syntax Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * typo Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * remove unused imports Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Apply isort and black reformatting Signed-off-by: akoumpa <akoumpa@users.noreply.github.com> * remove streaming ckpt Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * Update ckpt paths Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> --------- Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> Signed-off-by: akoumpa <akoumpa@users.noreply.github.com> Signed-off-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Co-authored-by: akoumpa <akoumpa@users.noreply.github.com>
- Loading branch information
Showing
4 changed files
with
115 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import torch | ||
|
||
torch.set_grad_enabled(False) | ||
|
||
|
||
config_name_to_hf_id = { | ||
'MistralConfig7B': 'mistralai/Mistral-7B-v0.1', | ||
# 'Nemotron3Config4B': 'nvidia/Minitron-4B-Base', | ||
'Llama2Config7B': 'meta-llama/Llama-2-7b-hf', | ||
'Llama3Config8B': 'meta-llama/Meta-Llama-3-8B', | ||
# 'MixtralConfig8x7B': 'mistralai/Mixtral-8x7B-v0.1', | ||
# 'ChatGLM2Config6B': 'THUDM/chatglm2-6b', | ||
'GemmaConfig2B': 'google/gemma-2b', | ||
# 'Baichuan2Config7B': 'baichuan-inc/Baichuan2-7B-Base', | ||
} | ||
|
||
|
||
def strip_digits_from_end(s): | ||
s = list(s) | ||
while s and s[-1].isdigit(): | ||
s = s[:-1] | ||
return ''.join(s) | ||
|
||
|
||
def get_modulename_from_config_name(config_name): | ||
# Finds name of model class from config class name. | ||
# Llama2Config7B -> Llama2Model (fail) -> LlamaModel | ||
import nemo.collections.llm.gpt.model as nemo_ux_llms | ||
|
||
assert 'Config' in config_name, 'Expected config_name to contain "Config".' | ||
module_name = config_name.split('Config')[0] + "Model" | ||
if not hasattr(nemo_ux_llms, module_name): | ||
module_name = strip_digits_from_end(config_name.split('Config')[0]) + "Model" | ||
if not hasattr(nemo_ux_llms, module_name): | ||
raise ValueError("Failed to get modulename") | ||
return module_name | ||
|
||
|
||
def generate_twolayer_checkpoints(config_name, hf_id): | ||
from transformers import AutoConfig, AutoModel, AutoTokenizer | ||
|
||
config = AutoConfig.from_pretrained(hf_id, trust_remote_code=True) | ||
# Reduce number of layers to two. | ||
if hasattr(config, 'num_hidden_layers'): | ||
print(config.num_hidden_layers) | ||
config.num_hidden_layers = 2 | ||
elif hasattr(config, 'num_layers'): | ||
print(config.num_layers) | ||
config.num_layers = 2 | ||
else: | ||
print(config) | ||
raise ValueError("HF config has neither num_hidden_layers nor num_layers") | ||
|
||
# Calling random init is slow. | ||
with torch.device('meta'): | ||
model_2l = AutoModel.from_config(config, trust_remote_code=True) | ||
|
||
model_2l = model_2l.to_empty(device='cpu') | ||
state = model_2l.state_dict() | ||
# Fill state-dict with i/n | ||
n = len(state.items()) | ||
for i, key in enumerate(state.keys()): | ||
value = torch.empty_like(state[key]).fill_(i / n) | ||
state[key] = value | ||
model_2l.load_state_dict(state) | ||
model_2l.save_pretrained(f'hf_ckpts/{config_name}/', safe_serialization=False) | ||
hf_tokenizer = AutoTokenizer.from_pretrained(hf_id, trust_remote_code=True) | ||
hf_tokenizer.save_pretrained(f'hf_ckpts/{config_name}/', trust_remote_code=True) | ||
|
||
|
||
def import_from_hf(config_name, hf_path): | ||
import nemo.collections.llm.gpt.model as nemo_ux_llms | ||
from nemo.collections.llm import import_ckpt | ||
|
||
module_name = get_modulename_from_config_name(config_name) | ||
config_cls = getattr(nemo_ux_llms, config_name) | ||
model_cls = getattr(nemo_ux_llms, module_name) | ||
model = model_cls(config_cls()) | ||
import_ckpt(model=model, source=hf_path) | ||
|
||
|
||
if __name__ == '__main__': | ||
for config_name, hf_id in config_name_to_hf_id.items(): | ||
src = f'hf:///home/TestData/nemo2_ckpt/{config_name}' | ||
import_from_hf(config_name, src) |