Skip to content

Commit fffd722

Browse files
committed
feat: Add support for Literal types.
This commit extends support of dataclasses_json to dataclasses with fields annotated with Literal types. Literal types allow users to specify a list of valid values, e.g., ```python @DataClass class DataClassWithLiteral(DataClassJsonMixin): languages: Literal["C", "C++", "Java"] ``` When de-serializing data, this commit now validates that the JSON's values are one of those specified in the Literal type. Change in behavior: Using literal types would previously give users the following warning: ``` dataclasses_json/mm.py:357: UserWarning: Unknown type C at Foo.langs: typing.Literal['C', 'C++', 'Java']. It's advised to pass the correct marshmallow type to `mm_field`. ```
1 parent 538ff15 commit fffd722

File tree

3 files changed

+156
-7
lines changed

3 files changed

+156
-7
lines changed

dataclasses_json/core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Tuple, TypeVar, Type)
1818
from uuid import UUID
1919

20-
from typing_inspect import is_union_type # type: ignore
20+
from typing_inspect import is_union_type, is_literal_type # type: ignore
2121

2222
from dataclasses_json import cfg
2323
from dataclasses_json.utils import (_get_type_cons, _get_type_origin,
@@ -358,7 +358,8 @@ def _decode_dict_keys(key_type, xs, infer_missing):
358358
# This is a special case for Python 3.7 and Python 3.8.
359359
# By some reason, "unbound" dicts are counted
360360
# as having key type parameter to be TypeVar('KT')
361-
if key_type is None or key_type == Any or isinstance(key_type, TypeVar):
361+
# Literal types are also passed through without any decoding.
362+
if key_type is None or key_type == Any or isinstance(key_type, TypeVar) or is_literal_type(key_type):
362363
decode_function = key_type = (lambda x: x)
363364
# handle a nested python dict that has tuples for keys. E.g. for
364365
# Dict[Tuple[int], int], key_type will be typing.Tuple[int], but

dataclasses_json/mm.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
from uuid import UUID
1212
from enum import Enum
1313

14-
from typing_inspect import is_union_type # type: ignore
14+
from typing_inspect import is_union_type, is_literal_type # type: ignore
1515

1616
from marshmallow import fields, Schema, post_load # type: ignore
1717
from marshmallow.exceptions import ValidationError # type: ignore
1818

1919
from dataclasses_json.core import (_is_supported_generic, _decode_dataclass,
2020
_ExtendedEncoder, _user_overrides_or_exts)
21-
from dataclasses_json.utils import (_is_collection, _is_optional,
21+
from dataclasses_json.utils import (_get_type_args, _is_collection, _is_optional,
2222
_issubclass_safe, _timestamp_to_dt_aware,
2323
_is_new_type, _get_type_origin,
2424
_handle_undefined_parameters_safe,
@@ -130,6 +130,46 @@ def _deserialize(self, value, attr, data, **kwargs):
130130
return None if optional_list is None else tuple(optional_list)
131131

132132

133+
class _LiteralField(fields.Field):
134+
def __init__(self, literal_values, cls, field, *args, **kwargs):
135+
"""Create a new Literal field.
136+
137+
Literals allow you to specify the set of valid _values_ for a field. The field
138+
implementation validates against these values on deserialization.
139+
140+
Example:
141+
>>> @dataclass
142+
... class DataClassWithLiteral(DataClassJsonMixin):
143+
... read_mode: Literal["r", "w", "a"]
144+
145+
Args:
146+
literal_values: A sequence of possible values for the field.
147+
cls: The dataclass that the field belongs to.
148+
field: The field that the schema describes.
149+
"""
150+
self.literal_values = literal_values
151+
self.cls = cls
152+
self.field = field
153+
super().__init__(*args, **kwargs)
154+
155+
def _serialize(self, value, attr, obj, **kwargs):
156+
if self.allow_none and value is None:
157+
return None
158+
if value not in self.literal_values:
159+
warnings.warn(
160+
f'The value "{value}" is not one of the values of typing.Literal '
161+
f'(dataclass: {self.cls.__name__}, field: {self.field.name}). '
162+
f'Value will not be de-serialized properly.')
163+
return super()._serialize(value, attr, obj, **kwargs)
164+
165+
def _deserialize(self, value, attr, data, **kwargs):
166+
if value not in self.literal_values:
167+
raise ValidationError(
168+
f'Value "{value}" is not one in typing.Literal{self.literal_values} '
169+
f'(dataclass: {self.cls.__name__}, field: {self.field.name}).')
170+
return super()._deserialize(value, attr, data, **kwargs)
171+
172+
133173
TYPES = {
134174
typing.Mapping: fields.Mapping,
135175
typing.MutableMapping: fields.Mapping,
@@ -259,9 +299,14 @@ def inner(type_, options):
259299
f"`dataclass_json` decorator or mixin.")
260300
return fields.Field(**options)
261301

262-
origin = getattr(type_, '__origin__', type_)
263-
args = [inner(a, {}) for a in getattr(type_, '__args__', []) if
264-
a is not type(None)]
302+
origin = _get_type_origin(type_)
303+
304+
# Type arguments are typically types (e.g. int in list[int]) except for Literal
305+
# types, where they are values.
306+
if is_literal_type(type_):
307+
args = []
308+
else:
309+
args = [inner(a, {}) for a in _get_type_args(type_) if a is not type(None)]
265310

266311
if type_ == Ellipsis:
267312
return type_
@@ -279,6 +324,10 @@ def inner(type_, options):
279324
if _issubclass_safe(origin, Enum):
280325
return fields.Enum(enum=origin, by_value=True, *args, **options)
281326

327+
if is_literal_type(type_):
328+
literal_values = _get_type_args(type_)
329+
return _LiteralField(literal_values, cls, field, **options)
330+
282331
if is_union_type(type_):
283332
union_types = [a for a in getattr(type_, '__args__', []) if
284333
a is not type(None)]

tests/test_literals.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""Test dataclasses_json handling of Literal types."""
2+
import sys
3+
import pytest
4+
5+
if sys.version_info < (3, 8):
6+
pytest.skip("Literal types are only supported in Python 3.8+", allow_module_level=True)
7+
8+
import json
9+
from typing import Literal, Optional
10+
11+
from dataclasses import dataclass
12+
13+
from dataclasses_json import dataclass_json, DataClassJsonMixin
14+
from marshmallow.exceptions import ValidationError # type: ignore
15+
16+
17+
@dataclass_json
18+
@dataclass
19+
class DataClassWithLiteral(DataClassJsonMixin):
20+
numeric_literals: Literal[0, 1]
21+
string_literals: Literal["one", "two", "three"]
22+
mixed_literals: Literal[0, "one", 2]
23+
24+
25+
with_valid_literal_json = '{"numeric_literals": 0, "string_literals": "one", "mixed_literals": 2}'
26+
with_valid_literal_data = DataClassWithLiteral(numeric_literals=0, string_literals="one", mixed_literals=2)
27+
with_invalid_literal_json = '{"numeric_literals": 9, "string_literals": "four", "mixed_literals": []}'
28+
with_invalid_literal_data = DataClassWithLiteral(numeric_literals=9, string_literals="four", mixed_literals=[]) # type: ignore
29+
30+
@dataclass_json
31+
@dataclass
32+
class DataClassWithNestedLiteral(DataClassJsonMixin):
33+
list_of_literals: list[Literal[0, 1]]
34+
dict_of_literals: dict[Literal["one", "two", "three"], Literal[0, 1]]
35+
optional_literal: Optional[Literal[0, 1]]
36+
37+
with_valid_nested_literal_json = '{"list_of_literals": [0, 1], "dict_of_literals": {"one": 0, "two": 1}, "optional_literal": 1}'
38+
with_valid_nested_literal_data = DataClassWithNestedLiteral(list_of_literals=[0, 1], dict_of_literals={"one": 0, "two": 1}, optional_literal=1)
39+
with_invalid_nested_literal_json = '{"list_of_literals": [0, 2], "dict_of_literals": {"one": 0, "four": 2}, "optional_literal": 2}'
40+
with_invalid_nested_literal_data = DataClassWithNestedLiteral(list_of_literals=[0, 2], dict_of_literals={"one": 0, "four": 2}, optional_literal=2) # type: ignore
41+
42+
class TestEncoder:
43+
def test_valid_literal(self):
44+
assert with_valid_literal_data.to_dict(encode_json=True) == json.loads(with_valid_literal_json)
45+
46+
def test_invalid_literal(self):
47+
assert with_invalid_literal_data.to_dict(encode_json=True) == json.loads(with_invalid_literal_json)
48+
49+
def test_valid_nested_literal(self):
50+
assert with_valid_nested_literal_data.to_dict(encode_json=True) == json.loads(with_valid_nested_literal_json)
51+
52+
def test_invalid_nested_literal(self):
53+
assert with_invalid_nested_literal_data.to_dict(encode_json=True) == json.loads(with_invalid_nested_literal_json)
54+
55+
56+
class TestSchemaEncoder:
57+
def test_valid_literal(self):
58+
actual = DataClassWithLiteral.schema().dumps(with_valid_literal_data)
59+
assert actual == with_valid_literal_json
60+
61+
def test_invalid_literal(self):
62+
actual = DataClassWithLiteral.schema().dumps(with_invalid_literal_data)
63+
assert actual == with_invalid_literal_json
64+
65+
def test_valid_nested_literal(self):
66+
actual = DataClassWithNestedLiteral.schema().dumps(with_valid_nested_literal_data)
67+
assert actual == with_valid_nested_literal_json
68+
69+
def test_invalid_nested_literal(self):
70+
actual = DataClassWithNestedLiteral.schema().dumps(with_invalid_nested_literal_data)
71+
assert actual == with_invalid_nested_literal_json
72+
73+
class TestDecoder:
74+
def test_valid_literal(self):
75+
actual = DataClassWithLiteral.from_json(with_valid_literal_json)
76+
assert actual == with_valid_literal_data
77+
78+
def test_invalid_literal(self):
79+
expected = DataClassWithLiteral(numeric_literals=9, string_literals="four", mixed_literals=[]) # type: ignore
80+
actual = DataClassWithLiteral.from_json(with_invalid_literal_json)
81+
assert actual == expected
82+
83+
84+
class TestSchemaDecoder:
85+
def test_valid_literal(self):
86+
actual = DataClassWithLiteral.schema().loads(with_valid_literal_json)
87+
assert actual == with_valid_literal_data
88+
89+
def test_invalid_literal(self):
90+
with pytest.raises(ValidationError):
91+
DataClassWithLiteral.schema().loads(with_invalid_literal_json)
92+
93+
def test_valid_nested_literal(self):
94+
actual = DataClassWithNestedLiteral.schema().loads(with_valid_nested_literal_json)
95+
assert actual == with_valid_nested_literal_data
96+
97+
def test_invalid_nested_literal(self):
98+
with pytest.raises(ValidationError):
99+
DataClassWithNestedLiteral.schema().loads(with_invalid_nested_literal_json)

0 commit comments

Comments
 (0)