-
Notifications
You must be signed in to change notification settings - Fork 0
fix(proxy): Enforce immutable snapshots and #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: update-readme
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,7 +49,10 @@ def _extract_subschema(schema: CoreSchema, name_or_idx) -> CoreSchema: | |
| # Prefer strict schema for subfield extraction | ||
| return _extract_subschema(strict, name_or_idx) | ||
| # JSON-or-python wrapper | ||
| case {"type": "json-or-python", "json_schema": inner} | {"type": "json-or-python", "python_schema": inner}: | ||
| case {"type": "json-or-python", "json_schema": inner} | { | ||
| "type": "json-or-python", | ||
| "python_schema": inner, | ||
| }: | ||
| # Recurse through the inner schema | ||
| return _extract_subschema(inner, name_or_idx) | ||
| # BaseModel | ||
|
|
@@ -118,6 +121,8 @@ def _unwrap_proxy_and_apply( | |
|
|
||
| Note: SerializerFunction is a union of many function signatures with different | ||
| arities. We wrap them generically here, which pyright can't fully verify. | ||
| Pydantic-core may pass additional keyword-only args (e.g., mode), hence the | ||
| *args pattern to forward all positional arguments. | ||
| """ | ||
| if pass_info: | ||
|
|
||
|
|
@@ -133,14 +138,40 @@ def apply_to_unwrapped(proxy: "SerializationProxy[T]", *args): | |
| return apply_to_unwrapped # pyright: ignore[reportReturnType] | ||
|
|
||
|
|
||
| def _freeze_collections(obj): | ||
| """Recursively freeze mutable collections for true snapshot immutability. | ||
|
|
||
| This ensures that even direct access to .mapping or .serialized cannot | ||
| mutate the snapshot. Lists become tuples, dicts become MappingProxyType. | ||
|
|
||
| Args: | ||
| obj: The object to freeze (can be dict, list, tuple, or primitive) | ||
|
|
||
| Returns: | ||
| Frozen version of the object with all nested collections immutable | ||
| """ | ||
| if isinstance(obj, dict): | ||
| # Freeze nested values, then wrap dict | ||
| return MappingProxyType({k: _freeze_collections(v) for k, v in obj.items()}) | ||
| elif isinstance(obj, list): | ||
| # Convert list to tuple, freeze nested items | ||
| return tuple(_freeze_collections(item) for item in obj) | ||
| elif isinstance(obj, tuple): | ||
| # Already immutable container, but freeze nested items | ||
| return tuple(_freeze_collections(item) for item in obj) | ||
| else: | ||
| # Primitive or other immutable type (str, int, float, bool, None, etc.) | ||
| return obj | ||
|
|
||
|
|
||
| # Cache size constants for memory management | ||
| WRAPPED_SCHEMA_CACHE_SIZE = 256 | ||
| PROXY_TYPE_CACHE_SIZE = 256 | ||
| ATTR_CACHE_SIZE = 512 | ||
|
|
||
| # Type-check tuples (hoisted to module scope for micro-optimization) | ||
| _MAPPING_TYPES = (dict, MappingProxyType, Mapping) | ||
| _COLLECTION_TYPES = (dict, list, tuple) | ||
| _COLLECTION_TYPES = (dict, list, tuple, MappingProxyType) | ||
|
|
||
| # Bounded cache for wrapped schemas to prevent memory leaks in long-running applications | ||
| # Using OrderedDict for LRU eviction | ||
|
|
@@ -153,7 +184,6 @@ def _wrap_core_schema(schema: CoreSchema) -> CoreSchema: | |
| """Wrap a CoreSchema to make it proxy-aware. Uses bounded LRU cache to avoid expensive deepcopy.""" | ||
| schema_id = id(schema) | ||
|
|
||
| # Check cache under lock (OrderedDict isn't thread-safe even for reads) | ||
| with _WRAPPED_SCHEMA_CACHE_LOCK: | ||
| tup = _wrapped_schema_cache.get(schema_id) | ||
| if tup is not None: | ||
|
|
@@ -163,8 +193,6 @@ def _wrap_core_schema(schema: CoreSchema) -> CoreSchema: | |
| _wrapped_schema_cache.move_to_end(schema_id) | ||
| return wrapped | ||
|
|
||
| # Build wrapped schema (outside lock to minimize critical section) | ||
| # Key insight: Don't mutate "type" - only wrap serialization behavior | ||
| match schema: | ||
| # something we can reference to (e.g. BaseModel, Dataclass, ...) | ||
| case {"ref": ref}: | ||
|
|
@@ -185,16 +213,15 @@ def _wrap_core_schema(schema: CoreSchema) -> CoreSchema: | |
| # Pyright can't verify all possible serialization schema types support "function". | ||
| wrapped_schema["serialization"]["function"] = _unwrap_proxy_and_apply( # pyright: ignore[reportGeneralTypeIssues] | ||
| func, # pyright: ignore[reportArgumentType] | ||
| pass_info=True | ||
| pass_info=True, | ||
| ) | ||
| # Has custom serializer without info_arg | ||
| case {"serialization": {"function": func}}: | ||
| wrapped_schema = deepcopy(schema) | ||
| # We're wrapping the serializer function to unwrap proxies first. | ||
| # Pyright can't verify all possible serialization schema types support "function". | ||
| wrapped_schema["serialization"]["function"] = _unwrap_proxy_and_apply( # pyright: ignore[reportGeneralTypeIssues] | ||
| func, # pyright: ignore[reportArgumentType] | ||
| pass_info=False | ||
| pass_info=False, | ||
| ) | ||
| # No custom serializer - add one | ||
| case _: | ||
|
|
@@ -244,6 +271,15 @@ class SerializationProxy(Generic[T]): | |
| - Deterministic serialization snapshots for concurrent operations | ||
| """ | ||
|
|
||
| __slots__ = ( | ||
| "obj", | ||
| "serialized", | ||
| "root_adapter", | ||
| "_attr_cache", | ||
| "_attr_cache_lock", | ||
| "_version", | ||
| ) | ||
|
|
||
| core_schema: CoreSchema | ||
| __pydantic_serializer__: SchemaSerializer | ||
| # Note: __pydantic_validator__ is intentionally not set to avoid | ||
|
|
@@ -264,8 +300,21 @@ def __init__( | |
| self.root_adapter = root_adapter | ||
| # Bounded LRU cache for accessed attributes to avoid rebuilding proxies | ||
| # Keys are either strings (for attributes) or tuples (for items) | ||
| self._attr_cache: OrderedDict[str | tuple, "SerializationProxy"] = OrderedDict() | ||
| # Values are (version, proxy) tuples to invalidate stale entries | ||
| self._attr_cache: OrderedDict[ | ||
| str | tuple, tuple[int, "SerializationProxy"] | ||
| ] = OrderedDict() | ||
| self._attr_cache_lock = RLock() | ||
| # Version counter for refresh() coherence (bumped on each refresh) | ||
| self._version = 0 | ||
|
|
||
| def _current_version(self) -> int: | ||
| """Get the current version counter for cache coherence. | ||
|
|
||
| Reading an int is atomic in CPython due to the GIL, so no lock needed. | ||
| Worst case: stale read leads to rejected cache entry on next access (safe). | ||
| """ | ||
| return self._version | ||
|
|
||
| @classmethod | ||
| def _build( | ||
|
|
@@ -275,11 +324,9 @@ def _build( | |
| adapter: TypeAdapter, | ||
| core_schema: CoreSchema, | ||
| ): | ||
| # Normalize: wrap dicts to ensure immutability | ||
| if isinstance(serialized, dict) and not isinstance( | ||
| serialized, MappingProxyType | ||
| ): | ||
| serialized = MappingProxyType(serialized) | ||
| # Freeze collections recursively for true immutability | ||
| # (This also wraps dicts in MappingProxyType) | ||
| serialized = _freeze_collections(serialized) | ||
|
|
||
| schema_id = id(core_schema) | ||
|
|
||
|
|
@@ -339,9 +386,8 @@ def build( | |
| if adapter is None: | ||
| adapter = TypeAdapter(type(obj)) | ||
| serialized = adapter.dump_python(obj) | ||
| # If it's a dict, make it read-only to prevent accidental mutation from other threads | ||
| if isinstance(serialized, dict): | ||
| serialized = MappingProxyType(serialized) | ||
| # Freeze all collections recursively for true immutability | ||
| serialized = _freeze_collections(serialized) | ||
| core_schema = adapter.core_schema | ||
| return cls._build(obj, serialized, adapter, core_schema) | ||
|
|
||
|
|
@@ -357,6 +403,9 @@ def unwrap(self) -> T: | |
| def mapping(self) -> Mapping: | ||
| """Access the underlying serialized mapping directly. | ||
|
|
||
| Returns an **immutable** Mapping (typically MappingProxyType) to prevent | ||
| mutation of the snapshot. All nested collections are recursively frozen. | ||
|
|
||
| Use this to access mapping methods when your data has fields with those names. | ||
| For example, if you have a field named 'items': | ||
| - Access the field: proxy.items or proxy['items'] | ||
|
|
@@ -373,7 +422,7 @@ def mapping(self) -> Mapping: | |
| Use iteration or indexing directly for sequences. | ||
| """ | ||
| ser = self.serialized | ||
| # Ensure it's wrapped for immutability | ||
| # Ensure it's wrapped for immutability (should already be from _freeze_collections) | ||
| if isinstance(ser, dict) and not isinstance(ser, MappingProxyType): | ||
| ser = MappingProxyType(ser) | ||
| # Guard against calling .items() on a sequence | ||
|
|
@@ -389,21 +438,37 @@ def refresh(self) -> None: | |
|
|
||
| This recomputes the serialization and clears the attribute cache. | ||
| Useful when the underlying object changes and you want the proxy to reflect those changes. | ||
|
|
||
| Thread safety: Uses copy-on-write with version stamping to prevent races. | ||
| Any cache entries built before the refresh are invalidated by version mismatch. | ||
| """ | ||
| self.serialized = self.root_adapter.dump_python(self.obj) | ||
| if isinstance(self.serialized, dict): | ||
| self.serialized = MappingProxyType(self.serialized) | ||
| # Compute new snapshot outside the lock (expensive operation) | ||
| new_serialized = self.root_adapter.dump_python(self.obj) | ||
| new_serialized = _freeze_collections(new_serialized) | ||
|
|
||
| # Atomically swap snapshot, bump version, and clear cache | ||
| with self._attr_cache_lock: | ||
| self.serialized = new_serialized | ||
| self._version += 1 | ||
| self._attr_cache.clear() | ||
|
|
||
| def __getattr__(self, name: str): | ||
| # Check attribute cache first (LRU: move to end on access) | ||
| # Capture current version for cache coherence | ||
| ver = self._current_version() | ||
|
|
||
| # Check cache with version validation | ||
| with self._attr_cache_lock: | ||
| if name in self._attr_cache: | ||
| self._attr_cache.move_to_end(name) | ||
| return self._attr_cache[name] | ||
|
|
||
| # Hoist to local for micro-optimization | ||
| cached_ver, cached_proxy = self._attr_cache[name] | ||
| if cached_ver == ver: | ||
| # Cache hit with matching version | ||
| self._attr_cache.move_to_end(name) | ||
| return cached_proxy | ||
| else: | ||
| # Stale entry from pre-refresh, evict it | ||
| self._attr_cache.pop(name, None) | ||
|
|
||
| # Cache miss or stale - build new proxy | ||
| ser = self.serialized | ||
| if isinstance(ser, _MAPPING_TYPES) and name in ser: | ||
| sub_schema = _extract_subschema(self.core_schema, name) | ||
|
|
@@ -415,34 +480,43 @@ def __getattr__(self, name: str): | |
| if not isinstance(child_ser, _COLLECTION_TYPES): | ||
| return child_ser | ||
|
|
||
| # Wrap child dicts to prevent mutation | ||
| if isinstance(child_ser, dict): | ||
| child_ser = MappingProxyType(child_ser) | ||
| # child_ser is already frozen by _freeze_collections, but _build expects it | ||
| proxy = self._build( | ||
| getattr(self.obj, name), | ||
| child_ser, | ||
| self.root_adapter, | ||
| sub_schema, | ||
| ) | ||
| # Cache the built proxy with LRU eviction | ||
| # Cache with version stamp for coherence | ||
| with self._attr_cache_lock: | ||
| # Prune BEFORE insert to maintain strict bound | ||
| if len(self._attr_cache) >= ATTR_CACHE_SIZE: | ||
| self._attr_cache.popitem(last=False) | ||
| self._attr_cache[name] = proxy | ||
| self._attr_cache[name] = (ver, proxy) | ||
| self._attr_cache.move_to_end(name) | ||
| return proxy | ||
| return getattr(self.obj, name) | ||
|
|
||
| def __getitem__(self, key): | ||
| # Capture current version for cache coherence | ||
| ver = self._current_version() | ||
|
|
||
| # For getitem, we use a tuple for cache key to avoid collisions | ||
| cache_key = ("__item__", key) | ||
|
|
||
| # Check cache with version validation | ||
| with self._attr_cache_lock: | ||
| if cache_key in self._attr_cache: | ||
| self._attr_cache.move_to_end(cache_key) | ||
| return self._attr_cache[cache_key] | ||
|
|
||
| # Hoist to local for micro-optimization | ||
| cached_ver, cached_proxy = self._attr_cache[cache_key] | ||
| if cached_ver == ver: | ||
| # Cache hit with matching version | ||
| self._attr_cache.move_to_end(cache_key) | ||
| return cached_proxy | ||
| else: | ||
| # Stale entry from pre-refresh, evict it | ||
| self._attr_cache.pop(cache_key, None) | ||
|
Comment on lines
508
to
+517
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for checking and retrieving from the versioned cache is duplicated here and in |
||
|
|
||
| # Cache miss or stale - build new proxy | ||
| ser = self.serialized | ||
| sub_schema = _extract_subschema(self.core_schema, key) | ||
| # ser is Mapping|Sequence|object, but we know it supports __getitem__ at runtime | ||
|
|
@@ -454,9 +528,7 @@ def __getitem__(self, key): | |
| if not isinstance(child_ser, _COLLECTION_TYPES): | ||
| return child_ser | ||
|
|
||
| # Wrap child dicts to prevent mutation | ||
| if isinstance(child_ser, dict): | ||
| child_ser = MappingProxyType(child_ser) | ||
| # child_ser is already frozen by _freeze_collections | ||
|
|
||
| # Try to keep the real underlying object if possible; otherwise fall back to serialized | ||
| try: | ||
|
|
@@ -474,12 +546,12 @@ def __getitem__(self, key): | |
| sub_schema, | ||
| ) | ||
|
|
||
| # Cache the built proxy with LRU eviction | ||
| # Cache with version stamp for coherence | ||
| with self._attr_cache_lock: | ||
| # Prune BEFORE insert to maintain strict bound | ||
| if len(self._attr_cache) >= ATTR_CACHE_SIZE: | ||
| self._attr_cache.popitem(last=False) | ||
| self._attr_cache[cache_key] = proxy | ||
| self._attr_cache[cache_key] = (ver, proxy) | ||
| self._attr_cache.move_to_end(cache_key) | ||
| return proxy | ||
|
|
||
|
|
@@ -530,7 +602,12 @@ def __bool__(self): | |
|
|
||
| # Mapping-like methods for Jinja ergonomics | ||
| def __contains__(self, key): | ||
| """Check if key exists in the serialized data.""" | ||
| """Check if key/value exists in the serialized data. | ||
|
|
||
| Semantics depend on the wrapped type: | ||
| - Mapping: Checks for key membership (like dict.__contains__) | ||
| - Sequence: Checks for value membership (like list.__contains__) | ||
| """ | ||
| try: | ||
| return key in self.serialized | ||
| except TypeError: | ||
|
|
@@ -569,14 +646,16 @@ def __reversed__(self): | |
| def __repr__(self): | ||
| match self.obj: | ||
| case Dataclass(): | ||
| # Access metadata from obj directly to avoid __getattr__ lookup | ||
| attrs = ", ".join( | ||
| f"{name}={getattr(self, name)!r}" | ||
| for name in self.__dataclass_fields__ | ||
| for name in self.obj.__dataclass_fields__ | ||
| ) | ||
| return f"{type(self.obj).__name__}({attrs})" | ||
| case BaseModel(): | ||
| # Access metadata from obj directly to avoid __getattr__ lookup | ||
| attrs = ", ".join( | ||
| f"{name}={getattr(self, name)!r}" for name in self.model_fields | ||
| f"{name}={getattr(self, name)!r}" for name in self.obj.model_fields | ||
| ) | ||
| return f"{type(self.obj).__name__}({attrs})" | ||
| case _: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check appears to be redundant. The
self.serializedattribute is consistently processed by_freeze_collectionsduring proxy creation and refresh, which wraps anydictin aMappingProxyType. Therefore,isinstance(ser, dict)should always beFalsehere for a mapping type. Removing this defensive code would make the implementation cleaner and more reliant on the immutability guarantee provided elsewhere.