diff --git a/API.md b/API.md index e5e786d..1299f5f 100644 --- a/API.md +++ b/API.md @@ -317,7 +317,7 @@ Get public token information including uploads. "allow_public_downloads": false, "uploads": [ { - "id": 1, + "public_id": "rT72ZKGMPdldiEmA9eDI7kik", "filename": "document.pdf", "ext": "pdf", "mimetype": "application/pdf", @@ -328,8 +328,8 @@ Get public token information including uploads. "status": "completed", "created_at": "2025-12-23T12:00:00Z", "completed_at": "2025-12-23T12:01:00Z", - "download_url": "http://localhost:8000/api/tokens/fbc_token/uploads/1", - "upload_url": "http://localhost:8000/api/uploads/1/tus" + "download_url": "http://localhost:8000/api/tokens/fbc_token/uploads/rT72ZKGMPdldiEmA9eDI7kik", + "upload_url": "http://localhost:8000/api/uploads/rT72ZKGMPdldiEmA9eDI7kik/tus" } ] } @@ -364,7 +364,7 @@ List all uploads for a specific token. ```json [ { - "id": 1, + "public_id": "rT72ZKGMPdldiEmA9eDI7kik", "filename": "document.pdf", "ext": "pdf", "mimetype": "application/pdf", @@ -375,8 +375,8 @@ List all uploads for a specific token. "status": "completed", "created_at": "2025-12-23T12:00:00Z", "completed_at": "2025-12-23T12:01:00Z", - "download_url": "http://localhost:8000/api/tokens/fbc_token/uploads/1", - "upload_url": "http://localhost:8000/api/uploads/1/tus" + "download_url": "http://localhost:8000/api/tokens/fbc_token/uploads/rT72ZKGMPdldiEmA9eDI7kik", + "upload_url": "http://localhost:8000/api/uploads/rT72ZKGMPdldiEmA9eDI7kik/tus" } ] ``` @@ -399,7 +399,7 @@ Get metadata information about a completed upload. **Response (200):** ```json { - "id": 1, + "public_id": "rT72ZKGMPdldiEmA9eDI7kik", "filename": "document.pdf", "ext": "pdf", "mimetype": "application/pdf", @@ -412,8 +412,8 @@ Get metadata information about a completed upload. "status": "completed", "created_at": "2025-01-01T12:00:00Z", "completed_at": "2025-01-01T12:05:00Z", - "upload_url": "http://localhost:8000/api/uploads/1/tus", - "download_url": "http://localhost:8000/api/tokens/fbc_token/uploads/1/download" + "upload_url": "http://localhost:8000/api/uploads/rT72ZKGMPdldiEmA9eDI7kik/tus", + "download_url": "http://localhost:8000/api/tokens/fbc_token/uploads/rT72ZKGMPdldiEmA9eDI7kik/download" } ``` @@ -476,9 +476,9 @@ Initiate a new file upload. **Response (201):** ```json { - "upload_id": 1, - "upload_url": "http://localhost:8000/api/uploads/1/tus", - "download_url": "http://localhost:8000/api/tokens/fbc_token/uploads/1", + "upload_id": "rT72ZKGMPdldiEmA9eDI7kik", + "upload_url": "http://localhost:8000/api/uploads/rT72ZKGMPdldiEmA9eDI7kik/tus", + "download_url": "http://localhost:8000/api/tokens/fbc_token/uploads/rT72ZKGMPdldiEmA9eDI7kik", "meta_data": { "title": "My Document", "category": "reports" @@ -545,7 +545,7 @@ Upload file chunk (TUS protocol). **Authentication:** None **Path Parameters:** -- `upload_id` (integer): The upload record ID +- `upload_id` (string): The upload record public ID (random string) **Required Headers:** - `Upload-Offset` (integer): Current upload offset (must match server state) @@ -588,7 +588,7 @@ Delete an upload and its associated file (TUS protocol). **Authentication:** None **Path Parameters:** -- `upload_id` (integer): The upload record ID +- `upload_id` (string): The upload record public ID (random string) **Response (204):** No content @@ -609,7 +609,7 @@ Cancel an in-progress upload and restore the token slot. **Authentication:** Required via query parameter **Path Parameters:** -- `upload_id` (integer): The upload record ID +- `upload_id` (string): The upload record public ID (random string) **Query Parameters:** - `token` (string, required): The upload token @@ -640,12 +640,12 @@ Manually mark an upload as complete. **Authentication:** None **Path Parameters:** -- `upload_id` (integer): The upload record ID +- `upload_id` (string): The upload record public ID (random string) **Response (200):** ```json { - "id": 1, + "public_id": "rT72ZKGMPdldiEmA9eDI7kik", "filename": "document.pdf", "ext": "pdf", "mimetype": "application/pdf", @@ -819,7 +819,7 @@ Delete an upload record and its file (Admin only). **Authentication:** Required (Admin) **Path Parameters:** -- `upload_id` (integer): The upload record ID +- `upload_id` (string): The upload record public ID (random string) **Response (204):** No content @@ -897,10 +897,10 @@ Typical upload flow: 4. **Upload File Chunks (TUS Protocol)** ```http # Check current offset - HEAD /api/uploads/1/tus + HEAD /api/uploads/rT72ZKGMPdldiEmA9eDI7kik/tus # Upload chunk - PATCH /api/uploads/1/tus + PATCH /api/uploads/rT72ZKGMPdldiEmA9eDI7kik/tus Upload-Offset: 0 Tus-Resumable: 1.0.0 Content-Type: application/offset+octet-stream @@ -910,7 +910,7 @@ Typical upload flow: 5. **Download File** ```http - GET /api/tokens/{download_token}/uploads/1 + GET /api/tokens/{download_token}/uploads/rT72ZKGMPdldiEmA9eDI7kik Authorization: Bearer YOUR_API_KEY ``` @@ -941,6 +941,7 @@ Typical upload flow: - File paths are resolved and stored as absolute paths - Upload tokens are 18-character URL-safe strings - Download tokens are prefixed with `fbc_` followed by 16-character URL-safe strings +- Upload IDs (`public_id`) are 18-character URL-safe random strings (not sequential integers for security) - Metadata is stored as JSON in the database (`meta_data` column) - TUS protocol is recommended for files larger than a few MB for reliability - Maximum chunk size is controlled by `FBC_MAX_CHUNK_BYTES` (default: 90MB) diff --git a/backend/app/cleanup.py b/backend/app/cleanup.py index 7c148f7..538c7d7 100644 --- a/backend/app/cleanup.py +++ b/backend/app/cleanup.py @@ -3,17 +3,26 @@ import logging from datetime import UTC, datetime, timedelta from pathlib import Path +from typing import TYPE_CHECKING, Any from sqlalchemy import select, update +from sqlalchemy.engine.result import Result from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql.selectable import Select from . import config, models from .db import SessionLocal +if TYPE_CHECKING: + from sqlalchemy.engine.result import Result + from sqlalchemy.sql.dml import Update + from sqlalchemy.sql.selectable import Select + logger: logging.Logger = logging.getLogger(__name__) async def _cleanup_once() -> None: + """Perform a single cleanup operation.""" async with SessionLocal() as session: await _disable_expired_tokens(session) if config.settings.incomplete_ttl_hours > 0: @@ -22,114 +31,122 @@ async def _cleanup_once() -> None: await _remove_disabled_tokens(session) -async def _disable_expired_tokens(session: AsyncSession) -> None: - now = datetime.now(UTC) - stmt = ( +async def _disable_expired_tokens(session: AsyncSession) -> int: + """ + Disable tokens that have expired. + + Args: + session (AsyncSession): The database session. + + Returns: + int: The number of tokens disabled. + + """ + now: datetime = datetime.now(UTC) + stmt: Update = ( update(models.UploadToken) .where(models.UploadToken.expires_at < now) .where(models.UploadToken.disabled.is_(False)) .values(disabled=True) ) - res = await session.execute(stmt) + + res: Result[Any] = await session.execute(stmt) + if res.rowcount: logger.info("Disabled %d expired tokens", res.rowcount) + await session.commit() + return res.rowcount -async def _remove_stale_uploads(session: AsyncSession) -> None: - """Remove stale uploads in batches to avoid loading millions of records into memory.""" - cutoff = datetime.now(UTC) - timedelta(hours=config.settings.incomplete_ttl_hours) - cutoff_naive = cutoff.replace(tzinfo=None) +async def _remove_stale_uploads(session: AsyncSession) -> int: + """ + Remove stale uploads. - total_removed = 0 - batch_size = 100 + Args: + session (AsyncSession): The database session. - while True: - stmt = ( - select(models.UploadRecord) - .where(models.UploadRecord.status != "completed") - .where(models.UploadRecord.created_at < cutoff_naive) - .limit(batch_size) - ) - res = await session.execute(stmt) - batch = res.scalars().all() + Returns: + int: The number of uploads removed. - if not batch: - break + """ + cutoff: datetime = datetime.now(UTC) - timedelta(hours=config.settings.incomplete_ttl_hours) + cutoff_naive: datetime = cutoff.replace(tzinfo=None) - for record in batch: - if record.storage_path: - path = Path(record.storage_path) - if path.exists(): - try: - path.unlink() - except OSError: - logger.warning("Failed to remove stale upload file: %s", path) - await session.delete(record) + total_removed = 0 - await session.flush() - await session.commit() - total_removed += len(batch) + stmt: Select[tuple[models.UploadRecord]] = ( + select(models.UploadRecord).where(models.UploadRecord.status != "completed").where(models.UploadRecord.created_at < cutoff_naive) + ) + res: Result[tuple[models.UploadRecord]] = await session.execute(stmt) - if len(batch) < batch_size: - break + for record in res.scalars().all(): + if record.storage_path: + path = Path(record.storage_path) + if path.exists(): + try: + path.unlink() + except OSError: + logger.warning("Failed to remove stale upload file: %s", path) + + total_removed += 1 + await session.delete(record) + + await session.flush() + await session.commit() if total_removed > 0: logger.info("Removed %d stale uploads", total_removed) + return total_removed -async def _remove_disabled_tokens(session: AsyncSession) -> None: - """Remove old disabled tokens in batches to avoid loading millions of records into memory.""" - cutoff = datetime.now(UTC) - timedelta(days=config.settings.disabled_tokens_ttl_days) + +async def _remove_disabled_tokens(session: AsyncSession) -> int: + """ + Remove old disabled tokens. + + Args: + session (AsyncSession): The database session. + + Returns: + int: The number of tokens removed + + """ + cutoff: datetime = datetime.now(UTC) - timedelta(days=config.settings.disabled_tokens_ttl_days) total_removed = 0 - batch_size = 50 - while True: - stmt = ( - select(models.UploadToken) - .where(models.UploadToken.disabled.is_(True)) - .where(models.UploadToken.expires_at < cutoff) - .limit(batch_size) - ) - res = await session.execute(stmt) - batch = res.scalars().all() - - if not batch: - break - - for token in batch: - if config.settings.delete_files_on_token_cleanup: - uploads_stmt = select(models.UploadRecord).where(models.UploadRecord.token_id == token.id) - uploads_res = await session.execute(uploads_stmt) - uploads = uploads_res.scalars().all() - - for upload in uploads: - if upload.storage_path: - path = Path(upload.storage_path) - if path.exists(): - try: - path.unlink() - except OSError: - logger.warning( - "Failed to remove upload file during token cleanup: %s", - path, - ) - await session.delete(upload) - - storage_dir = Path(config.settings.storage_path).expanduser().resolve() / token.token - if storage_dir.exists() and storage_dir.is_dir(): - with contextlib.suppress(OSError): - storage_dir.rmdir() - - await session.delete(token) - - await session.flush() - await session.commit() - total_removed += len(batch) - - if len(batch) < batch_size: - break + stmt: Select[tuple[models.UploadToken]] = ( + select(models.UploadToken).where(models.UploadToken.disabled.is_(True)).where(models.UploadToken.expires_at < cutoff) + ) + res: Result[tuple[models.UploadToken]] = await session.execute(stmt) + + for token in res.scalars().all(): + if config.settings.delete_files_on_token_cleanup: + uploads_stmt: Select[tuple[models.UploadRecord]] = select(models.UploadRecord).where(models.UploadRecord.token_id == token.id) + uploads_res: Result[tuple[models.UploadRecord]] = await session.execute(uploads_stmt) + + for upload in uploads_res.scalars().all(): + if upload.storage_path: + path = Path(upload.storage_path) + if path.exists(): + try: + path.unlink() + except OSError: + logger.warning("Failed to remove upload file during token cleanup: %s", path) + + total_removed += 1 + await session.delete(upload) + + storage_dir: Path = Path(config.settings.storage_path).expanduser().resolve() / token.token + if storage_dir.exists() and storage_dir.is_dir(): + with contextlib.suppress(OSError): + storage_dir.rmdir() + + await session.delete(token) + + await session.flush() + await session.commit() if total_removed > 0: logger.info( @@ -138,6 +155,8 @@ async def _remove_disabled_tokens(session: AsyncSession) -> None: config.settings.delete_files_on_token_cleanup, ) + return total_removed + async def start_cleanup_loop() -> None: while True: diff --git a/backend/app/cli.py b/backend/app/cli.py index 200b4ad..e064e61 100644 --- a/backend/app/cli.py +++ b/backend/app/cli.py @@ -2,6 +2,7 @@ import json import sys from pathlib import Path +from typing import Any import httpx import typer @@ -13,6 +14,7 @@ from backend.app import version from backend.app.config import settings from backend.app.main import app as fastapi_app +from backend.app.utils import parse_size app = typer.Typer(help=f"FBC Uploader {version.APP_VERSION} CLI", no_args_is_help=True) @@ -21,27 +23,6 @@ def _default_base_url() -> str: return settings.public_base_url or "http://127.0.0.1:8000" -MULTIPLIERS: dict[str, int] = { - "k": 1024, - "m": 1024**2, - "g": 1024**3, - "t": 1024**4, -} - - -def parse_size(text: str) -> int: - """Parse human-readable sizes: 100, 10M, 1G, 500k.""" - s = text.strip().lower() - if s[-1].isalpha(): - num = float(s[:-1]) - unit = s[-1] - if unit not in MULTIPLIERS: - msg = "Unknown size suffix; use K/M/G/T" - raise typer.BadParameter(msg) - return int(num * MULTIPLIERS[unit]) - return int(s) - - @app.command("create-token") def create_token( max_uploads: int = typer.Option(1, help="Max number of uploads"), @@ -54,22 +35,34 @@ def create_token( admin_key: str | None = typer.Option(None, envvar="FBC_ADMIN_API_KEY", help="Admin API key"), base_url: str = typer.Option(None, envvar="FBC_PUBLIC_BASE_URL", help="API base URL"), ) -> None: - """Create upload token.""" + """ + Create upload token. + + Args: + max_uploads: Maximum number of uploads allowed with this token + max_size: Maximum size per upload (e.g., 1G, 500M + allowed_mime: Comma-separated MIME patterns (e.g., application/pdf,video/*). Omit to allow any. + admin_key: Admin API key for authentication + base_url: Base URL of the API + + """ key: str = admin_key or settings.admin_api_key url_base: str = base_url or _default_base_url() - max_size_bytes: int = parse_size(max_size) - mime_list: list[str] | None = [m.strip() for m in allowed_mime.split(",")] if allowed_mime else None - payload = { + try: + max_size_bytes: int = parse_size(max_size) + except ValueError as e: + raise typer.BadParameter(str(e)) from e + + payload: dict[str, Any] = { "max_uploads": max_uploads, "max_size_bytes": max_size_bytes, - "allowed_mime": mime_list, + "allowed_mime": [m.strip() for m in allowed_mime.split(",")] if allowed_mime else None, } async def _run(): async with httpx.AsyncClient(base_url=url_base) as client: - create_url = str(fastapi_app.url_path_for("create_token")) - r = await client.post( - create_url, + r: httpx.Response = await client.post( + str(fastapi_app.url_path_for("create_token")), json=payload, headers={"Authorization": f"Bearer {key}"}, timeout=30, @@ -86,16 +79,23 @@ def view_token( admin_key: str | None = typer.Option(None, envvar="FBC_ADMIN_API_KEY", help="Admin API key"), base_url: str = typer.Option(None, envvar="FBC_PUBLIC_BASE_URL", help="API base URL"), ) -> None: - """View upload token.""" + """ + View upload token. + + Args: + token: Token to inspect + admin_key: Admin API key for authentication + base_url: Base URL of the API + + """ key: str = admin_key or settings.admin_api_key url_base: str = base_url or _default_base_url() async def _run(): try: async with httpx.AsyncClient(base_url=url_base) as client: - list_url = str(fastapi_app.url_path_for("list_token_uploads", token_value=token)) - r = await client.get( - list_url, + r: httpx.Response = await client.get( + str(fastapi_app.url_path_for("get_token", token_value=token)), headers={"Authorization": f"Bearer {key}"}, timeout=30, ) diff --git a/backend/app/config.py b/backend/app/config.py index 02852e0..7eee8aa 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -28,27 +28,28 @@ class Settings(BaseSettings): model_config = SettingsConfigDict(env_file=".env", env_prefix="FBC_", extra="ignore") - def model_post_init(self, __context) -> None: - if not self.database_url: - cfg_dir = Path(self.config_path).expanduser().resolve() - default_db_path = cfg_dir / "fbc.db" - self.database_url = f"sqlite+aiosqlite:///{default_db_path}" - - cfg_dir = Path(self.config_path).expanduser().resolve() + def model_post_init(self, _) -> None: + cfg_dir: Path = Path(self.config_path).expanduser().resolve() cfg_dir.mkdir(parents=True, exist_ok=True) + if not self.database_url: + default_db_path: Path = cfg_dir / "fbc.db" + self.database_url = f"sqlite+aiosqlite:///{default_db_path!s}" + if self.admin_api_key == "change-me": api_path: Path = cfg_dir / "secret.key" - if api_path.exists(): - self.admin_api_key = api_path.read_text().strip() - else: + if not api_path.exists(): from secrets import token_urlsafe key: str = token_urlsafe(32) api_path.write_text(key) + with contextlib.suppress(OSError): api_path.chmod(0o600) + self.admin_api_key = key + else: + self.admin_api_key = api_path.read_text().strip() @lru_cache @@ -56,4 +57,4 @@ def get_settings() -> Settings: return Settings() -settings = get_settings() +settings: Settings = get_settings() diff --git a/backend/app/db.py b/backend/app/db.py index ca00926..8deeb2f 100644 --- a/backend/app/db.py +++ b/backend/app/db.py @@ -1,22 +1,26 @@ +from collections.abc import AsyncGenerator from pathlib import Path +from typing import Any from sqlalchemy.engine import make_url +from sqlalchemy.engine.url import URL from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio.engine import AsyncEngine from sqlalchemy.orm import declarative_base from .config import settings -url = make_url(settings.database_url) +url: URL = make_url(settings.database_url) if url.drivername.startswith("sqlite") and url.database: Path(url.database).expanduser().parent.mkdir(parents=True, exist_ok=True) -engine = create_async_engine(settings.database_url, future=True) -SessionLocal = async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession) +engine: AsyncEngine = create_async_engine(settings.database_url, future=True) +SessionLocal: async_sessionmaker[AsyncSession] = async_sessionmaker(bind=engine, expire_on_commit=False, class_=AsyncSession) -Base = declarative_base() +Base: Any = declarative_base() -async def get_db(): +async def get_db() -> AsyncGenerator[AsyncSession]: async with SessionLocal() as session: yield session diff --git a/backend/app/main.py b/backend/app/main.py index 9d9d4a4..eb66f59 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,9 +1,10 @@ import asyncio import logging +import os from contextlib import asynccontextmanager, suppress from pathlib import Path -from fastapi import FastAPI, HTTPException, Request +from fastapi import FastAPI, HTTPException, Request, status from fastapi.concurrency import run_in_threadpool from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, JSONResponse @@ -17,29 +18,36 @@ from .migrate import run_migrations -def _ensure_storage() -> None: - Path(settings.storage_path).mkdir(parents=True, exist_ok=True) - Path(settings.config_path).mkdir(parents=True, exist_ok=True) - - def create_app() -> FastAPI: logging.basicConfig(level=logging.INFO, format="%(levelname)s [%(name)s] %(message)s") - _ensure_storage() + Path(settings.storage_path).mkdir(parents=True, exist_ok=True) + Path(settings.config_path).mkdir(parents=True, exist_ok=True) @asynccontextmanager async def lifespan(app: FastAPI): + """ + Application lifespan context manager. + + Args: + app (FastAPI): The FastAPI application instance. + + """ if not settings.skip_migrations: await run_in_threadpool(run_migrations) + if not settings.skip_cleanup: app.state.cleanup_task = asyncio.create_task(start_cleanup_loop(), name="cleanup_loop") + yield + if not settings.skip_cleanup: - task = getattr(app.state, "cleanup_task", None) + task: asyncio.Task | None = getattr(app.state, "cleanup_task", None) if task: task.cancel() with suppress(asyncio.CancelledError): await task + await engine.dispose() app = FastAPI( @@ -47,7 +55,7 @@ async def lifespan(app: FastAPI): lifespan=lifespan, version=version.APP_VERSION, redirect_slashes=True, - docs_url=None, + docs_url="/docs" if bool(os.getenv("FBC_DEV_MODE", "0") == "1") else None, redoc_url=None, ) @@ -55,6 +63,17 @@ async def lifespan(app: FastAPI): @app.middleware("http") async def proxy_headers_middleware(request: Request, call_next): + """ + Middleware to trust proxy headers for scheme and host. + + Args: + request (Request): The incoming HTTP request. + call_next: Function to call the next middleware or route handler. + + Returns: + Response: The HTTP response. + + """ if forwarded_proto := request.headers.get("X-Forwarded-Proto"): request.scope["scheme"] = forwarded_proto @@ -84,43 +103,76 @@ async def proxy_headers_middleware(request: Request, call_next): @app.middleware("http") async def log_exceptions(request: Request, call_next): + """ + Middleware to log unhandled exceptions. + + Args: + request (Request): The incoming HTTP request. + call_next: Function to call the next middleware or route handler. + + Returns: + Response: The HTTP response. + + """ try: return await call_next(request) except Exception as exc: logging.exception("Unhandled exception", exc_info=exc) - return JSONResponse(status_code=500, content={"detail": "Internal Server Error"}) + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "Internal Server Error"}) @app.get("/api/health", name="health") - def health(): + def health() -> dict[str, str]: + """Health check endpoint.""" return {"status": "ok"} + @app.get("/api/version", name="version") + def app_version() -> dict[str, str]: + """Get application version information.""" + return { + "version": version.APP_VERSION, + "commit_sha": version.APP_COMMIT_SHA, + "build_date": version.APP_BUILD_DATE, + "branch": version.APP_BRANCH, + } + for _route in routers.__all__: app.include_router(getattr(routers, _route).router) frontend_dir: Path = Path(settings.frontend_export_path).resolve() if frontend_dir.exists(): - @app.get("/{full_path:path}") - async def spa_fallback(full_path: str): + @app.get("/{full_path:path}", name="static_frontend") + async def frontend(full_path: str) -> FileResponse: + """ + Serve static frontend files. + + Args: + full_path (str): The requested file path. + + Returns: + FileResponse: The response containing the requested file. + + """ if full_path.startswith("api/"): - raise HTTPException(status_code=404) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) - if not full_path or full_path == "/": - index_file = frontend_dir / "index.html" + if not full_path or "/" == full_path: + index_file: Path = frontend_dir / "index.html" if index_file.exists(): - return FileResponse(index_file, status_code=200) - raise HTTPException(status_code=404) + return FileResponse(index_file, status_code=status.HTTP_200_OK) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) - requested_file = frontend_dir / full_path + requested_file: Path = frontend_dir / full_path if requested_file.is_file(): - return FileResponse(requested_file, status_code=200) + return FileResponse(requested_file, status_code=status.HTTP_200_OK) - index_file = frontend_dir / "index.html" + index_file: Path = frontend_dir / "index.html" if index_file.exists(): - return FileResponse(index_file, status_code=200) - raise HTTPException(status_code=404) + return FileResponse(index_file, status_code=status.HTTP_200_OK) + + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) return app -app = create_app() +app: FastAPI = create_app() diff --git a/backend/app/metadata_schema.py b/backend/app/metadata_schema.py index ae8a0ad..0823fd4 100644 --- a/backend/app/metadata_schema.py +++ b/backend/app/metadata_schema.py @@ -11,17 +11,24 @@ def load_schema() -> list[dict]: + """ + Load metadata schema from configuration file. + + Returns: + list[dict]: List of metadata field definitions + + """ path = Path(settings.config_path).expanduser() / "metadata.json" if not path.exists(): return [] - mtime = path.stat().st_mtime + mtime: float = path.stat().st_mtime if _cache["mtime"] == mtime: return _cache["schema"] data = json.loads(path.read_text()) if not isinstance(data, list): - raise HTTPException(status_code=500, detail="metadata.json must be a list") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="metadata.json must be a list") _cache["mtime"] = mtime _cache["schema"] = data @@ -29,18 +36,43 @@ def load_schema() -> list[dict]: def _error(field: str, msg: str) -> HTTPException: + """ + Create a standardized HTTPException for metadata validation errors. + + Args: + field (str): The metadata field that caused the error. + msg (str): The error message. + + Returns: + HTTPException: The constructed exception with status 422. + + """ return HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail={"field": field, "message": msg}, ) -def _coerce_type(value: Any, ftype: str, field: str): +def _coerce_type(value: Any, ftype: str, field: str) -> Any: + """ + Coerce a value to the specified metadata field type. + + Args: + value (Any): The value to coerce. + ftype (str): The target field type. + field (str): The metadata field name (for error reporting). + + Returns: + Any: The coerced value. + + """ if value is None: return None + try: if ftype in ("string", "text"): return str(value) + if ftype == "boolean": if isinstance(value, bool): return value @@ -49,72 +81,99 @@ def _coerce_type(value: Any, ftype: str, field: str): if str(value).lower() in ("false", "0", "no", "off"): return False raise ValueError + if ftype == "number": return float(value) + if ftype == "integer": return int(value) + if ftype == "date": return date.fromisoformat(str(value)) + if ftype == "datetime": return datetime.fromisoformat(str(value)) + if ftype in ("select", "multiselect"): return value + except Exception: raise _error(field, f"Invalid {ftype} value") + return value def validate_metadata(values: dict[str, Any]) -> dict[str, Any]: - schema = load_schema() + """ + Validate and clean metadata values against the schema. + + Args: + values (dict[str, Any]): The metadata values to validate. + + Returns: + dict[str, Any]: The cleaned metadata values. + + """ + schema: list[dict] = load_schema() + cleaned: dict[str, Any] = {} + for field in schema: - key = field["key"] - ftype = field.get("type", "string") - required = field.get("required", False) - val = values.get(key) + key: str = field["key"] + ftype: str = field.get("type", "string") + required: bool = field.get("required", False) + val: Any = values.get(key) + if val is None: if required: raise _error(key, "Field is required") + continue + val = _coerce_type(val, ftype, key) - allow_custom = field.get("allowCustom") or field.get("allow_custom") - if ftype == "multiselect": + allow_custom: bool = field.get("allowCustom") or field.get("allow_custom") + + if "multiselect" == ftype: if not isinstance(val, list): raise _error(key, "Must be a list") - allowed = field.get("options") + + allowed: list | None = field.get("options") if allowed and not allow_custom: - allowed_vals = [a if isinstance(a, str) else a.get("value") for a in allowed] + allowed_vals: list[str] = [a if isinstance(a, str) else a.get("value") for a in allowed] for v in val: if v not in allowed_vals: raise _error(key, f"Invalid option: {v}") - if ftype == "select": - allowed = field.get("options") + + if "select" == ftype: + allowed: list | None = field.get("options") if allowed and not allow_custom: - allowed_vals = [a if isinstance(a, str) else a.get("value") for a in allowed] + allowed_vals: list[str] = [a if isinstance(a, str) else a.get("value") for a in allowed] if val not in allowed_vals: raise _error(key, "Invalid option") + if ftype in ("string", "text"): - min_len = field.get("minLength") - max_len = field.get("maxLength") - if min_len and len(val) < min_len: + if (min_len := field.get("minLength")) and len(val) < min_len: raise _error(key, f"Must be at least {min_len} characters") - if max_len and len(val) > max_len: + + if (max_len := field.get("maxLength")) and len(val) > max_len: raise _error(key, f"Must be at most {max_len} characters") - regex = field.get("regex") - if regex: + + if regex := field.get("regex"): import re if not re.fullmatch(regex, val): raise _error(key, "Invalid format") + if ftype in ("number", "integer"): - min_v = field.get("min") - max_v = field.get("max") + min_v: int | None = field.get("min") + max_v: int | None = field.get("max") + if min_v is not None and val < min_v: raise _error(key, f"Must be >= {min_v}") + if max_v is not None and val > max_v: raise _error(key, f"Must be <= {max_v}") - if isinstance(val, (datetime, date)): - cleaned[key] = val.isoformat() - else: - cleaned[key] = val + + cleaned[key] = val.isoformat() if isinstance(val, (datetime, date)) else val + return cleaned diff --git a/backend/app/migrate.py b/backend/app/migrate.py index 8815268..9f83851 100644 --- a/backend/app/migrate.py +++ b/backend/app/migrate.py @@ -8,7 +8,7 @@ def run_migrations() -> None: """Run Alembic migrations to head. Uses project root alembic.ini and overrides DB URL from settings.""" - root_cfg = Path(__file__).resolve().parents[2] / "alembic.ini" + root_cfg: Path = Path(__file__).resolve().parents[2] / "alembic.ini" cfg = Config(str(root_cfg)) cfg.set_main_option("script_location", str(Path(__file__).resolve().parents[1] / "migrations")) cfg.set_main_option("sqlalchemy.url", settings.database_url) diff --git a/backend/app/models.py b/backend/app/models.py index 02579a3..c793f30 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -31,6 +31,7 @@ class UploadRecord(Base): __tablename__: str = "uploads" id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) + public_id: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True) token_id: Mapped[int] = mapped_column(Integer, ForeignKey("upload_tokens.id"), nullable=False) filename: Mapped[str | None] = mapped_column(String(255)) ext: Mapped[str | None] = mapped_column(String(32)) diff --git a/backend/app/routers/admin.py b/backend/app/routers/admin.py index 29ece68..810e459 100644 --- a/backend/app/routers/admin.py +++ b/backend/app/routers/admin.py @@ -1,4 +1,5 @@ -from typing import Annotated +from pathlib import Path +from typing import TYPE_CHECKING, Annotated from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select @@ -8,29 +9,40 @@ from backend.app.models import UploadRecord from backend.app.security import verify_admin +if TYPE_CHECKING: + from sqlalchemy.engine.result import Result + from sqlalchemy.sql.selectable import Select + router = APIRouter(prefix="/api/admin", tags=["admin"]) @router.get("/validate", name="validate_api_key") -async def validate_api_key(_: Annotated[bool, Depends(verify_admin)]): +async def validate_api_key(_: Annotated[bool, Depends(verify_admin)]) -> dict[str, bool]: """Validate the provided admin API key.""" return {"status": True} @router.delete("/uploads/{upload_id}", name="delete_upload") async def delete_upload( - upload_id: int, + upload_id: str, _: Annotated[bool, Depends(verify_admin)], db: Annotated[AsyncSession, Depends(get_db)], -): - """Delete an upload record and its associated file.""" - from pathlib import Path +) -> dict[str, str]: + """ + Delete an upload record and its associated file. + + Args: + upload_id (str): The public ID of the upload to delete. + db (AsyncSession): The database session. + + Returns: + dict[str, str]: A confirmation message with the deleted upload ID. - stmt = select(UploadRecord).where(UploadRecord.id == upload_id) - res = await db.execute(stmt) - upload = res.scalar_one_or_none() + """ + stmt: Select[tuple[UploadRecord]] = select(UploadRecord).where(UploadRecord.public_id == upload_id) + res: Result[tuple[UploadRecord]] = await db.execute(stmt) - if not upload: + if not (upload := res.scalar_one_or_none()): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Upload not found") if upload.storage_path: @@ -41,4 +53,4 @@ async def delete_upload( await db.delete(upload) await db.commit() - return {"status": "deleted", "id": upload_id} + return {"status": "deleted", "public_id": upload_id} diff --git a/backend/app/routers/metadata.py b/backend/app/routers/metadata.py index cdb4024..cc98e02 100644 --- a/backend/app/routers/metadata.py +++ b/backend/app/routers/metadata.py @@ -1,3 +1,5 @@ +from typing import Any + from fastapi import APIRouter from backend.app.metadata_schema import load_schema, validate_metadata @@ -6,11 +8,27 @@ @router.get("/", name="metadata_schema") -async def get_metadata_schema(): +async def get_metadata_schema() -> dict[str, list[dict]]: + """ + Retrieve the metadata schema fields. + + Returns: + dict: A dictionary containing the metadata schema fields. + + """ return {"fields": load_schema()} @router.post("/validate", name="metadata_schema_validate") -async def validate_metadata_payload(payload: dict): - cleaned = validate_metadata(payload.get("metadata", payload)) - return {"metadata": cleaned} +async def validate_metadata_payload(payload: dict) -> dict[str, dict[str, Any]]: + """ + Validate and clean the provided metadata payload. + + Args: + payload: Metadata payload to validate + + Returns: + dict: A dictionary containing the cleaned metadata + + """ + return {"metadata": validate_metadata(payload.get("metadata", payload))} diff --git a/backend/app/routers/notice.py b/backend/app/routers/notice.py index 0c6d9d7..e107904 100644 --- a/backend/app/routers/notice.py +++ b/backend/app/routers/notice.py @@ -8,8 +8,15 @@ @router.get("/", name="get_notice") -async def get_notice(): - notice_file = Path(settings.config_path) / "notice.md" +async def get_notice() -> dict[str, str | None]: + """ + Retrieve the site notice content from the notice.md file. + + Returns: + dict: A dictionary with the notice content, or None if no notice is set. + + """ + notice_file: Path = Path(settings.config_path) / "notice.md" if not notice_file.exists(): return {"notice": None} content: str = notice_file.read_text(encoding="utf-8") diff --git a/backend/app/routers/tokens.py b/backend/app/routers/tokens.py index d8f18c9..45d5254 100644 --- a/backend/app/routers/tokens.py +++ b/backend/app/routers/tokens.py @@ -2,57 +2,80 @@ import secrets from datetime import UTC, datetime, timedelta from pathlib import Path -from typing import Annotated +from typing import TYPE_CHECKING, Annotated -from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, Response, status +from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status from fastapi.responses import FileResponse -from sqlalchemy import select +from sqlalchemy import Sequence, select +from sqlalchemy.engine.result import Result from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql.selectable import Select from backend.app import models, schemas from backend.app.config import settings from backend.app.db import get_db -from backend.app.security import verify_admin - -router = APIRouter(prefix="/api/tokens", tags=["tokens"]) +from backend.app.security import optional_admin_check, verify_admin +if TYPE_CHECKING: + from sqlalchemy.engine.result import Result + from sqlalchemy.sql.selectable import Select -def optional_admin_check( - authorization: Annotated[str | None, Header()] = None, - api_key: Annotated[str | None, Query(description="API key")] = None, -) -> bool: - """Check admin authentication only if public downloads are disabled.""" - if settings.allow_public_downloads: - return True - return verify_admin(authorization, api_key) +router = APIRouter(prefix="/api/tokens", tags=["tokens"]) -@router.get("/", name="list_tokens") +@router.get("/", response_model=schemas.TokenListResponse, name="list_tokens") async def list_tokens( db: Annotated[AsyncSession, Depends(get_db)], _: Annotated[bool, Depends(verify_admin)], skip: Annotated[int, Query(ge=0, description="Number of records to skip")] = 0, limit: Annotated[int, Query(ge=1, le=100, description="Maximum number of records to return")] = 10, -): - count_stmt = select(models.UploadToken) - count_res = await db.execute(count_stmt) - total = len(count_res.scalars().all()) +) -> schemas.TokenListResponse: + """ + List all created upload tokens. + + Args: + db (AsyncSession): The database session. + skip (int): Number of records to skip. + limit (int): Maximum number of records to return. + + Returns: + TokenListResponse: A list of upload tokens with total count. + + """ + count_stmt: Select[tuple[models.UploadToken]] = select(models.UploadToken) + count_res: Result[tuple[models.UploadToken]] = await db.execute(count_stmt) - stmt = select(models.UploadToken).order_by(models.UploadToken.created_at.desc()).offset(skip).limit(limit) - res = await db.execute(stmt) - tokens = res.scalars().all() + stmt: Select[tuple[models.UploadToken]] = ( + select(models.UploadToken).order_by(models.UploadToken.created_at.desc()).offset(skip).limit(limit) + ) + res: Result[tuple[models.UploadToken]] = await db.execute(stmt) - return {"tokens": tokens, "total": total} + return schemas.TokenListResponse( + tokens=res.scalars().all(), + total=len(count_res.scalars().all()), + ) -@router.post("/", response_model=schemas.TokenResponse, status_code=201, name="create_token") +@router.post("/", response_model=schemas.TokenResponse, status_code=status.HTTP_201_CREATED, name="create_token") async def create_token( request: Request, payload: schemas.TokenCreate, db: Annotated[AsyncSession, Depends(get_db)], _: Annotated[bool, Depends(verify_admin)], -): - expires_at = payload.expiry_datetime or datetime.now(UTC) + timedelta(hours=settings.default_token_ttl_hours) +) -> schemas.TokenResponse: + """ + Create a new upload token. + + Args: + request (Request): The FastAPI request object. + payload (TokenCreate): The token creation payload. + db (AsyncSession): The database session. + + Returns: + TokenResponse: The created upload token details. + + """ + expires_at: datetime = payload.expiry_datetime or datetime.now(UTC) + timedelta(hours=settings.default_token_ttl_hours) if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=UTC) @@ -73,7 +96,7 @@ async def create_token( upload_url = str(request.url_for("health")) if upload_token: - upload_url = upload_url.replace("/api/health", f"/t/{upload_token}") + upload_url: str = upload_url.replace("/api/health", f"/t/{upload_token}") return schemas.TokenResponse( token=upload_token, @@ -86,18 +109,71 @@ async def create_token( ) -@router.get("/{token_value}", response_model=schemas.TokenInfo, name="get_token") +@router.get("/{token_value}", response_model=schemas.TokenPublicInfo, name="get_token") async def get_token( + request: Request, token_value: str, db: Annotated[AsyncSession, Depends(get_db)], - _: Annotated[bool, Depends(optional_admin_check)], -): - stmt = select(models.UploadToken).where((models.UploadToken.token == token_value) | (models.UploadToken.download_token == token_value)) - res = await db.execute(stmt) - record = res.scalar_one_or_none() - if not record: +) -> schemas.TokenPublicInfo: + """ + Get information about an upload token. + + Args: + request (Request): The FastAPI request object. + token_value (str): The upload or download token value. + db (AsyncSession): The database session. + + Returns: + TokenPublicInfo: The upload token information. + + """ + stmt: Select[tuple[models.UploadToken]] = select(models.UploadToken).where( + (models.UploadToken.token == token_value) | (models.UploadToken.download_token == token_value) + ) + res: Result[tuple[models.UploadToken]] = await db.execute(stmt) + + if not (token_row := res.scalar_one_or_none()): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Token not found") - return record + + now: datetime = datetime.now(UTC) + expires_at: datetime = token_row.expires_at + + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=UTC) + + if token_row.disabled: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Token is disabled") + + if expires_at < now: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Token has expired") + + uploads_stmt: Select[tuple[models.UploadRecord]] = ( + select(models.UploadRecord).where(models.UploadRecord.token_id == token_row.id).order_by(models.UploadRecord.created_at.desc()) + ) + uploads_res: Result[tuple[models.UploadRecord]] = await db.execute(uploads_stmt) + uploads: Sequence[models.UploadRecord] = uploads_res.scalars().all() + + uploads_list: list[schemas.UploadRecordResponse] = [] + for u in uploads: + item: schemas.UploadRecordResponse = schemas.UploadRecordResponse.model_validate(u, from_attributes=True) + item.upload_url = str(request.url_for("tus_head", upload_id=u.public_id)) + item.download_url = str(request.url_for("download_file", download_token=token_row.download_token, upload_id=u.public_id)) + item.info_url = str(request.url_for("get_file_info", download_token=token_row.download_token, upload_id=u.public_id)) + uploads_list.append(item) + + return schemas.TokenPublicInfo( + token=token_row.token if token_value == token_row.token else None, + download_token=token_row.download_token, + remaining_uploads=token_row.remaining_uploads, + max_uploads=token_row.max_uploads, + max_size_bytes=token_row.max_size_bytes, + max_chunk_bytes=settings.max_chunk_bytes, + allowed_mime=token_row.allowed_mime, + expires_at=token_row.expires_at, + disabled=token_row.disabled, + allow_public_downloads=settings.allow_public_downloads, + uploads=uploads_list, + ) @router.patch("/{token_value}", response_model=schemas.TokenInfo, name="update_token") @@ -106,13 +182,25 @@ async def update_token( payload: schemas.TokenUpdate, db: Annotated[AsyncSession, Depends(get_db)], _: Annotated[bool, Depends(verify_admin)], -): - token_stmt = select(models.UploadToken).where( +) -> models.UploadToken: + """ + Update an existing upload token. + + Args: + token_value (str): The upload or download token value. + payload (TokenUpdate): The token update payload. + db (AsyncSession): The database session. + + Returns: + UploadToken: The updated upload token. + + """ + token_stmt: Select[tuple[models.UploadToken]] = select(models.UploadToken).where( (models.UploadToken.token == token_value) | (models.UploadToken.download_token == token_value) ) - token_res = await db.execute(token_stmt) - token_row = token_res.scalar_one_or_none() - if not token_row: + token_res: Result[tuple[models.UploadToken]] = await db.execute(token_stmt) + + if not (token_row := token_res.scalar_one_or_none()): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Token not found") if payload.max_uploads is not None: @@ -130,15 +218,19 @@ async def update_token( token_row.allowed_mime = payload.allowed_mime if payload.expiry_datetime: - expires_at = payload.expiry_datetime + expires_at: datetime = payload.expiry_datetime + if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=UTC) + token_row.expires_at = expires_at if payload.extend_hours: expires_at = token_row.expires_at + if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=UTC) + token_row.expires_at = expires_at + timedelta(hours=payload.extend_hours) if payload.disabled is not None: @@ -149,26 +241,38 @@ async def update_token( return token_row -@router.delete("/{token_value}", status_code=204, name="delete_token") +@router.delete("/{token_value}", status_code=status.HTTP_204_NO_CONTENT, name="delete_token") async def delete_token( token_value: str, *, delete_files: Annotated[bool, Query(..., description="Also delete uploaded files")] = False, db: Annotated[AsyncSession, Depends(get_db)], _: Annotated[bool, Depends(verify_admin)], -): - token_stmt = select(models.UploadToken).where( +) -> Response: + """ + Delete an upload token and optionally its associated files. + + Args: + token_value (str): The upload or download token value. + delete_files (bool): Whether to delete associated uploaded files. + db (AsyncSession): The database session. + + Returns: + Response: A response with status code 204 No Content. + + """ + token_stmt: Select[tuple[models.UploadToken]] = select(models.UploadToken).where( (models.UploadToken.token == token_value) | (models.UploadToken.download_token == token_value) ) - token_res = await db.execute(token_stmt) - token_row = token_res.scalar_one_or_none() - if not token_row: + token_res: Result[tuple[models.UploadToken]] = await db.execute(token_stmt) + + if not (token_row := token_res.scalar_one_or_none()): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Token not found") if delete_files: - uploads_stmt = select(models.UploadRecord).where(models.UploadRecord.token_id == token_row.id) - uploads_res = await db.execute(uploads_stmt) - uploads = uploads_res.scalars().all() + uploads_stmt: Select[tuple[models.UploadRecord]] = select(models.UploadRecord).where(models.UploadRecord.token_id == token_row.id) + uploads_res: Result[tuple[models.UploadRecord]] = await db.execute(uploads_stmt) + uploads: Sequence[models.UploadRecord] = uploads_res.scalars().all() for record in uploads: if record.storage_path: path = Path(record.storage_path) @@ -183,7 +287,7 @@ async def delete_token( await db.delete(token_row) await db.commit() - return Response(status_code=204) + return Response(status_code=status.HTTP_204_NO_CONTENT) @router.get("/{token_value}/uploads", response_model=list[schemas.UploadRecordResponse], name="list_token_uploads") @@ -193,138 +297,131 @@ async def list_token_uploads( token_value: str, db: Annotated[AsyncSession, Depends(get_db)], _: Annotated[bool, Depends(optional_admin_check)], -): - token_stmt = select(models.UploadToken).where( +) -> list[schemas.UploadRecordResponse]: + """ + List all uploads associated with a given token. + + Args: + request (Request): The FastAPI request object. + token_value (str): The upload or download token value. + db (AsyncSession): The database session. + + Returns: + list[UploadRecordResponse]: A list of upload records with URLs. + + """ + token_stmt: Select[tuple[models.UploadToken]] = select(models.UploadToken).where( (models.UploadToken.token == token_value) | (models.UploadToken.download_token == token_value) ) - token_res = await db.execute(token_stmt) - token_row = token_res.scalar_one_or_none() - if not token_row: - raise HTTPException(status_code=404, detail="Token not found") - - stmt = select(models.UploadRecord).where(models.UploadRecord.token_id == token_row.id).order_by(models.UploadRecord.created_at.desc()) - res = await db.execute(stmt) - uploads = res.scalars().all() - enriched = [] - for u in uploads: - item = schemas.UploadRecordResponse.model_validate(u, from_attributes=True) - item.download_url = str(request.url_for("download_file", download_token=token_row.download_token, upload_id=u.id)) - if not settings.allow_public_downloads: - item.download_url += f"?api_key={settings.admin_api_key}" - item.upload_url = str(request.url_for("tus_head", upload_id=u.id)) - item.info_url = str(request.url_for("get_file_info", download_token=token_row.download_token, upload_id=u.id)) - enriched.append(item) - return enriched - - -@router.get("/{token_value}/info", response_model=schemas.TokenPublicInfo, name="get_public_token_info") -@router.get("/{token_value}/info/", response_model=schemas.TokenPublicInfo, name="token_info_trailing_slash") -async def token_info(request: Request, token_value: str, db: Annotated[AsyncSession, Depends(get_db)]): - stmt = select(models.UploadToken).where(models.UploadToken.token == token_value) - res = await db.execute(stmt) - token_row = res.scalar_one_or_none() - if not token_row: - raise HTTPException(status_code=404, detail="Token not found") - - now = datetime.now(UTC) - expires_at = token_row.expires_at - if expires_at.tzinfo is None: - expires_at = expires_at.replace(tzinfo=UTC) - if token_row.disabled: - raise HTTPException(status_code=403, detail="Token is disabled") - if expires_at < now: - raise HTTPException(status_code=403, detail="Token has expired") - uploads_stmt = ( + token_res: Result[tuple[models.UploadToken]] = await db.execute(token_stmt) + + if not (token_row := token_res.scalar_one_or_none()): + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Token not found") + + stmt: Select[tuple[models.UploadRecord]] = ( select(models.UploadRecord).where(models.UploadRecord.token_id == token_row.id).order_by(models.UploadRecord.created_at.desc()) ) - uploads_res = await db.execute(uploads_stmt) - uploads = uploads_res.scalars().all() - - enriched_uploads = [] + res: Result[tuple[models.UploadRecord]] = await db.execute(stmt) + uploads: Sequence[models.UploadRecord] = res.scalars().all() + uploads_list: list[schemas.UploadRecordResponse] = [] for u in uploads: - item = schemas.UploadRecordResponse.model_validate(u, from_attributes=True) - item.upload_url = str(request.url_for("tus_head", upload_id=u.id)) - item.download_url = str(request.url_for("download_file", download_token=token_row.download_token, upload_id=u.id)) - item.info_url = str(request.url_for("get_file_info", download_token=token_row.download_token, upload_id=u.id)) - enriched_uploads.append(item) + item: schemas.UploadRecordResponse = schemas.UploadRecordResponse.model_validate(u, from_attributes=True) + item.download_url = str(request.url_for("download_file", download_token=token_row.download_token, upload_id=u.public_id)) + item.upload_url = str(request.url_for("tus_head", upload_id=u.public_id)) + item.info_url = str(request.url_for("get_file_info", download_token=token_row.download_token, upload_id=u.public_id)) + uploads_list.append(item) - return schemas.TokenPublicInfo( - token=token_row.token, - download_token=token_row.download_token, - remaining_uploads=token_row.remaining_uploads, - max_uploads=token_row.max_uploads, - max_size_bytes=token_row.max_size_bytes, - max_chunk_bytes=settings.max_chunk_bytes, - allowed_mime=token_row.allowed_mime, - expires_at=token_row.expires_at, - disabled=token_row.disabled, - allow_public_downloads=settings.allow_public_downloads, - uploads=enriched_uploads, - ) + return uploads_list @router.get("/{download_token}/uploads/{upload_id}", name="get_file_info", summary="Get upload file info") -@router.get("/{download_token}/uploads/{upload_id}/", name="get_file_info_trailing_slash") +@router.get("/{download_token}/uploads/{upload_id}/") async def get_file_info( request: Request, download_token: str, - upload_id: int, + upload_id: str, db: Annotated[AsyncSession, Depends(get_db)], _: Annotated[bool, Depends(optional_admin_check)], -): - token_stmt = select(models.UploadToken).where(models.UploadToken.download_token == download_token) - token_res = await db.execute(token_stmt) - token_row = token_res.scalar_one_or_none() - if not token_row: - raise HTTPException(status_code=404, detail="Download token not found") - - upload_stmt = select(models.UploadRecord).where(models.UploadRecord.id == upload_id, models.UploadRecord.token_id == token_row.id) - upload_res = await db.execute(upload_stmt) - record = upload_res.scalar_one_or_none() - if not record: - raise HTTPException(status_code=404, detail="Upload not found") - if record.status != "completed": - raise HTTPException(status_code=409, detail="Upload not yet completed") +) -> schemas.UploadRecordResponse: + """ + Retrieve metadata about a specific uploaded file. + + Args: + request (Request): The FastAPI request object. + download_token (str): The download token associated with the upload. + upload_id (str): The public ID of the upload. + db (AsyncSession): The database session. + + Returns: + UploadRecordResponse: Metadata about the uploaded file. + + """ + token_stmt: Select[tuple[models.UploadToken]] = select(models.UploadToken).where(models.UploadToken.download_token == download_token) + token_res: Result[tuple[models.UploadToken]] = await db.execute(token_stmt) + + if not (token_row := token_res.scalar_one_or_none()): + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Download token not found") + + upload_stmt: Select[tuple[models.UploadRecord]] = select(models.UploadRecord).where( + models.UploadRecord.public_id == upload_id, models.UploadRecord.token_id == token_row.id + ) + upload_res: Result[tuple[models.UploadRecord]] = await db.execute(upload_stmt) + + if not (record := upload_res.scalar_one_or_none()): + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Upload not found") + + if "completed" != record.status: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Upload not yet completed") path = Path(record.storage_path or "") if not path.exists(): - raise HTTPException(status_code=404, detail="File missing") + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File missing") - # Return JSON metadata about the file - item = schemas.UploadRecordResponse.model_validate(record, from_attributes=True) + item: schemas.UploadRecordResponse = schemas.UploadRecordResponse.model_validate(record, from_attributes=True) item.download_url = str(request.url_for("download_file", download_token=download_token, upload_id=upload_id)) - if not settings.allow_public_downloads: - item.download_url += f"?api_key={settings.admin_api_key}" - item.upload_url = str(request.url_for("tus_head", upload_id=upload_id)) item.info_url = str(request.url_for("get_file_info", download_token=download_token, upload_id=upload_id)) return item @router.get("/{download_token}/uploads/{upload_id}/download", name="download_file") -@router.get("/{download_token}/uploads/{upload_id}/download/", name="download_file_trailing_slash") +@router.get("/{download_token}/uploads/{upload_id}/download/") async def download_file( download_token: str, - upload_id: int, + upload_id: str, db: Annotated[AsyncSession, Depends(get_db)], _: Annotated[bool, Depends(optional_admin_check)], -): - token_stmt = select(models.UploadToken).where(models.UploadToken.download_token == download_token) - token_res = await db.execute(token_stmt) - token_row = token_res.scalar_one_or_none() - if not token_row: - raise HTTPException(status_code=404, detail="Download token not found") - - upload_stmt = select(models.UploadRecord).where(models.UploadRecord.id == upload_id, models.UploadRecord.token_id == token_row.id) - upload_res = await db.execute(upload_stmt) - record = upload_res.scalar_one_or_none() - if not record: - raise HTTPException(status_code=404, detail="Upload not found") - if record.status != "completed": - raise HTTPException(status_code=409, detail="Upload not yet completed") +) -> FileResponse: + """ + Download the file associated with a specific upload. + + Args: + download_token (str): The download token associated with the upload. + upload_id (str): The public ID of the upload. + db (AsyncSession): The database session. + + Returns: + FileResponse: The file response for downloading the file. + + """ + token_stmt: Select[tuple[models.UploadToken]] = select(models.UploadToken).where(models.UploadToken.download_token == download_token) + token_res: Result[tuple[models.UploadToken]] = await db.execute(token_stmt) + if not (token_row := token_res.scalar_one_or_none()): + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Download token not found") + + upload_stmt: Select[tuple[models.UploadRecord]] = select(models.UploadRecord).where( + models.UploadRecord.public_id == upload_id, models.UploadRecord.token_id == token_row.id + ) + upload_res: Result[tuple[models.UploadRecord]] = await db.execute(upload_stmt) + + if not (record := upload_res.scalar_one_or_none()): + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Upload not found") + + if "completed" != record.status: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Upload not yet completed") path = Path(record.storage_path or "") if not path.exists(): - raise HTTPException(status_code=404, detail="File missing") + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File missing") return FileResponse(path, filename=record.filename or path.name, media_type=record.mimetype or "application/octet-stream") diff --git a/backend/app/routers/uploads.py b/backend/app/routers/uploads.py index 0664e55..40825cf 100644 --- a/backend/app/routers/uploads.py +++ b/backend/app/routers/uploads.py @@ -1,78 +1,123 @@ import contextlib +import secrets from datetime import UTC, datetime from pathlib import Path -from typing import Annotated +from typing import TYPE_CHECKING, Annotated, Any import aiofiles from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, Response, status -from sqlalchemy import select, update +from sqlalchemy import select +from sqlalchemy.engine.result import Result from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import attributes +from sqlalchemy.sql.selectable import Select from backend.app import models, schemas from backend.app.config import settings from backend.app.db import get_db from backend.app.metadata_schema import validate_metadata -from backend.app.utils import detect_mimetype, extract_ffprobe_metadata, is_multimedia +from backend.app.utils import detect_mimetype, extract_ffprobe_metadata, is_multimedia, mime_allowed + +if TYPE_CHECKING: + from sqlalchemy.engine.result import Result + from sqlalchemy.sql.selectable import Select router = APIRouter(prefix="/api/uploads", tags=["uploads"]) async def _ensure_token(db: AsyncSession, token_value: str) -> models.UploadToken: - stmt = select(models.UploadToken).where(models.UploadToken.token == token_value) - res = await db.execute(stmt) - token = res.scalar_one_or_none() - if not token: + """ + Ensure the upload token is valid, not expired or disabled, and has remaining uploads. + + Args: + db (AsyncSession): Database session. + token_value (str): The upload token string. + + Returns: + UploadToken: The valid upload token object. + + """ + stmt: Select[tuple[models.UploadToken]] = select(models.UploadToken).where(models.UploadToken.token == token_value) + res: Result[tuple[models.UploadToken]] = await db.execute(stmt) + if not (token := res.scalar_one_or_none()): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Token not found") - now = datetime.now(UTC) - expires_at = token.expires_at + + now: datetime = datetime.now(UTC) + expires_at: datetime = token.expires_at + if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=UTC) + if token.disabled or expires_at < now: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Token expired or disabled") + if token.remaining_uploads <= 0: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Upload limit reached") + return token -def _mime_allowed(filetype: str | None, allowed: list[str] | None) -> bool: - if not allowed or not filetype: - return True - for pattern in allowed: - if pattern.endswith("/*"): - prefix = pattern.split("/")[0] - if filetype.startswith(prefix + "/"): - return True - elif filetype == pattern: - return True - return False +async def _get_upload_record(db: AsyncSession, upload_id: str) -> models.UploadRecord: + """ + Retrieve the upload record by its public ID. + + Args: + db (AsyncSession): Database session. + upload_id (str): The public ID of the upload. + + Returns: + UploadRecord: The upload record object. + + """ + stmt: Select[tuple[models.UploadRecord]] = select(models.UploadRecord).where(models.UploadRecord.public_id == upload_id) + res: Result[tuple[models.UploadRecord]] = await db.execute(stmt) + if not (record := res.scalar_one_or_none()): + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Upload not found") + + return record -@router.post("/initiate", response_model=dict, status_code=201, name="initiate_upload") +@router.post("/initiate", response_model=schemas.InitiateUploadResponse, status_code=status.HTTP_201_CREATED, name="initiate_upload") async def initiate_upload( request: Request, payload: schemas.UploadRequest, db: Annotated[AsyncSession, Depends(get_db)], token: Annotated[str, Query(description="Upload token")] = ..., -): - token_row = await _ensure_token(db, token) - cleaned_metadata = validate_metadata(payload.meta_data or {}) +) -> schemas.InitiateUploadResponse: + """ + Initiate a new upload record and prepare for TUS upload. + + Args: + request (Request): The incoming HTTP request. + payload (UploadRequest): The upload request payload. + db (AsyncSession): Database session. + token (str): The upload token string. + + Returns: + InitiateUploadResponse: Details for the initiated upload. + + """ + token_row: models.UploadToken = await _ensure_token(db, token) + cleaned_metadata: dict[str, Any] = validate_metadata(payload.meta_data or {}) + if payload.size_bytes and payload.size_bytes > token_row.max_size_bytes: raise HTTPException( status_code=status.HTTP_413_CONTENT_TOO_LARGE, detail="File size exceeds token limit", ) - if not _mime_allowed(payload.filetype, token_row.allowed_mime): + + if not mime_allowed(payload.filetype, token_row.allowed_mime): raise HTTPException( status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, detail="File type not allowed for this token", ) - ext = None + ext: str | None = None if payload.filename: - ext = Path(payload.filename).suffix.lstrip(".") + ext: str = Path(payload.filename).suffix.lstrip(".") record = models.UploadRecord( + public_id=secrets.token_urlsafe(18), token_id=token_row.id, filename=payload.filename, ext=ext, @@ -97,35 +142,36 @@ async def initiate_upload( await db.refresh(record) await db.refresh(token_row) - upload_url = str(request.url_for("tus_head", upload_id=record.id)) - download_url = str(request.url_for("download_file", download_token=token_row.download_token, upload_id=record.id)) + return schemas.InitiateUploadResponse( + upload_id=record.public_id, + upload_url=str(request.url_for("tus_head", upload_id=record.public_id)), + download_url=str(request.url_for("download_file", download_token=token_row.download_token, upload_id=record.public_id)), + meta_data=cleaned_metadata, + allowed_mime=token_row.allowed_mime, + remaining_uploads=token_row.remaining_uploads, + ) - return { - "upload_id": record.id, - "upload_url": upload_url, - "download_url": download_url, - "meta_data": cleaned_metadata, - "allowed_mime": token_row.allowed_mime, - "remaining_uploads": token_row.remaining_uploads, - } +@router.head("/{upload_id}/tus", name="tus_head") +async def tus_head(upload_id: str, db: Annotated[AsyncSession, Depends(get_db)]): + """ + Handle TUS protocol HEAD request to get upload status. -async def _get_upload_record(db: AsyncSession, upload_id: int) -> models.UploadRecord: - stmt = select(models.UploadRecord).where(models.UploadRecord.id == upload_id) - res = await db.execute(stmt) - record = res.scalar_one_or_none() - if not record: - raise HTTPException(status_code=404, detail="Upload not found") - return record + Args: + upload_id (str): The public ID of the upload. + db (AsyncSession): Database session. + Returns: + Response: HTTP response with upload offset and length. + + """ + record: models.UploadRecord = await _get_upload_record(db, upload_id) -@router.head("/{upload_id}/tus", name="tus_head") -async def tus_head(upload_id: int, db: Annotated[AsyncSession, Depends(get_db)]): - record = await _get_upload_record(db, upload_id) if record.upload_length is None: - raise HTTPException(status_code=409, detail="Upload length unknown") + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Upload length unknown") + return Response( - status_code=200, + status_code=status.HTTP_200_OK, headers={ "Upload-Offset": str(record.upload_offset or 0), "Upload-Length": str(record.upload_length), @@ -135,35 +181,54 @@ async def tus_head(upload_id: int, db: Annotated[AsyncSession, Depends(get_db)]) @router.patch("/{upload_id}/tus", name="tus_patch") -async def tus_patch( # noqa: PLR0915 - upload_id: int, +async def tus_patch( + upload_id: str, request: Request, db: Annotated[AsyncSession, Depends(get_db)], upload_offset: Annotated[int, Header(convert_underscores=False, alias="Upload-Offset")] = ..., content_length: Annotated[int | None, Header()] = None, content_type: Annotated[str, Header(convert_underscores=False, alias="Content-Type")] = ..., -): +) -> Response: + """ + Handle TUS protocol PATCH request to upload file chunks. + + Args: + upload_id (str): The public ID of the upload. + request (Request): The incoming HTTP request. + db (AsyncSession): Database session. + upload_offset (int): The current upload offset from the client. + content_length (int | None): The Content-Length header value. + content_type (str): The Content-Type header value. + + Returns: + Response: HTTP response with updated upload offset. + + """ from starlette.requests import ClientDisconnect if content_type != "application/offset+octet-stream": - raise HTTPException(status_code=415, detail="Invalid Content-Type") + raise HTTPException(status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, detail="Invalid Content-Type") + if content_length and content_length > settings.max_chunk_bytes: raise HTTPException( status_code=status.HTTP_413_CONTENT_TOO_LARGE, detail=f"Chunk too large. Max {settings.max_chunk_bytes} bytes", ) - record = await _get_upload_record(db, upload_id) + + record: models.UploadRecord = await _get_upload_record(db, upload_id) + if record.upload_length is None: - raise HTTPException(status_code=409, detail="Upload length unknown") + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Upload length unknown") + if record.upload_offset != upload_offset: - raise HTTPException(status_code=409, detail="Mismatched Upload-Offset") + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Mismatched Upload-Offset") - if record.status == "completed": - return Response(status_code=204, headers={"Upload-Offset": str(record.upload_offset)}) + if "completed" == record.status: + return Response(status_code=status.HTTP_204_NO_CONTENT, headers={"Upload-Offset": str(record.upload_offset)}) path = Path(record.storage_path) path.parent.mkdir(parents=True, exist_ok=True) - bytes_written = 0 + bytes_written: int = 0 try: async with aiofiles.open(path, "ab") as f: @@ -176,7 +241,7 @@ async def tus_patch( # noqa: PLR0915 if bytes_written > 0: record.upload_offset += bytes_written if record.upload_offset > record.upload_length: - raise HTTPException(status_code=413, detail="Upload exceeds declared length") + raise HTTPException(status_code=status.HTTP_413_CONTENT_TOO_LARGE, detail="Upload exceeds declared length") if record.upload_offset == record.upload_length: try: @@ -187,13 +252,13 @@ async def tus_patch( # noqa: PLR0915 detail=f"Failed to detect file type: {e}", ) - stmt = select(models.UploadToken).where(models.UploadToken.id == record.token_id) - res = await db.execute(stmt) - token = res.scalar_one_or_none() - if not token: + stmt: Select[tuple[models.UploadToken]] = select(models.UploadToken).where(models.UploadToken.id == record.token_id) + res: Result[tuple[models.UploadToken]] = await db.execute(stmt) + + if not (token := res.scalar_one_or_none()): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Token not found") - if not _mime_allowed(actual_mimetype, token.allowed_mime): + if not mime_allowed(actual_mimetype, token.allowed_mime): path.unlink(missing_ok=True) await db.delete(record) await db.commit() @@ -204,15 +269,13 @@ async def tus_patch( # noqa: PLR0915 record.mimetype = actual_mimetype - # Extract ffprobe metadata for multimedia files if is_multimedia(actual_mimetype): - ffprobe_data = await extract_ffprobe_metadata(path) + ffprobe_data: dict | None = await extract_ffprobe_metadata(path) if ffprobe_data is not None: - # Update meta_data with ffprobe results if record.meta_data is None: record.meta_data = {} + record.meta_data["ffprobe"] = ffprobe_data - # Mark the attribute as modified for SQLAlchemy to detect the change attributes.flag_modified(record, "meta_data") record.status = "completed" @@ -224,12 +287,11 @@ async def tus_patch( # noqa: PLR0915 await db.commit() await db.refresh(record) except Exception: - # On concurrent update conflict, refresh and try again await db.rollback() await db.refresh(record) return Response( - status_code=204, + status_code=status.HTTP_204_NO_CONTENT, headers={ "Upload-Offset": str(record.upload_offset), "Tus-Resumable": "1.0.0", @@ -239,9 +301,10 @@ async def tus_patch( # noqa: PLR0915 @router.options("/tus", name="tus_options") -async def tus_options(): +async def tus_options() -> Response: + """Handle TUS protocol OPTIONS request.""" return Response( - status_code=204, + status_code=status.HTTP_204_NO_CONTENT, headers={ "Tus-Resumable": "1.0.0", "Tus-Version": "1.0.0", @@ -250,25 +313,50 @@ async def tus_options(): ) -@router.delete("/{upload_id}/tus", status_code=204, name="tus_delete") -async def tus_delete(upload_id: int, db: Annotated[AsyncSession, Depends(get_db)]): - record = await _get_upload_record(db, upload_id) +@router.delete("/{upload_id}/tus", status_code=status.HTTP_204_NO_CONTENT, name="tus_delete") +async def tus_delete(upload_id: str, db: Annotated[AsyncSession, Depends(get_db)]) -> Response: + """ + Delete an upload record and its associated file. + + Args: + upload_id (str): The public ID of the upload. + db (AsyncSession): Database session. + + Returns: + Response: HTTP 204 No Content response. + + """ + record: models.UploadRecord = await _get_upload_record(db, upload_id) path = Path(record.storage_path or "") + if path.exists(): with contextlib.suppress(OSError): path.unlink() + await db.delete(record) await db.commit() - return Response(status_code=204) + return Response(status_code=status.HTTP_204_NO_CONTENT) @router.post("/{upload_id}/complete", response_model=schemas.UploadRecordResponse, name="mark_complete") -async def mark_complete(upload_id: int, db: Annotated[AsyncSession, Depends(get_db)]): - stmt = select(models.UploadRecord).where(models.UploadRecord.id == upload_id) - res = await db.execute(stmt) - record = res.scalar_one_or_none() - if not record: - raise HTTPException(status_code=404, detail="Upload not found") +async def mark_complete(upload_id: str, db: Annotated[AsyncSession, Depends(get_db)]) -> models.UploadRecord: + """ + Mark an upload as complete. + + Args: + upload_id (str): The public ID of the upload. + db (AsyncSession): Database session. + + Returns: + UploadRecord: The updated upload record. + + """ + stmt: Select[tuple[models.UploadRecord]] = select(models.UploadRecord).where(models.UploadRecord.public_id == upload_id) + res: Result[tuple[models.UploadRecord]] = await db.execute(stmt) + + if not (record := res.scalar_one_or_none()): + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Upload not found") + record.status = "completed" record.completed_at = datetime.now(UTC) await db.commit() @@ -278,23 +366,34 @@ async def mark_complete(upload_id: int, db: Annotated[AsyncSession, Depends(get_ @router.delete("/{upload_id}/cancel", response_model=dict, name="cancel_upload") async def cancel_upload( - upload_id: int, + upload_id: str, db: Annotated[AsyncSession, Depends(get_db)], token: Annotated[str, Query(description="Upload token")] = ..., -): - """Cancel an incomplete upload, delete the record, and restore the upload slot.""" - record = await _get_upload_record(db, upload_id) - - stmt = select(models.UploadToken).where(models.UploadToken.token == token) - res = await db.execute(stmt) - token_row = res.scalar_one_or_none() - if not token_row: +) -> dict[str, Any]: + """ + Cancel an incomplete upload, delete the record, and restore the upload slot. + + Args: + upload_id (str): The public ID of the upload. + db (AsyncSession): Database session. + token (str): The upload token string. + + Returns: + dict: Confirmation message and remaining uploads. + + """ + record: models.UploadRecord = await _get_upload_record(db, upload_id) + + stmt: Select[tuple[models.UploadToken]] = select(models.UploadToken).where(models.UploadToken.token == token) + res: Result[tuple[models.UploadToken]] = await db.execute(stmt) + + if not (token_row := res.scalar_one_or_none()): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Token not found") if record.token_id != token_row.id: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Upload does not belong to this token") - if record.status == "completed": + if "completed" == record.status: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot cancel completed upload") path = Path(record.storage_path or "") @@ -304,9 +403,7 @@ async def cancel_upload( await db.delete(record) - await db.execute( - update(models.UploadToken).where(models.UploadToken.id == token_row.id).values(uploads_used=models.UploadToken.uploads_used - 1) - ) + token_row.uploads_used -= 1 await db.commit() await db.refresh(token_row) diff --git a/backend/app/schemas.py b/backend/app/schemas.py index d1b960c..654a89d 100644 --- a/backend/app/schemas.py +++ b/backend/app/schemas.py @@ -51,7 +51,7 @@ class TokenInfo(BaseModel): class UploadRecordResponse(BaseModel): - id: int + public_id: str filename: str | None ext: str | None mimetype: str | None @@ -70,7 +70,7 @@ class UploadRecordResponse(BaseModel): class TokenPublicInfo(BaseModel): - token: str + token: str | None download_token: str remaining_uploads: int max_uploads: int @@ -95,3 +95,17 @@ class UploadRequest(BaseModel): filename: str | None = None filetype: str | None = None size_bytes: int | None = Field(None, gt=0) + + +class TokenListResponse(BaseModel): + tokens: list[TokenAdmin] + total: int + + +class InitiateUploadResponse(BaseModel): + upload_id: str + upload_url: str + download_url: str + meta_data: dict[str, Any] + allowed_mime: list[str] | None + remaining_uploads: int diff --git a/backend/app/security.py b/backend/app/security.py index 2678a21..0fd32dc 100644 --- a/backend/app/security.py +++ b/backend/app/security.py @@ -5,10 +5,45 @@ from .config import settings +def optional_admin_check( + authorization: Annotated[str | None, Header()] = None, + api_key: Annotated[str | None, Query(description="API key")] = None, +) -> bool: + """ + Check admin authentication only if public downloads are disabled. + + Args: + authorization (str | None): The Authorization header. + api_key (str | None): The API key query parameter. + + Returns: + bool: True if admin is verified or public downloads are allowed. + + Raises: + HTTPException: If admin verification fails when required. + + """ + return True if settings.allow_public_downloads else verify_admin(authorization, api_key) + + def verify_admin( authorization: Annotated[str | None, Header()] = None, api_key: Annotated[str | None, Query(description="API key")] = None, ) -> bool: + """ + Verify the provided admin API key. + + Args: + authorization (str | None): The Authorization header. + api_key (str | None): The API key query parameter. + + Returns: + bool: True if the API key is valid. + + Raises: + HTTPException: If the API key is invalid or missing. + + """ key = None if api_key is not None: @@ -17,10 +52,14 @@ def verify_admin( if authorization and authorization.lower().startswith("bearer "): key: str = authorization.split(" ", 1)[1] + if key: + key = key.strip() + if key != settings.admin_api_key: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials", headers={"WWW-Authenticate": "Bearer"}, ) + return True diff --git a/backend/app/utils.py b/backend/app/utils.py index 91ba195..5ea265d 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -6,6 +6,15 @@ import magic +MIME = magic.Magic(mime=True) + +MULTIPLIERS: dict[str, int] = { + "k": 1024, + "m": 1024**2, + "g": 1024**3, + "t": 1024**4, +} + def detect_mimetype(file_path: str | Path) -> str: """ @@ -22,13 +31,12 @@ def detect_mimetype(file_path: str | Path) -> str: OSError: If the file cannot be read """ - path = Path(file_path) + path: Path = Path(file_path) if not path.exists(): - msg = f"File not found: {file_path}" + msg: str = f"File not found: {file_path}" raise FileNotFoundError(msg) - mime = magic.Magic(mime=True) - return mime.from_file(str(path)) + return MIME.from_file(str(path)) def is_multimedia(mimetype: str) -> bool: @@ -50,7 +58,7 @@ async def extract_ffprobe_metadata(file_path: str | Path) -> dict | None: Extract multimedia metadata using ffprobe. Args: - file_path: Path to the multimedia file + file_path (str | Path): Path to the multimedia file Returns: Dictionary containing ffprobe output in JSON format, or None if extraction fails @@ -82,10 +90,64 @@ async def extract_ffprobe_metadata(file_path: str | Path) -> dict | None: if proc.returncode != 0: return None - dct = json.loads(stdout.decode()) - if "format" in dct and "filename" in dct.get("format"): + dct: dict | None = json.loads(stdout.decode()) + if dct and "format" in dct and "filename" in dct.get("format"): dct["format"].pop("filename", None) except Exception: return None else: return dct + + +def mime_allowed(filetype: str | None, allowed: list[str] | None) -> bool: + """ + Check if a given MIME type is allowed based on a list of allowed patterns. + + Args: + filetype: The MIME type to check (e.g., 'video/mp4') + allowed: List of allowed MIME patterns (e.g., ['application/pdf', 'video/* + + Returns: + True if the MIME type is allowed, False otherwise + + """ + if not allowed or not filetype: + return True + + for pattern in allowed: + if pattern.endswith("/*"): + prefix: str = pattern.split("/")[0] + if filetype.startswith(prefix + "/"): + return True + + elif filetype == pattern: + return True + + return False + + +def parse_size(text: str) -> int: + """ + Parse human-readable sizes: 100, 10M, 1G, 500k. + + Args: + text (str): Size string to parse. + + Returns: + int: Size in bytes as an integer. + + Raises: + ValueError: If the size string is invalid. + + """ + s: str = text.strip().lower() + if s[-1].isalpha(): + num = float(s[:-1]) + unit: str = s[-1] + if unit not in MULTIPLIERS: + msg = "Unknown size suffix; use K/M/G/T" + raise ValueError(msg) + + return int(num * MULTIPLIERS[unit]) + + return int(s) diff --git a/backend/main.py b/backend/main.py index f7caec7..ff1d4c7 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,19 +1,18 @@ import logging +import os import sys from pathlib import Path import uvicorn -ROOT = Path(__file__).resolve().parent.parent +ROOT: Path = Path(__file__).resolve().parent.parent if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) -LOG = logging.getLogger("fbc-uploader") +LOG: logging.Logger = logging.getLogger("fbc-uploader") def main() -> None: - import os - uvicorn.run("backend.app.main:app", host="0.0.0.0", port=8000, reload=bool(os.getenv("FBC_DEV_MODE", "0") == "1")) diff --git a/backend/migrations/versions/cbb603dd8b0b_add_public_id_to_uploads.py b/backend/migrations/versions/cbb603dd8b0b_add_public_id_to_uploads.py new file mode 100644 index 0000000..09cf7c0 --- /dev/null +++ b/backend/migrations/versions/cbb603dd8b0b_add_public_id_to_uploads.py @@ -0,0 +1,45 @@ +"""add_public_id_to_uploads + +Revision ID: cbb603dd8b0b +Revises: 5dc14b681055 +Create Date: 2025-12-30 17:50:45.019140 +""" +from alembic import op +import sqlalchemy as sa +import secrets + + +# revision identifiers, used by Alembic. +revision = 'cbb603dd8b0b' +down_revision = '5dc14b681055' +branch_labels = None +depends_on = None + + +def upgrade(): + # Add public_id column as nullable first + op.add_column('uploads', sa.Column('public_id', sa.String(length=64), nullable=True)) + + # Populate existing rows with random public_ids + connection = op.get_bind() + result = connection.execute(sa.text("SELECT id FROM uploads")) + for row in result: + public_id = secrets.token_urlsafe(18) + connection.execute( + sa.text("UPDATE uploads SET public_id = :public_id WHERE id = :id"), + {"public_id": public_id, "id": row[0]} + ) + + # For SQLite, we need to use batch operations to make column non-nullable + with op.batch_alter_table('uploads', schema=None) as batch_op: + batch_op.alter_column('public_id', nullable=False) + batch_op.create_unique_constraint('uq_uploads_public_id', ['public_id']) + batch_op.create_index('ix_uploads_public_id', ['public_id']) + + +def downgrade(): + with op.batch_alter_table('uploads', schema=None) as batch_op: + batch_op.drop_index('ix_uploads_public_id') + batch_op.drop_constraint('uq_uploads_public_id', type_='unique') + batch_op.drop_column('public_id') + diff --git a/backend/tests/test_admin.py b/backend/tests/test_admin.py index fbcee6c..4232a89 100644 --- a/backend/tests/test_admin.py +++ b/backend/tests/test_admin.py @@ -67,7 +67,7 @@ async def test_delete_upload(): upload_id = upload_data["upload_id"] async with SessionLocal() as session: - stmt = select(models.UploadRecord).where(models.UploadRecord.id == upload_id) + stmt = select(models.UploadRecord).where(models.UploadRecord.public_id == upload_id) res = await session.execute(stmt) upload = res.scalar_one_or_none() assert upload is not None, "Upload should exist before deletion" @@ -80,7 +80,7 @@ async def test_delete_upload(): assert delete_resp.json()["status"] == "deleted", "Response should indicate deletion" async with SessionLocal() as session: - stmt = select(models.UploadRecord).where(models.UploadRecord.id == upload_id) + stmt = select(models.UploadRecord).where(models.UploadRecord.public_id == upload_id) res = await session.execute(stmt) upload = res.scalar_one_or_none() assert upload is None, "Upload should be deleted from database" @@ -92,7 +92,7 @@ async def test_delete_upload_not_found(): transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://testserver") as client: r = await client.delete( - app.url_path_for("delete_upload", upload_id=99999), + app.url_path_for("delete_upload", upload_id="nonexistent_id"), headers={"Authorization": f"Bearer {settings.admin_api_key}"}, ) assert r.status_code == status.HTTP_404_NOT_FOUND, "Deleting non-existent upload should return 404" @@ -103,11 +103,11 @@ async def test_delete_upload_requires_admin(): """Test delete upload requires admin authentication.""" transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://testserver") as client: - r = await client.delete(app.url_path_for("delete_upload", upload_id=1)) + r = await client.delete(app.url_path_for("delete_upload", upload_id="some_id")) assert r.status_code == status.HTTP_401_UNAUTHORIZED, "Delete without auth should return 401" r = await client.delete( - app.url_path_for("delete_upload", upload_id=1), + app.url_path_for("delete_upload", upload_id="some_id"), headers={"Authorization": "Bearer invalid-key"}, ) assert r.status_code == status.HTTP_401_UNAUTHORIZED, "Delete with invalid key should return 401" diff --git a/backend/tests/test_advanced.py b/backend/tests/test_advanced.py index 993efef..f415ae6 100644 --- a/backend/tests/test_advanced.py +++ b/backend/tests/test_advanced.py @@ -59,7 +59,7 @@ async def test_reject_upload_with_disallowed_mime(): "meta_data": {"broadcast_date": "2025-01-01", "title": "Test", "source": "youtube"}, }, ) - assert upload_resp.status_code in [status.HTTP_201_CREATED, status.HTTP_403_FORBIDDEN, status.HTTP_415_UNSUPPORTED_MEDIA_TYPE], "Disallowed MIME type should be rejected or require special handling" + assert upload_resp.status_code == status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, "Disallowed MIME type should be rejected" @pytest.mark.asyncio @@ -81,7 +81,9 @@ async def test_reject_upload_exceeding_size(): "meta_data": {"broadcast_date": "2025-01-01", "title": "Test", "source": "youtube"}, }, ) - assert upload_resp.status_code in [status.HTTP_403_FORBIDDEN, status.HTTP_413_CONTENT_TOO_LARGE], "Upload exceeding size limit should be rejected" + assert upload_resp.status_code in [status.HTTP_403_FORBIDDEN, status.HTTP_413_CONTENT_TOO_LARGE], ( + "Upload exceeding size limit should be rejected" + ) assert ( "exceeds" in upload_resp.json()["detail"].lower() or "too large" in upload_resp.json()["detail"].lower() diff --git a/backend/tests/test_cleanup.py b/backend/tests/test_cleanup.py index 028be9d..04fb048 100644 --- a/backend/tests/test_cleanup.py +++ b/backend/tests/test_cleanup.py @@ -1,5 +1,6 @@ from datetime import UTC, datetime, timedelta from pathlib import Path +import secrets import pytest from sqlalchemy import select @@ -69,6 +70,7 @@ async def test_remove_stale_uploads_deletes_files(monkeypatch): await session.refresh(token) stale_upload = models.UploadRecord( + public_id=secrets.token_urlsafe(18), token_id=token.id, filename="stale.bin", size_bytes=5, @@ -131,6 +133,7 @@ async def test_remove_disabled_tokens_cleans_records_and_storage(monkeypatch): await session.refresh(recent_token) old_upload = models.UploadRecord( + public_id=secrets.token_urlsafe(18), token_id=old_token.id, filename="old.txt", size_bytes=9, @@ -140,6 +143,7 @@ async def test_remove_disabled_tokens_cleans_records_and_storage(monkeypatch): completed_at=now - timedelta(days=2), ) recent_upload = models.UploadRecord( + public_id=secrets.token_urlsafe(18), token_id=recent_token.id, filename="recent.txt", size_bytes=11, diff --git a/backend/tests/test_download_url_security.py b/backend/tests/test_download_url_security.py new file mode 100644 index 0000000..c0d9edd --- /dev/null +++ b/backend/tests/test_download_url_security.py @@ -0,0 +1,67 @@ +""" +Test that API responses don't expose admin API key in download URLs. +The frontend should append the API key when needed. +""" + +import pytest +from fastapi import status +from httpx import ASGITransport, AsyncClient + +from backend.app.config import settings +from backend.app.main import app +from backend.tests.utils import create_token, initiate_upload, upload_file_via_tus + + +@pytest.mark.asyncio +async def test_list_token_uploads_does_not_expose_api_key(): + """Test that list_token_uploads returns clean download URLs without API key.""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://testserver") as client: + token_data = await create_token(client) + upload_data = await initiate_upload( + client, token_data["token"], filename="test.txt", size_bytes=11, filetype="text/plain", meta_data={} + ) + await upload_file_via_tus(client, upload_data["upload_id"], b"hello world") + + # Get uploads list as admin + response = await client.get( + app.url_path_for("list_token_uploads", token_value=token_data["token"]), + headers={"Authorization": f"Bearer {settings.admin_api_key}"}, + ) + assert response.status_code == status.HTTP_200_OK, "Should return uploads list" + uploads = response.json() + assert len(uploads) > 0, "Should have at least one upload" + + # Verify download_url does NOT contain api_key + download_url = uploads[0]["download_url"] + assert "api_key" not in download_url, "Download URL should not contain api_key" + assert "?api_key=" not in download_url, "Download URL should not have api_key query param" + + +@pytest.mark.asyncio +async def test_get_file_info_does_not_expose_api_key(): + """Test that get_file_info returns clean download URLs without API key.""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://testserver") as client: + # Create token and upload a file + token_data = await create_token(client) + upload_data = await initiate_upload( + client, token_data["token"], filename="test.txt", size_bytes=11, filetype="text/plain", meta_data={} + ) + await upload_file_via_tus(client, upload_data["upload_id"], b"hello world") + + response = await client.get( + app.url_path_for( + "get_file_info", + download_token=token_data["download_token"], + upload_id=upload_data["upload_id"], + ), + headers={"Authorization": f"Bearer {settings.admin_api_key}"}, + ) + assert response.status_code == status.HTTP_200_OK, "Should return file info" + file_info = response.json() + + # Verify download_url does NOT contain api_key + download_url = file_info["download_url"] + assert "api_key" not in download_url, "Download URL should not contain api_key" + assert "?api_key=" not in download_url, "Download URL should not have api_key query param" diff --git a/backend/tests/test_mimetype_validation.py b/backend/tests/test_mimetype_validation.py index 49660eb..54bd983 100644 --- a/backend/tests/test_mimetype_validation.py +++ b/backend/tests/test_mimetype_validation.py @@ -2,6 +2,7 @@ import shutil from pathlib import Path +from fastapi import status import pytest from httpx import ASGITransport, AsyncClient @@ -27,7 +28,7 @@ async def test_mimetype_spoofing_rejected(): }, headers={"Authorization": f"Bearer {settings.admin_api_key}"}, ) - assert resp.status_code == 201, "Token creation should return 201" + assert resp.status_code == status.HTTP_201_CREATED, "Token creation should return 201" token_data = resp.json() token_value = token_data["token"] @@ -43,7 +44,7 @@ async def test_mimetype_spoofing_rejected(): }, params={"token": token_value}, ) - assert init_resp.status_code == 201, "Upload initiation should return 201" + assert init_resp.status_code == status.HTTP_201_CREATED, "Upload initiation should return 201" upload_data = init_resp.json() upload_id = upload_data["upload_id"] @@ -57,11 +58,11 @@ async def test_mimetype_spoofing_rejected(): }, ) - assert patch_resp.status_code == 415, "Fake video file should be rejected with 415" + assert patch_resp.status_code == status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, "Fake video file should be rejected with 415" assert "does not match allowed types" in patch_resp.json()["detail"], "Error should indicate type mismatch" head_resp = await client.head(app.url_path_for("tus_head", upload_id=upload_id)) - assert head_resp.status_code == 404, "Rejected upload should be removed" + assert head_resp.status_code == status.HTTP_404_NOT_FOUND, "Rejected upload should be removed" @pytest.mark.asyncio @@ -78,7 +79,7 @@ async def test_valid_mimetype_accepted(): }, headers={"Authorization": f"Bearer {settings.admin_api_key}"}, ) - assert resp.status_code == 201, "Token creation should return 201" + assert resp.status_code == status.HTTP_201_CREATED, "Token creation should return 201" token_data = resp.json() token_value = token_data["token"] @@ -92,14 +93,14 @@ async def test_valid_mimetype_accepted(): }, params={"token": token_value}, ) - assert init_resp.status_code == 201, "Upload initiation should return 201" + assert init_resp.status_code == status.HTTP_201_CREATED, "Upload initiation should return 201" upload_data = init_resp.json() upload_id = upload_data["upload_id"] text_content = b"This is a text file." head_resp = await client.head(app.url_path_for("tus_head", upload_id=upload_id)) - assert head_resp.status_code == 200, "TUS HEAD should return 200" + assert head_resp.status_code == status.HTTP_200_OK, "TUS HEAD should return 200" patch_resp = await client.patch( app.url_path_for("tus_patch", upload_id=upload_id), @@ -111,10 +112,10 @@ async def test_valid_mimetype_accepted(): }, ) - assert patch_resp.status_code == 204, "Valid text file should be accepted" + assert patch_resp.status_code == status.HTTP_204_NO_CONTENT, "Valid text file should be accepted" head_resp = await client.head(app.url_path_for("tus_head", upload_id=upload_id)) - assert head_resp.status_code == 200, "Upload should still exist after completion" + assert head_resp.status_code == status.HTTP_200_OK, "Upload should still exist after completion" @pytest.mark.asyncio @@ -130,7 +131,7 @@ async def test_mimetype_updated_on_completion(): }, headers={"Authorization": f"Bearer {settings.admin_api_key}"}, ) - assert resp.status_code == 201, "Token creation should return 201" + assert resp.status_code == status.HTTP_201_CREATED, "Token creation should return 201" token_data = resp.json() token_value = token_data["token"] @@ -144,7 +145,7 @@ async def test_mimetype_updated_on_completion(): }, params={"token": token_value}, ) - assert init_resp.status_code == 201, "Upload initiation should return 201" + assert init_resp.status_code == status.HTTP_201_CREATED, "Upload initiation should return 201" upload_data = init_resp.json() upload_id = upload_data["upload_id"] @@ -159,10 +160,10 @@ async def test_mimetype_updated_on_completion(): "Content-Length": str(len(text_content)), }, ) - assert patch_resp.status_code == 204, "Upload completion should return 204" + assert patch_resp.status_code == status.HTTP_204_NO_CONTENT, "Upload completion should return 204" async with SessionLocal() as session: - stmt = select(models.UploadRecord).where(models.UploadRecord.id == upload_id) + stmt = select(models.UploadRecord).where(models.UploadRecord.public_id == upload_id) res = await session.execute(stmt) upload = res.scalar_one_or_none() assert upload is not None, "Upload record should exist" @@ -184,7 +185,7 @@ async def test_ffprobe_extracts_metadata_for_video(): }, headers={"Authorization": f"Bearer {settings.admin_api_key}"}, ) - assert resp.status_code == 201, "Token creation should return 201" + assert resp.status_code == status.HTTP_201_CREATED, "Token creation should return 201" token_data = resp.json() token_value = token_data["token"] @@ -200,7 +201,7 @@ async def test_ffprobe_extracts_metadata_for_video(): }, params={"token": token_value}, ) - assert init_resp.status_code == 201, "Upload initiation should return 201" + assert init_resp.status_code == status.HTTP_201_CREATED, "Upload initiation should return 201" upload_data = init_resp.json() upload_id = upload_data["upload_id"] @@ -213,10 +214,10 @@ async def test_ffprobe_extracts_metadata_for_video(): "Content-Length": str(file.stat().st_size), }, ) - assert patch_resp.status_code == 204, "Video upload should complete successfully" + assert patch_resp.status_code == status.HTTP_204_NO_CONTENT, "Video upload should complete successfully" async with SessionLocal() as session: - stmt = select(models.UploadRecord).where(models.UploadRecord.id == upload_id) + stmt = select(models.UploadRecord).where(models.UploadRecord.public_id == upload_id) res = await session.execute(stmt) upload = res.scalar_one_or_none() assert upload is not None, "Upload record should exist" @@ -224,7 +225,9 @@ async def test_ffprobe_extracts_metadata_for_video(): assert upload.meta_data is not None, "Metadata should be extracted" if "ffprobe" in upload.meta_data: assert isinstance(upload.meta_data["ffprobe"], dict), "ffprobe data should be a dict" - assert "format" in upload.meta_data["ffprobe"] or "streams" in upload.meta_data["ffprobe"], "ffprobe should contain format or streams info" + assert "format" in upload.meta_data["ffprobe"] or "streams" in upload.meta_data["ffprobe"], ( + "ffprobe should contain format or streams info" + ) @pytest.mark.asyncio @@ -241,7 +244,7 @@ async def test_ffprobe_not_run_for_non_multimedia(): }, headers={"Authorization": f"Bearer {settings.admin_api_key}"}, ) - assert resp.status_code == 201, "Token creation should return 201" + assert resp.status_code == status.HTTP_201_CREATED, "Token creation should return 201" token_data = resp.json() token_value = token_data["token"] @@ -256,7 +259,7 @@ async def test_ffprobe_not_run_for_non_multimedia(): }, params={"token": token_value}, ) - assert init_resp.status_code == 201, "Upload initiation should return 201" + assert init_resp.status_code == status.HTTP_201_CREATED, "Upload initiation should return 201" upload_data = init_resp.json() upload_id = upload_data["upload_id"] @@ -269,10 +272,10 @@ async def test_ffprobe_not_run_for_non_multimedia(): "Content-Length": str(len(text_content)), }, ) - assert patch_resp.status_code == 204, "Text file upload should complete successfully" + assert patch_resp.status_code == status.HTTP_204_NO_CONTENT, "Text upload should complete successfully" async with SessionLocal() as session: - stmt = select(models.UploadRecord).where(models.UploadRecord.id == upload_id) + stmt = select(models.UploadRecord).where(models.UploadRecord.public_id == upload_id) res = await session.execute(stmt) upload = res.scalar_one_or_none() assert upload is not None, "Upload record should exist" diff --git a/backend/tests/test_share_view.py b/backend/tests/test_share_view.py new file mode 100644 index 0000000..5f6cb86 --- /dev/null +++ b/backend/tests/test_share_view.py @@ -0,0 +1,59 @@ +"""Test share view endpoint returns appropriate data based on token type.""" + +import pytest +from httpx import ASGITransport, AsyncClient +from fastapi import status + +from backend.app.main import app +from backend.tests.utils import create_token + + +@pytest.mark.asyncio +async def test_get_token_with_upload_token_returns_full_info(): + """Test that accessing with upload token returns full token info including upload token.""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://testserver") as client: + token_data = await create_token(client, max_uploads=1, max_size_bytes=1000000) + upload_token = token_data["token"] + + response = await client.get(app.url_path_for("get_token", token_value=upload_token)) + + assert response.status_code == status.HTTP_200_OK, "Should return 200 for valid upload token with auth" + data = response.json() + assert "token" in data, "Should include upload token field" + assert data["token"] == upload_token, "Upload token should match" + assert "download_token" in data, "Should include download token field" + assert "remaining_uploads" in data, "Should include remaining_uploads field" + + +@pytest.mark.asyncio +async def test_get_token_with_download_token_returns_limited_info(): + """Test that accessing with download token returns share info without upload token.""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://testserver") as client: + token_data = await create_token(client, max_uploads=1, max_size_bytes=1000000) + download_token = token_data["download_token"] + + response = await client.get(app.url_path_for("get_token", token_value=download_token)) + + assert response.status_code == status.HTTP_200_OK, "Should return 200 for valid download token" + data = response.json() + assert data["token"] is None or data["token"] == "", "Token field should be empty/None for share info" + assert "download_token" in data, "Should include download token field" + assert data["download_token"] == download_token, "Download token should match" + assert "max_uploads" in data, "Should include max_uploads field" + assert "allowed_mime" in data, "Should include allowed_mime field" + assert "allow_public_downloads" in data, "Should include allow_public_downloads field" + + +@pytest.mark.asyncio +async def test_get_token_invalid_token_returns_404(): + """Test that invalid token returns 404.""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://testserver") as client: + response = await client.get(app.url_path_for("get_token", token_value="invalid_token")) + + assert response.status_code == status.HTTP_404_NOT_FOUND, "Should return 404 for invalid token" + data = response.json() + assert "detail" in data, "Should include error detail" + assert "not found" in data["detail"].lower(), "Error should mention token not found" diff --git a/backend/tests/test_upload_cancel.py b/backend/tests/test_upload_cancel.py index c992837..93294dd 100644 --- a/backend/tests/test_upload_cancel.py +++ b/backend/tests/test_upload_cancel.py @@ -95,10 +95,10 @@ async def test_cancel_upload_invalid_upload_id(): token = r.json()["token"] response = await client.delete( - app.url_path_for("cancel_upload", upload_id=999999), + app.url_path_for("cancel_upload", upload_id="nonexistent_upload_id"), params={"token": token}, ) - assert response.status_code == 404, "Canceling non-existent upload should return 404" + assert response.status_code == status.HTTP_404_NOT_FOUND, "Canceling non-existent upload should return 404" assert response.json()["detail"] == "Upload not found", "Error should indicate upload not found" @@ -126,7 +126,7 @@ async def test_cancel_completed_upload_fails(): "meta_data": {"broadcast_date": "2024-01-01", "title": "Test", "source": "youtube"}, }, ) - assert init_response.status_code == 201, "Upload initiation should return 201" + assert init_response.status_code == status.HTTP_201_CREATED, "Upload initiation should return 201" upload_id = init_response.json()["upload_id"] tus_patch_url = app.url_path_for("tus_patch", upload_id=upload_id) @@ -139,13 +139,13 @@ async def test_cancel_completed_upload_fails(): }, content=b"hello", ) - assert patch_response.status_code == 204, "TUS PATCH should return 204" + assert patch_response.status_code == status.HTTP_204_NO_CONTENT, "TUS PATCH should return 204" response = await client.delete( app.url_path_for("cancel_upload", upload_id=upload_id), params={"token": token}, ) - assert response.status_code == 400, "Canceling completed upload should return 400" + assert response.status_code == status.HTTP_400_BAD_REQUEST, "Canceling completed upload should return 400" assert "Cannot cancel completed upload" in response.json()["detail"], "Error should indicate upload is completed" @@ -175,10 +175,10 @@ async def test_cancel_multiple_uploads(): "meta_data": {"broadcast_date": "2024-01-01", "title": f"Test{i}", "source": "youtube"}, }, ) - assert response.status_code == 201, f"Upload initiation {i} should return 201" + assert response.status_code == status.HTTP_201_CREATED, f"Upload initiation {i} should return 201" upload_ids.append(response.json()["upload_id"]) - info_url = app.url_path_for("get_public_token_info", token_value=token) + info_url = app.url_path_for("get_token", token_value=token) info = await client.get(info_url) assert info.json()["remaining_uploads"] == 0, "All upload slots should be used" @@ -186,14 +186,14 @@ async def test_cancel_multiple_uploads(): app.url_path_for("cancel_upload", upload_id=upload_ids[0]), params={"token": token}, ) - assert response.status_code == 200, "First cancellation should return 200" + assert response.status_code == status.HTTP_200_OK, "First cancellation should return 200" assert response.json()["remaining_uploads"] == 1, "First cancellation should restore one slot" response = await client.delete( app.url_path_for("cancel_upload", upload_id=upload_ids[1]), params={"token": token}, ) - assert response.status_code == 200, "Second cancellation should return 200" + assert response.status_code == status.HTTP_200_OK, "Second cancellation should return 200" assert response.json()["remaining_uploads"] == 2, "Second cancellation should restore another slot" info = await client.get(info_url) @@ -224,12 +224,12 @@ async def test_cancel_with_nonexistent_token(): "meta_data": {"broadcast_date": "2024-01-01", "title": "Test", "source": "youtube"}, }, ) - assert response.status_code == 201, "Upload initiation should return 201" + assert response.status_code == status.HTTP_201_CREATED, "Upload initiation should return 201" upload_id = response.json()["upload_id"] response = await client.delete( app.url_path_for("cancel_upload", upload_id=upload_id), params={"token": "fake_token"}, ) - assert response.status_code == 404, "Non-existent token should return 404" + assert response.status_code == status.HTTP_404_NOT_FOUND, "Non-existent token should return 404" assert response.json()["detail"] == "Token not found", "Error should indicate token not found" diff --git a/backend/tests/test_upload_flow.py b/backend/tests/test_upload_flow.py index 9f11ed1..0b8277d 100644 --- a/backend/tests/test_upload_flow.py +++ b/backend/tests/test_upload_flow.py @@ -30,7 +30,7 @@ async def test_token_info_and_initiate(): body = init.json() print(f"init response: {body}") assert init.status_code == status.HTTP_201_CREATED, "Initiate upload should return 201" - assert body["upload_id"] > 0, "Upload ID should be a positive integer" + assert isinstance(body["upload_id"], str) and len(body["upload_id"]) > 0, "Upload ID should be a non-empty string" assert body["remaining_uploads"] == 0, "Remaining uploads should decrease to 0" diff --git a/backend/tests/utils.py b/backend/tests/utils.py index cdd1f6f..93d059b 100644 --- a/backend/tests/utils.py +++ b/backend/tests/utils.py @@ -134,7 +134,7 @@ async def get_token_info( Token info response JSON """ - resp = await client.get(app.url_path_for("get_public_token_info", token_value=token)) + resp = await client.get(app.url_path_for("get_token", token_value=token)) return resp.json() if resp.status_code == status.HTTP_200_OK else {} diff --git a/frontend/app/components/AdminTokensTable.vue b/frontend/app/components/AdminTokensTable.vue index 8e59fa0..30cdf96 100644 --- a/frontend/app/components/AdminTokensTable.vue +++ b/frontend/app/components/AdminTokensTable.vue @@ -34,10 +34,6 @@ @@ -63,12 +59,25 @@
- - - + + + + + + + + + + + + + + + +
@@ -81,6 +90,8 @@ import type { AdminToken } from "~/types/token"; import { copyText, formatBytes, formatDate } from "~/utils"; +const toast = useToast(); + defineProps<{ tokens: AdminToken[]; loading?: boolean; @@ -92,8 +103,27 @@ defineEmits<{ delete: [token: AdminToken]; }>(); -const copyUrl = (token: string) => { - const url = `${window.location.origin}/t/${token}`; +const copyUrl = (path: string, token: string) => { + const url = `${window.location.origin}/${path}/${token}`; + console.log("Copying URL:", url); copyText(url); + toast.add({ + title: 'link copied to clipboard.', + color: 'success', + icon: 'i-heroicons-check-circle-20-solid', + }) } - + +const getCopyMenuItems = (token: AdminToken) => [ + [{ + label: 'Copy upload link', + icon: 'i-heroicons-arrow-up-tray-20-solid', + onSelect: () => copyUrl("t", token.token) + }], + [{ + label: 'Copy share link', + icon: 'i-heroicons-share-20-solid', + onSelect: () => copyUrl("f", token.download_token) + }] +] + \ No newline at end of file diff --git a/frontend/app/components/AdminUploadsTable.vue b/frontend/app/components/AdminUploadsTable.vue index 497de2f..baf7bef 100644 --- a/frontend/app/components/AdminUploadsTable.vue +++ b/frontend/app/components/AdminUploadsTable.vue @@ -25,7 +25,7 @@
- {{ upload.filename || 'Unnamed file' }} @@ -71,11 +71,13 @@ diff --git a/frontend/app/composables/useTokenInfo.ts b/frontend/app/composables/useTokenInfo.ts index ba99c07..ac38c80 100644 --- a/frontend/app/composables/useTokenInfo.ts +++ b/frontend/app/composables/useTokenInfo.ts @@ -8,7 +8,7 @@ export function useTokenInfo(tokenValue: Ref) { const shareLinkText = computed(() => { if (!tokenInfo.value) return '' - return `${window.location.origin}/api/tokens/${tokenInfo.value.download_token}/uploads` + return `${window.location.origin}/f/${tokenInfo.value.download_token}` }) async function fetchTokenInfo() { @@ -18,7 +18,7 @@ export function useTokenInfo(tokenValue: Ref) { } tokenError.value = '' try { - const data = await $fetch('/api/tokens/' + tokenValue.value + '/info') + const data = await $fetch('/api/tokens/' + tokenValue.value) tokenInfo.value = data as any notFound.value = false } catch (err: any) { diff --git a/frontend/app/layouts/default.vue b/frontend/app/layouts/default.vue index e928fff..d088f58 100644 --- a/frontend/app/layouts/default.vue +++ b/frontend/app/layouts/default.vue @@ -27,6 +27,35 @@ + +
@@ -34,6 +63,19 @@ diff --git a/frontend/app/pages/admin/index.vue b/frontend/app/pages/admin/index.vue index 5e62d2f..1e57029 100644 --- a/frontend/app/pages/admin/index.vue +++ b/frontend/app/pages/admin/index.vue @@ -57,29 +57,19 @@ - + - + @@ -160,6 +150,7 @@ const editOpen = ref(false); const uploadsOpen = ref(false); const uploadsToken = ref(null); const uploads = ref([]); +const allowPublicDownloads = ref(true); const deleteOpen = ref(false); const deleteTarget = ref(null); @@ -266,8 +257,14 @@ async function openUploads(token: AdminToken) { uploadsOpen.value = true; loadingUploads.value = true; try { - const res = await $apiFetch(`/api/tokens/${token.token}/uploads`); - uploads.value = res; + const tokenValue = ref(token.token); + const { tokenInfo, fetchTokenInfo } = useTokenInfo(tokenValue); + await fetchTokenInfo(); + + if (tokenInfo.value) { + allowPublicDownloads.value = tokenInfo.value.allow_public_downloads ?? true; + uploads.value = tokenInfo.value.uploads; + } } catch (err: any) { handleAuthError(err); } finally { diff --git a/frontend/app/pages/f/[token].vue b/frontend/app/pages/f/[token].vue new file mode 100644 index 0000000..c1f48c6 --- /dev/null +++ b/frontend/app/pages/f/[token].vue @@ -0,0 +1,262 @@ + + + diff --git a/frontend/app/tests/useTokenInfo.test.ts b/frontend/app/tests/useTokenInfo.test.ts index bfac931..7f0db66 100644 --- a/frontend/app/tests/useTokenInfo.test.ts +++ b/frontend/app/tests/useTokenInfo.test.ts @@ -22,10 +22,10 @@ describe('useTokenInfo', () => { await fetchTokenInfo() - expect(fetchMock).toHaveBeenCalledWith('/api/tokens/abc123/info') + expect(fetchMock).toHaveBeenCalledWith('/api/tokens/abc123') expect(notFound.value).toBe(false) expect(tokenInfo.value?.download_token).toBe('dl-token') - expect(shareLinkText.value).toBe(`${window.location.origin}/api/tokens/dl-token/uploads`) + expect(shareLinkText.value).toBe(`${window.location.origin}/f/dl-token`) }) it('sets error state when fetch fails', async () => { diff --git a/frontend/app/utils/index.ts b/frontend/app/utils/index.ts index 2ce35fa..289e9a2 100644 --- a/frontend/app/utils/index.ts +++ b/frontend/app/utils/index.ts @@ -73,4 +73,13 @@ function formatValue(val: any): string { return String(val); } -export { copyText, formatBytes, formatDate, percent, formatKey, formatValue }; \ No newline at end of file +/** + * Add admin API key to download URL if public downloads are disabled + */ +function addAdminKeyToUrl(url: string, allowPublicDownloads: boolean, apiKey: string | null): string { + if (allowPublicDownloads || !apiKey) return url; + const separator = url.includes('?') ? '&' : '?'; + return `${url}${separator}api_key=${apiKey}`; +} + +export { copyText, formatBytes, formatDate, percent, formatKey, formatValue, addAdminKeyToUrl }; \ No newline at end of file diff --git a/frontend/package.json b/frontend/package.json index f98684f..3058b4e 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -14,7 +14,8 @@ "lint:fix": "eslint . --fix", "test": "vitest run", "test:watch": "vitest", - "test:ci": "vitest run --silent --reporter=dot" + "test:ci": "vitest run --silent --reporter=dot", + "lint:tsc": "vue-tsc --noEmit" }, "dependencies": { "@iconify-json/lucide": "^1.2.82", diff --git a/pyproject.toml b/pyproject.toml index 3d61da2..b561800 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,6 +108,7 @@ ignore = [ "SLF001", "FBT001", "FBT002", + "PLR0915" ] fixable = ["ALL"] unfixable = [] diff --git a/tools/fbc_extractor.py b/tools/fbc_extractor.py index 63dce02..17d5073 100644 --- a/tools/fbc_extractor.py +++ b/tools/fbc_extractor.py @@ -1,16 +1,19 @@ import os import re import typing +from urllib.parse import ParseResult, urlparse, urlunparse from yt_dlp.extractor.common import InfoExtractor # type: ignore from yt_dlp.utils import int_or_none, str_or_none, traverse_obj # type: ignore class FBCIE(InfoExtractor): - _APIKEY = None - _NETRC_MACHINE = "fbc_uploader" - _VALID_URL = r"https?:\/\/.+\/api\/tokens\/(?Pfbc_[A-Za-z0-9_-]{22})\/?" - _VALID_FBC: re.Pattern[str] = re.compile(r"https?:\/\/.+\/api\/tokens\/(?Pfbc_[A-Za-z0-9_-]{22})\/uploads/?(?P\d+)?/?$") + _APIKEY: str | None = None + _NETRC_MACHINE: str = "fbc_uploader" + _VALID_URL: str = r"https?:\/\/.+\/(?:api\/tokens|f)\/(?Pfbc_[A-Za-z0-9_-]{22})\/?" + _VALID_FBC: re.Pattern[str] = re.compile( + r"https?:\/\/.+?\/(?:api\/tokens|f)\/(?Pfbc_[A-Za-z0-9_-]{22})(?:\/uploads)?/?(?P[A-Za-z0-9_-]+)?/?$" + ) _POSSIBLE_METADATA_FIELDS: typing.ClassVar = { "title": "title", "description": "description", @@ -24,7 +27,7 @@ def _match_valid_url(cls, url) -> re.Match[str] | None: @classmethod def _match_id(cls, url) -> None | str: - mat = cls._match_valid_url(url) + mat: re.Match[str] | None = cls._match_valid_url(url) if not mat: return None @@ -34,7 +37,28 @@ def _match_id(cls, url) -> None | str: return mat.group("id") def _perform_login(self, _, password): - FBCIE._APIKEY = password + FBCIE._APIKEY: str = password + + def _convert_to_api_url(self, url: str) -> str: + """Convert share URL format to API URL format.""" + mat: re.Match[str] | None = self._match_valid_url(url) + if not mat: + return url + + parsed: ParseResult = urlparse(url) + token_id: str | None = mat.group("id") + file_id: str | None = mat.group("fid") + + return urlunparse( + ( + parsed.scheme, + parsed.netloc, + f"/api/tokens/{token_id}/uploads/{file_id}" if file_id else f"/api/tokens/{token_id}/uploads", + "", + "", + "", + ) + ) def _real_extract(self, url): video_id: str = self._match_id(url) @@ -43,12 +67,12 @@ def _real_extract(self, url): if apikey := (FBCIE._APIKEY or os.environ.get("FBC_API_KEY")): headers["Authorization"] = f"Bearer {apikey!s}" - err_note = "Failed to download token info" + err_note = "Failed to download token info." if not apikey: - err_note += ", you may need to provide a valid API key --password or via FBC_API_KEY environment variable." + err_note += "You may need to provide a valid API key via --password or FBC_API_KEY environment variable." items_info = self._download_json( - url, + self._convert_to_api_url(url), video_id=video_id, headers=headers, note="Downloading token info", @@ -62,10 +86,8 @@ def _real_extract(self, url): playlist: list[dict] = [ self._format_item( video_data, - video_id, "video" if is_single or len(items_info) < 2 else "url", headers=headers, - is_single=is_single, ) for video_data in items_info if "completed" == video_data.get("status") @@ -130,13 +152,13 @@ def _parse_date(self, date_str: str | None, dateformat: str = "{year:04}{month:0 from datetime import datetime as dt try: - _dt = dt.fromisoformat(date_str) + _dt: dt = dt.fromisoformat(date_str) except Exception: return None return dateformat.format(year=_dt.year, month=_dt.month, day=_dt.day) - def _expand_format(self, format_dict: dict, ffprobe_data: dict) -> dict: # noqa: PLR0915 + def _expand_format(self, format_dict: dict, ffprobe_data: dict) -> dict: """Enrich format dictionary with data from ffprobe output.""" if not ffprobe_data: return format_dict @@ -214,7 +236,7 @@ def _expand_format(self, format_dict: dict, ffprobe_data: dict) -> dict: # noqa return format_dict - def _format_item(self, video_data: dict, video_id: str, _type: str, headers: dict | None = None, is_single: bool = False) -> dict: + def _format_item(self, video_data: dict, _type: str, headers: dict | None = None) -> dict: base_format = { "url": video_data.get("download_url"), "ext": video_data.get("ext"), @@ -226,7 +248,7 @@ def _format_item(self, video_data: dict, video_id: str, _type: str, headers: dic base_format = self._expand_format(base_format, ffprobe_data) dct = { - "id": f"{video_id}-{video_data.get('id')}" if not is_single else video_id, + "id": video_data.get("public_id", video_data.get("id")), "_type": _type, "ext": video_data.get("ext"), "mimetype": video_data.get("mimetype"),