Skip to content

Commit 32d0be5

Browse files
authored
Introduce hook factories for enums (#705)
* Introduce hook factories for enums * Fix * Fix * Fix * Fix * Fix * Remove dead code * Add FunctionDispatch test * Tweak factory registration * Add PR link
1 parent 6941413 commit 32d0be5

File tree

12 files changed

+80
-31
lines changed

12 files changed

+80
-31
lines changed

HISTORY.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ Our backwards-compatibility policy can be found [here](https://github.com/python
2929
([#707](https://github.com/python-attrs/cattrs/issues/707) [#708](https://github.com/python-attrs/cattrs/pull/708))
3030
- The {mod}`tomlkit <cattrs.preconf.tomlkit>` preconf converter now passes date objects directly to _tomlkit_ for unstructuring.
3131
([#707](https://github.com/python-attrs/cattrs/issues/707) [#708](https://github.com/python-attrs/cattrs/pull/708))
32-
32+
- Enum handling has been optimized by switching to hook factories, improving performance especially for plain enums.
33+
([#705](https://github.com/python-attrs/cattrs/pull/705))
3334

3435
## 25.3.0 (2025-10-07)
3536

src/cattrs/converters.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
is_mutable_set,
4848
is_optional,
4949
is_protocol,
50+
is_subclass,
5051
is_tuple,
5152
is_typeddict,
5253
is_union_type,
@@ -76,6 +77,7 @@
7677
UnstructuredValue,
7778
UnstructureHook,
7879
)
80+
from .enums import enum_structure_factory, enum_unstructure_factory
7981
from .errors import (
8082
IterableValidationError,
8183
IterableValidationNote,
@@ -246,12 +248,12 @@ def __init__(
246248
lambda t: self.get_unstructure_hook(get_type_alias_base(t)),
247249
True,
248250
),
249-
(is_literal_containing_enums, self.unstructure),
250251
(is_mapping, self._unstructure_mapping),
251252
(is_sequence, self._unstructure_seq),
252253
(is_mutable_set, self._unstructure_seq),
253254
(is_frozenset, self._unstructure_seq),
254-
(lambda t: issubclass(t, Enum), self._unstructure_enum),
255+
(is_literal_containing_enums, self.unstructure),
256+
(lambda t: is_subclass(t, Enum), enum_unstructure_factory, "extended"),
255257
(has, self._unstructure_attrs),
256258
(is_union_type, self._unstructure_union),
257259
(lambda t: t in ANIES, self.unstructure),
@@ -298,6 +300,7 @@ def __init__(
298300
self._union_struct_registry.__getitem__,
299301
True,
300302
),
303+
(lambda t: is_subclass(t, Enum), enum_structure_factory, "extended"),
301304
(has, self._structure_attrs),
302305
]
303306
)
@@ -308,7 +311,6 @@ def __init__(
308311
(bytes, self._structure_call),
309312
(int, self._structure_call),
310313
(float, self._structure_call),
311-
(Enum, self._structure_enum),
312314
(Path, self._structure_call),
313315
]
314316
)
@@ -630,12 +632,6 @@ def unstructure_attrs_astuple(self, obj: Any) -> tuple[Any, ...]:
630632
res.append(dispatch(a.type or v.__class__)(v))
631633
return tuple(res)
632634

633-
def _unstructure_enum(self, obj: Enum) -> Any:
634-
"""Convert an enum to its unstructured value."""
635-
if "_value_" in obj.__class__.__annotations__:
636-
return self._unstructure_func.dispatch(obj.value.__class__)(obj.value)
637-
return obj.value
638-
639635
def _unstructure_seq(self, seq: Sequence[T]) -> Sequence[T]:
640636
"""Convert a sequence to primitive equivalents."""
641637
# We can reuse the sequence class, so tuples stay tuples.
@@ -715,15 +711,6 @@ def _structure_simple_literal(val, type):
715711
raise Exception(f"{val} not in literal {type}")
716712
return val
717713

718-
def _structure_enum(self, val: Any, cl: type[Enum]) -> Enum:
719-
"""Structure ``val`` if possible and return the enum it corresponds to.
720-
721-
Uses type hints for the "_value_" attribute if they exist to structure
722-
the enum values before returning the result."""
723-
if "_value_" in cl.__annotations__:
724-
val = self.structure(val, cl.__annotations__["_value_"])
725-
return cl(val)
726-
727714
@staticmethod
728715
def _structure_enum_literal(val, type):
729716
vals = {(x.value if isinstance(x, Enum) else x): x for x in type.__args__}

src/cattrs/enums.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from collections.abc import Callable
2+
from enum import Enum
3+
from typing import TYPE_CHECKING, Any
4+
5+
if TYPE_CHECKING:
6+
from .converters import BaseConverter
7+
8+
9+
def enum_unstructure_factory(
10+
type: type[Enum], converter: "BaseConverter"
11+
) -> Callable[[Enum], Any]:
12+
"""A factory for generating enum unstructure hooks.
13+
14+
If the enum is a typed enum (has `_value_`), we use the underlying value's hook.
15+
Otherwise, we use the value directly.
16+
"""
17+
if "_value_" in type.__annotations__:
18+
return lambda e: converter.unstructure(e.value)
19+
20+
return lambda e: e.value
21+
22+
23+
def enum_structure_factory(
24+
type: type[Enum], converter: "BaseConverter"
25+
) -> Callable[[Any, type[Enum]], Enum]:
26+
"""A factory for generating enum structure hooks.
27+
28+
If the enum is a typed enum (has `_value_`), we structure the value first.
29+
Otherwise, we use the value directly.
30+
"""
31+
if "_value_" in type.__annotations__:
32+
val_type = type.__annotations__["_value_"]
33+
val_hook = converter.get_structure_hook(val_type)
34+
return lambda v, _: type(val_hook(v, val_type))
35+
36+
return lambda v, _: type(v)

src/cattrs/preconf/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from collections.abc import Callable
12
from datetime import datetime
23
from enum import Enum
3-
from typing import Any, Callable, ParamSpec, TypeVar, get_args
4+
from typing import Any, ParamSpec, TypeVar, get_args
45

56
from .._compat import is_subclass
67
from ..converters import Converter, UnstructureHook

src/cattrs/preconf/bson.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,11 @@ def gen_structure_mapping(cl: Any) -> StructureHook:
9999

100100
# datetime inherits from date, so identity unstructure hook used
101101
# here to prevent the date unstructure hook running.
102-
converter.register_unstructure_hook(datetime, lambda v: v)
102+
converter.register_unstructure_hook(datetime, identity)
103103
converter.register_structure_hook(datetime, validate_datetime)
104104
converter.register_unstructure_hook(date, lambda v: v.isoformat())
105105
converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v))
106-
converter.register_unstructure_hook_func(is_primitive_enum, identity)
106+
converter.register_unstructure_hook_factory(is_primitive_enum, lambda t: identity)
107107
converter.register_unstructure_hook_factory(
108108
is_literal_containing_enums, literals_with_enums_unstructure_factory
109109
)

src/cattrs/preconf/cbor2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def configure_converter(converter: BaseConverter):
3737
)
3838
converter.register_unstructure_hook(date, lambda v: v.isoformat())
3939
converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v))
40-
converter.register_unstructure_hook_func(is_primitive_enum, identity)
40+
converter.register_unstructure_hook_factory(is_primitive_enum, lambda t: identity)
4141
converter.register_unstructure_hook_factory(
4242
is_literal_containing_enums, literals_with_enums_unstructure_factory
4343
)

src/cattrs/preconf/json.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def configure_converter(converter: BaseConverter) -> None:
5252
converter.register_unstructure_hook_factory(
5353
is_literal_containing_enums, literals_with_enums_unstructure_factory
5454
)
55-
converter.register_unstructure_hook_func(is_primitive_enum, identity)
55+
converter.register_unstructure_hook_factory(is_primitive_enum, lambda _: identity)
5656
configure_union_passthrough(Union[str, bool, int, float, None], converter)
5757

5858

src/cattrs/preconf/msgpack.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def configure_converter(converter: BaseConverter) -> None:
4646
converter.register_structure_hook(
4747
date, lambda v, _: datetime.fromtimestamp(v, timezone.utc).date()
4848
)
49-
converter.register_unstructure_hook_func(is_primitive_enum, identity)
49+
converter.register_unstructure_hook_factory(is_primitive_enum, lambda t: identity)
5050
converter.register_unstructure_hook_factory(
5151
is_literal_containing_enums, literals_with_enums_unstructure_factory
5252
)

src/cattrs/preconf/msgspec.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,27 @@
33
from __future__ import annotations
44

55
from base64 import b64decode
6+
from collections.abc import Callable
67
from dataclasses import is_dataclass
78
from datetime import date, datetime
89
from enum import Enum
910
from functools import partial
10-
from typing import Any, Callable, TypeVar, Union, get_type_hints
11+
from typing import Any, TypeVar, Union, get_type_hints
1112

1213
from attrs import has as attrs_has
1314
from attrs import resolve_types
1415
from msgspec import Struct, convert, to_builtins
1516
from msgspec.json import Encoder, decode
1617

17-
from .._compat import fields, get_args, get_origin, is_bare, is_mapping, is_sequence
18+
from .._compat import (
19+
fields,
20+
get_args,
21+
get_origin,
22+
is_bare,
23+
is_mapping,
24+
is_sequence,
25+
is_subclass,
26+
)
1827
from ..cols import is_namedtuple
1928
from ..converters import BaseConverter, Converter
2029
from ..dispatch import UnstructureHook
@@ -74,7 +83,9 @@ def configure_converter(converter: Converter) -> None:
7483
configure_passthroughs(converter)
7584

7685
converter.register_unstructure_hook(Struct, to_builtins)
77-
converter.register_unstructure_hook(Enum, identity)
86+
converter.register_unstructure_hook_factory(
87+
lambda t: is_subclass(t, Enum), lambda t, c: identity
88+
)
7889

7990
converter.register_structure_hook(Struct, convert)
8091
converter.register_structure_hook(bytes, lambda v, _: b64decode(v))

src/cattrs/preconf/orjson.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def key_handler(v):
8787
),
8888
]
8989
)
90-
converter.register_unstructure_hook_func(
91-
partial(is_primitive_enum, include_bare_enums=True), identity
90+
converter.register_unstructure_hook_factory(
91+
partial(is_primitive_enum, include_bare_enums=True), lambda t: identity
9292
)
9393
converter.register_unstructure_hook_factory(
9494
is_literal_containing_enums, literals_with_enums_unstructure_factory

0 commit comments

Comments
 (0)