From 214345ee4787b04636947af50ffb2f869c11613b Mon Sep 17 00:00:00 2001 From: JINO ROHIT Date: Tue, 29 Oct 2024 19:43:56 +0530 Subject: [PATCH] ENH Check layers to transforms and layer pattern (#2159) --- src/peft/tuners/adalora/config.py | 6 +++--- src/peft/tuners/boft/config.py | 3 +++ src/peft/tuners/fourierft/config.py | 3 +++ src/peft/tuners/hra/config.py | 4 ++++ src/peft/tuners/loha/config.py | 3 +++ src/peft/tuners/lokr/config.py | 3 +++ src/peft/tuners/lora/config.py | 5 +++++ src/peft/tuners/oft/config.py | 3 +++ src/peft/tuners/vblora/config.py | 3 +++ src/peft/tuners/vera/config.py | 4 +++- tests/test_config.py | 27 +++++++++++++++++++++++++++ tests/test_tuners_utils.py | 3 --- 12 files changed, 60 insertions(+), 7 deletions(-) diff --git a/src/peft/tuners/adalora/config.py b/src/peft/tuners/adalora/config.py index 5419159397..f86a9dfd60 100644 --- a/src/peft/tuners/adalora/config.py +++ b/src/peft/tuners/adalora/config.py @@ -68,9 +68,9 @@ def __post_init__(self): if isinstance(self.target_modules, str) and self.layers_to_transform is not None: raise ValueError("`layers_to_transform` cannot be used when `target_modules` is a str.") - # if target_modules is a regex expression, then layers_pattern should be None - if isinstance(self.target_modules, str) and self.layers_pattern is not None: - raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.") + # check for layers_to_transform and layers_pattern + if self.layers_pattern and not self.layers_to_transform: + raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ") # Check if 'r' has been set to a non-default value if self.r != 8: # 8 is the default value for 'r' in LoraConfig diff --git a/src/peft/tuners/boft/config.py b/src/peft/tuners/boft/config.py index 7559303272..856e1d8667 100644 --- a/src/peft/tuners/boft/config.py +++ b/src/peft/tuners/boft/config.py @@ -146,6 +146,9 @@ def __post_init__(self): self.exclude_modules = ( set(self.exclude_modules) if isinstance(self.exclude_modules, list) else self.exclude_modules ) + # check for layers_to_transform and layers_pattern + if self.layers_pattern and not self.layers_to_transform: + raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ") if self.boft_block_size == 0 and self.boft_block_num == 0: raise ValueError( f"Either `boft_block_size` or `boft_block_num` must be non-zero. Currently, boft_block_size = {self.boft_block_size} and boft_block_num = {self.boft_block_num}." diff --git a/src/peft/tuners/fourierft/config.py b/src/peft/tuners/fourierft/config.py index 93aaf85f02..a95896c125 100644 --- a/src/peft/tuners/fourierft/config.py +++ b/src/peft/tuners/fourierft/config.py @@ -199,3 +199,6 @@ def __post_init__(self): # if target_modules is a regex expression, then layers_pattern should be None if isinstance(self.target_modules, str) and self.layers_pattern is not None: raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.") + # check for layers_to_transform and layers_pattern + if self.layers_pattern and not self.layers_to_transform: + raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ") diff --git a/src/peft/tuners/hra/config.py b/src/peft/tuners/hra/config.py index 81997bbfeb..470af9019e 100644 --- a/src/peft/tuners/hra/config.py +++ b/src/peft/tuners/hra/config.py @@ -129,3 +129,7 @@ def __post_init__(self): # if target_modules is a regex expression, then layers_pattern should be None if isinstance(self.target_modules, str) and self.layers_pattern is not None: raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.") + + # check for layers_to_transform and layers_pattern + if self.layers_pattern and not self.layers_to_transform: + raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ") diff --git a/src/peft/tuners/loha/config.py b/src/peft/tuners/loha/config.py index aa05121be2..998bb26eaf 100644 --- a/src/peft/tuners/loha/config.py +++ b/src/peft/tuners/loha/config.py @@ -133,3 +133,6 @@ def __post_init__(self): self.exclude_modules = ( set(self.exclude_modules) if isinstance(self.exclude_modules, list) else self.exclude_modules ) + # check for layers_to_transform and layers_pattern + if self.layers_pattern and not self.layers_to_transform: + raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ") diff --git a/src/peft/tuners/lokr/config.py b/src/peft/tuners/lokr/config.py index b4879af558..a382f8cd68 100644 --- a/src/peft/tuners/lokr/config.py +++ b/src/peft/tuners/lokr/config.py @@ -142,3 +142,6 @@ def __post_init__(self): self.exclude_modules = ( set(self.exclude_modules) if isinstance(self.exclude_modules, list) else self.exclude_modules ) + # check for layers_to_transform and layers_pattern + if self.layers_pattern and not self.layers_to_transform: + raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ") diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index 10522dcc42..2d18f5bb90 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -340,6 +340,7 @@ def __post_init__(self): self.exclude_modules = ( set(self.exclude_modules) if isinstance(self.exclude_modules, list) else self.exclude_modules ) + # if target_modules is a regex expression, then layers_to_transform should be None if isinstance(self.target_modules, str) and self.layers_to_transform is not None: raise ValueError("`layers_to_transform` cannot be used when `target_modules` is a str.") @@ -348,6 +349,10 @@ def __post_init__(self): if isinstance(self.target_modules, str) and self.layers_pattern is not None: raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.") + # check for layers_to_transform and layers_pattern + if self.layers_pattern and not self.layers_to_transform: + raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ") + if self.use_dora and self.megatron_config: raise ValueError("DoRA does not support megatron_core, please set `use_dora=False`.") diff --git a/src/peft/tuners/oft/config.py b/src/peft/tuners/oft/config.py index 40aad205d3..8f54000e63 100644 --- a/src/peft/tuners/oft/config.py +++ b/src/peft/tuners/oft/config.py @@ -176,6 +176,9 @@ def __post_init__(self): self.exclude_modules = ( set(self.exclude_modules) if isinstance(self.exclude_modules, list) else self.exclude_modules ) + # check for layers_to_transform and layers_pattern + if self.layers_pattern and not self.layers_to_transform: + raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ") if self.r == 0 and self.oft_block_size == 0: raise ValueError( f"Either `r` or `oft_block_size` must be non-zero. Currently, r = {self.r} and oft_block_size = {self.oft_block_size}." diff --git a/src/peft/tuners/vblora/config.py b/src/peft/tuners/vblora/config.py index a493c35312..c4d25c63e3 100644 --- a/src/peft/tuners/vblora/config.py +++ b/src/peft/tuners/vblora/config.py @@ -190,3 +190,6 @@ def __post_init__(self): self.exclude_modules = ( set(self.exclude_modules) if isinstance(self.exclude_modules, list) else self.exclude_modules ) + # check for layers_to_transform and layers_pattern + if self.layers_pattern and not self.layers_to_transform: + raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ") diff --git a/src/peft/tuners/vera/config.py b/src/peft/tuners/vera/config.py index 45ef903756..1aa3a7b4d7 100644 --- a/src/peft/tuners/vera/config.py +++ b/src/peft/tuners/vera/config.py @@ -150,7 +150,9 @@ def __post_init__(self): self.target_modules = ( set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules ) - + # check for layers_to_transform and layers_pattern + if self.layers_pattern and not self.layers_to_transform: + raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ") if not self.save_projection: warnings.warn( "Specified to not save vera_A and vera_B within the state dictionary, instead they will be restored " diff --git a/tests/test_config.py b/tests/test_config.py index f02c28d197..3726840ecf 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -351,3 +351,30 @@ def test_from_pretrained_sanity_check(self, config_class, tmp_path): msg = f"The config that is trying to be loaded is not a valid {config_class.__name__} config" with pytest.raises(TypeError, match=msg): config_class.from_pretrained(tmp_path) + + def test_lora_config_layers_to_transform_validation(self): + """Test that specifying layers_pattern without layers_to_transform raises an error""" + with pytest.raises( + ValueError, match="When `layers_pattern` is specified, `layers_to_transform` must also be specified." + ): + LoraConfig(r=8, lora_alpha=16, target_modules=["query", "value"], layers_pattern="model.layers") + + # Test that specifying both layers_to_transform and layers_pattern works fine + config = LoraConfig( + r=8, + lora_alpha=16, + target_modules=["query", "value"], + layers_to_transform=[0, 1, 2], + layers_pattern="model.layers", + ) + assert config.layers_to_transform == [0, 1, 2] + assert config.layers_pattern == "model.layers" + + # Test that not specifying either works fine + config = LoraConfig( + r=8, + lora_alpha=16, + target_modules=["query", "value"], + ) + assert config.layers_to_transform is None + assert config.layers_pattern is None diff --git a/tests/test_tuners_utils.py b/tests/test_tuners_utils.py index 06a47deb26..09cd62e04f 100644 --- a/tests/test_tuners_utils.py +++ b/tests/test_tuners_utils.py @@ -92,9 +92,6 @@ ("foo.bar.1.baz", ["baz"], [0, 1, 2], ["bar"], True), ("foo.bar.1.baz", ["baz", "spam"], [1], ["bar"], True), ("foo.bar.1.baz", ["baz", "spam"], [0, 1, 2], ["bar"], True), - # empty layers_to_transform - ("foo.bar.7.baz", ["baz"], [], ["bar"], True), - ("foo.bar.7.baz", ["baz"], None, ["bar"], True), # empty layers_pattern ("foo.whatever.1.baz", ["baz"], [1], [], True), ("foo.whatever.1.baz", ["baz"], [0], [], False),