Skip to content

Commit

Permalink
Backport: Fix activation lookup with Python 3.12.3 (#375) (#377)
Browse files Browse the repository at this point in the history
* Fix activation lookup with Python 3.12.3 (#375)

We used the metaclass `EnumMeta`/`EnumType` to override reporting of
missing enum values (to give the full set of supported activations).
However, in Python 3.12.3, the default value of the `name` parameter of
`EnumType.__call__` method was changed from `None` to `_not_given`:

python/cpython@d771729

Even though this is a public API (which now uses a private default
value), it seems too risky to continue using it. So in this change, we
implement `Enum.__mising__` instead for the improved error reporting.

* Set version to 1.3.2

* Adjust two cross-tests for changes in HF transformers (#367)

* Fix `test_rotary_embeddings_against_hf` for latest transformers

* xfail test because HfFileSystem is currently broken
  • Loading branch information
danieldk authored Apr 17, 2024
1 parent b192987 commit 3e6180f
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 43 deletions.
51 changes: 10 additions & 41 deletions curated_transformers/layers/activations.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,13 @@
import math
from enum import Enum, EnumMeta
from enum import Enum
from typing import Type

import torch
from torch import Tensor
from torch.nn import Module


class _ActivationMeta(EnumMeta):
"""
``Enum`` metaclass to override the class ``__call__`` method with a more
fine-grained exception for unknown activation functions.
"""

def __call__(
cls,
value,
names=None,
*,
module=None,
qualname=None,
type=None,
start=1,
):
# Wrap superclass __call__ to give a nicer error message when
# an unknown activation is used.
if names is None:
try:
return EnumMeta.__call__(
cls,
value,
names,
module=module,
qualname=qualname,
type=type,
start=start,
)
except ValueError:
supported_activations = ", ".join(sorted(v.value for v in cls))
raise ValueError(
f"Invalid activation function `{value}`. "
f"Supported functions: {supported_activations}"
)
else:
return EnumMeta.__call__(cls, value, names, module, qualname, type, start)


class Activation(Enum, metaclass=_ActivationMeta):
class Activation(Enum):
"""
Activation functions.
Expand All @@ -71,6 +32,14 @@ class Activation(Enum, metaclass=_ActivationMeta):
#: Sigmoid Linear Unit (`Hendrycks et al., 2016`_).
SiLU = "silu"

@classmethod
def _missing_(cls, value):
supported_activations = ", ".join(sorted(v.value for v in cls))
raise ValueError(
f"Invalid activation function `{value}`. "
f"Supported functions: {supported_activations}"
)

@property
def module(self) -> Type[torch.nn.Module]:
"""
Expand Down
3 changes: 2 additions & 1 deletion curated_transformers/tests/layers/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def test_rotary_embeddings_against_hf(device):

X = torch.rand(16, 12, 64, 768, device=device)
Y = re(X)
hf_re_cos, hf_re_sin = hf_re(X, seq_len=X.shape[-2])
positions = torch.arange(X.shape[2], device=device).view([1, -1])
hf_re_cos, hf_re_sin = hf_re(X, positions)
Y_hf = hf_re_cos * X + hf_re_sin * rotate_half(X)

torch_assertclose(Y, Y_hf)
Expand Down
1 change: 1 addition & 0 deletions curated_transformers/tests/tokenizers/test_hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_from_hf_hub_to_cache_legacy():
)


@pytest.mark.xfail(reason="HfFileSystem calls safetensors with incorrect arguments")
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
def test_fsspec(sample_texts):
# We only test one model, since using fsspec downloads the model
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[metadata]
version = 1.3.1
version = 1.3.2
description = A PyTorch library of transformer models and components
url = https://github.com/explosion/curated-transformers
author = Explosion
Expand Down

0 comments on commit 3e6180f

Please sign in to comment.