Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 131 additions & 63 deletions src/deigma/proxy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import OrderedDict
from collections.abc import Callable, Iterable, Mapping
from copy import deepcopy
from threading import Lock, RLock
from types import MappingProxyType
from typing import Generic, NamedTuple, TypeGuard, TypeVar

Expand Down Expand Up @@ -61,23 +62,39 @@ def apply_to_unwrapped(proxy: "SerializationProxy[T]") -> T:

# Bounded cache for wrapped schemas to prevent memory leaks in long-running applications
# Using OrderedDict for LRU eviction
# Store tuple (orig_schema, wrapped_schema) to prevent id() reuse bugs
_WRAPPED_SCHEMA_CACHE_SIZE = 256
_wrapped_schema_cache: OrderedDict[int, CoreSchema] = OrderedDict()
_wrapped_schema_cache: OrderedDict[int, tuple[CoreSchema, CoreSchema]] = OrderedDict()
_WRAPPED_SCHEMA_CACHE_LOCK = Lock()


def _wrap_core_schema(schema: CoreSchema) -> CoreSchema:
"""Wrap a CoreSchema to make it proxy-aware. Uses bounded LRU cache to avoid expensive deepcopy."""
schema_id = id(schema)

# Check cache first (LRU: move to end if found)
if schema_id in _wrapped_schema_cache:
# Move to end (most recently used)
_wrapped_schema_cache.move_to_end(schema_id)
return _wrapped_schema_cache[schema_id]

# Build wrapped schema
# Lock-free fast path: check cache without lock first
# Note: OrderedDict.get() can raise during concurrent modifications, so we catch any errors
try:
tup = _wrapped_schema_cache.get(schema_id)
if tup is not None and tup[0] is schema:
return tup[1]
except (ValueError, RuntimeError):
# Race condition detected, fall through to locked path
pass

# Slow path: take lock and recheck
with _WRAPPED_SCHEMA_CACHE_LOCK:
tup = _wrapped_schema_cache.get(schema_id)
if tup is not None:
orig, wrapped = tup
# Guard against id() reuse by verifying identity
if orig is schema:
_wrapped_schema_cache.move_to_end(schema_id)
return wrapped

# Build wrapped schema (outside lock to minimize critical section)
match schema:
# someting we can reference to (e.g. BaseModel, Dataclass, ...)
# something we can reference to (e.g. BaseModel, Dataclass, ...)
case {"ref": ref}:
wrapped_schema = core_schema.definitions_schema(
schema=core_schema.definition_reference_schema(
Expand Down Expand Up @@ -105,12 +122,13 @@ def _wrap_core_schema(schema: CoreSchema) -> CoreSchema:
)

# Cache with LRU eviction
_wrapped_schema_cache[schema_id] = wrapped_schema
_wrapped_schema_cache.move_to_end(schema_id)
with _WRAPPED_SCHEMA_CACHE_LOCK:
_wrapped_schema_cache[schema_id] = (schema, wrapped_schema)
_wrapped_schema_cache.move_to_end(schema_id)

# Evict oldest entry if cache is too large
if len(_wrapped_schema_cache) > _WRAPPED_SCHEMA_CACHE_SIZE:
_wrapped_schema_cache.popitem(last=False)
# Evict oldest entry if cache is too large
if len(_wrapped_schema_cache) > _WRAPPED_SCHEMA_CACHE_SIZE:
_wrapped_schema_cache.popitem(last=False)

return wrapped_schema
Comment on lines +125 to 133

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While adding a lock here prevents race conditions, this implementation of double-checked locking can be made more efficient. If two threads call this function with the same uncached schema, both will perform the expensive schema wrapping. The second thread to acquire the lock will simply overwrite the result of the first.

To avoid this redundant work, you can re-check the cache after acquiring the lock for writing. If another thread has populated the cache in the meantime, you can use its result and discard your own. This pattern is already correctly implemented in the _build method. Applying it here would improve efficiency and consistency.

Suggested change
with _WRAPPED_SCHEMA_CACHE_LOCK:
_wrapped_schema_cache[schema_id] = (schema, wrapped_schema)
_wrapped_schema_cache.move_to_end(schema_id)
# Evict oldest entry if cache is too large
if len(_wrapped_schema_cache) > _WRAPPED_SCHEMA_CACHE_SIZE:
_wrapped_schema_cache.popitem(last=False)
# Evict oldest entry if cache is too large
if len(_wrapped_schema_cache) > _WRAPPED_SCHEMA_CACHE_SIZE:
_wrapped_schema_cache.popitem(last=False)
return wrapped_schema
with _WRAPPED_SCHEMA_CACHE_LOCK:
# Re-check if another thread has already wrapped this schema while we were working
tup = _wrapped_schema_cache.get(schema_id)
if tup is None or tup[0] is not schema:
_wrapped_schema_cache[schema_id] = (schema, wrapped_schema)
_wrapped_schema_cache.move_to_end(schema_id)
# Evict oldest entry if cache is too large
if len(_wrapped_schema_cache) > _WRAPPED_SCHEMA_CACHE_SIZE:
_wrapped_schema_cache.popitem(last=False)
return wrapped_schema
else:
# Another thread won the race, use its result and update LRU
_wrapped_schema_cache.move_to_end(schema_id)
return tup[1]


Expand All @@ -123,11 +141,12 @@ class SerializationProxy(Generic[T]):
# Bounded cache for proxy types to prevent memory leaks
_PROXY_TYPE_CACHE_SIZE = 256
_proxy_type_cache: OrderedDict[int, type["SerializationProxy"]] = OrderedDict()
_PROXY_TYPE_CACHE_LOCK = Lock()

def __init__(
self,
obj: T,
serialized: MappingProxyType,
serialized: Mapping | Iterable | object,
root_adapter: TypeAdapter,
):
self.obj = obj
Expand All @@ -136,42 +155,65 @@ def __init__(
# Cache for accessed attributes to avoid rebuilding proxies
# Keys are either strings (for attributes) or tuples (for items)
self._attr_cache: dict[str | tuple, "SerializationProxy"] = {}
self._attr_cache_lock = RLock()

@classmethod
def _build(
cls,
obj: T,
serialized: MappingProxyType,
serialized: Mapping | Iterable | object,
adapter: TypeAdapter,
core_schema: CoreSchema,
):
schema_id = id(core_schema)
# Normalize: wrap dicts to ensure immutability
if isinstance(serialized, dict) and not isinstance(serialized, MappingProxyType):
serialized = MappingProxyType(serialized)

# Check if we already have a cached proxy type for this schema (LRU)
if schema_id in cls._proxy_type_cache:
# Move to end (most recently used)
cls._proxy_type_cache.move_to_end(schema_id)
proxy_type = cls._proxy_type_cache[schema_id]
else:
# Build new proxy type
wrapped_core_schema = _wrap_core_schema(core_schema)
proxy_type = type(
f"SerializationProxy[{type(obj).__name__}]",
(cls,),
{
"core_schema": core_schema,
"__pydantic_serializer__": SchemaSerializer(wrapped_core_schema),
"__pydantic_core_schema__": wrapped_core_schema,
"__pydantic_validator__": adapter.validator,
},
)
# Cache the proxy type with LRU eviction
cls._proxy_type_cache[schema_id] = proxy_type
cls._proxy_type_cache.move_to_end(schema_id)
schema_id = id(core_schema)

# Evict oldest entry if cache is too large
if len(cls._proxy_type_cache) > cls._PROXY_TYPE_CACHE_SIZE:
cls._proxy_type_cache.popitem(last=False)
# Lock-free fast path: check cache without lock first
# Note: OrderedDict.get() can raise during concurrent modifications, so we catch any errors
try:
proxy_type = cls._proxy_type_cache.get(schema_id)
if proxy_type is not None and getattr(proxy_type, "core_schema", None) is core_schema:
return proxy_type(obj, serialized, adapter)
except (ValueError, RuntimeError):
# Race condition detected, fall through to locked path
pass

# Slow path: take lock and recheck
with cls._PROXY_TYPE_CACHE_LOCK:
proxy_type = cls._proxy_type_cache.get(schema_id)
# Guard against id() reuse by verifying identity
if proxy_type is not None and getattr(proxy_type, "core_schema", None) is core_schema:
cls._proxy_type_cache.move_to_end(schema_id)
return proxy_type(obj, serialized, adapter)

# Build new proxy type (outside lock to minimize critical section)
wrapped_core_schema = _wrap_core_schema(core_schema)
proxy_type = type(
f"SerializationProxy[{type(obj).__name__}]",
(cls,),
{
"core_schema": core_schema,
"__pydantic_serializer__": SchemaSerializer(wrapped_core_schema),
"__pydantic_core_schema__": wrapped_core_schema,
"__pydantic_validator__": adapter.validator,
},
)

# Publish under lock
with cls._PROXY_TYPE_CACHE_LOCK:
# Re-check if someone else beat us
existing = cls._proxy_type_cache.get(schema_id)
if existing is None or getattr(existing, "core_schema", None) is not core_schema:
cls._proxy_type_cache[schema_id] = proxy_type
cls._proxy_type_cache.move_to_end(schema_id)
# Evict oldest entry if cache is too large
if len(cls._proxy_type_cache) > cls._PROXY_TYPE_CACHE_SIZE:
cls._proxy_type_cache.popitem(last=False)
else:
proxy_type = existing

return proxy_type(obj, serialized, adapter)

Expand All @@ -187,51 +229,77 @@ def build(
if adapter is None:
adapter = TypeAdapter(type(obj))
serialized = adapter.dump_python(obj)
# If it's a dict, make it read-only to prevent accidental mutation from other threads
if isinstance(serialized, dict):
serialized = MappingProxyType(serialized)
core_schema = adapter.core_schema
return cls._build(obj, serialized, adapter, core_schema)

def __getattr__(self, name: str):
# Check attribute cache first
if name in self._attr_cache:
return self._attr_cache[name]
with self._attr_cache_lock:
if name in self._attr_cache:
return self._attr_cache[name]

if isinstance(self.serialized, dict) and name in self.serialized:
if isinstance(self.serialized, (dict, MappingProxyType, Mapping)) and name in self.serialized:
sub_schema = _extract_subschema(self.core_schema, name)
child_ser = self.serialized[name]

# For primitive types (non-dict/list serialized values), return the serialized value directly
# This ensures field serializers (like PlainSerializer) are properly applied
if not isinstance(child_ser, (dict, list, tuple)):
return child_ser

# Wrap child dicts to prevent mutation
if isinstance(child_ser, dict):
child_ser = MappingProxyType(child_ser)
proxy = self._build(
getattr(self.obj, name),
self.serialized[name],
child_ser,
self.root_adapter,
sub_schema,
)
# Cache the built proxy
self._attr_cache[name] = proxy
with self._attr_cache_lock:
self._attr_cache[name] = proxy

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This can be made more efficient and concise. If two threads access the same uncached attribute, both will build a proxy. By using setdefault, you can atomically (within the lock) check for and set the value, ensuring that only the first thread's proxy is cached and subsequent threads use the cached version. This avoids redundant work in highly concurrent scenarios.

Suggested change
self._attr_cache[name] = proxy
proxy = self._attr_cache.setdefault(name, proxy)

return proxy
return getattr(self.obj, name)

def __getitem__(self, key):
# For getitem, we use a tuple for cache key to avoid collisions
cache_key = ("__item__", key)
if cache_key in self._attr_cache:
return self._attr_cache[cache_key]
with self._attr_cache_lock:
if cache_key in self._attr_cache:
return self._attr_cache[cache_key]

sub_schema = _extract_subschema(self.core_schema, key)
if type(self.serialized) is type(self.obj):
proxy = self._build(
self.obj[key],
self.serialized[key],
self.root_adapter,
sub_schema,
)
else:
proxy = self._build(
self.serialized[key],
self.serialized[key],
self.root_adapter,
sub_schema,
)
child_ser = self.serialized[key]

# For primitive types (non-dict/list serialized values), return the serialized value directly
# This ensures field serializers (like PlainSerializer) are properly applied
if not isinstance(child_ser, (dict, list, tuple)):
return child_ser

# Wrap child dicts to prevent mutation
if isinstance(child_ser, dict):
child_ser = MappingProxyType(child_ser)

# Try to keep the real underlying object if possible; otherwise fall back to serialized
try:
child_obj = self.obj[key]
except Exception:
child_obj = child_ser

proxy = self._build(
child_obj,
child_ser,
self.root_adapter,
sub_schema,
)

# Cache the built proxy
self._attr_cache[cache_key] = proxy
with self._attr_cache_lock:
self._attr_cache[cache_key] = proxy

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to my other comments, you can make this caching more efficient in a concurrent environment by using setdefault. This will prevent redundant work if multiple threads try to access the same uncached item simultaneously, ensuring only one proxy is created and cached.

Suggested change
self._attr_cache[cache_key] = proxy
proxy = self._attr_cache.setdefault(cache_key, proxy)

return proxy

def __iter__(self):
Expand Down