Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit d51271f

Browse files
markusschmausbrandonwillard
authored andcommitted
turn HasDataType and HasShape into Protocol\s
1 parent ec82b9f commit d51271f

File tree

5 files changed

+26
-14
lines changed

5 files changed

+26
-14
lines changed

aesara/graph/type.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import abstractmethod
22
from typing import Any, Generic, Optional, Text, Tuple, TypeVar, Union
33

4-
from typing_extensions import TypeAlias
4+
from typing_extensions import Protocol, TypeAlias, runtime_checkable
55

66
from aesara.graph import utils
77
from aesara.graph.basic import Constant, Variable
@@ -262,14 +262,22 @@ def values_eq_approx(cls, a: D, b: D) -> bool:
262262
return cls.values_eq(a, b)
263263

264264

265-
class HasDataType:
266-
"""A mixin for a type that has a :attr:`dtype` attribute."""
265+
DataType = str
267266

268-
dtype: str
269267

268+
@runtime_checkable
269+
class HasDataType(Protocol):
270+
"""A protocol matching any class with :attr:`dtype` attribute."""
270271

271-
class HasShape:
272-
"""A mixin for a type that has :attr:`shape` and :attr:`ndim` attributes."""
272+
dtype: DataType
273+
274+
275+
ShapeType = Tuple[Optional[int], ...]
276+
277+
278+
@runtime_checkable
279+
class HasShape(Protocol):
280+
"""A protocol matching any class that has :attr:`shape` and :attr:`ndim` attributes."""
273281

274282
ndim: int
275-
shape: Tuple[Optional[int], ...]
283+
shape: ShapeType

aesara/link/c/cmodule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2441,7 +2441,7 @@ def linking_patch(lib_dirs: List[str], libs: List[str]) -> List[str]:
24412441
if sys.platform != "win32":
24422442
return [f"-l{l}" for l in libs]
24432443

2444-
def sort_key(lib): # type: ignore
2444+
def sort_key(lib):
24452445
name, *numbers, extension = lib.split(".")
24462446
return (extension == "dll", tuple(map(int, numbers)))
24472447

aesara/scalar/basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from aesara.graph.basic import Apply, Constant, Variable, clone, list_of_nodes
2929
from aesara.graph.fg import FunctionGraph
3030
from aesara.graph.rewriting.basic import MergeOptimizer
31-
from aesara.graph.type import HasDataType, HasShape
31+
from aesara.graph.type import DataType
3232
from aesara.graph.utils import MetaObject, MethodNotDefined
3333
from aesara.link.c.op import COp
3434
from aesara.link.c.type import CType
@@ -268,7 +268,7 @@ def convert(x, dtype=None):
268268
return x_
269269

270270

271-
class ScalarType(CType, HasDataType, HasShape):
271+
class ScalarType(CType):
272272

273273
"""
274274
Internal class, should not be used by clients.
@@ -284,6 +284,7 @@ class ScalarType(CType, HasDataType, HasShape):
284284
__props__ = ("dtype",)
285285
ndim = 0
286286
shape = ()
287+
dtype: DataType
287288

288289
def __init__(self, dtype):
289290
if isinstance(dtype, str) and dtype == "floatX":

aesara/sparse/type.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import aesara
88
from aesara import scalar as aes
99
from aesara.graph.basic import Variable
10-
from aesara.graph.type import HasDataType
1110
from aesara.tensor.type import DenseTensorType, TensorType
1211

1312

@@ -33,7 +32,7 @@ def _is_sparse(x):
3332
return isinstance(x, scipy.sparse.spmatrix)
3433

3534

36-
class SparseTensorType(TensorType, HasDataType):
35+
class SparseTensorType(TensorType):
3736
"""A `Type` for sparse tensors.
3837
3938
Notes

aesara/tensor/type.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from aesara import scalar as aes
99
from aesara.configdefaults import config
1010
from aesara.graph.basic import Variable
11-
from aesara.graph.type import HasDataType, HasShape
11+
from aesara.graph.type import DataType, ShapeType
1212
from aesara.graph.utils import MetaType
1313
from aesara.link.c.type import CType
1414
from aesara.misc.safe_asarray import _asarray
@@ -48,11 +48,15 @@
4848
}
4949

5050

51-
class TensorType(CType[np.ndarray], HasDataType, HasShape):
51+
class TensorType(CType[np.ndarray]):
5252
r"""Symbolic `Type` representing `numpy.ndarray`\s."""
5353

5454
__props__: Tuple[str, ...] = ("dtype", "shape")
5555

56+
ndim: int
57+
shape: ShapeType
58+
dtype: DataType
59+
5660
dtype_specs_map = dtype_specs_map
5761
context_name = "cpu"
5862
filter_checks_isfinite = False

0 commit comments

Comments
 (0)