From b60edd47baecadb9e0db1f9efbef5bff4186c412 Mon Sep 17 00:00:00 2001 From: Sanskar Jethi Date: Tue, 26 Nov 2024 02:48:14 +0000 Subject: [PATCH 1/4] fix: openapi schema generation --- robyn/openapi.py | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/robyn/openapi.py b/robyn/openapi.py index 6532fe2f5..92f9b03a3 100644 --- a/robyn/openapi.py +++ b/robyn/openapi.py @@ -5,7 +5,7 @@ from importlib import resources from inspect import Signature from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict +from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict, Union from robyn.responses import html from robyn.robyn import QueryParams, Response @@ -389,25 +389,44 @@ def get_schema_object(self, parameter: str, param_type: Any) -> dict: list: "array", } + # Handle basic types for type_name in type_mapping: if param_type is type_name: properties["type"] = type_mapping[type_name] return properties - # check for Optional type - if param_type.__module__ == "typing": - properties["anyOf"] = [{"type": self.get_openapi_type(param_type.__args__[0])}, {"type": "null"}] - return properties - # check for custom classes and TypedDicts + # Handle typing module types (Optional, List, etc) + if hasattr(param_type, "__module__") and param_type.__module__ == "typing": + origin = typing.get_origin(param_type) + args = typing.get_args(param_type) + + # Handle Optional types + if origin is Union and type(None) in args: + non_none_type = next(t for t in args if t is not type(None)) + properties["anyOf"] = [ + {"type": self.get_openapi_type(non_none_type)}, + {"type": "null"} + ] + return properties + + # Handle List types + elif origin in (list, List): + properties["type"] = "array" + if args: + item_type = args[0] + properties["items"] = self.get_schema_object("item", item_type) + return properties + + # Handle custom classes and TypedDicts elif inspect.isclass(param_type): properties["type"] = "object" - properties["properties"] = {} - for e in param_type.__annotations__: - properties["properties"][e] = self.get_schema_object(e, param_type.__annotations__[e]) - - properties["type"] = "object" + if hasattr(param_type, "__annotations__"): + for e in param_type.__annotations__: + properties["properties"][e] = self.get_schema_object( + e, param_type.__annotations__[e] + ) return properties From 586d7d63881527d19ba29a03a1576b1758434d66 Mon Sep 17 00:00:00 2001 From: Sanskar Jethi Date: Tue, 26 Nov 2024 02:52:53 +0000 Subject: [PATCH 2/4] update --- integration_tests/base_routes.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/integration_tests/base_routes.py b/integration_tests/base_routes.py index 5092ae9a1..1ea7c5101 100644 --- a/integration_tests/base_routes.py +++ b/integration_tests/base_routes.py @@ -1,7 +1,7 @@ import os import pathlib from collections import defaultdict -from typing import Optional +from typing import Optional, List from integration_tests.subroutes import di_subrouter, sub_router from robyn import Headers, Request, Response, Robyn, WebSocket, WebSocketConnector, jsonify, serve_file, serve_html @@ -560,6 +560,10 @@ async def async_dict_post(): # Body +class TestMyRequest(Body): + items: List[str] + numbers: list[int] + @app.post("/sync/body") def sync_body_post(request: Request): @@ -575,6 +579,11 @@ async def async_body_post(request: Request): def sync_form_data(request: Request): return request.headers["Content-Type"] +@app.post("/sync/body/typed") +def sync_body_typed(body: TestMyRequest): + # the server should just start + return "This is not tested with a request" + # JSON Request From f413c8c586b1b97ac661ec3f81ae69e778b4b4c1 Mon Sep 17 00:00:00 2001 From: Sanskar Jethi Date: Tue, 26 Nov 2024 20:34:55 +0000 Subject: [PATCH 3/4] update --- robyn/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/robyn/__init__.py b/robyn/__init__.py index ca81226f9..fe8b70284 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -607,6 +607,9 @@ def ALLOW_CORS(app: Robyn, origins: Union[List[str], str]): @app.before_request() def cors_middleware(request): + if request is None: + return None + origin = request.headers.get("Origin") # If specific origins are set, validate the request origin From 77b84955a49148e851269dce8c45e7cb3dde4749 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 11 Jan 2025 22:35:52 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- integration_tests/base_routes.py | 2 ++ robyn/__init__.py | 2 +- robyn/openapi.py | 11 +++-------- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/integration_tests/base_routes.py b/integration_tests/base_routes.py index 1ea7c5101..de66e43e0 100644 --- a/integration_tests/base_routes.py +++ b/integration_tests/base_routes.py @@ -560,6 +560,7 @@ async def async_dict_post(): # Body + class TestMyRequest(Body): items: List[str] numbers: list[int] @@ -579,6 +580,7 @@ async def async_body_post(request: Request): def sync_form_data(request: Request): return request.headers["Content-Type"] + @app.post("/sync/body/typed") def sync_body_typed(body: TestMyRequest): # the server should just start diff --git a/robyn/__init__.py b/robyn/__init__.py index fe8b70284..084d4a6dd 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -609,7 +609,7 @@ def ALLOW_CORS(app: Robyn, origins: Union[List[str], str]): def cors_middleware(request): if request is None: return None - + origin = request.headers.get("Origin") # If specific origins are set, validate the request origin diff --git a/robyn/openapi.py b/robyn/openapi.py index 92f9b03a3..75dae7a6f 100644 --- a/robyn/openapi.py +++ b/robyn/openapi.py @@ -403,12 +403,9 @@ def get_schema_object(self, parameter: str, param_type: Any) -> dict: # Handle Optional types if origin is Union and type(None) in args: non_none_type = next(t for t in args if t is not type(None)) - properties["anyOf"] = [ - {"type": self.get_openapi_type(non_none_type)}, - {"type": "null"} - ] + properties["anyOf"] = [{"type": self.get_openapi_type(non_none_type)}, {"type": "null"}] return properties - + # Handle List types elif origin in (list, List): properties["type"] = "array" @@ -424,9 +421,7 @@ def get_schema_object(self, parameter: str, param_type: Any) -> dict: if hasattr(param_type, "__annotations__"): for e in param_type.__annotations__: - properties["properties"][e] = self.get_schema_object( - e, param_type.__annotations__[e] - ) + properties["properties"][e] = self.get_schema_object(e, param_type.__annotations__[e]) return properties