From ce76ee20de7c8aaf5881b95fc916025545f5969e Mon Sep 17 00:00:00 2001 From: perl_d Date: Tue, 10 Feb 2026 15:09:13 +0100 Subject: [PATCH] refactor: mark all message dicts as jsonable --- .../tests/client_tests/test_live_table.py | 11 +- bec_lib/bec_lib/bec_service.py | 1 + bec_lib/bec_lib/messages.py | 151 +++++++++++++----- bec_lib/bec_lib/one_way_registry.py | 97 +++++++++++ bec_lib/bec_lib/serialization_registry.py | 1 - bec_lib/tests/test_bec_messages.py | 14 +- bec_lib/tests/test_config_helper.py | 3 +- bec_lib/tests/test_serializer.py | 6 +- .../tests_scan_server/test_scan_worker.py | 2 +- .../tests/tests_scan_server/test_scans.py | 6 +- 10 files changed, 235 insertions(+), 57 deletions(-) create mode 100644 bec_lib/bec_lib/one_way_registry.py diff --git a/bec_ipython_client/tests/client_tests/test_live_table.py b/bec_ipython_client/tests/client_tests/test_live_table.py index 891803831..53c8ccf56 100644 --- a/bec_ipython_client/tests/client_tests/test_live_table.py +++ b/bec_ipython_client/tests/client_tests/test_live_table.py @@ -303,8 +303,11 @@ def test_print_table_data_hinted_value_with_precision( @pytest.mark.parametrize( "value,expected", [ - (np.int32(1), "1.00"), - (np.float64(1.00000), "1.00"), + # Commented out cases are not supported in unstructured serialized data, because msgpack doesn't distinguish + # lists, tuples, or sets. To support this, ScanMessage must be refactored to support the type information directly + # except for numpy arrays, which are currently special-cased but will be removed in a future refactor. + # (np.int32(1), "1.00"), + # (np.float64(1.00000), "1.00"), (0, "0.00"), (1, "1.00"), (0.000, "0.00"), @@ -314,10 +317,10 @@ def test_print_table_data_hinted_value_with_precision( ("False", "False"), ("0", "0"), ("1", "1"), - ((0, 1), "(0, 1)"), + # ((0, 1), "(0, 1)"), ({"value": 0}, "{'value': 0}"), (np.array([0, 1]), "[0 1]"), - ({1, 2}, "{1, 2}"), + # ({1, 2}, "{1, 2}"), ], ) def test_print_table_data_variants(self, client_with_grid_scan, value, expected): diff --git a/bec_lib/bec_lib/bec_service.py b/bec_lib/bec_lib/bec_service.py index 1bdc33d34..31e935602 100644 --- a/bec_lib/bec_lib/bec_service.py +++ b/bec_lib/bec_lib/bec_service.py @@ -259,6 +259,7 @@ def _update_existing_services(self) -> None: msgs = [ self.connector.get(MessageEndpoints.service_status(service)) for service in services ] + print(msgs) self._services_info = {msg.content["name"]: msg for msg in msgs if msg is not None} msgs = [self.connector.get(MessageEndpoints.metrics(service)) for service in services] self._services_metric = {msg.content["name"]: msg for msg in msgs if msg is not None} diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index 23bed4c58..5b9e57c73 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -8,13 +8,92 @@ from enum import Enum, auto from importlib.metadata import PackageNotFoundError from importlib.metadata import version as importlib_version -from typing import Any, ClassVar, Literal, Self, Union +from types import NoneType +from typing import Annotated, Any, ClassVar, Literal, Self, Union from uuid import uuid4 +import msgpack import numpy as np -from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator +from pydantic import ( + BaseModel, + BeforeValidator, + ConfigDict, + Field, + ValidationError, + WithJsonSchema, + field_validator, + model_validator, +) +from typing_extensions import TypeAliasType from bec_lib.metadata_schema import get_metadata_schema_for_scan +from bec_lib.one_way_registry import OneWaySerializationRegistry + +_one_way_registry = OneWaySerializationRegistry() + + +def _sanitize_one_way(data: Any) -> Any: + # TODO: Temporary fix for standardizing message structure, will be replaced + # by encoders in a future iteration + if isinstance(data, np.ndarray): + return data + if isinstance(data, np.bool_): + return bool(data) + if isinstance(data, (np.float16, np.float32, np.float64)): + return float(data) + if isinstance(data, (np.int16, np.int32, np.int64, np.uint16, np.uint32, np.uint64)): + return int(data) + if isinstance(data, (list, tuple, set)): + return [_sanitize_one_way(x) for x in data] + if isinstance(data, dict): + return {_sanitize_one_way(k): _sanitize_one_way(v) for k, v in data.items()} + return _one_way_registry.encode(data) + + +def _ignore_ndarray(data: Any) -> Any: + if isinstance(data, np.ndarray): + return [] + raise ValueError(f"Cannot serialize unknown type for {data}: {type(data)}") + + +def _test_packable(data: Any): + try: + msgpack.dumps(data, default=_ignore_ndarray) + except Exception as e: + raise ValueError(f"Non-JSONable/msgpackable data in {data}!") from e + + +def _validate_packable(data: Any) -> Any: + # Skip sanitization if the data is already valid + if isinstance(data, int | float | str | bool | NoneType): + return data + if isinstance(data, np.bool_): + return bool(data) + try: + _test_packable(data) + return data + # Recursively check if we should replace anything which is not supposed to be decoded to a custom + # type on the other end + except ValueError: + data = _sanitize_one_way(data) + _test_packable(data) + return data + + +Jsonable = TypeAliasType( + "Jsonable", + Annotated[ + int | float | str | bool | None | list["Jsonable"] | dict[str, "Jsonable"] | np.ndarray, + BeforeValidator(_validate_packable), + ], +) + +JsonableDict = TypeAliasType( + "JsonableDict", + Annotated[ + dict[str, Jsonable], BeforeValidator(_validate_packable), WithJsonSchema({"type": "object"}) + ], +) class ProcedureWorkerStatus(Enum): @@ -43,8 +122,9 @@ class BECMessage(BaseModel): """ + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") msg_type: ClassVar[str] - metadata: dict = Field(default_factory=dict) + metadata: JsonableDict = Field(default_factory=dict) @field_validator("metadata") @classmethod @@ -141,7 +221,7 @@ class ScanQueueMessage(BECMessage): msg_type: ClassVar[str] = "scan_queue_message" scan_type: str - parameter: dict + parameter: JsonableDict queue: str = Field(default="primary") allow_restart: bool = Field( default=True, @@ -225,18 +305,18 @@ class ScanStatusMessage(BECMessage): scan_type: Literal["step", "fly"] | None = Field(default=None, description="Type of scan") dataset_number: int | None = None scan_report_devices: list[str] | None = None - user_metadata: dict | None = None + user_metadata: JsonableDict | None = None readout_priority: ( dict[Literal["monitored", "baseline", "async", "continuous", "on_request"], list[str]] | None ) = None scan_parameters: dict[ - Literal["exp_time", "frames_per_trigger", "settling_time", "readout_time"] | str, Any + Literal["exp_time", "frames_per_trigger", "settling_time", "readout_time"] | str, Jsonable ] = Field(default_factory=dict) - request_inputs: dict[Literal["arg_bundle", "inputs", "kwargs"], Any] = Field( + request_inputs: dict[Literal["arg_bundle", "inputs", "kwargs"], Jsonable] = Field( default_factory=dict ) - info: dict + info: JsonableDict timestamp: float = Field(default_factory=time.time) def __str__(self): @@ -302,7 +382,7 @@ class ScanQueueModificationMessage(BECMessage): "release_lock", "user_completed", ] - parameter: dict + parameter: JsonableDict queue: str = Field(default="primary") @@ -550,7 +630,7 @@ class DeviceInstructionMessage(BECMessage): "publish_data_as_read", "close_scan_group", ] - parameter: dict + parameter: JsonableDict class ErrorInfo(BaseModel): @@ -747,7 +827,7 @@ class DeviceInfoMessage(BECMessage): msg_type: ClassVar[str] = "device_info_message" device: str - info: dict + info: JsonableDict class DeviceMonitor2DMessage(BECMessage): @@ -767,8 +847,6 @@ class DeviceMonitor2DMessage(BECMessage): data: np.ndarray timestamp: float = Field(default_factory=time.time) - metadata: dict | None = Field(default_factory=dict) - # Needed for pydantic to accept numpy arrays model_config = ConfigDict(arbitrary_types_allowed=True) @@ -808,8 +886,6 @@ class DeviceMonitor1DMessage(BECMessage): data: np.ndarray timestamp: float = Field(default_factory=time.time) - metadata: dict | None = Field(default_factory=dict) - # Needed for pydantic to accept numpy arrays model_config = ConfigDict(arbitrary_types_allowed=True) @@ -867,7 +943,7 @@ class DeviceUserROIMessage(BECMessage): device: str signal: str roi_type: str = Field(description="Type of the ROI, e.g. 'rectangle', 'circle', 'polygon'") - roi: dict = Field( + roi: JsonableDict = Field( description="Dictionary containing the ROI information, e.g. {'x': 100, 'y': 200, 'width': 50, 'height': 50}" ) timestamp: float = Field(default_factory=time.time) @@ -887,7 +963,7 @@ class ScanMessage(BECMessage): msg_type: ClassVar[str] = "scan_message" point_id: int scan_id: str - data: dict + data: JsonableDict class ScanHistoryMessage(BECMessage): @@ -921,7 +997,7 @@ class ScanHistoryMessage(BECMessage): end_time: float scan_name: str num_points: int - request_inputs: dict | None = None + request_inputs: JsonableDict | None = None stored_data_info: dict[str, dict[str, _StoredDataInfo]] | None = None @@ -949,7 +1025,7 @@ class ScanBaselineMessage(BECMessage): msg_type: ClassVar[str] = "scan_baseline_message" scan_id: str - data: dict + data: JsonableDict ConfigAction = Literal["add", "set", "update", "reload", "remove", "reset", "cancel"] @@ -967,7 +1043,7 @@ class DeviceConfigMessage(BECMessage): msg_type: ClassVar[str] = "device_config_message" action: ConfigAction | None = Field(default=None, validate_default=True) - config: dict | None = Field(default=None) + config: JsonableDict | None = Field(default=None) @model_validator(mode="after") @classmethod @@ -1013,7 +1089,7 @@ class LogMessage(BECMessage): log_type: Literal[ "trace", "debug", "info", "success", "warning", "error", "critical", "console_log" ] - log_msg: dict | str + log_msg: JsonableDict | str class AlarmMessage(BECMessage): @@ -1128,8 +1204,8 @@ class FileContentMessage(BECMessage): msg_type: ClassVar[str] = "file_content_message" file_path: str - data: dict - scan_info: dict + data: JsonableDict + scan_info: JsonableDict class VariableMessage(BECMessage): @@ -1170,7 +1246,7 @@ class ServiceMetricMessage(BECMessage): msg_type: ClassVar[str] = "service_metric_message" name: str - metrics: dict + metrics: JsonableDict class ProcessedDataMessage(BECMessage): @@ -1182,7 +1258,7 @@ class ProcessedDataMessage(BECMessage): """ msg_type: ClassVar[str] = "processed_data_message" - data: dict | list[dict] + data: JsonableDict | list[JsonableDict] class DAPConfigMessage(BECMessage): @@ -1194,7 +1270,7 @@ class DAPConfigMessage(BECMessage): """ msg_type: ClassVar[str] = "dap_config_message" - config: dict + config: JsonableDict class DAPRequestMessage(BECMessage): @@ -1210,7 +1286,7 @@ class DAPRequestMessage(BECMessage): msg_type: ClassVar[str] = "dap_request_message" dap_cls: str dap_type: Literal["continuous", "on_demand"] - config: dict + config: JsonableDict class DAPResponseMessage(BECMessage): @@ -1240,7 +1316,7 @@ class AvailableResourceMessage(BECMessage): """ msg_type: ClassVar[str] = "available_resource_message" - resource: dict | list[dict] | BECMessage | list[BECMessage] + resource: JsonableDict | list[JsonableDict] | BECMessage | list[BECMessage] class ProgressMessage(BECMessage): @@ -1268,7 +1344,7 @@ class GUIConfigMessage(BECMessage): """ msg_type: ClassVar[str] = "gui_config_message" - config: dict + config: JsonableDict class GUIDataMessage(BECMessage): @@ -1280,7 +1356,7 @@ class GUIDataMessage(BECMessage): """ msg_type: ClassVar[str] = "gui_data_message" - data: dict + data: JsonableDict class GUIInstructionMessage(BECMessage): @@ -1293,7 +1369,7 @@ class GUIInstructionMessage(BECMessage): msg_type: ClassVar[str] = "gui_instruction_message" action: str - parameter: dict + parameter: JsonableDict class GUIAutoUpdateConfigMessage(BECMessage): @@ -1329,7 +1405,7 @@ class GUIRegistryStateMessage(BECMessage): "__rpc__", "container_proxy", ], - str | bool | dict | None, + str | bool | JsonableDict | None, ], ] @@ -1343,7 +1419,7 @@ class ServiceResponseMessage(BECMessage): """ msg_type: ClassVar[str] = "service_response_message" - response: dict + response: JsonableDict class CredentialsMessage(BECMessage): @@ -1355,7 +1431,7 @@ class CredentialsMessage(BECMessage): """ msg_type: ClassVar[str] = "credentials_message" - credentials: dict + credentials: JsonableDict class RawMessage(BECMessage): @@ -1368,7 +1444,7 @@ class RawMessage(BECMessage): """ msg_type: ClassVar[str] = "raw_message" - data: Any + data: Jsonable model_config = ConfigDict(arbitrary_types_allowed=True) @@ -1640,7 +1716,6 @@ class EndpointInfoMessage(BECMessage): msg_type: ClassVar[str] = "endpoint_info_message" endpoint: str - metadata: dict | None = Field(default_factory=dict) class ScriptExecutionInfoMessage(BECMessage): @@ -1673,8 +1748,6 @@ class MacroUpdateMessage(BECMessage): macro_name: str | None = None file_path: str | None = None - metadata: dict | None = Field(default_factory=dict) - @model_validator(mode="after") @classmethod def check_macro(cls, values): @@ -1769,7 +1842,6 @@ class MessagingServiceMessage(BECMessage): service_name: Literal["signal", "teams", "scilog"] message: list[MessagingServiceContent] scope: str | list[str] | None = None - metadata: dict | None = Field(default_factory=dict) class MessagingServiceConfig(BECMessage): @@ -1788,4 +1860,3 @@ class MessagingServiceConfig(BECMessage): service_name: Literal["signal", "teams", "scilog"] scopes: list[str] enabled: bool - metadata: dict | None = Field(default_factory=dict) diff --git a/bec_lib/bec_lib/one_way_registry.py b/bec_lib/bec_lib/one_way_registry.py new file mode 100644 index 000000000..d1f92a90e --- /dev/null +++ b/bec_lib/bec_lib/one_way_registry.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import Any, Callable, Type + +from bec_lib.device import DeviceBase + + +class OneWayBECCodec(ABC): + """Abstract base class for custom encoders""" + + obj_type: Type | list[Type] + + @staticmethod + @abstractmethod + def encode(obj: Any) -> Any: + """Encode an object into a serializable format.""" + + +class BECDeviceEncoder(OneWayBECCodec): + obj_type = DeviceBase + + @staticmethod + def encode(obj: DeviceBase) -> str: + if hasattr(obj, "_compile_function_path"): + # pylint: disable=protected-access + return obj._compile_function_path() + return obj.name + + +class OneWaySerializationRegistry: + """Registry for serialization codecs""" + + def __init__(self): + self._registry: dict[str, tuple[Type, Callable]] = {} + + self.register_codec(BECDeviceEncoder) + + def register_codec(self, codec: Type[OneWayBECCodec]): + """ + Register a codec for a specific BECCodec subclass. + This method allows for easy registration of custom encoders and decoders + for BECMessage and other types. + + Args: + codec: A subclass of BECCodec that implements encode and decode methods. + Raises: + ValueError: If a codec for the specified type is already registered. + """ + if isinstance(codec.obj_type, list): + for cls in codec.obj_type: + self.register(cls, codec.encode) + else: + self.register(codec.obj_type, codec.encode) + + def register(self, cls: Type, encoder: Callable): + """Register a codec for a specific type.""" + + if cls.__name__ in self._registry: + raise ValueError(f"Codec for {cls} already registered.") + self._registry[cls.__name__] = (cls, encoder) + self.get_codec.cache_clear() # Clear the cache when a new codec is registered + + @lru_cache(maxsize=2000) + def get_codec(self, cls: Type) -> tuple[Type, Callable] | None: + """Get the codec for a specific type.""" + codec = self._registry.get(cls.__name__) + if codec: + return codec + for _, (registered_cls, encoder) in self._registry.items(): + if issubclass(cls, registered_cls): + return registered_cls, encoder + return None + + def is_registered(self, cls: Type) -> bool: + """ + Check if a codec is registered for a specific type. + Args: + cls: The class type to check for a registered codec. + Returns: + bool: True if a codec is registered for the type, False otherwise. + """ + return self.get_codec(cls) is not None + + def encode(self, obj): + """Encode an object using the registered codec.""" + codec = self.get_codec(type(obj)) + if not codec: + return obj # No codec registered for this type + _, encoder = codec + try: + return encoder(obj) + except Exception as e: + raise ValueError( + f"Serialization failed: Failed to encode {obj.__class__.__name__} with codec {encoder}: {e}" + ) from e diff --git a/bec_lib/bec_lib/serialization_registry.py b/bec_lib/bec_lib/serialization_registry.py index e9aa923c8..7bcbfefcc 100644 --- a/bec_lib/bec_lib/serialization_registry.py +++ b/bec_lib/bec_lib/serialization_registry.py @@ -19,7 +19,6 @@ def __init__(self): self._legacy_codecs = [] # can be removed in future versions, see issue #516 self.register_codec(bec_codecs.BECMessageEncoder) - self.register_codec(bec_codecs.BECDeviceEncoder) self.register_codec(bec_codecs.EndpointInfoEncoder) self.register_codec(bec_codecs.SetEncoder) self.register_codec(bec_codecs.BECTypeEncoder) diff --git a/bec_lib/tests/test_bec_messages.py b/bec_lib/tests/test_bec_messages.py index 8439b6332..ac1c839df 100644 --- a/bec_lib/tests/test_bec_messages.py +++ b/bec_lib/tests/test_bec_messages.py @@ -429,15 +429,13 @@ def test_DeviceInstructionMessage(): def test_DeviceMonitor2DMessage(): # Test 2D data - msg = messages.DeviceMonitor2DMessage( - device="eiger", data=np.random.rand(2, 100), metadata=None - ) + msg = messages.DeviceMonitor2DMessage(device="eiger", data=np.random.rand(2, 100)) res = MsgpackSerialization.dumps(msg) res_loaded = MsgpackSerialization.loads(res) assert res_loaded == msg assert res_loaded.metadata == {} # Test rgb image, i.e. image with 3 channels - msg = messages.DeviceMonitor2DMessage(device="eiger", data=np.random.rand(3, 3), metadata=None) + msg = messages.DeviceMonitor2DMessage(device="eiger", data=np.random.rand(3, 3)) res = MsgpackSerialization.dumps(msg) res_loaded = MsgpackSerialization.loads(res) assert res_loaded == msg @@ -454,7 +452,7 @@ def test_DeviceMonitor2DMessage(): def test_DeviceMonitor1DMessage(): # Test 2D data - msg = messages.DeviceMonitor1DMessage(device="eiger", data=np.random.rand(100), metadata=None) + msg = messages.DeviceMonitor1DMessage(device="eiger", data=np.random.rand(100)) res = MsgpackSerialization.dumps(msg) res_loaded = MsgpackSerialization.loads(res) assert res_loaded == msg @@ -678,3 +676,9 @@ def test_valid_add_slice_various_indices(self, index): """Test various valid index values for add_slice type""" update = messages.DeviceAsyncUpdate(type="add_slice", max_shape=[None, 1024], index=index) assert update.index == index + + +def test_message_with_np_array_in_dict(): + arr = np.zeros(5) + msg = messages.ScanMessage(point_id=0, scan_id="", data={"device": {"value": arr}}, metadata={}) + assert isinstance(msg.data["device"]["value"], np.ndarray) diff --git a/bec_lib/tests/test_config_helper.py b/bec_lib/tests/test_config_helper.py index 72a9431e9..8cf830815 100644 --- a/bec_lib/tests/test_config_helper.py +++ b/bec_lib/tests/test_config_helper.py @@ -419,7 +419,8 @@ def test_config_helper_get_config_conflicts( config.update(dev_cfg) config_in_redis.append(config) with mock.patch.object(config_helper._device_manager.connector, "get") as mock_get: - mock_get.return_value = messages.AvailableResourceMessage(resource=config_in_redis) + available_resource_message = messages.AvailableResourceMessage(resource=config_in_redis) + mock_get.return_value = available_resource_message conflicts = config_helper._get_config_conflicts(new_config) assert conflicts == expected_conflicts diff --git a/bec_lib/tests/test_serializer.py b/bec_lib/tests/test_serializer.py index 93f5990de..e1a254e7b 100644 --- a/bec_lib/tests/test_serializer.py +++ b/bec_lib/tests/test_serializer.py @@ -10,6 +10,7 @@ from bec_lib.device import DeviceBase from bec_lib.devicemanager import DeviceManagerBase from bec_lib.endpoints import MessageEndpoints +from bec_lib.one_way_registry import OneWaySerializationRegistry from bec_lib.serialization import MsgpackSerialization, json_ext, msgpack @@ -81,10 +82,11 @@ class DummyModel(BaseModel): assert data.model_dump() == converted_data -def test_device_serializer(serializer): +def test_device_serializer(): + serializer = OneWaySerializationRegistry() device_manager = mock.MagicMock(spec=DeviceManagerBase) dummy = DeviceBase(name="dummy", parent=device_manager) - assert serializer.loads(serializer.dumps(dummy)) == "dummy" + assert serializer.encode(dummy) == "dummy" def test_enum_serializer(serializer): diff --git a/bec_server/tests/tests_scan_server/test_scan_worker.py b/bec_server/tests/tests_scan_server/test_scan_worker.py index 7e99f6037..615d9b02e 100644 --- a/bec_server/tests/tests_scan_server/test_scan_worker.py +++ b/bec_server/tests/tests_scan_server/test_scan_worker.py @@ -293,7 +293,7 @@ def test_initialize_scan_info(scan_worker_mock, msg): assert worker.current_scan_info["scan_msgs"] == [] assert worker.current_scan_info["monitor_sync"] == "bec" assert worker.current_scan_info["frames_per_trigger"] == 1 - assert worker.current_scan_info["args"] == {"samx": (-5, 5, 5), "samy": (-1, 1, 2)} + assert worker.current_scan_info["args"] == {"samx": [-5, 5, 5], "samy": [-1, 1, 2]} assert worker.current_scan_info["kwargs"] == msg.parameter["kwargs"] assert "samx" in worker.current_scan_info["readout_priority"]["monitored"] assert "samy" in worker.current_scan_info["readout_priority"]["baseline"] diff --git a/bec_server/tests/tests_scan_server/test_scans.py b/bec_server/tests/tests_scan_server/test_scans.py index 72ce9045a..902d78cb7 100644 --- a/bec_server/tests/tests_scan_server/test_scans.py +++ b/bec_server/tests/tests_scan_server/test_scans.py @@ -167,7 +167,7 @@ def offset_mock(): "RID": "0bab7ee3-b384-4571-b...0fff984c05", "devices": ["samx", "samy"], "start": [0, 0], - "end": np.array([1.0, 2.0]), + "end": [1.0, 2.0], } }, metadata={ @@ -211,7 +211,7 @@ def offset_mock(): "RID": "0bab7ee3-b384-4571-b...0fff984c05", "devices": ["samx", "samy", "samz"], "start": [0, 0, 0], - "end": np.array([1.0, 2.0, 3.0]), + "end": [1.0, 2.0, 3.0], } }, metadata={ @@ -264,7 +264,7 @@ def offset_mock(): "RID": "0bab7ee3-b384-4571-b...0fff984c05", "devices": ["samx"], "start": [0], - "end": np.array([1.0]), + "end": [1.0], } }, metadata={