Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 97 additions & 51 deletions src/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import os
from contextlib import asynccontextmanager
from typing import AsyncIterator, Awaitable, Callable
from typing import AsyncIterator

from fastapi import FastAPI, HTTPException, Request, Response
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.routing import Mount, Route, WebSocketRoute
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from llama_stack_client import APIConnectionError

from authorization.azure_token_manager import AzureEntraIDManager
Expand Down Expand Up @@ -115,66 +116,103 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
)


@app.middleware("")
async def rest_api_metrics(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Middleware with REST API counter update logic.
class RestApiMetricsMiddleware: # pylint: disable=too-few-public-methods
"""Pure ASGI middleware for REST API metrics.

Record REST API request metrics for application routes and forward the
request to the next REST API handler.
request to the next ASGI handler.

Only requests whose path is listed in the application's `app_routes_paths`
are measured. For measured requests, this middleware records request
duration and increments a per-path/per-status counter; it does not
increment counters for the `/metrics` endpoint.
Only requests whose path is listed in the application's routes are
measured. For measured requests, this middleware records request duration
and increments a per-path / per-status counter; it does not increment
counters for the ``/metrics`` endpoint.

Parameters:
request (Request): The incoming HTTP request.
call_next (Callable[[Request], Awaitable[Response]]): Callable that
forwards the request to the next ASGI/route handler and returns a
Response.

Returns:
Response: The HTTP response produced by the next handler.
This is implemented as a pure ASGI middleware (instead of using Starlette's
``BaseHTTPMiddleware``) to avoid the ``RuntimeError: No response returned``
bug that occurs when ``call_next`` is used with long-running handlers such
as LLM inference. See https://issues.redhat.com/browse/RSPEED-2413.
"""
path = request.url.path
logger.debug("Received request for path: %s", path)

# ignore paths that are not part of the app routes
if path not in app_routes_paths:
return await call_next(request)
def __init__(self, app: ASGIApp) -> None: # pylint: disable=redefined-outer-name
"""Initialize the middleware."""
self.app = app

logger.debug("Processing API request for path: %s", path)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Process an ASGI request."""
if scope["type"] != "http":
await self.app(scope, receive, send)
return

# measure time to handle duration + update histogram
with metrics.response_duration_seconds.labels(path).time():
response = await call_next(request)
path = scope["path"]
logger.debug("Received request for path: %s", path)

# ignore /metrics endpoint that will be called periodically
if not path.endswith("/metrics"):
# just update metrics
metrics.rest_api_calls_total.labels(path, response.status_code).inc()
return response
# Ignore paths that are not part of the app routes.
if path not in app_routes_paths:
await self.app(scope, receive, send)
return

logger.debug("Processing API request for path: %s", path)

status_code = 500

async def send_wrapper(message: Message) -> None:
nonlocal status_code
if message["type"] == "http.response.start":
status_code = message["status"]
await send(message)

# Measure duration and forward the request. Use try/finally so the
# call counter is always incremented, even when the inner app raises.
try:
with metrics.response_duration_seconds.labels(path).time():
await self.app(scope, receive, send_wrapper)
finally:
# Ignore /metrics endpoint that will be called periodically.
if not path.endswith("/metrics"):
metrics.rest_api_calls_total.labels(path, status_code).inc()

@app.middleware("http")
async def global_exception_middleware(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Middleware to handle uncaught exceptions from all endpoints."""
try:
response = await call_next(request)
return response
except HTTPException:
raise
except Exception as exc: # pylint: disable=broad-exception-caught
logger.exception("Uncaught exception in endpoint: %s", exc)
error_response = InternalServerErrorResponse.generic()
return JSONResponse(
status_code=error_response.status_code,
content={"detail": error_response.detail.model_dump()},
)

class GlobalExceptionMiddleware: # pylint: disable=too-few-public-methods
"""Pure ASGI middleware to handle uncaught exceptions from all endpoints.

This is implemented as a pure ASGI middleware (instead of using Starlette's
``BaseHTTPMiddleware``) to avoid the ``RuntimeError: No response returned``
bug that occurs when ``call_next`` is used with long-running handlers such
as LLM inference. See https://issues.redhat.com/browse/RSPEED-2413.
"""

def __init__(self, app: ASGIApp) -> None: # pylint: disable=redefined-outer-name
"""Initialize the middleware."""
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Process an ASGI request."""
if scope["type"] != "http":
await self.app(scope, receive, send)
return

response_started = False

async def send_wrapper(message: Message) -> None:
nonlocal response_started
if message["type"] == "http.response.start":
response_started = True
await send(message)

try:
await self.app(scope, receive, send_wrapper)
except HTTPException:
raise
except Exception as exc: # pylint: disable=broad-exception-caught
logger.exception("Uncaught exception in endpoint: %s", exc)
if response_started:
raise
error_response = InternalServerErrorResponse.generic()
response = JSONResponse(
status_code=error_response.status_code,
content={"detail": error_response.detail.model_dump()},
)
await response(scope, receive, send)


logger.info("Including routers")
Expand All @@ -185,3 +223,11 @@ async def global_exception_middleware(
for route in app.routes
if isinstance(route, (Mount, Route, WebSocketRoute))
]

# Register pure ASGI middlewares. Middleware execution order is the reverse of
# registration order: GlobalExceptionMiddleware (registered first) is innermost,
# RestApiMetricsMiddleware (registered last) is outermost. This ensures metrics
# always observe a status code — including 500s synthesised by the exception
# middleware — rather than seeing a raw exception with no response.
app.add_middleware(GlobalExceptionMiddleware)
app.add_middleware(RestApiMetricsMiddleware)
173 changes: 138 additions & 35 deletions tests/unit/app/test_main_middleware.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,82 @@
"""Unit tests for the global exception middleware in main.py."""
"""Unit tests for the pure ASGI middlewares in main.py."""

import json
from typing import cast
from unittest.mock import Mock
from unittest.mock import patch

import pytest
from fastapi import HTTPException, Request, Response, status
from fastapi.responses import JSONResponse
from starlette.requests import Request as StarletteRequest
from fastapi import HTTPException, status
from starlette.types import Message, Receive, Scope, Send

from app.main import GlobalExceptionMiddleware, RestApiMetricsMiddleware
from models.responses import InternalServerErrorResponse
from app.main import global_exception_middleware


def _make_scope(path: str = "/test") -> dict:
"""Build a minimal HTTP ASGI scope."""
return {
"type": "http",
"method": "GET",
"path": path,
"query_string": b"",
"headers": [],
}


async def _noop_receive() -> dict:
"""Minimal ASGI receive callable."""
return {"type": "http.request", "body": b""}


class _ResponseCollector:
"""Accumulate ASGI messages so tests can inspect them."""

def __init__(self) -> None:
self.messages: list[Message] = []

async def __call__(self, message: Message) -> None:
self.messages.append(message)

@property
def status_code(self) -> int:
"""Return the HTTP status code from the collected response."""
for msg in self.messages:
if msg["type"] == "http.response.start":
return msg["status"]
raise AssertionError("No http.response.start message")

@property
def body_json(self) -> dict:
"""Return the response body decoded as JSON."""
body = b""
for msg in self.messages:
if msg["type"] == "http.response.body":
body += msg.get("body", b"")
return json.loads(body)


# ---------------------------------------------------------------------------
# GlobalExceptionMiddleware
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_global_exception_middleware_catches_unexpected_exception() -> None:
"""Test that global exception middleware catches unexpected exceptions."""

mock_request = Mock(spec=StarletteRequest)
mock_request.url.path = "/test"
"""Test that GlobalExceptionMiddleware catches unexpected exceptions."""

async def mock_call_next_raises_error(request: Request) -> Response:
"""Mock call_next that raises an unexpected exception."""
async def failing_app(scope: Scope, receive: Receive, send: Send) -> None:
raise ValueError("This is an unexpected error for testing")

response = await global_exception_middleware(
mock_request, mock_call_next_raises_error
)
middleware = GlobalExceptionMiddleware(failing_app)
collector = _ResponseCollector()

await middleware(_make_scope(), _noop_receive, collector)

# Verify it returns a JSONResponse
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert collector.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR

# Parse the response body
response_body_bytes = bytes(response.body)
response_body = json.loads(response_body_bytes.decode("utf-8"))
assert "detail" in response_body
detail = response_body["detail"]
detail = collector.body_json["detail"]
assert isinstance(detail, dict)
assert "response" in detail
assert "cause" in detail

# Verify it matches the generic InternalServerErrorResponse
expected_response = InternalServerErrorResponse.generic()
expected_detail = expected_response.model_dump()["detail"]
detail_dict = cast(dict[str, str], detail)
Expand All @@ -51,24 +86,92 @@ async def mock_call_next_raises_error(request: Request) -> Response:

@pytest.mark.asyncio
async def test_global_exception_middleware_passes_through_http_exception() -> None:
"""Test that global exception middleware passes through HTTPException unchanged."""
"""Test that GlobalExceptionMiddleware passes through HTTPException."""

mock_request = Mock(spec=StarletteRequest)
mock_request.url.path = "/test"

async def mock_call_next_raises_http_exception(request: Request) -> Response:
"""Mock call_next that raises HTTPException."""
async def http_error_app(scope: Scope, receive: Receive, send: Send) -> None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"response": "Test error", "cause": "This is a test"},
)

middleware = GlobalExceptionMiddleware(http_error_app)
collector = _ResponseCollector()

with pytest.raises(HTTPException) as exc_info:
await global_exception_middleware(
mock_request, mock_call_next_raises_http_exception
)
await middleware(_make_scope(), _noop_receive, collector)

assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
detail = cast(dict[str, str], exc_info.value.detail)
assert detail["response"] == "Test error"
assert detail["cause"] == "This is a test"


@pytest.mark.asyncio
async def test_global_exception_middleware_reraises_when_response_started() -> None:
"""Test that exceptions after response headers are sent are re-raised."""

async def partial_response_app(
_scope: Scope, _receive: Receive, send: Send
) -> None:
await send({"type": "http.response.start", "status": 200, "headers": []})
raise RuntimeError("error after headers sent")

middleware = GlobalExceptionMiddleware(partial_response_app)
collector = _ResponseCollector()

with pytest.raises(RuntimeError, match="error after headers sent"):
await middleware(_make_scope(), _noop_receive, collector)


@pytest.mark.asyncio
async def test_global_exception_middleware_skips_non_http() -> None:
"""Test that non-HTTP scopes pass through untouched."""
called = False

async def inner_app(_scope: Scope, _receive: Receive, _send: Send) -> None:
nonlocal called
called = True

middleware = GlobalExceptionMiddleware(inner_app)
await middleware({"type": "websocket"}, _noop_receive, _ResponseCollector())
assert called


# ---------------------------------------------------------------------------
# RestApiMetricsMiddleware
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_rest_api_metrics_skips_non_http() -> None:
"""Test that non-HTTP scopes pass through untouched."""
called = False

async def inner_app(_scope: Scope, _receive: Receive, _send: Send) -> None:
nonlocal called
called = True

middleware = RestApiMetricsMiddleware(inner_app)
await middleware({"type": "websocket"}, _noop_receive, _ResponseCollector())
assert called


@pytest.mark.asyncio
@patch("app.main.app_routes_paths", ["/v1/infer"])
async def test_rest_api_metrics_increments_counter_on_exception() -> None:
"""Counter must be incremented even when the inner app raises."""

async def failing_app(_scope: Scope, _receive: Receive, _send: Send) -> None:
raise RuntimeError("boom")

with patch("app.main.metrics") as mock_metrics:
middleware = RestApiMetricsMiddleware(failing_app)

with pytest.raises(RuntimeError, match="boom"):
await middleware(
_make_scope("/v1/infer"), _noop_receive, _ResponseCollector()
)

mock_metrics.response_duration_seconds.labels.assert_called_once_with("/v1/infer")
mock_metrics.rest_api_calls_total.labels.assert_called_once_with("/v1/infer", 500)
mock_metrics.rest_api_calls_total.labels.return_value.inc.assert_called_once()
Loading