diff --git a/src/deigma/proxy.py b/src/deigma/proxy.py index 5a8c2e4..bad2c15 100644 --- a/src/deigma/proxy.py +++ b/src/deigma/proxy.py @@ -49,7 +49,10 @@ def _extract_subschema(schema: CoreSchema, name_or_idx) -> CoreSchema: # Prefer strict schema for subfield extraction return _extract_subschema(strict, name_or_idx) # JSON-or-python wrapper - case {"type": "json-or-python", "json_schema": inner} | {"type": "json-or-python", "python_schema": inner}: + case {"type": "json-or-python", "json_schema": inner} | { + "type": "json-or-python", + "python_schema": inner, + }: # Recurse through the inner schema return _extract_subschema(inner, name_or_idx) # BaseModel @@ -118,6 +121,8 @@ def _unwrap_proxy_and_apply( Note: SerializerFunction is a union of many function signatures with different arities. We wrap them generically here, which pyright can't fully verify. + Pydantic-core may pass additional keyword-only args (e.g., mode), hence the + *args pattern to forward all positional arguments. """ if pass_info: @@ -133,6 +138,32 @@ def apply_to_unwrapped(proxy: "SerializationProxy[T]", *args): return apply_to_unwrapped # pyright: ignore[reportReturnType] +def _freeze_collections(obj): + """Recursively freeze mutable collections for true snapshot immutability. + + This ensures that even direct access to .mapping or .serialized cannot + mutate the snapshot. Lists become tuples, dicts become MappingProxyType. + + Args: + obj: The object to freeze (can be dict, list, tuple, or primitive) + + Returns: + Frozen version of the object with all nested collections immutable + """ + if isinstance(obj, dict): + # Freeze nested values, then wrap dict + return MappingProxyType({k: _freeze_collections(v) for k, v in obj.items()}) + elif isinstance(obj, list): + # Convert list to tuple, freeze nested items + return tuple(_freeze_collections(item) for item in obj) + elif isinstance(obj, tuple): + # Already immutable container, but freeze nested items + return tuple(_freeze_collections(item) for item in obj) + else: + # Primitive or other immutable type (str, int, float, bool, None, etc.) + return obj + + # Cache size constants for memory management WRAPPED_SCHEMA_CACHE_SIZE = 256 PROXY_TYPE_CACHE_SIZE = 256 @@ -140,7 +171,7 @@ def apply_to_unwrapped(proxy: "SerializationProxy[T]", *args): # Type-check tuples (hoisted to module scope for micro-optimization) _MAPPING_TYPES = (dict, MappingProxyType, Mapping) -_COLLECTION_TYPES = (dict, list, tuple) +_COLLECTION_TYPES = (dict, list, tuple, MappingProxyType) # Bounded cache for wrapped schemas to prevent memory leaks in long-running applications # Using OrderedDict for LRU eviction @@ -153,7 +184,6 @@ def _wrap_core_schema(schema: CoreSchema) -> CoreSchema: """Wrap a CoreSchema to make it proxy-aware. Uses bounded LRU cache to avoid expensive deepcopy.""" schema_id = id(schema) - # Check cache under lock (OrderedDict isn't thread-safe even for reads) with _WRAPPED_SCHEMA_CACHE_LOCK: tup = _wrapped_schema_cache.get(schema_id) if tup is not None: @@ -163,8 +193,6 @@ def _wrap_core_schema(schema: CoreSchema) -> CoreSchema: _wrapped_schema_cache.move_to_end(schema_id) return wrapped - # Build wrapped schema (outside lock to minimize critical section) - # Key insight: Don't mutate "type" - only wrap serialization behavior match schema: # something we can reference to (e.g. BaseModel, Dataclass, ...) case {"ref": ref}: @@ -185,16 +213,15 @@ def _wrap_core_schema(schema: CoreSchema) -> CoreSchema: # Pyright can't verify all possible serialization schema types support "function". wrapped_schema["serialization"]["function"] = _unwrap_proxy_and_apply( # pyright: ignore[reportGeneralTypeIssues] func, # pyright: ignore[reportArgumentType] - pass_info=True + pass_info=True, ) # Has custom serializer without info_arg case {"serialization": {"function": func}}: wrapped_schema = deepcopy(schema) # We're wrapping the serializer function to unwrap proxies first. - # Pyright can't verify all possible serialization schema types support "function". wrapped_schema["serialization"]["function"] = _unwrap_proxy_and_apply( # pyright: ignore[reportGeneralTypeIssues] func, # pyright: ignore[reportArgumentType] - pass_info=False + pass_info=False, ) # No custom serializer - add one case _: @@ -244,6 +271,15 @@ class SerializationProxy(Generic[T]): - Deterministic serialization snapshots for concurrent operations """ + __slots__ = ( + "obj", + "serialized", + "root_adapter", + "_attr_cache", + "_attr_cache_lock", + "_version", + ) + core_schema: CoreSchema __pydantic_serializer__: SchemaSerializer # Note: __pydantic_validator__ is intentionally not set to avoid @@ -264,8 +300,21 @@ def __init__( self.root_adapter = root_adapter # Bounded LRU cache for accessed attributes to avoid rebuilding proxies # Keys are either strings (for attributes) or tuples (for items) - self._attr_cache: OrderedDict[str | tuple, "SerializationProxy"] = OrderedDict() + # Values are (version, proxy) tuples to invalidate stale entries + self._attr_cache: OrderedDict[ + str | tuple, tuple[int, "SerializationProxy"] + ] = OrderedDict() self._attr_cache_lock = RLock() + # Version counter for refresh() coherence (bumped on each refresh) + self._version = 0 + + def _current_version(self) -> int: + """Get the current version counter for cache coherence. + + Reading an int is atomic in CPython due to the GIL, so no lock needed. + Worst case: stale read leads to rejected cache entry on next access (safe). + """ + return self._version @classmethod def _build( @@ -275,11 +324,9 @@ def _build( adapter: TypeAdapter, core_schema: CoreSchema, ): - # Normalize: wrap dicts to ensure immutability - if isinstance(serialized, dict) and not isinstance( - serialized, MappingProxyType - ): - serialized = MappingProxyType(serialized) + # Freeze collections recursively for true immutability + # (This also wraps dicts in MappingProxyType) + serialized = _freeze_collections(serialized) schema_id = id(core_schema) @@ -339,9 +386,8 @@ def build( if adapter is None: adapter = TypeAdapter(type(obj)) serialized = adapter.dump_python(obj) - # If it's a dict, make it read-only to prevent accidental mutation from other threads - if isinstance(serialized, dict): - serialized = MappingProxyType(serialized) + # Freeze all collections recursively for true immutability + serialized = _freeze_collections(serialized) core_schema = adapter.core_schema return cls._build(obj, serialized, adapter, core_schema) @@ -357,6 +403,9 @@ def unwrap(self) -> T: def mapping(self) -> Mapping: """Access the underlying serialized mapping directly. + Returns an **immutable** Mapping (typically MappingProxyType) to prevent + mutation of the snapshot. All nested collections are recursively frozen. + Use this to access mapping methods when your data has fields with those names. For example, if you have a field named 'items': - Access the field: proxy.items or proxy['items'] @@ -373,7 +422,7 @@ def mapping(self) -> Mapping: Use iteration or indexing directly for sequences. """ ser = self.serialized - # Ensure it's wrapped for immutability + # Ensure it's wrapped for immutability (should already be from _freeze_collections) if isinstance(ser, dict) and not isinstance(ser, MappingProxyType): ser = MappingProxyType(ser) # Guard against calling .items() on a sequence @@ -389,21 +438,37 @@ def refresh(self) -> None: This recomputes the serialization and clears the attribute cache. Useful when the underlying object changes and you want the proxy to reflect those changes. + + Thread safety: Uses copy-on-write with version stamping to prevent races. + Any cache entries built before the refresh are invalidated by version mismatch. """ - self.serialized = self.root_adapter.dump_python(self.obj) - if isinstance(self.serialized, dict): - self.serialized = MappingProxyType(self.serialized) + # Compute new snapshot outside the lock (expensive operation) + new_serialized = self.root_adapter.dump_python(self.obj) + new_serialized = _freeze_collections(new_serialized) + + # Atomically swap snapshot, bump version, and clear cache with self._attr_cache_lock: + self.serialized = new_serialized + self._version += 1 self._attr_cache.clear() def __getattr__(self, name: str): - # Check attribute cache first (LRU: move to end on access) + # Capture current version for cache coherence + ver = self._current_version() + + # Check cache with version validation with self._attr_cache_lock: if name in self._attr_cache: - self._attr_cache.move_to_end(name) - return self._attr_cache[name] - - # Hoist to local for micro-optimization + cached_ver, cached_proxy = self._attr_cache[name] + if cached_ver == ver: + # Cache hit with matching version + self._attr_cache.move_to_end(name) + return cached_proxy + else: + # Stale entry from pre-refresh, evict it + self._attr_cache.pop(name, None) + + # Cache miss or stale - build new proxy ser = self.serialized if isinstance(ser, _MAPPING_TYPES) and name in ser: sub_schema = _extract_subschema(self.core_schema, name) @@ -415,34 +480,43 @@ def __getattr__(self, name: str): if not isinstance(child_ser, _COLLECTION_TYPES): return child_ser - # Wrap child dicts to prevent mutation - if isinstance(child_ser, dict): - child_ser = MappingProxyType(child_ser) + # child_ser is already frozen by _freeze_collections, but _build expects it proxy = self._build( getattr(self.obj, name), child_ser, self.root_adapter, sub_schema, ) - # Cache the built proxy with LRU eviction + # Cache with version stamp for coherence with self._attr_cache_lock: # Prune BEFORE insert to maintain strict bound if len(self._attr_cache) >= ATTR_CACHE_SIZE: self._attr_cache.popitem(last=False) - self._attr_cache[name] = proxy + self._attr_cache[name] = (ver, proxy) self._attr_cache.move_to_end(name) return proxy return getattr(self.obj, name) def __getitem__(self, key): + # Capture current version for cache coherence + ver = self._current_version() + # For getitem, we use a tuple for cache key to avoid collisions cache_key = ("__item__", key) + + # Check cache with version validation with self._attr_cache_lock: if cache_key in self._attr_cache: - self._attr_cache.move_to_end(cache_key) - return self._attr_cache[cache_key] - - # Hoist to local for micro-optimization + cached_ver, cached_proxy = self._attr_cache[cache_key] + if cached_ver == ver: + # Cache hit with matching version + self._attr_cache.move_to_end(cache_key) + return cached_proxy + else: + # Stale entry from pre-refresh, evict it + self._attr_cache.pop(cache_key, None) + + # Cache miss or stale - build new proxy ser = self.serialized sub_schema = _extract_subschema(self.core_schema, key) # ser is Mapping|Sequence|object, but we know it supports __getitem__ at runtime @@ -454,9 +528,7 @@ def __getitem__(self, key): if not isinstance(child_ser, _COLLECTION_TYPES): return child_ser - # Wrap child dicts to prevent mutation - if isinstance(child_ser, dict): - child_ser = MappingProxyType(child_ser) + # child_ser is already frozen by _freeze_collections # Try to keep the real underlying object if possible; otherwise fall back to serialized try: @@ -474,12 +546,12 @@ def __getitem__(self, key): sub_schema, ) - # Cache the built proxy with LRU eviction + # Cache with version stamp for coherence with self._attr_cache_lock: # Prune BEFORE insert to maintain strict bound if len(self._attr_cache) >= ATTR_CACHE_SIZE: self._attr_cache.popitem(last=False) - self._attr_cache[cache_key] = proxy + self._attr_cache[cache_key] = (ver, proxy) self._attr_cache.move_to_end(cache_key) return proxy @@ -530,7 +602,12 @@ def __bool__(self): # Mapping-like methods for Jinja ergonomics def __contains__(self, key): - """Check if key exists in the serialized data.""" + """Check if key/value exists in the serialized data. + + Semantics depend on the wrapped type: + - Mapping: Checks for key membership (like dict.__contains__) + - Sequence: Checks for value membership (like list.__contains__) + """ try: return key in self.serialized except TypeError: @@ -569,14 +646,16 @@ def __reversed__(self): def __repr__(self): match self.obj: case Dataclass(): + # Access metadata from obj directly to avoid __getattr__ lookup attrs = ", ".join( f"{name}={getattr(self, name)!r}" - for name in self.__dataclass_fields__ + for name in self.obj.__dataclass_fields__ ) return f"{type(self.obj).__name__}({attrs})" case BaseModel(): + # Access metadata from obj directly to avoid __getattr__ lookup attrs = ", ".join( - f"{name}={getattr(self, name)!r}" for name in self.model_fields + f"{name}={getattr(self, name)!r}" for name in self.obj.model_fields ) return f"{type(self.obj).__name__}({attrs})" case _: