Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions aws_lambda_powertools/event_handler/openapi/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,11 +1112,23 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup
"""
annotated_args = get_args(annotation)
type_annotation = annotated_args[0]
powertools_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)]

# Handle both FieldInfo instances and FieldInfo subclasses (e.g., Body vs Body())
powertools_annotations: list[FieldInfo] = []
for arg in annotated_args[1:]:
if isinstance(arg, FieldInfo):
powertools_annotations.append(arg)
elif isinstance(arg, type) and issubclass(arg, FieldInfo):
# If it's a class (e.g., Body instead of Body()), instantiate it
powertools_annotations.append(arg())

# Preserve non-FieldInfo metadata (like annotated_types constraints)
# This is important for constraints like Interval, Gt, Lt, etc.
other_metadata = [arg for arg in annotated_args[1:] if not isinstance(arg, FieldInfo)]
other_metadata = [
arg
for arg in annotated_args[1:]
if not isinstance(arg, FieldInfo) and not (isinstance(arg, type) and issubclass(arg, FieldInfo))
]

# Determine which annotation to use
powertools_annotation: FieldInfo | None = None
Expand Down
120 changes: 120 additions & 0 deletions tests/functional/event_handler/_pydantic/test_openapi_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,3 +1267,123 @@ def list_items(limit: Annotated[constrained_int, Query()] = 10):
assert limit_param.schema_.type == "integer"
assert limit_param.schema_.default == 10
assert limit_param.required is False


def test_body_class_annotation_without_parentheses():
"""
GIVEN an endpoint using Body class (not instance) in Annotated
WHEN sending a valid request body
THEN the request should be validated correctly
"""
app = APIGatewayRestResolver(enable_validation=True)

class MyRequest(BaseModel):
foo: str
bar: str = "default_bar"

class MyResponse(BaseModel):
concatenated: str

# Using Body (class) instead of Body() (instance)
@app.patch("/test")
def handler(body: Annotated[MyRequest, Body]) -> MyResponse:
return MyResponse(concatenated=body.foo + body.bar)

event = {
"resource": "/test",
"path": "/test",
"httpMethod": "PATCH",
"body": '{"foo": "hello"}',
"isBase64Encoded": False,
}

result = app(event, {})
assert result["statusCode"] == 200
response_body = json.loads(result["body"])
assert response_body["concatenated"] == "hellodefault_bar"


def test_body_instance_annotation_with_parentheses():
"""
GIVEN an endpoint using Body() instance in Annotated
WHEN sending a valid request body
THEN the request should be validated correctly
"""
app = APIGatewayRestResolver(enable_validation=True)

class MyRequest(BaseModel):
foo: str
bar: str = "default_bar"

class MyResponse(BaseModel):
concatenated: str

# Using Body() (instance)
@app.patch("/test")
def handler(body: Annotated[MyRequest, Body()]) -> MyResponse:
return MyResponse(concatenated=body.foo + body.bar)

event = {
"resource": "/test",
"path": "/test",
"httpMethod": "PATCH",
"body": '{"foo": "hello"}',
"isBase64Encoded": False,
}

result = app(event, {})
assert result["statusCode"] == 200
response_body = json.loads(result["body"])
assert response_body["concatenated"] == "hellodefault_bar"


def test_query_class_annotation_without_parentheses():
"""
GIVEN an endpoint using Query class (not instance) in Annotated
WHEN sending a valid query parameter
THEN the request should be validated correctly
"""
app = APIGatewayRestResolver(enable_validation=True)

@app.get("/test")
def handler(name: Annotated[str, Query]) -> dict:
return {"name": name}

event = {
"resource": "/test",
"path": "/test",
"httpMethod": "GET",
"queryStringParameters": {"name": "hello"},
"isBase64Encoded": False,
}

result = app(event, {})
assert result["statusCode"] == 200
response_body = json.loads(result["body"])
assert response_body["name"] == "hello"


def test_header_class_annotation_without_parentheses():
"""
GIVEN an endpoint using Header class (not instance) in Annotated
WHEN sending a valid header
THEN the request should be validated correctly
"""
app = APIGatewayRestResolver(enable_validation=True)

@app.get("/test")
def handler(x_custom: Annotated[str, Header]) -> dict:
return {"header": x_custom}

event = {
"resource": "/test",
"path": "/test",
"httpMethod": "GET",
"headers": {"x-custom": "my-value"},
"isBase64Encoded": False,
}

result = app(event, {})
assert result["statusCode"] == 200
response_body = json.loads(result["body"])
assert response_body["header"] == "my-value"