-
-
Notifications
You must be signed in to change notification settings - Fork 877
/
models.py
1379 lines (1189 loc) · 52 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Module for models and model loading"""
# pylint: disable=too-many-lines
import gc
import logging
import math
import os
import types
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
import addict
import bitsandbytes as bnb
import torch
import transformers
import transformers.modeling_utils
from accelerate import init_empty_weights
from bitsandbytes.nn import Params4bit
from peft import (
LoftQConfig,
PeftConfig,
PeftModel,
PeftModelForCausalLM,
prepare_model_for_kbit_training,
)
from peft.tuners.lora import QuantLinear
from torch import nn
from transformers import ( # noqa: F401
AddedToken,
AutoConfig,
AutoModelForCausalLM,
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
AwqConfig,
BitsAndBytesConfig,
GPTQConfig,
LlavaForConditionalGeneration,
MllamaForConditionalGeneration,
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
)
from transformers.integrations.deepspeed import (
HfTrainerDeepSpeedConfig,
is_deepspeed_zero3_enabled,
)
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack,
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
LOG = logging.getLogger("axolotl")
# copied from accelerator.FullyShardedDataParallelPlugin
def get_module_class_from_name(module, name):
"""
Gets a class from a module by its name.
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
if len(modules_children) == 0:
return None
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class
return None
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
if cfg.is_multimodal:
model_config = model_config.text_config
quant_config_exists = (
hasattr(model_config, "quantization_config")
and model_config.quantization_config
)
quant_config_method_is_gptq = (
quant_config_exists
and "quant_method" in model_config.quantization_config
and model_config.quantization_config["quant_method"] == "gptq"
)
if cfg.gptq and not quant_config_method_is_gptq:
raise ValueError(
"model_config.quantization_config is not set or quant_method is not set to gptq. "
"Please make sure to point to a GPTQ model."
)
if not cfg.gptq and quant_config_exists and not cfg.load_in_4bit:
raise ValueError(
"model_config.quantization_config is set but `gptq` flag is not. "
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
)
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
if (
cfg.adapter
and cfg.tokens
and (
not cfg.lora_modules_to_save
or not all(x in cfg.lora_modules_to_save for x in lora_modules_to_save)
)
):
lora_modules_to_save = ", ".join(map(lambda x: f"`{x}`", lora_modules_to_save))
raise ValueError(
f"`lora_modules_to_save` not properly set when adding new tokens. Please include [{lora_modules_to_save}] in `lora_modules_to_save`."
)
def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model
if not model_config_name and cfg.tokenizer_config:
model_config_name = cfg.tokenizer_config
trust_remote_code = cfg.trust_remote_code is True
config_kwargs = {}
if cfg.revision_of_model:
config_kwargs["revision"] = cfg.revision_of_model
try:
model_config = AutoConfig.from_pretrained(
model_config_name,
trust_remote_code=trust_remote_code,
**config_kwargs,
)
except ValueError as err:
if "mamba" in model_config_name:
return addict.Dict(
{
"model_type": "mamba",
}
)
raise err
if cfg.overrides_of_model_config:
for key, val in cfg.overrides_of_model_config.items():
setattr(model_config, key, val)
check_model_config(cfg, model_config)
return model_config
def load_tokenizer(cfg):
model_config = load_model_config(cfg)
tokenizer_kwargs = {}
use_fast = True # this is the default
if cfg.tokenizer_use_fast is not None:
use_fast = cfg.tokenizer_use_fast
if cfg.tokenizer_legacy is not None:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
tokenizer_cls = AutoTokenizer
if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
tokenizer = tokenizer_cls.from_pretrained(
cfg.tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
if (
tokenizer.__class__.__name__
in [
"LlamaTokenizer",
"LlamaTokenizerFast",
"CodeLlamaTokenizer",
"CodeLlamaTokenizerFast",
]
and hasattr(tokenizer, "pad_token")
and not tokenizer.pad_token
):
# set a pad_token, but use eos_token so we don't add a new token
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Mistral's official FA implementation requires left padding
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
tokenizer.padding_side = "left"
# Qwen base only has single token, so we need to set the special tokens
if cfg.is_qwen_derived_model:
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
for attr_name in token_ids:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, tokenizer.eod_id)
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
for attr_name in token_names:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
additional_special_tokens = None
if cfg.special_tokens:
special_tokens = cfg.special_tokens.to_dict()
additional_special_tokens = special_tokens.pop(
"additional_special_tokens", None
)
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
for k, val in special_tokens.items():
# check if new special token is not already in tokenizer and
# is adapter training to make sure lora_modules_to_save is set
# pylint: disable=too-many-boolean-expressions
if (
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
and cfg.adapter
and (
not cfg.lora_modules_to_save
or not all(
x in cfg.lora_modules_to_save for x in lora_modules_to_save
)
)
and k != "pad_token"
):
lora_modules_to_save = ", ".join(
[f"`{x}`" for x in lora_modules_to_save]
)
raise ValueError(
f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens."
)
tokenizer.add_special_tokens(
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
)
# If we add bos_token and eos_token, we need to update the post processor to
# handle them correctly.
# https://github.com/huggingface/transformers/pull/24132
bos_or_eos_in_special_tokens = (
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
)
if (
tokenizer.__class__.__name__
in (
"LlamaTokenizerFast",
"CodeLlamaTokenizerFast",
)
and bos_or_eos_in_special_tokens
):
tokenizer.update_post_processor()
if cfg.tokens:
tokenizer.add_tokens(
[
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
for token in cfg.tokens
]
)
# Additional special tokens are a List, and need to be treated differently than regular special
# tokens. We add them after we have called `add_tokens` in case these additional special tokens
# are new tokens.
#
# Usage:
#
# ```py
# special_tokens:
# additional_special_tokens: ["<|im_start|>", "<|im_end|>"]
# ```
if additional_special_tokens is not None:
tokenizer.add_special_tokens(
{"additional_special_tokens": additional_special_tokens}
)
with zero_only():
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if cfg.chat_template:
chat_template_string = get_chat_template_from_config(
cfg=cfg,
tokenizer=tokenizer,
)
if cfg.default_system_message and cfg.chat_template == "chatml":
chat_template_string = chat_template_string.replace(
"You are a helpful assistant.", cfg.default_system_message
)
tokenizer.chat_template = chat_template_string
else:
LOG.info(
"No Chat template selected. Consider adding a chat template for easier inference."
)
return tokenizer
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
processor_kwargs: Dict[str, Any] = {} # do we actually need this?
processor_cls = AutoProcessor
if cfg.processor_type:
processor_cls = getattr(transformers, cfg.processor_type)
processor = processor_cls.from_pretrained(
cfg.processor_config,
trust_remote_code=cfg.trust_remote_code or False,
tokenizer=tokenizer,
**processor_kwargs,
)
return processor
class ModelLoader:
"""
ModelLoader: managing all the config and monkey patches while loading model
"""
def __init__(
self,
cfg: DictDefault,
tokenizer: PreTrainedTokenizerBase,
*,
processor: ProcessorMixin = None, # pylint: disable=unused-argument
inference: bool = False,
reference_model: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> None:
self.cfg = cfg
self.tokenizer = tokenizer
self.inference: bool = inference
self.reference_model: bool = reference_model
# init model kwargs
self.model_kwargs: Dict[str, Any] = {}
if cfg.model_kwargs:
for key, val in cfg.model_kwargs.items():
self.model_kwargs[key] = val
# init model
self.model: PreTrainedModel
self.base_model = cfg.base_model
self.model_type = cfg.type_of_model
# init model config
self.model_config = load_model_config(cfg)
if cfg.is_multimodal:
self.text_model_config = self.model_config.text_config
else:
self.text_model_config = self.model_config
self.AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name
def apply_patches(self) -> None:
# load any patches from plugins
from axolotl.integrations.base import PluginManager
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg)
if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if self.cfg.flash_attention:
self.patch_attention()
if self.cfg.sample_packing and self.cfg.s2_attention:
raise ValueError(
"Received `sample_packing=true` and `s2_attention=true`; however, \
shifted-sparse attention does not currently support sample packing."
)
if (
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and self.cfg.flash_attention
and self.cfg.sample_packing
):
has_remote_code = (
"auto_map" in self.model_config
and "AutoModelForCausalLM" in self.model_config["auto_map"]
)
if has_remote_code and self.cfg.trust_remote_code is False:
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
has_remote_code = self.cfg.trust_remote_code
patch_for_multipack(
self.cfg.model_config_type,
model_name=self.cfg.base_model,
has_remote_code=has_remote_code,
)
if self.cfg.is_llama_derived_model:
self.patch_loss()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
elif self.cfg.is_llama_derived_model:
self.patch_llama_derived_model()
if (
self.cfg.model_config_type == "mistral"
and self.cfg.flash_attn_cross_entropy_loss
):
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
patch_mistral_cross_entropy,
)
patch_mistral_cross_entropy()
def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
from axolotl.monkeypatch.attention.mllama import patch_mllama
patch_mllama()
if self.model_config.model_type == "btlm":
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
replace_btlm_attn_with_flash_attn,
)
replace_btlm_attn_with_flash_attn(self.cfg.base_model)
if (
self.model_config.model_type == "stablelm_epoch"
and self.cfg.sample_packing
):
from axolotl.monkeypatch.stablelm_attn_hijack_flash import (
replace_stablelm_attn_with_flash_attn,
)
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
def patch_loss(self) -> None:
"""
Patch loss functions
"""
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)
if self.cfg.flash_attn_cross_entropy:
patch_llama_cross_entropy()
if self.cfg.flash_attn_rms_norm:
patch_llama_rms_norm()
elif self.cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
patch_unsloth_layernorm()
if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
def patch_llama_derived_model(self) -> None:
"""
Modify all llama derived models in one block
"""
if self.cfg.flash_attention:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
if self.cfg.sample_packing:
if self.cfg.device not in ["mps", "cpu"] and not self.inference:
LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(
packed=True,
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
)
elif self.cfg.s2_attention:
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
replace_llama_attn_with_flash_attn(
packed=False,
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
use_shifted_sparse_attn=True,
)
elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm:
replace_llama_attn_with_flash_attn(
packed=False,
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
)
elif self.cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)
LOG.info("patching with xformers attention")
hijack_llama_attention()
elif self.cfg.sample_packing:
from axolotl.monkeypatch.llama_patch_multipack import (
hijack_llama_prepare_4d_mask,
)
LOG.info("patching llama _prepare_4d_causal_attention_mask*")
hijack_llama_prepare_4d_mask()
elif self.cfg.s2_attention:
raise NotImplementedError(
"Shifted-sparse attention not currently implemented without flash attention."
)
if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
def set_auto_model_loader(self) -> None:
"""set self.AutoModelLoader
- default value: AutoModelForCausalLM (set at __init__)
- when using a multi modality model, self.AutoModelLoader should
be set according to model type of the model
"""
if self.cfg.is_multimodal:
if self.model_config.model_type == "llava":
self.AutoModelLoader = ( # pylint: disable=invalid-name
LlavaForConditionalGeneration
)
elif self.model_config.model_type == "mllama":
self.AutoModelLoader = ( # pylint: disable=invalid-name
MllamaForConditionalGeneration
)
else:
self.AutoModelLoader = (
AutoModelForVision2Seq # pylint: disable=invalid-name
)
def set_device_map_config(self) -> None:
device_map = self.cfg.device_map
max_memory = self.cfg.max_memory
if self.cfg.gpu_memory_limit:
gpu_memory_limit = (
str(self.cfg.gpu_memory_limit) + "GiB"
if isinstance(self.cfg.gpu_memory_limit, int)
else self.cfg.gpu_memory_limit
)
max_memory = {}
num_device = get_device_count()
for i in range(num_device):
max_memory[i] = gpu_memory_limit
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything
if max_memory is not None:
# Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
from accelerate import infer_auto_device_map
with init_empty_weights():
model_canvas = self.AutoModelLoader.from_config(
self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
)
model_canvas.tie_weights()
device_map = infer_auto_device_map(
model_canvas,
max_memory=max_memory,
dtype=self.cfg.torch_dtype,
)
# We can discard max_memory now as we have a device map set up for us
max_memory = None
self.model_kwargs["device_map"] = device_map
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
cur_device = get_device_type()
if "mps" in str(cur_device):
self.model_kwargs["device_map"] = "mps:0"
elif "npu" in str(cur_device):
self.model_kwargs["device_map"] = "npu:0"
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
# if cfg.rl:
# if torch.cuda.device_count() > 1:
# if reference_model:
# model_kwargs["device_map"] = "cuda:" + str(
# torch.cuda.current_device() + 1
# )
# else:
# model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device())
if is_deepspeed_zero3_enabled():
del self.model_kwargs["device_map"]
def set_quantization_config(self) -> None:
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
if self.cfg.gptq:
if not hasattr(self.model_config, "quantization_config"):
LOG.warning(
"model config does not contain quantization_config information"
)
else:
if self.cfg.gptq_disable_exllama is not None:
self.model_config.quantization_config[
"disable_exllama"
] = self.cfg.gptq_disable_exllama
self.model_kwargs["quantization_config"] = GPTQConfig(
**self.model_config.quantization_config
)
if (
self.cfg.adapter in ["qlora", "lora"]
and hasattr(self.model_config, "quantization_config")
and self.model_config.quantization_config["quant_method"]
in ["gptq", "awq", "bitsandbytes"]
):
if self.model_config.quantization_config["quant_method"] == "gptq":
self.model_kwargs["quantization_config"] = GPTQConfig(
**self.model_config.quantization_config
)
elif self.model_config.quantization_config["quant_method"] == "awq":
self.model_kwargs["quantization_config"] = AwqConfig(
**self.model_config.quantization_config
)
elif (
self.model_config.quantization_config["quant_method"] == "bitsandbytes"
):
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
bnb_config = {
"load_in_4bit": True,
"llm_int8_threshold": 6.0,
"llm_int8_has_fp16_weight": False,
"bnb_4bit_compute_dtype": self.cfg.torch_dtype,
"bnb_4bit_use_double_quant": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16,
}
if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
self.cfg.deepspeed or self.cfg.fsdp
):
# for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32
if self.cfg.bnb_config_kwargs:
bnb_config.update(self.cfg.bnb_config_kwargs)
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]:
bnb_config = {
"load_in_8bit": True,
}
# Exclude mamba blocks from int8 quantization for jamba
if self.cfg.model_config_type == "jamba":
bnb_config["llm_int8_skip_modules"] = ["mamba"]
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
self.model_kwargs.pop("load_in_8bit", None)
self.model_kwargs.pop("load_in_4bit", None)
def set_attention_config(self) -> None:
"""
sample packing uses custom FA2 patch
"""
if self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
elif self.cfg.sdp_attention:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"sdpa"
)
elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
if self.cfg.low_cpu_mem_usage:
self.model_kwargs["low_cpu_mem_usage"] = True
def build_model(self, qlora_fsdp) -> bool:
def _configure_zero3_memory_efficient_loading():
"""
Set the deepspeed config to load the model into RAM first before moving to VRAM.
We need to return hf_ds_cfg as it needs to exist before model loading.
"""
hf_ds_cfg = None
if os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
hf_ds_cfg = HfTrainerDeepSpeedConfig(self.cfg.deepspeed)
hf_ds_cfg.fill_match(
"train_micro_batch_size_per_gpu", self.cfg.micro_batch_size
)
hf_ds_cfg.fill_match(
"gradient_accumulation_steps", self.cfg.gradient_accumulation_steps
)
hf_ds_cfg.fill_match(
"train_batch_size",
int(os.getenv("WORLD_SIZE", "1"))
* self.cfg.micro_batch_size
* self.cfg.gradient_accumulation_steps,
)
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"]
transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: True
transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = (
lambda: True
)
return hf_ds_cfg
skip_move_to_device = False
if ( # pylint: disable=condition-evals-to-constant)
(self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading)
and not qlora_fsdp
and False
):
self.model = load_sharded_model(
self.base_model,
self.model_config,
self.cfg,
torch_dtype=self.cfg.torch_dtype,
)
skip_move_to_device = True
elif (
qlora_fsdp
and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and (
self.cfg.model_config_type == "dbrx"
or self.cfg.qlora_sharded_model_loading
)
):
quant_storage = self.cfg.torch_dtype
quantization_config = hasattr(
self.model_config, "quantization_config"
) and getattr(self.model_config, "quantization_config")
quantization_config = (
quantization_config or self.model_kwargs["quantization_config"]
)
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = load_sharded_model_quant(
self.base_model,
self.model_config,
self.cfg,
quant_storage=quant_storage,
quantization_config=quantization_config,
)
skip_move_to_device = True
elif (
self.model_config.model_type == "llama"
and not self.cfg.trust_remote_code
and not self.cfg.gptq
):
if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"]
_ = _configure_zero3_memory_efficient_loading()
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained(
self.base_model,
config=self.model_config,
**self.model_kwargs,
)
# TODO (MengqingCao) split these patches seperately
if self.cfg.flash_attention and not self.inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
is_xformers_swiglu_available,
replace_llama_mlp_with_swiglu,
replace_llama_qkv_with_fused,
)
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
LOG.info("patching with SwiGLU")
replace_llama_mlp_with_swiglu(self.model)
if self.cfg.flash_attn_fuse_qkv:
LOG.info("patching with fused QKV")
replace_llama_qkv_with_fused(self.model)
elif self.model_type == "MambaLMHeadModel":
# FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"]
self.model_kwargs["device"] = torch.cuda.current_device()
del self.model_kwargs["torch_dtype"]
del self.model_kwargs["device_map"]
self.model = MambaLMHeadModel.from_pretrained(
self.base_model,
**self.model_kwargs,
)
elif (
self.model_type
and self.model_type != "AutoModelForCausalLM"
and not self.cfg.trust_remote_code
):
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
if self.cfg.gptq:
self.model = self.AutoModelLoader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else:
self.model = getattr(transformers, self.model_type).from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else:
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
if (
hasattr(self.text_model_config, "max_seq_len")
and self.text_model_config.max_seq_len
and self.cfg.sequence_len > self.text_model_config.max_seq_len
):
self.text_model_config.max_seq_len = self.cfg.sequence_len
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
elif (
hasattr(self.text_model_config, "max_sequence_length")
and self.text_model_config.max_sequence_length
and self.cfg.sequence_len > self.text_model_config.max_sequence_length
):
self.text_model_config.max_sequence_length = self.cfg.sequence_len
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
if self.cfg.gptq:
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else:
if (
self.cfg.fsdp
and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
):
# disabling either of these two still leads to VRAM spike before setting back down
skip_move_to_device = True
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"]
_ = _configure_zero3_memory_efficient_loading()
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
if is_deepspeed_zero3_enabled():
skip_move_to_device = True
return skip_move_to_device
def ajust_model_config(self) -> None:
if (
hasattr(self.model, "config")
and hasattr(self.model.config, "max_position_embeddings")
and self.model.config.max_position_embeddings
and self.cfg.sequence_len > self.model.config.max_position_embeddings
):
LOG.warning(
f"increasing model.config.max_position_embeddings from {self.model.config.max_position_embeddings} to {self.cfg.sequence_len}"
)
self.model.config.max_position_embeddings = self.cfg.sequence_len
if (
hasattr(self.model, "config")
and hasattr(self.model.config, "bos_token_id")
and self.model.config.bos_token_id
and self.model.config.bos_token_id != self.tokenizer.bos_token_id
):
self.model.config.bos_token_id = self.tokenizer.bos_token_id
if (
hasattr(self.model, "config")
and hasattr(self.model.config, "eos_token_id")
and self.model.config.eos_token_id
and self.model.config.eos_token_id != self.tokenizer.eos_token_id
):
self.model.config.eos_token_id = self.tokenizer.eos_token_id
def set_z3_leaf_modules(self) -> None:
from deepspeed.utils import ( # pylint: disable=no-name-in-module
set_z3_leaf_modules,
)
if self.cfg.model_config_type in MOE_ARCH_BLOCK:
moe_blocks = MOE_ARCH_BLOCK[self.cfg.model_config_type]
moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks
set_z3_leaf_modules(
self.model,
[
get_module_class_from_name(self.model, module_name)
for module_name in moe_blocks
],
)
def prepare_model(self, qlora_fsdp) -> None:
skip_prepare_model_for_kbit_training = False
if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
skip_prepare_model_for_kbit_training = True
loftq_bits = (
self.cfg.peft
and self.cfg.peft.loftq_config
and self.cfg.peft.loftq_config.loftq_bits
)
if self.cfg.adapter == "lora" and loftq_bits:
skip_prepare_model_for_kbit_training = True
if qlora_fsdp or (
self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
):
# make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True
if is_deepspeed_zero3_enabled():
skip_prepare_model_for_kbit_training = True
if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit)
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
self.model = prepare_model_for_kbit_training(
self.model, use_gradient_checkpointing=self.cfg.gradient_checkpointing
)
def convert_embedding_modules_dtype(
self, embedding_modules, dist_dtype, before_kbit_train_or_finetune
) -> None:
for name, module in self.model.named_modules():
if "norm" in name:
module.to(dist_dtype)
if before_kbit_train_or_finetune:
if name.endswith(".gate"):
module.to(dist_dtype)
if self.model_config.model_type == "btlm":
# don't upcast lm_head for btlm
continue
if any(m in name for m in embedding_modules):
if hasattr(module, "weight"):