diff --git a/shield/core/backends/file.py b/shield/core/backends/file.py index 9781463..7d72261 100644 --- a/shield/core/backends/file.py +++ b/shield/core/backends/file.py @@ -10,6 +10,23 @@ Format is chosen at construction time from the file extension. The data structure is identical across all formats; only the serialisation differs. + +Performance design +------------------ +The original implementation read and wrote the entire file on every +``get_state`` / ``set_state`` / ``write_audit`` call — O(N) file I/O per +operation. + +This version introduces a **write-through in-memory cache**: + +* All reads are served from the in-memory ``_states`` dict — zero file I/O. +* Writes update the in-memory dict immediately (O(1)), then schedule a + **debounced disk flush** (50 ms window). Rapid sequential writes are + coalesced into a single file write. +* A dedicated ``_io_lock`` serialises concurrent flushes so the file is + never corrupted by interleaved writes. +* The cache is populated lazily on the first operation via ``_ensure_loaded``. +* ``shutdown()`` cancels any pending debounce and flushes synchronously. """ from __future__ import annotations @@ -28,6 +45,9 @@ _MAX_AUDIT_ENTRIES = 1000 +# Debounce window: rapid sequential writes are coalesced into one disk flush. +_WRITE_DEBOUNCE_SECONDS = 0.05 + # Supported extensions mapped to a canonical format name. _EXT_TO_FORMAT: dict[str, str] = { ".json": "json", @@ -41,8 +61,12 @@ class FileBackend(ShieldBackend): """Backend that persists state to a file via ``aiofiles``. Survives process restarts. Suitable for simple single-instance - deployments. A single ``asyncio.Lock`` prevents concurrent write - corruption. + deployments. + + All reads are served from an in-memory cache populated on first access. + Writes update the cache immediately then schedule a debounced disk flush — + meaning rapid bursts of state changes (e.g. startup route registration) + result in a single file write rather than N writes. The file format is auto-detected from the extension: @@ -78,7 +102,12 @@ class FileBackend(ShieldBackend): def __init__(self, path: str) -> None: self._path = Path(path) - self._lock = asyncio.Lock() + # Serializes concurrent disk flushes — held only during I/O, not + # during in-memory mutations, so reads are never blocked. + self._io_lock = asyncio.Lock() + # Guards the initial file load to prevent duplicate reads when + # multiple coroutines first access the backend concurrently. + self._load_lock = asyncio.Lock() ext = self._path.suffix.lower() if ext not in _EXT_TO_FORMAT: @@ -89,6 +118,17 @@ def __init__(self, path: str) -> None: ) self._format: str = _EXT_TO_FORMAT[ext] + # In-memory cache — populated lazily by _ensure_loaded(). + # Raw dicts (not Pydantic models) are stored so that _flush_to_disk() + # can snapshot and serialise without re-encoding. + self._states: dict[str, Any] = {} + self._audit: list[dict[str, Any]] = [] + self._loaded: bool = False + + # Debounced write task — cancelled and replaced on each write so that + # a burst of N writes results in one disk flush rather than N. + self._pending_write: asyncio.Task[None] | None = None + # ------------------------------------------------------------------ # Format-specific serialisation # ------------------------------------------------------------------ @@ -140,11 +180,12 @@ def _serialize(self, data: dict[str, Any]) -> str: # Internal read / write # ------------------------------------------------------------------ - async def _read(self) -> dict[str, Any]: - """Read and parse the state file. + async def _read_from_disk(self) -> dict[str, Any]: + """Read and parse the state file from disk. Returns an empty ``{"states": {}, "audit": []}`` structure if the - file does not exist or is blank. + file does not exist or is blank. Called only once during cache + initialisation — all subsequent reads go through the in-memory cache. """ if not self._path.exists(): return {"states": {}, "audit": []} @@ -163,52 +204,136 @@ async def _write(self, data: dict[str, Any]) -> None: async with aiofiles.open(self._path, "w") as f: await f.write(self._serialize(data)) + async def _ensure_loaded(self) -> None: + """Populate the in-memory cache from disk on first access. + + Uses a lock + double-check to ensure exactly one disk read even when + multiple coroutines call a backend method concurrently before the + cache is warm. + """ + if self._loaded: + return + async with self._load_lock: + if self._loaded: # another coroutine beat us here + return + data = await self._read_from_disk() + self._states = data["states"] + self._audit = data["audit"] + self._loaded = True + + def _schedule_write(self) -> None: + """Schedule a debounced disk flush. + + Each call cancels any previously-scheduled flush and creates a new + one, so a burst of N writes within the debounce window results in + exactly one disk write. Falls back to a no-op when there is no + running event loop (e.g. during synchronous test teardown). + """ + if self._pending_write is not None and not self._pending_write.done(): + self._pending_write.cancel() + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return # no event loop — write will be flushed on shutdown() + self._pending_write = loop.create_task(self._debounced_write()) + + async def _debounced_write(self) -> None: + """Wait the debounce window then flush to disk.""" + try: + await asyncio.sleep(_WRITE_DEBOUNCE_SECONDS) + await self._flush_to_disk() + except asyncio.CancelledError: + pass # a newer write superseded us — that task will flush instead + + async def _flush_to_disk(self) -> None: + """Snapshot the current in-memory state and write it to disk. + + The ``_io_lock`` prevents concurrent flushes from interleaving. + The snapshot is taken inside the lock so the written data is always + consistent — no partial view of an in-progress update. + + In asyncio, dict operations without ``await`` are atomic (cooperative + scheduling guarantees no interleaving between two non-awaiting + statements), so snapshotting inside the lock is sufficient. + """ + async with self._io_lock: + data: dict[str, Any] = { + "states": dict(self._states), + "audit": list(self._audit), + } + await self._write(data) + # ------------------------------------------------------------------ # ShieldBackend interface # ------------------------------------------------------------------ async def get_state(self, path: str) -> RouteState: - """Return the current state for *path*. + """Return the current state for *path* from the in-memory cache. Raises ``KeyError`` if no state has been registered for *path*. + Zero file I/O after the cache is warm. """ - data = await self._read() - if path not in data["states"]: + await self._ensure_loaded() + raw = self._states.get(path) + if raw is None: raise KeyError(f"No state registered for path {path!r}") - return RouteState.model_validate(data["states"][path]) + return RouteState.model_validate(raw) async def set_state(self, path: str, state: RouteState) -> None: - """Persist *state* for *path*.""" - async with self._lock: - data = await self._read() - data["states"][path] = json.loads(state.model_dump_json()) - await self._write(data) + """Update the in-memory cache and flush to disk immediately. + + State changes are written synchronously so that a second + ``FileBackend`` instance (e.g. the CLI) reading the same file sees + the update right away. Unlike ``write_audit``, state mutations + are not debounced — durability is more important than batching here. + """ + await self._ensure_loaded() + self._states[path] = json.loads(state.model_dump_json()) + # Cancel any pending debounced audit flush — the full flush below + # will include both the new state and any queued audit entries. + if self._pending_write is not None and not self._pending_write.done(): + self._pending_write.cancel() + self._pending_write = None + await self._flush_to_disk() async def delete_state(self, path: str) -> None: - """Remove state for *path*. No-op if not registered.""" - async with self._lock: - data = await self._read() - data["states"].pop(path, None) - await self._write(data) + """Remove state for *path* from cache and flush to disk immediately. + + No-op if *path* is not registered. + """ + await self._ensure_loaded() + self._states.pop(path, None) + if self._pending_write is not None and not self._pending_write.done(): + self._pending_write.cancel() + self._pending_write = None + await self._flush_to_disk() async def list_states(self) -> list[RouteState]: - """Return all registered route states.""" - data = await self._read() - return [RouteState.model_validate(v) for v in data["states"].values()] + """Return all registered route states from the in-memory cache. + + Zero file I/O after the cache is warm. + """ + await self._ensure_loaded() + return [RouteState.model_validate(v) for v in self._states.values()] async def write_audit(self, entry: AuditEntry) -> None: - """Append *entry* to the audit log, capping at 1000 entries.""" - async with self._lock: - data = await self._read() - data["audit"].append(json.loads(entry.model_dump_json())) - if len(data["audit"]) > _MAX_AUDIT_ENTRIES: - data["audit"] = data["audit"][-_MAX_AUDIT_ENTRIES:] - await self._write(data) + """Append *entry* to the in-memory audit log (capped at 1000 entries) + and schedule a debounced disk flush. + """ + await self._ensure_loaded() + self._audit.append(json.loads(entry.model_dump_json())) + if len(self._audit) > _MAX_AUDIT_ENTRIES: + # Trim to the most-recent 1000 entries. + del self._audit[: len(self._audit) - _MAX_AUDIT_ENTRIES] + self._schedule_write() async def get_audit_log(self, path: str | None = None, limit: int = 100) -> list[AuditEntry]: - """Return audit entries, newest first, optionally filtered by *path*.""" - data = await self._read() - entries = data["audit"] + """Return audit entries, newest first, optionally filtered by *path*. + + Served entirely from the in-memory cache — zero file I/O. + """ + await self._ensure_loaded() + entries = self._audit if path is not None: entries = [e for e in entries if e["path"] == path] return [AuditEntry.model_validate(e) for e in reversed(entries)][:limit] @@ -217,3 +342,19 @@ async def subscribe(self) -> AsyncIterator[RouteState]: """Not supported — raises ``NotImplementedError``.""" raise NotImplementedError("FileBackend does not support pub/sub. Use polling instead.") yield + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + async def shutdown(self) -> None: + """Flush any pending write and release resources. + + Cancels the debounce timer and performs a synchronous flush so that + in-flight state changes are not lost on graceful shutdown. + """ + if self._pending_write is not None and not self._pending_write.done(): + self._pending_write.cancel() + self._pending_write = None + if self._loaded: + await self._flush_to_disk() diff --git a/shield/core/backends/memory.py b/shield/core/backends/memory.py index 9312a56..a8997ef 100644 --- a/shield/core/backends/memory.py +++ b/shield/core/backends/memory.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +from collections import defaultdict, deque from collections.abc import AsyncIterator from shield.core.backends.base import ShieldBackend @@ -16,11 +17,19 @@ class MemoryBackend(ShieldBackend): Default backend. Ideal for single-instance apps and testing. State is lost when the process restarts. + + Audit log is stored in a ``deque`` (O(1) append/evict) with a parallel + per-path index (``dict[path, list[AuditEntry]]``) so that filtered + queries — ``get_audit_log(path=...)`` — are O(k) where k is the number + of entries for that specific path, not O(total entries). """ def __init__(self) -> None: self._states: dict[str, RouteState] = {} - self._audit: list[AuditEntry] = [] + # Ordered audit log — deque gives O(1) append and O(1) popleft eviction. + self._audit: deque[AuditEntry] = deque() + # Per-path index for O(1)-lookup filtered audit queries. + self._audit_by_path: defaultdict[str, list[AuditEntry]] = defaultdict(list) self._subscribers: list[asyncio.Queue[RouteState]] = [] async def get_state(self, path: str) -> RouteState: @@ -48,15 +57,36 @@ async def list_states(self) -> list[RouteState]: return list(self._states.values()) async def write_audit(self, entry: AuditEntry) -> None: - """Append *entry* to the audit log, capping at 1000 entries.""" + """Append *entry* to the audit log, capping at 1000 entries. + + When the cap is reached the oldest entry is evicted from both the + ordered deque and the per-path index in O(1) / O(k) time respectively, + where k is the number of entries for the evicted path (≪ total entries). + """ + if len(self._audit) >= _MAX_AUDIT_ENTRIES: + evicted = self._audit.popleft() + # Clean up the per-path index for the evicted entry. + path_list = self._audit_by_path.get(evicted.path) + if path_list: + try: + path_list.remove(evicted) + except ValueError: + pass + self._audit.append(entry) - if len(self._audit) > _MAX_AUDIT_ENTRIES: - self._audit = self._audit[-_MAX_AUDIT_ENTRIES:] + self._audit_by_path[entry.path].append(entry) async def get_audit_log(self, path: str | None = None, limit: int = 100) -> list[AuditEntry]: - """Return audit entries, newest first, optionally filtered by *path*.""" - entries = self._audit if path is None else [e for e in self._audit if e.path == path] - return list(reversed(entries))[:limit] + """Return audit entries, newest first, optionally filtered by *path*. + + When *path* is provided the per-path index is used — O(k) where k is + the number of entries for that route — instead of scanning all 1000 + entries (O(N)). + """ + if path is None: + return list(reversed(self._audit))[:limit] + path_entries = self._audit_by_path.get(path, []) + return list(reversed(path_entries))[:limit] async def subscribe(self) -> AsyncIterator[RouteState]: """Yield ``RouteState`` objects as they are updated.""" diff --git a/shield/core/backends/redis.py b/shield/core/backends/redis.py index 44127a5..f2c235e 100644 --- a/shield/core/backends/redis.py +++ b/shield/core/backends/redis.py @@ -5,9 +5,25 @@ Key schema ---------- -``shield:state:{path}`` — JSON-serialized ``RouteState`` -``shield:audit`` — Redis list, newest-first (LPUSH + LTRIM to 1000) -``shield:changes`` — pub/sub channel for live state updates +``shield:state:{path}`` — JSON-serialized ``RouteState`` +``shield:route-index`` — Redis Set of all registered route paths + (replaces dangerous ``KEYS`` scans with safe + O(N) ``SMEMBERS`` that does not block the server) +``shield:audit`` — Redis list, newest-first (LPUSH + LTRIM to 1000) +``shield:audit:path:{path}`` — Per-path audit list for O(limit) filtered queries + instead of fetching all 1000 entries to filter in Python +``shield:changes`` — pub/sub channel for live state updates + +Performance notes +----------------- +* ``list_states()`` uses ``SMEMBERS shield:route-index`` + ``MGET`` instead + of ``KEYS shield:state:*``. ``KEYS`` is an O(keyspace) blocking command + that freezes Redis on production instances; ``SMEMBERS`` on a dedicated + set is safe and equally fast. +* ``set_state()`` / ``delete_state()`` maintain the route-index atomically + via pipeline so the set and the state key are always in sync. +* ``get_audit_log(path=X)`` reads directly from ``shield:audit:path:X`` + instead of fetching up to 1000 global entries and filtering in Python. """ from __future__ import annotations @@ -25,6 +41,7 @@ logger = logging.getLogger(__name__) _AUDIT_KEY = "shield:audit" +_ROUTE_INDEX_KEY = "shield:route-index" _CHANGES_CHANNEL = "shield:changes" _MAX_AUDIT_ENTRIES = 1000 @@ -33,6 +50,11 @@ def _state_key(path: str) -> str: return f"shield:state:{path}" +def _audit_path_key(path: str) -> str: + """Per-path audit list key for O(limit) filtered audit queries.""" + return f"shield:audit:path:{path}" + + class RedisBackend(ShieldBackend): """Backend that stores all state in Redis. @@ -74,32 +96,54 @@ async def get_state(self, path: str) -> RouteState: return RouteState.model_validate(json.loads(raw)) async def set_state(self, path: str, state: RouteState) -> None: - """Persist *state* for *path* and publish to ``shield:changes``.""" + """Persist *state* for *path*, update the route-index, and publish to + ``shield:changes``. + + The state key and the route-index entry are written atomically in a + single pipeline so ``list_states()`` can never see a state key that + is missing from the index (or vice-versa). + """ payload = state.model_dump_json() try: async with self._client() as r: - await r.set(_state_key(path), payload) - await r.publish(_CHANGES_CHANNEL, payload) + pipe = r.pipeline() + pipe.set(_state_key(path), payload) + pipe.sadd(_ROUTE_INDEX_KEY, path) + pipe.publish(_CHANGES_CHANNEL, payload) + await pipe.execute() except Exception as exc: logger.error("shield: redis set_state error for %r: %s", path, exc) raise async def delete_state(self, path: str) -> None: - """Remove state for *path*. No-op if not registered.""" + """Remove state for *path* and remove it from the route-index. + + No-op if *path* is not registered. + """ try: async with self._client() as r: - await r.delete(_state_key(path)) + pipe = r.pipeline() + pipe.delete(_state_key(path)) + pipe.srem(_ROUTE_INDEX_KEY, path) + await pipe.execute() except Exception as exc: logger.error("shield: redis delete_state error: %s", exc) raise async def list_states(self) -> list[RouteState]: - """Return all registered route states.""" + """Return all registered route states. + + Uses ``SMEMBERS shield:route-index`` + ``MGET`` instead of the + dangerous ``KEYS shield:state:*`` pattern. ``KEYS`` is an O(keyspace) + blocking command that can freeze a busy Redis server; ``SMEMBERS`` on + the dedicated route-index set is safe to use in production. + """ try: async with self._client() as r: - keys: list[str] = await r.keys("shield:state:*") - if not keys: + paths: set[str] = await r.smembers(_ROUTE_INDEX_KEY) # type: ignore[misc] + if not paths: return [] + keys = [_state_key(p) for p in paths] values: list[str | None] = await r.mget(*keys) except Exception as exc: logger.error("shield: redis list_states error: %s", exc) @@ -112,37 +156,51 @@ async def list_states(self) -> list[RouteState]: return states async def write_audit(self, entry: AuditEntry) -> None: - """Append *entry* to the Redis audit list (capped at 1000).""" + """Append *entry* to both the global audit list and the per-path list. + + Both lists are capped at 1000 entries via ``LTRIM``. Writing to a + per-path list means ``get_audit_log(path=X)`` can fetch exactly the + required entries directly — no full-list fetch-then-filter in Python. + """ payload = entry.model_dump_json() + path_key = _audit_path_key(entry.path) try: async with self._client() as r: pipe = r.pipeline() + # Global audit list (for unfiltered queries). pipe.lpush(_AUDIT_KEY, payload) pipe.ltrim(_AUDIT_KEY, 0, _MAX_AUDIT_ENTRIES - 1) + # Per-path audit list (for filtered queries — O(limit) instead + # of O(1000) fetch-then-filter). + pipe.lpush(path_key, payload) + pipe.ltrim(path_key, 0, _MAX_AUDIT_ENTRIES - 1) await pipe.execute() except Exception as exc: logger.error("shield: redis write_audit error: %s", exc) raise async def get_audit_log(self, path: str | None = None, limit: int = 100) -> list[AuditEntry]: - """Return audit entries, newest first, optionally filtered by *path*.""" + """Return audit entries, newest first. + + When *path* is provided the per-path list is used — fetches exactly + *limit* entries via a single ``LRANGE`` call, eliminating the + fetch-all-then-filter pattern of the previous implementation. + """ try: async with self._client() as r: - # Fetch more than limit to allow post-filter narrowing. - fetch = limit if path is None else _MAX_AUDIT_ENTRIES - raws: list[str] = await r.lrange(_AUDIT_KEY, 0, fetch - 1) # type: ignore[misc] + if path is not None: + # Per-path list: fetch exactly what we need — O(limit). + raws: list[str] = await r.lrange( # type: ignore[misc] + _audit_path_key(path), 0, limit - 1 + ) + else: + # Global list: all entries newest-first. + raws = await r.lrange(_AUDIT_KEY, 0, limit - 1) # type: ignore[misc] except Exception as exc: logger.error("shield: redis get_audit_log error: %s", exc) raise - entries: list[AuditEntry] = [] - for raw in raws: - entry = AuditEntry.model_validate(json.loads(raw)) - if path is None or entry.path == path: - entries.append(entry) - if len(entries) >= limit: - break - return entries + return [AuditEntry.model_validate(json.loads(raw)) for raw in raws] async def subscribe(self) -> AsyncIterator[RouteState]: """Yield ``RouteState`` objects as they are updated via pub/sub.""" diff --git a/shield/core/engine.py b/shield/core/engine.py index 9d24862..91f20d3 100644 --- a/shield/core/engine.py +++ b/shield/core/engine.py @@ -61,6 +61,14 @@ def __init__( self.scheduler: MaintenanceScheduler = MaintenanceScheduler(engine=self) # Webhook registry: list of (url, formatter) pairs. self._webhooks: list[tuple[str, WebhookFormatter]] = [] + # Global config cache — avoids a backend round-trip on every request. + # Invalidated whenever the global config is written in this process. + # Acceptable stale window for multi-instance deployments: until the + # next write in this process (usually a human-initiated action). + self._global_config_cache: GlobalMaintenanceConfig | None = None + # Monotonic counter bumped on every state change. Used by the OpenAPI + # filter to detect when the cached schema needs to be rebuilt. + self._schema_version: int = 0 # ------------------------------------------------------------------ # Async context manager — calls backend lifecycle hooks @@ -119,12 +127,12 @@ async def check( """ # 1. Global maintenance check — highest priority. try: - global_cfg = await self.backend.get_global_config() + global_cfg = await self._get_global_config_cached() if global_cfg.enabled: method_key = f"{method.upper()}:{path}" if method else None - is_exempt = path in global_cfg.exempt_paths or ( - method_key is not None and method_key in global_cfg.exempt_paths - ) + # Use frozenset for O(1) membership tests instead of O(M) list scan. + exempt = global_cfg.exempt_set + is_exempt = path in exempt or (method_key is not None and method_key in exempt) if not is_exempt: raise MaintenanceException(reason=global_cfg.reason) except MaintenanceException: @@ -244,6 +252,63 @@ async def register(self, path: str, meta: dict[str, Any]) -> None: await self.backend.set_state(path, state) + async def register_batch(self, routes: list[tuple[str, dict[str, Any]]]) -> None: + """Register multiple routes in a single backend round-trip. + + Replaces N individual ``register()`` calls (each doing one + ``backend.get_state()`` read) with a single ``backend.list_states()`` + call to discover already-persisted routes, then only writes the truly + new ones. For ``FileBackend`` this means one file read instead of N, + and the debounced writer coalesces all new-state writes into a single + disk flush. + + Like ``register()``, persisted state always wins over decorator state — + routes already present in the backend are left untouched. + + Parameters + ---------- + routes: + Sequence of ``(path, meta)`` pairs exactly as accumulated by + ``ShieldRouter._shield_routes``. + """ + if not routes: + return + + # One backend call to discover every already-persisted route. + try: + existing = await self.backend.list_states() + existing_keys: set[str] = {s.path for s in existing} + except Exception: + logger.exception( + "shield: register_batch — failed to list existing states, " + "falling back to per-route registration" + ) + for path, meta in routes: + await self.register(path, meta) + return + + for path, meta in routes: + if path in existing_keys: + continue # persisted state wins — skip + + is_force_active = bool(meta.get("force_active")) + status_str: str = meta.get("status", RouteStatus.ACTIVE) + status = RouteStatus(status_str) + + state = RouteState( + path=path, + status=status, + reason=meta.get("reason", ""), + allowed_envs=meta.get("allowed_envs", []), + sunset_date=meta.get("sunset_date"), + successor_path=meta.get("successor_path"), + force_active=is_force_active, + ) + if "window" in meta and meta["window"] is not None: + state.window = meta["window"] + + await self.backend.set_state(path, state) + # ------------------------------------------------------------------ # State mutation methods # ------------------------------------------------------------------ @@ -261,6 +326,32 @@ async def _assert_mutable(self, path: str) -> RouteState: raise RouteProtectedException(path) return state + async def _get_global_config_cached(self) -> GlobalMaintenanceConfig: + """Return the global config, using the in-process cache when available. + + The cache is populated on first call and invalidated whenever this + process writes a new global config (enable/disable/set_exempt_paths). + For single-instance deployments the cache is always fresh. For + multi-instance Redis deployments, cross-process changes are visible + on the next write from *this* process — an acceptable tradeoff given + that global maintenance is a rare, operator-initiated action. + """ + if self._global_config_cache is not None: + return self._global_config_cache + cfg = await self.backend.get_global_config() + self._global_config_cache = cfg + return cfg + + def _invalidate_global_config_cache(self) -> None: + """Drop the cached global config so the next check re-fetches from backend.""" + self._global_config_cache = None + + def _bump_schema_version(self) -> None: + """Increment the schema version counter to signal that cached OpenAPI schemas + are stale and need to be rebuilt on the next ``/docs`` or ``/openapi.json`` request. + """ + self._schema_version += 1 + async def enable(self, path: str, actor: str = "system", reason: str = "") -> RouteState: """Enable *path*, returning the updated ``RouteState``. @@ -280,6 +371,7 @@ async def enable(self, path: str, actor: str = "system", reason: str = "") -> Ro update={"status": RouteStatus.ACTIVE, "reason": reason, "window": None} ) await self.backend.set_state(path, new_state) + self._bump_schema_version() await self._audit( path=path, action="enable", @@ -296,6 +388,7 @@ async def disable(self, path: str, reason: str = "", actor: str = "system") -> R old_state = await self._assert_mutable(path) new_state = old_state.model_copy(update={"status": RouteStatus.DISABLED, "reason": reason}) await self.backend.set_state(path, new_state) + self._bump_schema_version() await self._audit( path=path, action="disable", @@ -324,6 +417,7 @@ async def set_maintenance( } ) await self.backend.set_state(path, new_state) + self._bump_schema_version() await self._audit( path=path, action="maintenance_on", @@ -357,6 +451,7 @@ async def set_env_only(self, path: str, envs: list[str], actor: str = "system") update={"status": RouteStatus.ENV_GATED, "allowed_envs": envs} ) await self.backend.set_state(path, new_state) + self._bump_schema_version() await self._audit( path=path, action="env_gate", @@ -405,6 +500,8 @@ async def enable_global_maintenance( include_force_active=include_force_active, ) await self.backend.set_global_config(cfg) + self._invalidate_global_config_cache() + self._bump_schema_version() prev = RouteStatus.MAINTENANCE if old_cfg.enabled else RouteStatus.ACTIVE await self._audit( path="__global__", @@ -421,6 +518,8 @@ async def disable_global_maintenance(self, actor: str = "system") -> GlobalMaint old_cfg = await self.backend.get_global_config() cfg = GlobalMaintenanceConfig(enabled=False) await self.backend.set_global_config(cfg) + self._invalidate_global_config_cache() + self._bump_schema_version() prev = RouteStatus.MAINTENANCE if old_cfg.enabled else RouteStatus.ACTIVE await self._audit( path="__global__", @@ -439,6 +538,8 @@ async def set_global_exempt_paths( cfg = await self.backend.get_global_config() updated = cfg.model_copy(update={"exempt_paths": paths}) await self.backend.set_global_config(updated) + self._invalidate_global_config_cache() + self._bump_schema_version() return updated # ------------------------------------------------------------------ diff --git a/shield/core/models.py b/shield/core/models.py index 5d6ca08..0fa09f3 100644 --- a/shield/core/models.py +++ b/shield/core/models.py @@ -69,6 +69,15 @@ class GlobalMaintenanceConfig(BaseModel): exempt_paths: list[str] = Field(default_factory=list) include_force_active: bool = False + @property + def exempt_set(self) -> frozenset[str]: + """Return ``exempt_paths`` as a frozenset for O(1) membership tests. + + Constructed once per config object — avoids rebuilding a set on every + call to ``engine.check()`` when global maintenance is active. + """ + return frozenset(self.exempt_paths) + class AuditEntry(BaseModel): """An immutable record of a route state change.""" diff --git a/shield/fastapi/middleware.py b/shield/fastapi/middleware.py index 9050cd3..f67606a 100644 --- a/shield/fastapi/middleware.py +++ b/shield/fastapi/middleware.py @@ -34,7 +34,7 @@ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.responses import JSONResponse, Response -from starlette.routing import Match +from starlette.routing import Match, Route from starlette.types import ASGIApp, Receive, Scope, Send from shield.core.engine import ShieldEngine @@ -65,6 +65,12 @@ def __init__(self, app: ASGIApp, engine: ShieldEngine) -> None: self.engine = engine self._scan_lock: asyncio.Lock = asyncio.Lock() self._routes_scanned: bool = False + # Pre-built route lookup cache — populated after scan_routes() completes. + # Static paths (no path params) get an O(1) dict lookup. + # Parameterised paths fall back to a short list scan (usually << total routes). + self._static_route_meta: dict[str, tuple[bool, str]] = {} + self._param_routes: list[tuple[Route, bool, str]] = [] + self._route_cache_built: bool = False # ------------------------------------------------------------------ # ASGI entry point — intercept lifespan for eager startup scan @@ -118,6 +124,45 @@ async def _do_scan(self, app: Any) -> None: await scan_routes(app, self.engine) self._routes_scanned = True + # Build the O(1) route-lookup cache now that all routes are registered. + self._build_route_cache(app) + + def _build_route_cache(self, app: Any) -> None: + """Pre-build a fast route-metadata lookup structure. + + Splits app routes into two buckets: + + * ``_static_route_meta`` — exact-path routes (no ``{params}``). + Resolved in O(1) via dict lookup on every request. + * ``_param_routes`` — parameterised routes (e.g. ``/items/{id}``). + Stored as a short list; still requires ``route.matches()`` but + the list is typically much smaller than the total route count. + + The structure stores ``(is_force_active, template_path)`` per route + so ``_resolve_route`` can answer both questions in a single pass. + """ + static: dict[str, tuple[bool, str]] = {} + param: list[tuple[Route, bool, str]] = [] + + for route in getattr(app, "routes", []): + if not isinstance(route, Route): + continue + endpoint = getattr(route, "endpoint", None) + meta = getattr(endpoint, "__shield_meta__", {}) if endpoint else {} + is_force_active = bool(meta.get("force_active")) + template = getattr(route, "path", None) or "" + + if "{" not in template: + # Static path — exact dict key match on every request. + static[template] = (is_force_active, template) + else: + # Parameterised path — requires regex matching per request + # but the list is usually a small fraction of total routes. + param.append((route, is_force_active, template)) + + self._static_route_meta = static + self._param_routes = param + self._route_cache_built = True async def _ensure_routes_scanned(self, app: Any) -> None: """Lazy fallback scan for environments without ASGI lifespan support.""" @@ -181,7 +226,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) - return response def _resolve_route(self, request: Request) -> tuple[bool, str | None]: - """Match the request against app routes in a single pass. + """Match the request against app routes using the pre-built cache. Returns ``(is_force_active, template_path)`` where: @@ -194,7 +239,37 @@ def _resolve_route(self, request: Request) -> tuple[bool, str | None]: are resolved correctly. Returns ``(False, None)`` when no route matches (unregistered path). + + Performance + ----------- + After ``_build_route_cache()`` runs at startup: + + * Static paths (no ``{params}``) resolve in O(1) via dict lookup. + * Parameterised paths scan only ``_param_routes`` — a small subset of + all routes — rather than iterating the entire route list on every + request. + + Falls back to the original O(N) walk when the cache is not yet built + (e.g. in environments without ASGI lifespan support where the lazy + scan has not completed for the current request). """ + path = request.url.path + + if self._route_cache_built: + # Fast path: O(1) dict lookup for static routes. + entry = self._static_route_meta.get(path) + if entry is not None: + return entry + + # Parameterised routes — scan only the short param-route list. + for route, is_force_active, template in self._param_routes: + match, _ = route.matches(request.scope) + if match == Match.FULL: + return is_force_active, template + + return False, None + + # Fallback: full O(N) walk used before the cache is ready. routes = getattr(request.app, "routes", []) for route in routes: match, _ = route.matches(request.scope) diff --git a/shield/fastapi/openapi.py b/shield/fastapi/openapi.py index 56bc1b7..dafde4b 100644 --- a/shield/fastapi/openapi.py +++ b/shield/fastapi/openapi.py @@ -58,7 +58,20 @@ def apply_shield_to_openapi(app: FastAPI, engine: ShieldEngine) -> None: """ original_openapi = app.openapi + # Schema cache — keyed to engine._schema_version so it is automatically + # invalidated whenever any route state changes (enable, disable, maintenance, etc.). + # The cache avoids the O(N * verbs) rebuild on every /docs or /openapi.json request. + _cached_schema: dict[str, Any] | None = None + _cached_schema_version: int = -1 + def patched_openapi() -> dict[str, Any]: + nonlocal _cached_schema, _cached_schema_version + + # Return cached schema when the engine state hasn't changed. + current_version = engine._schema_version + if _cached_schema is not None and _cached_schema_version == current_version: + return _cached_schema + base = original_openapi() states = _fetch_states(engine) global_cfg = _fetch_global_config(engine) @@ -186,6 +199,11 @@ def patched_openapi() -> dict[str, Any]: schema["info"]["x-shield-global-maintenance"] = {"enabled": False} schema["paths"] = filtered + + # Store in cache so identical /docs requests within the same state + # version are served instantly without rebuilding. + _cached_schema = schema + _cached_schema_version = current_version return schema app.openapi = patched_openapi # type: ignore[method-assign] diff --git a/shield/fastapi/router.py b/shield/fastapi/router.py index 74e5904..984e734 100644 --- a/shield/fastapi/router.py +++ b/shield/fastapi/router.py @@ -39,7 +39,10 @@ async def scan_routes(app: Any, engine: ShieldEngine) -> None: This function is **idempotent**: routes already registered (e.g. by a ``ShieldRouter`` startup hook or a previous ``scan_routes()`` call) are - left untouched because ``engine.register()`` honours persisted state. + left untouched because ``engine.register_batch()`` honours persisted state. + + Uses ``engine.register_batch()`` to discover all already-persisted routes + in one backend call instead of N individual ``get_state()`` reads. Parameters ---------- @@ -49,6 +52,8 @@ async def scan_routes(app: Any, engine: ShieldEngine) -> None: engine: The ``ShieldEngine`` that owns all route state. """ + routes_to_register: list[tuple[str, dict[str, Any]]] = [] + for route in getattr(app, "routes", []): if not isinstance(route, Route): continue @@ -65,9 +70,11 @@ async def scan_routes(app: Any, engine: ShieldEngine) -> None: methods: set[str] = route.methods or set() if methods: for method in sorted(methods): - await engine.register(f"{method}:{route.path}", meta) + routes_to_register.append((f"{method}:{route.path}", meta)) else: - await engine.register(route.path, meta) + routes_to_register.append((route.path, meta)) + + await engine.register_batch(routes_to_register) class ShieldRouter(APIRouter): @@ -155,9 +162,13 @@ async def register_shield_routes(self) -> None: Call this during application startup (e.g. via a ``lifespan`` handler or ``on_startup`` event). ``ShieldRouter`` calls this automatically when you pass it to ``app.include_router()``. + + Uses ``engine.register_batch()`` — a single ``list_states()`` backend + call discovers all already-persisted routes, then only the truly new + routes are written. For ``FileBackend`` this means one file read and + one debounced file write instead of N reads and N writes. """ - for path, meta in self._shield_routes: - await self._shield_engine.register(path, meta) + await self._shield_engine.register_batch(list(self._shield_routes)) # ------------------------------------------------------------------ # Hook into include_router so startup fires automatically