Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 122 additions & 43 deletions src/deigma/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def _extract_subschema(schema: CoreSchema, name_or_idx) -> CoreSchema:
# Prefer strict schema for subfield extraction
return _extract_subschema(strict, name_or_idx)
# JSON-or-python wrapper
case {"type": "json-or-python", "json_schema": inner} | {"type": "json-or-python", "python_schema": inner}:
case {"type": "json-or-python", "json_schema": inner} | {
"type": "json-or-python",
"python_schema": inner,
}:
# Recurse through the inner schema
return _extract_subschema(inner, name_or_idx)
# BaseModel
Expand Down Expand Up @@ -118,6 +121,8 @@ def _unwrap_proxy_and_apply(

Note: SerializerFunction is a union of many function signatures with different
arities. We wrap them generically here, which pyright can't fully verify.
Pydantic-core may pass additional keyword-only args (e.g., mode), hence the
*args pattern to forward all positional arguments.
"""
if pass_info:

Expand All @@ -133,14 +138,40 @@ def apply_to_unwrapped(proxy: "SerializationProxy[T]", *args):
return apply_to_unwrapped # pyright: ignore[reportReturnType]


def _freeze_collections(obj):
"""Recursively freeze mutable collections for true snapshot immutability.

This ensures that even direct access to .mapping or .serialized cannot
mutate the snapshot. Lists become tuples, dicts become MappingProxyType.

Args:
obj: The object to freeze (can be dict, list, tuple, or primitive)

Returns:
Frozen version of the object with all nested collections immutable
"""
if isinstance(obj, dict):
# Freeze nested values, then wrap dict
return MappingProxyType({k: _freeze_collections(v) for k, v in obj.items()})
elif isinstance(obj, list):
# Convert list to tuple, freeze nested items
return tuple(_freeze_collections(item) for item in obj)
elif isinstance(obj, tuple):
# Already immutable container, but freeze nested items
return tuple(_freeze_collections(item) for item in obj)
else:
# Primitive or other immutable type (str, int, float, bool, None, etc.)
return obj


# Cache size constants for memory management
WRAPPED_SCHEMA_CACHE_SIZE = 256
PROXY_TYPE_CACHE_SIZE = 256
ATTR_CACHE_SIZE = 512

# Type-check tuples (hoisted to module scope for micro-optimization)
_MAPPING_TYPES = (dict, MappingProxyType, Mapping)
_COLLECTION_TYPES = (dict, list, tuple)
_COLLECTION_TYPES = (dict, list, tuple, MappingProxyType)

# Bounded cache for wrapped schemas to prevent memory leaks in long-running applications
# Using OrderedDict for LRU eviction
Expand All @@ -153,7 +184,6 @@ def _wrap_core_schema(schema: CoreSchema) -> CoreSchema:
"""Wrap a CoreSchema to make it proxy-aware. Uses bounded LRU cache to avoid expensive deepcopy."""
schema_id = id(schema)

# Check cache under lock (OrderedDict isn't thread-safe even for reads)
with _WRAPPED_SCHEMA_CACHE_LOCK:
tup = _wrapped_schema_cache.get(schema_id)
if tup is not None:
Expand All @@ -163,8 +193,6 @@ def _wrap_core_schema(schema: CoreSchema) -> CoreSchema:
_wrapped_schema_cache.move_to_end(schema_id)
return wrapped

# Build wrapped schema (outside lock to minimize critical section)
# Key insight: Don't mutate "type" - only wrap serialization behavior
match schema:
# something we can reference to (e.g. BaseModel, Dataclass, ...)
case {"ref": ref}:
Expand All @@ -185,16 +213,15 @@ def _wrap_core_schema(schema: CoreSchema) -> CoreSchema:
# Pyright can't verify all possible serialization schema types support "function".
wrapped_schema["serialization"]["function"] = _unwrap_proxy_and_apply( # pyright: ignore[reportGeneralTypeIssues]
func, # pyright: ignore[reportArgumentType]
pass_info=True
pass_info=True,
)
# Has custom serializer without info_arg
case {"serialization": {"function": func}}:
wrapped_schema = deepcopy(schema)
# We're wrapping the serializer function to unwrap proxies first.
# Pyright can't verify all possible serialization schema types support "function".
wrapped_schema["serialization"]["function"] = _unwrap_proxy_and_apply( # pyright: ignore[reportGeneralTypeIssues]
func, # pyright: ignore[reportArgumentType]
pass_info=False
pass_info=False,
)
# No custom serializer - add one
case _:
Expand Down Expand Up @@ -244,6 +271,15 @@ class SerializationProxy(Generic[T]):
- Deterministic serialization snapshots for concurrent operations
"""

__slots__ = (
"obj",
"serialized",
"root_adapter",
"_attr_cache",
"_attr_cache_lock",
"_version",
)

core_schema: CoreSchema
__pydantic_serializer__: SchemaSerializer
# Note: __pydantic_validator__ is intentionally not set to avoid
Expand All @@ -264,8 +300,21 @@ def __init__(
self.root_adapter = root_adapter
# Bounded LRU cache for accessed attributes to avoid rebuilding proxies
# Keys are either strings (for attributes) or tuples (for items)
self._attr_cache: OrderedDict[str | tuple, "SerializationProxy"] = OrderedDict()
# Values are (version, proxy) tuples to invalidate stale entries
self._attr_cache: OrderedDict[
str | tuple, tuple[int, "SerializationProxy"]
] = OrderedDict()
self._attr_cache_lock = RLock()
# Version counter for refresh() coherence (bumped on each refresh)
self._version = 0

def _current_version(self) -> int:
"""Get the current version counter for cache coherence.

Reading an int is atomic in CPython due to the GIL, so no lock needed.
Worst case: stale read leads to rejected cache entry on next access (safe).
"""
return self._version

@classmethod
def _build(
Expand All @@ -275,11 +324,9 @@ def _build(
adapter: TypeAdapter,
core_schema: CoreSchema,
):
# Normalize: wrap dicts to ensure immutability
if isinstance(serialized, dict) and not isinstance(
serialized, MappingProxyType
):
serialized = MappingProxyType(serialized)
# Freeze collections recursively for true immutability
# (This also wraps dicts in MappingProxyType)
serialized = _freeze_collections(serialized)

schema_id = id(core_schema)

Expand Down Expand Up @@ -339,9 +386,8 @@ def build(
if adapter is None:
adapter = TypeAdapter(type(obj))
serialized = adapter.dump_python(obj)
# If it's a dict, make it read-only to prevent accidental mutation from other threads
if isinstance(serialized, dict):
serialized = MappingProxyType(serialized)
# Freeze all collections recursively for true immutability
serialized = _freeze_collections(serialized)
core_schema = adapter.core_schema
return cls._build(obj, serialized, adapter, core_schema)

Expand All @@ -357,6 +403,9 @@ def unwrap(self) -> T:
def mapping(self) -> Mapping:
"""Access the underlying serialized mapping directly.

Returns an **immutable** Mapping (typically MappingProxyType) to prevent
mutation of the snapshot. All nested collections are recursively frozen.

Use this to access mapping methods when your data has fields with those names.
For example, if you have a field named 'items':
- Access the field: proxy.items or proxy['items']
Expand All @@ -373,7 +422,7 @@ def mapping(self) -> Mapping:
Use iteration or indexing directly for sequences.
"""
ser = self.serialized
# Ensure it's wrapped for immutability
# Ensure it's wrapped for immutability (should already be from _freeze_collections)
if isinstance(ser, dict) and not isinstance(ser, MappingProxyType):
ser = MappingProxyType(ser)
Comment on lines 426 to 427

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check appears to be redundant. The self.serialized attribute is consistently processed by _freeze_collections during proxy creation and refresh, which wraps any dict in a MappingProxyType. Therefore, isinstance(ser, dict) should always be False here for a mapping type. Removing this defensive code would make the implementation cleaner and more reliant on the immutability guarantee provided elsewhere.

# Guard against calling .items() on a sequence
Expand All @@ -389,21 +438,37 @@ def refresh(self) -> None:

This recomputes the serialization and clears the attribute cache.
Useful when the underlying object changes and you want the proxy to reflect those changes.

Thread safety: Uses copy-on-write with version stamping to prevent races.
Any cache entries built before the refresh are invalidated by version mismatch.
"""
self.serialized = self.root_adapter.dump_python(self.obj)
if isinstance(self.serialized, dict):
self.serialized = MappingProxyType(self.serialized)
# Compute new snapshot outside the lock (expensive operation)
new_serialized = self.root_adapter.dump_python(self.obj)
new_serialized = _freeze_collections(new_serialized)

# Atomically swap snapshot, bump version, and clear cache
with self._attr_cache_lock:
self.serialized = new_serialized
self._version += 1
self._attr_cache.clear()

def __getattr__(self, name: str):
# Check attribute cache first (LRU: move to end on access)
# Capture current version for cache coherence
ver = self._current_version()

# Check cache with version validation
with self._attr_cache_lock:
if name in self._attr_cache:
self._attr_cache.move_to_end(name)
return self._attr_cache[name]

# Hoist to local for micro-optimization
cached_ver, cached_proxy = self._attr_cache[name]
if cached_ver == ver:
# Cache hit with matching version
self._attr_cache.move_to_end(name)
return cached_proxy
else:
# Stale entry from pre-refresh, evict it
self._attr_cache.pop(name, None)

# Cache miss or stale - build new proxy
ser = self.serialized
if isinstance(ser, _MAPPING_TYPES) and name in ser:
sub_schema = _extract_subschema(self.core_schema, name)
Expand All @@ -415,34 +480,43 @@ def __getattr__(self, name: str):
if not isinstance(child_ser, _COLLECTION_TYPES):
return child_ser

# Wrap child dicts to prevent mutation
if isinstance(child_ser, dict):
child_ser = MappingProxyType(child_ser)
# child_ser is already frozen by _freeze_collections, but _build expects it
proxy = self._build(
getattr(self.obj, name),
child_ser,
self.root_adapter,
sub_schema,
)
# Cache the built proxy with LRU eviction
# Cache with version stamp for coherence
with self._attr_cache_lock:
# Prune BEFORE insert to maintain strict bound
if len(self._attr_cache) >= ATTR_CACHE_SIZE:
self._attr_cache.popitem(last=False)
self._attr_cache[name] = proxy
self._attr_cache[name] = (ver, proxy)
self._attr_cache.move_to_end(name)
return proxy
return getattr(self.obj, name)

def __getitem__(self, key):
# Capture current version for cache coherence
ver = self._current_version()

# For getitem, we use a tuple for cache key to avoid collisions
cache_key = ("__item__", key)

# Check cache with version validation
with self._attr_cache_lock:
if cache_key in self._attr_cache:
self._attr_cache.move_to_end(cache_key)
return self._attr_cache[cache_key]

# Hoist to local for micro-optimization
cached_ver, cached_proxy = self._attr_cache[cache_key]
if cached_ver == ver:
# Cache hit with matching version
self._attr_cache.move_to_end(cache_key)
return cached_proxy
else:
# Stale entry from pre-refresh, evict it
self._attr_cache.pop(cache_key, None)
Comment on lines 508 to +517

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for checking and retrieving from the versioned cache is duplicated here and in __getattr__ (lines 460-470). Similarly, the logic for adding a new entry to the cache is duplicated (here at lines 550-555 and in __getattr__ at lines 491-496). To improve maintainability and reduce code duplication, consider extracting this cache management logic into private helper methods (e.g., _get_cached_proxy and _cache_proxy).


# Cache miss or stale - build new proxy
ser = self.serialized
sub_schema = _extract_subschema(self.core_schema, key)
# ser is Mapping|Sequence|object, but we know it supports __getitem__ at runtime
Expand All @@ -454,9 +528,7 @@ def __getitem__(self, key):
if not isinstance(child_ser, _COLLECTION_TYPES):
return child_ser

# Wrap child dicts to prevent mutation
if isinstance(child_ser, dict):
child_ser = MappingProxyType(child_ser)
# child_ser is already frozen by _freeze_collections

# Try to keep the real underlying object if possible; otherwise fall back to serialized
try:
Expand All @@ -474,12 +546,12 @@ def __getitem__(self, key):
sub_schema,
)

# Cache the built proxy with LRU eviction
# Cache with version stamp for coherence
with self._attr_cache_lock:
# Prune BEFORE insert to maintain strict bound
if len(self._attr_cache) >= ATTR_CACHE_SIZE:
self._attr_cache.popitem(last=False)
self._attr_cache[cache_key] = proxy
self._attr_cache[cache_key] = (ver, proxy)
self._attr_cache.move_to_end(cache_key)
return proxy

Expand Down Expand Up @@ -530,7 +602,12 @@ def __bool__(self):

# Mapping-like methods for Jinja ergonomics
def __contains__(self, key):
"""Check if key exists in the serialized data."""
"""Check if key/value exists in the serialized data.

Semantics depend on the wrapped type:
- Mapping: Checks for key membership (like dict.__contains__)
- Sequence: Checks for value membership (like list.__contains__)
"""
try:
return key in self.serialized
except TypeError:
Expand Down Expand Up @@ -569,14 +646,16 @@ def __reversed__(self):
def __repr__(self):
match self.obj:
case Dataclass():
# Access metadata from obj directly to avoid __getattr__ lookup
attrs = ", ".join(
f"{name}={getattr(self, name)!r}"
for name in self.__dataclass_fields__
for name in self.obj.__dataclass_fields__
)
return f"{type(self.obj).__name__}({attrs})"
case BaseModel():
# Access metadata from obj directly to avoid __getattr__ lookup
attrs = ", ".join(
f"{name}={getattr(self, name)!r}" for name in self.model_fields
f"{name}={getattr(self, name)!r}" for name in self.obj.model_fields
)
return f"{type(self.obj).__name__}({attrs})"
case _:
Expand Down