From 276690ac91653cb94778574a5425d2fad8277b38 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Tue, 23 Jan 2024 16:56:46 -0800 Subject: [PATCH 1/2] add positional args --- src/peft/peft_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 974cedace3..30154b8848 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -1130,14 +1130,14 @@ def forward( inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) return self.base_model(inputs_embeds=inputs_embeds, **kwargs) - def generate(self, **kwargs): + def generate(self, *args, **kwargs): self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation if hasattr(self.base_model, "model"): self.base_model.model.generation_config = self.generation_config else: self.base_model.generation_config = self.generation_config try: - outputs = self.base_model.generate(**kwargs) + outputs = self.base_model.generate(*args, **kwargs) except: self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation raise From 1a1599e0154483d9c4f4c38caec4bdcfc33c7d90 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Wed, 24 Jan 2024 13:47:49 -0800 Subject: [PATCH 2/2] update tests --- tests/test_adaption_prompt.py | 5 ++--- tests/test_decoder_models.py | 5 +++++ tests/test_encoder_decoder_models.py | 6 ++++++ tests/test_multitask_prompt_tuning.py | 5 ++--- tests/test_tuners_utils.py | 2 +- tests/testing_common.py | 22 ++++++++++++++++------ 6 files changed, 32 insertions(+), 13 deletions(-) diff --git a/tests/test_adaption_prompt.py b/tests/test_adaption_prompt.py index 92bbd72017..2607c185e9 100644 --- a/tests/test_adaption_prompt.py +++ b/tests/test_adaption_prompt.py @@ -267,9 +267,8 @@ def test_generate(self) -> None: # check if `generate` works _ = model.generate(input_ids=input_ids, attention_mask=attention_mask) - with self.assertRaises(TypeError): - # check if `generate` raises an error if no positional arguments are passed - _ = model.generate(input_ids, attention_mask=attention_mask) + # check if `generate` works if positional arguments are passed + _ = model.generate(input_ids, attention_mask=attention_mask) def test_sequence_adapter_ops(self) -> None: """Test sequence of adapter operations.""" diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index ab49c3eea5..f7bf9efc28 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -195,6 +195,11 @@ def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs): def test_generate(self, test_name, model_id, config_cls, config_kwargs): self._test_generate(model_id, config_cls, config_kwargs) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_generate_pos_args(self, test_name, model_id, config_cls, config_kwargs): + # positional args are supported for PeftModelForCausalLM + self._test_generate_pos_args(model_id, config_cls, config_kwargs, raises_err=False) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs): self._test_merge_layers_fp16(model_id, config_cls, config_kwargs) diff --git a/tests/test_encoder_decoder_models.py b/tests/test_encoder_decoder_models.py index 8aab9ed044..3669f712b4 100644 --- a/tests/test_encoder_decoder_models.py +++ b/tests/test_encoder_decoder_models.py @@ -104,6 +104,12 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): def test_generate(self, test_name, model_id, config_cls, config_kwargs): self._test_generate(model_id, config_cls, config_kwargs) + # skip non lora models - generate does not work for prefix tuning, prompt tuning + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_generate_pos_args(self, test_name, model_id, config_cls, config_kwargs): + # positional arguments are not supported for PeftModelForSeq2SeqLM + self._test_generate_pos_args(model_id, config_cls, config_kwargs, raises_err=True) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs): self._test_generate_half_prec(model_id, config_cls, config_kwargs) diff --git a/tests/test_multitask_prompt_tuning.py b/tests/test_multitask_prompt_tuning.py index be548aaa3e..b9e229073a 100644 --- a/tests/test_multitask_prompt_tuning.py +++ b/tests/test_multitask_prompt_tuning.py @@ -214,9 +214,8 @@ def test_generate(self) -> None: # check if `generate` works _ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids) - with self.assertRaises(TypeError): - # check if `generate` raises an error if no positional arguments are passed - _ = model.generate(input_ids, attention_mask=attention_mask) + # check if `generate` works if positional arguments are passed + _ = model.generate(input_ids, attention_mask=attention_mask, task_ids=task_ids) def test_use_cache(self) -> None: """Test that MultiTaskPromptTuning works when Llama config use_cache=True.""" diff --git a/tests/test_tuners_utils.py b/tests/test_tuners_utils.py index 9678862cc6..d931d795c2 100644 --- a/tests/test_tuners_utils.py +++ b/tests/test_tuners_utils.py @@ -24,11 +24,11 @@ from peft import IA3Config, LoHaConfig, LoraConfig, get_peft_model from peft.tuners.tuners_utils import ( - INCLUDE_LINEAR_LAYERS_SHORTHAND, _maybe_include_all_linear_layers, check_target_module_exists, inspect_matched_modules, ) +from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND from .testing_utils import require_bitsandbytes, require_torch_gpu diff --git a/tests/testing_common.py b/tests/testing_common.py index 1c3bf8ebb1..8faceccbbe 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -650,8 +650,22 @@ def _test_generate(self, model_id, config_cls, config_kwargs): # check if `generate` works _ = model.generate(**inputs) - with self.assertRaises(TypeError): - # check if `generate` raises an error if no positional arguments are passed + def _test_generate_pos_args(self, model_id, config_cls, config_kwargs, raises_err: bool): + model = self.transformers_class.from_pretrained(model_id) + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + inputs = self.prepare_inputs_for_testing() + if raises_err: + with self.assertRaises(TypeError): + # check if `generate` raises an error if positional arguments are passed + _ = model.generate(inputs["input_ids"]) + else: + # check if `generate` works if positional arguments are passed _ = model.generate(inputs["input_ids"]) def _test_generate_half_prec(self, model_id, config_cls, config_kwargs): @@ -672,10 +686,6 @@ def _test_generate_half_prec(self, model_id, config_cls, config_kwargs): # check if `generate` works _ = model.generate(input_ids=input_ids, attention_mask=attention_mask) - with self.assertRaises(TypeError): - # check if `generate` raises an error if no positional arguments are passed - _ = model.generate(input_ids, attention_mask=attention_mask) - def _test_prefix_tuning_half_prec_conversion(self, model_id, config_cls, config_kwargs): if config_cls not in (PrefixTuningConfig,): return