-
Notifications
You must be signed in to change notification settings - Fork 0
Fix thread safety in SerializationProxy #8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: claude/dev-ergonomics-improvements-011CUVRN18PCYh3xd7J7c1td
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,6 +1,7 @@ | ||||||
| from collections import OrderedDict | ||||||
| from collections.abc import Callable, Iterable, Mapping | ||||||
| from copy import deepcopy | ||||||
| from threading import Lock, RLock | ||||||
| from types import MappingProxyType | ||||||
| from typing import Generic, NamedTuple, TypeGuard, TypeVar | ||||||
|
|
||||||
|
|
@@ -61,23 +62,39 @@ def apply_to_unwrapped(proxy: "SerializationProxy[T]") -> T: | |||||
|
|
||||||
| # Bounded cache for wrapped schemas to prevent memory leaks in long-running applications | ||||||
| # Using OrderedDict for LRU eviction | ||||||
| # Store tuple (orig_schema, wrapped_schema) to prevent id() reuse bugs | ||||||
| _WRAPPED_SCHEMA_CACHE_SIZE = 256 | ||||||
| _wrapped_schema_cache: OrderedDict[int, CoreSchema] = OrderedDict() | ||||||
| _wrapped_schema_cache: OrderedDict[int, tuple[CoreSchema, CoreSchema]] = OrderedDict() | ||||||
| _WRAPPED_SCHEMA_CACHE_LOCK = Lock() | ||||||
|
|
||||||
|
|
||||||
| def _wrap_core_schema(schema: CoreSchema) -> CoreSchema: | ||||||
| """Wrap a CoreSchema to make it proxy-aware. Uses bounded LRU cache to avoid expensive deepcopy.""" | ||||||
| schema_id = id(schema) | ||||||
|
|
||||||
| # Check cache first (LRU: move to end if found) | ||||||
| if schema_id in _wrapped_schema_cache: | ||||||
| # Move to end (most recently used) | ||||||
| _wrapped_schema_cache.move_to_end(schema_id) | ||||||
| return _wrapped_schema_cache[schema_id] | ||||||
|
|
||||||
| # Build wrapped schema | ||||||
| # Lock-free fast path: check cache without lock first | ||||||
| # Note: OrderedDict.get() can raise during concurrent modifications, so we catch any errors | ||||||
| try: | ||||||
| tup = _wrapped_schema_cache.get(schema_id) | ||||||
| if tup is not None and tup[0] is schema: | ||||||
| return tup[1] | ||||||
| except (ValueError, RuntimeError): | ||||||
| # Race condition detected, fall through to locked path | ||||||
| pass | ||||||
|
|
||||||
| # Slow path: take lock and recheck | ||||||
| with _WRAPPED_SCHEMA_CACHE_LOCK: | ||||||
| tup = _wrapped_schema_cache.get(schema_id) | ||||||
| if tup is not None: | ||||||
| orig, wrapped = tup | ||||||
| # Guard against id() reuse by verifying identity | ||||||
| if orig is schema: | ||||||
| _wrapped_schema_cache.move_to_end(schema_id) | ||||||
| return wrapped | ||||||
|
|
||||||
| # Build wrapped schema (outside lock to minimize critical section) | ||||||
| match schema: | ||||||
| # someting we can reference to (e.g. BaseModel, Dataclass, ...) | ||||||
| # something we can reference to (e.g. BaseModel, Dataclass, ...) | ||||||
| case {"ref": ref}: | ||||||
| wrapped_schema = core_schema.definitions_schema( | ||||||
| schema=core_schema.definition_reference_schema( | ||||||
|
|
@@ -105,12 +122,13 @@ def _wrap_core_schema(schema: CoreSchema) -> CoreSchema: | |||||
| ) | ||||||
|
|
||||||
| # Cache with LRU eviction | ||||||
| _wrapped_schema_cache[schema_id] = wrapped_schema | ||||||
| _wrapped_schema_cache.move_to_end(schema_id) | ||||||
| with _WRAPPED_SCHEMA_CACHE_LOCK: | ||||||
| _wrapped_schema_cache[schema_id] = (schema, wrapped_schema) | ||||||
| _wrapped_schema_cache.move_to_end(schema_id) | ||||||
|
|
||||||
| # Evict oldest entry if cache is too large | ||||||
| if len(_wrapped_schema_cache) > _WRAPPED_SCHEMA_CACHE_SIZE: | ||||||
| _wrapped_schema_cache.popitem(last=False) | ||||||
| # Evict oldest entry if cache is too large | ||||||
| if len(_wrapped_schema_cache) > _WRAPPED_SCHEMA_CACHE_SIZE: | ||||||
| _wrapped_schema_cache.popitem(last=False) | ||||||
|
|
||||||
| return wrapped_schema | ||||||
|
|
||||||
|
|
@@ -123,11 +141,12 @@ class SerializationProxy(Generic[T]): | |||||
| # Bounded cache for proxy types to prevent memory leaks | ||||||
| _PROXY_TYPE_CACHE_SIZE = 256 | ||||||
| _proxy_type_cache: OrderedDict[int, type["SerializationProxy"]] = OrderedDict() | ||||||
| _PROXY_TYPE_CACHE_LOCK = Lock() | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| obj: T, | ||||||
| serialized: MappingProxyType, | ||||||
| serialized: Mapping | Iterable | object, | ||||||
| root_adapter: TypeAdapter, | ||||||
| ): | ||||||
| self.obj = obj | ||||||
|
|
@@ -136,42 +155,65 @@ def __init__( | |||||
| # Cache for accessed attributes to avoid rebuilding proxies | ||||||
| # Keys are either strings (for attributes) or tuples (for items) | ||||||
| self._attr_cache: dict[str | tuple, "SerializationProxy"] = {} | ||||||
| self._attr_cache_lock = RLock() | ||||||
|
|
||||||
| @classmethod | ||||||
| def _build( | ||||||
| cls, | ||||||
| obj: T, | ||||||
| serialized: MappingProxyType, | ||||||
| serialized: Mapping | Iterable | object, | ||||||
| adapter: TypeAdapter, | ||||||
| core_schema: CoreSchema, | ||||||
| ): | ||||||
| schema_id = id(core_schema) | ||||||
| # Normalize: wrap dicts to ensure immutability | ||||||
| if isinstance(serialized, dict) and not isinstance(serialized, MappingProxyType): | ||||||
| serialized = MappingProxyType(serialized) | ||||||
|
|
||||||
| # Check if we already have a cached proxy type for this schema (LRU) | ||||||
| if schema_id in cls._proxy_type_cache: | ||||||
| # Move to end (most recently used) | ||||||
| cls._proxy_type_cache.move_to_end(schema_id) | ||||||
| proxy_type = cls._proxy_type_cache[schema_id] | ||||||
| else: | ||||||
| # Build new proxy type | ||||||
| wrapped_core_schema = _wrap_core_schema(core_schema) | ||||||
| proxy_type = type( | ||||||
| f"SerializationProxy[{type(obj).__name__}]", | ||||||
| (cls,), | ||||||
| { | ||||||
| "core_schema": core_schema, | ||||||
| "__pydantic_serializer__": SchemaSerializer(wrapped_core_schema), | ||||||
| "__pydantic_core_schema__": wrapped_core_schema, | ||||||
| "__pydantic_validator__": adapter.validator, | ||||||
| }, | ||||||
| ) | ||||||
| # Cache the proxy type with LRU eviction | ||||||
| cls._proxy_type_cache[schema_id] = proxy_type | ||||||
| cls._proxy_type_cache.move_to_end(schema_id) | ||||||
| schema_id = id(core_schema) | ||||||
|
|
||||||
| # Evict oldest entry if cache is too large | ||||||
| if len(cls._proxy_type_cache) > cls._PROXY_TYPE_CACHE_SIZE: | ||||||
| cls._proxy_type_cache.popitem(last=False) | ||||||
| # Lock-free fast path: check cache without lock first | ||||||
| # Note: OrderedDict.get() can raise during concurrent modifications, so we catch any errors | ||||||
| try: | ||||||
| proxy_type = cls._proxy_type_cache.get(schema_id) | ||||||
| if proxy_type is not None and getattr(proxy_type, "core_schema", None) is core_schema: | ||||||
| return proxy_type(obj, serialized, adapter) | ||||||
| except (ValueError, RuntimeError): | ||||||
| # Race condition detected, fall through to locked path | ||||||
| pass | ||||||
|
|
||||||
| # Slow path: take lock and recheck | ||||||
| with cls._PROXY_TYPE_CACHE_LOCK: | ||||||
| proxy_type = cls._proxy_type_cache.get(schema_id) | ||||||
| # Guard against id() reuse by verifying identity | ||||||
| if proxy_type is not None and getattr(proxy_type, "core_schema", None) is core_schema: | ||||||
| cls._proxy_type_cache.move_to_end(schema_id) | ||||||
| return proxy_type(obj, serialized, adapter) | ||||||
|
|
||||||
| # Build new proxy type (outside lock to minimize critical section) | ||||||
| wrapped_core_schema = _wrap_core_schema(core_schema) | ||||||
| proxy_type = type( | ||||||
| f"SerializationProxy[{type(obj).__name__}]", | ||||||
| (cls,), | ||||||
| { | ||||||
| "core_schema": core_schema, | ||||||
| "__pydantic_serializer__": SchemaSerializer(wrapped_core_schema), | ||||||
| "__pydantic_core_schema__": wrapped_core_schema, | ||||||
| "__pydantic_validator__": adapter.validator, | ||||||
| }, | ||||||
| ) | ||||||
|
|
||||||
| # Publish under lock | ||||||
| with cls._PROXY_TYPE_CACHE_LOCK: | ||||||
| # Re-check if someone else beat us | ||||||
| existing = cls._proxy_type_cache.get(schema_id) | ||||||
| if existing is None or getattr(existing, "core_schema", None) is not core_schema: | ||||||
| cls._proxy_type_cache[schema_id] = proxy_type | ||||||
| cls._proxy_type_cache.move_to_end(schema_id) | ||||||
| # Evict oldest entry if cache is too large | ||||||
| if len(cls._proxy_type_cache) > cls._PROXY_TYPE_CACHE_SIZE: | ||||||
| cls._proxy_type_cache.popitem(last=False) | ||||||
| else: | ||||||
| proxy_type = existing | ||||||
|
|
||||||
| return proxy_type(obj, serialized, adapter) | ||||||
|
|
||||||
|
|
@@ -187,51 +229,77 @@ def build( | |||||
| if adapter is None: | ||||||
| adapter = TypeAdapter(type(obj)) | ||||||
| serialized = adapter.dump_python(obj) | ||||||
| # If it's a dict, make it read-only to prevent accidental mutation from other threads | ||||||
| if isinstance(serialized, dict): | ||||||
| serialized = MappingProxyType(serialized) | ||||||
| core_schema = adapter.core_schema | ||||||
| return cls._build(obj, serialized, adapter, core_schema) | ||||||
|
|
||||||
| def __getattr__(self, name: str): | ||||||
| # Check attribute cache first | ||||||
| if name in self._attr_cache: | ||||||
| return self._attr_cache[name] | ||||||
| with self._attr_cache_lock: | ||||||
| if name in self._attr_cache: | ||||||
| return self._attr_cache[name] | ||||||
|
|
||||||
| if isinstance(self.serialized, dict) and name in self.serialized: | ||||||
| if isinstance(self.serialized, (dict, MappingProxyType, Mapping)) and name in self.serialized: | ||||||
| sub_schema = _extract_subschema(self.core_schema, name) | ||||||
| child_ser = self.serialized[name] | ||||||
|
|
||||||
| # For primitive types (non-dict/list serialized values), return the serialized value directly | ||||||
| # This ensures field serializers (like PlainSerializer) are properly applied | ||||||
| if not isinstance(child_ser, (dict, list, tuple)): | ||||||
| return child_ser | ||||||
|
|
||||||
| # Wrap child dicts to prevent mutation | ||||||
| if isinstance(child_ser, dict): | ||||||
| child_ser = MappingProxyType(child_ser) | ||||||
| proxy = self._build( | ||||||
| getattr(self.obj, name), | ||||||
| self.serialized[name], | ||||||
| child_ser, | ||||||
| self.root_adapter, | ||||||
| sub_schema, | ||||||
| ) | ||||||
| # Cache the built proxy | ||||||
| self._attr_cache[name] = proxy | ||||||
| with self._attr_cache_lock: | ||||||
| self._attr_cache[name] = proxy | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be made more efficient and concise. If two threads access the same uncached attribute, both will build a proxy. By using
Suggested change
|
||||||
| return proxy | ||||||
| return getattr(self.obj, name) | ||||||
|
|
||||||
| def __getitem__(self, key): | ||||||
| # For getitem, we use a tuple for cache key to avoid collisions | ||||||
| cache_key = ("__item__", key) | ||||||
| if cache_key in self._attr_cache: | ||||||
| return self._attr_cache[cache_key] | ||||||
| with self._attr_cache_lock: | ||||||
| if cache_key in self._attr_cache: | ||||||
| return self._attr_cache[cache_key] | ||||||
|
|
||||||
| sub_schema = _extract_subschema(self.core_schema, key) | ||||||
| if type(self.serialized) is type(self.obj): | ||||||
| proxy = self._build( | ||||||
| self.obj[key], | ||||||
| self.serialized[key], | ||||||
| self.root_adapter, | ||||||
| sub_schema, | ||||||
| ) | ||||||
| else: | ||||||
| proxy = self._build( | ||||||
| self.serialized[key], | ||||||
| self.serialized[key], | ||||||
| self.root_adapter, | ||||||
| sub_schema, | ||||||
| ) | ||||||
| child_ser = self.serialized[key] | ||||||
|
|
||||||
| # For primitive types (non-dict/list serialized values), return the serialized value directly | ||||||
| # This ensures field serializers (like PlainSerializer) are properly applied | ||||||
| if not isinstance(child_ser, (dict, list, tuple)): | ||||||
| return child_ser | ||||||
|
|
||||||
| # Wrap child dicts to prevent mutation | ||||||
| if isinstance(child_ser, dict): | ||||||
| child_ser = MappingProxyType(child_ser) | ||||||
|
|
||||||
| # Try to keep the real underlying object if possible; otherwise fall back to serialized | ||||||
| try: | ||||||
| child_obj = self.obj[key] | ||||||
| except Exception: | ||||||
| child_obj = child_ser | ||||||
|
|
||||||
| proxy = self._build( | ||||||
| child_obj, | ||||||
| child_ser, | ||||||
| self.root_adapter, | ||||||
| sub_schema, | ||||||
| ) | ||||||
|
|
||||||
| # Cache the built proxy | ||||||
| self._attr_cache[cache_key] = proxy | ||||||
| with self._attr_cache_lock: | ||||||
| self._attr_cache[cache_key] = proxy | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to my other comments, you can make this caching more efficient in a concurrent environment by using
Suggested change
|
||||||
| return proxy | ||||||
|
|
||||||
| def __iter__(self): | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While adding a lock here prevents race conditions, this implementation of double-checked locking can be made more efficient. If two threads call this function with the same uncached schema, both will perform the expensive schema wrapping. The second thread to acquire the lock will simply overwrite the result of the first.
To avoid this redundant work, you can re-check the cache after acquiring the lock for writing. If another thread has populated the cache in the meantime, you can use its result and discard your own. This pattern is already correctly implemented in the
_buildmethod. Applying it here would improve efficiency and consistency.