Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 174 additions & 33 deletions shield/core/backends/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@

Format is chosen at construction time from the file extension. The data
structure is identical across all formats; only the serialisation differs.

Performance design
------------------
The original implementation read and wrote the entire file on every
``get_state`` / ``set_state`` / ``write_audit`` call — O(N) file I/O per
operation.

This version introduces a **write-through in-memory cache**:

* All reads are served from the in-memory ``_states`` dict — zero file I/O.
* Writes update the in-memory dict immediately (O(1)), then schedule a
**debounced disk flush** (50 ms window). Rapid sequential writes are
coalesced into a single file write.
* A dedicated ``_io_lock`` serialises concurrent flushes so the file is
never corrupted by interleaved writes.
* The cache is populated lazily on the first operation via ``_ensure_loaded``.
* ``shutdown()`` cancels any pending debounce and flushes synchronously.
"""

from __future__ import annotations
Expand All @@ -28,6 +45,9 @@

_MAX_AUDIT_ENTRIES = 1000

# Debounce window: rapid sequential writes are coalesced into one disk flush.
_WRITE_DEBOUNCE_SECONDS = 0.05

# Supported extensions mapped to a canonical format name.
_EXT_TO_FORMAT: dict[str, str] = {
".json": "json",
Expand All @@ -41,8 +61,12 @@ class FileBackend(ShieldBackend):
"""Backend that persists state to a file via ``aiofiles``.

Survives process restarts. Suitable for simple single-instance
deployments. A single ``asyncio.Lock`` prevents concurrent write
corruption.
deployments.

All reads are served from an in-memory cache populated on first access.
Writes update the cache immediately then schedule a debounced disk flush —
meaning rapid bursts of state changes (e.g. startup route registration)
result in a single file write rather than N writes.

The file format is auto-detected from the extension:

Expand Down Expand Up @@ -78,7 +102,12 @@ class FileBackend(ShieldBackend):

def __init__(self, path: str) -> None:
self._path = Path(path)
self._lock = asyncio.Lock()
# Serializes concurrent disk flushes — held only during I/O, not
# during in-memory mutations, so reads are never blocked.
self._io_lock = asyncio.Lock()
# Guards the initial file load to prevent duplicate reads when
# multiple coroutines first access the backend concurrently.
self._load_lock = asyncio.Lock()

ext = self._path.suffix.lower()
if ext not in _EXT_TO_FORMAT:
Expand All @@ -89,6 +118,17 @@ def __init__(self, path: str) -> None:
)
self._format: str = _EXT_TO_FORMAT[ext]

# In-memory cache — populated lazily by _ensure_loaded().
# Raw dicts (not Pydantic models) are stored so that _flush_to_disk()
# can snapshot and serialise without re-encoding.
self._states: dict[str, Any] = {}
self._audit: list[dict[str, Any]] = []
self._loaded: bool = False

# Debounced write task — cancelled and replaced on each write so that
# a burst of N writes results in one disk flush rather than N.
self._pending_write: asyncio.Task[None] | None = None

# ------------------------------------------------------------------
# Format-specific serialisation
# ------------------------------------------------------------------
Expand Down Expand Up @@ -140,11 +180,12 @@ def _serialize(self, data: dict[str, Any]) -> str:
# Internal read / write
# ------------------------------------------------------------------

async def _read(self) -> dict[str, Any]:
"""Read and parse the state file.
async def _read_from_disk(self) -> dict[str, Any]:
"""Read and parse the state file from disk.

Returns an empty ``{"states": {}, "audit": []}`` structure if the
file does not exist or is blank.
file does not exist or is blank. Called only once during cache
initialisation — all subsequent reads go through the in-memory cache.
"""
if not self._path.exists():
return {"states": {}, "audit": []}
Expand All @@ -163,52 +204,136 @@ async def _write(self, data: dict[str, Any]) -> None:
async with aiofiles.open(self._path, "w") as f:
await f.write(self._serialize(data))

async def _ensure_loaded(self) -> None:
"""Populate the in-memory cache from disk on first access.

Uses a lock + double-check to ensure exactly one disk read even when
multiple coroutines call a backend method concurrently before the
cache is warm.
"""
if self._loaded:
return
async with self._load_lock:
if self._loaded: # another coroutine beat us here
return
data = await self._read_from_disk()
self._states = data["states"]
self._audit = data["audit"]
self._loaded = True

def _schedule_write(self) -> None:
"""Schedule a debounced disk flush.

Each call cancels any previously-scheduled flush and creates a new
one, so a burst of N writes within the debounce window results in
exactly one disk write. Falls back to a no-op when there is no
running event loop (e.g. during synchronous test teardown).
"""
if self._pending_write is not None and not self._pending_write.done():
self._pending_write.cancel()
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return # no event loop — write will be flushed on shutdown()
self._pending_write = loop.create_task(self._debounced_write())

async def _debounced_write(self) -> None:
"""Wait the debounce window then flush to disk."""
try:
await asyncio.sleep(_WRITE_DEBOUNCE_SECONDS)
await self._flush_to_disk()
except asyncio.CancelledError:
pass # a newer write superseded us — that task will flush instead

async def _flush_to_disk(self) -> None:
"""Snapshot the current in-memory state and write it to disk.

The ``_io_lock`` prevents concurrent flushes from interleaving.
The snapshot is taken inside the lock so the written data is always
consistent — no partial view of an in-progress update.

In asyncio, dict operations without ``await`` are atomic (cooperative
scheduling guarantees no interleaving between two non-awaiting
statements), so snapshotting inside the lock is sufficient.
"""
async with self._io_lock:
data: dict[str, Any] = {
"states": dict(self._states),
"audit": list(self._audit),
}
await self._write(data)

# ------------------------------------------------------------------
# ShieldBackend interface
# ------------------------------------------------------------------

async def get_state(self, path: str) -> RouteState:
"""Return the current state for *path*.
"""Return the current state for *path* from the in-memory cache.

Raises ``KeyError`` if no state has been registered for *path*.
Zero file I/O after the cache is warm.
"""
data = await self._read()
if path not in data["states"]:
await self._ensure_loaded()
raw = self._states.get(path)
if raw is None:
raise KeyError(f"No state registered for path {path!r}")
return RouteState.model_validate(data["states"][path])
return RouteState.model_validate(raw)

async def set_state(self, path: str, state: RouteState) -> None:
"""Persist *state* for *path*."""
async with self._lock:
data = await self._read()
data["states"][path] = json.loads(state.model_dump_json())
await self._write(data)
"""Update the in-memory cache and flush to disk immediately.

State changes are written synchronously so that a second
``FileBackend`` instance (e.g. the CLI) reading the same file sees
the update right away. Unlike ``write_audit``, state mutations
are not debounced — durability is more important than batching here.
"""
await self._ensure_loaded()
self._states[path] = json.loads(state.model_dump_json())
# Cancel any pending debounced audit flush — the full flush below
# will include both the new state and any queued audit entries.
if self._pending_write is not None and not self._pending_write.done():
self._pending_write.cancel()
self._pending_write = None
await self._flush_to_disk()

async def delete_state(self, path: str) -> None:
"""Remove state for *path*. No-op if not registered."""
async with self._lock:
data = await self._read()
data["states"].pop(path, None)
await self._write(data)
"""Remove state for *path* from cache and flush to disk immediately.

No-op if *path* is not registered.
"""
await self._ensure_loaded()
self._states.pop(path, None)
if self._pending_write is not None and not self._pending_write.done():
self._pending_write.cancel()
self._pending_write = None
await self._flush_to_disk()

async def list_states(self) -> list[RouteState]:
"""Return all registered route states."""
data = await self._read()
return [RouteState.model_validate(v) for v in data["states"].values()]
"""Return all registered route states from the in-memory cache.

Zero file I/O after the cache is warm.
"""
await self._ensure_loaded()
return [RouteState.model_validate(v) for v in self._states.values()]

async def write_audit(self, entry: AuditEntry) -> None:
"""Append *entry* to the audit log, capping at 1000 entries."""
async with self._lock:
data = await self._read()
data["audit"].append(json.loads(entry.model_dump_json()))
if len(data["audit"]) > _MAX_AUDIT_ENTRIES:
data["audit"] = data["audit"][-_MAX_AUDIT_ENTRIES:]
await self._write(data)
"""Append *entry* to the in-memory audit log (capped at 1000 entries)
and schedule a debounced disk flush.
"""
await self._ensure_loaded()
self._audit.append(json.loads(entry.model_dump_json()))
if len(self._audit) > _MAX_AUDIT_ENTRIES:
# Trim to the most-recent 1000 entries.
del self._audit[: len(self._audit) - _MAX_AUDIT_ENTRIES]
self._schedule_write()

async def get_audit_log(self, path: str | None = None, limit: int = 100) -> list[AuditEntry]:
"""Return audit entries, newest first, optionally filtered by *path*."""
data = await self._read()
entries = data["audit"]
"""Return audit entries, newest first, optionally filtered by *path*.

Served entirely from the in-memory cache — zero file I/O.
"""
await self._ensure_loaded()
entries = self._audit
if path is not None:
entries = [e for e in entries if e["path"] == path]
return [AuditEntry.model_validate(e) for e in reversed(entries)][:limit]
Expand All @@ -217,3 +342,19 @@ async def subscribe(self) -> AsyncIterator[RouteState]:
"""Not supported — raises ``NotImplementedError``."""
raise NotImplementedError("FileBackend does not support pub/sub. Use polling instead.")
yield

# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------

async def shutdown(self) -> None:
"""Flush any pending write and release resources.

Cancels the debounce timer and performs a synchronous flush so that
in-flight state changes are not lost on graceful shutdown.
"""
if self._pending_write is not None and not self._pending_write.done():
self._pending_write.cancel()
self._pending_write = None
if self._loaded:
await self._flush_to_disk()
44 changes: 37 additions & 7 deletions shield/core/backends/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import asyncio
from collections import defaultdict, deque
from collections.abc import AsyncIterator

from shield.core.backends.base import ShieldBackend
Expand All @@ -16,11 +17,19 @@ class MemoryBackend(ShieldBackend):

Default backend. Ideal for single-instance apps and testing.
State is lost when the process restarts.

Audit log is stored in a ``deque`` (O(1) append/evict) with a parallel
per-path index (``dict[path, list[AuditEntry]]``) so that filtered
queries — ``get_audit_log(path=...)`` — are O(k) where k is the number
of entries for that specific path, not O(total entries).
"""

def __init__(self) -> None:
self._states: dict[str, RouteState] = {}
self._audit: list[AuditEntry] = []
# Ordered audit log — deque gives O(1) append and O(1) popleft eviction.
self._audit: deque[AuditEntry] = deque()
# Per-path index for O(1)-lookup filtered audit queries.
self._audit_by_path: defaultdict[str, list[AuditEntry]] = defaultdict(list)
self._subscribers: list[asyncio.Queue[RouteState]] = []

async def get_state(self, path: str) -> RouteState:
Expand Down Expand Up @@ -48,15 +57,36 @@ async def list_states(self) -> list[RouteState]:
return list(self._states.values())

async def write_audit(self, entry: AuditEntry) -> None:
"""Append *entry* to the audit log, capping at 1000 entries."""
"""Append *entry* to the audit log, capping at 1000 entries.

When the cap is reached the oldest entry is evicted from both the
ordered deque and the per-path index in O(1) / O(k) time respectively,
where k is the number of entries for the evicted path (≪ total entries).
"""
if len(self._audit) >= _MAX_AUDIT_ENTRIES:
evicted = self._audit.popleft()
# Clean up the per-path index for the evicted entry.
path_list = self._audit_by_path.get(evicted.path)
if path_list:
try:
path_list.remove(evicted)
except ValueError:
pass

self._audit.append(entry)
if len(self._audit) > _MAX_AUDIT_ENTRIES:
self._audit = self._audit[-_MAX_AUDIT_ENTRIES:]
self._audit_by_path[entry.path].append(entry)

async def get_audit_log(self, path: str | None = None, limit: int = 100) -> list[AuditEntry]:
"""Return audit entries, newest first, optionally filtered by *path*."""
entries = self._audit if path is None else [e for e in self._audit if e.path == path]
return list(reversed(entries))[:limit]
"""Return audit entries, newest first, optionally filtered by *path*.

When *path* is provided the per-path index is used — O(k) where k is
the number of entries for that route — instead of scanning all 1000
entries (O(N)).
"""
if path is None:
return list(reversed(self._audit))[:limit]
path_entries = self._audit_by_path.get(path, [])
return list(reversed(path_entries))[:limit]

async def subscribe(self) -> AsyncIterator[RouteState]:
"""Yield ``RouteState`` objects as they are updated."""
Expand Down
Loading
Loading