diff --git a/bec_ipython_client/tests/end-2-end/test_scans_e2e.py b/bec_ipython_client/tests/end-2-end/test_scans_e2e.py index 238883c11..0b357668e 100644 --- a/bec_ipython_client/tests/end-2-end/test_scans_e2e.py +++ b/bec_ipython_client/tests/end-2-end/test_scans_e2e.py @@ -114,9 +114,9 @@ def test_mv_scan_nested_device(capsys, bec_ipython_client_fixture): bec.metadata.update({"unit_test": "test_mv_scan_nested_device"}) dev = bec.device_manager.devices scans.mv(dev.hexapod.x, 10, dev.hexapod.y, 20, relative=False).wait() - if not bec.connector._messages_queue.empty(): + while not bec.connector._messages_queue.empty(): print("Waiting for messages to be processed") - time.sleep(0.5) + time.sleep(0.1) current_pos_hexapod_x = dev.hexapod.x.read(cached=True)["hexapod_x"]["value"] current_pos_hexapod_y = dev.hexapod.y.read(cached=True)["hexapod_y"]["value"] assert np.isclose( @@ -126,9 +126,9 @@ def test_mv_scan_nested_device(capsys, bec_ipython_client_fixture): current_pos_hexapod_y, 20, atol=dev.hexapod._config["deviceConfig"].get("tolerance", 0.5) ) scans.umv(dev.hexapod.x, 10, dev.hexapod.y, 20, relative=False) - if not bec.connector._messages_queue.empty(): + while not bec.connector._messages_queue.empty(): print("Waiting for messages to be processed") - time.sleep(0.5) + time.sleep(0.1) current_pos_hexapod_x = dev.hexapod.x.read(cached=True)["hexapod_x"]["value"] current_pos_hexapod_y = dev.hexapod.y.read(cached=True)["hexapod_y"]["value"] captured = capsys.readouterr() diff --git a/bec_lib/bec_lib/bec_serializable.py b/bec_lib/bec_lib/bec_serializable.py new file mode 100644 index 000000000..6a6d84f0c --- /dev/null +++ b/bec_lib/bec_lib/bec_serializable.py @@ -0,0 +1,115 @@ +# pylint: disable=too-many-lines +from __future__ import annotations + +import base64 +from io import BytesIO +from types import UnionType +from typing import Annotated, Any, Callable, ClassVar, Union, get_args, get_origin + +import numpy as np +from pydantic import ( + BaseModel, + ConfigDict, + PlainSerializer, + WithJsonSchema, + computed_field, + model_validator, +) + +_NDARRAY_TAG = b"__NP_NDARRAY__" +_NDARRAY_TAG_STR = _NDARRAY_TAG.decode() +_NDARRAY_TAG_OFFSET = len(_NDARRAY_TAG) + + +def ndarray_to_bytes(arr: np.ndarray) -> bytes: + if not isinstance(arr, np.ndarray): + return arr + out_buf = BytesIO() + np.save(out_buf, arr) + return _NDARRAY_TAG + base64.urlsafe_b64encode(out_buf.getvalue()) + + +def numpy_decode(input: str | bytes): + is_str = isinstance(input, str) + is_bytes = isinstance(input, bytes) + # let pydantic handle any other validation or coercion + if not (is_str or is_bytes): + return input + if is_str and not input.startswith(_NDARRAY_TAG_STR): + return input + if is_bytes and not input.startswith(_NDARRAY_TAG): + return input + # strip the tag, decode, and load + io = BytesIO(base64.urlsafe_b64decode(input[_NDARRAY_TAG_OFFSET:])) + return np.load(io) + + +NumpyField = Annotated[ + np.ndarray, + PlainSerializer(ndarray_to_bytes), + WithJsonSchema({"type": "string", "contentEncoding": "base64"}), +] + + +def serialize_type(cls: type): + return cls.__name__ + + +class BecCodecInfo(BaseModel): + type_name: str + + +class BECSerializable(BaseModel): + + _deserialization_registry: ClassVar[ + list[tuple[tuple[type | Annotated, ...], Callable[[Any], Any]]] + ] = [((NumpyField, np.ndarray), numpy_decode)] + + model_config = ConfigDict( + json_schema_serialization_defaults_required=True, + json_encoders={np.ndarray: ndarray_to_bytes}, + arbitrary_types_allowed=True, + ) + + @computed_field() + @property + def bec_codec(self) -> BecCodecInfo: + return BecCodecInfo(type_name=self.__class__.__name__) + + @classmethod + def _try_apply_registry(cls, anno: type, data: dict, field: str): + for entry, deserializer in cls._deserialization_registry: + if anno in entry: + data[field] = deserializer(data[field]) + + @model_validator(mode="before") + @classmethod + def deser_custom(cls, data: dict[str, Any]): + for field in data: + if (field_info := cls.model_fields.get(field)) is not None: + if field_info.annotation is None: + continue # No need to do anything for NoneType + if get_origin(field_info.annotation) in [UnionType, Union]: + for arg in get_args(field_info.annotation): + cls._try_apply_registry(arg, data, field) + else: + cls._try_apply_registry(field_info.annotation, data, field) + return data + + +class BecWrappedValue(BECSerializable): + data: np.ndarray # can be extended, must be in registry + + def __getattr__(self, name: str) -> Any: + if hasattr(self.data, name): + return getattr(self.data, name) + else: + raise AttributeError( + f"{self.__class__.__name__} wrapping data type {type(self.data)} has no attribute {name} on either itself or its data." + ) + + def __getitem__(self, item): + if hasattr(self.data, "__getitem__"): + return self.data.__getitem__(item) + else: + raise AttributeError(f"Wrapped data type {type(self.data)} has no __getitem__.") diff --git a/bec_lib/bec_lib/codecs.py b/bec_lib/bec_lib/codecs.py index 0df6fb31a..6372f882e 100644 --- a/bec_lib/bec_lib/codecs.py +++ b/bec_lib/bec_lib/codecs.py @@ -1,17 +1,9 @@ from __future__ import annotations -import enum from abc import ABC, abstractmethod from typing import Any, Type -import numpy as np -from pydantic import BaseModel - -from bec_lib import messages as messages_module -from bec_lib import numpy_encoder from bec_lib.device import DeviceBase -from bec_lib.endpoints import EndpointInfo -from bec_lib.messages import BECMessage, BECStatus class BECCodec(ABC): @@ -24,64 +16,9 @@ class BECCodec(ABC): def encode(obj: Any) -> Any: """Encode an object into a serializable format.""" - @staticmethod - @abstractmethod - def decode(type_name: str, data: dict): - """Decode data into an object.""" - - -class NumpyEncoder(BECCodec): - obj_type: list[Type] = [np.ndarray, np.bool_, np.number, complex] - - @staticmethod - def encode(obj: np.ndarray) -> dict: - return numpy_encoder.numpy_encode(obj) - - @staticmethod - def decode(type_name: str, data: dict) -> np.ndarray: - return numpy_encoder.numpy_decode(data) - - -class NumpyEncoderList(BECCodec): - obj_type: list[Type] = [np.ndarray, np.bool_, np.number, complex] - - @staticmethod - def encode(obj: np.ndarray) -> dict: - return numpy_encoder.numpy_encode_list(obj) - - @staticmethod - def decode(type_name: str, data: dict) -> np.ndarray: - return numpy_encoder.numpy_decode_list(data) - - -class BECMessageEncoder(BECCodec): - obj_type: Type = BECMessage - - @staticmethod - def encode(obj: BECMessage) -> dict: - return obj.__dict__ - - @staticmethod - def decode(type_name: str, data: dict) -> BECMessage: - return getattr(messages_module, type_name)(**data) - - -class EnumEncoder(BECCodec): - obj_type: Type = enum.Enum - - @staticmethod - def encode(obj: enum.Enum) -> Any: - return obj.value - - @staticmethod - def decode(type_name: str, data: Any) -> Any: - if type_name == "BECStatus": - return BECStatus(data) - return data - class BECDeviceEncoder(BECCodec): - obj_type: Type = DeviceBase + obj_type = DeviceBase @staticmethod def encode(obj: DeviceBase) -> str: @@ -89,72 +26,3 @@ def encode(obj: DeviceBase) -> str: # pylint: disable=protected-access return obj._compile_function_path() return obj.name - - @staticmethod - def decode(type_name: str, data: str) -> str: - """ - DeviceBase objects are encoded as strings. No decoding is necessary. - """ - return data - - -class PydanticEncoder(BECCodec): - obj_type: Type = BaseModel - - @staticmethod - def encode(obj: BaseModel) -> dict: - return obj.model_dump() - - @staticmethod - def decode(type_name: str, data: dict) -> dict: - return data - - -class EndpointInfoEncoder(BECCodec): - obj_type: Type = EndpointInfo - - @staticmethod - def encode(obj: EndpointInfo) -> dict: - return { - "endpoint": obj.endpoint, - "message_type": obj.message_type.__name__, - "message_op": obj.message_op, - } - - @staticmethod - def decode(type_name: str, data: dict) -> EndpointInfo: - return EndpointInfo( - endpoint=data["endpoint"], - message_type=getattr(messages_module, data["message_type"]), - message_op=data["message_op"], - ) - - -class SetEncoder(BECCodec): - obj_type: Type = set - - @staticmethod - def encode(obj: set) -> list: - return list(obj) - - @staticmethod - def decode(type_name: str, data: list) -> set: - return set(data) - - -class BECTypeEncoder(BECCodec): - obj_type: Type = type - - @staticmethod - def encode(obj: type) -> dict: - return {"type_name": obj.__name__, "module": obj.__module__} - - @staticmethod - def decode(type_name: str, data: dict) -> type: - if data["module"] == "builtins": - return __builtins__.get(data["type_name"]) - if data["module"] == "bec_lib.messages": - return getattr(messages_module, data["type_name"]) - if data["module"] == "numpy": - return getattr(np, data["type_name"]) - raise ValueError(f"Unknown type {data['type_name']} in module {data['module']}") diff --git a/bec_lib/bec_lib/device.py b/bec_lib/bec_lib/device.py index 3ccfbf40b..0fbf9a30b 100644 --- a/bec_lib/bec_lib/device.py +++ b/bec_lib/bec_lib/device.py @@ -24,6 +24,7 @@ from bec_lib.atlas_models import _DeviceModelCore from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger +from bec_lib.messages import DeviceMessage from bec_lib.queue_items import QueueItem from bec_lib.utils.import_utils import lazy_import @@ -1111,12 +1112,15 @@ def limits(self): """ Returns the device limits. """ - limit_msg = self.root.parent.connector.get(MessageEndpoints.device_limits(self.root.name)) + limit_msg: DeviceMessage = self.root.parent.connector.get( + MessageEndpoints.device_limits(self.root.name) + ) if not limit_msg: return [0, 0] + limits = [ - limit_msg.content["signals"].get("low", {}).get("value", 0), - limit_msg.content["signals"].get("high", {}).get("value", 0), + limit_msg.signals.get("low", {}).get("value", 0), + limit_msg.signals.get("high", {}).get("value", 0), ] return limits diff --git a/bec_lib/bec_lib/endpoints.py b/bec_lib/bec_lib/endpoints.py index 6b50e8408..a6bf09162 100644 --- a/bec_lib/bec_lib/endpoints.py +++ b/bec_lib/bec_lib/endpoints.py @@ -5,23 +5,18 @@ from __future__ import annotations import enum -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from functools import lru_cache +from typing import Annotated, Any, Callable, ClassVar, Generic, TypeVar -from bec_lib.utils.import_utils import lazy_import +from pydantic import PlainSerializer + +from bec_lib import messages +from bec_lib.bec_serializable import BECSerializable, serialize_type # pylint: disable=too-many-public-methods # pylint: disable=too-many-lines -if TYPE_CHECKING: # pragma: no cover - from bec_lib import messages -else: - # TODO: put back normal import when Pydantic gets faster - # from bec_lib import messages - messages = lazy_import("bec_lib.messages") - - class EndpointType(str, enum.Enum): """Endpoint type enum""" @@ -49,8 +44,21 @@ class MessageOp(list[str], enum.Enum): MessageType = TypeVar("MessageType", bound="type[messages.BECMessage]") -@dataclass -class EndpointInfo(Generic[MessageType]): +@lru_cache() +def _resolve_msg_type(n: type[messages.BECMessage] | str): + if n is Any or n == "Any": + return Any + if isinstance(n, str): + return messages.__dict__.get(n, messages.RawMessage) + if issubclass(n, messages.BECMessage): + return n + raise TypeError(f"Invalid argument for _resolve_msg_type: {n}") + + +MesageAnno = Annotated[MessageType, PlainSerializer(serialize_type)] + + +class EndpointInfo(BECSerializable, Generic[MessageType]): """ Dataclass for endpoint info. @@ -60,8 +68,10 @@ class EndpointInfo(Generic[MessageType]): message_op (MessageOp): Message operation. """ + _deserialization_registry = [((MesageAnno,), _resolve_msg_type)] + endpoint: str - message_type: MessageType + message_type: MesageAnno | Any message_op: MessageOp diff --git a/bec_lib/bec_lib/logger.py b/bec_lib/bec_lib/logger.py index 201efb3a8..51a79de7a 100644 --- a/bec_lib/bec_lib/logger.py +++ b/bec_lib/bec_lib/logger.py @@ -16,8 +16,8 @@ # TODO: Importing bec_lib, instead of `from bec_lib.messages import LogMessage`, avoids potential # logger <-> messages circular import. But there could be a better solution. import bec_lib +import bec_lib.endpoints from bec_lib.bec_errors import ServiceConfigError -from bec_lib.endpoints import MessageEndpoints from bec_lib.utils.import_utils import lazy_import_from if TYPE_CHECKING: # pragma: no cover @@ -210,7 +210,7 @@ def _logger_callback(self, msg): msg["service_name"] = self.service_name try: self.connector.xadd( - topic=MessageEndpoints.log(), + topic=bec_lib.endpoints.MessageEndpoints.log(), msg_dict={ "data": bec_lib.messages.LogMessage( log_type=msg["record"]["level"]["name"].lower(), log_msg=msg diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index 23bed4c58..23fce1ee5 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -3,18 +3,30 @@ import time import uuid -import warnings from copy import deepcopy from enum import Enum, auto from importlib.metadata import PackageNotFoundError from importlib.metadata import version as importlib_version -from typing import Any, ClassVar, Literal, Self, Union +from typing import Any, ClassVar, Literal, Self from uuid import uuid4 import numpy as np -from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from bec_lib.metadata_schema import get_metadata_schema_for_scan +from bec_lib.bec_serializable import BECSerializable, NumpyField + +# TODO: restore when moved to external repo +# from bec_lib.metadata_schema import get_metadata_schema_for_scan + +logger = None + + +def lazy_ensure_logger(): + global logger + if logger is None: + from bec_lib.logger import bec_logger + + logger = bec_logger.logger class ProcedureWorkerStatus(Enum): @@ -34,7 +46,7 @@ class BECStatus(Enum): ERROR = -1 -class BECMessage(BaseModel): +class BECMessage(BECSerializable): """Base Model class for BEC Messages Args: @@ -43,19 +55,9 @@ class BECMessage(BaseModel): """ - msg_type: ClassVar[str] + msg_type: ClassVar[str] = "bec_message" metadata: dict = Field(default_factory=dict) - @field_validator("metadata") - @classmethod - def check_metadata(cls, v): - """Validate the metadata, return empty dict if None - - Args: - v (dict, None): Metadata dictionary - """ - return v or {} - @property def content(self): """Return the content of the message""" @@ -75,20 +77,6 @@ def __eq__(self, other): return self.msg_type == other.msg_type and self.metadata == other.metadata - def loads(self): - warnings.warn( - "BECMessage.loads() is deprecated and should not be used anymore. When calling Connector methods, it can be omitted. When a message needs to be deserialized call the appropriate function from bec_lib.serialization", - FutureWarning, - ) - return self - - def dumps(self): - warnings.warn( - "BECMessage.dumps() is deprecated and should not be used anymore. When calling Connector methods, it can be omitted. When a message needs to be serialized call the appropriate function from bec_lib.serialization", - FutureWarning, - ) - return self - def __hash__(self) -> int: return self.model_dump_json().__hash__() @@ -140,6 +128,7 @@ class ScanQueueMessage(BECMessage): """ msg_type: ClassVar[str] = "scan_queue_message" + scan_type: str parameter: dict queue: str = Field(default="primary") @@ -148,19 +137,20 @@ class ScanQueueMessage(BECMessage): description="Whether the server is allowed to restart the scan if needed. If False, only a ScanRestartMessage will be sent.", ) - @model_validator(mode="after") - @classmethod - def _validate_metadata(cls, data): - """Make sure the metadata conforms to the registered schema, but - leave it as a dict""" - schema = get_metadata_schema_for_scan(data.scan_type) - try: - schema.model_validate(data.metadata.get("user_metadata", {})) - except ValidationError as e: - raise ValueError( - f"Scan metadata {data.metadata} does not conform to registered schema {schema}. \n Errors: {str(e)}" - ) from e - return data + # TODO: restore when moved to external repo + # @model_validator(mode="after") + # @classmethod + # def _validate_metadata(cls, data): + # """Make sure the metadata conforms to the registered schema, but + # leave it as a dict""" + # schema = get_metadata_schema_for_scan(data.scan_type) + # try: + # schema.model_validate(data.metadata.get("user_metadata", {})) + # except ValidationError as e: + # raise ValueError( + # f"Scan metadata {data.metadata} does not conform to registered schema {schema}. \n Errors: {str(e)}" + # ) from e + # return data class ScanQueueHistoryMessage(BECMessage): @@ -230,12 +220,11 @@ class ScanStatusMessage(BECMessage): dict[Literal["monitored", "baseline", "async", "continuous", "on_request"], list[str]] | None ) = None - scan_parameters: dict[ - Literal["exp_time", "frames_per_trigger", "settling_time", "readout_time"] | str, Any - ] = Field(default_factory=dict) - request_inputs: dict[Literal["arg_bundle", "inputs", "kwargs"], Any] = Field( - default_factory=dict - ) + scan_parameters: ( + dict[Literal["exp_time", "frames_per_trigger", "settling_time", "readout_time"] | str, Any] + | None + ) = None + request_inputs: dict[Literal["arg_bundle", "inputs", "kwargs"], Any] | None = None info: dict timestamp: float = Field(default_factory=time.time) @@ -326,7 +315,7 @@ class ScanQueueOrderMessage(BECMessage): target_position: int | None = None -class RequestBlock(BaseModel): +class RequestBlock(BECSerializable): """ Model for a request block within a scan queue entry. It represents a single request in the scan queue, e.g. a single scan or rpc call. @@ -354,7 +343,7 @@ class RequestBlock(BaseModel): report_instructions: list[dict] | None = None -class QueueInfoEntry(BaseModel): +class QueueInfoEntry(BECSerializable): """ Model for scan queue information entries. It represents a single queue element within a scan queue but may contain multiple request blocks. @@ -390,7 +379,7 @@ class ScanQueueLock(BaseModel): identifier: str -class ScanQueueStatus(BaseModel): +class ScanQueueStatus(BECSerializable): """ Model for scan queue status information. It represents the status of a single queue, e.g. "primary" or "interception". @@ -553,7 +542,7 @@ class DeviceInstructionMessage(BECMessage): parameter: dict -class ErrorInfo(BaseModel): +class ErrorInfo(BECSerializable): id: str = Field(default_factory=lambda: str(uuid.uuid4())) error_message: str compact_error_message: str | None @@ -577,6 +566,34 @@ def _ensure_error_info_if_error(self): return self +class SignalReading(BECSerializable): + value: int | float | list[int] | list[float] | NumpyField | Any + timestamp: float | list[float] | None = None + + def to_dict(self): + lazy_ensure_logger() + logger.warning( + "Dictionary usage of SignalReading is deprecated; please replace it with a different access pattern." + ) + return {"value": self.value, "timestamp": self.timestamp} + + def get(self, item: Literal["value", "timestamp"], default=Any): + """Allow dictionary-style access for legacy reasons.""" + lazy_ensure_logger() + logger.warning( + "Get-access on SignalReading is deprecated; Just access the model.value field." + ) + if item not in ["value", "timestamp"]: + raise KeyError('SignalReading only has "value" and "timestamp" fields!') + return getattr(self, item) + + def __getitem__(self, item: str): + return self.get(item) + + def items(self): + return self.to_dict().items() + + class DeviceMessage(BECMessage): """Message type for sending device readings from the device server @@ -589,7 +606,7 @@ class DeviceMessage(BECMessage): """ msg_type: ClassVar[str] = "device_message" - signals: dict[str, dict[Literal["value", "timestamp"], Any]] + signals: dict[str, SignalReading] @field_validator("metadata") @classmethod @@ -606,7 +623,7 @@ def check_metadata(cls, v): return v -class DeviceAsyncUpdate(BaseModel): +class DeviceAsyncUpdate(BECSerializable): """Model for validating async update metadata sent with device data. The async update metadata controls how data is aggregated into datasets during a scan: @@ -697,6 +714,7 @@ class DeviceRPCMessage(BECMessage): """ msg_type: ClassVar[str] = "device_rpc_message" + device: str return_val: Any out: str | dict | ErrorInfo @@ -757,20 +775,15 @@ class DeviceMonitor2DMessage(BECMessage): Args: device (str): Device name. - data (np.ndarray): Numpy array data from the monitor + data (NumpyField): Numpy array data from the monitor metadata (dict, optional): Additional metadata. """ msg_type: ClassVar[str] = "device_monitor2d_message" device: str - data: np.ndarray timestamp: float = Field(default_factory=time.time) - - metadata: dict | None = Field(default_factory=dict) - - # Needed for pydantic to accept numpy arrays - model_config = ConfigDict(arbitrary_types_allowed=True) + data: NumpyField @field_validator("data") @classmethod @@ -778,7 +791,7 @@ def check_data(cls, v: np.ndarray): """Validate the entry in data. Has to be a 2D numpy array Args: - v (np.ndarray): data array + v (NumpyField): data array """ if not isinstance(v, np.ndarray): raise ValueError(f"Invalid array type: {type(v)}. Must be a numpy array.") @@ -798,20 +811,15 @@ class DeviceMonitor1DMessage(BECMessage): Args: device (str): Device name. - data (np.ndarray): Numpy array data from the monitor + data (NumpyField): Numpy array data from the monitor metadata (dict, optional): Additional metadata. """ msg_type: ClassVar[str] = "device_monitor1d_message" device: str - data: np.ndarray timestamp: float = Field(default_factory=time.time) - - metadata: dict | None = Field(default_factory=dict) - - # Needed for pydantic to accept numpy arrays - model_config = ConfigDict(arbitrary_types_allowed=True) + data: NumpyField @field_validator("data") @classmethod @@ -819,7 +827,7 @@ def check_data(cls, v: np.ndarray): """Validate the entry in data. Has to be a 2D numpy array Args: - v (np.ndarray): data array + v (NumpyField): data array """ if not isinstance(v, np.ndarray): raise ValueError(f"Invalid array type: {type(v)}. Must be a numpy array.") @@ -837,7 +845,7 @@ class DevicePreviewMessage(BECMessage): Args: device (str): Device name. signal (str): Signal name, e.g. "image", "data", "preview". - data (np.ndarray): Numpy array data from the preview. + data (NumpyField): Numpy array data from the preview. timestamp (float, optional): Timestamp of the message. Defaults to time.time(). metadata (dict, optional): Additional metadata. """ @@ -845,10 +853,8 @@ class DevicePreviewMessage(BECMessage): msg_type: ClassVar[str] = "device_preview_message" device: str signal: str - data: np.ndarray timestamp: float = Field(default_factory=time.time) - # Needed for pydantic to accept numpy arrays - model_config = ConfigDict(arbitrary_types_allowed=True) + data: NumpyField class DeviceUserROIMessage(BECMessage): @@ -898,8 +904,7 @@ class ScanHistoryMessage(BECMessage): scan_number (int): Scan number. dataset_number (int): Dataset number. file_path (str): Path to the file. - exit_status (Literal["closed", "aborted", "halted", "user_completed"]): Exit status of the scan. - reason (Literal["user", "alarm"] | None, optional): Reason for the exit status, if applicable. + exit_status (Literal["closed", "aborted", "halted"]): Exit status of the scan. start_time (float): Start time of the scan. end_time (float): End time of the scan. scan_name (str): Name of the scan. @@ -915,8 +920,7 @@ class ScanHistoryMessage(BECMessage): scan_number: int dataset_number: int file_path: str - exit_status: Literal["closed", "aborted", "halted", "user_completed"] - reason: Literal["user", "alarm"] | None = None + exit_status: Literal["closed", "aborted", "halted"] start_time: float end_time: float scan_name: str @@ -1031,7 +1035,7 @@ class AlarmMessage(BECMessage): info: ErrorInfo -class ServiceVersions(BaseModel): +class ServiceVersions(BECSerializable): _versions: ClassVar[Self | None] = None bec_lib: str @@ -1059,7 +1063,7 @@ def _get_safe_version(package: str) -> str: return cls._versions -class ServiceInfo(BaseModel): +class ServiceInfo(BECSerializable): user: str hostname: str timestamp: float = Field(default_factory=time.time) @@ -1220,7 +1224,7 @@ class DAPResponseMessage(BECMessage): success (bool): True if the request was successful data (tuple, optional): DAP data (tuple of data (dict) and metadata). Defaults to ({} , None). error (str, optional): DAP error. Defaults to None. - dap_request (BECMessage, None): DAP request. Defaults to None. + dap_request (DAPRequestMessage, None): DAP request. Defaults to None. metadata (dict, optional): Metadata. Defaults to None. """ @@ -1228,19 +1232,21 @@ class DAPResponseMessage(BECMessage): success: bool data: tuple | None = Field(default_factory=lambda: ({}, None)) error: str | None = None - dap_request: BECMessage | None = Field(default=None) + dap_request: DAPRequestMessage | None = Field(default=None) class AvailableResourceMessage(BECMessage): """Message for available resources such as scans, data processing plugins etc Args: - resource (dict, list[dict], BECMessage, list[BECMessage]): Resource description + resource (dict, list[dict]): Resource description metadata (dict, optional): Metadata. Defaults to None. """ msg_type: ClassVar[str] = "available_resource_message" - resource: dict | list[dict] | BECMessage | list[BECMessage] + resource: ( + dict | list[dict] | list[MessagingServiceConfig] + ) # | BECSerializable | list[BECSerializable] class ProgressMessage(BECMessage): @@ -1640,7 +1646,6 @@ class EndpointInfoMessage(BECMessage): msg_type: ClassVar[str] = "endpoint_info_message" endpoint: str - metadata: dict | None = Field(default_factory=dict) class ScriptExecutionInfoMessage(BECMessage): @@ -1673,8 +1678,6 @@ class MacroUpdateMessage(BECMessage): macro_name: str | None = None file_path: str | None = None - metadata: dict | None = Field(default_factory=dict) - @model_validator(mode="after") @classmethod def check_macro(cls, values): @@ -1707,6 +1710,8 @@ class MessagingServiceFileContent(BaseModel): data (bytes): File data """ + model_config = ConfigDict(ser_json_bytes="base64", val_json_bytes="base64") + filename: str mime_type: str data: bytes @@ -1745,13 +1750,13 @@ class MessagingServiceGiphyContent(BaseModel): giphy_url: str -MessagingServiceContent = Union[ - MessagingServiceTextContent, - MessagingServiceFileContent, - MessagingServiceTagsContent, - MessagingServiceStickerContent, - MessagingServiceGiphyContent, -] +MessagingServiceContent = ( + MessagingServiceTextContent + | MessagingServiceFileContent + | MessagingServiceTagsContent + | MessagingServiceStickerContent + | MessagingServiceGiphyContent +) class MessagingServiceMessage(BECMessage): @@ -1769,7 +1774,6 @@ class MessagingServiceMessage(BECMessage): service_name: Literal["signal", "teams", "scilog"] message: list[MessagingServiceContent] scope: str | list[str] | None = None - metadata: dict | None = Field(default_factory=dict) class MessagingServiceConfig(BECMessage): @@ -1788,4 +1792,3 @@ class MessagingServiceConfig(BECMessage): service_name: Literal["signal", "teams", "scilog"] scopes: list[str] enabled: bool - metadata: dict | None = Field(default_factory=dict) diff --git a/bec_lib/bec_lib/redis_connector.py b/bec_lib/bec_lib/redis_connector.py index 7cfeca3fc..f7cff6a45 100644 --- a/bec_lib/bec_lib/redis_connector.py +++ b/bec_lib/bec_lib/redis_connector.py @@ -1148,10 +1148,7 @@ def get(self, topic: str, pipe: Pipeline | None = None): if pipe: return data else: - try: - return MsgpackSerialization.loads(data) - except RuntimeError: - return data + return MsgpackSerialization.loads(data) def mget(self, topics: list[str], pipe: Pipeline | None = None): """retrieve multiple entries""" diff --git a/bec_lib/bec_lib/serialization.py b/bec_lib/bec_lib/serialization.py index f512c279e..b09f8a4fe 100644 --- a/bec_lib/bec_lib/serialization.py +++ b/bec_lib/bec_lib/serialization.py @@ -7,57 +7,57 @@ import contextlib import gc import json -from abc import abstractmethod +import types +from functools import lru_cache +from typing import Any -import msgpack as msgpack_module +import msgpack as msgpack_mod +from pydantic import BaseModel -from bec_lib import messages as messages_module -from bec_lib.logger import bec_logger -from bec_lib.messages import BECMessage +from bec_lib import bec_serializable, endpoints, messages +from bec_lib.bec_serializable import BECSerializable, BecWrappedValue +from bec_lib.device import DeviceBase +from bec_lib.messages import BECMessage, BundleMessage, RawMessage from bec_lib.serialization_registry import SerializationRegistry -logger = bec_logger.logger - -class SerializationInterface: - """Base class for message serialization""" - - @abstractmethod - def loads(self, msg, **kwargs) -> dict: - """load and de-serialize a message""" - - @abstractmethod - def dumps(self, msg, **kwargs) -> str: - """serialize a message""" - - -class BECMessagePack(SerializationRegistry): - """Encapsulates msgpack dumps/loads with extensions""" - - def dumps(self, obj): - """Pack object `obj` and return packed bytes.""" - return msgpack_module.packb(obj, default=self.encode) - - def loads(self, raw_bytes): - """Unpack bytes and return the decoded object.""" - out = msgpack_module.unpackb( - raw_bytes, raw=False, strict_map_key=True, object_hook=self.decode - ) - return out - - -class BECJson(SerializationRegistry): - """Encapsulates JSON dumps/loads with extensions""" - - use_json = True - - def dumps(self, obj, indent: int | None = None) -> str: - """Pack object `obj` and return packed bytes.""" - return json.dumps(obj, default=self.encode, indent=indent) - - def loads(self, raw_bytes): - """Unpack bytes and return the decoded object.""" - return json.loads(raw_bytes, object_hook=self.decode) +@lru_cache(maxsize=2048) +def _get_type(type_name: str) -> type[BECSerializable] | None: + for mod in BecSerializableCodec.registry.values(): + if (T := mod.__dict__.get(type_name)) is not None: + if not issubclass(T, BECSerializable): + raise RuntimeError( + f"BecSerializableCodec found type {T} in module {mod} from bec_codec type info '{type_name}'. Please ensure another type isn't shadowing the correct one." + ) + return T + + +class BecSerializableCodec: + # dicts are ordered by insertion, so later additions will not override these. + registry: dict[str, types.ModuleType] = { + "messages": messages, + "endpoints": endpoints, + "bec_serializable": bec_serializable, + } + + @classmethod + def register_module(cls, mod: types.ModuleType) -> None: + if mod.__name__ in cls.registry: + raise ValueError(f"A module named {mod.__name__} is already registered!") + cls.registry[mod.__name__] = mod + _get_type.cache_clear() + + @classmethod + def encode(cls, obj: BaseModel) -> dict: + return obj.model_dump(mode="json") + + @classmethod + def decode(cls, type_name: str, data: dict) -> BECSerializable | dict: + if (BecType := _get_type(type_name)) is not None: + if BecType is BecWrappedValue: + return BecType.model_validate(data).data + return BecType.model_validate(data) + return data @contextlib.contextmanager @@ -75,35 +75,67 @@ def pause_gc(): gc.enable() -class MsgpackSerialization(SerializationInterface): - """Message serialization using msgpack encoding""" +def _msg_object_hook(msg: dict): + bec_type_name: str | None = msg.get("bec_codec", {}).get("type_name") + if bec_type_name is None: + return msg + return BecSerializableCodec.decode(bec_type_name, msg) - ext_type_offset_to_data = {199: 3, 200: 4, 201: 6} + +def _one_way_encoding(val: Any) -> Any: + # TODO hacky fix, tidy up + try: + return msgpack.encode(val) + except NoCodec: + try: + return BecWrappedValue(data=val) + except: + ... + except Exception as e: + raise TypeError(f"Type {type(val)} not supported for serialization!") from e + + +class MsgpackSerialization(SerializationRegistry): + """Message serialization using msgpack encoding""" @staticmethod - def loads(msg) -> BECMessage | list[BECMessage]: + def loads(msg: bytes) -> BECMessage | list[BECMessage] | Any: + if msg is None: + return None with pause_gc(): try: - msg = msgpack.loads(msg) - except Exception as exception: + msg_ = msgpack_mod.loads(msg, object_hook=_msg_object_hook) + except Exception as e: try: - data = json.loads(msg) - return messages_module.RawMessage(data=data) + return RawMessage(data=json.loads(msg, object_hook=_msg_object_hook)) except Exception: - pass - raise RuntimeError("Failed to decode BECMessage") from exception + raise RuntimeError(f"Failed to decode BECMessage: {msg}") from e else: - if isinstance(msg, BECMessage): - if msg.msg_type == "bundle_message": - return msg.messages - return msg + if isinstance(msg_, BundleMessage): + return msg_.messages + return msg_ @staticmethod - def dumps(msg, version=None) -> str: - if version is None or version == 1.2: - return msgpack.dumps(msg) - raise RuntimeError(f"Unsupported BECMessage version {version}.") + def dumps(msg: BECMessage | Any) -> str: + if not isinstance(msg, BECSerializable): + return msgpack_mod.dumps(msg) # type: ignore + return msgpack_mod.dumps(msg.model_dump(mode="json", fallback=_one_way_encoding)) # type: ignore + +# TODO change name and tidy up +msgpack = MsgpackSerialization() -msgpack = BECMessagePack() -json_ext = BECJson() + +class json_ext: + """Message serialization using json encoding""" + + @staticmethod + def loads(msg) -> BECMessage | list[BECMessage] | Any: + with pause_gc(): + return json.loads(msg, object_hook=_msg_object_hook) + + @staticmethod + def dumps(msg: BECMessage | Any, indent: int = 0) -> str: + if not isinstance(msg, BECSerializable): + return json.dumps(msg, indent=indent) # type: ignore + return msg.model_dump_json(indent=indent) diff --git a/bec_lib/bec_lib/serialization_registry.py b/bec_lib/bec_lib/serialization_registry.py index e9aa923c8..87492776f 100644 --- a/bec_lib/bec_lib/serialization_registry.py +++ b/bec_lib/bec_lib/serialization_registry.py @@ -9,27 +9,16 @@ logger = bec_logger.logger +class NoCodec(Exception): ... + + class SerializationRegistry: """Registry for serialization codecs""" - use_json = False - def __init__(self): - self._registry: dict[str, tuple[Type, Callable, Callable]] = {} - self._legacy_codecs = [] # can be removed in future versions, see issue #516 + self._registry: dict[str, tuple[Type, Callable]] = {} - self.register_codec(bec_codecs.BECMessageEncoder) self.register_codec(bec_codecs.BECDeviceEncoder) - self.register_codec(bec_codecs.EndpointInfoEncoder) - self.register_codec(bec_codecs.SetEncoder) - self.register_codec(bec_codecs.BECTypeEncoder) - self.register_codec(bec_codecs.PydanticEncoder) - self.register_codec(bec_codecs.EnumEncoder) - - if self.use_json: - self.register_codec(bec_codecs.NumpyEncoderList) - else: - self.register_codec(bec_codecs.NumpyEncoder) def register_codec(self, codec: Type[bec_codecs.BECCodec]): """ @@ -44,27 +33,27 @@ def register_codec(self, codec: Type[bec_codecs.BECCodec]): """ if isinstance(codec.obj_type, list): for cls in codec.obj_type: - self.register(cls, codec.encode, codec.decode) + self.register(cls, codec.encode) else: - self.register(codec.obj_type, codec.encode, codec.decode) + self.register(codec.obj_type, codec.encode) - def register(self, cls: Type, encoder: Callable, decoder: Callable): + def register(self, cls: Type, encoder: Callable, *_): # hacky fix for BW compat """Register a codec for a specific type.""" if cls.__name__ in self._registry: raise ValueError(f"Codec for {cls} already registered.") - self._registry[cls.__name__] = (cls, encoder, decoder) + self._registry[cls.__name__] = (cls, encoder) self.get_codec.cache_clear() # Clear the cache when a new codec is registered @lru_cache(maxsize=2000) - def get_codec(self, cls: Type) -> tuple[Type, Callable, Callable] | None: + def get_codec(self, cls: Type) -> tuple[Type, Callable] | None: """Get the codec for a specific type.""" codec = self._registry.get(cls.__name__) if codec: return codec - for _, (registered_cls, encoder, decoder) in self._registry.items(): + for _, (registered_cls, encoder) in self._registry.items(): if issubclass(cls, registered_cls): - return registered_cls, encoder, decoder + return registered_cls, encoder return None def is_registered(self, cls: Type) -> bool: @@ -81,33 +70,11 @@ def encode(self, obj): """Encode an object using the registered codec.""" codec = self.get_codec(type(obj)) if not codec: - return obj # No codec registered for this type - cls, encoder, _ = codec - try: - return { - "__bec_codec__": { - "encoder_name": cls.__name__, - "type_name": obj.__class__.__name__, - "data": encoder(obj), - } - } - except Exception as e: - raise ValueError( - f"Serialization failed: Failed to encode {obj.__class__.__name__} with codec {encoder}: {e}" - ) from e - - def decode(self, data): - """Decode an object using the registered codec.""" - if not isinstance(data, dict) or "__bec_codec__" not in data: - return data - codec_info = data["__bec_codec__"] - codec_type = codec_info.pop("encoder_name") - if not codec_type or codec_type not in self._registry: - return data - _, _, decoder = self._registry[codec_type] + raise NoCodec() # No codec registered for this type try: - return decoder(**codec_info) + _, encoder = codec + return encoder(obj) except Exception as e: raise ValueError( - f"Deserialization failed: Failed to decode {codec_type} with codec {decoder}: {e}" + f"Serialization failed: Failed to encode {obj.__class__.__name__} with codec {codec}: {e}" ) from e diff --git a/bec_lib/bec_lib/tests/utils.py b/bec_lib/bec_lib/tests/utils.py index a318be1bb..43c3350aa 100644 --- a/bec_lib/bec_lib/tests/utils.py +++ b/bec_lib/bec_lib/tests/utils.py @@ -860,3 +860,7 @@ def redis_server_is_running(self): def get_last(self, topic, key): return None + + +def _endpoint_info(ep, msg_t, msg_op): + return EndpointInfo(endpoint=ep, message_type=msg_t, message_op=msg_op) diff --git a/bec_lib/tests/test_bec_messages.py b/bec_lib/tests/test_bec_messages.py index 8439b6332..3908a4b55 100644 --- a/bec_lib/tests/test_bec_messages.py +++ b/bec_lib/tests/test_bec_messages.py @@ -6,32 +6,17 @@ from bec_lib.serialization import MsgpackSerialization -@pytest.mark.parametrize("version", [1.0, 1.1, 1.2, None]) -def test_bec_message_msgpack_serialization_version(version): - msg = messages.DeviceInstructionMessage( - device="samx", action="set", parameter={"set": 0.5}, metadata={"RID": "1234"} - ) - if version is not None and version < 1.2: - with pytest.raises(RuntimeError) as exception: - MsgpackSerialization.dumps(msg, version=version) - assert "Unsupported BECMessage version" in str(exception.value) - else: - res = MsgpackSerialization.dumps(msg) - res_expected = b"\x81\xad__bec_codec__\x83\xacencoder_name\xaaBECMessage\xa9type_name\xb8DeviceInstructionMessage\xa4data\x84\xa8metadata\x81\xa3RID\xa41234\xa6device\xa4samx\xa6action\xa3set\xa9parameter\x81\xa3set\xcb?\xe0\x00\x00\x00\x00\x00\x00" - assert res == res_expected - res_loaded = MsgpackSerialization.loads(res) - assert res_loaded == msg - - -@pytest.mark.parametrize("version", [1.2, None]) -def test_bec_message_serialization_numpy_ndarray(version): - msg = messages.DeviceMessage( - signals={"samx": {"value": np.random.rand(20).astype(np.float32)}}, metadata={"RID": "1234"} +def test_bec_message_serialization_numpy_ndarray(): + msg = messages.DeviceMessage.model_validate( + { + "signals": {"samx": {"value": np.random.rand(20).astype(np.float32)}}, + "metadata": {"RID": "1234"}, + } ) res = MsgpackSerialization.dumps(msg) print(res) - res_loaded = MsgpackSerialization.loads(res) - np.testing.assert_equal(res_loaded.content, msg.content) + res_loaded: messages.DeviceMessage = MsgpackSerialization.loads(res) + np.testing.assert_equal(res_loaded.signals["samx"].value, msg.signals["samx"].value) assert res_loaded == msg @@ -40,9 +25,7 @@ def test_device_message_with_async_update(): signals={"samx": {"value": 5.2}}, metadata={ "RID": "1234", - "async_update": messages.DeviceAsyncUpdate( - type="add", max_shape=[None, 1024, 1024] - ).model_dump(), + "async_update": messages.DeviceAsyncUpdate(type="add", max_shape=[None, 1024, 1024]), }, ) res = MsgpackSerialization.dumps(msg) @@ -429,15 +412,13 @@ def test_DeviceInstructionMessage(): def test_DeviceMonitor2DMessage(): # Test 2D data - msg = messages.DeviceMonitor2DMessage( - device="eiger", data=np.random.rand(2, 100), metadata=None - ) + msg = messages.DeviceMonitor2DMessage(device="eiger", data=np.random.rand(2, 100)) res = MsgpackSerialization.dumps(msg) res_loaded = MsgpackSerialization.loads(res) assert res_loaded == msg assert res_loaded.metadata == {} # Test rgb image, i.e. image with 3 channels - msg = messages.DeviceMonitor2DMessage(device="eiger", data=np.random.rand(3, 3), metadata=None) + msg = messages.DeviceMonitor2DMessage(device="eiger", data=np.random.rand(3, 3)) res = MsgpackSerialization.dumps(msg) res_loaded = MsgpackSerialization.loads(res) assert res_loaded == msg @@ -454,7 +435,7 @@ def test_DeviceMonitor2DMessage(): def test_DeviceMonitor1DMessage(): # Test 2D data - msg = messages.DeviceMonitor1DMessage(device="eiger", data=np.random.rand(100), metadata=None) + msg = messages.DeviceMonitor1DMessage(device="eiger", data=np.random.rand(100)) res = MsgpackSerialization.dumps(msg) res_loaded = MsgpackSerialization.loads(res) assert res_loaded == msg diff --git a/bec_lib/tests/test_devices.py b/bec_lib/tests/test_devices.py index 02b449cd2..117e0ca5c 100644 --- a/bec_lib/tests/test_devices.py +++ b/bec_lib/tests/test_devices.py @@ -21,6 +21,7 @@ ) from bec_lib.devicemanager import DeviceContainer, DeviceManagerBase from bec_lib.endpoints import MessageEndpoints +from bec_lib.messages import SignalReading from bec_lib.tests.fixtures import device_manager_class from bec_lib.tests.utils import ClientMock, ConnectorMock, get_device_info_mock @@ -53,15 +54,21 @@ def test_read(dev: Any): res = dev.samx.read(cached=True) mock_get.assert_called_once_with(MessageEndpoints.device_readback("samx")) assert res == { - "samx": {"value": 0, "timestamp": 1701105880.1711318}, - "samx_setpoint": {"value": 0, "timestamp": 1701105880.1693492}, - "samx_motor_is_moving": {"value": 0, "timestamp": 1701105880.16935}, + "samx": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.1711318} + ), + "samx_setpoint": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.1693492} + ), + "samx_motor_is_moving": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.16935} + ), } def test_read_filtered_hints(dev: Any): with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get: - mock_get.return_value = messages.DeviceMessage( + msg = messages.DeviceMessage( signals={ "samx": {"value": 0, "timestamp": 1701105880.1711318}, "samx_setpoint": {"value": 0, "timestamp": 1701105880.1693492}, @@ -69,9 +76,10 @@ def test_read_filtered_hints(dev: Any): }, metadata={"scan_id": "scan_id", "scan_type": "scan_type"}, ) + mock_get.return_value = msg res = dev.samx.read(cached=True, filter_to_hints=True) mock_get.assert_called_once_with(MessageEndpoints.device_readback("samx")) - assert res == {"samx": {"value": 0, "timestamp": 1701105880.1711318}} + assert res == {"samx": msg.signals.get("samx")} def test_read_use_read(dev: Any): @@ -86,7 +94,7 @@ def test_read_use_read(dev: Any): ) res = dev.samx.read(cached=True, use_readback=False) mock_get.assert_called_once_with(MessageEndpoints.device_read("samx")) - assert res == data + assert res == {s: SignalReading.model_validate(sr) for s, sr in data.items()} def test_read_nested_device(dev: Any): @@ -103,7 +111,7 @@ def test_read_nested_device(dev: Any): ) res = dev.dyn_signals.messages.read(cached=True) mock_get.assert_called_once_with(MessageEndpoints.device_readback("dyn_signals")) - assert res == data + assert res == {s: SignalReading.model_validate(sr) for s, sr in data.items()} @pytest.mark.parametrize( @@ -131,7 +139,9 @@ def test_read_kind_hinted( if cached: mock_get.assert_called_once_with(MessageEndpoints.device_readback("samx")) mock_run.assert_not_called() - assert res == {"samx": {"value": 0, "timestamp": 1701105880.1711318}} + assert res == { + "samx": SignalReading.model_validate({"value": 0, "timestamp": 1701105880.1711318}) + } else: mock_run.assert_called_once_with(cached=False, fcn=dev.samx.readback.read) mock_get.assert_not_called() diff --git a/bec_lib/tests/test_generate_mesage_schema.py b/bec_lib/tests/test_generate_mesage_schema.py new file mode 100644 index 000000000..cb0ca65dd --- /dev/null +++ b/bec_lib/tests/test_generate_mesage_schema.py @@ -0,0 +1,15 @@ +from typing import Annotated + +import numpy as np +from pydantic import ConfigDict, WithJsonSchema + +from bec_lib.messages import BECMessage + + +class NumpyMessage(BECMessage): + model_config = ConfigDict(arbitrary_types_allowed=True) + important_value: Annotated[np.ndarray, WithJsonSchema({"type": "string"})] + + +def test_replace_numpy(): + schema = NumpyMessage.model_json_schema() diff --git a/bec_lib/tests/test_metadata_schema.py b/bec_lib/tests/test_metadata_schema.py index 5b10e71f4..abd0228b6 100644 --- a/bec_lib/tests/test_metadata_schema.py +++ b/bec_lib/tests/test_metadata_schema.py @@ -45,78 +45,80 @@ def test_required_fields_validate(): test_metadata.number_field = "string" -def test_creating_scan_queue_message_validates_metadata(): - with patch.dict(metadata_schema._METADATA_SCHEMA_REGISTRY, TEST_REGISTRY, clear=True): - with pytest.raises(ValidationError): - ScanQueueMessage(scan_type="fake_scan_with_extra_metadata") - with pytest.raises(ValidationError): - ScanQueueMessage( - scan_type="fake_scan_with_extra_metadata", - parameter={}, - metadata={"user_metadata": {"number_field": "string"}}, - ) - ScanQueueMessage( - scan_type="fake_scan_with_extra_metadata", - parameter={}, - metadata={"user_metadata": {"number_field": 123}}, - ) - msg_with_extra_keys = ScanQueueMessage( - scan_type="fake_scan_with_extra_metadata", - parameter={}, - metadata={"user_metadata": {"number_field": 123, "extra": "data"}}, - ) - assert msg_with_extra_keys.metadata["user_metadata"]["extra"] == "data" - - -def test_default_schema_is_used_as_fallback(): - with patch.dict(metadata_schema._METADATA_SCHEMA_REGISTRY, TEST_REGISTRY, clear=True): - metadata_schema.get_metadata_schema_for_scan("") # create cache before patching default - with patch.object(metadata_schema, "_DEFAULT_SCHEMA", BeamlineDefaultSchema): - - assert metadata_schema.get_default_schema() is BeamlineDefaultSchema - assert ( - metadata_schema.get_metadata_schema_for_scan("not associated with anything") - is BeamlineDefaultSchema - ) - - with pytest.raises(ValidationError): - _msg_not_matching_default_and_no_specified_schema = ScanQueueMessage( - scan_type="not associated with anything", - parameter={}, - metadata={"user_metadata": {"number_field": 123}}, - ) - with pytest.raises(ValidationError): - _msg_matching_default_but_with_specified_schema = ScanQueueMessage( - scan_type="fake_scan_with_extra_metadata", - parameter={}, - metadata={"user_metadata": {"sample_name_long": "long string of text"}}, - ) - _msg_matching_default_and_no_specified_schema = ScanQueueMessage( - scan_type="not associated with anything", - parameter={}, - metadata={"user_metadata": {"sample_name_long": "long string of text"}}, - ) - - -def test_prepare_scan_request_produces_conforming_message(): - with patch.dict(metadata_schema._METADATA_SCHEMA_REGISTRY, TEST_REGISTRY, clear=True): - with pytest.raises(ValidationError): - Scans.prepare_scan_request( - scan_name="fake_scan_with_extra_metadata", - scan_info={"required_kwargs": []}, - system_config={}, - ) - with pytest.raises(ValidationError): - Scans.prepare_scan_request( - scan_name="fake_scan_with_extra_metadata", - scan_info={"required_kwargs": []}, - system_config={}, - user_metadata={"number_field": "string"}, - ) - msg = Scans.prepare_scan_request( - scan_name="fake_scan_with_extra_metadata", - scan_info={"required_kwargs": []}, - system_config={}, - user_metadata={"number_field": 123}, - ) - assert msg.metadata["user_metadata"] == {"number_field": 123} +# TODO: reenable when message classes moved out and validator can be readded + +# def test_creating_scan_queue_message_validates_metadata(): +# with patch.dict(metadata_schema._METADATA_SCHEMA_REGISTRY, TEST_REGISTRY, clear=True): +# with pytest.raises(ValidationError): +# ScanQueueMessage(scan_type="fake_scan_with_extra_metadata") +# with pytest.raises(ValidationError): +# ScanQueueMessage( +# scan_type="fake_scan_with_extra_metadata", +# parameter={}, +# metadata={"user_metadata": {"number_field": "string"}}, +# ) +# ScanQueueMessage( +# scan_type="fake_scan_with_extra_metadata", +# parameter={}, +# metadata={"user_metadata": {"number_field": 123}}, +# ) +# msg_with_extra_keys = ScanQueueMessage( +# scan_type="fake_scan_with_extra_metadata", +# parameter={}, +# metadata={"user_metadata": {"number_field": 123, "extra": "data"}}, +# ) +# assert msg_with_extra_keys.metadata["user_metadata"]["extra"] == "data" + + +# def test_default_schema_is_used_as_fallback(): +# with patch.dict(metadata_schema._METADATA_SCHEMA_REGISTRY, TEST_REGISTRY, clear=True): +# metadata_schema.get_metadata_schema_for_scan("") # create cache before patching default +# with patch.object(metadata_schema, "_DEFAULT_SCHEMA", BeamlineDefaultSchema): + +# assert metadata_schema.get_default_schema() is BeamlineDefaultSchema +# assert ( +# metadata_schema.get_metadata_schema_for_scan("not associated with anything") +# is BeamlineDefaultSchema +# ) + +# with pytest.raises(ValidationError): +# _msg_not_matching_default_and_no_specified_schema = ScanQueueMessage( +# scan_type="not associated with anything", +# parameter={}, +# metadata={"user_metadata": {"number_field": 123}}, +# ) +# with pytest.raises(ValidationError): +# _msg_matching_default_but_with_specified_schema = ScanQueueMessage( +# scan_type="fake_scan_with_extra_metadata", +# parameter={}, +# metadata={"user_metadata": {"sample_name_long": "long string of text"}}, +# ) +# _msg_matching_default_and_no_specified_schema = ScanQueueMessage( +# scan_type="not associated with anything", +# parameter={}, +# metadata={"user_metadata": {"sample_name_long": "long string of text"}}, +# ) + + +# def test_prepare_scan_request_produces_conforming_message(): +# with patch.dict(metadata_schema._METADATA_SCHEMA_REGISTRY, TEST_REGISTRY, clear=True): +# with pytest.raises(ValidationError): +# Scans.prepare_scan_request( +# scan_name="fake_scan_with_extra_metadata", +# scan_info={"required_kwargs": []}, +# system_config={}, +# ) +# with pytest.raises(ValidationError): +# Scans.prepare_scan_request( +# scan_name="fake_scan_with_extra_metadata", +# scan_info={"required_kwargs": []}, +# system_config={}, +# user_metadata={"number_field": "string"}, +# ) +# msg = Scans.prepare_scan_request( +# scan_name="fake_scan_with_extra_metadata", +# scan_info={"required_kwargs": []}, +# system_config={}, +# user_metadata={"number_field": 123}, +# ) +# assert msg.metadata["user_metadata"] == {"number_field": 123} diff --git a/bec_lib/tests/test_redis_connector.py b/bec_lib/tests/test_redis_connector.py index 06d0fd039..1b6cce653 100644 --- a/bec_lib/tests/test_redis_connector.py +++ b/bec_lib/tests/test_redis_connector.py @@ -23,6 +23,7 @@ validate_endpoint, ) from bec_lib.serialization import MsgpackSerialization +from bec_lib.tests.utils import _endpoint_info # pylint: disable=protected-access # pylint: disable=missing-function-docstring @@ -305,7 +306,7 @@ def test_redis_connector_delete(connector, topic, use_pipe): @pytest.mark.parametrize("topic, use_pipe", [["topic1", True], ["topic2", False]]) def test_redis_connector_get(connector, topic, use_pipe): pipe = use_pipe_fcn(connector, use_pipe) - + connector._redis_conn.get.return_value = None ret = connector.get(topic, pipe) if pipe: connector.pipeline().get.assert_called_once_with(topic) @@ -393,7 +394,8 @@ def test_mget(connector): def test_validate_with_present_arg(): - endpoint = EndpointInfo("test", Any, ["method"]) # type: ignore + endpoint = _endpoint_info("test", Any, MessageOp.LIST) + endpoint.message_op = ["method"] # type: ignore @validate_endpoint("arg1") def method(_, arg1): @@ -412,7 +414,7 @@ def method(_, arg1): ... def test_validate_rejects_wrong_op(): - endpoint = EndpointInfo("test", Any, ["missing_ops"]) # type: ignore + endpoint = _endpoint_info("test", Any, MessageOp.LIST) # type: ignore @validate_endpoint("arg1") def not_in_list(_, arg1): ... @@ -441,7 +443,7 @@ def test_set_connector( connected_connector, ) -> Generator[tuple[RedisConnector, EndpointInfo, set[ProcedureExecutionMessage]], None, None]: - test_set_endpoint = EndpointInfo( + test_set_endpoint = _endpoint_info( f"{EndpointType.INFO}/procedures/active_procedures", ProcedureExecutionMessage, MessageOp.SET, @@ -481,7 +483,7 @@ def test_list_pop_to_sadd_adds_to_set( test_set_connector: tuple[RedisConnector, EndpointInfo, set[ProcedureExecutionMessage]], ): connected_connector, test_set_endpoint, test_set_messages = test_set_connector - test_list_endpoint = EndpointInfo( + test_list_endpoint = _endpoint_info( f"{EndpointType.INTERNAL}/procedures/procedure_execution/queue5", ProcedureExecutionMessage, MessageOp.LIST, @@ -515,7 +517,7 @@ def test_list_pop_to_sadd_rejects_wrong_message_for_set( test_set_connector: tuple[RedisConnector, EndpointInfo, set[ProcedureExecutionMessage]], ): connected_connector, test_set_endpoint, _ = test_set_connector - test_list_endpoint = EndpointInfo( + test_list_endpoint = _endpoint_info( f"{EndpointType.INTERNAL}/procedures/procedure_execution/queue5", ProcedureExecutionMessage, MessageOp.LIST, diff --git a/bec_lib/tests/test_redis_connector_fakeredis.py b/bec_lib/tests/test_redis_connector_fakeredis.py index 4baec31af..620fced44 100644 --- a/bec_lib/tests/test_redis_connector_fakeredis.py +++ b/bec_lib/tests/test_redis_connector_fakeredis.py @@ -10,6 +10,7 @@ from bec_lib.endpoints import EndpointInfo, MessageEndpoints, MessageOp from bec_lib.redis_connector import MessageObject, RedisConnector from bec_lib.serialization import MsgpackSerialization +from bec_lib.tests.utils import _endpoint_info from .test_redis_connector import TestMessage @@ -20,8 +21,8 @@ # pylint: disable=unused-argument -TestStreamEndpoint = EndpointInfo("test", TestMessage, MessageOp.STREAM) -TestStreamEndpoint2 = EndpointInfo("test2", TestMessage, MessageOp.STREAM) +TestStreamEndpoint = _endpoint_info("test", TestMessage, MessageOp.STREAM) +TestStreamEndpoint2 = _endpoint_info("test2", TestMessage, MessageOp.STREAM) @pytest.mark.parametrize( @@ -145,8 +146,8 @@ def test_redis_connector_register_identical(connected_connector): def test_redis_connector_unregister_cb_not_topic(connected_connector): connector = connected_connector - topic1 = EndpointInfo("topic1", TestMessage, MessageOp.SEND) - topic2 = EndpointInfo("topic2", TestMessage, MessageOp.SEND) + topic1 = _endpoint_info("topic1", TestMessage, MessageOp.SEND) + topic2 = _endpoint_info("topic2", TestMessage, MessageOp.SEND) received_event1 = mock.Mock(spec=[]) received_event2 = mock.Mock(spec=[]) @@ -188,8 +189,8 @@ def send_msgs_and_poll(timeout=None): connector = connected_connector - topic1 = EndpointInfo("topic1", TestMessage, MessageOp.SEND) - topic2 = EndpointInfo("topic2", TestMessage, MessageOp.SEND) + topic1 = _endpoint_info("topic1", TestMessage, MessageOp.SEND) + topic2 = _endpoint_info("topic2", TestMessage, MessageOp.SEND) received_event1 = mock.Mock(spec=[]) received_event2 = mock.Mock(spec=[]) diff --git a/bec_lib/tests/test_serializer.py b/bec_lib/tests/test_serializer.py index 93f5990de..9da85a7c3 100644 --- a/bec_lib/tests/test_serializer.py +++ b/bec_lib/tests/test_serializer.py @@ -6,14 +6,13 @@ from pydantic import BaseModel from bec_lib import messages -from bec_lib.codecs import BECCodec from bec_lib.device import DeviceBase from bec_lib.devicemanager import DeviceManagerBase from bec_lib.endpoints import MessageEndpoints -from bec_lib.serialization import MsgpackSerialization, json_ext, msgpack +from bec_lib.serialization import MsgpackSerialization -@pytest.fixture(params=[json_ext, msgpack, MsgpackSerialization]) +@pytest.fixture(params=[MsgpackSerialization]) def serializer(request): yield request.param @@ -31,8 +30,6 @@ class CustomEnum(enum.Enum): 1, 1.0, [1, 2, 3], - np.array([1, 2, 3]), - {1, 2, 3}, { "hroz": { "hroz": {"value": 0, "timestamp": 1708336264.5731058}, @@ -40,11 +37,7 @@ class CustomEnum(enum.Enum): } }, MessageEndpoints.progress("test"), - messages.DeviceMessage, - float, messages.RawMessage(data={"a": 1, "b": 2}), - messages.BECStatus.RUNNING, - np.uint32, messages.DeviceMessage( signals={ "hroz": { @@ -66,90 +59,14 @@ class CustomEnum(enum.Enum): ], ) def test_serialize(serializer, data): - res = serializer.loads(serializer.dumps(data)) == data + ser = serializer.dumps(data) + deser = serializer.loads(ser) + res = deser == data assert all(res) if isinstance(data, np.ndarray) else res -def test_serialize_model(serializer): - - class DummyModel(BaseModel): - a: int - b: int - - data = DummyModel(a=1, b=2) - converted_data = serializer.loads(serializer.dumps(data)) - assert data.model_dump() == converted_data - - -def test_device_serializer(serializer): - device_manager = mock.MagicMock(spec=DeviceManagerBase) - dummy = DeviceBase(name="dummy", parent=device_manager) - assert serializer.loads(serializer.dumps(dummy)) == "dummy" - - -def test_enum_serializer(serializer): - assert serializer.loads(serializer.dumps(CustomEnum.VALUE1)) == "value1" - - -def test_serializer_encoding_on_failure(): - """ - Test that an exception raised during serialization is caught and the original object is returned. - """ - - class DummyModel: - def __init__(self, a, b): - self.a = a - self.b = b - - def __eq__(self, other): - return isinstance(other, DummyModel) and self.a == other.a and self.b == other.b - - class RaiseEncoder(BECCodec): - obj_type = DummyModel - - @staticmethod - def encode(obj): - raise ValueError("Serialization failed") - - @staticmethod - def decode(type_name: str, data: dict): - raise ValueError("Deserialization failed") - - try: - msgpack.register_codec(RaiseEncoder) - data = DummyModel(a=1, b=2) - with pytest.raises(ValueError, match="Serialization failed"): - serialized_data = msgpack.dumps(data) - - serialized_data = msgpack.dumps( - {"__bec_codec__": {"encoder_name": "DummyModel", "type_name": "DummyModel", "data": {}}} - ) - with pytest.raises(ValueError, match="Deserialization failed"): - msgpack.loads(serialized_data) - finally: - # Unregister the codec to avoid side effects on other tests - msgpack._registry.pop("DummyModel") - - -def test_serializer_registry_cache_resets(): - """ - Test that adding a new codec resets the cache. - """ - - class DummyType: - pass - - class DummyCodec(BECCodec): - obj_type = DummyType - - @staticmethod - def encode(obj): - return {"dummy": "data"} - - @staticmethod - def decode(type_name: str, data: dict): - return DummyType() - - assert not msgpack.is_registered(DummyType) - msgpack.register_codec(DummyCodec) - assert msgpack.is_registered(DummyType) +# TODO: figure out what to do with this - works inside messages +# def test_device_serializer(serializer): +# device_manager = mock.MagicMock(spec=DeviceManagerBase) +# dummy = DeviceBase(name="dummy", parent=device_manager) +# assert serializer.loads(serializer.dumps(dummy)) == "dummy" diff --git a/bec_server/bec_server/file_writer/async_writer.py b/bec_server/bec_server/file_writer/async_writer.py index 2a506b1a3..10527a120 100644 --- a/bec_server/bec_server/file_writer/async_writer.py +++ b/bec_server/bec_server/file_writer/async_writer.py @@ -285,26 +285,9 @@ def write_data( else: signal_group = device_group[signal_name] - for key, value in signal_data.items(): - - if key == "value": - self.write_value_data(signal_group, value, async_update) - elif key == "timestamp": - self.write_timestamp_data(signal_group, value) - else: # pragma: no cover - # this should never happen as the keys are fixed in the pydantic model - msg = f"Unknown key: {key}. Data will not be written." - error_info = messages.ErrorInfo( - error_message=msg, - compact_error_message=msg, - exception_type="ValueError", - device=device_name, - ) - self.connector.raise_alarm( - severity=Alarms.WARNING, - info=error_info, - metadata={"scan_id": self.scan_id, "scan_number": self.scan_number}, - ) + self.write_value_data(signal_group, signal_data.value, async_update) + if signal_data.timestamp is not None: + self.write_timestamp_data(signal_group, signal_data.timestamp) if write_replace: for group_name, value in self.device_data_replace.items(): diff --git a/bec_server/bec_server/file_writer/file_writer.py b/bec_server/bec_server/file_writer/file_writer.py index b589d5418..f319b25f5 100644 --- a/bec_server/bec_server/file_writer/file_writer.py +++ b/bec_server/bec_server/file_writer/file_writer.py @@ -10,6 +10,7 @@ import h5py from bec_lib import messages, plugin_helper +from bec_lib.bec_serializable import BecWrappedValue from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger @@ -40,6 +41,16 @@ def __init__(self, storage_type: str = "group", data=None) -> None: self.attrs = {} self._data = data + @property + def data(self): + if isinstance(self._data, messages.SignalReading): + return self._data.to_dict() + if isinstance(self._data, list) and set(map(type, self._data)) == {messages.SignalReading}: + return list(map(lambda x: x.to_dict(), self._data)) + if isinstance(self._data, BecWrappedValue): + return self._data.data + return self._data + def create_group(self, name: str) -> HDF5Storage: """ Create a group in the HDF5 storage. @@ -116,7 +127,7 @@ def add_group(self, name: str, container: typing.Any, val: HDF5Storage): self.add_attribute(group, val.attrs) self.add_content(group, val._storage) - data = val._data + data = val.data if not data: return @@ -140,14 +151,13 @@ def add_group(self, name: str, container: typing.Any, val: HDF5Storage): group.create_dataset(name=key, data=value) def add_dataset(self, name: str, container: typing.Any, val: HDF5Storage): + data = val.data try: - if isinstance(val._data, dict): - self.add_group(name, container, val) - return - - data = val._data if data is None: return + if isinstance(data, dict): + self.add_group(name, container, val) + return if isinstance(data, list): if data and isinstance(data[0], dict): data = json.dumps(data) @@ -164,7 +174,7 @@ def add_dataset(self, name: str, container: typing.Any, val: HDF5Storage): container.attrs["signal"] = "value" except Exception: content = traceback.format_exc() - logger.error(f"Failed to write dataset {name}: {content}") + logger.error(f"Failed to write dataset {name} with {data=}: {content}") return def add_attribute(self, container: typing.Any, attributes: dict): diff --git a/bec_server/bec_server/file_writer/file_writer_manager.py b/bec_server/bec_server/file_writer/file_writer_manager.py index bb6e4fada..beb93c32c 100644 --- a/bec_server/bec_server/file_writer/file_writer_manager.py +++ b/bec_server/bec_server/file_writer/file_writer_manager.py @@ -236,9 +236,10 @@ def insert_to_scan_storage(self, msg: messages.ScanMessage) -> None: self.scan_storage[scan_id] = ScanStorage( scan_number=msg.metadata.get("scan_number"), scan_id=scan_id ) - self.scan_storage[scan_id].append( - point_id=msg.content.get("point_id"), data=msg.content.get("data") - ) + point_data = msg.content.get("data") + if isinstance(point_data, messages.SignalReading): + point_data = point_data.to_dict() + self.scan_storage[scan_id].append(point_id=msg.content.get("point_id"), data=point_data) logger.debug(msg.content.get("point_id")) self.check_storage_status(scan_id=scan_id) diff --git a/bec_server/tests/tests_file_writer/test_async_file_writer.py b/bec_server/tests/tests_file_writer/test_async_file_writer.py index 31b354ffd..f2dbc8fe1 100644 --- a/bec_server/tests/tests_file_writer/test_async_file_writer.py +++ b/bec_server/tests/tests_file_writer/test_async_file_writer.py @@ -272,11 +272,9 @@ def test_async_writer_add_slice_fixed_size_data_consistency(async_writer): assert out.shape == (2, 20) assert np.allclose( out[0, :], - np.hstack( - (data[0].signals["monitor_async"]["value"], data[1].signals["monitor_async"]["value"]) - ), + np.hstack((data[0].signals["monitor_async"].value, data[1].signals["monitor_async"].value)), ) - assert np.allclose(out[1, :10], data[2].signals["monitor_async"]["value"]) + assert np.allclose(out[1, :10], data[2].signals["monitor_async"].value) assert np.allclose(out[1, 10:], np.zeros(10)) @@ -389,7 +387,7 @@ def test_async_writer_replace(async_writer, data): out = f[async_writer.BASE_PATH]["monitor_async"]["monitor_async"]["value"][:] assert out.shape == (10,) - assert np.allclose(out, data[-1].signals["monitor_async"]["value"]) + assert np.allclose(out, data[-1].signals["monitor_async"].value) def test_async_writer_async_signal(async_writer): diff --git a/bec_server/tests/tests_scan_server/test_scan_worker.py b/bec_server/tests/tests_scan_server/test_scan_worker.py index 7e99f6037..80f8e7bee 100644 --- a/bec_server/tests/tests_scan_server/test_scan_worker.py +++ b/bec_server/tests/tests_scan_server/test_scan_worker.py @@ -101,7 +101,7 @@ def test_publish_data_as_read(scan_worker_mock): def test_publish_data_as_read_multiple(scan_worker_mock): worker = scan_worker_mock - data = [{"samx": {}}, {"samy": {}}] + data = [{"samx": {"value": 0}}, {"samy": {"value": 0}}] devices = ["samx", "samy"] instr = messages.DeviceInstructionMessage( device=devices, diff --git a/bec_server/tests/tests_scihub/test_atlas_forwarder.py b/bec_server/tests/tests_scihub/test_atlas_forwarder.py index 22d4815ff..bb29d8bcb 100644 --- a/bec_server/tests/tests_scihub/test_atlas_forwarder.py +++ b/bec_server/tests/tests_scihub/test_atlas_forwarder.py @@ -1,9 +1,9 @@ +import json import time import pytest from bec_lib import messages -from bec_lib.serialization import json_ext as json @pytest.fixture