Skip to content

Commit 5f4e9ee

Browse files
committed
feat(core): add AttrDict class for attribute-style access
Introduces a new dictionary subclass that enables both traditional dictionary access and attribute-style access patterns. Supports nested merging operations, type safety with generics, and recursive conversion of nested dictionaries. Provides convenient methods for deep copying and converting back to regular dictionaries.
1 parent 8597eba commit 5f4e9ee

File tree

2 files changed

+130
-0
lines changed

2 files changed

+130
-0
lines changed

src/polykit/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def get_setting(self, key):
3434

3535
from __future__ import annotations
3636

37+
from .attr_dict import AttrDict
3738
from .decorators import async_retry_on_exception, retry_on_exception, with_retries
3839
from .detect import platform_check
3940
from .setup import polykit_setup

src/polykit/core/attr_dict.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# type: ignore[reportArgumentType, reportIncompatibleMethodOverride, reportOperatorIssue]
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import Iterable, Iterator, Mapping, MutableMapping
6+
from typing import Any, TypeVar
7+
8+
# Type variables for keys and values
9+
KT = TypeVar("KT")
10+
VT = TypeVar("VT")
11+
12+
13+
class AttrDict[KT, VT](MutableMapping[KT, VT]):
14+
"""A dictionary that allows for attribute-style access."""
15+
16+
def __init__(self, *args: Mapping[KT, VT] | Iterable[tuple[KT, VT]], **kwargs: Any):
17+
self._data: dict[KT, VT] = {}
18+
self.update(dict(*args, **kwargs))
19+
20+
def __setitem__(self, key: KT, value: VT) -> None:
21+
self._data[key] = self._convert(value)
22+
23+
def __getitem__(self, key: KT) -> VT:
24+
return self._data[key]
25+
26+
def __delitem__(self, key: KT) -> None:
27+
del self._data[key]
28+
29+
def __iter__(self) -> Iterator[KT]:
30+
return iter(self._data)
31+
32+
def __len__(self) -> int:
33+
return len(self._data)
34+
35+
def __repr__(self) -> str:
36+
return f"AttrDict({self._data})"
37+
38+
def __eq__(self, other: object) -> bool:
39+
if isinstance(other, AttrDict):
40+
return self._data == other._data
41+
return self._data == other if isinstance(other, dict) else False
42+
43+
def __getattr__(self, name: str) -> Any:
44+
try:
45+
return self[name]
46+
except KeyError as e:
47+
msg = f"'AttrDict' object has no attribute '{name}'"
48+
raise AttributeError(msg) from e
49+
50+
def __setattr__(self, name: str, value: Any) -> None:
51+
if name == "_data":
52+
super().__setattr__(name, value)
53+
else:
54+
self[name] = value
55+
56+
def __dir__(self) -> list[str]:
57+
return list(set(super().__dir__()) | {str(k) for k in self._data})
58+
59+
def __or__(self, other: Mapping[KT, VT] | Iterable[tuple[KT, VT]] | None) -> AttrDict[KT, VT]:
60+
"""Implement the | operator for AttrDict with nested merging."""
61+
if other is None:
62+
return self.copy()
63+
64+
result = self.copy()
65+
other_dict = dict(other)
66+
67+
for key, value in other_dict.items():
68+
if (
69+
key in result
70+
and isinstance(result[key], AttrDict)
71+
and isinstance(value, dict | AttrDict)
72+
):
73+
result[key] |= AttrDict(value)
74+
else:
75+
result[key] = self._convert(value)
76+
77+
return result
78+
79+
def __ror__(self, other: Mapping[KT, VT] | Iterable[tuple[KT, VT]] | None) -> AttrDict[KT, VT]:
80+
"""Implement reverse | operator for AttrDict with nested merging."""
81+
return self.copy() if other is None else AttrDict(other) | self
82+
83+
@classmethod
84+
def _convert(cls, value: Any) -> Any:
85+
if isinstance(value, Mapping) and not isinstance(value, AttrDict):
86+
return cls(value)
87+
return [cls._convert(v) for v in value] if isinstance(value, list) else value
88+
89+
def to_dict(self) -> dict[KT, Any]:
90+
"""Convert AttrDict to a regular dictionary recursively."""
91+
92+
def _to_dict(value: Any) -> Any:
93+
if isinstance(value, AttrDict):
94+
return value.to_dict()
95+
return [_to_dict(v) for v in value] if isinstance(value, list) else value
96+
97+
return {k: _to_dict(v) for k, v in self._data.items()}
98+
99+
def copy(self) -> AttrDict[KT, VT]:
100+
"""Return a shallow copy of the AttrDict."""
101+
return AttrDict(self._data)
102+
103+
def deep_copy(self) -> AttrDict[KT, VT]:
104+
"""Return a deep copy of the AttrDict."""
105+
from copy import deepcopy
106+
107+
return AttrDict(deepcopy(self._data))
108+
109+
def update(self, *args: Mapping[KT, VT] | Iterable[tuple[KT, VT]], **kwargs: Any) -> None:
110+
"""Update the AttrDict with the key/value pairs from other, overwriting existing keys."""
111+
for k, v in dict(*args, **kwargs).items():
112+
self[k] = v
113+
114+
def setdefault(self, key: KT, default: VT | None = None) -> VT:
115+
"""Insert key with a value of default if key is not in the dictionary."""
116+
if key not in self:
117+
self[key] = default
118+
return self[key]
119+
120+
def get(self, key: KT, default: Any | None = None) -> Any:
121+
"""Return the value for key if key is in the dictionary, else default."""
122+
return self._data.get(key, default)
123+
124+
def pop(self, key: KT, default: Any | None = None) -> Any:
125+
"""Remove specified key and return the corresponding value."""
126+
return self._data.pop(key, default)
127+
128+
def __contains__(self, key: object) -> bool:
129+
return key in self._data

0 commit comments

Comments
 (0)