diff --git a/openapi_to_fastapi/tests/test_router.py b/openapi_to_fastapi/tests/test_router.py index 8744f61..f9e567d 100644 --- a/openapi_to_fastapi/tests/test_router.py +++ b/openapi_to_fastapi/tests/test_router.py @@ -1,4 +1,5 @@ -from typing import Any, Dict +import inspect +from typing import Any, Dict, Optional import pydantic import pytest @@ -596,3 +597,42 @@ def make_strict_request(json: Dict[str, Any]) -> Any: assert resp.status_code == expected_strict_code, resp.json() if resp.status_code != 200: assert json_snapshot == resp.json() + + +def test_modified_handler_signatures(app, client, specs_root): + + spec_router = SpecRouter(specs_root / "definitions") + + def handler_1(request, x_my_header: Optional[str] = Header(None)): + return {} + + def handler_2(request, x_my_header: Optional[str] = Header(None)): + return {} + + # Remove the header from handler_2 + sig = inspect.signature(handler_2) + params = sig.parameters + filtered_params = [ + param for param_name, param in params.items() if param_name != "x_my_header" + ] + handler_2.__signature__ = sig.replace(parameters=filtered_params) + + # Add handlers to router (non-decorator syntax) + spec_router.post("/TestValidation_v0.1")(handler_1) + spec_router.post("/TestValidation_v0.2")(handler_2) + + router = spec_router.to_fastapi_router() + app.include_router(router) + openapi_spec = app.openapi() + + route_spec_1 = openapi_spec["paths"]["/TestValidation_v0.1"] + route_spec_2 = openapi_spec["paths"]["/TestValidation_v0.2"] + + parameters_1 = route_spec_1["post"].get("parameters", {}) + parameters_2 = route_spec_2["post"].get("parameters", {}) + + headers_1 = {p.get("name") for p in parameters_1 if p.get("in") == "header"} + headers_2 = {p.get("name") for p in parameters_2 if p.get("in") == "header"} + + assert "x-my-header" in headers_1 + assert "x-my-header" not in headers_2 diff --git a/openapi_to_fastapi/utils.py b/openapi_to_fastapi/utils.py index af51b0d..93ef7a4 100644 --- a/openapi_to_fastapi/utils.py +++ b/openapi_to_fastapi/utils.py @@ -22,6 +22,11 @@ def copy_function(fn) -> Callable: ) g.__kwdefaults__ = deepcopy(fn.__kwdefaults__) g.__annotations__ = deepcopy(fn.__annotations__) + + # Signature is immutable, no need to copy/deepcopy + # Mypy doesn't know about __signature__: https://github.com/python/mypy/issues/12472 + g.__signature__ = inspect.signature(fn) # type: ignore[attr-defined] + return g @@ -32,9 +37,22 @@ def add_annotation_to_first_argument(fn: FunctionType, model: Type[pydantic.Base :param fn: Function to patch :param model: Type to add to the first argument """ - fn_spec = inspect.getfullargspec(fn) - if not len(fn_spec.args): + + sig = inspect.signature(fn) + params = sig.parameters + if not params: raise ValueError(f"Function {fn.__name__} has no arguments") - untyped_args = [a for a in fn_spec.args if a not in fn.__annotations__] - if untyped_args: - fn.__annotations__[untyped_args[0]] = model + + updated = False + updated_params = [] + for param_name, param in params.items(): + if not updated and param.annotation is inspect.Parameter.empty: + updated_params.append(param.replace(annotation=model)) + updated = True + else: + updated_params.append(param) + + # Mypy doesn't know about __signature__: https://github.com/python/mypy/issues/12472 + fn.__signature__ = sig.replace( # type: ignore[attr-defined] + parameters=updated_params + ) diff --git a/pyproject.toml b/pyproject.toml index 3b3f4f3..04ccc94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "openapi-to-fastapi" -version = "0.20.0" +version = "0.21.0" description = "Create FastAPI routes from OpenAPI spec" authors = ["IOXIO Ltd"] license = "BSD-3-Clause"