Skip to content

Commit

Permalink
Add get model config tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ltoniazzi authored and Lorenzo Toniazzi committed Aug 25, 2024
1 parent eca05db commit cf4bf3e
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tests/test_tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@
get_peft_model,
)
from peft.tuners.tuners_utils import (
BaseTuner,
BaseTunerLayer,
_maybe_include_all_linear_layers,
check_target_module_exists,
inspect_matched_modules,
)
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, ModulesToSaveWrapper, infer_device
from peft.utils.constants import DUMMY_MODEL_CONFIG

from .testing_utils import require_bitsandbytes, require_non_cpu, require_torch_gpu

Expand Down Expand Up @@ -1065,3 +1067,39 @@ def forward(self, X):

with pytest.raises(TypeError, match="get_model_status is not supported for PeftMixedModel"):
model.get_model_status()


# Tests for BaseTuner
class MockModelConfig:
config = {"mock_key": "mock_value"}

def to_dict(self):
return self.config


class ModelWithConfig(nn.Module):
def __init__(self):
self.config = MockModelConfig()


class ModelWithDictConfig(nn.Module):
def __init__(self):
self.config = MockModelConfig.config


class ModelWithNoConfig(nn.Module):
pass


class TestBaseTunerMethods(unittest.TestCase):
def test_get_model_config_use_to_dict(self):
config = BaseTuner.get_model_config(ModelWithConfig())
assert config == MockModelConfig.config

def test_get_model_config_as_dict(self):
config = BaseTuner.get_model_config(ModelWithDictConfig())
assert config == MockModelConfig.config

def test_get_model_config_with_no_config(self):
config = BaseTuner.get_model_config(ModelWithNoConfig())
assert config == DUMMY_MODEL_CONFIG

0 comments on commit cf4bf3e

Please sign in to comment.