diff --git a/src/app/main.py b/src/app/main.py index e4ee8390..baf2b65f 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -2,12 +2,13 @@ import os from contextlib import asynccontextmanager -from typing import AsyncIterator, Awaitable, Callable +from typing import AsyncIterator -from fastapi import FastAPI, HTTPException, Request, Response +from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from starlette.routing import Mount, Route, WebSocketRoute +from starlette.types import ASGIApp, Message, Receive, Scope, Send from llama_stack_client import APIConnectionError from authorization.azure_token_manager import AzureEntraIDManager @@ -115,66 +116,103 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: ) -@app.middleware("") -async def rest_api_metrics( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -) -> Response: - """Middleware with REST API counter update logic. +class RestApiMetricsMiddleware: # pylint: disable=too-few-public-methods + """Pure ASGI middleware for REST API metrics. Record REST API request metrics for application routes and forward the - request to the next REST API handler. + request to the next ASGI handler. - Only requests whose path is listed in the application's `app_routes_paths` - are measured. For measured requests, this middleware records request - duration and increments a per-path/per-status counter; it does not - increment counters for the `/metrics` endpoint. + Only requests whose path is listed in the application's routes are + measured. For measured requests, this middleware records request duration + and increments a per-path / per-status counter; it does not increment + counters for the ``/metrics`` endpoint. - Parameters: - request (Request): The incoming HTTP request. - call_next (Callable[[Request], Awaitable[Response]]): Callable that - forwards the request to the next ASGI/route handler and returns a - Response. - - Returns: - Response: The HTTP response produced by the next handler. + This is implemented as a pure ASGI middleware (instead of using Starlette's + ``BaseHTTPMiddleware``) to avoid the ``RuntimeError: No response returned`` + bug that occurs when ``call_next`` is used with long-running handlers such + as LLM inference. See https://issues.redhat.com/browse/RSPEED-2413. """ - path = request.url.path - logger.debug("Received request for path: %s", path) - # ignore paths that are not part of the app routes - if path not in app_routes_paths: - return await call_next(request) + def __init__(self, app: ASGIApp) -> None: # pylint: disable=redefined-outer-name + """Initialize the middleware.""" + self.app = app - logger.debug("Processing API request for path: %s", path) + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Process an ASGI request.""" + if scope["type"] != "http": + await self.app(scope, receive, send) + return - # measure time to handle duration + update histogram - with metrics.response_duration_seconds.labels(path).time(): - response = await call_next(request) + path = scope["path"] + logger.debug("Received request for path: %s", path) - # ignore /metrics endpoint that will be called periodically - if not path.endswith("/metrics"): - # just update metrics - metrics.rest_api_calls_total.labels(path, response.status_code).inc() - return response + # Ignore paths that are not part of the app routes. + if path not in app_routes_paths: + await self.app(scope, receive, send) + return + logger.debug("Processing API request for path: %s", path) + + status_code = 500 + + async def send_wrapper(message: Message) -> None: + nonlocal status_code + if message["type"] == "http.response.start": + status_code = message["status"] + await send(message) + + # Measure duration and forward the request. Use try/finally so the + # call counter is always incremented, even when the inner app raises. + try: + with metrics.response_duration_seconds.labels(path).time(): + await self.app(scope, receive, send_wrapper) + finally: + # Ignore /metrics endpoint that will be called periodically. + if not path.endswith("/metrics"): + metrics.rest_api_calls_total.labels(path, status_code).inc() -@app.middleware("http") -async def global_exception_middleware( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -) -> Response: - """Middleware to handle uncaught exceptions from all endpoints.""" - try: - response = await call_next(request) - return response - except HTTPException: - raise - except Exception as exc: # pylint: disable=broad-exception-caught - logger.exception("Uncaught exception in endpoint: %s", exc) - error_response = InternalServerErrorResponse.generic() - return JSONResponse( - status_code=error_response.status_code, - content={"detail": error_response.detail.model_dump()}, - ) + +class GlobalExceptionMiddleware: # pylint: disable=too-few-public-methods + """Pure ASGI middleware to handle uncaught exceptions from all endpoints. + + This is implemented as a pure ASGI middleware (instead of using Starlette's + ``BaseHTTPMiddleware``) to avoid the ``RuntimeError: No response returned`` + bug that occurs when ``call_next`` is used with long-running handlers such + as LLM inference. See https://issues.redhat.com/browse/RSPEED-2413. + """ + + def __init__(self, app: ASGIApp) -> None: # pylint: disable=redefined-outer-name + """Initialize the middleware.""" + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Process an ASGI request.""" + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + response_started = False + + async def send_wrapper(message: Message) -> None: + nonlocal response_started + if message["type"] == "http.response.start": + response_started = True + await send(message) + + try: + await self.app(scope, receive, send_wrapper) + except HTTPException: + raise + except Exception as exc: # pylint: disable=broad-exception-caught + logger.exception("Uncaught exception in endpoint: %s", exc) + if response_started: + raise + error_response = InternalServerErrorResponse.generic() + response = JSONResponse( + status_code=error_response.status_code, + content={"detail": error_response.detail.model_dump()}, + ) + await response(scope, receive, send) logger.info("Including routers") @@ -185,3 +223,11 @@ async def global_exception_middleware( for route in app.routes if isinstance(route, (Mount, Route, WebSocketRoute)) ] + +# Register pure ASGI middlewares. Middleware execution order is the reverse of +# registration order: GlobalExceptionMiddleware (registered first) is innermost, +# RestApiMetricsMiddleware (registered last) is outermost. This ensures metrics +# always observe a status code — including 500s synthesised by the exception +# middleware — rather than seeing a raw exception with no response. +app.add_middleware(GlobalExceptionMiddleware) +app.add_middleware(RestApiMetricsMiddleware) diff --git a/tests/unit/app/test_main_middleware.py b/tests/unit/app/test_main_middleware.py index 3b7184d4..a10d0f06 100644 --- a/tests/unit/app/test_main_middleware.py +++ b/tests/unit/app/test_main_middleware.py @@ -1,47 +1,82 @@ -"""Unit tests for the global exception middleware in main.py.""" +"""Unit tests for the pure ASGI middlewares in main.py.""" import json from typing import cast -from unittest.mock import Mock +from unittest.mock import patch import pytest -from fastapi import HTTPException, Request, Response, status -from fastapi.responses import JSONResponse -from starlette.requests import Request as StarletteRequest +from fastapi import HTTPException, status +from starlette.types import Message, Receive, Scope, Send +from app.main import GlobalExceptionMiddleware, RestApiMetricsMiddleware from models.responses import InternalServerErrorResponse -from app.main import global_exception_middleware + + +def _make_scope(path: str = "/test") -> dict: + """Build a minimal HTTP ASGI scope.""" + return { + "type": "http", + "method": "GET", + "path": path, + "query_string": b"", + "headers": [], + } + + +async def _noop_receive() -> dict: + """Minimal ASGI receive callable.""" + return {"type": "http.request", "body": b""} + + +class _ResponseCollector: + """Accumulate ASGI messages so tests can inspect them.""" + + def __init__(self) -> None: + self.messages: list[Message] = [] + + async def __call__(self, message: Message) -> None: + self.messages.append(message) + + @property + def status_code(self) -> int: + """Return the HTTP status code from the collected response.""" + for msg in self.messages: + if msg["type"] == "http.response.start": + return msg["status"] + raise AssertionError("No http.response.start message") + + @property + def body_json(self) -> dict: + """Return the response body decoded as JSON.""" + body = b"" + for msg in self.messages: + if msg["type"] == "http.response.body": + body += msg.get("body", b"") + return json.loads(body) + + +# --------------------------------------------------------------------------- +# GlobalExceptionMiddleware +# --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_global_exception_middleware_catches_unexpected_exception() -> None: - """Test that global exception middleware catches unexpected exceptions.""" - - mock_request = Mock(spec=StarletteRequest) - mock_request.url.path = "/test" + """Test that GlobalExceptionMiddleware catches unexpected exceptions.""" - async def mock_call_next_raises_error(request: Request) -> Response: - """Mock call_next that raises an unexpected exception.""" + async def failing_app(scope: Scope, receive: Receive, send: Send) -> None: raise ValueError("This is an unexpected error for testing") - response = await global_exception_middleware( - mock_request, mock_call_next_raises_error - ) + middleware = GlobalExceptionMiddleware(failing_app) + collector = _ResponseCollector() + + await middleware(_make_scope(), _noop_receive, collector) - # Verify it returns a JSONResponse - assert isinstance(response, JSONResponse) - assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert collector.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR - # Parse the response body - response_body_bytes = bytes(response.body) - response_body = json.loads(response_body_bytes.decode("utf-8")) - assert "detail" in response_body - detail = response_body["detail"] + detail = collector.body_json["detail"] assert isinstance(detail, dict) - assert "response" in detail - assert "cause" in detail - # Verify it matches the generic InternalServerErrorResponse expected_response = InternalServerErrorResponse.generic() expected_detail = expected_response.model_dump()["detail"] detail_dict = cast(dict[str, str], detail) @@ -51,24 +86,92 @@ async def mock_call_next_raises_error(request: Request) -> Response: @pytest.mark.asyncio async def test_global_exception_middleware_passes_through_http_exception() -> None: - """Test that global exception middleware passes through HTTPException unchanged.""" + """Test that GlobalExceptionMiddleware passes through HTTPException.""" - mock_request = Mock(spec=StarletteRequest) - mock_request.url.path = "/test" - - async def mock_call_next_raises_http_exception(request: Request) -> Response: - """Mock call_next that raises HTTPException.""" + async def http_error_app(scope: Scope, receive: Receive, send: Send) -> None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={"response": "Test error", "cause": "This is a test"}, ) + middleware = GlobalExceptionMiddleware(http_error_app) + collector = _ResponseCollector() + with pytest.raises(HTTPException) as exc_info: - await global_exception_middleware( - mock_request, mock_call_next_raises_http_exception - ) + await middleware(_make_scope(), _noop_receive, collector) assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST detail = cast(dict[str, str], exc_info.value.detail) assert detail["response"] == "Test error" assert detail["cause"] == "This is a test" + + +@pytest.mark.asyncio +async def test_global_exception_middleware_reraises_when_response_started() -> None: + """Test that exceptions after response headers are sent are re-raised.""" + + async def partial_response_app( + _scope: Scope, _receive: Receive, send: Send + ) -> None: + await send({"type": "http.response.start", "status": 200, "headers": []}) + raise RuntimeError("error after headers sent") + + middleware = GlobalExceptionMiddleware(partial_response_app) + collector = _ResponseCollector() + + with pytest.raises(RuntimeError, match="error after headers sent"): + await middleware(_make_scope(), _noop_receive, collector) + + +@pytest.mark.asyncio +async def test_global_exception_middleware_skips_non_http() -> None: + """Test that non-HTTP scopes pass through untouched.""" + called = False + + async def inner_app(_scope: Scope, _receive: Receive, _send: Send) -> None: + nonlocal called + called = True + + middleware = GlobalExceptionMiddleware(inner_app) + await middleware({"type": "websocket"}, _noop_receive, _ResponseCollector()) + assert called + + +# --------------------------------------------------------------------------- +# RestApiMetricsMiddleware +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_rest_api_metrics_skips_non_http() -> None: + """Test that non-HTTP scopes pass through untouched.""" + called = False + + async def inner_app(_scope: Scope, _receive: Receive, _send: Send) -> None: + nonlocal called + called = True + + middleware = RestApiMetricsMiddleware(inner_app) + await middleware({"type": "websocket"}, _noop_receive, _ResponseCollector()) + assert called + + +@pytest.mark.asyncio +@patch("app.main.app_routes_paths", ["/v1/infer"]) +async def test_rest_api_metrics_increments_counter_on_exception() -> None: + """Counter must be incremented even when the inner app raises.""" + + async def failing_app(_scope: Scope, _receive: Receive, _send: Send) -> None: + raise RuntimeError("boom") + + with patch("app.main.metrics") as mock_metrics: + middleware = RestApiMetricsMiddleware(failing_app) + + with pytest.raises(RuntimeError, match="boom"): + await middleware( + _make_scope("/v1/infer"), _noop_receive, _ResponseCollector() + ) + + mock_metrics.response_duration_seconds.labels.assert_called_once_with("/v1/infer") + mock_metrics.rest_api_calls_total.labels.assert_called_once_with("/v1/infer", 500) + mock_metrics.rest_api_calls_total.labels.return_value.inc.assert_called_once()