Skip to content

Commit

Permalink
add rope scaling and convertor (#9815)
Browse files Browse the repository at this point in the history
* add rope scaling and convertor

* Apply isort and black reformatting

Signed-off-by: JRD971000 <JRD971000@users.noreply.github.com>

* add rope scaling

* Apply isort and black reformatting

Signed-off-by: JRD971000 <JRD971000@users.noreply.github.com>

* add version for flag

* Apply isort and black reformatting

Signed-off-by: JRD971000 <JRD971000@users.noreply.github.com>

* minor fix

* Apply isort and black reformatting

Signed-off-by: JRD971000 <JRD971000@users.noreply.github.com>

* minor fix

* move patch to nemo

* Apply isort and black reformatting

Signed-off-by: JRD971000 <JRD971000@users.noreply.github.com>

* add tokenizer name

* Apply isort and black reformatting

Signed-off-by: JRD971000 <JRD971000@users.noreply.github.com>

* add tokenizer llama31 in convertor

* Apply isort and black reformatting

Signed-off-by: JRD971000 <JRD971000@users.noreply.github.com>

* Apply isort and black reformatting

Signed-off-by: JRD971000 <JRD971000@users.noreply.github.com>

---------

Signed-off-by: JRD971000 <JRD971000@users.noreply.github.com>
Signed-off-by: Ali Taghibakhshi <71892896+JRD971000@users.noreply.github.com>
Co-authored-by: Ali Taghibakhshi <ataghibakhsh@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: JRD971000 <JRD971000@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 15, 2024
1 parent d298c97 commit 05ced1e
Show file tree
Hide file tree
Showing 6 changed files with 669 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ model:
batch_p2p_comm: True # Batch consecutive inter-peer send/recv operations. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1
seq_len_interpolation_factor: null # RoPE Interpolation factor for sequence length. This is used to build long-context models with RoPE ex: https://arxiv.org/abs/2306.15595.
num_query_groups: null # Number of query groups for group query attention. If None, normal attention is used.
scale_positional_embedding: False # Apply scaling for RoPE frequencies

## Reset learning rate schedule.
# 1. reset_lr=True, reset_lr_steps=False. When pre-training an existing checkpoint "from scratch" on a different dataset.
Expand Down
28 changes: 28 additions & 0 deletions nemo/collections/common/parts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import math
import os
from typing import Iterable, List

logger = logging.getLogger(__name__)

import einops
import torch
import torch.nn as nn
Expand Down Expand Up @@ -109,6 +112,31 @@ def extend_instance(obj, mixin):
) # mixin needs to go first for our forward() logic to work


def apply_rope_scaling(freqs):
# Apply scaling for RoPE frequencies
logger.info("apply rope scaling ...")
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


def mask_sequence_tensor(tensor: torch.Tensor, lengths: torch.Tensor):
"""
For tensors containing sequences, zero out out-of-bound elements given lengths of every element in the batch.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pytorch_lightning.loops.fetchers import _DataFetcherWrapper
from pytorch_lightning.trainer.trainer import Trainer

from nemo.collections.common.parts.utils import extend_instance
from nemo.collections.common.parts.utils import apply_rope_scaling, extend_instance
from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import (
MegatronCorePretrainingSampler,
MegatronPretrainingRandomSampler,
Expand Down Expand Up @@ -77,6 +77,7 @@
from nemo.utils.te_utils import is_float8tensor

try:
import megatron.core as core
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset
Expand Down Expand Up @@ -429,6 +430,7 @@ def get_inference_config(self):
def model_provider_func(self, pre_process, post_process):
"""Model depends on pipeline paralellism."""
if self.mcore_gpt:

model = MCoreGPTModel(
config=self.transformer_config,
transformer_layer_spec=get_specs(
Expand All @@ -448,6 +450,10 @@ def model_provider_func(self, pre_process, post_process):
seq_len_interpolation_factor=self.cfg.get('seq_len_interpolation_factor', None),
rotary_base=self.cfg.get('rotary_base', 10000),
)

if self.cfg.get('scale_positional_embedding', False):
model.rotary_pos_emb.inv_freq = apply_rope_scaling(model.rotary_pos_emb.inv_freq)

if self.cfg.get("apply_embedding_scaling", False) and parallel_state.is_pipeline_first_stage():
extend_instance(model.embedding, EmbeddingScalingMixin)
else:
Expand Down
31 changes: 26 additions & 5 deletions scripts/checkpoint_converters/convert_llama_hf_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
python convert_llama_hf_to_nemo.py \
--input_name_or_path <path_to_hf_checkpoints_folder> \
--output_path <path_to_output_nemo_file>
--precision bf16 \
--llama31 True
"""

import os
Expand Down Expand Up @@ -60,6 +62,13 @@ def get_args():
required=False,
help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml",
)
parser.add_argument(
"--llama31",
type=bool,
default=True,
required=False,
help="Whether the model is from LLaMa 3.1 family. LLaMa 3.1 enables scaling for RoPE frequencies.",
)
parser.add_argument("--precision", type=str, default="16", help="Model precision")
args = parser.parse_args()
return args
Expand Down Expand Up @@ -110,6 +119,7 @@ def load_config(args, llama_config):
while llama_config['vocab_size'] % base != 0:
base //= 2
nemo_config.make_vocab_size_divisible_by = base
nemo_config.scale_positional_embedding = args.llama31

return nemo_config

Expand Down Expand Up @@ -161,6 +171,7 @@ def convert(args):
plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler))

nemo_config.precision = precision
nemo_config.micro_batch_size = 1
print(f"nemo_config: {nemo_config}")

# Remove precision arg, since with PTL >= 2.1 both precision and precision plugin cannot exist together.
Expand Down Expand Up @@ -298,12 +309,22 @@ def convert(args):

# We make sure that the tokenizer can be instantiated later regardless of args.input_name_or_path
if 'tokenizer_model' not in hf_config:
if hf_config['num_hidden_layers'] == 32:
model.cfg.tokenizer.update(type='meta-llama/Meta-Llama-3-8B')
elif hf_config['num_hidden_layers'] == 80:
model.cfg.tokenizer.update(type='meta-llama/Meta-Llama-3-70B')
if args.llama31:
if hf_config['num_hidden_layers'] == 32:
model.cfg.tokenizer.update(type='meta-llama/Meta-Llama-3.1-8B')
elif hf_config['num_hidden_layers'] == 80:
model.cfg.tokenizer.update(type='meta-llama/Meta-Llama-3.1-70B')
elif hf_config['num_hidden_layers'] == 126:
model.cfg.tokenizer.update(type='meta-llama/Meta-Llama-3.1-8B') # 405B tokenizer is the same as 8B
else:
logging.warning("Unexpected model config for Llama3. Tokenizer config has not been modified.")
else:
logging.warning("Unexpected model config for Llama3. Tokenizer config has not been modified.")
if hf_config['num_hidden_layers'] == 32:
model.cfg.tokenizer.update(type='meta-llama/Meta-Llama-3-8B')
elif hf_config['num_hidden_layers'] == 80:
model.cfg.tokenizer.update(type='meta-llama/Meta-Llama-3-70B')
else:
logging.warning("Unexpected model config for Llama3. Tokenizer config has not been modified.")

# cast to target precision and disable cpu init
dtype = torch_dtype_from_precision(precision)
Expand Down
Loading

0 comments on commit 05ced1e

Please sign in to comment.