Skip to content

Commit 8fb380c

Browse files
committed
Added requests validation based on swagger schema.
1 parent be39d48 commit 8fb380c

File tree

10 files changed

+1008
-25
lines changed

10 files changed

+1008
-25
lines changed

aiohttp_swagger/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
generate_doc_from_each_end_point,
1313
load_doc_from_yaml_file,
1414
swagger_path,
15+
swagger_validation,
16+
add_swagger_validation,
1517
)
1618

1719
try:
@@ -89,7 +91,7 @@ def setup_swagger(app: web.Application,
8991
)
9092

9193
if swagger_validate_schema:
92-
pass
94+
add_swagger_validation(app, swagger_info)
9395

9496
swagger_info = json.dumps(swagger_info)
9597

aiohttp_swagger/helpers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .builders import * # noqa
22
from .decorators import * # noqa
3+
from .validation import * # noqa

aiohttp_swagger/helpers/builders.py

+51-9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import (
23
MutableMapping,
34
Mapping,
@@ -13,18 +14,21 @@
1314
from aiohttp import web
1415
from aiohttp.hdrs import METH_ANY, METH_ALL
1516
from jinja2 import Template
16-
1717
try:
1818
import ujson as json
1919
except ImportError: # pragma: no cover
2020
import json
2121

22+
from .validation import validate_decorator
23+
2224

2325
SWAGGER_TEMPLATE = abspath(join(dirname(__file__), "..", "templates"))
2426

2527

26-
def _extract_swagger_docs(end_point_doc, method="get"):
27-
# Find Swagger start point in doc
28+
def _extract_swagger_docs(end_point_doc: str) -> Mapping:
29+
"""
30+
Find Swagger start point in doc.
31+
"""
2832
end_point_swagger_start = 0
2933
for i, doc_line in enumerate(end_point_doc):
3034
if "---" in doc_line:
@@ -42,7 +46,7 @@ def _extract_swagger_docs(end_point_doc, method="get"):
4246
"from docstring ⚠",
4347
"tags": ["Invalid Swagger"]
4448
}
45-
return {method: end_point_swagger_doc}
49+
return end_point_swagger_doc
4650

4751

4852
def _build_doc_from_func_doc(route):
@@ -58,16 +62,14 @@ def _build_doc_from_func_doc(route):
5862
method = getattr(route.handler, method_name)
5963
if method.__doc__ is not None and "---" in method.__doc__:
6064
end_point_doc = method.__doc__.splitlines()
61-
out.update(
62-
_extract_swagger_docs(end_point_doc, method=method_name))
65+
out[method_name] = _extract_swagger_docs(end_point_doc)
6366

6467
else:
6568
try:
6669
end_point_doc = route.handler.__doc__.splitlines()
6770
except AttributeError:
6871
return {}
69-
out.update(_extract_swagger_docs(
70-
end_point_doc, method=route.method.lower()))
72+
out[route.method.lower()] = _extract_swagger_docs(end_point_doc)
7173
return out
7274

7375

@@ -150,7 +152,47 @@ def load_doc_from_yaml_file(doc_path: str) -> MutableMapping:
150152
return yaml.load(open(doc_path, "r").read())
151153

152154

155+
def add_swagger_validation(app, swagger_info: Mapping):
156+
for route in app.router.routes():
157+
method = route.method.lower()
158+
handler = route.handler
159+
url_info = route.get_info()
160+
url = url_info.get('path') or url_info.get('formatter')
161+
162+
if method != '*':
163+
swagger_endpoint_info_for_method = \
164+
swagger_info['paths'].get(url, {}).get(method)
165+
swagger_endpoint_info = \
166+
{method: swagger_endpoint_info_for_method} if \
167+
swagger_endpoint_info_for_method is not None else {}
168+
else:
169+
# all methods
170+
swagger_endpoint_info = swagger_info['paths'].get(url, {})
171+
for method, info in swagger_endpoint_info.items():
172+
logging.debug(
173+
'Added validation for method: {}. Path: {}'.
174+
format(method.upper(), url)
175+
)
176+
if issubclass(handler, web.View) and route.method == METH_ANY:
177+
# whole class validation
178+
should_be_validated = getattr(handler, 'validation', False)
179+
cls_method = getattr(handler, method, None)
180+
if cls_method is not None:
181+
if not should_be_validated:
182+
# method validation
183+
should_be_validated = \
184+
getattr(handler, 'validation', False)
185+
if should_be_validated:
186+
new_cls_method = validate_decorator(info)(cls_method)
187+
setattr(handler, method, new_cls_method)
188+
else:
189+
should_be_validated = getattr(handler, 'validation', False)
190+
if should_be_validated:
191+
route._handler = validate_decorator(info)(handler)
192+
193+
153194
__all__ = (
154195
"generate_doc_from_each_end_point",
155-
"load_doc_from_yaml_file"
196+
"load_doc_from_yaml_file",
197+
"add_swagger_validation",
156198
)

aiohttp_swagger/helpers/decorators.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,27 @@
1-
class swagger_path(object):
1+
from functools import partial
2+
from inspect import isfunction, isclass
3+
4+
__all__ = (
5+
'swagger_path',
6+
'swagger_validation',
7+
)
8+
9+
10+
class swagger_path:
11+
212
def __init__(self, swagger_file):
313
self.swagger_file = swagger_file
414

515
def __call__(self, f):
616
f.swagger_file = self.swagger_file
717
return f
18+
19+
20+
def swagger_validation(func=None, *, validation=True):
21+
22+
if func is None or not (isfunction(func) or isclass(func)):
23+
validation = func
24+
return partial(swagger_validation, validation=validation)
25+
26+
func.validation = validation
27+
return func

aiohttp_swagger/helpers/validation.py

+205
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
from copy import deepcopy
2+
import sys
3+
import json
4+
import logging
5+
from functools import wraps
6+
from traceback import format_exc
7+
from itertools import groupby
8+
from operator import itemgetter
9+
from typing import (
10+
Mapping,
11+
Iterable,
12+
Optional,
13+
)
14+
15+
from aiohttp import web
16+
from aiohttp.web import (
17+
Request,
18+
Response,
19+
json_response,
20+
)
21+
from collections import defaultdict
22+
from jsonschema import (
23+
validate,
24+
ValidationError,
25+
FormatChecker,
26+
)
27+
from jsonschema.validators import validator_for
28+
29+
30+
__all__ = (
31+
'validate_decorator',
32+
)
33+
34+
35+
logger = logging.getLogger(__name__)
36+
37+
38+
def serialize_error_response(message: str, code: int, padding='error',
39+
traceback: bool=False, **kwargs):
40+
obj = {padding: {'message': message, 'code': code, **kwargs}}
41+
if traceback and sys.exc_info()[0]:
42+
obj[padding]['traceback'] = format_exc()
43+
return json.dumps(obj, default=lambda x: str(x))
44+
45+
46+
def multi_dict_to_dict(mld: Mapping) -> Mapping:
47+
return {
48+
key: value[0]
49+
if isinstance(value, (list, tuple)) and len(value) == 1 else value
50+
for key, value in mld.items()
51+
}
52+
53+
54+
def validate_schema(obj: Mapping, schema: Mapping):
55+
validate(obj, schema, format_checker=FormatChecker())
56+
57+
58+
def validate_multi_dict(obj, schema):
59+
validate(multi_dict_to_dict(obj), schema, format_checker=FormatChecker())
60+
61+
62+
def validate_content_type(swagger: Mapping, content_type: str):
63+
consumes = swagger.get('consumes')
64+
if consumes and not any(content_type == consume for consume in consumes):
65+
raise ValidationError(
66+
message='Unsupported content type: {}'.format(content_type))
67+
68+
69+
async def validate_request(
70+
request: Request,
71+
parameter_groups: Mapping,
72+
swagger: Mapping):
73+
validate_content_type(swagger, request.content_type)
74+
for group_name, group_schemas in parameter_groups.items():
75+
if group_name == 'header':
76+
headers = request.headers
77+
for schema in group_schemas:
78+
validate_multi_dict(headers, schema)
79+
if group_name == 'query':
80+
query = request.query
81+
for schema in group_schemas:
82+
validate_multi_dict(query, schema)
83+
if group_name == 'formData':
84+
try:
85+
data = await request.post()
86+
except ValueError:
87+
data = None
88+
for schema in group_schemas:
89+
validate_multi_dict(data, schema)
90+
if group_name == 'body':
91+
try:
92+
content = await request.json()
93+
except json.JSONDecodeError:
94+
content = None
95+
for schema in group_schemas:
96+
validate_schema(content, schema)
97+
if group_name == 'path':
98+
params = dict(request.match_info)
99+
for schema in group_schemas:
100+
validate_schema(params, schema)
101+
102+
103+
def adjust_swagger_item_to_json_schemes(*schemes: Mapping) -> Mapping:
104+
new_schema = {
105+
'type': 'object',
106+
'properties': {},
107+
}
108+
required_fields = []
109+
for schema in schemes:
110+
required = schema.get('required', False)
111+
name = schema['name']
112+
_schema = schema.get('schema')
113+
if _schema is not None:
114+
new_schema['properties'][name] = _schema
115+
else:
116+
new_schema['properties'][name] = {
117+
key: value for key, value in schema.items()
118+
if key not in ('required',)
119+
}
120+
if required:
121+
required_fields.append(name)
122+
if required_fields:
123+
new_schema['required'] = required_fields
124+
validator_for(new_schema).check_schema(new_schema)
125+
return new_schema
126+
127+
128+
def adjust_swagger_body_item_to_json_schema(schema: Mapping) -> Mapping:
129+
required = schema.get('required', False)
130+
_schema = schema.get('schema')
131+
new_schema = deepcopy(_schema)
132+
if not required:
133+
new_schema = {
134+
'anyOf': [
135+
{'type': 'null'},
136+
new_schema,
137+
]
138+
}
139+
validator_for(new_schema).check_schema(new_schema)
140+
return new_schema
141+
142+
143+
def adjust_swagger_to_json_schema(parameter_groups: Iterable) -> Mapping:
144+
res = defaultdict(list)
145+
for group_name, group_schemas in parameter_groups:
146+
if group_name in ('query', 'header', 'path', 'formData'):
147+
json_schema = adjust_swagger_item_to_json_schemes(*group_schemas)
148+
res[group_name].append(json_schema)
149+
else:
150+
# only one possible schema for in: body
151+
schema = list(group_schemas)[0]
152+
json_schema = adjust_swagger_body_item_to_json_schema(schema)
153+
res[group_name].append(json_schema)
154+
return res
155+
156+
157+
def validation_exc_to_dict(exc, code=400):
158+
paths = list(exc.path)
159+
field = str(paths[-1]) if paths else ''
160+
value = exc.instance
161+
validator = exc.validator
162+
message = exc.message
163+
try:
164+
schema = dict(exc.schema)
165+
except TypeError:
166+
schema = {}
167+
return {
168+
'message': message,
169+
'code': code,
170+
'description': {
171+
'validator': validator,
172+
'schema': schema,
173+
'field': field,
174+
'value': value,
175+
}
176+
}
177+
178+
179+
def validate_decorator(swagger: Mapping):
180+
181+
parameters = swagger.get('parameters', [])
182+
parameter_groups = adjust_swagger_to_json_schema(
183+
groupby(parameters, key=itemgetter('in'))
184+
)
185+
186+
def _func_wrapper(func):
187+
188+
@wraps(func)
189+
async def _wrapper(*args, **kwargs) -> Response:
190+
request = args[0].request \
191+
if isinstance(args[0], web.View) else args[0]
192+
try:
193+
await validate_request(request, parameter_groups, swagger)
194+
except ValidationError as exc:
195+
logger.exception(exc)
196+
exc_dict = validation_exc_to_dict(exc)
197+
return json_response(
198+
text=serialize_error_response(**exc_dict),
199+
status=400
200+
)
201+
return await func(*args, **kwargs)
202+
203+
return _wrapper
204+
205+
return _func_wrapper

0 commit comments

Comments
 (0)