diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 3718a29c5..8335cbd18 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -2,11 +2,17 @@ from . import agent, models, telemetry, types from .agent.agent import Agent +from .agent.serializers import JSONSerializer, PickleSerializer, StateSerializer +from .agent.state import AgentState from .tools.decorator import tool from .types.tools import ToolContext __all__ = [ "Agent", + "AgentState", + "JSONSerializer", + "PickleSerializer", + "StateSerializer", "agent", "models", "tool", diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index 6618d3328..017dc8a79 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -4,6 +4,7 @@ - Agent: The main interface for interacting with AI models and tools - ConversationManager: Classes for managing conversation history and context windows +- Serializers: Pluggable serialization strategies for agent state (JSONSerializer, PickleSerializer) """ from .agent import Agent @@ -14,12 +15,18 @@ SlidingWindowConversationManager, SummarizingConversationManager, ) +from .serializers import JSONSerializer, PickleSerializer, StateSerializer +from .state import AgentState __all__ = [ "Agent", "AgentResult", + "AgentState", "ConversationManager", + "JSONSerializer", "NullConversationManager", + "PickleSerializer", "SlidingWindowConversationManager", + "StateSerializer", "SummarizingConversationManager", ] diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 9e726ca0b..762d25012 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -66,6 +66,7 @@ ConversationManager, SlidingWindowConversationManager, ) +from .serializers import StateSerializer from .state import AgentState logger = logging.getLogger(__name__) @@ -121,6 +122,7 @@ def __init__( name: Optional[str] = None, description: Optional[str] = None, state: Optional[Union[AgentState, dict]] = None, + state_serializer: Optional[StateSerializer] = None, hooks: Optional[list[HookProvider]] = None, session_manager: Optional[SessionManager] = None, tool_executor: Optional[ToolExecutor] = None, @@ -168,6 +170,9 @@ def __init__( Defaults to None. state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict. Defaults to an empty AgentState object. + state_serializer: Serializer for state persistence (e.g., JSONSerializer, PickleSerializer). + Cannot be provided together with an AgentState object in 'state' parameter. + Defaults to JSONSerializer for backward compatibility. hooks: hooks to be added to the agent hook registry Defaults to None. session_manager: Manager for handling agent sessions including conversation history and state. @@ -175,7 +180,8 @@ def __init__( tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). Raises: - ValueError: If agent id contains path separators. + ValueError: If agent id contains path separators, or if both state (AgentState) and state_serializer + are provided. """ self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model self.messages = messages if messages is not None else [] @@ -231,13 +237,18 @@ def __init__( # Initialize agent state management if state is not None: if isinstance(state, dict): - self.state = AgentState(state) + self.state = AgentState(state, serializer=state_serializer) elif isinstance(state, AgentState): + if state_serializer is not None: + raise ValueError( + "Cannot provide both state (AgentState) and state_serializer. " + "Configure serializer on the AgentState object instead." + ) self.state = state else: raise ValueError("state must be an AgentState object or a dict") else: - self.state = AgentState() + self.state = AgentState(serializer=state_serializer) self.tool_caller = _ToolCaller(self) diff --git a/src/strands/agent/serializers.py b/src/strands/agent/serializers.py new file mode 100644 index 000000000..2208c5e84 --- /dev/null +++ b/src/strands/agent/serializers.py @@ -0,0 +1,150 @@ +"""State serializers for agent state management. + +This module provides pluggable serialization strategies for AgentState: +- JSONSerializer: Default serializer, backward compatible, validates on set() +- PickleSerializer: Supports any Python object, no validation on set() +- StateSerializer: Protocol for custom serializers +""" + +import copy +import json +import pickle +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class StateSerializer(Protocol): + """Protocol for state serializers. + + Custom serializers can implement this protocol to provide + alternative serialization strategies for agent state. + """ + + def serialize(self, data: dict[str, Any]) -> bytes: + """Serialize state dict to bytes. + + Args: + data: Dictionary of state data to serialize + + Returns: + Serialized state as bytes + """ + ... + + def deserialize(self, data: bytes) -> dict[str, Any]: + """Deserialize bytes back to state dict. + + Args: + data: Serialized state bytes + + Returns: + Deserialized state dictionary + """ + ... + + def validate(self, value: Any) -> None: + """Validate a value can be serialized. + + Serializers that accept any value should implement this as a no-op. + + Args: + value: The value to validate + + Raises: + ValueError: If value cannot be serialized by this serializer + """ + ... + + +class JSONSerializer: + """JSON-based state serializer. + + Default serializer that provides: + - Human-readable serialization format + - Validation on set() to maintain current behavior + - Backward compatibility with existing code + """ + + def serialize(self, data: dict[str, Any]) -> bytes: + """Serialize state dict to JSON bytes. + + Args: + data: Dictionary of state data to serialize + + Returns: + JSON serialized state as bytes + """ + return json.dumps(data).encode("utf-8") + + def deserialize(self, data: bytes) -> dict[str, Any]: + """Deserialize JSON bytes back to state dict. + + Args: + data: JSON serialized state bytes + + Returns: + Deserialized state dictionary + """ + result: dict[str, Any] = json.loads(data.decode("utf-8")) + return result + + def validate(self, value: Any) -> None: + """Validate that a value is JSON serializable. + + Args: + value: The value to validate + + Raises: + ValueError: If value is not JSON serializable + """ + try: + json.dumps(value) + except (TypeError, ValueError) as e: + raise ValueError( + f"Value is not JSON serializable: {type(value).__name__}. " + f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed." + ) from e + + +class PickleSerializer: + """Pickle-based state serializer. + + Provides: + - Support for any Python object (datetime, UUID, dataclass, Pydantic models, etc.) + - No validation on set() (accepts anything) + + Security Warning: + Pickle can execute arbitrary code during deserialization. + Only unpickle data from trusted sources. + """ + + def serialize(self, data: dict[str, Any]) -> bytes: + """Serialize state dict using pickle. + + Args: + data: Dictionary of state data to serialize + + Returns: + Pickle serialized state as bytes + """ + return pickle.dumps(copy.deepcopy(data)) + + def deserialize(self, data: bytes) -> dict[str, Any]: + """Deserialize pickle bytes back to state dict. + + Args: + data: Pickle serialized state bytes + + Returns: + Deserialized state dictionary + """ + result: dict[str, Any] = pickle.loads(data) # noqa: S301 + return result + + def validate(self, value: Any) -> None: + """No-op validation - pickle accepts any Python object. + + Args: + value: The value to validate (ignored) + """ + pass diff --git a/src/strands/agent/state.py b/src/strands/agent/state.py index c323041a3..979bb3993 100644 --- a/src/strands/agent/state.py +++ b/src/strands/agent/state.py @@ -1,6 +1,190 @@ -"""Agent state management.""" +"""Agent state management. -from ..types.json_dict import JSONSerializableDict +Provides flexible state container with pluggable serialization and transient state support. +""" -# Type alias for agent state -AgentState = JSONSerializableDict +import copy +from typing import Any + +from .serializers import JSONSerializer, StateSerializer + + +class AgentState: + """Flexible state container with pluggable serialization and transient state support. + + AgentState provides a key-value store for agent state with: + - Pluggable serialization (JSON by default, Pickle for rich types) + - Transient state support for runtime-only resources (persist=False) + - Backward compatible API with existing code + + Example: + Basic usage (backward compatible): + ```python + state = AgentState() + state.set("count", 42) # Persistent by default + state.get("count") # Returns 42 + ``` + + Rich types with PickleSerializer: + ```python + from strands.agent.serializers import PickleSerializer + from datetime import datetime + + state = AgentState(serializer=PickleSerializer()) + state.set("created_at", datetime.now()) # Works with Pickle + ``` + + Transient state for runtime resources: + ```python + state.set("db_connection", connection, persist=False) # Not serialized + state.get("db_connection") # Returns the connection + state.is_transient("db_connection") # Returns True + ``` + """ + + def __init__( + self, + initial_state: dict[str, Any] | None = None, + serializer: StateSerializer | None = None, + ): + """Initialize AgentState. + + Args: + initial_state: Optional initial state dictionary + serializer: Serializer to use for state persistence. + Defaults to JSONSerializer for backward compatibility. + + Raises: + ValueError: If initial_state contains non-serializable values (with JSONSerializer) + """ + self._serializer = serializer if serializer is not None else JSONSerializer() + self._transient_keys: set[str] = set() + self._data: dict[str, Any] + + if initial_state: + # Validate initial state + self._serializer.validate(initial_state) + self._data = copy.deepcopy(initial_state) + else: + self._data = {} + + @property + def serializer(self) -> StateSerializer: + """Get the current serializer. + + Returns: + The serializer used for state persistence + """ + return self._serializer + + @serializer.setter + def serializer(self, value: StateSerializer) -> None: + """Set the serializer. + + Args: + value: New serializer to use for state persistence + """ + self._serializer = value + + def set(self, key: str, value: Any, *, persist: bool = True) -> None: + """Set a value in the store. + + Args: + key: The key to store the value under + value: The value to store + persist: If False, value is transient (not serialized). Default True. + + Raises: + ValueError: If key is invalid, or if value is not serializable + (only when persist=True) + """ + self._validate_key(key) + + if persist: + # Validate serializable + self._serializer.validate(value) + self._transient_keys.discard(key) + else: + # Mark as transient - skip validation + self._transient_keys.add(key) + + self._data[key] = copy.deepcopy(value) + + def get(self, key: str | None = None) -> Any: + """Get a value or entire data. + + Works uniformly for both persistent and transient values. + + Args: + key: The key to retrieve (if None, returns entire data dict) + + Returns: + The stored value, entire data dict, or None if not found + """ + if key is None: + return copy.deepcopy(self._data) + else: + return copy.deepcopy(self._data.get(key)) + + def delete(self, key: str) -> None: + """Delete a specific key from the store. + + Args: + key: The key to delete + """ + self._validate_key(key) + self._data.pop(key, None) + self._transient_keys.discard(key) + + def is_transient(self, key: str) -> bool: + """Check if a key is transient (not persisted). + + Args: + key: The key to check + + Returns: + True if the key is transient, False otherwise + """ + return key in self._transient_keys + + def serialize(self) -> bytes: + """Serialize only persistent keys. + + Returns: + Serialized state as bytes (excludes transient keys) + """ + persistent_data = {k: v for k, v in self._data.items() if k not in self._transient_keys} + return self._serializer.serialize(persistent_data) + + def deserialize(self, data: bytes) -> dict[str, Any]: + """Deserialize persistent state. + + Transient keys are preserved if already in memory. + + Args: + data: Serialized state bytes to restore + + Returns: + The complete state dictionary (including preserved transient keys) + """ + persistent_data = self._serializer.deserialize(data) + # Keep transient keys in memory, replace persistent + transient_data = {k: v for k, v in self._data.items() if k in self._transient_keys} + self._data = {**persistent_data, **transient_data} + return self._data + + def _validate_key(self, key: str) -> None: + """Validate that a key is valid. + + Args: + key: The key to validate + + Raises: + ValueError: If key is invalid + """ + if key is None: + raise ValueError("Key cannot be None") + if not isinstance(key, str): + raise ValueError("Key must be a string") + if not key.strip(): + raise ValueError("Key cannot be empty") diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 5ddb181ea..0ec874c19 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -19,6 +19,7 @@ from .... import _identifier from ....agent.state import AgentState +from ....agent.serializers import StateSerializer from ....hooks import HookProvider, HookRegistry from ....interrupt import _InterruptState from ....tools._caller import _ToolCaller @@ -73,6 +74,7 @@ def __init__( description: str | None = None, hooks: list[HookProvider] | None = None, state: AgentState | dict | None = None, + state_serializer: StateSerializer | None = None, session_manager: "SessionManager | None" = None, tool_executor: ToolExecutor | None = None, **kwargs: Any, @@ -91,13 +93,16 @@ def __init__( description: Description of what the Agent does. hooks: Optional list of hook providers to register for lifecycle events. state: Stateful information for the agent. Can be either an AgentState object, or a json serializable dict. + state_serializer: Serializer for state persistence (e.g., JSONSerializer, PickleSerializer). + Cannot be provided together with an AgentState object in 'state' parameter. session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). **kwargs: Additional configuration for future extensibility. Raises: - ValueError: If model configuration is invalid or state is invalid type. + ValueError: If model configuration is invalid, state is invalid type, or both state (AgentState) and + state_serializer are provided. TypeError: If model type is unsupported. """ self.model = ( @@ -134,13 +139,16 @@ def __init__( # Initialize agent state management if state is not None: if isinstance(state, dict): - self.state = AgentState(state) + self.state = AgentState(state, serializer=state_serializer) elif isinstance(state, AgentState): + if state_serializer is not None: + raise ValueError("Cannot provide both state (AgentState) and state_serializer. " + "Configure serializer on the AgentState object instead.") self.state = state else: raise ValueError("state must be an AgentState object or a dict") else: - self.state = AgentState() + self.state = AgentState(serializer=state_serializer) # Initialize other components self._tool_caller = _ToolCaller(self) diff --git a/tests/strands/agent/test_agent_state.py b/tests/strands/agent/test_agent_state.py index bc2321a56..2933fd1c6 100644 --- a/tests/strands/agent/test_agent_state.py +++ b/tests/strands/agent/test_agent_state.py @@ -1,8 +1,11 @@ """Tests for AgentState class.""" +from datetime import datetime +from uuid import uuid4 + import pytest -from strands import Agent, tool +from strands import Agent, JSONSerializer, PickleSerializer, tool from strands.agent.state import AgentState from strands.types.content import Messages @@ -143,3 +146,208 @@ def update_state(agent: Agent): assert agent.state.get("hello") == "world" assert agent.state.get("foo") == "baz" + + +# Serializer Tests + + +def test_default_serializer_is_json(): + """Test that default serializer is JSONSerializer.""" + state = AgentState() + assert isinstance(state.serializer, JSONSerializer) + + +def test_pickle_serializer_allows_rich_types(): + """Test that PickleSerializer allows datetime, UUID, and other rich types.""" + state = AgentState(serializer=PickleSerializer()) + + # Rich types that don't work with JSONSerializer + now = datetime.now() + user_id = uuid4() + + state.set("created_at", now) + state.set("user_id", user_id) + state.set("config", {"nested": now}) + + assert state.get("created_at") == now + assert state.get("user_id") == user_id + + +def test_json_serializer_rejects_rich_types(): + """Test that JSONSerializer rejects datetime and other non-JSON types.""" + state = AgentState(serializer=JSONSerializer()) + + with pytest.raises(ValueError, match="not JSON serializable"): + state.set("created_at", datetime.now()) + + +def test_serialize_deserialize_json(): + """Test serialize and deserialize with JSONSerializer.""" + state = AgentState(serializer=JSONSerializer()) + state.set("name", "test") + state.set("count", 42) + + # Serialize + data = state.serialize() + assert isinstance(data, bytes) + + # Deserialize into new state + new_state = AgentState(serializer=JSONSerializer()) + new_state.deserialize(data) + + assert new_state.get("name") == "test" + assert new_state.get("count") == 42 + + +def test_serialize_deserialize_pickle(): + """Test serialize and deserialize with PickleSerializer.""" + state = AgentState(serializer=PickleSerializer()) + now = datetime.now() + user_id = uuid4() + state.set("created_at", now) + state.set("user_id", user_id) + + # Serialize + data = state.serialize() + assert isinstance(data, bytes) + + # Deserialize into new state + new_state = AgentState(serializer=PickleSerializer()) + new_state.deserialize(data) + + assert new_state.get("created_at") == now + assert new_state.get("user_id") == user_id + + +def test_transient_state_not_serialized(): + """Test that transient values are not serialized.""" + state = AgentState(serializer=JSONSerializer()) + + # Persistent value + state.set("persistent_key", "persistent_value") + + # Transient value (not serializable, but persist=False so no validation) + state.set("transient_key", lambda: "function", persist=False) + + # Check transient flag + assert state.is_transient("transient_key") is True + assert state.is_transient("persistent_key") is False + + # Get works for both + assert state.get("persistent_key") == "persistent_value" + assert state.get("transient_key") is not None # Lambda exists + + # Serialize excludes transient + data = state.serialize() + new_state = AgentState(serializer=JSONSerializer()) + new_state.deserialize(data) + + assert new_state.get("persistent_key") == "persistent_value" + assert new_state.get("transient_key") is None # Transient not restored + + +def test_transient_preserved_after_deserialize(): + """Test that transient values in memory are preserved after deserialize.""" + state = AgentState(serializer=JSONSerializer()) + + # Set transient value first + state.set("runtime_db", "connection_object", persist=False) + + # Set persistent and serialize + state.set("user_id", "123") + data = state.serialize() + + # Modify persistent value after serialization + state.set("user_id", "999") + + # Deserialize - should restore persistent but keep transient + state.deserialize(data) + + assert state.get("user_id") == "123" # Restored from serialized + assert state.get("runtime_db") == "connection_object" # Preserved in memory + + +def test_delete_removes_transient_flag(): + """Test that delete also removes the transient flag.""" + state = AgentState(serializer=JSONSerializer()) + state.set("key", "value", persist=False) + assert state.is_transient("key") is True + + state.delete("key") + assert state.is_transient("key") is False + + +def test_serializer_property(): + """Test serializer property getter and setter.""" + state = AgentState(serializer=JSONSerializer()) + assert isinstance(state.serializer, JSONSerializer) + + # Change serializer + state.serializer = PickleSerializer() + assert isinstance(state.serializer, PickleSerializer) + + +def test_agent_with_state_serializer(): + """Test Agent constructor with state_serializer parameter.""" + agent_messages: Messages = [ + {"role": "assistant", "content": [{"text": "Hello!"}]}, + ] + mocked_model_provider = MockedModelProvider(agent_messages) + + agent = Agent( + model=mocked_model_provider, + state_serializer=PickleSerializer(), + ) + + assert isinstance(agent.state.serializer, PickleSerializer) + + +def test_agent_with_state_dict_and_serializer(): + """Test Agent with dict state and state_serializer parameter.""" + agent_messages: Messages = [ + {"role": "assistant", "content": [{"text": "Hello!"}]}, + ] + mocked_model_provider = MockedModelProvider(agent_messages) + + agent = Agent( + model=mocked_model_provider, + state={"key": "value"}, + state_serializer=PickleSerializer(), + ) + + assert isinstance(agent.state.serializer, PickleSerializer) + assert agent.state.get("key") == "value" + + +def test_agent_with_agent_state_and_serializer_raises(): + """Test that providing both AgentState and state_serializer raises error.""" + agent_messages: Messages = [ + {"role": "assistant", "content": [{"text": "Hello!"}]}, + ] + mocked_model_provider = MockedModelProvider(agent_messages) + + existing_state = AgentState(serializer=JSONSerializer()) + + with pytest.raises(ValueError, match="Cannot provide both state.*and state_serializer"): + Agent( + model=mocked_model_provider, + state=existing_state, + state_serializer=PickleSerializer(), + ) + + +def test_agent_state_with_pickle_allows_datetime(): + """Test using datetime in agent state with PickleSerializer.""" + agent_messages: Messages = [ + {"role": "assistant", "content": [{"text": "Hello!"}]}, + ] + mocked_model_provider = MockedModelProvider(agent_messages) + + now = datetime.now() + agent = Agent( + model=mocked_model_provider, + state_serializer=PickleSerializer(), + ) + + agent.state.set("created_at", now) + assert agent.state.get("created_at") == now