Skip to content

Commit e7e7e9d

Browse files
authored
Fix activation lookup with Python 3.12.3 (#375)
We used the metaclass `EnumMeta`/`EnumType` to override reporting of missing enum values (to give the full set of supported activations). However, in Python 3.12.3, the default value of the `name` parameter of `EnumType.__call__` method was changed from `None` to `_not_given`: python/cpython@d771729 Even though this is a public API (which now uses a private default value), it seems too risky to continue using it. So in this change, we implement `Enum.__mising__` instead for the improved error reporting.
1 parent 8debb21 commit e7e7e9d

File tree

1 file changed

+10
-41
lines changed

1 file changed

+10
-41
lines changed

curated_transformers/layers/activations.py

Lines changed: 10 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,13 @@
11
import math
2-
from enum import Enum, EnumMeta
2+
from enum import Enum
33
from typing import Type
44

55
import torch
66
from torch import Tensor
77
from torch.nn import Module
88

99

10-
class _ActivationMeta(EnumMeta):
11-
"""
12-
``Enum`` metaclass to override the class ``__call__`` method with a more
13-
fine-grained exception for unknown activation functions.
14-
"""
15-
16-
def __call__(
17-
cls,
18-
value,
19-
names=None,
20-
*,
21-
module=None,
22-
qualname=None,
23-
type=None,
24-
start=1,
25-
):
26-
# Wrap superclass __call__ to give a nicer error message when
27-
# an unknown activation is used.
28-
if names is None:
29-
try:
30-
return EnumMeta.__call__(
31-
cls,
32-
value,
33-
names,
34-
module=module,
35-
qualname=qualname,
36-
type=type,
37-
start=start,
38-
)
39-
except ValueError:
40-
supported_activations = ", ".join(sorted(v.value for v in cls))
41-
raise ValueError(
42-
f"Invalid activation function `{value}`. "
43-
f"Supported functions: {supported_activations}"
44-
)
45-
else:
46-
return EnumMeta.__call__(cls, value, names, module, qualname, type, start)
47-
48-
49-
class Activation(Enum, metaclass=_ActivationMeta):
10+
class Activation(Enum):
5011
"""
5112
Activation functions.
5213
@@ -71,6 +32,14 @@ class Activation(Enum, metaclass=_ActivationMeta):
7132
#: Sigmoid Linear Unit (`Hendrycks et al., 2016`_).
7233
SiLU = "silu"
7334

35+
@classmethod
36+
def _missing_(cls, value):
37+
supported_activations = ", ".join(sorted(v.value for v in cls))
38+
raise ValueError(
39+
f"Invalid activation function `{value}`. "
40+
f"Supported functions: {supported_activations}"
41+
)
42+
7443
@property
7544
def module(self) -> Type[torch.nn.Module]:
7645
"""

0 commit comments

Comments
 (0)