Skip to content

Commit

Permalink
Decouple tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ltoniazzi authored and Lorenzo Toniazzi committed Aug 26, 2024
1 parent 7f290ea commit 7926888
Showing 1 changed file with 39 additions and 77 deletions.
116 changes: 39 additions & 77 deletions tests/test_tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# limitations under the License.
import re
import unittest
import warnings
from copy import deepcopy

import pytest
Expand Down Expand Up @@ -1092,7 +1091,7 @@ class ModelWithNoConfig(nn.Module):
pass


class TestBaseTunerMethods(unittest.TestCase):
class TestBaseTunerGetModelConfig(unittest.TestCase):
def test_get_model_config_use_to_dict(self):
config = BaseTuner.get_model_config(ModelWithConfig())
assert config == MockModelConfig.config
Expand All @@ -1105,82 +1104,45 @@ def test_get_model_config_with_no_config(self):
config = BaseTuner.get_model_config(ModelWithNoConfig())
assert config == DUMMY_MODEL_CONFIG

def test_warn_for_tied_embeddings_inject_and_merge(self):
model_id = "HuggingFaceH4/tiny-random-LlamaForCausalLM"
warn_start = "Model with `tie_word_embeddings=True` and the tied_target_modules=['lm_head']"
warn_end_inject = "huggingface/peft/issues/2018."
warn_end_merge = "tie_word_embeddings=False)\n```\n"

def assert_warning_triggered(records, warn_start, warn_end, triggered=True):
warning_triggered = False
for record in records:
if str(record.message).startswith(warn_start) and str(record.message).endswith(warn_end):
warning_triggered = True
if triggered:
assert warning_triggered
else:
assert not warning_triggered

# Capture warning when loading model and merging with tie_word_embeddings and relevant target module
with pytest.warns(UserWarning) as records_inject:
model = get_peft_model(
AutoModelForCausalLM.from_pretrained(model_id, tie_word_embeddings=True),
LoraConfig(target_modules=["lm_head"]),
)
with pytest.warns(UserWarning) as records_merge:
model.merge_and_unload(safe_merge=True, adapter_names=["default"])

assert_warning_triggered(
records_inject,
warn_start=warn_start,
warn_end=warn_end_inject,
)
assert_warning_triggered(
records_merge,
warn_start=warn_start,
warn_end=warn_end_merge,
)
class TestBaseTunerWarnForTiedEmbeddings:
model_id = "HuggingFaceH4/tiny-random-LlamaForCausalLM"
warn_end_inject = "huggingface/peft/issues/2018."
warn_end_merge = "tie_word_embeddings=False)\n```\n"

# No warning when loading model with no tie_word_embeddings although a relevant target module
with warnings.catch_warnings(record=True) as records_inject_not_tied:
model_not_tied = get_peft_model(
AutoModelForCausalLM.from_pretrained(model_id, tie_word_embeddings=False),
LoraConfig(target_modules=["lm_head"]),
)
with warnings.catch_warnings(record=True) as records_merge_not_tied:
model_not_tied.merge_and_unload(safe_merge=True, adapter_names=["default"])

assert_warning_triggered(
records_inject_not_tied,
warn_start=warn_start,
warn_end=warn_end_inject,
triggered=False,
)
assert_warning_triggered(
records_merge_not_tied,
warn_start=warn_start,
warn_end=warn_end_merge,
triggered=False,
def _get_peft_model(self, tie_word_embeddings, target_module):
model = get_peft_model(
AutoModelForCausalLM.from_pretrained(self.model_id, tie_word_embeddings=tie_word_embeddings),
LoraConfig(target_modules=[target_module]),
)
return model

# No warning when loading model with tie_word_embeddings but not relevant target module
with warnings.catch_warnings(record=True) as records_inject_tied_no_target:
model_not_tied = get_peft_model(
AutoModelForCausalLM.from_pretrained(model_id, tie_word_embeddings=False),
LoraConfig(target_modules=["q_proj"]),
)
with warnings.catch_warnings(record=True) as records_merge_tied_no_target:
model_not_tied.merge_and_unload(safe_merge=True, adapter_names=["default"])

assert_warning_triggered(
records_inject_tied_no_target,
warn_start=warn_start,
warn_end=warn_end_inject,
triggered=False,
)
assert_warning_triggered(
records_merge_tied_no_target,
warn_start=warn_start,
warn_end=warn_end_merge,
triggered=False,
)
def _is_warn_triggered(self, rrecwarn, endswith):
return any(str(warning.message).endswith(endswith) for warning in rrecwarn.list)

def test_warn_for_tied_embeddings_inject(self, recwarn):
self._get_peft_model(tie_word_embeddings=True, target_module="lm_head")
assert self._is_warn_triggered(recwarn, self.warn_end_inject)

def test_warn_for_tied_embeddings_merge(self, recwarn):
model = self._get_peft_model(tie_word_embeddings=True, target_module="lm_head")
model.merge_and_unload()
assert self._is_warn_triggered(recwarn, self.warn_end_merge)

def test_no_warn_for_untied_embeddings_inject(self, recwarn):
self._get_peft_model(tie_word_embeddings=False, target_module="lm_head")
assert not self._is_warn_triggered(recwarn, self.warn_end_inject)

def test_no_warn_for_untied_embeddings_merge(self, recwarn):
model_not_tied = self._get_peft_model(tie_word_embeddings=False, target_module="lm_head")
model_not_tied.merge_and_unload()
assert not self._is_warn_triggered(recwarn, self.warn_end_merge)

def test_no_warn_for_no_target_module_inject(self, recwarn):
self._get_peft_model(tie_word_embeddings=True, target_module="q_proj")
assert not self._is_warn_triggered(recwarn, self.warn_end_inject)

def test_no_warn_for_no_target_module_merge(self, recwarn):
model_no_target_module = self._get_peft_model(tie_word_embeddings=True, target_module="q_proj")
model_no_target_module.merge_and_unload()
assert not self._is_warn_triggered(recwarn, self.warn_end_merge)

0 comments on commit 7926888

Please sign in to comment.