Skip to content

Commit

Permalink
code review + bug fix
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Kaminski <piotrus.kaminski@gmail.com>
  • Loading branch information
Laplasjan107 committed Dec 23, 2024
1 parent 3e7cdb5 commit 68c2007
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 55 deletions.
26 changes: 11 additions & 15 deletions nemo/export/sentencepiece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,18 @@ def __init__(
legacy: bool = False,
tokenizer: Optional[sentencepiece.SentencePieceProcessor] = None,
):
if tokenizer is not None:
model_path_provided = model_path is not None
tokenizer_provided = tokenizer is not None
if not (model_path_provided ^ tokenizer_provided):
raise ValueError("Exactly only one of the arguments 'model_path', 'tokenizer' should be provided")

if tokenizer_provided:
self.tokenizer = tokenizer
self.legacy = False
self.special_token_to_id = {}
self.id_to_special_token = {}
self.space_sensitive = self.text_to_tokens('x y') != self.text_to_tokens('x') + self.text_to_tokens('y')
self.original_vocab_size = self.tokenizer.get_piece_size()
self.vocab_size = self.tokenizer.get_piece_size()
return

if model_path is None:
raise ValueError("Neither tokenizer nor model_path were provided")
if not model_path or not os.path.exists(model_path):
raise ValueError(f"model_path: {model_path} is invalid")
self.tokenizer = sentencepiece.SentencePieceProcessor()
self.tokenizer.Load(model_path)
else:
if not model_path or not os.path.exists(model_path):
raise ValueError(f"model_path: {model_path} is invalid")
self.tokenizer = sentencepiece.SentencePieceProcessor()
self.tokenizer.Load(model_path)

self.original_vocab_size = self.tokenizer.get_piece_size()
self.vocab_size = self.tokenizer.get_piece_size()
Expand Down
74 changes: 35 additions & 39 deletions nemo/export/vllm/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,40 +91,14 @@ def __init__(
if self.model_converter is None:
raise RuntimeError(f'Unknown model type "{model_type}"')

hf_to_nemo_dict = {
'hidden_size': 'hidden_size',
'intermediate_size': 'ffn_hidden_size',
'num_hidden_layers': 'num_layers',
'num_attention_heads': 'num_attention_heads',
'num_key_value_heads': 'num_query_groups',
# 'hidden_act': 'activation', ## <- vLLM has good defaults for the models, nemo values are wrong
'max_position_embeddings': ['max_position_embeddings', 'encoder_seq_length'],
'rms_norm_eps': 'layernorm_epsilon',
'attention_dropout': 'attention_dropout',
'initializer_range': 'init_method_std',
'norm_epsilon': 'layernorm_epsilon',
'rope_theta': 'rotary_base',
'use_bias': ['bias', 'add_bias_linear'],
}

if is_nemo2_checkpoint(nemo_checkpoint):
from nemo.lightning.io import load_context

nemo_checkpoint: Path = Path(nemo_checkpoint)
self.nemo_model_config: dict = yaml.load(
(nemo_checkpoint / "context/model.yaml").open('r'), Loader=yaml.SafeLoader
)
hf_args = {}
for hf_arg, nemo_arg in hf_to_nemo_dict.items():
if not isinstance(nemo_arg, list):
nemo_arg = [nemo_arg]

for nemo_arg_option in nemo_arg:
value = self.nemo_model_config['config'].get(nemo_arg_option)
if value is not None:
hf_args[hf_arg] = value
break

hf_args = self._load_hf_arguments(self.nemo_model_config)
tokenizer = load_context((nemo_checkpoint / "context"), subpath="model.tokenizer")

if hasattr(tokenizer, 'bos_id'):
Expand All @@ -140,18 +114,7 @@ def __init__(
with TarPath(nemo_checkpoint) as archive:
with (archive / "model_config.yaml").open("r") as model_config_file:
self.nemo_model_config = yaml.load(model_config_file, Loader=yaml.SafeLoader)

hf_args = {}
for hf_arg, nemo_arg in hf_to_nemo_dict.items():
if not isinstance(nemo_arg, list):
nemo_arg = [nemo_arg]

for nemo_arg_option in nemo_arg:
value = self.nemo_model_config.get(nemo_arg_option)
if value is not None:
hf_args[hf_arg] = value
break

hf_args = self._load_hf_arguments(self.nemo_model_config)
self.model_converter.convert_config(self.nemo_model_config, hf_args)
self.hf_config = AutoConfig.for_model(model_type, **hf_args)

Expand All @@ -174,3 +137,36 @@ def __init__(
self._verify_embedding_mode()
self._verify_quantization()
self._verify_cuda_graph()


def _load_hf_arguments(self, nemo_config: dict):
"""Maps argument names used in NeMo to their corresponding names in HF"""

hf_to_nemo_dict = {
'hidden_size': 'hidden_size',
'intermediate_size': 'ffn_hidden_size',
'num_hidden_layers': 'num_layers',
'num_attention_heads': 'num_attention_heads',
'num_key_value_heads': 'num_query_groups',
# 'hidden_act': 'activation', ## <- vLLM has good defaults for the models, nemo values are wrong
'max_position_embeddings': ['max_position_embeddings', 'encoder_seq_length'],
'rms_norm_eps': 'layernorm_epsilon',
'attention_dropout': 'attention_dropout',
'initializer_range': 'init_method_std',
'norm_epsilon': 'layernorm_epsilon',
'rope_theta': 'rotary_base',
'use_bias': ['bias', 'add_bias_linear'],
}

hf_args = {}
for hf_arg, nemo_arg in hf_to_nemo_dict.items():
if not isinstance(nemo_arg, list):
nemo_arg = [nemo_arg]

for nemo_arg_option in nemo_arg:
value = nemo_config.get(nemo_arg_option)
if value is not None:
hf_args[hf_arg] = value
break

return hf_args
2 changes: 1 addition & 1 deletion nemo/export/vllm/model_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def convert_config(self, nemo_model_config: dict, hf_config: dict) -> None:
pass

@abstractmethod
def convert_weights(self, nemo_model_config: dict, state_dict: dict) -> Generator[Tuple[str, torch.tensor]]:
def convert_weights(self, nemo_model_config: dict, state_dict: dict) -> Generator[Tuple[str, torch.tensor], None, None]:
"""
Returns or yields a sequence of (name, tensor) tuples that contain model weights in the HF format.
"""
Expand Down

0 comments on commit 68c2007

Please sign in to comment.