-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
Copy pathtest_custom_models.py
3541 lines (3016 loc) · 139 KB
/
test_custom_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import re
import shutil
import tempfile
import time
import unittest
from contextlib import contextmanager
from functools import partial
import pytest
import torch
from parameterized import parameterized
from safetensors.torch import load_file as safe_load_file
from torch import nn
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification
from transformers.pytorch_utils import Conv1D
from peft import (
AdaLoraConfig,
BOFTConfig,
FourierFTConfig,
HRAConfig,
IA3Config,
LNTuningConfig,
LoHaConfig,
LoKrConfig,
LoraConfig,
OFTConfig,
PeftModel,
TaskType,
VeraConfig,
get_peft_model,
)
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import ModulesToSaveWrapper, infer_device
from .testing_common import PeftCommonTester
from .testing_utils import get_state_dict, require_torch_gpu
# MLP is a vanilla FF network with only linear layers
# EmbConv1D has an embedding and a Conv1D layer
# Conv2D has a Conv2D layer
TEST_CASES = [
########
# LoRA #
########
("Vanilla MLP 1 LoRA", "MLP", LoraConfig, {"target_modules": "lin0"}),
("Vanilla MLP 2 LoRA", "MLP", LoraConfig, {"target_modules": ["lin0"]}),
("Vanilla MLP 3 LoRA", "MLP", LoraConfig, {"target_modules": ["lin1"]}),
("Vanilla MLP 4 LoRA", "MLP", LoraConfig, {"target_modules": ["lin0", "lin1"]}),
("Vanilla MLP 5 LoRA", "MLP", LoraConfig, {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}),
(
"Vanilla MLP 6 LoRA",
"MLP",
LoraConfig,
{
"target_modules": ["lin0"],
"lora_alpha": 4,
"lora_dropout": 0.1,
},
),
("Vanilla MLP 7 LoRA with DoRA", "MLP", LoraConfig, {"target_modules": ["lin0"], "use_dora": True}),
("Vanilla MLP 8 LoRA with DoRA", "MLP", LoraConfig, {"target_modules": ["lin0", "lin1"], "use_dora": True}),
(
"Vanilla MLP 9 LoRA with DoRA",
"MLP",
LoraConfig,
{"target_modules": "lin1", "use_dora": True, "lora_alpha": 32},
),
("Embedding + transformers Conv1D 1 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["conv1d"]}),
("Embedding + transformers Conv1D 2 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["emb"]}),
("Embedding + transformers Conv1D 3 LoRA", "EmbConv1D", LoraConfig, {"target_modules": ["emb", "conv1d"]}),
("Conv2d 1 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"]}),
("Conv2d 2 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"]}),
("Conv2d 1 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}),
("Conv2d 2 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"], "use_dora": True}),
#######
# IA³ #
#######
("Vanilla MLP 1 IA3", "MLP", IA3Config, {"target_modules": "lin0", "feedforward_modules": []}),
("Vanilla MLP 2 IA3", "MLP", IA3Config, {"target_modules": "lin0", "feedforward_modules": "lin0"}),
("Vanilla MLP 3 IA3", "MLP", IA3Config, {"target_modules": ["lin0"], "feedforward_modules": []}),
("Vanilla MLP 4 IA3", "MLP", IA3Config, {"target_modules": ["lin0"], "feedforward_modules": ["lin0"]}),
("Vanilla MLP 5 IA3", "MLP", IA3Config, {"target_modules": ["lin1"], "feedforward_modules": []}),
("Vanilla MLP 6 IA3", "MLP", IA3Config, {"target_modules": ["lin1"], "feedforward_modules": ["lin1"]}),
(
"Vanilla MLP 7 IA3",
"MLP",
IA3Config,
{"target_modules": ["lin0", "lin1"], "feedforward_modules": []},
),
(
"Vanilla MLP 8 IA3",
"MLP",
IA3Config,
{"target_modules": ["lin0", "lin1"], "feedforward_modules": ["lin0", "lin1"]},
),
(
"Vanilla MLP 9 IA3",
"MLP",
IA3Config,
{"target_modules": ["lin0"], "modules_to_save": ["lin1"], "feedforward_modules": ["lin0"]},
),
(
"transformers Conv1D 1 IA3",
"EmbConv1D",
IA3Config,
{"target_modules": ["conv1d"], "feedforward_modules": ["conv1d"]},
),
(
"transformers Conv1D 2 IA3",
"EmbConv1D",
IA3Config,
{"target_modules": ["conv1d", "lin0"], "feedforward_modules": ["conv1d", "lin0"]},
),
(
"transformers Conv1D 1 IA3",
"EmbConv1D",
IA3Config,
{"target_modules": ["conv1d"], "feedforward_modules": ["conv1d"], "modules_to_save": ["lin1"]},
),
("Conv2d 1 IA3", "Conv2d", IA3Config, {"target_modules": ["conv2d"], "feedforward_modules": []}),
("Conv2d 2 IA3", "Conv2d", IA3Config, {"target_modules": ["conv2d"], "feedforward_modules": ["conv2d"]}),
(
"Conv2d 3 IA3",
"Conv2d",
IA3Config,
{"target_modules": ["conv2d", "lin0"], "feedforward_modules": []},
),
(
"Conv2d 4 IA3",
"Conv2d",
IA3Config,
{"target_modules": ["conv2d", "lin0"], "feedforward_modules": ["conv2d"]},
),
(
"Conv2d 5 IA3",
"Conv2d",
IA3Config,
{"target_modules": ["conv2d", "lin0"], "feedforward_modules": ["conv2d", "lin0"]},
),
########
# LoHa #
########
("Vanilla MLP 1 LOHA", "MLP", LoHaConfig, {"target_modules": "lin0"}),
("Vanilla MLP 2 LOHA", "MLP", LoHaConfig, {"target_modules": ["lin0"]}),
("Vanilla MLP 3 LOHA", "MLP", LoHaConfig, {"target_modules": ["lin1"]}),
("Vanilla MLP 4 LOHA", "MLP", LoHaConfig, {"target_modules": ["lin0", "lin1"]}),
("Vanilla MLP 5 LOHA", "MLP", LoHaConfig, {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}),
(
"Vanilla MLP 6 LOHA",
"MLP",
LoHaConfig,
{
"target_modules": ["lin0"],
"alpha": 4,
"module_dropout": 0.1,
},
),
("Vanilla MLP 7 LOHA", "MLP", LoHaConfig, {"target_modules": "lin0", "rank_dropout": 0.5}),
("Conv2d 1 LOHA", "Conv2d", LoHaConfig, {"target_modules": ["conv2d"]}),
("Conv2d 2 LOHA", "Conv2d", LoHaConfig, {"target_modules": ["conv2d", "lin0"]}),
("Conv2d 3 LOHA", "Conv2d", LoHaConfig, {"target_modules": ["conv2d"], "use_effective_conv2d": True}),
("Conv2d 4 LOHA", "Conv2d", LoHaConfig, {"target_modules": ["conv2d", "lin0"], "use_effective_conv2d": True}),
# LoKr
("Vanilla MLP 1 LOKR", "MLP", LoKrConfig, {"target_modules": "lin0"}),
("Vanilla MLP 2 LOKR", "MLP", LoKrConfig, {"target_modules": ["lin0"]}),
("Vanilla MLP 3 LOKR", "MLP", LoKrConfig, {"target_modules": ["lin1"]}),
("Vanilla MLP 4 LOKR", "MLP", LoKrConfig, {"target_modules": ["lin0", "lin1"]}),
("Vanilla MLP 5 LOKR", "MLP", LoKrConfig, {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}),
(
"Vanilla MLP 6 LOKR",
"MLP",
LoKrConfig,
{
"target_modules": ["lin0"],
"alpha": 4,
"module_dropout": 0.1,
},
),
("Vanilla MLP 7 LOKR", "MLP", LoKrConfig, {"target_modules": "lin0", "rank_dropout": 0.5}),
("Vanilla MLP 8 LOKR", "MLP", LoKrConfig, {"target_modules": "lin0", "decompose_both": True, "r": 1, "alpha": 1}),
("Conv2d 1 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d"]}),
("Conv2d 2 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d", "lin0"]}),
("Conv2d 3 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d"], "use_effective_conv2d": True}),
("Conv2d 4 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d", "lin0"], "use_effective_conv2d": True}),
(
"Conv2d 5 LOKR",
"Conv2d",
LoKrConfig,
{"target_modules": ["conv2d", "lin0"], "use_effective_conv2d": True, "decompose_both": True},
),
(
"Conv2d 6 LOKR",
"Conv2d",
LoKrConfig,
{"target_modules": ["conv2d", "lin0"], "use_effective_conv2d": True, "decompose_factor": 4},
),
(
"Conv2d 7 LOKR",
"Conv2d",
LoKrConfig,
{
"target_modules": ["conv2d", "lin0"],
"use_effective_conv2d": True,
"decompose_both": True,
"decompose_factor": 4,
},
),
########
# OFT #
########
("Vanilla MLP 1 OFT", "MLP", OFTConfig, {"target_modules": "lin0"}),
("Vanilla MLP 2 OFT", "MLP", OFTConfig, {"target_modules": ["lin0"]}),
("Vanilla MLP 5 OFT", "MLP", OFTConfig, {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}),
(
"Vanilla MLP 6 OFT",
"MLP",
OFTConfig,
{
"target_modules": ["lin0"],
"module_dropout": 0.1,
},
),
("Vanilla MLP 7 OFT", "MLP", OFTConfig, {"target_modules": ["lin0"], "coft": True}),
("Vanilla MLP 8 OFT", "MLP", OFTConfig, {"target_modules": ["lin0"], "block_share": True}),
("Vanilla MLP 9 OFT", "MLP", OFTConfig, {"target_modules": ["lin0"], "coft": True, "block_share": True}),
("Conv2d 1 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"]}),
("Conv2d 3 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"], "coft": True}),
("Conv2d 4 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"], "block_share": True}),
("Conv2d 5 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"], "coft": True, "block_share": True}),
########
# HRA #
########
("Vanilla MLP 1 HRA", "MLP", HRAConfig, {"target_modules": "lin0"}),
("Vanilla MLP 2 HRA", "MLP", HRAConfig, {"target_modules": ["lin0"]}),
("Vanilla MLP 3 HRA", "MLP", HRAConfig, {"target_modules": ["lin0", "lin1"]}),
("Vanilla MLP 5 HRA", "MLP", HRAConfig, {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}),
("Conv2d 1 HRA", "Conv2d", HRAConfig, {"target_modules": ["conv2d"]}),
#############
# LN Tuning #
#############
("LayerNorm 1 LNTuning", "MLP_LayerNorm", LNTuningConfig, {"target_modules": "layernorm0"}),
("LayerNorm 2 LNTuning", "MLP_LayerNorm", LNTuningConfig, {"target_modules": ["layernorm0"]}),
(
"LayerNorm 3 LNTuning",
"MLP_LayerNorm",
LNTuningConfig,
{"target_modules": ["layernorm0"], "modules_to_save": ["layernorm1"]},
),
("Linear 4 LNTuning", "MLP_LayerNorm", LNTuningConfig, {"target_modules": "lin0"}),
("Linear 5 LNTuning", "MLP_LayerNorm", LNTuningConfig, {"target_modules": ["lin0"]}),
########
# BOFT #
########
("Vanilla MLP 1 BOFT", "MLP", BOFTConfig, {"target_modules": ["lin1"], "boft_block_size": 2}),
(
"Vanilla MLP 2 BOFT",
"MLP",
BOFTConfig,
{"target_modules": ["lin1"], "modules_to_save": ["lin0"], "boft_block_size": 2},
),
(
"Vanilla MLP 3 BOFT",
"MLP",
BOFTConfig,
{
"target_modules": ["lin1"],
"boft_block_size": 2,
"boft_dropout": 0.1,
},
),
(
"Vanilla MLP 4 BOFT",
"MLP",
BOFTConfig,
{"target_modules": ["lin1"], "boft_block_size": 2, "boft_block_num": 0, "boft_n_butterfly_factor": 1},
),
(
"Vanilla MLP 5 BOFT",
"MLP",
BOFTConfig,
{"target_modules": ["lin1"], "boft_block_size": 0, "boft_block_num": 2, "boft_n_butterfly_factor": 1},
),
(
"Vanilla MLP 6 BOFT",
"MLP",
BOFTConfig,
{"target_modules": ["lin1"], "boft_block_size": 10, "boft_block_num": 0, "boft_n_butterfly_factor": 2},
),
(
"Conv2d 1 BOFT",
"Conv2d",
BOFTConfig,
{"target_modules": ["conv2d"], "boft_block_size": 45, "boft_block_num": 0, "boft_n_butterfly_factor": 1},
),
(
"Conv2d 2 BOFT",
"Conv2d",
BOFTConfig,
{"target_modules": ["conv2d"], "boft_block_size": 0, "boft_block_num": 1, "boft_n_butterfly_factor": 1},
),
(
"MLP2 1 BOFT",
"MLP2",
BOFTConfig,
{"target_modules": ["lin1"], "boft_block_size": 2, "boft_block_num": 0, "boft_n_butterfly_factor": 3},
),
(
"MLP2 2 BOFT",
"MLP2",
BOFTConfig,
{"target_modules": ["lin1"], "boft_block_size": 0, "boft_block_num": 8, "boft_n_butterfly_factor": 3},
),
(
"Conv2d2 1 BOFT",
"Conv2d2",
BOFTConfig,
{"target_modules": ["conv2d"], "boft_block_size": 2, "boft_block_num": 0, "boft_n_butterfly_factor": 2},
),
(
"Conv2d2 1 BOFT",
"Conv2d2",
BOFTConfig,
{"target_modules": ["conv2d"], "boft_block_size": 2, "boft_block_num": 0, "boft_n_butterfly_factor": 3},
),
########
# VeRA #
########
("Vanilla MLP 1 VeRA", "MLP", VeraConfig, {"target_modules": "lin0"}),
("Vanilla MLP 2 VeRA", "MLP", VeraConfig, {"target_modules": ["lin0"]}),
("Vanilla MLP 3 VeRA", "MLP", VeraConfig, {"target_modules": ["lin1"]}),
("Vanilla MLP 4 VeRA", "MLP", VeraConfig, {"target_modules": ["lin0", "lin1"]}),
(
"Vanilla MLP 5 VeRA",
"MLP",
VeraConfig,
{"target_modules": ["lin0"], "modules_to_save": ["lin1"]},
),
(
"Embedding + transformers Conv1D 1 VeRA",
"EmbConv1D",
VeraConfig,
{"target_modules": ["conv1d"]},
),
########
# FourierFT #
########
("Vanilla MLP 1 FourierFT", "MLP", FourierFTConfig, {"n_frequency": 10, "target_modules": "lin0"}),
("Vanilla MLP 2 FourierFT", "MLP", FourierFTConfig, {"n_frequency": 10, "target_modules": ["lin0"]}),
("Vanilla MLP 3 FourierFT", "MLP", FourierFTConfig, {"n_frequency": 10, "target_modules": ["lin1"]}),
(
"Vanilla MLP 5 FourierFT",
"MLP",
FourierFTConfig,
{"n_frequency": 10, "target_modules": ["lin0"], "modules_to_save": ["lin1"]},
),
(
"Vanilla MLP 6 FourierFT",
"MLP",
FourierFTConfig,
{"n_frequency": 10, "target_modules": ["lin0", "lin1"], "modules_to_save": ["lin1"]},
),
(
"Vanilla MLP 7 FourierFT",
"MLP",
FourierFTConfig,
{
"n_frequency_pattern": {"lin0": 5, "lin1": 10},
"target_modules": ["lin0", "lin1"],
"modules_to_save": ["lin1"],
},
),
]
# For this test matrix, each tuple consists of:
# - test name
# - tuner method
# - config_cls
# - 1st config kwargs
# - 2nd config kwargs
# The model used for this test is `MLP`, which uses linear layers `lin0` and `lin1`
MULTIPLE_ACTIVE_ADAPTERS_TEST_CASES = [
(
"LoRA Same",
"lora",
LoraConfig,
{"target_modules": ["lin0"], "init_lora_weights": False},
{"target_modules": ["lin0"], "init_lora_weights": False},
),
(
"LoRA Different",
"lora",
LoraConfig,
{"target_modules": ["lin0"], "init_lora_weights": False},
{"target_modules": ["lin1"], "init_lora_weights": False},
),
(
"IA3 Same",
"ia3",
IA3Config,
{
"target_modules": ["lin0"],
"feedforward_modules": ["lin0"],
"init_ia3_weights": False,
},
{
"target_modules": ["lin0"],
"feedforward_modules": ["lin0"],
"init_ia3_weights": False,
},
),
(
"IA3 Different",
"ia3",
IA3Config,
{
"target_modules": ["lin0"],
"feedforward_modules": ["lin0"],
"init_ia3_weights": False,
},
{
"target_modules": ["lin1"],
"feedforward_modules": ["lin1"],
"init_ia3_weights": False,
},
),
(
"AdaLora Same",
"adalora",
AdaLoraConfig,
{"target_modules": ["lin0"], "init_lora_weights": False, "inference_mode": True},
{"target_modules": ["lin0"], "init_lora_weights": False, "inference_mode": True},
),
(
"AdaLora Different",
"adalora",
AdaLoraConfig,
{"target_modules": ["lin0"], "init_lora_weights": False, "inference_mode": True},
{"target_modules": ["lin1"], "init_lora_weights": False, "inference_mode": True},
),
(
"FourierFT Same",
"fourierft",
FourierFTConfig,
{"n_frequency": 10, "target_modules": ["lin0"]},
{"n_frequency": 10, "target_modules": ["lin0"]},
),
(
"FourierFT Different",
"fourierft",
FourierFTConfig,
{"n_frequency": 10, "target_modules": ["lin0"]},
{"n_frequency": 10, "target_modules": ["lin1"]},
),
# Note: Currently, we cannot target lin0 and lin1 with different adapters when using VeRA. The reason is that the
# first adapter being created will result in a vera_A or vera_B shape that is too small for the next adapter
# (remember that VeRA shares these parameters across all layers), which results in an error.
(
"VeRA Same",
"vera",
VeraConfig,
{"target_modules": ["lin0"], "init_weights": False},
{"target_modules": ["lin0"], "init_weights": False},
),
(
"HRA Same",
"hra",
HRAConfig,
{"target_modules": ["lin0"], "init_weights": False},
{"target_modules": ["lin0"], "init_weights": False},
),
(
"HRA Different",
"hra",
HRAConfig,
{"target_modules": ["lin0"], "init_weights": False},
{"target_modules": ["lin1"], "init_weights": False},
),
]
PREFIXES = {
IA3Config: "ia3_",
LoraConfig: "lora_",
LoHaConfig: "hada_",
LoKrConfig: "lokr_",
OFTConfig: "oft_",
BOFTConfig: "boft_",
LNTuningConfig: "ln_tuning_",
VeraConfig: "vera_lambda_",
FourierFTConfig: "fourierft_",
HRAConfig: "hra_",
}
class MLP(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.lin0 = nn.Linear(10, 20, bias=bias)
self.relu = nn.ReLU()
self.drop = nn.Dropout(0.5)
self.lin1 = nn.Linear(20, 2, bias=bias)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = X.float()
X = self.lin0(X)
X = self.relu(X)
X = self.drop(X)
X = self.lin1(X)
X = self.sm(X)
return X
class MLP_LayerNorm(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.layernorm0 = nn.LayerNorm(10, 10)
self.lin0 = nn.Linear(10, 20, bias=bias)
self.relu = nn.ReLU()
self.drop = nn.Dropout(0.5)
self.layernorm1 = nn.LayerNorm(20, 20)
self.lin1 = nn.Linear(20, 2, bias=bias)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = X.float()
X = self.layernorm0(X)
X = self.lin0(X)
X = self.relu(X)
X = self.drop(X)
X = self.layernorm1(X)
X = self.lin1(X)
X = self.sm(X)
return X
class MLP2(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.lin0 = nn.Linear(10, 32, bias=bias)
self.relu = nn.ReLU()
self.drop = nn.Dropout(0.5)
self.lin1 = nn.Linear(32, 2, bias=bias)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = X.float()
X = self.lin0(X)
X = self.relu(X)
X = self.drop(X)
X = self.lin1(X)
X = self.sm(X)
return X
class Block(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.lin0 = nn.Linear(10, 20, bias=bias)
self.relu = nn.ReLU()
self.drop = nn.Dropout(0.5)
self.lin1 = nn.Linear(20, 10, bias=bias)
def forward(self, X):
X = X.float()
X = self.lin0(X)
X = self.relu(X)
X = self.drop(X)
X = self.lin1(X)
return X
class DeepMLP(nn.Module):
def __init__(self, bias=True, num_hidden_layers=12):
super().__init__()
self.layers = nn.ModuleList([Block(bias=bias) for _ in range(num_hidden_layers)])
self.out = nn.Linear(10, 2, bias=bias)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = X.float(X)
for layer in self.layers:
X = layer(X)
X = self.out(X)
X = self.sm(X)
return X
class ModelEmbConv1D(nn.Module):
def __init__(self, emb_size=100):
super().__init__()
self.emb = nn.Embedding(emb_size, 5)
self.conv1d = Conv1D(1, 5)
self.relu = nn.ReLU()
self.flat = nn.Flatten()
self.lin0 = nn.Linear(10, 2)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = self.emb(X)
X = self.conv1d(X)
X = self.relu(X)
X = self.flat(X)
X = self.lin0(X)
X = self.sm(X)
return X
class ModelEmbWithEmbeddingUtils(nn.Module):
# Adds `get_input_embeddings` and `get_output_embeddings` methods to mimic 🤗 transformers models
def __init__(self):
super().__init__()
self.embed_tokens = nn.Embedding(100, 5)
self.conv1d = Conv1D(1, 5)
self.relu = nn.ReLU()
self.flat = nn.Flatten()
self.lin0 = nn.Linear(10, 2)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = self.embed_tokens(X)
X = self.conv1d(X)
X = self.relu(X)
X = self.flat(X)
X = self.lin0(X)
X = self.sm(X)
return X
def get_input_embeddings(self):
return self.embed_tokens
def get_output_embeddings(self):
return None
class ModelConv2D(nn.Module):
def __init__(self):
super().__init__()
self.conv2d = nn.Conv2d(5, 10, 3)
self.relu = nn.ReLU()
self.flat = nn.Flatten()
self.lin0 = nn.Linear(10, 2)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = X.float().reshape(-1, 5, 3, 3)
X = self.conv2d(X)
X = self.relu(X)
X = self.flat(X)
X = self.lin0(X)
X = self.sm(X)
return X
class ModelConv2D2(nn.Module):
def __init__(self):
super().__init__()
self.lin0 = nn.Linear(10, 40)
self.conv2d = nn.Conv2d(8, 32, 3)
self.relu = nn.ReLU()
self.flat = nn.Flatten()
self.lin1 = nn.Linear(32, 2)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = X.float()
X = self.lin0(X)
X = self.relu(X)
X = X.reshape(-1, 8, 3, 3)
X = self.conv2d(X)
X = self.relu(X)
X = self.flat(X)
X = self.lin1(X)
X = self.sm(X)
return X
class MockTransformerWrapper:
"""Mock class to behave like a transformers model.
This is needed because the tests initialize the model by calling transformers_class.from_pretrained.
"""
@classmethod
def from_pretrained(cls, model_id, torch_dtype=None):
# set the seed so that from_pretrained always returns the same model
torch.manual_seed(0)
if torch_dtype is None:
torch_dtype = torch.float32
if model_id == "MLP":
return MLP().to(torch_dtype)
if model_id == "EmbConv1D":
return ModelEmbConv1D().to(torch_dtype)
if model_id == "Conv2d":
return ModelConv2D().to(torch_dtype)
if model_id == "MLP_LayerNorm":
return MLP_LayerNorm().to(torch_dtype)
if model_id == "MLP2":
return MLP2().to(torch_dtype)
if model_id == "Conv2d2":
return ModelConv2D2().to(torch_dtype)
raise ValueError(f"model_id {model_id} not implemented")
class PeftCustomModelTester(unittest.TestCase, PeftCommonTester):
"""TODO"""
transformers_class = MockTransformerWrapper
def prepare_inputs_for_testing(self):
X = torch.arange(90).view(9, 10).to(self.torch_device)
return {"X": X}
@parameterized.expand(TEST_CASES)
def test_attributes_parametrized(self, test_name, model_id, config_cls, config_kwargs):
self._test_model_attr(model_id, config_cls, config_kwargs)
@parameterized.expand(TEST_CASES)
def test_adapter_name(self, test_name, model_id, config_cls, config_kwargs):
self._test_adapter_name(model_id, config_cls, config_kwargs)
@parameterized.expand(TEST_CASES)
def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls, config_kwargs):
# This test does not work with custom models because it assumes that
# there is always a method get_input_embeddings that returns a layer
# which does not need updates. Instead, a new test is added below that
# checks that LoRA works as expected.
pass
@parameterized.expand(TEST_CASES)
def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs):
self._test_save_pretrained(model_id, config_cls, config_kwargs)
@parameterized.expand(TEST_CASES)
def test_save_pretrained_pickle(self, test_name, model_id, config_cls, config_kwargs):
self._test_save_pretrained(model_id, config_cls, config_kwargs, safe_serialization=False)
@parameterized.expand(TEST_CASES)
def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs):
self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs)
@parameterized.expand(TEST_CASES)
def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
config_kwargs = config_kwargs.copy()
if issubclass(config_cls, LoraConfig):
config_kwargs["init_lora_weights"] = False
elif issubclass(config_cls, IA3Config):
config_kwargs["init_ia3_weights"] = False
elif issubclass(config_cls, LNTuningConfig):
pass
else:
config_kwargs["init_weights"] = False
self._test_merge_layers(model_id, config_cls, config_kwargs)
@parameterized.expand(TEST_CASES)
def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs):
config_kwargs = config_kwargs.copy()
if issubclass(config_cls, LoraConfig):
config_kwargs["init_lora_weights"] = False
elif issubclass(config_cls, IA3Config):
config_kwargs["init_ia3_weights"] = False
self._test_merge_layers_fp16(model_id, config_cls, config_kwargs)
@parameterized.expand(TEST_CASES)
def test_merge_layers_is_idempotent(self, test_name, model_id, config_cls, config_kwargs):
# calling merge twice with the same arguments should not change the output
config_kwargs = config_kwargs.copy()
if issubclass(config_cls, LoraConfig):
config_kwargs["init_lora_weights"] = False
elif issubclass(config_cls, IA3Config):
config_kwargs["init_ia3_weights"] = False
self._test_merge_layers_is_idempotent(model_id, config_cls, config_kwargs)
@parameterized.expand(TEST_CASES)
def test_safe_merge(self, test_name, model_id, config_cls, config_kwargs):
# calling merge twice with the same arguments should not change the output
config_kwargs = config_kwargs.copy()
if issubclass(config_cls, LoraConfig):
config_kwargs["init_lora_weights"] = False
elif issubclass(config_cls, IA3Config):
config_kwargs["init_ia3_weights"] = False
elif issubclass(config_cls, LNTuningConfig):
# LNTuning do not take init_weights
pass
else:
config_kwargs["init_weights"] = False
self._test_safe_merge(model_id, config_cls, config_kwargs)
@parameterized.expand(TEST_CASES)
def test_generate(self, test_name, model_id, config_cls, config_kwargs):
# Custom models do not (necessarily) have a generate method, so this test is not performed
pass
@parameterized.expand(TEST_CASES)
def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs):
# Custom models do not (necessarily) have a generate method, so this test is not performed
pass
@parameterized.expand(TEST_CASES)
def test_training_custom_models(self, test_name, model_id, config_cls, config_kwargs):
self._test_training(model_id, config_cls, config_kwargs)
@parameterized.expand(TEST_CASES)
def test_training_custom_models_layer_indexing(self, test_name, model_id, config_cls, config_kwargs):
# At the moment, layer indexing only works when layer names conform to a specific pattern, which is not
# guaranteed here. Therefore, this test is not performed.
pass
@parameterized.expand(TEST_CASES)
def test_training_custom_models_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs):
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs)
@parameterized.expand(TEST_CASES)
def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwargs):
self._test_inference_safetensors(model_id, config_cls, config_kwargs)
@parameterized.expand(TEST_CASES)
def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwargs):
self._test_peft_model_device_map(model_id, config_cls, config_kwargs)
@parameterized.expand(TEST_CASES)
def test_forward_output_finite(self, test_name, model_id, config_cls, config_kwargs):
X = self.prepare_inputs_for_testing()
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
model.eval()
with torch.no_grad():
output = model(**X)
assert torch.isfinite(output).all()
@parameterized.expand(TEST_CASES)
def test_only_params_are_updated(self, test_name, model_id, config_cls, config_kwargs):
# An explicit test that when using an adapter on a custom model, only the adapter parameters are updated during
# training
X = self.prepare_inputs_for_testing()
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
model_before = copy.deepcopy(model)
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
# breaking of some LoRA layers that are initialized with constants)
for _ in range(3):
optimizer.zero_grad()
y_pred = model(**X)
loss = y_pred.sum()
loss.backward()
optimizer.step()
tol = 1e-4
params_before = dict(model_before.named_parameters())
params_after = dict(model.named_parameters())
assert params_before.keys() == params_after.keys()
prefix = PREFIXES[config_cls]
for name, param_before in params_before.items():
param_after = params_after[name]
if (prefix in name) or ("modules_to_save" in name):
# target_modules and modules_to_save _are_ updated
assert not torch.allclose(param_before, param_after, atol=tol, rtol=tol)
else:
assert torch.allclose(param_before, param_after, atol=tol, rtol=tol)
@parameterized.expand(TEST_CASES)
def test_parameters_after_loading_model(self, test_name, model_id, config_cls, config_kwargs):
# An explicit test that when loading a trained model, the parameters are loaded correctly
# see issue #808
X = self.prepare_inputs_for_testing()
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
model.train()
lr = 0.5 if not config_kwargs.get("use_dora") else 0.1 # otherwise we get nan
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
# breaking of some LoRA layers that are initialized with constants)
for _ in range(3):
optimizer.zero_grad()
y_pred = model(**X)
loss = y_pred.sum()
loss.backward()
optimizer.step()
tol = 1e-4
params_before = get_state_dict(model)
# note: no need to sanity check if parameters were updated at all, this
# is already covered in the previous test
with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)
model_from_pretrained = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)
params_after = get_state_dict(model_from_pretrained)
assert params_before.keys() == params_after.keys()
for name, param_before in params_before.items():
param_after = params_after[name]
assert torch.allclose(param_before, param_after, atol=tol, rtol=tol)
@parameterized.expand(TEST_CASES)
def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):
X = self.prepare_inputs_for_testing()
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device).eval()
outputs_base = model(**X)
if issubclass(config_cls, FourierFTConfig):
config_kwargs = config_kwargs.copy()
config_kwargs["init_weights"] = True
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
model.eval()
outputs_before = model(**X)
assert torch.allclose(outputs_base, outputs_before)
model.train()
# EmbConv1D is slow to learn for some reason
lr = 0.01 if model_id != "EmbConv1D" else 1.0
if isinstance(config_cls, LNTuningConfig):
# LayerNorm tuning is slow to learn
lr = 1.0
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
# breaking of some LoRA layers that are initialized with constants)
for _ in range(3):
optimizer.zero_grad()
y_pred = model(**X)
y = torch.arange(len(y_pred)).to(self.torch_device) % 2
loss = nn.functional.nll_loss(y_pred, y)
loss.backward()
optimizer.step()
model.eval()
outputs_after = model(**X)
with model.disable_adapter():
outputs_disabled = model(**X)
# check that after leaving the disable_adapter context, everything is enabled again
outputs_enabled_after_disable = model(**X)
if self.torch_device == "cpu":
# LayerNorm is running float32 on cpu, so difference in outputs are smaller
rtol, atol = 1e-8, 1e-8
else:
rtol, atol = 1e-5, 1e-8
assert not torch.allclose(outputs_before, outputs_after, rtol=rtol, atol=atol)
assert torch.allclose(outputs_before, outputs_disabled)
assert torch.allclose(outputs_after, outputs_enabled_after_disable)
@parameterized.expand(TEST_CASES)
def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, config_kwargs):
# same as test_disable_adapters, but with merging
X = self.prepare_inputs_for_testing()