diff --git a/src/mlpa/core/app_attest/qa_certificates.py b/src/mlpa/core/app_attest/qa_certificates.py index 15c415c..a379f38 100644 --- a/src/mlpa/core/app_attest/qa_certificates.py +++ b/src/mlpa/core/app_attest/qa_certificates.py @@ -6,9 +6,9 @@ from google.cloud import storage from google.cloud.exceptions import NotFound -from loguru import logger from mlpa.core.config import env +from mlpa.core.logger import logger QA_CERT_DIR = Path(env.APP_ATTEST_QA_CERT_DIR) QA_CERT_FILENAMES: tuple[str, ...] = ( diff --git a/src/mlpa/core/completions.py b/src/mlpa/core/completions.py index 13900d1..a409f45 100644 --- a/src/mlpa/core/completions.py +++ b/src/mlpa/core/completions.py @@ -5,7 +5,6 @@ import httpx import tiktoken from fastapi import HTTPException -from loguru import logger from mlpa.core.classes import AuthorizedChatRequest from mlpa.core.config import ( @@ -16,8 +15,9 @@ env, ) from mlpa.core.http_client import get_http_client +from mlpa.core.logger import logger from mlpa.core.prometheus_metrics import PrometheusResult, metrics -from mlpa.core.utils import is_rate_limit_error +from mlpa.core.utils import is_rate_limit_error, raise_and_log # Global default tokenizer - initialized once at module load time _global_default_tokenizer: Optional[tiktoken.Encoding] = None @@ -124,10 +124,7 @@ async def stream_completion(authorized_chat_request: AuthorizedChatRequest): return # For other errors or if we couldn't parse the error - logger.error( - f"Upstream service returned an error: {e.response.status_code} - {error_text_str}" - ) - yield f'data: {{"error": "Upstream service returned an error"}}\n\n'.encode() + yield raise_and_log(e, True) return async for chunk in response.aiter_bytes(): @@ -147,13 +144,15 @@ async def stream_completion(authorized_chat_request: AuthorizedChatRequest): metrics.chat_tokens.labels(type="completion").inc(num_completion_tokens) result = PrometheusResult.SUCCESS except httpx.HTTPStatusError as e: - logger.error(f"Upstream service returned an error: {e}") if not streaming_started: - yield f'data: {{"error": "Upstream service returned an error"}}\n\n'.encode() + yield raise_and_log(e, True) + else: + logger.error(f"Upstream service returned an error: {e}") except Exception as e: - logger.error(f"Failed to proxy request to {LITELLM_COMPLETIONS_URL}: {e}") if not streaming_started: - yield f'data: {{"error": "Failed to proxy request"}}\n\n'.encode() + yield raise_and_log(e, True, 502, "Failed to proxy request") + else: + logger.error(f"Upstream service returned an error: {e}") finally: metrics.chat_completion_latency.labels(result=result).observe( time.time() - start_time @@ -188,20 +187,7 @@ async def get_completion(authorized_chat_request: AuthorizedChatRequest): except httpx.HTTPStatusError as e: if e.response.status_code in {400, 429}: _handle_rate_limit_error(e.response.text, authorized_chat_request.user) - logger.error( - f"Upstream service returned an error: {e.response.status_code} - {e.response.text}" - ) - raise HTTPException( - status_code=e.response.status_code, - detail={"error": "Upstream service returned an error"}, - ) - logger.error( - f"Upstream service returned an error: {e.response.status_code} - {e.response.text}" - ) - raise HTTPException( - status_code=e.response.status_code, - detail={"error": "Upstream service returned an error"}, - ) + raise_and_log(e) data = response.json() usage = data.get("usage", {}) prompt_tokens = usage.get("prompt_tokens", 0) @@ -215,11 +201,7 @@ async def get_completion(authorized_chat_request: AuthorizedChatRequest): except HTTPException: raise except Exception as e: - logger.error(f"Failed to proxy request to {LITELLM_COMPLETIONS_URL}: {e}") - raise HTTPException( - status_code=502, - detail={"error": f"Failed to proxy request"}, - ) + raise_and_log(e, False, 502, "Failed to proxy request") finally: metrics.chat_completion_latency.labels(result=result).observe( time.time() - start_time diff --git a/src/mlpa/core/pg_services/app_attest_pg_service.py b/src/mlpa/core/pg_services/app_attest_pg_service.py index bcea25b..acef18e 100644 --- a/src/mlpa/core/pg_services/app_attest_pg_service.py +++ b/src/mlpa/core/pg_services/app_attest_pg_service.py @@ -1,6 +1,5 @@ -from loguru import logger - from mlpa.core.config import env +from mlpa.core.logger import logger from mlpa.core.pg_services.pg_service import PGService @@ -11,36 +10,36 @@ def __init__(self): # Challenges # async def store_challenge(self, key_id_b64: str, challenge: str): try: - async with self.pg.acquire() as conn: - query = """ - INSERT INTO challenges (key_id_b64, challenge) - VALUES ($1, $2) - ON CONFLICT (key_id_b64) DO UPDATE SET - challenge = EXCLUDED.challenge, - created_at = NOW() + await self.pg.fetchval( """ - stmt = await self._get_prepared_statement(conn, query) - await stmt.execute(key_id_b64, challenge) + INSERT INTO challenges (key_id_b64, challenge) + VALUES ($1, $2) + ON CONFLICT (key_id_b64) DO UPDATE SET + challenge = EXCLUDED.challenge, + created_at = NOW() + RETURNING 1 + """, + key_id_b64, + challenge, + ) except Exception as e: logger.error(f"Error storing challenge: {e}") async def get_challenge(self, key_id_b64: str) -> dict | None: try: - async with self.pg.acquire() as conn: - stmt = await self._get_prepared_statement( - conn, - "SELECT challenge, created_at FROM challenges WHERE key_id_b64 = $1", - ) - return await stmt.fetchrow(key_id_b64) + return await self.pg.fetchrow( + "SELECT challenge, created_at FROM challenges WHERE key_id_b64 = $1", + key_id_b64, + ) except Exception as e: logger.error(f"Error retrieving challenge: {e}") async def delete_challenge(self, key_id_b64: str): try: - async with self.pg.acquire() as conn: - query = "DELETE FROM challenges WHERE key_id_b64 = $1" - stmt = await self._get_prepared_statement(conn, query) - await stmt.execute(key_id_b64) + await self.pg.fetchval( + "DELETE FROM challenges WHERE key_id_b64 = $1 RETURNING 1", + key_id_b64, + ) except Exception as e: logger.error(f"Error deleting challenge: {e}") @@ -65,25 +64,27 @@ async def store_key(self, key_id_b64: str, public_key_pem: str, counter: int): async def get_key(self, key_id_b64: str) -> dict | None: try: - async with self.pg.acquire() as conn: - query = "SELECT public_key_pem, counter FROM public_keys WHERE key_id_b64 = $1" - stmt = await self._get_prepared_statement(conn, query) - return await stmt.fetchrow(key_id_b64) + return await self.pg.fetchrow( + "SELECT public_key_pem, counter FROM public_keys WHERE key_id_b64 = $1", + key_id_b64, + ) except Exception as e: logger.error(f"Error retrieving key: {e}") return None async def update_key_counter(self, key_id_b64: str, counter: int): try: - async with self.pg.acquire() as conn: - query = """ - UPDATE public_keys - SET counter = $2, - updated_at = NOW() - WHERE key_id_b64 = $1 AND counter < $2 + await self.pg.fetchval( """ - stmt = await self._get_prepared_statement(conn, query) - await stmt.execute(key_id_b64, counter) + UPDATE public_keys + SET counter = $2, + updated_at = NOW() + WHERE key_id_b64 = $1 AND counter < $2 + RETURNING 1 + """, + key_id_b64, + counter, + ) except Exception as e: logger.error(f"Error updating key counter: {e}") diff --git a/src/mlpa/core/pg_services/litellm_pg_service.py b/src/mlpa/core/pg_services/litellm_pg_service.py index 50ee774..b3385ec 100644 --- a/src/mlpa/core/pg_services/litellm_pg_service.py +++ b/src/mlpa/core/pg_services/litellm_pg_service.py @@ -1,7 +1,7 @@ from fastapi import HTTPException -from loguru import logger from mlpa.core.config import env +from mlpa.core.logger import logger from mlpa.core.pg_services.pg_service import PGService @@ -14,19 +14,21 @@ def __init__(self): super().__init__(env.LITELLM_DB_NAME) async def get_user(self, user_id: str): - async with self.pg.acquire() as conn: - query = 'SELECT * FROM "LiteLLM_EndUserTable" WHERE user_id = $1' - stmt = await self._get_prepared_statement(conn, query) - user = await stmt.fetchrow(user_id) - return dict(user) if user else None + user = await self.pg.fetchrow( + 'SELECT * FROM "LiteLLM_EndUserTable" WHERE user_id = $1', + user_id, + ) + return dict(user) if user else None async def block_user(self, user_id: str, blocked: bool = True) -> dict: try: async with self.pg.acquire() as conn: async with conn.transaction(): - query = 'UPDATE "LiteLLM_EndUserTable" SET "blocked" = $1 WHERE user_id = $2 RETURNING *' - stmt = await self._get_prepared_statement(conn, query) - updated_user_record = await stmt.fetchrow(blocked, user_id) + updated_user_record = await conn.fetchrow( + 'UPDATE "LiteLLM_EndUserTable" SET "blocked" = $1 WHERE user_id = $2 RETURNING *', + blocked, + user_id, + ) if updated_user_record is None: logger.error( @@ -48,21 +50,21 @@ async def block_user(self, user_id: str, blocked: bool = True) -> dict: async def list_users(self, limit: int = 50, offset: int = 0) -> dict: try: - async with self.pg.acquire() as conn: - count_query = 'SELECT COUNT(*) FROM "LiteLLM_EndUserTable"' - count_stmt = await self._get_prepared_statement(conn, count_query) - total = await count_stmt.fetchval() - - query = 'SELECT * FROM "LiteLLM_EndUserTable" ORDER BY user_id LIMIT $1 OFFSET $2' - stmt = await self._get_prepared_statement(conn, query) - users = await stmt.fetch(limit, offset) + total = await self.pg.fetchval( + 'SELECT COUNT(*) FROM "LiteLLM_EndUserTable"' + ) + users = await self.pg.fetch( + 'SELECT * FROM "LiteLLM_EndUserTable" ORDER BY user_id LIMIT $1 OFFSET $2', + limit, + offset, + ) - return { - "users": [dict(user) for user in users], - "total": total, - "limit": limit, - "offset": offset, - } + return { + "users": [dict(user) for user in users], + "total": total, + "limit": limit, + "offset": offset, + } except Exception as e: logger.error(f"Error listing users: {e}") raise HTTPException( diff --git a/src/mlpa/core/pg_services/pg_service.py b/src/mlpa/core/pg_services/pg_service.py index 26edde9..13e3758 100644 --- a/src/mlpa/core/pg_services/pg_service.py +++ b/src/mlpa/core/pg_services/pg_service.py @@ -1,12 +1,9 @@ import sys -from collections.abc import MutableMapping -from typing import Any import asyncpg -from cachetools import LRUCache -from loguru import logger from mlpa.core.config import env +from mlpa.core.logger import logger class PGService: @@ -18,29 +15,13 @@ def __init__(self, db_name: str): self.connected = False self.pg = None - def _create_stmt_cache(self) -> MutableMapping[str, Any]: - """Create a prepared statement cache with LRU eviction.""" - return LRUCache(maxsize=env.PG_PREPARED_STMT_CACHE_MAX_SIZE) - - async def _get_prepared_statement(self, conn: asyncpg.Connection, query: str): - stmt_cache = getattr(conn, "_mlpa_stmt_cache", None) - if stmt_cache is None: - stmt_cache = self._create_stmt_cache() - conn._mlpa_stmt_cache = stmt_cache - - if query in stmt_cache: - return stmt_cache[query] - - prepared_stmt = await conn.prepare(query) - stmt_cache[query] = prepared_stmt - return prepared_stmt - async def connect(self): try: self.pg = await asyncpg.create_pool( self.db_url, min_size=env.PG_POOL_MIN_SIZE, max_size=env.PG_POOL_MAX_SIZE, + statement_cache_size=env.PG_PREPARED_STMT_CACHE_MAX_SIZE, ) self.connected = True logger.info(f"Connected to /{self.db_name}") diff --git a/src/mlpa/core/routers/appattest/appattest.py b/src/mlpa/core/routers/appattest/appattest.py index 733a275..5d325ad 100644 --- a/src/mlpa/core/routers/appattest/appattest.py +++ b/src/mlpa/core/routers/appattest/appattest.py @@ -12,13 +12,13 @@ from cryptography.x509.base import load_pem_x509_certificate from fastapi import HTTPException from fastapi.concurrency import run_in_threadpool -from loguru import logger from pyattest.assertion import Assertion from pyattest.attestation import Attestation from pyattest.configs.apple import AppleConfig from mlpa.core.app_attest import QA_CERT_DIR, ensure_qa_certificates from mlpa.core.config import env +from mlpa.core.logger import logger from mlpa.core.pg_services.services import app_attest_pg from mlpa.core.prometheus_metrics import PrometheusResult, metrics from mlpa.core.utils import b64decode_safe diff --git a/src/mlpa/core/routers/appattest/middleware.py b/src/mlpa/core/routers/appattest/middleware.py index fa9dedd..b6a3fe5 100644 --- a/src/mlpa/core/routers/appattest/middleware.py +++ b/src/mlpa/core/routers/appattest/middleware.py @@ -1,10 +1,10 @@ from typing import Annotated from fastapi import APIRouter, Header, HTTPException -from loguru import logger from mlpa.core.classes import AssertionAuth, ChatRequest from mlpa.core.config import env +from mlpa.core.logger import logger from mlpa.core.routers.appattest import ( generate_client_challenge, validate_challenge, diff --git a/src/mlpa/core/routers/fxa/fxa.py b/src/mlpa/core/routers/fxa/fxa.py index 6419f64..6f31162 100644 --- a/src/mlpa/core/routers/fxa/fxa.py +++ b/src/mlpa/core/routers/fxa/fxa.py @@ -2,8 +2,8 @@ from typing import Annotated from fastapi import APIRouter, Header, HTTPException -from loguru import logger +from mlpa.core.logger import logger from mlpa.core.prometheus_metrics import PrometheusResult, metrics from mlpa.core.utils import get_fxa_client diff --git a/src/mlpa/core/routers/user/user.py b/src/mlpa/core/routers/user/user.py index 2604f59..925ec9a 100644 --- a/src/mlpa/core/routers/user/user.py +++ b/src/mlpa/core/routers/user/user.py @@ -1,11 +1,13 @@ from typing import Annotated +import httpx from fastapi import APIRouter, Depends, Header, HTTPException, Query -from loguru import logger from mlpa.core.config import LITELLM_MASTER_AUTH_HEADERS, env from mlpa.core.http_client import get_http_client +from mlpa.core.logger import logger from mlpa.core.pg_services.services import litellm_pg +from mlpa.core.utils import raise_and_log router = APIRouter() @@ -45,6 +47,10 @@ async def user_info(user_id: str): params=params, headers=LITELLM_MASTER_AUTH_HEADERS, ) + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise_and_log(e, False, e.response.status_code, "Error fetching user info") user = response.json() if not user: diff --git a/src/mlpa/core/utils.py b/src/mlpa/core/utils.py index 8e8bcb6..494ae91 100644 --- a/src/mlpa/core/utils.py +++ b/src/mlpa/core/utils.py @@ -1,13 +1,15 @@ +import ast import base64 +import json from fastapi import HTTPException from fxa.oauth import Client from jwtoxide import DecodingKey, ValidationOptions, decode -from loguru import logger from mlpa.core.classes import AssertionAuth, AttestationAuth from mlpa.core.config import LITELLM_MASTER_AUTH_HEADERS, env from mlpa.core.http_client import get_http_client +from mlpa.core.logger import logger async def get_or_create_user(user_id: str): @@ -107,3 +109,54 @@ def parse_app_attest_jwt(authorization: str, type: str): logger.error(f"App {type} JWT decode error: {e}") raise HTTPException(status_code=401, detail=f"Invalid App {type}") return appAuth + + +GENERIC_UPSTREAM_ERROR = "Upstream service returned an error" + + +def raise_and_log( + e: Exception, + stream: bool = False, + response_code: int | None = None, + response_text_prefix: str | None = None, +): + """ + Log an upstream exception and return or raise a standardized FastAPI response. + + When streaming, returns an SSE payload as bytes. Otherwise, raises an + HTTPException with the chosen status code and a sanitized error message. + If the upstream error body contains a nested error message, it is extracted + so clients receive the actual upstream detail in debug mode. (dev environment only) + """ + response = getattr(e, "response", None) + error_text = response.text if response is not None else "" + detail_text = error_text or str(e) + if error_text: + try: + error_payload = json.loads(error_text) + message = error_payload.get("error", {}).get("message") + if isinstance(message, str) and message.startswith("{'error':"): + try: + message_obj = ast.literal_eval(message) + message = message_obj.get("error", message) + except (ValueError, SyntaxError): + pass + if isinstance(message, str) and message: + detail_text = message + except (json.JSONDecodeError, AttributeError, TypeError): + pass + status_code = response_code or getattr(response, "status_code", None) or 500 + logger.error(f"{response_text_prefix or GENERIC_UPSTREAM_ERROR}: {detail_text}") + if stream: + error_msg = detail_text if env.MLPA_DEBUG else GENERIC_UPSTREAM_ERROR + payload = {"code": status_code, "error": error_msg} + return f"data: {json.dumps(payload)}\n\n".encode() + else: + raise HTTPException( + status_code=status_code, + detail={ + "error": detail_text + if env.MLPA_DEBUG + else response_text_prefix or GENERIC_UPSTREAM_ERROR + }, + ) diff --git a/src/tests/mocks.py b/src/tests/mocks.py index 773f916..904dca6 100644 --- a/src/tests/mocks.py +++ b/src/tests/mocks.py @@ -5,13 +5,10 @@ from cryptography.hazmat.primitives import serialization from cryptography.x509 import load_der_x509_certificate from fastapi import HTTPException -from fastapi.responses import JSONResponse -from loguru import logger -from pyattest.assertion import Assertion from pyattest.testutils.factories.attestation import apple as apple_factory from mlpa.core.classes import AuthorizedChatRequest, ChatRequest -from mlpa.core.config import env +from mlpa.core.logger import logger from mlpa.core.routers.appattest.appattest import validate_challenge from mlpa.core.utils import b64decode_safe, parse_app_attest_jwt from tests.consts import ( diff --git a/src/tests/unit/test_completions.py b/src/tests/unit/test_completions.py index 8952cd0..1799080 100644 --- a/src/tests/unit/test_completions.py +++ b/src/tests/unit/test_completions.py @@ -344,9 +344,6 @@ async def test_get_completion_400_non_rate_limit_error(mocker): mock_get_client = mocker.patch("mlpa.core.completions.get_http_client") mock_get_client.return_value = mock_client - mock_metrics = mocker.patch("mlpa.core.completions.metrics") - mock_logger = mocker.patch("mlpa.core.completions.logger") - # Act & Assert: Expect a 400 HTTPException with the upstream error payload with pytest.raises(HTTPException) as exc_info: await get_completion(SAMPLE_REQUEST) @@ -552,14 +549,14 @@ async def test_stream_completion_400_non_rate_limit_error( ) mock_metrics = mocker.patch("mlpa.core.completions.metrics") - mock_logger = mocker.patch("mlpa.core.completions.logger") + mock_logger = mocker.patch("mlpa.core.utils.logger") received_chunks = [chunk async for chunk in stream_completion(SAMPLE_REQUEST)] assert len(received_chunks) == 1 assert ( received_chunks[0] - == b'data: {"error": "Upstream service returned an error"}\n\n' + == b'data: {"code": 400, "error": "Upstream service returned an error"}\n\n' ) mock_logger.error.assert_called_once() mock_metrics.chat_completion_latency.labels.assert_called_once_with( @@ -584,14 +581,14 @@ async def test_stream_completion_429_non_rate_limit_error( ) mock_metrics = mocker.patch("mlpa.core.completions.metrics") - mock_logger = mocker.patch("mlpa.core.completions.logger") + mock_logger = mocker.patch("mlpa.core.utils.logger") received_chunks = [chunk async for chunk in stream_completion(SAMPLE_REQUEST)] assert len(received_chunks) == 1 assert ( received_chunks[0] - == b'data: {"error": "Upstream service returned an error"}\n\n' + == b'data: {"code": 429, "error": "Upstream service returned an error"}\n\n' ) mock_logger.error.assert_called_once() mock_metrics.chat_completion_latency.labels.assert_called_once_with( @@ -611,14 +608,14 @@ async def test_stream_completion_429_invalid_json(httpx_mock: HTTPXMock, mocker) ) mock_metrics = mocker.patch("mlpa.core.completions.metrics") - mock_logger = mocker.patch("mlpa.core.completions.logger") + mock_logger = mocker.patch("mlpa.core.utils.logger") received_chunks = [chunk async for chunk in stream_completion(SAMPLE_REQUEST)] assert len(received_chunks) == 1 assert ( received_chunks[0] - == b'data: {"error": "Upstream service returned an error"}\n\n' + == b'data: {"code": 429, "error": "Upstream service returned an error"}\n\n' ) mock_logger.error.assert_called_once() mock_metrics.chat_completion_latency.labels.assert_called_once_with(