Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
# Optional DB override. By default uses sqlite file: data/proxy.db
#DATABASE_URL="sqlite+aiosqlite:///data/proxy.db"

# SQLite lock wait timeout in milliseconds.
#SQLITE_BUSY_TIMEOUT_MS=5000


# ------------------------------------------------------------------------------
# | [API KEYS] Provider API Keys |
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ The proxy includes a powerful text-based UI for configuration and management.
| `API_TOKEN_PEPPER` | HMAC key for API token hashes | Required in prod |
| `CORS_ALLOW_ORIGINS` | Comma-separated browser origins | empty |
| `CORS_ALLOW_CREDENTIALS` | Allow credentialed CORS requests | false (unless origins configured) |
| `SQLITE_BUSY_TIMEOUT_MS` | SQLite lock timeout in milliseconds | `5000` |
| `OAUTH_REFRESH_INTERVAL` | Token refresh check interval (seconds) | `600` |
| `SKIP_OAUTH_INIT_CHECK` | Skip interactive OAuth setup on startup | `false` |

Expand Down
67 changes: 64 additions & 3 deletions src/proxy_app/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@
import os
from pathlib import Path

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy import event, select
from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)

from proxy_app.db_models import Base, User

Expand Down Expand Up @@ -33,6 +39,45 @@ def get_database_url(root_dir: Path) -> str:
return f"sqlite+aiosqlite:///{db_dir / 'proxy.db'}"


def _is_sqlite_url(database_url: str) -> bool:
driver = make_url(database_url).get_backend_name()
return driver == "sqlite"


def _get_sqlite_busy_timeout_ms() -> int:
raw = os.getenv("SQLITE_BUSY_TIMEOUT_MS", "5000")
try:
timeout = int(raw)
except ValueError:
timeout = 5000
return max(1000, timeout)


def _configure_sqlite_engine(engine: AsyncEngine) -> None:
busy_timeout_ms = _get_sqlite_busy_timeout_ms()

@event.listens_for(engine.sync_engine, "connect")
def _set_sqlite_pragmas(dbapi_connection, _connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA temp_store=MEMORY")
cursor.execute("PRAGMA foreign_keys=ON")
cursor.execute(f"PRAGMA busy_timeout={busy_timeout_ms}")
cursor.close()


def create_db_engine(database_url: str) -> AsyncEngine:
connect_args = {}
if _is_sqlite_url(database_url):
connect_args["timeout"] = _get_sqlite_busy_timeout_ms() / 1000

engine = create_async_engine(database_url, future=True, connect_args=connect_args)
if _is_sqlite_url(database_url):
_configure_sqlite_engine(engine)
return engine


async def _bootstrap_initial_admin(session: AsyncSession) -> bool:
username = (os.getenv("INITIAL_ADMIN_USERNAME") or "").strip()
password = os.getenv("INITIAL_ADMIN_PASSWORD") or ""
Expand All @@ -58,7 +103,7 @@ async def _bootstrap_initial_admin(session: AsyncSession) -> bool:

async def init_db(root_dir: Path) -> async_sessionmaker[AsyncSession]:
database_url = get_database_url(root_dir)
engine = create_async_engine(database_url, future=True)
engine = create_db_engine(database_url)
session_maker = async_sessionmaker(engine, expire_on_commit=False)

async with engine.begin() as conn:
Expand All @@ -68,3 +113,19 @@ async def init_db(root_dir: Path) -> async_sessionmaker[AsyncSession]:
await _bootstrap_initial_admin(session)

return session_maker


async def init_db_runtime(
root_dir: Path,
) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]:
database_url = get_database_url(root_dir)
engine = create_db_engine(database_url)
session_maker = async_sessionmaker(engine, expire_on_commit=False)

async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

async with session_maker() as session:
await _bootstrap_initial_admin(session)

return engine, session_maker
9 changes: 7 additions & 2 deletions src/proxy_app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@
from proxy_app.batch_manager import EmbeddingBatcher
from proxy_app.api_token_auth import ApiActor, get_api_actor, require_admin_api_actor
from proxy_app.detailed_logger import RawIOLogger
from proxy_app.db import init_db
from proxy_app.db import init_db_runtime
from proxy_app.routers import admin_router, auth_router, ui_router, user_router
from proxy_app.usage_recorder import (
record_usage_event as record_usage_event_async,
Expand Down Expand Up @@ -436,7 +436,9 @@ def filter(self, record):
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage the RotatingClient's lifecycle with the app's lifespan."""
app.state.db_session_maker = await init_db(_root_dir)
db_engine, db_session_maker = await init_db_runtime(_root_dir)
app.state.db_engine = db_engine
app.state.db_session_maker = db_session_maker
app.state.usage_recorder = await start_usage_recorder(app.state.db_session_maker)

# [MODIFIED] Perform skippable OAuth initialization at startup
Expand Down Expand Up @@ -677,6 +679,9 @@ async def process_credential(provider: str, path: str, provider_instance):
else:
logging.info("RotatingClient closed.")

if hasattr(app.state, "db_engine") and app.state.db_engine:
await app.state.db_engine.dispose()


# --- FastAPI App Setup ---
app = FastAPI(lifespan=lifespan)
Expand Down
33 changes: 33 additions & 0 deletions tests/test_db_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
from sqlalchemy import text

from proxy_app.db import create_db_engine


@pytest.mark.asyncio
async def test_sqlite_pragmas_are_applied(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
monkeypatch.setenv("SQLITE_BUSY_TIMEOUT_MS", "7000")
db_file = tmp_path / "pragmas.db"
engine = create_db_engine(f"sqlite+aiosqlite:///{db_file}")

try:
async with engine.connect() as conn:
journal_mode = (
await conn.execute(text("PRAGMA journal_mode"))
).scalar_one_or_none()
synchronous = (
await conn.execute(text("PRAGMA synchronous"))
).scalar_one_or_none()
foreign_keys = (
await conn.execute(text("PRAGMA foreign_keys"))
).scalar_one_or_none()
busy_timeout = (
await conn.execute(text("PRAGMA busy_timeout"))
).scalar_one_or_none()

assert str(journal_mode).lower() == "wal"
assert int(synchronous) == 1 # NORMAL
assert int(foreign_keys) == 1
assert int(busy_timeout) == 7000
finally:
await engine.dispose()
Loading