-
Notifications
You must be signed in to change notification settings - Fork 2.4k
/
megatron_gpt_model.py
2074 lines (1816 loc) · 100 KB
/
megatron_gpt_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import os
import queue
import warnings
from contextlib import nullcontext
from dataclasses import fields
from functools import cache, partial
from importlib.metadata import version
from typing import Any, Dict, Iterator, List, Optional, Union
import torch
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from pkg_resources import packaging
from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.loops.fetchers import _DataFetcherWrapper
from pytorch_lightning.trainer.trainer import Trainer
from nemo.collections.common.parts.utils import extend_instance
from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import (
MegatronPretrainingRandomSampler,
MegatronPretrainingSampler,
)
from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import build_train_valid_test_datasets
from nemo.collections.nlp.data.language_modeling.megatron.gpt_fim_dataset import GPTFIMDataset, GPTFIMDatasetConfig
from nemo.collections.nlp.models.language_modeling.megatron.falcon.falcon_spec import get_falcon_layer_spec
from nemo.collections.nlp.models.language_modeling.megatron.gpt_full_te_layer_autocast_spec import (
get_gpt_full_te_layer_autocast_spec,
)
from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec
from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel
from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel
from nemo.collections.nlp.modules.common.megatron.build_model import build_model
from nemo.collections.nlp.modules.common.megatron.module import Float16Module
from nemo.collections.nlp.modules.common.megatron.utils import (
ApexGuardDefaults,
average_losses_across_data_parallel_group,
get_all_params_for_weight_decay_optimization,
get_ltor_masks_and_position_ids,
get_params_for_weight_decay_optimization,
)
from nemo.collections.nlp.modules.common.text_generation_strategy import TextGenerationStrategy
from nemo.collections.nlp.modules.common.text_generation_utils import (
generate,
get_computeprob_response,
get_default_length_params,
get_default_sampling_params,
megatron_gpt_generate,
)
from nemo.collections.nlp.modules.common.transformer.text_generation import (
LengthParam,
OutputType,
SamplingParam,
TextGeneration,
)
from nemo.collections.nlp.parts import utils_funcs
from nemo.collections.nlp.parts.utils_funcs import activation_to_func, get_last_rank
from nemo.core.classes import Exportable
from nemo.core.classes.common import PretrainedModelInfo
from nemo.core.neural_types import ChannelType, NeuralType
from nemo.utils import logging
from nemo.utils.te_utils import is_float8tensor
try:
import apex.transformer.pipeline_parallel.utils
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False
try:
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace
from megatron.core.dist_checkpointing.mapping import LocalNonpersitentObject, ShardedObject
from megatron.core.distributed import DistributedDataParallel as McoreDDP
from megatron.core.distributed import DistributedDataParallelConfig, finalize_model_grads
# NeMo's implementation of the get_gpt_layer_ammo_spec function is temporarily used
# from megatron.core.inference.gpt.model_specs import get_gpt_layer_ammo_spec
from megatron.core.models.gpt import GPTModel as MCoreGPTModel
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from megatron.core.transformer.module import Float16Module as MCoreFloat16Module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import (
drain_embedding_wgrad_compute,
get_model_config,
init_method_normal,
scaled_init_method_normal,
)
HAVE_MEGATRON_CORE = True
except (ImportError, ModuleNotFoundError):
TransformerConfig = ApexGuardDefaults
HAVE_MEGATRON_CORE = False
try:
import transformer_engine
from transformer_engine.pytorch import module as te_module
HAVE_TE = True
except (ImportError, ModuleNotFoundError):
HAVE_TE = False
@cache
def mcore_supports_moe() -> bool:
global HAVE_MEGATRON_CORE
if not HAVE_MEGATRON_CORE:
return False
try:
from megatron.core.transformer.moe.router import TopKRouter
return True
except ImportError:
return False
def get_specs(spec_name, num_experts=None, moe_grouped_gemm=False, use_te=True):
if num_experts is not None:
assert mcore_supports_moe(), "Megatron-core >= v0.5.0 is required for MoE"
if use_te and spec_name == '':
spec_name = 'te_gpt'
name_spec_dict = {
"": get_gpt_layer_local_spec(num_experts, moe_grouped_gemm),
"te_gpt": get_gpt_layer_with_transformer_engine_spec(num_experts, moe_grouped_gemm),
"megatron_falcon_gpt": get_falcon_layer_spec(),
"megatron_gpt_full_te_layer_autocast": get_gpt_full_te_layer_autocast_spec(),
"modelopt": get_gpt_layer_modelopt_spec(),
}
if spec_name not in name_spec_dict:
raise ValueError(f"Spec name '{spec_name}' is not recognized.")
return name_spec_dict[spec_name]
class EmbeddingScalingMixin(torch.nn.Module):
"""
A mixin class for scaling embeddings in Megatron GPT.
The scaling is applied only if the configuration (accessible via `self.config`)
includes `apply_embedding_scaling` set to True.
"""
def forward(self, **kwargs):
"""
Forward pass that scales the output embeddings from the `forward` method of
the superclass by the square root of the hidden size specified in the configuration.
"""
embeddings = super().forward(**kwargs)
return embeddings * torch.tensor(self.config.hidden_size**0.5, dtype=embeddings.dtype)
class MegatronGPTExportableModel(torch.nn.Module, Exportable):
"""
Megatron GPT Wrapper for ONNX export
"""
def __init__(self, model):
super().__init__()
self.model = model
self.fp8_enabled = model.cfg.get('fp8', False)
self.fp8_recipe = None
if self.fp8_enabled and HAVE_TE:
self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
margin=0, interval=1, fp8_format=transformer_engine.common.recipe.Format.E4M3
)
self.dtype = utils_funcs.torch_dtype_from_precision(model.cfg.precision)
def forward(self, tokens, position_ids, attention_mask):
if self.fp8_enabled and HAVE_TE:
with (
transformer_engine.pytorch.onnx_export(self.fp8_enabled),
transformer_engine.pytorch.fp8_autocast(enabled=self.fp8_enabled, fp8_recipe=self.fp8_recipe),
torch.no_grad(),
torch.inference_mode(),
torch.autocast('cuda', dtype=self.dtype),
warnings.catch_warnings(),
):
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning, module=r'.*')
assert tokens.shape == position_ids.shape
assert attention_mask.shape[2] == attention_mask.shape[3] == tokens.shape[1] == position_ids.shape[1]
output_tensor = self.model.forward(
tokens=tokens.cuda(),
text_position_ids=position_ids.cuda(),
attention_mask=attention_mask.cuda(),
labels=None,
)
else:
with (
torch.no_grad(),
torch.inference_mode(),
torch.autocast('cuda', dtype=self.dtype),
warnings.catch_warnings(),
):
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning, module=r'.*')
assert tokens.shape == position_ids.shape
assert attention_mask.shape[2] == attention_mask.shape[3] == tokens.shape[1] == position_ids.shape[1]
output_tensor = self.model.forward(
tokens=tokens.cuda(),
text_position_ids=position_ids.cuda(),
attention_mask=attention_mask.cuda(),
labels=None,
)
return output_tensor
def freeze(self):
for param in self.parameters():
param.requires_grad = False
def input_example(self, max_batch=1, max_dim=768, seq_len=6):
ids = [self.model.tokenizer.text_to_ids(text) for text in ["how is the weather on Sunday"]]
id_tensors = [torch.unsqueeze(torch.LongTensor(id_list), dim=0) for id_list in ids]
masks_and_position_ids = [
get_ltor_masks_and_position_ids(id_tensor, self.model.tokenizer.eos_id, False, False, False)
for id_tensor in id_tensors
]
for tokens, attn_mask_and_pos_ids in zip(id_tensors, masks_and_position_ids):
attn_mask, _, pos_ids = attn_mask_and_pos_ids
return tokens, pos_ids, attn_mask
@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
return {
"input_ids": NeuralType(('B', 'T'), ChannelType()),
"position_ids": NeuralType(('B', 'T'), ChannelType()),
"attention_mask": NeuralType(('D', 'D', 'T', 'T'), ChannelType()),
}
@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {"logits": NeuralType(('B', 'T', 'D'), ChannelType())}
@property
def input_names(self) -> List[str]:
return ['input_ids', 'position_ids', 'attention_mask']
@property
def output_names(self) -> List[str]:
return ['logits']
class MegatronGPTModel(MegatronBaseModel, TextGeneration):
"""
Megatron GPT pretraining
"""
def __init__(self, cfg: DictConfig, trainer: Trainer):
if not HAVE_APEX:
raise ImportError(
"Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
)
if not HAVE_MEGATRON_CORE:
logging.warning(
"megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
)
# this prevents base constructor from initializing tokenizer
self.tokenizer = None
super().__init__(cfg, trainer=trainer, no_lm_init=True)
self._validate_trainer()
# build the transformer config
# TODO: add type hint once pip package is out
self.transformer_config = self.build_transformer_config()
self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False)
self.mcore_gpt = cfg.get('mcore_gpt', False)
self.spec_name = cfg.get('name', '')
if cfg.get('fp8', False):
self.prev_step_training = True
self.rampup_batch_size = self.cfg.get('rampup_batch_size', None)
if self.rampup_batch_size:
self.prev_consumed_samples = 0
self.if_first_step = 0
self.prev_global_batch_size = None
if cfg.get('data', None) is not None:
self.reset_position_ids = cfg.data.get('reset_position_ids', False)
self.reset_attention_mask = cfg.data.get('reset_attention_mask', False)
self.eod_mask_loss = cfg.data.get('eod_mask_loss', False)
if not self.megatron_amp_O2 and self.cfg.get('virtual_pipeline_model_parallel_size', None):
raise ValueError('Virtual pipeline model parallel is only supported when using megatron_amp_O2')
if not self.megatron_amp_O2 and self.cfg.get('expert_model_parallel_size', 1) > 1:
raise ValueError('Expert parallelism is only supported when using megatron_amp_O2')
if self.cfg.get('expert_model_parallel_size', 1) > 1 and self.with_distributed_adam:
if not self.use_mcore_dist_optim:
raise ValueError(
'Expert parallelism is currently not supporting Apex distributed optimizer, use Mcore distributed optimizer instead'
)
self.transformer_engine = cfg.get('transformer_engine', False)
if self.megatron_amp_O2 and not self.transformer_engine:
logging.warning('megatron_amp_O2 is enabled but transformer-engine is not.')
# build_model returns a list of modules which are used for interleaved pipeline parallelism
if isinstance(self.trainer.accelerator, CPUAccelerator):
self.model = build_model(
model_provider_func=self.model_provider_func,
wrap_with_ddp=False,
on_cpu=True,
virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None),
)
else:
build_model_context = nullcontext
if HAVE_TE and self.cfg.get('fp8', False) and self.cfg.get('fp8_params', False):
build_model_context = transformer_engine.pytorch.fp8_model_init
with build_model_context():
self.model = build_model(
model_provider_func=self.model_provider_func,
wrap_with_ddp=False,
virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None),
on_cpu=cfg.get('use_cpu_initialization', False),
)
# if we're not using interleaved, then self.model is a module.
if self.cfg.get('virtual_pipeline_model_parallel_size', None) is None and (not self.use_mcore_dist_optim):
self.model = self.model[0]
if self.megatron_amp_O2:
if not self.with_distributed_adam and not self.cfg.get("use_cpu_initialization", False):
# Pre-allocate the model on GPU to have master parameters allocated on the same device with matching data type
if isinstance(self.model, list):
for module in self.model:
module.cuda(torch.cuda.current_device())
else:
self.model.cuda(torch.cuda.current_device())
self._wrap_model_for_O2()
self.enable_autocast = (
True if (not self.megatron_amp_O2) and (self.autocast_dtype in [torch.float16, torch.bfloat16]) else False
)
# configuration used for inference
self._inference_config = None
# Convert the global-batch-based profile index to micro-batch index
if hasattr(self, '_nsys_profile_enabled') or hasattr(self, '_memory_profile_enabled'):
mp_size = cfg.get('tensor_model_parallel_size', 1) * cfg.get('pipeline_model_parallel_size', 1)
cp_size = cfg.get('context_parallel_size', 1)
data_parallel_world_size = trainer.world_size // (mp_size * cp_size)
grad_accum_steps = cfg.get('global_batch_size') // (cfg.get('micro_batch_size') * data_parallel_world_size)
if hasattr(self, '_nsys_profile_enabled'):
self._nsys_profile_start_step *= grad_accum_steps
self._nsys_profile_end_step *= grad_accum_steps
if hasattr(self, '_memory_profile_enabled'):
self._memory_profile_start_step *= grad_accum_steps
self._memory_profile_end_step *= grad_accum_steps
self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', True)
self.initialize_ub = self.cfg.get('ub_tp_comm_overlap', False)
self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1)))
self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0)))
self.loss_broadcast_src_rank = None
data_cfg = cfg.get('data', {})
self.return_output_tensors = data_cfg.get('return_output_tensors', False)
self.validation_drop_last = data_cfg.get('validation_drop_last', True)
self.sample_weight = data_cfg.get('sample_weight', 'token')
self.validation_param_sync_overlap = self.cfg.get('validation_param_sync_overlap', False)
self.inference_params = None
# default to false since this doesn't work with sequence parallelism currently
self.use_loss_mask = self.cfg.get('use_loss_mask', False)
if self.use_loss_mask and self.transformer_config.sequence_parallel:
raise ValueError('Loss mask is not supported with sequence parallelism.')
def set_inference_config(self, inference_config):
self._inference_config = inference_config
def get_inference_config(self):
return self._inference_config
def model_provider_func(self, pre_process, post_process):
"""Model depends on pipeline paralellism."""
if self.mcore_gpt:
model = MCoreGPTModel(
config=self.transformer_config,
transformer_layer_spec=get_specs(
self.spec_name,
self.transformer_config.num_moe_experts,
self.transformer_config.moe_grouped_gemm,
self.transformer_engine,
),
vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size),
max_sequence_length=self.cfg.get('encoder_seq_length', 512),
pre_process=pre_process,
post_process=post_process,
parallel_output=True,
share_embeddings_and_output_weights=self.cfg.get('share_embeddings_and_output_weights', True),
position_embedding_type=self.cfg.get('position_embedding_type', 'learned_absolute'),
rotary_percent=self.cfg.get('rotary_percentage', 1.0),
seq_len_interpolation_factor=self.cfg.get('seq_len_interpolation_factor', None),
rotary_base=self.cfg.get('rotary_base', 10000),
)
if self.cfg.get("apply_embedding_scaling", False) and parallel_state.is_pipeline_first_stage():
extend_instance(model.embedding, EmbeddingScalingMixin)
else:
assert self.cfg.get('num_query_groups', None) is None or self.cfg.get(
'num_query_groups', None
) == self.cfg.get(
'num_attention_heads', None
), "Group Query Attention is only supported in Megatron Core. Set 'mcore_gpt' to use GQA."
model = GPTModel(
config=self.model_parallel_config,
vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size),
hidden_size=self.cfg.hidden_size,
max_position_embeddings=self.cfg.max_position_embeddings,
num_layers=self.cfg.num_layers,
num_attention_heads=self.cfg.num_attention_heads,
apply_query_key_layer_scaling=self.cfg.get('apply_query_key_layer_scaling', True),
kv_channels=self.cfg.get('kv_channels', None),
ffn_hidden_size=self.cfg.ffn_hidden_size,
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process,
init_method_std=self.cfg.get('init_method_std', 0.02),
use_scaled_init_method=self.cfg.get('use_scaled_init_method', True),
fp16_lm_cross_entropy=self.cfg.get('fp16_lm_cross_entropy', False),
hidden_dropout=self.cfg.get('hidden_dropout', 0.1),
attention_dropout=self.cfg.get('attention_dropout', 0.1),
ffn_dropout=self.cfg.get('ffn_dropout', 0.0),
precision=self.cfg.get('precision', 16),
fp32_residual_connection=self.cfg.get('fp32_residual_connection', False),
activations_checkpoint_granularity=self.cfg.get('activations_checkpoint_granularity', None),
activations_checkpoint_method=self.cfg.get('activations_checkpoint_method', None),
activations_checkpoint_num_layers=self.cfg.get('activations_checkpoint_num_layers', 1),
activations_checkpoint_layers_per_pipeline=self.cfg.get(
'activations_checkpoint_layers_per_pipeline', None
),
normalization=self.cfg.get('normalization', 'layernorm'),
layernorm_epsilon=self.cfg.get('layernorm_epsilon', 1e-5),
onnx_safe=self.cfg.get('onnx_safe', False),
bias=self.cfg.get('bias', True),
bias_activation_fusion=self.cfg.get('bias_activation_fusion', True),
bias_dropout_add_fusion=self.cfg.get('bias_dropout_add_fusion', True),
activation=self.cfg.get('activation', 'gelu'),
headscale=self.cfg.get('headscale', False),
transformer_block_type=self.cfg.get('transformer_block_type', 'pre_ln'),
openai_gelu=self.cfg.get('openai_gelu', False),
normalize_attention_scores=self.cfg.get('normalize_attention_scores', True),
position_embedding_type=self.cfg.get('position_embedding_type', 'learned_absolute'),
rotary_percentage=self.cfg.get('rotary_percentage', 1.0),
share_embeddings_and_output_weights=self.cfg.get('share_embeddings_and_output_weights', True),
attention_type=self.cfg.get('attention_type', 'multihead'),
masked_softmax_fusion=self.cfg.get('masked_softmax_fusion', True),
persist_layer_norm=self.cfg.get('persist_layer_norm', False),
transformer_engine=self.cfg.get('transformer_engine', False),
fp8=self.cfg.get('fp8', False),
fp8_e4m3=self.cfg.get('fp8_e4m3', False),
fp8_hybrid=self.cfg.get('fp8_hybrid', False),
fp8_margin=self.cfg.get('fp8_margin', 0),
fp8_interval=self.cfg.get('fp8_interval', 1),
fp8_amax_history_len=self.cfg.get('fp8_amax_history_len', 1024),
fp8_amax_compute_algo=self.cfg.get('fp8_amax_compute_algo', 'max'),
reduce_amax=self.cfg.get('reduce_amax', True),
use_emha=self.cfg.get('use_emha', False),
ub_tp_comm_overlap=self.cfg.get('ub_tp_comm_overlap', False),
use_flash_attention=self.cfg.get('use_flash_attention', False),
megatron_legacy=self.cfg.get('megatron_legacy', False),
seq_len_interpolation_factor=self.cfg.get('seq_len_interpolation_factor', None),
rotary_base=self.cfg.get('rotary_base', 10000),
)
if self.cfg.get("apply_embedding_scaling", False) and parallel_state.is_pipeline_first_stage():
extend_instance(model.language_model.embedding, EmbeddingScalingMixin)
return model
def setup_optimizer_param_groups(self):
"""ModelPT override. Optimizer will get self._optimizer_param_groups"""
if self.cfg.get('do_layer_norm_weight_decay', False):
if isinstance(self.model, list):
self._optimizer_param_groups = get_all_params_for_weight_decay_optimization(self.model)
else:
self._optimizer_param_groups = get_all_params_for_weight_decay_optimization([self.model])
else:
self._optimizer_param_groups = get_params_for_weight_decay_optimization(self.model)
def setup_mcore_distributed_parallel(self):
"""Set up mcore distributed data parallel"""
if self.with_distributed_adam and self.use_mcore_dist_optim:
config = get_model_config(self.model[0])
ddp_config = DistributedDataParallelConfig(
grad_reduce_in_fp32=(self.cfg.optim.get('grad_sync_dtype', 'fp32') == 'fp32'),
overlap_grad_reduce=self.cfg.optim.get('overlap_grad_sync', False),
use_distributed_optimizer=True,
check_for_nan_in_grad=self.cfg.optim.get('check_for_nan_in_grad', False),
# mcore bucket_size is based on num of parameters, therefore not
# using bucket_cap_mb to configure bucket_size here
bucket_size=self.cfg.optim.get('ddp_bucket_size', None),
)
self.model = [
McoreDDP(
config,
ddp_config,
model_chunk,
data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0),
)
for (model_chunk_idx, model_chunk) in enumerate(self.model)
]
# (TODO) Broadcast params from data parallel src rank to other data parallel ranks.
# by calling model_module.broadcast_params() if the model is randomly initialized.
def configure_optimizers(self):
if self.with_distributed_adam and not self.use_mcore_dist_optim:
# Special handling for embedding grads
with_fp32_embedding_grads = self.cfg.get('with_fp32_embedding_grads', True)
modules = self.get_model_module_list()
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
module = modules[0] # first virtual rank has the embeddings
# Word embeddings: use FP32 grads and disable
# overlapped grad sync with pipeline parallelism
word_embeddings = (
module.shared_embedding_or_output_weight() if self.mcore_gpt else module.word_embeddings_weight()
)
word_embeddings._with_fp32_optimizer = with_fp32_embedding_grads
if parallel_state.get_pipeline_model_parallel_world_size() > 1 and self.cfg.get(
'share_embeddings_and_output_weights', True
):
word_embeddings._disable_greedy_grad_copy = not self.megatron_amp_O2
word_embeddings._disable_overlap_grad_sync = True
# Position embeddings: use FP32 grads
position_embeddings = None
if self.mcore_gpt:
if module.embedding.add_position_embedding:
position_embeddings = module.embedding.position_embeddings.weight
else:
position_embeddings = module.position_embeddings_weight()
if position_embeddings is not None:
position_embeddings._with_fp32_optimizer = with_fp32_embedding_grads
# Handle case where embeddings are used in output layer
if parallel_state.is_pipeline_last_stage(ignore_virtual=True) and self.cfg.get(
'share_embeddings_and_output_weights', True
):
module = modules[-1] # last virtual rank has the embeddings
word_embeddings = (
module.shared_embedding_or_output_weight() if self.mcore_gpt else module.word_embeddings_weight()
)
word_embeddings._with_fp32_optimizer = with_fp32_embedding_grads
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
word_embeddings._disable_greedy_grad_copy = not self.megatron_amp_O2
word_embeddings._disable_overlap_grad_sync = True
# Disable overlapped grad sync for layer norm grads when
# sequence parallelism is enabled
for param in self.parameters():
if getattr(param, 'sequence_parallel', False):
param._disable_greedy_grad_copy = not self.megatron_amp_O2
param._disable_overlap_grad_sync = True
# Initialize parameter buckets for overlapped grad and param syncs
# Note: Params with disabled overlapping and params in the
# first layer are put together in a bucket. If FP8 tensors
# are detected, those are also put in the first layer's
# bucket.
def make_parameter_bucket(module: torch.nn.Module) -> List[torch.nn.Parameter]:
bucket = [
param for param in module.parameters() if not getattr(param, '_disable_overlap_grad_sync', False)
]
if any(is_float8tensor(param) for param in bucket):
bucket = list(filter(is_float8tensor, bucket))
return bucket
buckets = []
if self.cfg.get('virtual_pipeline_model_parallel_size', None) is not None:
# Initialize a bucket for each virtual pipeline stage
for module in self.model:
if isinstance(module, (Float16Module, MCoreFloat16Module)):
module = module.module
stage_bucket = []
layers = module.decoder.layers if self.mcore_gpt else module.language_model.encoder.layers
buckets.extend(make_parameter_bucket(layer) for layer in layers)
else:
# Initialize a bucket for each Transformer layer
modules = self.model if isinstance(self.model, list) else [self.model]
for module in modules:
if isinstance(module, (Float16Module, MCoreFloat16Module)):
module = module.module
layers = module.decoder.layers if self.mcore_gpt else module.language_model.encoder.layers
buckets.extend(make_parameter_bucket(layer) for layer in layers)
buckets.reverse()
used_params = set(itertools.chain.from_iterable(buckets))
buckets[-1].extend(p for p in self.parameters() if p not in used_params)
self.distributed_adam_buckets = buckets
return super().configure_optimizers()
def forward(self, tokens, text_position_ids, attention_mask, labels):
output_tensor = self.model(tokens, text_position_ids, attention_mask, labels=labels)
return output_tensor
def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None):
# handle asynchronous grad reduction
no_sync_func = None
grad_sync_func = None
param_sync_func = None
if self.with_distributed_adam:
if forward_only:
if self.validation_param_sync_overlap:
param_sync_func = self.sync_overlap_parameters
elif not self.use_mcore_dist_optim:
no_sync_func = partial(
self._optimizer.no_sync,
greedy_grad_copy=self.megatron_amp_O2,
)
grad_sync_func = self.reduce_overlap_gradients
param_sync_func = self.sync_overlap_parameters
else:
if self.cfg.optim.get("overlap_grad_sync", False):
no_sync_func = [model_chunk.no_sync for model_chunk in self.model]
no_sync_func = no_sync_func[0] if len(self.model) == 1 else no_sync_func
if self.cfg.optim.get("delay_grad_reduce", True):
grad_sync_func = [model_chunk.start_grad_sync for model_chunk in self.model]
grad_sync_func = grad_sync_func[0] if len(self.model) == 1 else grad_sync_func
if self.cfg.optim.get("overlap_param_sync", False) and self.cfg.optim.get("delay_param_gather", False):
param_sync_func = [
lambda x, model_index=model_index: self._optimizer.finish_param_sync(model_index, x)
for model_index in range(len(self.model))
]
param_sync_func = param_sync_func[0] if len(self.model) == 1 else param_sync_func
# pipeline schedules will get these from self.model.config
for module in self.get_model_module_list():
module.config.no_sync_func = no_sync_func
module.config.grad_sync_func = grad_sync_func
module.config.param_sync_func = param_sync_func
if self.use_mcore_dist_optim:
module.config.finalize_model_grads_func = finalize_model_grads
# run forward and backwards passes for an entire global batch
# we do this inside training_step to support pipeline parallelism
fwd_bwd_function = get_forward_backward_func()
# TODO @akhattar: add num_micro_batches_with_partial_activation_checkpoints when ready
losses_reduced_per_micro_batch = fwd_bwd_function(
forward_step_func=self.get_forward_output_and_loss_func(forward_only),
data_iterator=self._make_data_iterator_list(dataloader_iter),
model=self.model,
num_microbatches=get_num_microbatches(),
forward_only=forward_only,
seq_length=self.cfg.encoder_seq_length,
micro_batch_size=self.cfg.micro_batch_size,
first_val_step=first_val_step,
)
# only the last stages of the pipeline return losses
if losses_reduced_per_micro_batch:
if (not forward_only) or self.validation_drop_last:
# average loss across micro batches
loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch]
loss_tensor = torch.concat(loss_tensors_list)
loss_mean = loss_tensor.mean()
else:
# Get the total loss since micro batches sizes are not uniform
loss_sum_tensors_list = [
loss_sum['loss_sum_and_ub_size']
for loss_sum in losses_reduced_per_micro_batch
if loss_sum['loss_sum_and_ub_size'][1] > 0
]
loss_sum = (
torch.vstack(loss_sum_tensors_list).sum(axis=0)
if len(loss_sum_tensors_list) > 0
else torch.tensor([0.0, 0.0]).cuda()
)
return loss_sum
else:
# we're not on the last pipeline stage so no losses
if forward_only:
loss_mean = []
else:
loss_mean = torch.tensor(0.0).cuda()
return loss_mean
def initialize_ub_func(self):
ub_cfgs = self.cfg.get('ub_tp_comm_overlap_cfg', None)
if ub_cfgs is None:
warnings.warn(
"Couldn't find TP config. Please check the path correctness. Initializing TP comm overlap with the default config."
)
input_shape = [
self.cfg.get('encoder_seq_length')
* self.cfg.get('micro_batch_size')
// self.cfg.get('context_parallel_size', 1),
self.cfg.get('hidden_size'),
]
te_module.base.initialize_ub(
shape=input_shape,
tp_size=self.cfg.get('tensor_model_parallel_size'),
use_fp8=self.cfg.get('fp8'),
ub_cfgs=ub_cfgs,
)
self.initialize_ub = False
def training_step_fwd_bwd_step_call(self, dataloader_iter, forward_only):
"""
This method is called from the training_step method.
It is separated out to allow for overriding in the MegatronGPTEmbeddingModel
"""
loss_mean = self.fwd_bwd_step(dataloader_iter, forward_only)
return loss_mean
def training_step(self, dataloader_iter):
"""
We pass the dataloader iterator function to the micro-batch scheduler.
The input batch to each micro-batch is fetched using the dataloader function
in the micro-batch fwd function.
"""
# Initialize userbuffer communicators.
if self.initialize_ub:
self.initialize_ub_func()
if self.rampup_batch_size:
num_microbatch_calculator = apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR
current_global_batch_size = num_microbatch_calculator.current_global_batch_size
# do validation and save the checkpoint when gbs is changed
if self.prev_global_batch_size != current_global_batch_size and self.prev_global_batch_size:
self.trainer.should_stop = True
# zero out the mcore grad buf
if self.use_mcore_dist_optim:
for model_chunk in self.model:
model_chunk.zero_grad_buffer()
# we zero grads here because we also call backward in the megatron-core fwd/bwd functions
self._optimizer.zero_grad()
if self.with_distributed_adam and not self.use_mcore_dist_optim:
# hack to enable overlapping param sync and forward compute
# note: the distributed optimizer monkey-patches each
# parameter's __getattribute__ function so that it can
# launch parameter all-gathers the first time the
# parameter is accessed after the optimizer step. However,
# PyTorch directly passes embedding parameters into a C++,
# bypassing this process. A quick-and-dirty hack is to
# manually interact with the parameter.
modules = self.model if isinstance(self.model, list) else [self.model]
for module in modules:
if isinstance(module, (Float16Module, MCoreFloat16Module)):
module = module.module
if not self.mcore_gpt:
module = module.language_model
if hasattr(module, 'embedding'):
for param in module.embedding.parameters():
param.data_ptr()
if self.cfg.get('pipeline_model_parallel_size', 1) > 1 and parallel_state.is_pipeline_last_stage(
ignore_virtual=True
):
if (
self.cfg.get('defer_embedding_wgrad_compute', False) and self.mcore_gpt
): # Silently ignore the optimization if MCORE is not used
module_list = self.get_model_module_list()
if len(module_list) > 1:
embedding_module = module_list[-1]
else:
embedding_module = module_list[0]
embedding_module.embedding_activation_buffer.clear()
assert (
len(embedding_module.embedding_activation_buffer) == 0
), "When you defer wgrads, this buffer should not hold stray activations"
loss_mean = self.training_step_fwd_bwd_step_call(dataloader_iter, forward_only=False)
if self.cfg.get('fp8', False):
self.prev_step_training = self.training
# Optimization: Defer the embedding GEMM Wgrads of the last PP stage to pipeline flush waiting time
if self.cfg.get('pipeline_model_parallel_size', 1) > 1 and parallel_state.is_pipeline_last_stage(
ignore_virtual=True
):
if (
self.cfg.get('defer_embedding_wgrad_compute', False) and self.mcore_gpt
): # Silently ignore the optimization if MCORE is not used
module_list = self.get_model_module_list()
if len(module_list) > 1:
embedding_module = module_list[-1]
else:
embedding_module = module_list[0]
embedding_activation_buffer = embedding_module.embedding_activation_buffer
grad_output_buffer = embedding_module.grad_output_buffer
if self.cfg.get('share_embeddings_and_output_weights', True):
weight = embedding_module.shared_embedding_or_output_weight()
else:
weight = embedding_module.output_layer.weight
drain_embedding_wgrad_compute(
embedding_module.config, embedding_activation_buffer, grad_output_buffer, weight
)
# when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced
if self.cfg.get('tensor_model_parallel_size', 1) > 1 and self.cfg.get('sequence_parallel', False):
# Mcore DistOpt handles this, so we don't have to
if not self.use_mcore_dist_optim:
self.megatron_timer_start('allreduce_sequence_parallel_gradients', log_level=1)
self.allreduce_sequence_parallel_gradients()
self.megatron_timer_stop('allreduce_sequence_parallel_gradients')
self.megatron_timer_start('gradient_allreduce', log_level=1)
if self.use_fsdp:
# Reduce the gradients omitted from FSDP-sharding
self.allreduce_fsdp_sharding_omitted_gradients()
elif self.with_distributed_adam:
if not self.use_mcore_dist_optim:
# synchronize asynchronous grad reductions
# note: not necessary, but reduces performance degradation
# from multiple simultaneous NCCL calls
self._optimizer._finish_bucket_grad_sync()
# else: Mcore distributed optim calls finalize_model_grads to finish grad sync
elif self.megatron_amp_O2:
# when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously)
if (
self.cfg.get('pipeline_model_parallel_size', 1) > 1
or self.cfg.get('sequence_parallel', False)
or not self.cfg.get('async_grad_allreduce', True)
):
# main grads are stored in the MainParamsOptimizer wrapper
self._optimizer.allreduce_main_grads()
else:
# async grad allreduce is not currently implemented for O1/autocasting mixed precision training
# so we all-reduce gradients after the pipeline
self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf)
self.megatron_timer_stop('gradient_allreduce')
if (
not self.use_mcore_dist_optim
and self.cfg.get('pipeline_model_parallel_size', 1) > 1
and self.cfg.get('share_embeddings_and_output_weights', True)
):
self.megatron_timer_start('allreduce_first_last_embeddings', log_level=1)
# when using pipeline parallelism the first and last stage must keep embeddings in sync
self.allreduce_first_last_embeddings()
self.megatron_timer_stop('allreduce_first_last_embeddings')
if self.log_memory_usage:
max_memory_reserved = torch.cuda.max_memory_reserved()
memory_allocated = torch.cuda.memory_allocated()
self.log(
'peak_memory_usage',
max_memory_reserved,
prog_bar=True,
rank_zero_only=True,
batch_size=1,
)
self.log(
'memory_allocated',
memory_allocated,
prog_bar=True,
rank_zero_only=True,
batch_size=1,
)
## logging
if self.log_train_loss:
# When using pipeline parallelism, loss is calculated only in the last pipeline stage and
# it should be casted to other pipeline stages for logging.
# we can avoid this broadcast by updating the PTL log function to accept specific ranks
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if torch.distributed.get_rank() == get_last_rank():
torch.distributed.send(loss_mean, 0)
elif torch.distributed.get_rank() == 0:
torch.distributed.recv(loss_mean, get_last_rank())
self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1)
# (@adithyare) we need to check for the _scaler attribute to enable pp>1 for adapter training
if self.cfg.precision == 16 and hasattr(self.trainer.precision_plugin.scaler, "_scale"):
loss_scale = self.trainer.precision_plugin.scaler._scale
if loss_scale is not None:
self.log('loss_scale', loss_scale, batch_size=1)
lr = self._optimizer.param_groups[0]['lr']
self.log('lr', lr, rank_zero_only=True, batch_size=1)
self.log(
'global_step',
self.trainer.global_step,
prog_bar=True,
rank_zero_only=True,
batch_size=1,
)
consumed_samples = self._compute_consumed_samples_after_training_step()
# TODO: make sure compute_consumed_samples works for pipeline parallelism
self.log(
'consumed_samples',
consumed_samples,
prog_bar=True,
rank_zero_only=True,
batch_size=1,
)
if self.rampup_batch_size:
self.prev_global_batch_size = current_global_batch_size
self.prev_consumed_samples = consumed_samples
num_microbatch_calculator.update(
consumed_samples=consumed_samples,
consistency_check=False,
)
current_global_batch_size = num_microbatch_calculator.current_global_batch_size
self.log('global_batch_size', current_global_batch_size, prog_bar=True, rank_zero_only=True, batch_size=1)
self.if_first_step = 1
return loss_mean
def backward(self, *args, **kwargs):
"""LightningModule hook to do backward.
We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core.
No need to call it here.
"""
return
def optimizer_zero_grad(self, *args, **kwargs):
"""LightningModule hook to zero grad.
We want this to do nothing as we are zeroing grads during the training_step.
"""
return
def _append_sequence_parallel_module_grads(self, module, grads):
"""Helper method for allreduce_sequence_parallel_gradients"""
for param in module.parameters():
sequence_parallel_param = getattr(param, 'sequence_parallel', False) or getattr(
param, 'sequence_parallel_enabled', False
)
# (@adithyare) adapter training now extends MegatronGPTModel
# so we have to add this check here to ensure we do not
# perform all_reduce when grad is None.
# grad can be None when performing PEFT training.
if sequence_parallel_param and param.requires_grad:
if self.megatron_amp_O2:
grad = param.main_grad
else:
grad = param.grad
grads.append(grad.data)
def allreduce_sequence_parallel_gradients(self):
"""All-reduce layernorm parameters across model parallel nodes when sequence parallelism is used.
Modified from megatron-lm:
https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/3f91f09bb2ab32f9904b47f46f19d2fc3f518ed8/megatron/training.py#L425
"""
grads = []
if isinstance(self.model, list):
for module in self.model:
self._append_sequence_parallel_module_grads(module, grads)
else:
self._append_sequence_parallel_module_grads(self.model, grads)