From 2a2ca859fe0d32163ca7b12e1e237144aab19819 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 26 Sep 2024 16:03:04 +0200 Subject: [PATCH 1/3] ignore keys on check rope --- src/transformers/modeling_rope_utils.py | 40 ++++++++++++------- .../models/qwen2_vl/configuration_qwen2_vl.py | 6 ++- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index e7aa1ceb921329..150c402c5e1371 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -360,13 +360,23 @@ def _compute_llama3_parameters( } -def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None): +def _check_received_keys( + rope_type: str, + received_keys: set, + required_keys: set, + optional_keys: Optional[set] = None, + ignore_keys: Optional[set] = None, +): """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present if "type" in received_keys: received_keys -= {"type"} required_keys.add("rope_type") + # Some models need to store model-specific keys, and we don't want to throw warning at them + if ignore_keys is not None: + received_keys -= ignore_keys + missing_keys = required_keys - received_keys if missing_keys: raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}") @@ -379,47 +389,47 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") -def _validate_default_rope_parameters(config: PretrainedConfig): +def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type"} received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) -def _validate_linear_scaling_rope_parameters(config: PretrainedConfig): +def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor"} received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") -def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig): +def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor"} # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` optional_keys = {"original_max_position_embeddings"} received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, optional_keys) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") -def _validate_yarn_parameters(config: PretrainedConfig): +def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor"} optional_keys = {"attention_factor", "beta_fast", "beta_slow"} received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, optional_keys) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: @@ -444,14 +454,14 @@ def _validate_yarn_parameters(config: PretrainedConfig): ) -def _validate_longrope_parameters(config: PretrainedConfig): +def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "short_factor", "long_factor"} # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, optional_keys) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) @@ -494,12 +504,12 @@ def _validate_longrope_parameters(config: PretrainedConfig): ) -def _validate_llama3_parameters(config: PretrainedConfig): +def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"} received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: @@ -541,7 +551,7 @@ def _validate_llama3_parameters(config: PretrainedConfig): } -def rope_config_validation(config: PretrainedConfig): +def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None): """ Validate the RoPE config arguments, given a `PretrainedConfig` object """ @@ -553,7 +563,7 @@ def rope_config_validation(config: PretrainedConfig): rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) if validation_fn is not None: - validation_fn(config) + validation_fn(config, ignore_keys=ignore_keys) else: logger.warning( f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" diff --git a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py index 27615eb789f0b0..1349006e768cd4 100644 --- a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py @@ -235,11 +235,13 @@ def __init__( # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, move it to 'rope_type'. - # and change type from 'mrope' to 'default' + # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations + # one can set it to "linear"/"dynamic" etc. to have scaled RoPE + # TODO: @raushan update config in the hub if self.rope_scaling is not None and "type" in self.rope_scaling: if self.rope_scaling["type"] == "mrope": self.rope_scaling["type"] = "default" self.rope_scaling["rope_type"] = self.rope_scaling["type"] - rope_config_validation(self) + rope_config_validation(self, ignore_keys={"mrope_section"}) super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) From 75098338cc54a323f48845a3fad3075f7351b886 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 30 Sep 2024 10:49:51 +0200 Subject: [PATCH 2/3] add tests --- src/transformers/modeling_rope_utils.py | 45 +++++++++++++------------ tests/utils/test_modeling_rope_utils.py | 13 +++++++ 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index 150c402c5e1371..c157aa30d3427c 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import math +import warnings from typing import Optional, Tuple from .configuration_utils import PretrainedConfig @@ -386,7 +387,7 @@ def _check_received_keys( else: unused_keys = received_keys - required_keys if unused_keys: - logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") + warnings.warn(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): @@ -406,7 +407,7 @@ def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_ke factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + warnings.warn(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): @@ -420,7 +421,7 @@ def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_k factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + warnings.warn(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): @@ -433,22 +434,22 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + warnings.warn(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") attention_factor = rope_scaling.get("attention_factor") if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): - logger.warning( + warnings.warn( f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" ) beta_fast = rope_scaling.get("beta_fast") if beta_fast is not None and not isinstance(beta_fast, float): - logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + warnings.warn(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") beta_slow = rope_scaling.get("beta_slow") if beta_slow is not None and not isinstance(beta_slow, float): - logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + warnings.warn(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") if (beta_fast or 32) < (beta_slow or 1): - logger.warning( + warnings.warn( f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" ) @@ -469,15 +470,15 @@ def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optiona short_factor = rope_scaling.get("short_factor") if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): - logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}") + warnings.warn(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}") if not len(short_factor) == dim // 2: - logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}") + warnings.warn(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}") long_factor = rope_scaling.get("long_factor") if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor): - logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}") + warnings.warn(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}") if not len(long_factor) == dim // 2: - logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}") + warnings.warn(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}") # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is @@ -492,14 +493,14 @@ def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optiona else: factor = rope_scaling.get("factor") if factor is None: - logger.warning("Missing required keys in `rope_scaling`: 'factor'") + warnings.warn("Missing required keys in `rope_scaling`: 'factor'") elif not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + warnings.warn(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") attention_factor = rope_scaling.get("attention_factor") if attention_factor is not None: if not isinstance(attention_factor, float) or attention_factor < 0.0: - logger.warning( + warnings.warn( f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" ) @@ -513,28 +514,28 @@ def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[ factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + warnings.warn(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") low_freq_factor = rope_scaling["low_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"] if low_freq_factor is None or not isinstance(low_freq_factor, float): - logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") + warnings.warn(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") if high_freq_factor is None or not isinstance(high_freq_factor, float): - logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") + warnings.warn(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") if high_freq_factor <= low_freq_factor: - logger.warning( + warnings.warn( "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" f"{high_freq_factor} and low_freq_factor={low_freq_factor}" ) original_max_position_embeddings = rope_scaling["original_max_position_embeddings"] if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int): - logger.warning( + warnings.warn( "`rope_scaling`'s original_max_position_embeddings field must be an integer, got " f"{original_max_position_embeddings}" ) if original_max_position_embeddings >= config.max_position_embeddings: - logger.warning( + warnings.warn( "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got " f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}" ) @@ -565,6 +566,6 @@ def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] if validation_fn is not None: validation_fn(config, ignore_keys=ignore_keys) else: - logger.warning( + warnings.warn( f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" ) diff --git a/tests/utils/test_modeling_rope_utils.py b/tests/utils/test_modeling_rope_utils.py index a1d1fd6b922ab3..ac204d56bedf75 100644 --- a/tests/utils/test_modeling_rope_utils.py +++ b/tests/utils/test_modeling_rope_utils.py @@ -16,6 +16,7 @@ import math import unittest +import warnings from transformers import LlamaConfig from transformers.testing_utils import is_torch_available, require_torch, torch_device @@ -65,6 +66,18 @@ def test_rope_validation(self): with self.assertRaises(KeyError): rope_config_validation(config) + # Any other parameters passed to RoPE will raise a warning that a particular key is not used + # But sometimes we can have model-specific RoPE kwargs and bypass warning with `ignore_keys` + model_specific_kwarg = "mrope_sections" # e,g in Qwen2-VL + + for rope_type in all_rope_types: + if rope_type == "default": + config.rope_scaling = {"rope_type": rope_type, model_specific_kwarg: True} + rope_config_validation(config, ignore_keys={model_specific_kwarg}) + with warnings.catch_warnings(record=True) as warning_list: + rope_config_validation(config) + self.assertEqual(len(warning_list), 1) + def test_default_rope_function_bc(self): config = LlamaConfig() device = torch_device From 07e2bcfe83e577c515409719633319f82821aa08 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 1 Oct 2024 10:46:22 +0200 Subject: [PATCH 3/3] fix tests, so maybe better leave at logger lvl --- src/transformers/modeling_rope_utils.py | 45 ++++++++++++------------- tests/utils/test_modeling_rope_utils.py | 6 ++-- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index c157aa30d3427c..150c402c5e1371 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -13,7 +13,6 @@ # limitations under the License. import math -import warnings from typing import Optional, Tuple from .configuration_utils import PretrainedConfig @@ -387,7 +386,7 @@ def _check_received_keys( else: unused_keys = received_keys - required_keys if unused_keys: - warnings.warn(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") + logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): @@ -407,7 +406,7 @@ def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_ke factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: - warnings.warn(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): @@ -421,7 +420,7 @@ def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_k factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: - warnings.warn(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): @@ -434,22 +433,22 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: - warnings.warn(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") attention_factor = rope_scaling.get("attention_factor") if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): - warnings.warn( + logger.warning( f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" ) beta_fast = rope_scaling.get("beta_fast") if beta_fast is not None and not isinstance(beta_fast, float): - warnings.warn(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") beta_slow = rope_scaling.get("beta_slow") if beta_slow is not None and not isinstance(beta_slow, float): - warnings.warn(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") if (beta_fast or 32) < (beta_slow or 1): - warnings.warn( + logger.warning( f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" ) @@ -470,15 +469,15 @@ def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optiona short_factor = rope_scaling.get("short_factor") if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): - warnings.warn(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}") + logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}") if not len(short_factor) == dim // 2: - warnings.warn(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}") + logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}") long_factor = rope_scaling.get("long_factor") if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor): - warnings.warn(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}") + logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}") if not len(long_factor) == dim // 2: - warnings.warn(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}") + logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}") # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is @@ -493,14 +492,14 @@ def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optiona else: factor = rope_scaling.get("factor") if factor is None: - warnings.warn("Missing required keys in `rope_scaling`: 'factor'") + logger.warning("Missing required keys in `rope_scaling`: 'factor'") elif not isinstance(factor, float) or factor < 1.0: - warnings.warn(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") attention_factor = rope_scaling.get("attention_factor") if attention_factor is not None: if not isinstance(attention_factor, float) or attention_factor < 0.0: - warnings.warn( + logger.warning( f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" ) @@ -514,28 +513,28 @@ def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[ factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: - warnings.warn(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") low_freq_factor = rope_scaling["low_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"] if low_freq_factor is None or not isinstance(low_freq_factor, float): - warnings.warn(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") + logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") if high_freq_factor is None or not isinstance(high_freq_factor, float): - warnings.warn(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") + logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") if high_freq_factor <= low_freq_factor: - warnings.warn( + logger.warning( "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" f"{high_freq_factor} and low_freq_factor={low_freq_factor}" ) original_max_position_embeddings = rope_scaling["original_max_position_embeddings"] if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int): - warnings.warn( + logger.warning( "`rope_scaling`'s original_max_position_embeddings field must be an integer, got " f"{original_max_position_embeddings}" ) if original_max_position_embeddings >= config.max_position_embeddings: - warnings.warn( + logger.warning( "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got " f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}" ) @@ -566,6 +565,6 @@ def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] if validation_fn is not None: validation_fn(config, ignore_keys=ignore_keys) else: - warnings.warn( + logger.warning( f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" ) diff --git a/tests/utils/test_modeling_rope_utils.py b/tests/utils/test_modeling_rope_utils.py index ac204d56bedf75..d51f534055872a 100644 --- a/tests/utils/test_modeling_rope_utils.py +++ b/tests/utils/test_modeling_rope_utils.py @@ -16,7 +16,6 @@ import math import unittest -import warnings from transformers import LlamaConfig from transformers.testing_utils import is_torch_available, require_torch, torch_device @@ -74,9 +73,10 @@ def test_rope_validation(self): if rope_type == "default": config.rope_scaling = {"rope_type": rope_type, model_specific_kwarg: True} rope_config_validation(config, ignore_keys={model_specific_kwarg}) - with warnings.catch_warnings(record=True) as warning_list: + with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: rope_config_validation(config) - self.assertEqual(len(warning_list), 1) + self.assertEqual(len(logs.output), 1) + self.assertIn(model_specific_kwarg, logs.output[0]) def test_default_rope_function_bc(self): config = LlamaConfig()