Skip to content

Commit 64a9f0a

Browse files
add type checker
1 parent b958c0f commit 64a9f0a

File tree

2 files changed

+148
-0
lines changed

2 files changed

+148
-0
lines changed

pina/type_checker.py

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""Module for enforcing type hints in Python functions."""
2+
3+
import inspect
4+
import typing
5+
import logging
6+
7+
8+
def enforce_types(func):
9+
"""
10+
Function decorator to enforce type hints at runtime.
11+
12+
This decorator checks the types of the arguments and of the return value of
13+
the decorated function against the type hints specified in the function
14+
signature. If the types do not match, a TypeError is raised.
15+
Type checking is only performed when the logging level is set to `DEBUG`.
16+
17+
:param Callable func: The function to be decorated.
18+
:return: The decorated function with enforced type hints.
19+
:rtype: Callable
20+
21+
:Example:
22+
23+
>>> @enforce_types
24+
def dummy_function(a: int, b: float) -> float:
25+
... return a+b
26+
27+
# This always works.
28+
dummy_function(1, 2.0)
29+
30+
# This raises a TypeError for the second argument, if logging is set to
31+
# `DEBUG`.
32+
dummy_function(1, "Hello, world!")
33+
34+
35+
>>> @enforce_types
36+
def dummy_function2(a: int, right: bool) -> float:
37+
... if right:
38+
... return float(a)
39+
... else:
40+
... return "Hello, world!"
41+
42+
# This always works.
43+
dummy_function2(1, right=True)
44+
45+
# This raises a TypeError for the return value if logging is set to
46+
# `DEBUG`.
47+
dummy_function2(1, right=False)
48+
"""
49+
50+
def wrapper(*args, **kwargs):
51+
"""
52+
Wrapper function to enforce type hints.
53+
54+
:param tuple args: Positional arguments passed to the function.
55+
:param dict kwargs: Keyword arguments passed to the function.
56+
:raises TypeError: If the argument or return type does not match the
57+
specified type hints.
58+
:return: The result of the decorated function.
59+
:rtype: Any
60+
"""
61+
level = logging.getLevelName(logging.getLogger().getEffectiveLevel())
62+
63+
# Enforce type hints only in debug mode
64+
if level != "DEBUG":
65+
return func(*args, **kwargs)
66+
67+
# Get the type hints for the function arguments
68+
hints = typing.get_type_hints(func)
69+
sig = inspect.signature(func)
70+
bound = sig.bind(*args, **kwargs)
71+
bound.apply_defaults()
72+
73+
for arg_name, arg_value in bound.arguments.items():
74+
expected_type = hints.get(arg_name)
75+
if expected_type and not isinstance(arg_value, expected_type):
76+
raise TypeError(
77+
f"Argument '{arg_name}' must be {expected_type.__name__}, "
78+
f"but got {type(arg_value).__name__}!"
79+
)
80+
81+
# Get the type hints for the return values
82+
return_type = hints.get("return")
83+
result = func(*args, **kwargs)
84+
85+
if return_type and not isinstance(result, return_type):
86+
raise TypeError(
87+
f"Return value must be {return_type.__name__}, "
88+
f"but got {type(result).__name__}!"
89+
)
90+
91+
return result
92+
93+
return wrapper

tests/test_type_checker.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pytest
2+
import logging
3+
import math
4+
from pina.type_checker import enforce_types
5+
6+
7+
# Definition of a test function for arguments
8+
@enforce_types
9+
def foo_function1(a: int, b: float) -> float:
10+
return a + b
11+
12+
13+
# Definition of a test function for return values
14+
@enforce_types
15+
def foo_function2(a: int, right: bool) -> float:
16+
if right:
17+
return float(a)
18+
else:
19+
return "Hello, world!"
20+
21+
22+
def test_argument_type_checking():
23+
24+
# Setting logging level to INFO, which should not trigger type checking
25+
logging.getLogger().setLevel(logging.INFO)
26+
27+
# Both should work, even if the arguments are not of the expected type
28+
assert math.isclose(foo_function1(a=1, b=2.0), 3.0)
29+
assert math.isclose(foo_function1(a=1, b=2), 3.0)
30+
31+
# Setting logging level to DEBUG, which should trigger type checking
32+
logging.getLogger().setLevel(logging.DEBUG)
33+
34+
# The second should fail, as the second argument is an int
35+
assert math.isclose(foo_function1(a=1, b=2.0), 3.0)
36+
with pytest.raises(TypeError):
37+
foo_function1(a=1, b=2)
38+
39+
40+
def test_return_type_checking():
41+
42+
# Setting logging level to INFO, which should not trigger type checking
43+
logging.getLogger().setLevel(logging.INFO)
44+
45+
# Both should work, even if the return value is not of the expected type
46+
assert math.isclose(foo_function2(a=1, right=True), 1.0)
47+
assert foo_function2(a=1, right=False) == "Hello, world!"
48+
49+
# Setting logging level to DEBUG, which should trigger type checking
50+
logging.getLogger().setLevel(logging.DEBUG)
51+
52+
# The second should fail, as the return value is a string
53+
assert math.isclose(foo_function2(a=1, right=True), 1.0)
54+
with pytest.raises(TypeError):
55+
foo_function2(a=1, right=False)

0 commit comments

Comments
 (0)