diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index be4256703ba..4c4b66ed21b 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -1112,11 +1112,23 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup """ annotated_args = get_args(annotation) type_annotation = annotated_args[0] - powertools_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)] + + # Handle both FieldInfo instances and FieldInfo subclasses (e.g., Body vs Body()) + powertools_annotations: list[FieldInfo] = [] + for arg in annotated_args[1:]: + if isinstance(arg, FieldInfo): + powertools_annotations.append(arg) + elif isinstance(arg, type) and issubclass(arg, FieldInfo): + # If it's a class (e.g., Body instead of Body()), instantiate it + powertools_annotations.append(arg()) # Preserve non-FieldInfo metadata (like annotated_types constraints) # This is important for constraints like Interval, Gt, Lt, etc. - other_metadata = [arg for arg in annotated_args[1:] if not isinstance(arg, FieldInfo)] + other_metadata = [ + arg + for arg in annotated_args[1:] + if not isinstance(arg, FieldInfo) and not (isinstance(arg, type) and issubclass(arg, FieldInfo)) + ] # Determine which annotation to use powertools_annotation: FieldInfo | None = None diff --git a/tests/functional/event_handler/_pydantic/test_openapi_params.py b/tests/functional/event_handler/_pydantic/test_openapi_params.py index 7efba0e16b3..c1d2d304741 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_params.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_params.py @@ -1267,3 +1267,123 @@ def list_items(limit: Annotated[constrained_int, Query()] = 10): assert limit_param.schema_.type == "integer" assert limit_param.schema_.default == 10 assert limit_param.required is False + + +def test_body_class_annotation_without_parentheses(): + """ + GIVEN an endpoint using Body class (not instance) in Annotated + WHEN sending a valid request body + THEN the request should be validated correctly + """ + app = APIGatewayRestResolver(enable_validation=True) + + class MyRequest(BaseModel): + foo: str + bar: str = "default_bar" + + class MyResponse(BaseModel): + concatenated: str + + # Using Body (class) instead of Body() (instance) + @app.patch("/test") + def handler(body: Annotated[MyRequest, Body]) -> MyResponse: + return MyResponse(concatenated=body.foo + body.bar) + + event = { + "resource": "/test", + "path": "/test", + "httpMethod": "PATCH", + "body": '{"foo": "hello"}', + "isBase64Encoded": False, + } + + result = app(event, {}) + assert result["statusCode"] == 200 + response_body = json.loads(result["body"]) + assert response_body["concatenated"] == "hellodefault_bar" + + +def test_body_instance_annotation_with_parentheses(): + """ + GIVEN an endpoint using Body() instance in Annotated + WHEN sending a valid request body + THEN the request should be validated correctly + """ + app = APIGatewayRestResolver(enable_validation=True) + + class MyRequest(BaseModel): + foo: str + bar: str = "default_bar" + + class MyResponse(BaseModel): + concatenated: str + + # Using Body() (instance) + @app.patch("/test") + def handler(body: Annotated[MyRequest, Body()]) -> MyResponse: + return MyResponse(concatenated=body.foo + body.bar) + + event = { + "resource": "/test", + "path": "/test", + "httpMethod": "PATCH", + "body": '{"foo": "hello"}', + "isBase64Encoded": False, + } + + result = app(event, {}) + assert result["statusCode"] == 200 + response_body = json.loads(result["body"]) + assert response_body["concatenated"] == "hellodefault_bar" + + +def test_query_class_annotation_without_parentheses(): + """ + GIVEN an endpoint using Query class (not instance) in Annotated + WHEN sending a valid query parameter + THEN the request should be validated correctly + """ + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/test") + def handler(name: Annotated[str, Query]) -> dict: + return {"name": name} + + event = { + "resource": "/test", + "path": "/test", + "httpMethod": "GET", + "queryStringParameters": {"name": "hello"}, + "isBase64Encoded": False, + } + + result = app(event, {}) + assert result["statusCode"] == 200 + response_body = json.loads(result["body"]) + assert response_body["name"] == "hello" + + +def test_header_class_annotation_without_parentheses(): + """ + GIVEN an endpoint using Header class (not instance) in Annotated + WHEN sending a valid header + THEN the request should be validated correctly + """ + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/test") + def handler(x_custom: Annotated[str, Header]) -> dict: + return {"header": x_custom} + + event = { + "resource": "/test", + "path": "/test", + "httpMethod": "GET", + "headers": {"x-custom": "my-value"}, + "isBase64Encoded": False, + } + + result = app(event, {}) + assert result["statusCode"] == 200 + response_body = json.loads(result["body"]) + assert response_body["header"] == "my-value"