Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mlpa/core/app_attest/qa_certificates.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from google.cloud import storage
from google.cloud.exceptions import NotFound
from loguru import logger

from mlpa.core.config import env
from mlpa.core.logger import logger

QA_CERT_DIR = Path(env.APP_ATTEST_QA_CERT_DIR)
QA_CERT_FILENAMES: tuple[str, ...] = (
Expand Down
40 changes: 11 additions & 29 deletions src/mlpa/core/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import httpx
import tiktoken
from fastapi import HTTPException
from loguru import logger

from mlpa.core.classes import AuthorizedChatRequest
from mlpa.core.config import (
Expand All @@ -16,8 +15,9 @@
env,
)
from mlpa.core.http_client import get_http_client
from mlpa.core.logger import logger
from mlpa.core.prometheus_metrics import PrometheusResult, metrics
from mlpa.core.utils import is_rate_limit_error
from mlpa.core.utils import is_rate_limit_error, raise_and_log

# Global default tokenizer - initialized once at module load time
_global_default_tokenizer: Optional[tiktoken.Encoding] = None
Expand Down Expand Up @@ -124,10 +124,7 @@ async def stream_completion(authorized_chat_request: AuthorizedChatRequest):
return

# For other errors or if we couldn't parse the error
logger.error(
f"Upstream service returned an error: {e.response.status_code} - {error_text_str}"
)
yield f'data: {{"error": "Upstream service returned an error"}}\n\n'.encode()
yield raise_and_log(e, True)
return

async for chunk in response.aiter_bytes():
Expand All @@ -147,13 +144,15 @@ async def stream_completion(authorized_chat_request: AuthorizedChatRequest):
metrics.chat_tokens.labels(type="completion").inc(num_completion_tokens)
result = PrometheusResult.SUCCESS
except httpx.HTTPStatusError as e:
logger.error(f"Upstream service returned an error: {e}")
if not streaming_started:
yield f'data: {{"error": "Upstream service returned an error"}}\n\n'.encode()
yield raise_and_log(e, True)
else:
logger.error(f"Upstream service returned an error: {e}")
except Exception as e:
logger.error(f"Failed to proxy request to {LITELLM_COMPLETIONS_URL}: {e}")
if not streaming_started:
yield f'data: {{"error": "Failed to proxy request"}}\n\n'.encode()
yield raise_and_log(e, True, 502, "Failed to proxy request")
else:
logger.error(f"Upstream service returned an error: {e}")
finally:
metrics.chat_completion_latency.labels(result=result).observe(
time.time() - start_time
Expand Down Expand Up @@ -188,20 +187,7 @@ async def get_completion(authorized_chat_request: AuthorizedChatRequest):
except httpx.HTTPStatusError as e:
if e.response.status_code in {400, 429}:
_handle_rate_limit_error(e.response.text, authorized_chat_request.user)
logger.error(
f"Upstream service returned an error: {e.response.status_code} - {e.response.text}"
)
raise HTTPException(
status_code=e.response.status_code,
detail={"error": "Upstream service returned an error"},
)
logger.error(
f"Upstream service returned an error: {e.response.status_code} - {e.response.text}"
)
raise HTTPException(
status_code=e.response.status_code,
detail={"error": "Upstream service returned an error"},
)
raise_and_log(e)
data = response.json()
usage = data.get("usage", {})
prompt_tokens = usage.get("prompt_tokens", 0)
Expand All @@ -215,11 +201,7 @@ async def get_completion(authorized_chat_request: AuthorizedChatRequest):
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to proxy request to {LITELLM_COMPLETIONS_URL}: {e}")
raise HTTPException(
status_code=502,
detail={"error": f"Failed to proxy request"},
)
raise_and_log(e, False, 502, "Failed to proxy request")
finally:
metrics.chat_completion_latency.labels(result=result).observe(
time.time() - start_time
Expand Down
67 changes: 34 additions & 33 deletions src/mlpa/core/pg_services/app_attest_pg_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from loguru import logger

from mlpa.core.config import env
from mlpa.core.logger import logger
from mlpa.core.pg_services.pg_service import PGService


Expand All @@ -11,36 +10,36 @@ def __init__(self):
# Challenges #
async def store_challenge(self, key_id_b64: str, challenge: str):
try:
async with self.pg.acquire() as conn:
query = """
INSERT INTO challenges (key_id_b64, challenge)
VALUES ($1, $2)
ON CONFLICT (key_id_b64) DO UPDATE SET
challenge = EXCLUDED.challenge,
created_at = NOW()
await self.pg.fetchval(
"""
stmt = await self._get_prepared_statement(conn, query)
await stmt.execute(key_id_b64, challenge)
INSERT INTO challenges (key_id_b64, challenge)
VALUES ($1, $2)
ON CONFLICT (key_id_b64) DO UPDATE SET
challenge = EXCLUDED.challenge,
created_at = NOW()
RETURNING 1
""",
key_id_b64,
challenge,
)
except Exception as e:
logger.error(f"Error storing challenge: {e}")

async def get_challenge(self, key_id_b64: str) -> dict | None:
try:
async with self.pg.acquire() as conn:
stmt = await self._get_prepared_statement(
conn,
"SELECT challenge, created_at FROM challenges WHERE key_id_b64 = $1",
)
return await stmt.fetchrow(key_id_b64)
return await self.pg.fetchrow(
"SELECT challenge, created_at FROM challenges WHERE key_id_b64 = $1",
key_id_b64,
)
except Exception as e:
logger.error(f"Error retrieving challenge: {e}")

async def delete_challenge(self, key_id_b64: str):
try:
async with self.pg.acquire() as conn:
query = "DELETE FROM challenges WHERE key_id_b64 = $1"
stmt = await self._get_prepared_statement(conn, query)
await stmt.execute(key_id_b64)
await self.pg.fetchval(
"DELETE FROM challenges WHERE key_id_b64 = $1 RETURNING 1",
key_id_b64,
)
except Exception as e:
logger.error(f"Error deleting challenge: {e}")

Expand All @@ -65,25 +64,27 @@ async def store_key(self, key_id_b64: str, public_key_pem: str, counter: int):

async def get_key(self, key_id_b64: str) -> dict | None:
try:
async with self.pg.acquire() as conn:
query = "SELECT public_key_pem, counter FROM public_keys WHERE key_id_b64 = $1"
stmt = await self._get_prepared_statement(conn, query)
return await stmt.fetchrow(key_id_b64)
return await self.pg.fetchrow(
"SELECT public_key_pem, counter FROM public_keys WHERE key_id_b64 = $1",
key_id_b64,
)
except Exception as e:
logger.error(f"Error retrieving key: {e}")
return None

async def update_key_counter(self, key_id_b64: str, counter: int):
try:
async with self.pg.acquire() as conn:
query = """
UPDATE public_keys
SET counter = $2,
updated_at = NOW()
WHERE key_id_b64 = $1 AND counter < $2
await self.pg.fetchval(
"""
stmt = await self._get_prepared_statement(conn, query)
await stmt.execute(key_id_b64, counter)
UPDATE public_keys
SET counter = $2,
updated_at = NOW()
WHERE key_id_b64 = $1 AND counter < $2
RETURNING 1
""",
key_id_b64,
counter,
)
except Exception as e:
logger.error(f"Error updating key counter: {e}")

Expand Down
48 changes: 25 additions & 23 deletions src/mlpa/core/pg_services/litellm_pg_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import HTTPException
from loguru import logger

from mlpa.core.config import env
from mlpa.core.logger import logger
from mlpa.core.pg_services.pg_service import PGService


Expand All @@ -14,19 +14,21 @@ def __init__(self):
super().__init__(env.LITELLM_DB_NAME)

async def get_user(self, user_id: str):
async with self.pg.acquire() as conn:
query = 'SELECT * FROM "LiteLLM_EndUserTable" WHERE user_id = $1'
stmt = await self._get_prepared_statement(conn, query)
user = await stmt.fetchrow(user_id)
return dict(user) if user else None
user = await self.pg.fetchrow(
'SELECT * FROM "LiteLLM_EndUserTable" WHERE user_id = $1',
user_id,
)
return dict(user) if user else None

async def block_user(self, user_id: str, blocked: bool = True) -> dict:
try:
async with self.pg.acquire() as conn:
async with conn.transaction():
query = 'UPDATE "LiteLLM_EndUserTable" SET "blocked" = $1 WHERE user_id = $2 RETURNING *'
stmt = await self._get_prepared_statement(conn, query)
updated_user_record = await stmt.fetchrow(blocked, user_id)
updated_user_record = await conn.fetchrow(
'UPDATE "LiteLLM_EndUserTable" SET "blocked" = $1 WHERE user_id = $2 RETURNING *',
blocked,
user_id,
)

if updated_user_record is None:
logger.error(
Expand All @@ -48,21 +50,21 @@ async def block_user(self, user_id: str, blocked: bool = True) -> dict:

async def list_users(self, limit: int = 50, offset: int = 0) -> dict:
try:
async with self.pg.acquire() as conn:
count_query = 'SELECT COUNT(*) FROM "LiteLLM_EndUserTable"'
count_stmt = await self._get_prepared_statement(conn, count_query)
total = await count_stmt.fetchval()

query = 'SELECT * FROM "LiteLLM_EndUserTable" ORDER BY user_id LIMIT $1 OFFSET $2'
stmt = await self._get_prepared_statement(conn, query)
users = await stmt.fetch(limit, offset)
total = await self.pg.fetchval(
'SELECT COUNT(*) FROM "LiteLLM_EndUserTable"'
)
users = await self.pg.fetch(
'SELECT * FROM "LiteLLM_EndUserTable" ORDER BY user_id LIMIT $1 OFFSET $2',
limit,
offset,
)

return {
"users": [dict(user) for user in users],
"total": total,
"limit": limit,
"offset": offset,
}
return {
"users": [dict(user) for user in users],
"total": total,
"limit": limit,
"offset": offset,
}
except Exception as e:
logger.error(f"Error listing users: {e}")
raise HTTPException(
Expand Down
23 changes: 2 additions & 21 deletions src/mlpa/core/pg_services/pg_service.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import sys
from collections.abc import MutableMapping
from typing import Any

import asyncpg
from cachetools import LRUCache
from loguru import logger

from mlpa.core.config import env
from mlpa.core.logger import logger


class PGService:
Expand All @@ -18,29 +15,13 @@ def __init__(self, db_name: str):
self.connected = False
self.pg = None

def _create_stmt_cache(self) -> MutableMapping[str, Any]:
"""Create a prepared statement cache with LRU eviction."""
return LRUCache(maxsize=env.PG_PREPARED_STMT_CACHE_MAX_SIZE)

async def _get_prepared_statement(self, conn: asyncpg.Connection, query: str):
stmt_cache = getattr(conn, "_mlpa_stmt_cache", None)
if stmt_cache is None:
stmt_cache = self._create_stmt_cache()
conn._mlpa_stmt_cache = stmt_cache

if query in stmt_cache:
return stmt_cache[query]

prepared_stmt = await conn.prepare(query)
stmt_cache[query] = prepared_stmt
return prepared_stmt

async def connect(self):
try:
self.pg = await asyncpg.create_pool(
self.db_url,
min_size=env.PG_POOL_MIN_SIZE,
max_size=env.PG_POOL_MAX_SIZE,
statement_cache_size=env.PG_PREPARED_STMT_CACHE_MAX_SIZE,
)
self.connected = True
logger.info(f"Connected to /{self.db_name}")
Expand Down
2 changes: 1 addition & 1 deletion src/mlpa/core/routers/appattest/appattest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from cryptography.x509.base import load_pem_x509_certificate
from fastapi import HTTPException
from fastapi.concurrency import run_in_threadpool
from loguru import logger
from pyattest.assertion import Assertion
from pyattest.attestation import Attestation
from pyattest.configs.apple import AppleConfig

from mlpa.core.app_attest import QA_CERT_DIR, ensure_qa_certificates
from mlpa.core.config import env
from mlpa.core.logger import logger
from mlpa.core.pg_services.services import app_attest_pg
from mlpa.core.prometheus_metrics import PrometheusResult, metrics
from mlpa.core.utils import b64decode_safe
Expand Down
2 changes: 1 addition & 1 deletion src/mlpa/core/routers/appattest/middleware.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Annotated

from fastapi import APIRouter, Header, HTTPException
from loguru import logger

from mlpa.core.classes import AssertionAuth, ChatRequest
from mlpa.core.config import env
from mlpa.core.logger import logger
from mlpa.core.routers.appattest import (
generate_client_challenge,
validate_challenge,
Expand Down
2 changes: 1 addition & 1 deletion src/mlpa/core/routers/fxa/fxa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Annotated

from fastapi import APIRouter, Header, HTTPException
from loguru import logger

from mlpa.core.logger import logger
from mlpa.core.prometheus_metrics import PrometheusResult, metrics
from mlpa.core.utils import get_fxa_client

Expand Down
8 changes: 7 additions & 1 deletion src/mlpa/core/routers/user/user.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Annotated

import httpx
from fastapi import APIRouter, Depends, Header, HTTPException, Query
from loguru import logger

from mlpa.core.config import LITELLM_MASTER_AUTH_HEADERS, env
from mlpa.core.http_client import get_http_client
from mlpa.core.logger import logger
from mlpa.core.pg_services.services import litellm_pg
from mlpa.core.utils import raise_and_log

router = APIRouter()

Expand Down Expand Up @@ -45,6 +47,10 @@ async def user_info(user_id: str):
params=params,
headers=LITELLM_MASTER_AUTH_HEADERS,
)
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise_and_log(e, False, e.response.status_code, "Error fetching user info")
user = response.json()

if not user:
Expand Down
Loading