Skip to content

Commit

Permalink
fix config
Browse files Browse the repository at this point in the history
  • Loading branch information
mayank31398 committed Sep 13, 2024
1 parent b5abefb commit f3283a1
Showing 1 changed file with 3 additions and 20 deletions.
23 changes: 3 additions & 20 deletions src/transformers/models/granitemoe/configuration_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""GraniteMoe model configuration"""

from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
from ...utils import logging


Expand Down Expand Up @@ -165,7 +166,7 @@ def __init__(
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()

self.attention_bias = attention_bias
self.attention_dropout = attention_dropout

Expand All @@ -187,22 +188,4 @@ def __init__(
**kwargs,
)

def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
rope_config_validation(self)

0 comments on commit f3283a1

Please sign in to comment.