|
| 1 | +from copy import deepcopy |
| 2 | +import sys |
| 3 | +import json |
| 4 | +import logging |
| 5 | +from functools import wraps |
| 6 | +from traceback import format_exc |
| 7 | +from itertools import groupby |
| 8 | +from operator import itemgetter |
| 9 | +from typing import ( |
| 10 | + Mapping, |
| 11 | + Iterable, |
| 12 | + Optional, |
| 13 | +) |
| 14 | + |
| 15 | +from aiohttp import web |
| 16 | +from aiohttp.web import ( |
| 17 | + Request, |
| 18 | + Response, |
| 19 | + json_response, |
| 20 | +) |
| 21 | +from collections import defaultdict |
| 22 | +from jsonschema import ( |
| 23 | + validate, |
| 24 | + ValidationError, |
| 25 | + FormatChecker, |
| 26 | +) |
| 27 | +from jsonschema.validators import validator_for |
| 28 | + |
| 29 | + |
| 30 | +__all__ = ( |
| 31 | + 'validate_decorator', |
| 32 | +) |
| 33 | + |
| 34 | + |
| 35 | +logger = logging.getLogger(__name__) |
| 36 | + |
| 37 | + |
| 38 | +def serialize_error_response(message: str, code: int, padding='error', |
| 39 | + traceback: bool=False, **kwargs): |
| 40 | + obj = {padding: {'message': message, 'code': code, **kwargs}} |
| 41 | + if traceback and sys.exc_info()[0]: |
| 42 | + obj[padding]['traceback'] = format_exc() |
| 43 | + return json.dumps(obj, default=lambda x: str(x)) |
| 44 | + |
| 45 | + |
| 46 | +def multi_dict_to_dict(mld: Mapping) -> Mapping: |
| 47 | + return { |
| 48 | + key: value[0] |
| 49 | + if isinstance(value, (list, tuple)) and len(value) == 1 else value |
| 50 | + for key, value in mld.items() |
| 51 | + } |
| 52 | + |
| 53 | + |
| 54 | +def validate_schema(obj: Mapping, schema: Mapping): |
| 55 | + validate(obj, schema, format_checker=FormatChecker()) |
| 56 | + |
| 57 | + |
| 58 | +def validate_multi_dict(obj, schema): |
| 59 | + validate(multi_dict_to_dict(obj), schema, format_checker=FormatChecker()) |
| 60 | + |
| 61 | + |
| 62 | +def validate_content_type(swagger: Mapping, content_type: str): |
| 63 | + consumes = swagger.get('consumes') |
| 64 | + if consumes and not any(content_type == consume for consume in consumes): |
| 65 | + raise ValidationError( |
| 66 | + message='Unsupported content type: {}'.format(content_type)) |
| 67 | + |
| 68 | + |
| 69 | +async def validate_request( |
| 70 | + request: Request, |
| 71 | + parameter_groups: Mapping, |
| 72 | + swagger: Mapping): |
| 73 | + validate_content_type(swagger, request.content_type) |
| 74 | + for group_name, group_schemas in parameter_groups.items(): |
| 75 | + if group_name == 'header': |
| 76 | + headers = request.headers |
| 77 | + for schema in group_schemas: |
| 78 | + validate_multi_dict(headers, schema) |
| 79 | + if group_name == 'query': |
| 80 | + query = request.query |
| 81 | + for schema in group_schemas: |
| 82 | + validate_multi_dict(query, schema) |
| 83 | + if group_name == 'formData': |
| 84 | + try: |
| 85 | + data = await request.post() |
| 86 | + except ValueError: |
| 87 | + data = None |
| 88 | + for schema in group_schemas: |
| 89 | + validate_multi_dict(data, schema) |
| 90 | + if group_name == 'body': |
| 91 | + try: |
| 92 | + content = await request.json() |
| 93 | + except json.JSONDecodeError: |
| 94 | + content = None |
| 95 | + for schema in group_schemas: |
| 96 | + validate_schema(content, schema) |
| 97 | + if group_name == 'path': |
| 98 | + params = dict(request.match_info) |
| 99 | + for schema in group_schemas: |
| 100 | + validate_schema(params, schema) |
| 101 | + |
| 102 | + |
| 103 | +def adjust_swagger_item_to_json_schemes(*schemes: Mapping) -> Mapping: |
| 104 | + new_schema = { |
| 105 | + 'type': 'object', |
| 106 | + 'properties': {}, |
| 107 | + } |
| 108 | + required_fields = [] |
| 109 | + for schema in schemes: |
| 110 | + required = schema.get('required', False) |
| 111 | + name = schema['name'] |
| 112 | + _schema = schema.get('schema') |
| 113 | + if _schema is not None: |
| 114 | + new_schema['properties'][name] = _schema |
| 115 | + else: |
| 116 | + new_schema['properties'][name] = { |
| 117 | + key: value for key, value in schema.items() |
| 118 | + if key not in ('required',) |
| 119 | + } |
| 120 | + if required: |
| 121 | + required_fields.append(name) |
| 122 | + if required_fields: |
| 123 | + new_schema['required'] = required_fields |
| 124 | + validator_for(new_schema).check_schema(new_schema) |
| 125 | + return new_schema |
| 126 | + |
| 127 | + |
| 128 | +def adjust_swagger_body_item_to_json_schema(schema: Mapping) -> Mapping: |
| 129 | + required = schema.get('required', False) |
| 130 | + _schema = schema.get('schema') |
| 131 | + new_schema = deepcopy(_schema) |
| 132 | + if not required: |
| 133 | + new_schema = { |
| 134 | + 'anyOf': [ |
| 135 | + {'type': 'null'}, |
| 136 | + new_schema, |
| 137 | + ] |
| 138 | + } |
| 139 | + validator_for(new_schema).check_schema(new_schema) |
| 140 | + return new_schema |
| 141 | + |
| 142 | + |
| 143 | +def adjust_swagger_to_json_schema(parameter_groups: Iterable) -> Mapping: |
| 144 | + res = defaultdict(list) |
| 145 | + for group_name, group_schemas in parameter_groups: |
| 146 | + if group_name in ('query', 'header', 'path', 'formData'): |
| 147 | + json_schema = adjust_swagger_item_to_json_schemes(*group_schemas) |
| 148 | + res[group_name].append(json_schema) |
| 149 | + else: |
| 150 | + # only one possible schema for in: body |
| 151 | + schema = list(group_schemas)[0] |
| 152 | + json_schema = adjust_swagger_body_item_to_json_schema(schema) |
| 153 | + res[group_name].append(json_schema) |
| 154 | + return res |
| 155 | + |
| 156 | + |
| 157 | +def validation_exc_to_dict(exc, code=400): |
| 158 | + paths = list(exc.path) |
| 159 | + field = str(paths[-1]) if paths else '' |
| 160 | + value = exc.instance |
| 161 | + validator = exc.validator |
| 162 | + message = exc.message |
| 163 | + try: |
| 164 | + schema = dict(exc.schema) |
| 165 | + except TypeError: |
| 166 | + schema = {} |
| 167 | + return { |
| 168 | + 'message': message, |
| 169 | + 'code': code, |
| 170 | + 'description': { |
| 171 | + 'validator': validator, |
| 172 | + 'schema': schema, |
| 173 | + 'field': field, |
| 174 | + 'value': value, |
| 175 | + } |
| 176 | + } |
| 177 | + |
| 178 | + |
| 179 | +def validate_decorator(swagger: Mapping): |
| 180 | + |
| 181 | + parameters = swagger.get('parameters', []) |
| 182 | + parameter_groups = adjust_swagger_to_json_schema( |
| 183 | + groupby(parameters, key=itemgetter('in')) |
| 184 | + ) |
| 185 | + |
| 186 | + def _func_wrapper(func): |
| 187 | + |
| 188 | + @wraps(func) |
| 189 | + async def _wrapper(*args, **kwargs) -> Response: |
| 190 | + request = args[0].request \ |
| 191 | + if isinstance(args[0], web.View) else args[0] |
| 192 | + try: |
| 193 | + await validate_request(request, parameter_groups, swagger) |
| 194 | + except ValidationError as exc: |
| 195 | + logger.exception(exc) |
| 196 | + exc_dict = validation_exc_to_dict(exc) |
| 197 | + return json_response( |
| 198 | + text=serialize_error_response(**exc_dict), |
| 199 | + status=400 |
| 200 | + ) |
| 201 | + return await func(*args, **kwargs) |
| 202 | + |
| 203 | + return _wrapper |
| 204 | + |
| 205 | + return _func_wrapper |
0 commit comments