From 4365eed6c984baad24372b389580ee2882556c51 Mon Sep 17 00:00:00 2001 From: mojo-opencode Date: Sat, 14 Feb 2026 03:07:03 +0000 Subject: [PATCH] fix(db): enable sqlite WAL pragmas and dispose engine on shutdown --- .env.example | 3 ++ README.md | 1 + src/proxy_app/db.py | 67 ++++++++++++++++++++++++++++++++++++++-- src/proxy_app/main.py | 9 ++++-- tests/test_db_runtime.py | 33 ++++++++++++++++++++ 5 files changed, 108 insertions(+), 5 deletions(-) create mode 100644 tests/test_db_runtime.py diff --git a/.env.example b/.env.example index 35525aad..43d3647e 100644 --- a/.env.example +++ b/.env.example @@ -68,6 +68,9 @@ # Optional DB override. By default uses sqlite file: data/proxy.db #DATABASE_URL="sqlite+aiosqlite:///data/proxy.db" +# SQLite lock wait timeout in milliseconds. +#SQLITE_BUSY_TIMEOUT_MS=5000 + # ------------------------------------------------------------------------------ # | [API KEYS] Provider API Keys | diff --git a/README.md b/README.md index 41b42f88..3b04870f 100644 --- a/README.md +++ b/README.md @@ -507,6 +507,7 @@ The proxy includes a powerful text-based UI for configuration and management. | `API_TOKEN_PEPPER` | HMAC key for API token hashes | Required in prod | | `CORS_ALLOW_ORIGINS` | Comma-separated browser origins | empty | | `CORS_ALLOW_CREDENTIALS` | Allow credentialed CORS requests | false (unless origins configured) | +| `SQLITE_BUSY_TIMEOUT_MS` | SQLite lock timeout in milliseconds | `5000` | | `OAUTH_REFRESH_INTERVAL` | Token refresh check interval (seconds) | `600` | | `SKIP_OAUTH_INIT_CHECK` | Skip interactive OAuth setup on startup | `false` | diff --git a/src/proxy_app/db.py b/src/proxy_app/db.py index b5cd0720..31d4fff9 100644 --- a/src/proxy_app/db.py +++ b/src/proxy_app/db.py @@ -4,8 +4,14 @@ import os from pathlib import Path -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy import event, select +from sqlalchemy.engine import make_url +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) from proxy_app.db_models import Base, User @@ -33,6 +39,45 @@ def get_database_url(root_dir: Path) -> str: return f"sqlite+aiosqlite:///{db_dir / 'proxy.db'}" +def _is_sqlite_url(database_url: str) -> bool: + driver = make_url(database_url).get_backend_name() + return driver == "sqlite" + + +def _get_sqlite_busy_timeout_ms() -> int: + raw = os.getenv("SQLITE_BUSY_TIMEOUT_MS", "5000") + try: + timeout = int(raw) + except ValueError: + timeout = 5000 + return max(1000, timeout) + + +def _configure_sqlite_engine(engine: AsyncEngine) -> None: + busy_timeout_ms = _get_sqlite_busy_timeout_ms() + + @event.listens_for(engine.sync_engine, "connect") + def _set_sqlite_pragmas(dbapi_connection, _connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA journal_mode=WAL") + cursor.execute("PRAGMA synchronous=NORMAL") + cursor.execute("PRAGMA temp_store=MEMORY") + cursor.execute("PRAGMA foreign_keys=ON") + cursor.execute(f"PRAGMA busy_timeout={busy_timeout_ms}") + cursor.close() + + +def create_db_engine(database_url: str) -> AsyncEngine: + connect_args = {} + if _is_sqlite_url(database_url): + connect_args["timeout"] = _get_sqlite_busy_timeout_ms() / 1000 + + engine = create_async_engine(database_url, future=True, connect_args=connect_args) + if _is_sqlite_url(database_url): + _configure_sqlite_engine(engine) + return engine + + async def _bootstrap_initial_admin(session: AsyncSession) -> bool: username = (os.getenv("INITIAL_ADMIN_USERNAME") or "").strip() password = os.getenv("INITIAL_ADMIN_PASSWORD") or "" @@ -58,7 +103,7 @@ async def _bootstrap_initial_admin(session: AsyncSession) -> bool: async def init_db(root_dir: Path) -> async_sessionmaker[AsyncSession]: database_url = get_database_url(root_dir) - engine = create_async_engine(database_url, future=True) + engine = create_db_engine(database_url) session_maker = async_sessionmaker(engine, expire_on_commit=False) async with engine.begin() as conn: @@ -68,3 +113,19 @@ async def init_db(root_dir: Path) -> async_sessionmaker[AsyncSession]: await _bootstrap_initial_admin(session) return session_maker + + +async def init_db_runtime( + root_dir: Path, +) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: + database_url = get_database_url(root_dir) + engine = create_db_engine(database_url) + session_maker = async_sessionmaker(engine, expire_on_commit=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + async with session_maker() as session: + await _bootstrap_initial_admin(session) + + return engine, session_maker diff --git a/src/proxy_app/main.py b/src/proxy_app/main.py index fcd85b55..2eeb27f0 100644 --- a/src/proxy_app/main.py +++ b/src/proxy_app/main.py @@ -140,7 +140,7 @@ from proxy_app.batch_manager import EmbeddingBatcher from proxy_app.api_token_auth import ApiActor, get_api_actor, require_admin_api_actor from proxy_app.detailed_logger import RawIOLogger - from proxy_app.db import init_db + from proxy_app.db import init_db_runtime from proxy_app.routers import admin_router, auth_router, ui_router, user_router from proxy_app.usage_recorder import ( record_usage_event as record_usage_event_async, @@ -436,7 +436,9 @@ def filter(self, record): @asynccontextmanager async def lifespan(app: FastAPI): """Manage the RotatingClient's lifecycle with the app's lifespan.""" - app.state.db_session_maker = await init_db(_root_dir) + db_engine, db_session_maker = await init_db_runtime(_root_dir) + app.state.db_engine = db_engine + app.state.db_session_maker = db_session_maker app.state.usage_recorder = await start_usage_recorder(app.state.db_session_maker) # [MODIFIED] Perform skippable OAuth initialization at startup @@ -677,6 +679,9 @@ async def process_credential(provider: str, path: str, provider_instance): else: logging.info("RotatingClient closed.") + if hasattr(app.state, "db_engine") and app.state.db_engine: + await app.state.db_engine.dispose() + # --- FastAPI App Setup --- app = FastAPI(lifespan=lifespan) diff --git a/tests/test_db_runtime.py b/tests/test_db_runtime.py new file mode 100644 index 00000000..a95070ee --- /dev/null +++ b/tests/test_db_runtime.py @@ -0,0 +1,33 @@ +import pytest +from sqlalchemy import text + +from proxy_app.db import create_db_engine + + +@pytest.mark.asyncio +async def test_sqlite_pragmas_are_applied(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None: + monkeypatch.setenv("SQLITE_BUSY_TIMEOUT_MS", "7000") + db_file = tmp_path / "pragmas.db" + engine = create_db_engine(f"sqlite+aiosqlite:///{db_file}") + + try: + async with engine.connect() as conn: + journal_mode = ( + await conn.execute(text("PRAGMA journal_mode")) + ).scalar_one_or_none() + synchronous = ( + await conn.execute(text("PRAGMA synchronous")) + ).scalar_one_or_none() + foreign_keys = ( + await conn.execute(text("PRAGMA foreign_keys")) + ).scalar_one_or_none() + busy_timeout = ( + await conn.execute(text("PRAGMA busy_timeout")) + ).scalar_one_or_none() + + assert str(journal_mode).lower() == "wal" + assert int(synchronous) == 1 # NORMAL + assert int(foreign_keys) == 1 + assert int(busy_timeout) == 7000 + finally: + await engine.dispose()