diff --git a/bec_ipython_client/tests/client_tests/test_ipython_live_updates.py b/bec_ipython_client/tests/client_tests/test_ipython_live_updates.py index f26808044..f74224009 100644 --- a/bec_ipython_client/tests/client_tests/test_ipython_live_updates.py +++ b/bec_ipython_client/tests/client_tests/test_ipython_live_updates.py @@ -23,7 +23,7 @@ def queue_elements(bec_client_mock): client = bec_client_mock request_msg = messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -52,7 +52,7 @@ def queue_elements(bec_client_mock): def sample_request_msg(): return messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -232,7 +232,7 @@ def test_available_req_blocks_multiple_blocks(bec_client_mock): request_msg = messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "test_rid"}, ) diff --git a/bec_ipython_client/tests/client_tests/test_live_table.py b/bec_ipython_client/tests/client_tests/test_live_table.py index 891803831..d66e51228 100644 --- a/bec_ipython_client/tests/client_tests/test_live_table.py +++ b/bec_ipython_client/tests/client_tests/test_live_table.py @@ -50,7 +50,7 @@ def client_with_grid_scan(bec_client_mock): client = bec_client_mock request_msg = messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -88,7 +88,7 @@ def test_sort_devices(self): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ), @@ -134,7 +134,7 @@ def test_wait_for_request_acceptance(self, client_with_grid_scan): def test_run_update(self, bec_client_mock, scan_item): request_msg = messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -161,7 +161,7 @@ def test_run_update(self, bec_client_mock, scan_item): def test_run_update_without_monitored_devices(self, bec_client_mock, scan_item): request_msg = messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -303,8 +303,10 @@ def test_print_table_data_hinted_value_with_precision( @pytest.mark.parametrize( "value,expected", [ - (np.int32(1), "1.00"), - (np.float64(1.00000), "1.00"), + # Commented out cases are not supported in unstructured serialized data, because msgpack doesn't distinguish + # lists, tuples, or sets. To support this, ScanMessage must be refactored to support the type information directly + # (np.int32(1), "1.00"), + # (np.float64(1.00000), "1.00"), (0, "0.00"), (1, "1.00"), (0.000, "0.00"), @@ -314,10 +316,10 @@ def test_print_table_data_hinted_value_with_precision( ("False", "False"), ("0", "0"), ("1", "1"), - ((0, 1), "(0, 1)"), + # ((0, 1), "(0, 1)"), ({"value": 0}, "{'value': 0}"), - (np.array([0, 1]), "[0 1]"), - ({1, 2}, "{1, 2}"), + # (np.array([0, 1]), "[0 1]"), + # ({1, 2}, "{1, 2}"), ], ) def test_print_table_data_variants(self, client_with_grid_scan, value, expected): diff --git a/bec_ipython_client/tests/end-2-end/test_scans_e2e.py b/bec_ipython_client/tests/end-2-end/test_scans_e2e.py index 238883c11..0961c2667 100644 --- a/bec_ipython_client/tests/end-2-end/test_scans_e2e.py +++ b/bec_ipython_client/tests/end-2-end/test_scans_e2e.py @@ -918,7 +918,7 @@ def test_scan_repeat_decorator(bec_ipython_client_fixture): "update_frequency": 400, }, "readoutPriority": "baseline", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, } diff --git a/bec_ipython_client/tests/end-2-end/test_scans_lib_e2e.py b/bec_ipython_client/tests/end-2-end/test_scans_lib_e2e.py index 47d28e590..581289f42 100644 --- a/bec_ipython_client/tests/end-2-end/test_scans_lib_e2e.py +++ b/bec_ipython_client/tests/end-2-end/test_scans_lib_e2e.py @@ -217,7 +217,7 @@ def test_dap_fit(bec_client_lib): "hexapod": { "deviceClass": "ophyd_devices.SynDeviceOPAAS", "deviceConfig": {}, - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "readoutPriority": "baseline", "enabled": True, "readOnly": False, @@ -230,7 +230,7 @@ def test_dap_fit(bec_client_lib): "tolerance": 0.01, "update_frequency": 400, }, - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, }, @@ -244,7 +244,7 @@ def test_dap_fit(bec_client_lib): "hexapod": { "deviceClass": "ophyd_devices.SynDeviceOPAAS", "deviceConfig": {}, - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "readoutPriority": "baseline", "enabled": True, "readOnly": False, @@ -258,7 +258,7 @@ def test_dap_fit(bec_client_lib): "update_frequency": 400, }, "readoutPriority": "baseline", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, }, @@ -272,7 +272,7 @@ def test_dap_fit(bec_client_lib): "hexapod": { "deviceClass": "ophyd_devices.SynDeviceOPAAS", "deviceConfig": {}, - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "readoutPriority": "baseline", "enabled": True, "readOnly": False, @@ -281,7 +281,7 @@ def test_dap_fit(bec_client_lib): "deviceClass": "ophyd_devices.utils.bec_utils.DeviceClassConnectionError", "deviceConfig": {}, "readoutPriority": "baseline", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, }, @@ -295,7 +295,7 @@ def test_dap_fit(bec_client_lib): "hexapod": { "deviceClass": "SynDeviceOPAAS", "deviceConfig": {}, - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "readoutPriority": "baseline", "enabled": True, "readOnly": False, @@ -304,7 +304,7 @@ def test_dap_fit(bec_client_lib): "deviceClass": "ophyd_devices.utils.bec_utils.DeviceClassInitError", "deviceConfig": {}, "readoutPriority": "baseline", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, }, @@ -318,7 +318,7 @@ def test_dap_fit(bec_client_lib): "hexapod": { "deviceClass": "SynDeviceOPAAS", "deviceConfig": {}, - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "readoutPriority": "baseline", "enabled": True, "readOnly": False, @@ -327,7 +327,7 @@ def test_dap_fit(bec_client_lib): "deviceClass": "ophyd_devices.WrongDeviceClass", "deviceConfig": {}, "readoutPriority": "baseline", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, }, @@ -383,7 +383,7 @@ def test_config_reload_with_describe_failure(bec_test_config_file_path, bec_clie "hexapod": { "deviceClass": "ophyd_devices.sim.sim_test_devices.SimPositionerWithDescribeFailure", "deviceConfig": {}, - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "readoutPriority": "baseline", "enabled": True, "readOnly": False, @@ -397,7 +397,7 @@ def test_config_reload_with_describe_failure(bec_test_config_file_path, bec_clie "update_frequency": 400, }, "readoutPriority": "baseline", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, }, @@ -445,7 +445,7 @@ def test_config_add_remove_device(bec_client_lib): "update_frequency": 400, }, "readoutPriority": "baseline", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, } diff --git a/bec_lib/bec_lib/bec_serializable.py b/bec_lib/bec_lib/bec_serializable.py new file mode 100644 index 000000000..a3180a61d --- /dev/null +++ b/bec_lib/bec_lib/bec_serializable.py @@ -0,0 +1,33 @@ +import numpy as np +from pydantic import BaseModel, ConfigDict, computed_field + + +class BecCodecInfo(BaseModel): + type_name: str + + +class BECSerializable(BaseModel): + """A base class for serializable BEC objects, especially BEC messages. + Fields in subclasses which use non-primitive types must be in structured, + type-hinted objects, and their encoders and JSON schema should be defined in + this class.""" + + model_config = ConfigDict( + json_schema_serialization_defaults_required=True, + arbitrary_types_allowed=True, + extra="forbid", + ) + + @computed_field() + @property + def bec_codec(self) -> BecCodecInfo: + return BecCodecInfo(type_name=self.__class__.__name__) + + def __eq__(self, other): + if type(other) is not type(self): + return False + try: + np.testing.assert_equal(self.model_dump(), other.model_dump()) + return True + except AssertionError: + return False diff --git a/bec_lib/bec_lib/device.py b/bec_lib/bec_lib/device.py index 3ccfbf40b..a34f058ed 100644 --- a/bec_lib/bec_lib/device.py +++ b/bec_lib/bec_lib/device.py @@ -308,7 +308,7 @@ def _prepare_rpc_msg( client: BECClient = self.root.parent.parent msg = messages.ScanQueueMessage( scan_type="device_rpc", - parameter=params, + parameter=messages.sanitize_one_way_encodable(params), queue=client.queue.get_default_scan_queue(), # type: ignore metadata={"RID": request_id, "response": True}, ) @@ -1115,8 +1115,8 @@ def limits(self): if not limit_msg: return [0, 0] limits = [ - limit_msg.content["signals"].get("low", {}).get("value", 0), - limit_msg.content["signals"].get("high", {}).get("value", 0), + limit_msg.signals.get("low", {}).get("value", 0), + limit_msg.signals.get("high", {}).get("value", 0), ] return limits diff --git a/bec_lib/bec_lib/devicemanager.py b/bec_lib/bec_lib/devicemanager.py index 34feba6e0..5a3502e78 100644 --- a/bec_lib/bec_lib/devicemanager.py +++ b/bec_lib/bec_lib/devicemanager.py @@ -667,9 +667,18 @@ def _get_redis_device_config(self) -> list: def _add_multiple_devices_with_log(self, devices: Iterable[tuple[dict, DeviceInfoMessage]]): logs = (self._add_device(*conf_msg) for conf_msg in devices if conf_msg is not None) - logger.info(f"Adding new devices:\n" + ", ".join(f"{name}: {t}" for name, t in logs)) # type: ignore # filtered + if set(logs) == {None}: + logger.warning("No devices added!") + return + logger.info( + f"Adding new devices:\n" + + ", ".join(f"{log[0]}: {log[1]}" for log in logs if log is not None) + ) def _add_device(self, dev: dict, msg: DeviceInfoMessage) -> tuple[str, str] | None: + if msg is None: + logger.error(f"No device info in Redis for: {dev}") + return None name = msg.content["device"] info = msg.content["info"] diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index 23bed4c58..584421615 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -5,16 +5,66 @@ import uuid import warnings from copy import deepcopy -from enum import Enum, auto +from enum import Enum, StrEnum, auto from importlib.metadata import PackageNotFoundError from importlib.metadata import version as importlib_version -from typing import Any, ClassVar, Literal, Self, Union +from types import NoneType +from typing import Annotated, Any, ClassVar, Literal, Mapping, Self, TypeVar, Union from uuid import uuid4 +import msgpack import numpy as np -from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator - +from pydantic import ( + BaseModel, + BeforeValidator, + ConfigDict, + FailFast, + Field, + Strict, + StrictBool, + StrictFloat, + StrictInt, + StrictStr, + ValidationError, + WithJsonSchema, + field_validator, + model_validator, +) +from typing_extensions import TypeAliasType + +from bec_lib.bec_serializable import BECSerializable from bec_lib.metadata_schema import get_metadata_schema_for_scan +from bec_lib.one_way_registry import OneWaySerializationRegistry + +_one_way_registry = OneWaySerializationRegistry() + + +def sanitize_one_way_encodable(data: Any) -> Any: + """Sanitize any data which can be serialized in a json-compatible format and is not supposed to be decoded, + for example, a parameter dict containing devices""" + if isinstance(data, (list, tuple, set)): + return [sanitize_one_way_encodable(x) for x in data] + if isinstance(data, Mapping): + return { + sanitize_one_way_encodable(k): sanitize_one_way_encodable(v) for k, v in data.items() + } + return _one_way_registry.encode(data) + + +JsonableScalar = TypeAliasType("JsonableScalar", StrictInt | StrictFloat | StrictStr | StrictBool) + +Jsonable = TypeAliasType( + "Jsonable", + JsonableScalar + | None + | Annotated[list["Jsonable"], Strict(), FailFast()] + | Annotated[dict[StrictStr, "Jsonable"], Strict()], +) + +JsonableDict = TypeAliasType( + "JsonableDict", + Annotated[dict[StrictStr, Jsonable], WithJsonSchema({"type": "object"}), Strict()], +) class ProcedureWorkerStatus(Enum): @@ -34,7 +84,7 @@ class BECStatus(Enum): ERROR = -1 -class BECMessage(BaseModel): +class BECMessage(BECSerializable): """Base Model class for BEC Messages Args: @@ -44,17 +94,14 @@ class BECMessage(BaseModel): """ msg_type: ClassVar[str] - metadata: dict = Field(default_factory=dict) + metadata: JsonableDict = Field(default_factory=dict) - @field_validator("metadata") + @model_validator(mode="before") @classmethod - def check_metadata(cls, v): - """Validate the metadata, return empty dict if None - - Args: - v (dict, None): Metadata dictionary - """ - return v or {} + def _strip_codec_info(cls, data: Any): + if isinstance(data, dict): + data.pop("bec_codec", None) + return data @property def content(self): @@ -63,18 +110,6 @@ def content(self): content.pop("metadata", None) return content - def __eq__(self, other): - if not isinstance(other, BECMessage): - # don't attempt to compare against unrelated types - return False - - try: - np.testing.assert_equal(self.model_dump(), other.model_dump()) - except AssertionError: - return False - - return self.msg_type == other.msg_type and self.metadata == other.metadata - def loads(self): warnings.warn( "BECMessage.loads() is deprecated and should not be used anymore. When calling Connector methods, it can be omitted. When a message needs to be deserialized call the appropriate function from bec_lib.serialization", @@ -93,6 +128,11 @@ def __hash__(self) -> int: return self.model_dump_json().__hash__() +# To correctly encode a message in another message, pydantic should know it is to be dumped +# as the concrete type it is, and not only the fields from BECMessage +SpecificMessageType = TypeVar("SpecificMessageType", bound=BECMessage) + + class BundleMessage(BECMessage): """Message type to send a bundle of BECMessages. @@ -108,7 +148,7 @@ class BundleMessage(BECMessage): """ msg_type: ClassVar[str] = "bundle_message" - messages: list = Field(default_factory=list[BECMessage]) + messages: Annotated[list[SpecificMessageType], Field(default_factory=list)] def append(self, msg: BECMessage): """Append a new BECMessage to the bundle""" @@ -141,7 +181,7 @@ class ScanQueueMessage(BECMessage): msg_type: ClassVar[str] = "scan_queue_message" scan_type: str - parameter: dict + parameter: JsonableDict queue: str = Field(default="primary") allow_restart: bool = Field( default=True, @@ -225,18 +265,18 @@ class ScanStatusMessage(BECMessage): scan_type: Literal["step", "fly"] | None = Field(default=None, description="Type of scan") dataset_number: int | None = None scan_report_devices: list[str] | None = None - user_metadata: dict | None = None + user_metadata: JsonableDict | None = None readout_priority: ( dict[Literal["monitored", "baseline", "async", "continuous", "on_request"], list[str]] | None ) = None scan_parameters: dict[ - Literal["exp_time", "frames_per_trigger", "settling_time", "readout_time"] | str, Any + Literal["exp_time", "frames_per_trigger", "settling_time", "readout_time"] | str, Jsonable ] = Field(default_factory=dict) - request_inputs: dict[Literal["arg_bundle", "inputs", "kwargs"], Any] = Field( + request_inputs: dict[Literal["arg_bundle", "inputs", "kwargs"], Jsonable] = Field( default_factory=dict ) - info: dict + info: JsonableDict timestamp: float = Field(default_factory=time.time) def __str__(self): @@ -302,7 +342,7 @@ class ScanQueueModificationMessage(BECMessage): "release_lock", "user_completed", ] - parameter: dict + parameter: JsonableDict queue: str = Field(default="primary") @@ -550,7 +590,7 @@ class DeviceInstructionMessage(BECMessage): "publish_data_as_read", "close_scan_group", ] - parameter: dict + parameter: JsonableDict class ErrorInfo(BaseModel): @@ -577,6 +617,46 @@ def _ensure_error_info_if_error(self): return self +# TODO: remove when deprecated usages of SignalReading are cleaned up +logger = None + + +def lazy_ensure_logger(): + global logger + if logger is None: + from bec_lib.logger import bec_logger + + logger = bec_logger.logger + + +class SignalReading(BECSerializable): + value: int | float | list[int] | list[float] | np.ndarray | None | str + timestamp: float | list[float] | None = None + + def keys(self): + lazy_ensure_logger() + logger.warning( + "Dictionary usage of SignalReading is deprecated; please replace it with a different access pattern." + ) + return ["value", "timestamp"] + + def get(self, item: Literal["value", "timestamp"], default=Any): + """Allow dictionary-style access for legacy reasons.""" + lazy_ensure_logger() + logger.warning( + "Get-access on SignalReading is deprecated; Just access the model.value field." + ) + if item not in ["value", "timestamp"]: + raise KeyError('SignalReading only has "value" and "timestamp" fields!') + return getattr(self, item) + + def __getitem__(self, item: str): + return self.get(item) + + def items(self): + return dict(self).items() + + class DeviceMessage(BECMessage): """Message type for sending device readings from the device server @@ -589,7 +669,7 @@ class DeviceMessage(BECMessage): """ msg_type: ClassVar[str] = "device_message" - signals: dict[str, dict[Literal["value", "timestamp"], Any]] + signals: dict[str, SignalReading] @field_validator("metadata") @classmethod @@ -747,7 +827,7 @@ class DeviceInfoMessage(BECMessage): msg_type: ClassVar[str] = "device_info_message" device: str - info: dict + info: JsonableDict class DeviceMonitor2DMessage(BECMessage): @@ -767,8 +847,6 @@ class DeviceMonitor2DMessage(BECMessage): data: np.ndarray timestamp: float = Field(default_factory=time.time) - metadata: dict | None = Field(default_factory=dict) - # Needed for pydantic to accept numpy arrays model_config = ConfigDict(arbitrary_types_allowed=True) @@ -808,8 +886,6 @@ class DeviceMonitor1DMessage(BECMessage): data: np.ndarray timestamp: float = Field(default_factory=time.time) - metadata: dict | None = Field(default_factory=dict) - # Needed for pydantic to accept numpy arrays model_config = ConfigDict(arbitrary_types_allowed=True) @@ -867,7 +943,7 @@ class DeviceUserROIMessage(BECMessage): device: str signal: str roi_type: str = Field(description="Type of the ROI, e.g. 'rectangle', 'circle', 'polygon'") - roi: dict = Field( + roi: JsonableDict = Field( description="Dictionary containing the ROI information, e.g. {'x': 100, 'y': 200, 'width': 50, 'height': 50}" ) timestamp: float = Field(default_factory=time.time) @@ -887,7 +963,7 @@ class ScanMessage(BECMessage): msg_type: ClassVar[str] = "scan_message" point_id: int scan_id: str - data: dict + data: JsonableDict class ScanHistoryMessage(BECMessage): @@ -921,7 +997,7 @@ class ScanHistoryMessage(BECMessage): end_time: float scan_name: str num_points: int - request_inputs: dict | None = None + request_inputs: JsonableDict | None = None stored_data_info: dict[str, dict[str, _StoredDataInfo]] | None = None @@ -949,7 +1025,7 @@ class ScanBaselineMessage(BECMessage): msg_type: ClassVar[str] = "scan_baseline_message" scan_id: str - data: dict + data: JsonableDict ConfigAction = Literal["add", "set", "update", "reload", "remove", "reset", "cancel"] @@ -967,7 +1043,7 @@ class DeviceConfigMessage(BECMessage): msg_type: ClassVar[str] = "device_config_message" action: ConfigAction | None = Field(default=None, validate_default=True) - config: dict | None = Field(default=None) + config: JsonableDict | None = Field(default=None) @model_validator(mode="after") @classmethod @@ -1013,7 +1089,7 @@ class LogMessage(BECMessage): log_type: Literal[ "trace", "debug", "info", "success", "warning", "error", "critical", "console_log" ] - log_msg: dict | str + log_msg: JsonableDict | str class AlarmMessage(BECMessage): @@ -1128,8 +1204,8 @@ class FileContentMessage(BECMessage): msg_type: ClassVar[str] = "file_content_message" file_path: str - data: dict - scan_info: dict + data: JsonableDict + scan_info: JsonableDict class VariableMessage(BECMessage): @@ -1170,7 +1246,7 @@ class ServiceMetricMessage(BECMessage): msg_type: ClassVar[str] = "service_metric_message" name: str - metrics: dict + metrics: JsonableDict class ProcessedDataMessage(BECMessage): @@ -1182,7 +1258,7 @@ class ProcessedDataMessage(BECMessage): """ msg_type: ClassVar[str] = "processed_data_message" - data: dict | list[dict] + data: JsonableDict | list[JsonableDict] class DAPConfigMessage(BECMessage): @@ -1194,7 +1270,7 @@ class DAPConfigMessage(BECMessage): """ msg_type: ClassVar[str] = "dap_config_message" - config: dict + config: JsonableDict class DAPRequestMessage(BECMessage): @@ -1210,7 +1286,7 @@ class DAPRequestMessage(BECMessage): msg_type: ClassVar[str] = "dap_request_message" dap_cls: str dap_type: Literal["continuous", "on_demand"] - config: dict + config: JsonableDict class DAPResponseMessage(BECMessage): @@ -1228,19 +1304,48 @@ class DAPResponseMessage(BECMessage): success: bool data: tuple | None = Field(default_factory=lambda: ({}, None)) error: str | None = None - dap_request: BECMessage | None = Field(default=None) + dap_request: SpecificMessageType | None = Field(default=None) + + +class ScanArgType(StrEnum): + DEVICE = "device" + FLOAT = "float" + INT = "int" + BOOL = "boolean" + STR = "str" + LIST = "list" + DICT = "dict" + + +class AvailableScan(BECMessage): + """Information about an available scan""" + + class_name: str + base_class: str + arg_input: dict[str, Jsonable | ScanArgType] + gui_config: JsonableDict + required_kwargs: list[str] | dict[str, ScanArgType] + arg_bundle_size: JsonableDict + doc: str | None = None + signature: list[JsonableDict] class AvailableResourceMessage(BECMessage): """Message for available resources such as scans, data processing plugins etc Args: - resource (dict, list[dict], BECMessage, list[BECMessage]): Resource description + resource (dict, list[dict], BECMessage, list[BECMessage]): Resource description - may contain only one type of BECMessage metadata (dict, optional): Metadata. Defaults to None. """ msg_type: ClassVar[str] = "available_resource_message" - resource: dict | list[dict] | BECMessage | list[BECMessage] + resource: ( + JsonableDict + | list[JsonableDict] + | SpecificMessageType + | list[SpecificMessageType] + | dict[str, SpecificMessageType] + ) class ProgressMessage(BECMessage): @@ -1268,7 +1373,7 @@ class GUIConfigMessage(BECMessage): """ msg_type: ClassVar[str] = "gui_config_message" - config: dict + config: JsonableDict class GUIDataMessage(BECMessage): @@ -1280,7 +1385,7 @@ class GUIDataMessage(BECMessage): """ msg_type: ClassVar[str] = "gui_data_message" - data: dict + data: JsonableDict class GUIInstructionMessage(BECMessage): @@ -1293,7 +1398,7 @@ class GUIInstructionMessage(BECMessage): msg_type: ClassVar[str] = "gui_instruction_message" action: str - parameter: dict + parameter: JsonableDict class GUIAutoUpdateConfigMessage(BECMessage): @@ -1329,7 +1434,7 @@ class GUIRegistryStateMessage(BECMessage): "__rpc__", "container_proxy", ], - str | bool | dict | None, + str | bool | JsonableDict | None, ], ] @@ -1343,7 +1448,7 @@ class ServiceResponseMessage(BECMessage): """ msg_type: ClassVar[str] = "service_response_message" - response: dict + response: JsonableDict class CredentialsMessage(BECMessage): @@ -1355,7 +1460,7 @@ class CredentialsMessage(BECMessage): """ msg_type: ClassVar[str] = "credentials_message" - credentials: dict + credentials: JsonableDict class RawMessage(BECMessage): @@ -1368,7 +1473,7 @@ class RawMessage(BECMessage): """ msg_type: ClassVar[str] = "raw_message" - data: Any + data: Jsonable model_config = ConfigDict(arbitrary_types_allowed=True) @@ -1640,7 +1745,6 @@ class EndpointInfoMessage(BECMessage): msg_type: ClassVar[str] = "endpoint_info_message" endpoint: str - metadata: dict | None = Field(default_factory=dict) class ScriptExecutionInfoMessage(BECMessage): @@ -1673,8 +1777,6 @@ class MacroUpdateMessage(BECMessage): macro_name: str | None = None file_path: str | None = None - metadata: dict | None = Field(default_factory=dict) - @model_validator(mode="after") @classmethod def check_macro(cls, values): @@ -1769,7 +1871,6 @@ class MessagingServiceMessage(BECMessage): service_name: Literal["signal", "teams", "scilog"] message: list[MessagingServiceContent] scope: str | list[str] | None = None - metadata: dict | None = Field(default_factory=dict) class MessagingServiceConfig(BECMessage): @@ -1788,4 +1889,3 @@ class MessagingServiceConfig(BECMessage): service_name: Literal["signal", "teams", "scilog"] scopes: list[str] enabled: bool - metadata: dict | None = Field(default_factory=dict) diff --git a/bec_lib/bec_lib/one_way_registry.py b/bec_lib/bec_lib/one_way_registry.py new file mode 100644 index 000000000..d1f92a90e --- /dev/null +++ b/bec_lib/bec_lib/one_way_registry.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import Any, Callable, Type + +from bec_lib.device import DeviceBase + + +class OneWayBECCodec(ABC): + """Abstract base class for custom encoders""" + + obj_type: Type | list[Type] + + @staticmethod + @abstractmethod + def encode(obj: Any) -> Any: + """Encode an object into a serializable format.""" + + +class BECDeviceEncoder(OneWayBECCodec): + obj_type = DeviceBase + + @staticmethod + def encode(obj: DeviceBase) -> str: + if hasattr(obj, "_compile_function_path"): + # pylint: disable=protected-access + return obj._compile_function_path() + return obj.name + + +class OneWaySerializationRegistry: + """Registry for serialization codecs""" + + def __init__(self): + self._registry: dict[str, tuple[Type, Callable]] = {} + + self.register_codec(BECDeviceEncoder) + + def register_codec(self, codec: Type[OneWayBECCodec]): + """ + Register a codec for a specific BECCodec subclass. + This method allows for easy registration of custom encoders and decoders + for BECMessage and other types. + + Args: + codec: A subclass of BECCodec that implements encode and decode methods. + Raises: + ValueError: If a codec for the specified type is already registered. + """ + if isinstance(codec.obj_type, list): + for cls in codec.obj_type: + self.register(cls, codec.encode) + else: + self.register(codec.obj_type, codec.encode) + + def register(self, cls: Type, encoder: Callable): + """Register a codec for a specific type.""" + + if cls.__name__ in self._registry: + raise ValueError(f"Codec for {cls} already registered.") + self._registry[cls.__name__] = (cls, encoder) + self.get_codec.cache_clear() # Clear the cache when a new codec is registered + + @lru_cache(maxsize=2000) + def get_codec(self, cls: Type) -> tuple[Type, Callable] | None: + """Get the codec for a specific type.""" + codec = self._registry.get(cls.__name__) + if codec: + return codec + for _, (registered_cls, encoder) in self._registry.items(): + if issubclass(cls, registered_cls): + return registered_cls, encoder + return None + + def is_registered(self, cls: Type) -> bool: + """ + Check if a codec is registered for a specific type. + Args: + cls: The class type to check for a registered codec. + Returns: + bool: True if a codec is registered for the type, False otherwise. + """ + return self.get_codec(cls) is not None + + def encode(self, obj): + """Encode an object using the registered codec.""" + codec = self.get_codec(type(obj)) + if not codec: + return obj # No codec registered for this type + _, encoder = codec + try: + return encoder(obj) + except Exception as e: + raise ValueError( + f"Serialization failed: Failed to encode {obj.__class__.__name__} with codec {encoder}: {e}" + ) from e diff --git a/bec_lib/bec_lib/scans.py b/bec_lib/bec_lib/scans.py index 63b6a1ba1..92ad77814 100644 --- a/bec_lib/bec_lib/scans.py +++ b/bec_lib/bec_lib/scans.py @@ -20,6 +20,7 @@ from bec_lib.device import DeviceBase from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger +from bec_lib.messages import ScanArgType # moved from here to messages - for compat with plugins from bec_lib.scan_repeat import _scan_repeat_depth from bec_lib.scan_report import ScanReport from bec_lib.signature_serializer import dict_to_signature @@ -291,14 +292,14 @@ def prepare_scan_request( return messages.ScanQueueMessage( scan_type=scan_name, - parameter=params, + parameter=messages.sanitize_one_way_encodable(params), queue=scan_queue, metadata=metadata, allow_restart=allow_restart, ) @staticmethod - def _parameter_bundler(args: tuple, bundle_size: int) -> tuple | dict: + def _parameter_bundler(args: tuple, bundle_size: int) -> list | dict: """ Bundle the arguments into the correct format for the scan server. If the bundle size is 0, return the arguments as is. @@ -309,11 +310,11 @@ def _parameter_bundler(args: tuple, bundle_size: int) -> tuple | dict: bundle_size: number of parameters per bundle Returns: - tuple | dict: bundled arguments + list | dict: bundled arguments """ if not bundle_size: - return args + return list(args) params = {} for cmds in partition(bundle_size, args): params[cmds[0]] = list(cmds[1:]) diff --git a/bec_lib/bec_lib/serialization.py b/bec_lib/bec_lib/serialization.py index f512c279e..fe3ffce1a 100644 --- a/bec_lib/bec_lib/serialization.py +++ b/bec_lib/bec_lib/serialization.py @@ -36,6 +36,8 @@ class BECMessagePack(SerializationRegistry): def dumps(self, obj): """Pack object `obj` and return packed bytes.""" + if isinstance(obj, BECMessage): + obj = obj.model_dump(mode="python", fallback=self.encode) return msgpack_module.packb(obj, default=self.encode) def loads(self, raw_bytes): diff --git a/bec_lib/bec_lib/serialization_registry.py b/bec_lib/bec_lib/serialization_registry.py index e9aa923c8..e2e6c8d81 100644 --- a/bec_lib/bec_lib/serialization_registry.py +++ b/bec_lib/bec_lib/serialization_registry.py @@ -4,6 +4,7 @@ from typing import Callable, Type from bec_lib import codecs as bec_codecs +from bec_lib import messages from bec_lib.logger import bec_logger logger = bec_logger.logger @@ -18,8 +19,6 @@ def __init__(self): self._registry: dict[str, tuple[Type, Callable, Callable]] = {} self._legacy_codecs = [] # can be removed in future versions, see issue #516 - self.register_codec(bec_codecs.BECMessageEncoder) - self.register_codec(bec_codecs.BECDeviceEncoder) self.register_codec(bec_codecs.EndpointInfoEncoder) self.register_codec(bec_codecs.SetEncoder) self.register_codec(bec_codecs.BECTypeEncoder) @@ -98,6 +97,11 @@ def encode(self, obj): def decode(self, data): """Decode an object using the registered codec.""" + if isinstance(data, dict) and "bec_codec" in data: + codec_info = data.pop("bec_codec") + msg_cls = messages.__dict__.get(codec_info.get("type_name")) + if msg_cls is not None: + return msg_cls.model_validate(data) if not isinstance(data, dict) or "__bec_codec__" not in data: return data codec_info = data["__bec_codec__"] diff --git a/bec_lib/bec_lib/signature_serializer.py b/bec_lib/bec_lib/signature_serializer.py index 94e83b01b..bf1a6dfad 100644 --- a/bec_lib/bec_lib/signature_serializer.py +++ b/bec_lib/bec_lib/signature_serializer.py @@ -44,7 +44,7 @@ def _merge_literals(vals: Generator[str | dict, None, None]) -> Generator[str | if _literal_args == [None]: yield "NoneType" elif _literal_args: - yield {"Literal": tuple(_literal_args)} + yield {"Literal": list(_literal_args)} def serialize_dtype(dtype: type) -> list[str | dict] | str | dict: diff --git a/bec_lib/tests/test_bec_messages.py b/bec_lib/tests/test_bec_messages.py index 8439b6332..48b3a8b7b 100644 --- a/bec_lib/tests/test_bec_messages.py +++ b/bec_lib/tests/test_bec_messages.py @@ -17,7 +17,7 @@ def test_bec_message_msgpack_serialization_version(version): assert "Unsupported BECMessage version" in str(exception.value) else: res = MsgpackSerialization.dumps(msg) - res_expected = b"\x81\xad__bec_codec__\x83\xacencoder_name\xaaBECMessage\xa9type_name\xb8DeviceInstructionMessage\xa4data\x84\xa8metadata\x81\xa3RID\xa41234\xa6device\xa4samx\xa6action\xa3set\xa9parameter\x81\xa3set\xcb?\xe0\x00\x00\x00\x00\x00\x00" + res_expected = b"\x85\xa8metadata\x81\xa3RID\xa41234\xa6device\xa4samx\xa6action\xa3set\xa9parameter\x81\xa3set\xcb?\xe0\x00\x00\x00\x00\x00\x00\xa9bec_codec\x81\xa9type_name\xb8DeviceInstructionMessage" assert res == res_expected res_loaded = MsgpackSerialization.loads(res) assert res_loaded == msg @@ -429,15 +429,13 @@ def test_DeviceInstructionMessage(): def test_DeviceMonitor2DMessage(): # Test 2D data - msg = messages.DeviceMonitor2DMessage( - device="eiger", data=np.random.rand(2, 100), metadata=None - ) + msg = messages.DeviceMonitor2DMessage(device="eiger", data=np.random.rand(2, 100)) res = MsgpackSerialization.dumps(msg) res_loaded = MsgpackSerialization.loads(res) assert res_loaded == msg assert res_loaded.metadata == {} # Test rgb image, i.e. image with 3 channels - msg = messages.DeviceMonitor2DMessage(device="eiger", data=np.random.rand(3, 3), metadata=None) + msg = messages.DeviceMonitor2DMessage(device="eiger", data=np.random.rand(3, 3)) res = MsgpackSerialization.dumps(msg) res_loaded = MsgpackSerialization.loads(res) assert res_loaded == msg @@ -454,7 +452,7 @@ def test_DeviceMonitor2DMessage(): def test_DeviceMonitor1DMessage(): # Test 2D data - msg = messages.DeviceMonitor1DMessage(device="eiger", data=np.random.rand(100), metadata=None) + msg = messages.DeviceMonitor1DMessage(device="eiger", data=np.random.rand(100)) res = MsgpackSerialization.dumps(msg) res_loaded = MsgpackSerialization.loads(res) assert res_loaded == msg @@ -678,3 +676,22 @@ def test_valid_add_slice_various_indices(self, index): """Test various valid index values for add_slice type""" update = messages.DeviceAsyncUpdate(type="add_slice", max_shape=[None, 1024], index=index) assert update.index == index + + +def test_message_with_np_array_in_dict(): + arr = np.zeros(5) + with pytest.raises(pydantic.ValidationError) as e: + msg = messages.BECMessage(metadata={"value": arr}) + assert e.match("metadata.value") + assert e.match("should be a valid") + + +def test_message_service_config(): + msg = messages.MessagingServiceConfig( + metadata={}, service_name="signal", scopes=["*"], enabled=True + ) + dump = msg.model_dump(mode="python") + assert dump["service_name"] == "signal" + resource_msg = messages.AvailableResourceMessage(resource=[msg]) + resource_msg_dump = resource_msg.model_dump(mode="python") + assert resource_msg_dump["resource"][0]["service_name"] == "signal" diff --git a/bec_lib/tests/test_config_helper.py b/bec_lib/tests/test_config_helper.py index 72a9431e9..fc0b10014 100644 --- a/bec_lib/tests/test_config_helper.py +++ b/bec_lib/tests/test_config_helper.py @@ -72,7 +72,7 @@ def test_config_helper_save_current_session(config_helper): "enabled": True, "readOnly": False, "deviceClass": "SimPositioner", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "deviceConfig": { "delay": 1, "labels": "pinz", @@ -93,7 +93,7 @@ def test_config_helper_save_current_session(config_helper): "enabled": True, "readOnly": False, "deviceClass": "SimMonitor", - "deviceTags": {"beamline"}, + "deviceTags": ["beamline"], "deviceConfig": {"labels": "transd", "name": "transd", "tolerance": 0.5}, "readoutPriority": "monitored", "onFailure": "retry", @@ -238,7 +238,7 @@ def test_update_base_path_recovery(config_helper_plain): { "pinz": { "deviceClass": "SimPositioner", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, "deviceConfig": { @@ -254,7 +254,7 @@ def test_update_base_path_recovery(config_helper_plain): }, "transd": { "deviceClass": "SimMonitor", - "deviceTags": {"beamline"}, + "deviceTags": ["beamline"], "enabled": True, "readOnly": False, "deviceConfig": {"labels": "transd", "name": "transd", "tolerance": 0.5}, @@ -265,7 +265,7 @@ def test_update_base_path_recovery(config_helper_plain): { "pinz": { "deviceClass": "SimPositioner", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": True, "deviceConfig": { @@ -281,7 +281,7 @@ def test_update_base_path_recovery(config_helper_plain): }, "transd": { "deviceClass": "SimMonitor", - "deviceTags": {"beamline"}, + "deviceTags": ["beamline"], "enabled": True, "readOnly": False, "deviceConfig": {"labels": "transd", "name": "transd", "tolerance": 0.5}, @@ -295,7 +295,7 @@ def test_update_base_path_recovery(config_helper_plain): { "pinz": { "deviceClass": "SimPositioner", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, "deviceConfig": { @@ -311,7 +311,7 @@ def test_update_base_path_recovery(config_helper_plain): }, "transd": { "deviceClass": "SimMonitor", - "deviceTags": {"beamline"}, + "deviceTags": ["beamline"], "enabled": True, "readOnly": False, "deviceConfig": {"labels": "transd", "name": "transd", "tolerance": 0.5}, @@ -322,7 +322,7 @@ def test_update_base_path_recovery(config_helper_plain): { "pinz": { "deviceClass": "SimPositioner", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, "deviceConfig": { @@ -338,7 +338,7 @@ def test_update_base_path_recovery(config_helper_plain): }, "transd": { "deviceClass": "SimMonitor", - "deviceTags": {"beamline"}, + "deviceTags": ["beamline"], "enabled": True, "readOnly": False, "deviceConfig": {"labels": "transd", "name": "transd", "tolerance": 0.5}, @@ -352,7 +352,7 @@ def test_update_base_path_recovery(config_helper_plain): { "pinz": { "deviceClass": "SimPositioner", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, "deviceConfig": { @@ -368,7 +368,7 @@ def test_update_base_path_recovery(config_helper_plain): }, "transd": { "deviceClass": "SimMonitor", - "deviceTags": {"beamline"}, + "deviceTags": ["beamline"], "enabled": True, "readOnly": False, "deviceConfig": {"labels": "transd", "name": "transd", "tolerance": 0.5}, @@ -379,7 +379,7 @@ def test_update_base_path_recovery(config_helper_plain): { "pinz": { "deviceClass": "SimPositioner", - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, "deviceConfig": { @@ -395,7 +395,7 @@ def test_update_base_path_recovery(config_helper_plain): }, "transd": { "deviceClass": "SimMonitor", - "deviceTags": {"beamline"}, + "deviceTags": ["beamline"], "enabled": True, "readOnly": False, "deviceConfig": {"labels": "transd", "name": "transd", "tolerance": 0.5}, @@ -419,7 +419,8 @@ def test_config_helper_get_config_conflicts( config.update(dev_cfg) config_in_redis.append(config) with mock.patch.object(config_helper._device_manager.connector, "get") as mock_get: - mock_get.return_value = messages.AvailableResourceMessage(resource=config_in_redis) + available_resource_message = messages.AvailableResourceMessage(resource=config_in_redis) + mock_get.return_value = available_resource_message conflicts = config_helper._get_config_conflicts(new_config) assert conflicts == expected_conflicts diff --git a/bec_lib/tests/test_devices.py b/bec_lib/tests/test_devices.py index 02b449cd2..bf620d402 100644 --- a/bec_lib/tests/test_devices.py +++ b/bec_lib/tests/test_devices.py @@ -53,9 +53,15 @@ def test_read(dev: Any): res = dev.samx.read(cached=True) mock_get.assert_called_once_with(MessageEndpoints.device_readback("samx")) assert res == { - "samx": {"value": 0, "timestamp": 1701105880.1711318}, - "samx_setpoint": {"value": 0, "timestamp": 1701105880.1693492}, - "samx_motor_is_moving": {"value": 0, "timestamp": 1701105880.16935}, + "samx": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.1711318} + ), + "samx_setpoint": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.1693492} + ), + "samx_motor_is_moving": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.16935} + ), } @@ -71,15 +77,25 @@ def test_read_filtered_hints(dev: Any): ) res = dev.samx.read(cached=True, filter_to_hints=True) mock_get.assert_called_once_with(MessageEndpoints.device_readback("samx")) - assert res == {"samx": {"value": 0, "timestamp": 1701105880.1711318}} + assert res == { + "samx": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.1711318} + ) + } def test_read_use_read(dev: Any): with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get: data = { - "samx": {"value": 0, "timestamp": 1701105880.1711318}, - "samx_setpoint": {"value": 0, "timestamp": 1701105880.1693492}, - "samx_motor_is_moving": {"value": 0, "timestamp": 1701105880.16935}, + "samx": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.1711318} + ), + "samx_setpoint": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.1693492} + ), + "samx_motor_is_moving": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.16935} + ), } mock_get.return_value = messages.DeviceMessage( signals=data, metadata={"scan_id": "scan_id", "scan_type": "scan_type"} @@ -92,11 +108,21 @@ def test_read_use_read(dev: Any): def test_read_nested_device(dev: Any): with mock.patch.object(dev.dyn_signals.root.parent.connector, "get") as mock_get: data = { - "dyn_signals_messages_message1": {"value": 0, "timestamp": 1701105880.0716832}, - "dyn_signals_messages_message2": {"value": 0, "timestamp": 1701105880.071722}, - "dyn_signals_messages_message3": {"value": 0, "timestamp": 1701105880.071739}, - "dyn_signals_messages_message4": {"value": 0, "timestamp": 1701105880.071753}, - "dyn_signals_messages_message5": {"value": 0, "timestamp": 1701105880.071766}, + "dyn_signals_messages_message1": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.0716832} + ), + "dyn_signals_messages_message2": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.071722} + ), + "dyn_signals_messages_message3": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.071739} + ), + "dyn_signals_messages_message4": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.071753} + ), + "dyn_signals_messages_message5": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.071766} + ), } mock_get.return_value = messages.DeviceMessage( signals=data, metadata={"scan_id": "scan_id", "scan_type": "scan_type"} @@ -131,7 +157,11 @@ def test_read_kind_hinted( if cached: mock_get.assert_called_once_with(MessageEndpoints.device_readback("samx")) mock_run.assert_not_called() - assert res == {"samx": {"value": 0, "timestamp": 1701105880.1711318}} + assert res == { + "samx": messages.SignalReading.model_validate( + {"value": 0, "timestamp": 1701105880.1711318} + ) + } else: mock_run.assert_called_once_with(cached=False, fcn=dev.samx.readback.read) mock_get.assert_not_called() @@ -199,10 +229,10 @@ def test_read_configuration_cached( @pytest.mark.parametrize( ["mock_rpc", "method", "args", "kwargs", "expected_call"], [ - ("_get_rpc_response", "set", (1,), {}, (mock.ANY, mock.ANY)), - ("_run_rpc_call", "set", (1,), {}, ("samx", "setpoint.set", 1)), - ("_run_rpc_call", "put", (1,), {"wait": True}, ("samx", "setpoint.set", 1)), - ("_run_rpc_call", "put", (1,), {}, ("samx", "setpoint.put", 1)), + ("_get_rpc_response", "set", [1], {}, (mock.ANY, mock.ANY)), + ("_run_rpc_call", "set", [1], {}, ("samx", "setpoint.set", 1)), + ("_run_rpc_call", "put", [1], {"wait": True}, ("samx", "setpoint.set", 1)), + ("_run_rpc_call", "put", [1], {}, ("samx", "setpoint.put", 1)), ], ) def test_run_rpc_call(dev: Any, mock_rpc, method, args, kwargs, expected_call): @@ -326,7 +356,7 @@ def device_config(): "readoutPriority": "monitored", "deviceClass": "SimCamera", "deviceConfig": {"device_access": True, "labels": "eiger", "name": "eiger"}, - "deviceTags": {"detector"}, + "deviceTags": ["detector"], } @@ -360,7 +390,12 @@ def device_obj(device_config: dict[str, Any]): def test_create_device_saves_config( device_obj: DeviceBaseWithConfig, device_config: dict[str, Any] ): - assert {k: v for k, v in device_obj._config.items() if k in device_config} == device_config + assert ( + messages.sanitize_one_way_encodable( + {k: v for k, v in device_obj._config.items() if k in device_config} + ) + == device_config + ) def test_device_enabled(device_obj: DeviceBaseWithConfig, device_config: dict[str, Any]): @@ -454,7 +489,7 @@ def test_status_wait(): @pytest.fixture def device_w_tags(dev_w_config: Callable[..., DeviceBaseWithConfig]): - yield dev_w_config({"deviceTags": {"tag1", "tag2"}}) + yield dev_w_config({"deviceTags": ["tag1", "tag2"]}) @pytest.mark.parametrize( @@ -492,7 +527,7 @@ def test_properties(dev_w_config: Callable[..., DeviceBaseWithConfig], config, a @pytest.mark.parametrize( ["config", "method", "value"], - [({"deviceTags": {"tag1", "tag2"}}, "get_device_tags", {"tag1", "tag2"})], + [({"deviceTags": ["tag1", "tag2"]}, "get_device_tags", {"tag1", "tag2"})], ) def test_methods(dev_w_config: Callable[..., DeviceBaseWithConfig], config, method, value): assert getattr(dev_w_config(config), method)() == value @@ -591,7 +626,7 @@ def test_show_all(): "readOnly": False, "deviceClass": "Class1", "readoutPriority": "monitored", - "deviceTags": {"tag1", "tag2"}, + "deviceTags": ["tag1", "tag2"], }, parent=parent, ) @@ -603,7 +638,7 @@ def test_show_all(): "readOnly": True, "deviceClass": "Class2", "readoutPriority": "baseline", - "deviceTags": {"tag3", "tag4"}, + "deviceTags": ["tag3", "tag4"], }, parent=parent, ) diff --git a/bec_lib/tests/test_file_utils.py b/bec_lib/tests/test_file_utils.py index b27b6274b..ab27c225d 100644 --- a/bec_lib/tests/test_file_utils.py +++ b/bec_lib/tests/test_file_utils.py @@ -40,7 +40,7 @@ def scan_msg(): yield ScanStatusMessage( scan_id="1111", scan_parameters={"system_config": {"file_directory": None, "file_suffix": None}}, - info={"scan_number": 5, "file_components": ("S00000-00999/S00005/S00005", "h5")}, + info={"scan_number": 5, "file_components": ["S00000-00999/S00005/S00005", "h5"]}, status="closed", ) @@ -202,7 +202,7 @@ def test_compile_file_components(): ScanStatusMessage( scan_id="1111", scan_parameters={"system_config": {"file_directory": None, "file_suffix": None}}, - info={"scan_number": 5, "file_components": ("S00000-00999/S00005/S00005", "h5")}, + info={"scan_number": 5, "file_components": ["S00000-00999/S00005/S00005", "h5"]}, status="closed", ) ), @@ -212,7 +212,7 @@ def test_compile_file_components(): scan_parameters={ "system_config": {"file_directory": "/my_dir", "file_suffix": None} }, - info={"scan_number": 5, "file_components": ("/my_dir/S00005", "h5")}, + info={"scan_number": 5, "file_components": ["/my_dir/S00005", "h5"]}, status="closed", ) ), @@ -224,7 +224,7 @@ def test_compile_file_components(): }, info={ "scan_number": 5, - "file_components": ("S00000-00999/S00005_sampleA/S00005", "h5"), + "file_components": ["S00000-00999/S00005_sampleA/S00005", "h5"], }, status="closed", ) diff --git a/bec_lib/tests/test_scan_context.py b/bec_lib/tests/test_scan_context.py index 7c04584cd..afcb2ed43 100644 --- a/bec_lib/tests/test_scan_context.py +++ b/bec_lib/tests/test_scan_context.py @@ -136,7 +136,7 @@ def test_parameter_bundler(bec_client_mock): assert res == {dev.samx: [-5, 5, 5]} res = client.scans._parameter_bundler((-5, 5, 5), 0) - assert res == (-5, 5, 5) + assert res == [-5, 5, 5] @pytest.mark.parametrize( diff --git a/bec_lib/tests/test_serializer.py b/bec_lib/tests/test_serializer.py index 93f5990de..e1a254e7b 100644 --- a/bec_lib/tests/test_serializer.py +++ b/bec_lib/tests/test_serializer.py @@ -10,6 +10,7 @@ from bec_lib.device import DeviceBase from bec_lib.devicemanager import DeviceManagerBase from bec_lib.endpoints import MessageEndpoints +from bec_lib.one_way_registry import OneWaySerializationRegistry from bec_lib.serialization import MsgpackSerialization, json_ext, msgpack @@ -81,10 +82,11 @@ class DummyModel(BaseModel): assert data.model_dump() == converted_data -def test_device_serializer(serializer): +def test_device_serializer(): + serializer = OneWaySerializationRegistry() device_manager = mock.MagicMock(spec=DeviceManagerBase) dummy = DeviceBase(name="dummy", parent=device_manager) - assert serializer.loads(serializer.dumps(dummy)) == "dummy" + assert serializer.encode(dummy) == "dummy" def test_enum_serializer(serializer): diff --git a/bec_lib/tests/test_signature_serializer.py b/bec_lib/tests/test_signature_serializer.py index f2b0a5da9..b0418b291 100644 --- a/bec_lib/tests/test_signature_serializer.py +++ b/bec_lib/tests/test_signature_serializer.py @@ -41,7 +41,7 @@ def test_func(a: Literal[1, 2, 3] | None = None): "name": "a", "kind": "POSITIONAL_OR_KEYWORD", "default": None, - "annotation": {"Literal": (1, 2, 3, None)}, + "annotation": {"Literal": [1, 2, 3, None]}, } ] @@ -57,7 +57,7 @@ def test_func(a, b: Literal["test", None], *args, **kwargs): "name": "b", "kind": "POSITIONAL_OR_KEYWORD", "default": "_empty", - "annotation": {"Literal": ("test", None)}, + "annotation": {"Literal": ["test", None]}, }, {"name": "args", "kind": "VAR_POSITIONAL", "default": "_empty", "annotation": "_empty"}, {"name": "kwargs", "kind": "VAR_KEYWORD", "default": "_empty", "annotation": "_empty"}, @@ -81,13 +81,13 @@ def test_func( "name": "b", "kind": "POSITIONAL_OR_KEYWORD", "default": "_empty", - "annotation": {"Literal": ("test", None)}, + "annotation": {"Literal": ["test", None]}, }, { "name": "c", "kind": "POSITIONAL_OR_KEYWORD", "default": 1, - "annotation": {"Literal": (1, 2, 3)}, + "annotation": {"Literal": [1, 2, 3]}, }, { "name": "d", @@ -115,7 +115,7 @@ def test_func( (float, "float"), (bool, "bool"), (inspect._empty, "_empty"), - (Literal[1, 2, 3], {"Literal": (1, 2, 3)}), + (Literal[1, 2, 3], {"Literal": [1, 2, 3]}), (Union[int, str], ["int", "str"]), (Optional[str], ["str", "NoneType"]), (DeviceBase, "DeviceBase"), @@ -135,7 +135,7 @@ def test_serialize_dtype(dtype_in, dtype_out): ("float", float), ("bool", bool), ("_empty", inspect._empty), - ({"Literal": (1, 2, 3)}, Literal[1, 2, 3]), + ({"Literal": [1, 2, 3]}, Literal[1, 2, 3]), (["int", "str"], Union[int, str]), (["str", "NoneType"], Optional[str]), ("NoneType", None), diff --git a/bec_lib/tests/test_signature_serializer_with_future_import.py b/bec_lib/tests/test_signature_serializer_with_future_import.py index d895aac22..038245cfd 100644 --- a/bec_lib/tests/test_signature_serializer_with_future_import.py +++ b/bec_lib/tests/test_signature_serializer_with_future_import.py @@ -19,7 +19,7 @@ def test_func(a: Literal[1, 2, 3] | None = None): "name": "a", "kind": "POSITIONAL_OR_KEYWORD", "default": None, - "annotation": {"Literal": (1, 2, 3, None)}, + "annotation": {"Literal": [1, 2, 3, None]}, } ] @@ -34,7 +34,7 @@ def test_func(a: Literal[1, 2, 3] | None | Literal["a", "b", "c"]): "name": "a", "kind": "POSITIONAL_OR_KEYWORD", "default": "_empty", - "annotation": {"Literal": (1, 2, 3, None, "a", "b", "c")}, + "annotation": {"Literal": [1, 2, 3, None, "a", "b", "c"]}, } ] @@ -52,7 +52,7 @@ def test_func(a: Literal[1, 2, 3] | "SomeUnknownType" | Literal["a", "b", "c"]): "name": "a", "kind": "POSITIONAL_OR_KEYWORD", "default": "_empty", - "annotation": {"Literal": (1, 2, 3, "a", "b", "c")}, + "annotation": {"Literal": [1, 2, 3, "a", "b", "c"]}, } ] @@ -61,7 +61,7 @@ def test_serialize_dtype_imported_imported_func_arg(): sig = inspect.signature(literal_union_test_func) anno = sig.parameters["a"].annotation assert serialize_dtype(anno) == serialize_dtype(Union[Literal["a", "b", "c"], EnumTest]) - assert serialize_dtype(anno) == {"Literal": ("a", "b", "c")} + assert serialize_dtype(anno) == {"Literal": ["a", "b", "c"]} def test_signature_serializer_parses_untion_on_imported_func(): @@ -71,7 +71,7 @@ def test_signature_serializer_parses_untion_on_imported_func(): "name": "a", "kind": "POSITIONAL_OR_KEYWORD", "default": "_empty", - "annotation": {"Literal": ("a", "b", "c")}, + "annotation": {"Literal": ["a", "b", "c"]}, } ] diff --git a/bec_server/bec_server/device_server/device_server.py b/bec_server/bec_server/device_server/device_server.py index f47e275c0..a85b423c1 100644 --- a/bec_server/bec_server/device_server/device_server.py +++ b/bec_server/bec_server/device_server/device_server.py @@ -19,7 +19,7 @@ from bec_lib.device import OnFailure from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger -from bec_lib.messages import BECStatus +from bec_lib.messages import BECStatus, sanitize_one_way_encodable from bec_lib.serialization import json_ext from bec_lib.utils.rpc_utils import rgetattr from bec_server.device_server.devices.devicemanager import DeviceManagerDS @@ -774,7 +774,7 @@ def status_callback(self, status): def _update_read_configuration(self, obj: OphydObject, metadata: dict, pipe) -> None: dev_config_msg = messages.DeviceMessage( - signals=obj.root.read_configuration(), metadata=metadata + signals=sanitize_one_way_encodable(obj.root.read_configuration()), metadata=metadata ) self.connector.set_and_publish( MessageEndpoints.device_read_configuration(obj.root.name), dev_config_msg, pipe diff --git a/bec_server/bec_server/device_server/devices/device_serializer.py b/bec_server/bec_server/device_server/devices/device_serializer.py index 84b9ad626..e2cd18e35 100644 --- a/bec_server/bec_server/device_server/devices/device_serializer.py +++ b/bec_server/bec_server/device_server/devices/device_serializer.py @@ -12,6 +12,7 @@ from ophyd_devices import BECDeviceBase, ComputedSignal from ophyd_devices.utils.bec_signals import BECMessageSignal +from bec_lib import messages from bec_lib.bec_errors import DeviceConfigError from bec_lib.device import DeviceBaseWithConfig from bec_lib.logger import bec_logger @@ -183,7 +184,9 @@ def get_device_info( "kind_int": kind, "kind_str": Kind(kind).name, "doc": doc, - "describe": signal_obj.describe().get(signal_obj.name, {}), + "describe": messages.sanitize_one_way_encodable( + signal_obj.describe().get(signal_obj.name, {}) + ), # pylint: disable=protected-access "metadata": signal_obj._metadata, } @@ -200,7 +203,9 @@ def get_device_info( "kind_int": signal_obj.kind.value, "kind_str": signal_obj.kind.name, "doc": doc, - "describe": signal_obj.describe().get(signal_obj.name, {}), + "describe": messages.sanitize_one_way_encodable( + signal_obj.describe().get(signal_obj.name, {}) + ), # pylint: disable=protected-access "metadata": signal_obj._metadata, } diff --git a/bec_server/bec_server/device_server/devices/devicemanager.py b/bec_server/bec_server/device_server/devices/devicemanager.py index 0a0f5bf42..6437eaef2 100644 --- a/bec_server/bec_server/device_server/devices/devicemanager.py +++ b/bec_server/bec_server/device_server/devices/devicemanager.py @@ -117,7 +117,8 @@ def initialize_device_buffer(self, connector): if not isinstance(self.obj, ophyd.Signal): # signals have the same read and read_configuration values; no need to publish twice dev_config_msg = messages.DeviceMessage( - signals=self.obj.read_configuration(), metadata={} + signals=messages.sanitize_one_way_encodable(self.obj.read_configuration()), + metadata={}, ) connector.set_and_publish( MessageEndpoints.device_read_configuration(self.name), dev_config_msg, pipe=pipe diff --git a/bec_server/bec_server/procedures/manager.py b/bec_server/bec_server/procedures/manager.py index 7e2c15f84..b0e531786 100644 --- a/bec_server/bec_server/procedures/manager.py +++ b/bec_server/bec_server/procedures/manager.py @@ -50,6 +50,7 @@ def _log_on_end(future: Future): def _resolve_dict(msg: dict[str, Any] | _T, MsgType: type[_T]) -> _T: if isinstance(msg, dict): + msg.pop("bec_codec", None) return MsgType.model_validate(msg) return msg @@ -95,7 +96,7 @@ def __init__(self, redis: str, worker_type: type[ProcedureWorker]): MessageEndpoints.available_procedures(), AvailableResourceMessage( resource={ - name: procedure_registry.get_info(name) + name: list(procedure_registry.get_info(name)) for name in procedure_registry.available() } ), diff --git a/bec_server/bec_server/scan_server/scan_assembler.py b/bec_server/bec_server/scan_server/scan_assembler.py index 57b0eb1e2..9b1bc525d 100644 --- a/bec_server/bec_server/scan_server/scan_assembler.py +++ b/bec_server/bec_server/scan_server/scan_assembler.py @@ -34,8 +34,8 @@ def is_scan_message(self, msg: messages.ScanQueueMessage) -> bool: Returns: bool: True if the message is a scan message, False otherwise """ - scan = msg.content.get("scan_type") - cls_name = self.scan_manager.available_scans[scan]["class"] + scan = msg.scan_type + cls_name = self.scan_manager.available_scans[scan].class_name scan_cls = self.scan_manager.scan_dict[cls_name] return issubclass(scan_cls, ScanBase) @@ -55,8 +55,8 @@ def assemble_device_instructions( Returns: RequestBase: Scan instance of the initialized scan class """ - scan = msg.content.get("scan_type") - cls_name = self.scan_manager.available_scans[scan]["class"] + scan = msg.scan_type + cls_name = self.scan_manager.available_scans[scan].class_name scan_cls = self.scan_manager.scan_dict[cls_name] logger.info(f"Preparing instructions of request of type {scan} / {scan_cls.__name__}") diff --git a/bec_server/bec_server/scan_server/scan_gui_models.py b/bec_server/bec_server/scan_server/scan_gui_models.py index d560af472..7dec7a428 100644 --- a/bec_server/bec_server/scan_server/scan_gui_models.py +++ b/bec_server/bec_server/scan_server/scan_gui_models.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator from pydantic_core import PydanticCustomError +from bec_lib.messages import Jsonable, sanitize_one_way_encodable from bec_lib.signature_serializer import signature_to_dict from bec_server.scan_server.scans import ScanBase @@ -24,9 +25,9 @@ class GUIInput(BaseModel): arg: bool = Field(False) name: str = Field(None, validate_default=True) - type: Optional[ - Literal["DeviceBase", "device", "float", "int", "bool", "str", "list", "dict"] - ] = Field(None, validate_default=True) + type: ( + Literal["DeviceBase", "device", "float", "int", "bool", "str", "list", "dict"] | Jsonable + ) = Field(None, validate_default=True) display_name: Optional[str] = Field(None, validate_default=True) tooltip: Optional[str] = Field(None, validate_default=True) default: Optional[Any] = Field(None, validate_default=True) @@ -53,7 +54,7 @@ def validate_name(cls, v, values): def validate_field(cls, v, values): # args cannot be validated with the current implementation of signature of scans if values.data["arg"]: - return v + return sanitize_one_way_encodable(v) signature = context_signature.get() if v is None: name = values.data.get("name", None) @@ -66,7 +67,7 @@ def validate_field(cls, v, values): for entry in signature: if entry["name"] == name: v = entry["annotation"] - return v + return sanitize_one_way_encodable(v) @field_validator("tooltip") @classmethod @@ -187,7 +188,7 @@ class GUIConfig(BaseModel): scan_class_name: str arg_group: Optional[GUIArgGroup] = Field(None) - kwarg_groups: list[GUIGroup] = Field(None) + kwarg_groups: list[GUIGroup] | None = Field(None) signature: list[dict] = Field(..., exclude=True) docstring: str = Field(..., exclude=True) diff --git a/bec_server/bec_server/scan_server/scan_manager.py b/bec_server/bec_server/scan_server/scan_manager.py index 2198e629a..59cc24ebc 100644 --- a/bec_server/bec_server/scan_server/scan_manager.py +++ b/bec_server/bec_server/scan_server/scan_manager.py @@ -4,11 +4,10 @@ import inspect -from bec_lib import plugin_helper +from bec_lib import messages, plugin_helper from bec_lib.device import DeviceBase from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger -from bec_lib.messages import AvailableResourceMessage from bec_lib.signature_serializer import signature_to_dict from bec_server.scan_server.scan_gui_models import GUIConfig @@ -27,7 +26,7 @@ def __init__(self, *, parent): Scan Manager loads and manages the available scans. """ self.parent = parent - self.available_scans = {} + self.available_scans: dict[str, messages.AvailableScan] = {} self.scan_dict: dict[str, type[scans_module.RequestBase]] = {} self._plugins = {} self.load_plugins() @@ -77,16 +76,18 @@ def update_available_scans(self): base_cls = report_cls.__name__ self.scan_dict[scan_cls.__name__] = scan_cls gui_config = self.validate_gui_config(scan_cls) - self.available_scans[scan_cls.scan_name] = { - "class": scan_cls.__name__, - "base_class": base_cls, - "arg_input": self.convert_arg_input(scan_cls.arg_input), - "gui_config": gui_config, - "required_kwargs": scan_cls.required_kwargs, - "arg_bundle_size": scan_cls.arg_bundle_size, - "doc": scan_cls.__doc__ or scan_cls.__init__.__doc__, - "signature": signature_to_dict(scan_cls.__init__), - } + self.available_scans[scan_cls.scan_name] = messages.AvailableScan.model_validate( + { + "class_name": scan_cls.__name__, + "base_class": base_cls, + "arg_input": self.convert_arg_input(scan_cls.arg_input), + "gui_config": gui_config, + "required_kwargs": scan_cls.required_kwargs, + "arg_bundle_size": scan_cls.arg_bundle_size, + "doc": scan_cls.__doc__ or scan_cls.__init__.__doc__, + "signature": signature_to_dict(scan_cls.__init__), + } + ) def validate_gui_config(self, scan_cls) -> dict: """ @@ -142,5 +143,5 @@ def publish_available_scans(self): """send all available scans to the broker""" self.parent.connector.set( MessageEndpoints.available_scans(), - AvailableResourceMessage(resource=self.available_scans), + messages.AvailableResourceMessage(resource=self.available_scans), ) diff --git a/bec_server/bec_server/scan_server/scan_plugins/otf_scan.py b/bec_server/bec_server/scan_server/scan_plugins/otf_scan.py index 93f3b4e75..faac453af 100644 --- a/bec_server/bec_server/scan_server/scan_plugins/otf_scan.py +++ b/bec_server/bec_server/scan_server/scan_plugins/otf_scan.py @@ -1,7 +1,8 @@ import time from bec_lib.logger import bec_logger -from bec_server.scan_server.scans import ScanArgType, ScanBase, SyncFlyScanBase +from bec_lib.messages import ScanArgType +from bec_server.scan_server.scans import ScanBase, SyncFlyScanBase logger = bec_logger.logger diff --git a/bec_server/bec_server/scan_server/scan_stubs.py b/bec_server/bec_server/scan_server/scan_stubs.py index 3f5077372..7efbabe23 100644 --- a/bec_server/bec_server/scan_server/scan_stubs.py +++ b/bec_server/bec_server/scan_server/scan_stubs.py @@ -315,7 +315,9 @@ def _exclude_nones(input_dict: dict): def _device_msg(self, **kwargs) -> messages.DeviceInstructionMessage: """""" - msg = messages.DeviceInstructionMessage(**kwargs) + msg = messages.DeviceInstructionMessage.model_validate( + messages.sanitize_one_way_encodable(kwargs) + ) msg.metadata = {**self.device_msg_metadata(), **msg.metadata} return msg diff --git a/bec_server/bec_server/scan_server/scans.py b/bec_server/bec_server/scan_server/scans.py index 1549429f0..b83bc5706 100644 --- a/bec_server/bec_server/scan_server/scans.py +++ b/bec_server/bec_server/scan_server/scans.py @@ -1,7 +1,6 @@ from __future__ import annotations import ast -import enum import threading import time import uuid @@ -16,6 +15,7 @@ from bec_lib.devicemanager import DeviceManagerBase from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger +from bec_lib.messages import ScanArgType from bec_server.scan_server.instruction_handler import InstructionHandler from .errors import LimitError, ScanAbortion @@ -25,16 +25,6 @@ logger = bec_logger.logger -class ScanArgType(str, enum.Enum): - DEVICE = "device" - FLOAT = "float" - INT = "int" - BOOL = "boolean" - STR = "str" - LIST = "list" - DICT = "dict" - - def unpack_scan_args(scan_args: dict[str, Any]) -> list: """unpack_scan_args unpacks the scan arguments and returns them as a tuple. @@ -941,7 +931,7 @@ def scan_report_instructions(self): "RID": self.metadata["RID"], "devices": self.scan_motors, "start": self.start_pos, - "end": self.positions[0], + "end": list(self.positions[0]), } } ) diff --git a/bec_server/tests/tests_device_server/test_config_handler.py b/bec_server/tests/tests_device_server/test_config_handler.py index 192892141..8cddb920c 100644 --- a/bec_server/tests/tests_device_server/test_config_handler.py +++ b/bec_server/tests/tests_device_server/test_config_handler.py @@ -128,7 +128,7 @@ def test_parse_config_request_add_remove(dm_with_devices): "tolerance": 0.01, "update_frequency": 400, }, - "deviceTags": {"user motors"}, + "deviceTags": ["user motors"], "enabled": True, "readOnly": False, "name": "new_device", diff --git a/bec_server/tests/tests_device_server/test_rpc_handler.py b/bec_server/tests/tests_device_server/test_rpc_handler.py index 172d0ce86..b1d2948e5 100644 --- a/bec_server/tests/tests_device_server/test_rpc_handler.py +++ b/bec_server/tests/tests_device_server/test_rpc_handler.py @@ -58,7 +58,7 @@ def test_execute_rpc_call(rpc_cls: RPCHandler, instr_params): msg = messages.DeviceInstructionMessage( device="device", action="rpc", - parameter=instr_params, + parameter=messages.sanitize_one_way_encodable(instr_params), metadata={"RID": "RID", "device_instr_id": "diid"}, ) out = rpc_cls._execute_rpc_call(rpc_var=rpc_var, instr=msg) @@ -80,7 +80,7 @@ def test_execute_rpc_call_var(rpc_cls: RPCHandler, instr_params: dict): msg = messages.DeviceInstructionMessage( device="device", action="rpc", - parameter=instr_params, + parameter=messages.sanitize_one_way_encodable(instr_params), metadata={"RID": "RID", "device_instr_id": "diid"}, ) out = rpc_cls._execute_rpc_call(rpc_var=rpc_var, instr=msg) diff --git a/bec_server/tests/tests_file_writer/test_async_file_writer.py b/bec_server/tests/tests_file_writer/test_async_file_writer.py index 31b354ffd..a79e77bf8 100644 --- a/bec_server/tests/tests_file_writer/test_async_file_writer.py +++ b/bec_server/tests/tests_file_writer/test_async_file_writer.py @@ -549,7 +549,7 @@ def test_async_writer_raises_on_wrong_data_type(async_writer): # Send invalid data (not a DeviceMessage) invalid_data = messages.DeviceMessage( - signals={"monitor_async": {"value": {"data": None}, "timestamp": 1}}, + signals={"monitor_async": {"value": None, "timestamp": 1}}, metadata={"async_update": {"type": "add", "max_shape": [None]}}, ) diff --git a/bec_server/tests/tests_scan_server/test_scan_assembler.py b/bec_server/tests/tests_scan_server/test_scan_assembler.py index 0156fd57e..b0a7d4f41 100644 --- a/bec_server/tests/tests_scan_server/test_scan_assembler.py +++ b/bec_server/tests/tests_scan_server/test_scan_assembler.py @@ -39,7 +39,7 @@ def run(self): # Fermat scan with args and kwargs, matching the FermatSpiralScan signature messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"steps": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"steps": 3}}, queue="primary", ), { @@ -120,7 +120,7 @@ def run(self): # Line scan with arg bundle messages.ScanQueueMessage( scan_type="line_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"steps": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"steps": 3}}, queue="primary", ), {"arg_bundle": ["samx", -5, 5, "samy", -5, 5], "inputs": {}, "kwargs": {"steps": 3}}, @@ -147,12 +147,23 @@ def run(self): ) def test_scan_assembler_request_inputs(msg, request_inputs_expected, scan_assembler): + def _available_scan(clss: str): + return messages.AvailableScan( + class_name=clss, + base_class="", + arg_input={}, + gui_config={}, + required_kwargs=[], + arg_bundle_size={}, + signature=[], + ) + class MockScanManager: available_scans = { - "fermat_scan": {"class": "FermatSpiralScan"}, - "line_scan": {"class": "LineScan"}, - "custom_scan": {"class": "CustomScan"}, - "custom_scan2": {"class": "CustomScan2"}, + "fermat_scan": _available_scan("FermatSpiralScan"), + "line_scan": _available_scan("LineScan"), + "custom_scan": _available_scan("CustomScan"), + "custom_scan2": _available_scan("CustomScan2"), } scan_dict = { "FermatSpiralScan": FermatSpiralScan, diff --git a/bec_server/tests/tests_scan_server/test_scan_guard.py b/bec_server/tests/tests_scan_server/test_scan_guard.py index 3c544ac35..fceb2d43c 100644 --- a/bec_server/tests/tests_scan_server/test_scan_guard.py +++ b/bec_server/tests/tests_scan_server/test_scan_guard.py @@ -24,7 +24,7 @@ def scan_guard_mock(scan_server_mock): ( messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) ), @@ -73,7 +73,7 @@ def test_device_rpc_is_valid(scan_guard_mock, device, func, is_valid): ( messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ), True, @@ -122,7 +122,7 @@ def test_check_valid_scan_raises_for_unknown_scan(scan_guard_mock): request = messages.ScanQueueMessage( scan_type="unknown_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) @@ -139,7 +139,7 @@ def test_check_valid_scan_accepts_known_scan(scan_guard_mock): request = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) @@ -206,7 +206,7 @@ def test_append_to_scan_queue(scan_guard_mock): sg = scan_guard_mock msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) with mock.patch.object(sg.device_manager.connector, "send") as send: @@ -218,7 +218,7 @@ def test_scan_queue_request_callback(scan_guard_mock): sg = scan_guard_mock msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) msg_obj = MessageObject(MessageEndpoints.scan_queue_request(), msg) @@ -252,7 +252,7 @@ def test_handle_scan_request(scan_guard_mock): sg = scan_guard_mock msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) with mock.patch.object(sg, "_is_valid_scan_request") as valid: @@ -333,7 +333,7 @@ def test_handle_scan_request_rejected(scan_guard_mock): sg = scan_guard_mock msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) with mock.patch.object(sg, "_is_valid_scan_request") as valid: @@ -347,7 +347,7 @@ def test_is_valid_scan_request_returns_scan_status_on_error(scan_guard_mock): sg = scan_guard_mock msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) with mock.patch.object(sg, "_check_valid_scan") as valid: diff --git a/bec_server/tests/tests_scan_server/test_scan_gui_models.py b/bec_server/tests/tests_scan_server/test_scan_gui_models.py index 7ab1b0df0..784bb78d2 100644 --- a/bec_server/tests/tests_scan_server/test_scan_gui_models.py +++ b/bec_server/tests/tests_scan_server/test_scan_gui_models.py @@ -3,8 +3,9 @@ import pytest from pydantic import ValidationError +from bec_lib.messages import ScanArgType from bec_server.scan_server.scan_gui_models import GUIConfig -from bec_server.scan_server.scans import ScanArgType, ScanBase +from bec_server.scan_server.scans import ScanBase class GoodScan(ScanBase): # pragma: no cover @@ -193,7 +194,7 @@ def test_gui_config_good_scan_dump(): "expert": False, "name": "optim_trajectory", "tooltip": None, - "type": {"Literal": ("path", None)}, + "type": {"Literal": ["path", None]}, }, ], } diff --git a/bec_server/tests/tests_scan_server/test_scan_server_queue.py b/bec_server/tests/tests_scan_server/test_scan_server_queue.py index f17740945..42a298ef1 100644 --- a/bec_server/tests/tests_scan_server/test_scan_server_queue.py +++ b/bec_server/tests/tests_scan_server/test_scan_server_queue.py @@ -78,7 +78,7 @@ def test_queuemanager_add_to_queue(queuemanager_mock, queue): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue=queue, metadata={"RID": "something"}, ) @@ -120,7 +120,7 @@ def test_queuemanager_add_to_queue_restarts_queue_if_worker_is_dead(queuemanager msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -135,7 +135,7 @@ def test_queuemanager_add_to_queue_error_send_alarm(queuemanager_mock): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -151,7 +151,7 @@ def test_queuemanager_scan_queue_callback(queuemanager_mock): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -227,7 +227,7 @@ def test_set_pause(queuemanager_mock): # Add a queue item so worker_status has something to operate on msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -251,7 +251,7 @@ def test_set_pause_does_not_change_non_running_worker(queuemanager_mock): # Add a queue item msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -298,7 +298,7 @@ def test_set_abort(queuemanager_mock): queue_manager.connector.message_sent = [] msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -327,7 +327,7 @@ def test_set_abort_with_scan_id(queuemanager_mock): queue_manager.connector.message_sent = [] msg = messages.ScanQueueMessage( scan_type="line_scan", - parameter={"args": {"samx": (-1, 1)}, "kwargs": {"steps": 10, "relative": False}}, + parameter={"args": {"samx": [-1, 1]}, "kwargs": {"steps": 10, "relative": False}}, queue="primary", metadata={"RID": "something"}, ) @@ -356,7 +356,7 @@ def test_set_abort_with_scan_id_not_active(queuemanager_mock): queue_manager.connector.message_sent = [] msg = messages.ScanQueueMessage( scan_type="line_scan", - parameter={"args": {"samx": (-1, 1)}, "kwargs": {"steps": 10, "relative": False}}, + parameter={"args": {"samx": [-1, 1]}, "kwargs": {"steps": 10, "relative": False}}, queue="primary", metadata={"RID": "something"}, ) @@ -379,7 +379,7 @@ def test_set_abort_with_wrong_scan_id(queuemanager_mock): queue_manager.connector.message_sent = [] msg = messages.ScanQueueMessage( scan_type="line_scan", - parameter={"args": {"samx": (-1, 1)}, "kwargs": {"steps": 10, "relative": False}}, + parameter={"args": {"samx": [-1, 1]}, "kwargs": {"steps": 10, "relative": False}}, queue="primary", metadata={"RID": "something"}, ) @@ -431,7 +431,7 @@ def test_set_restart(queuemanager_mock): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -462,7 +462,7 @@ def test_set_restart_no_active_scan(queuemanager_mock): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -488,7 +488,7 @@ def test_set_user_completed(queuemanager_mock): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="line_scan", - parameter={"args": {"samx": (-1, 1)}, "kwargs": {"steps": 10, "relative": False}}, + parameter={"args": {"samx": [-1, 1]}, "kwargs": {"steps": 10, "relative": False}}, queue="primary", metadata={"RID": "something"}, ) @@ -509,7 +509,7 @@ def test_request_block(scan_server_mock): scan_server = scan_server_mock msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -524,7 +524,7 @@ def test_request_block(scan_server_mock): ( messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -532,7 +532,7 @@ def test_request_block(scan_server_mock): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -558,7 +558,7 @@ def test_remove_queue_item(queuemanager_mock): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -572,7 +572,7 @@ def test_invalid_scan_specified_in_message(queuemanager_mock): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="fake test scan which does not exist!", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -587,7 +587,7 @@ def test_set_clear(queuemanager_mock): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -679,7 +679,7 @@ def test_request_block_queue_append(): req_block_queue = RequestBlockQueue(mock.MagicMock(), mock.MagicMock()) msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -697,7 +697,7 @@ def test_request_block_queue_append(): ( messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ), @@ -706,7 +706,7 @@ def test_request_block_queue_append(): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something", "scan_def_id": "something"}, ), @@ -715,7 +715,7 @@ def test_request_block_queue_append(): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something", "scan_def_id": "existing_scan_def_id"}, ), @@ -752,7 +752,7 @@ def test_append_request_block(): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something", "scan_def_id": "something"}, ), @@ -761,7 +761,7 @@ def test_append_request_block(): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something", "scan_def_id": "existing_scan_def_id"}, ), @@ -792,7 +792,7 @@ def test_update_point_id(scan_queue_msg, scan_id): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something", "scan_def_id": "existing_scan_def_id"}, ), @@ -819,7 +819,7 @@ def test_update_point_id_takes_max(scan_queue_msg, scan_id): ( messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ), @@ -828,7 +828,7 @@ def test_update_point_id_takes_max(scan_queue_msg, scan_id): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ), @@ -837,7 +837,7 @@ def test_update_point_id_takes_max(scan_queue_msg, scan_id): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something", "scan_def_id": "existing_scan_def_id"}, ), @@ -846,7 +846,7 @@ def test_update_point_id_takes_max(scan_queue_msg, scan_id): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something", "dataset_id_on_hold": True}, ), @@ -877,7 +877,7 @@ def test_pull_request_block_non_empyt_rb(): req_block_queue = RequestBlockQueue(mock.MagicMock(), mock.MagicMock()) scan_queue_msg = messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -934,7 +934,7 @@ def test_queue_manager_get_active_scan_id(queuemanager_mock): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -953,7 +953,7 @@ def test_queue_manager_get_active_scan_id_wo_rbl_returns_None(queuemanager_mock) queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -965,7 +965,7 @@ def test_request_block_queue_next(): req_block_queue = RequestBlockQueue(mock.MagicMock(), mock.MagicMock()) msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -981,7 +981,7 @@ def test_request_block_queue_next_raises_stopiteration(): req_block_queue = RequestBlockQueue(mock.MagicMock(), mock.MagicMock()) msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) @@ -998,7 +998,7 @@ def test_request_block_queue_next_updates_point_id(): req_block_queue = RequestBlockQueue(mock.MagicMock(), mock.MagicMock()) msg = messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "something", "scan_def_id": "scan_def_id"}, ) @@ -1059,7 +1059,7 @@ def test_queue_order_change(queuemanager_mock, order_msg, position): queue_manager = queuemanager_mock() msg = messages.ScanQueueMessage( scan_type="line_scan", - parameter={"args": {"samx": (-5, 5)}, "kwargs": {"steps": 3}}, + parameter={"args": {"samx": [-5, 5]}, "kwargs": {"steps": 3}}, queue="primary", metadata={"RID": "something"}, ) diff --git a/bec_server/tests/tests_scan_server/test_scan_server_scan_manager.py b/bec_server/tests/tests_scan_server/test_scan_server_scan_manager.py index 604e88e86..35c0bc725 100644 --- a/bec_server/tests/tests_scan_server/test_scan_server_scan_manager.py +++ b/bec_server/tests/tests_scan_server/test_scan_server_scan_manager.py @@ -3,8 +3,8 @@ import pytest from bec_lib.device import Device, DeviceBase, Positioner +from bec_lib.messages import ScanArgType from bec_server.scan_server.scan_manager import ScanManager -from bec_server.scan_server.scans import ScanArgType @pytest.fixture diff --git a/bec_server/tests/tests_scan_server/test_scan_worker.py b/bec_server/tests/tests_scan_server/test_scan_worker.py index 7e99f6037..d4bae3843 100644 --- a/bec_server/tests/tests_scan_server/test_scan_worker.py +++ b/bec_server/tests/tests_scan_server/test_scan_worker.py @@ -101,7 +101,7 @@ def test_publish_data_as_read(scan_worker_mock): def test_publish_data_as_read_multiple(scan_worker_mock): worker = scan_worker_mock - data = [{"samx": {}}, {"samy": {}}] + data = [{"samx": {"value": None}}, {"samy": {"value": None}}] devices = ["samx", "samy"] instr = messages.DeviceInstructionMessage( device=devices, @@ -217,7 +217,7 @@ def test_open_scan(scan_worker_mock, instr, corr_num_points, scan_id): messages.ScanQueueMessage( scan_type="grid_scan", parameter={ - "args": {"samx": (-5, 5, 5), "samy": (-1, 1, 2)}, + "args": {"samx": [-5, 5, 5], "samy": [-1, 1, 2]}, "kwargs": { "exp_time": 1, "relative": True, @@ -234,7 +234,7 @@ def test_open_scan(scan_worker_mock, instr, corr_num_points, scan_id): messages.ScanQueueMessage( scan_type="grid_scan", parameter={ - "args": {"samx": (-5, 5, 5), "samy": (-1, 1, 2)}, + "args": {"samx": [-5, 5, 5], "samy": [-1, 1, 2]}, "kwargs": { "exp_time": 1, "relative": True, @@ -251,7 +251,7 @@ def test_open_scan(scan_worker_mock, instr, corr_num_points, scan_id): messages.ScanQueueMessage( scan_type="grid_scan", parameter={ - "args": {"samx": (-5, 5, 5), "samy": (-1, 1, 2)}, + "args": {"samx": [-5, 5, 5], "samy": [-1, 1, 2]}, "kwargs": { "exp_time": 1, "relative": True, @@ -293,7 +293,7 @@ def test_initialize_scan_info(scan_worker_mock, msg): assert worker.current_scan_info["scan_msgs"] == [] assert worker.current_scan_info["monitor_sync"] == "bec" assert worker.current_scan_info["frames_per_trigger"] == 1 - assert worker.current_scan_info["args"] == {"samx": (-5, 5, 5), "samy": (-1, 1, 2)} + assert worker.current_scan_info["args"] == {"samx": [-5, 5, 5], "samy": [-1, 1, 2]} assert worker.current_scan_info["kwargs"] == msg.parameter["kwargs"] assert "samx" in worker.current_scan_info["readout_priority"]["monitored"] assert "samy" in worker.current_scan_info["readout_priority"]["baseline"] diff --git a/bec_server/tests/tests_scan_server/test_scans.py b/bec_server/tests/tests_scan_server/test_scans.py index 72ce9045a..be03432b8 100644 --- a/bec_server/tests/tests_scan_server/test_scans.py +++ b/bec_server/tests/tests_scan_server/test_scans.py @@ -69,7 +69,7 @@ def test_unpack_scan_args_valid_input(): ( messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,), "samy": (2,)}, "kwargs": {}}, + parameter={"args": {"samx": [1], "samy": [2]}, "kwargs": {}}, queue="primary", ), [ @@ -90,7 +90,7 @@ def test_unpack_scan_args_valid_input(): ( messages.ScanQueueMessage( scan_type="mv", - parameter={"args": {"samx": (1,), "samy": (2,), "samz": (3,)}, "kwargs": {}}, + parameter={"args": {"samx": [1], "samy": [2], "samz": [3]}, "kwargs": {}}, queue="primary", ), [ @@ -116,7 +116,7 @@ def test_unpack_scan_args_valid_input(): ), ( messages.ScanQueueMessage( - scan_type="mv", parameter={"args": {"samx": (1,)}, "kwargs": {}}, queue="primary" + scan_type="mv", parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary" ), [ messages.DeviceInstructionMessage( @@ -154,7 +154,7 @@ def offset_mock(): ( messages.ScanQueueMessage( scan_type="umv", - parameter={"args": {"samx": (1,), "samy": (2,)}, "kwargs": {}}, + parameter={"args": {"samx": [1], "samy": [2]}, "kwargs": {}}, queue="primary", metadata={"RID": "0bab7ee3-b384-4571-b...0fff984c05"}, ), @@ -167,7 +167,7 @@ def offset_mock(): "RID": "0bab7ee3-b384-4571-b...0fff984c05", "devices": ["samx", "samy"], "start": [0, 0], - "end": np.array([1.0, 2.0]), + "end": [1.0, 2.0], } }, metadata={ @@ -198,7 +198,7 @@ def offset_mock(): ( messages.ScanQueueMessage( scan_type="umv", - parameter={"args": {"samx": (1,), "samy": (2,), "samz": (3,)}, "kwargs": {}}, + parameter={"args": {"samx": [1], "samy": [2], "samz": [3]}, "kwargs": {}}, queue="primary", metadata={"RID": "0bab7ee3-b384-4571-b...0fff984c05"}, ), @@ -211,7 +211,7 @@ def offset_mock(): "RID": "0bab7ee3-b384-4571-b...0fff984c05", "devices": ["samx", "samy", "samz"], "start": [0, 0, 0], - "end": np.array([1.0, 2.0, 3.0]), + "end": [1.0, 2.0, 3.0], } }, metadata={ @@ -251,7 +251,7 @@ def offset_mock(): ( messages.ScanQueueMessage( scan_type="umv", - parameter={"args": {"samx": (1,)}, "kwargs": {}}, + parameter={"args": {"samx": [1]}, "kwargs": {}}, queue="primary", metadata={"RID": "0bab7ee3-b384-4571-b...0fff984c05"}, ), @@ -264,7 +264,7 @@ def offset_mock(): "RID": "0bab7ee3-b384-4571-b...0fff984c05", "devices": ["samx"], "start": [0], - "end": np.array([1.0]), + "end": [1.0], } }, metadata={ @@ -297,8 +297,8 @@ def test_scan_updated_move(mv_msg, reference_msg_list, scan_assembler, ScanStubS mock_get_from_rpc.return_value = { dev: {"value": value} for dev, value in zip( - reference_msg_list[0].content["parameter"]["readback"]["devices"], - reference_msg_list[0].content["parameter"]["readback"]["start"], + reference_msg_list[0].parameter["readback"]["devices"], + reference_msg_list[0].parameter["readback"]["start"], ) } @@ -322,7 +322,7 @@ def mock_rpc_func(*args, **kwargs): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 3)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 3]}, "kwargs": {}}, queue="primary", ), [ @@ -473,7 +473,7 @@ def offset_mock(): ( messages.ScanQueueMessage( scan_type="grid_scan", - parameter={"args": {"samx": (-5, 5, 2), "samy": (-5, 5, 2)}, "kwargs": {}}, + parameter={"args": {"samx": [-5, 5, 2], "samy": [-5, 5, 2]}, "kwargs": {}}, queue="primary", ), [ @@ -648,7 +648,7 @@ def offset_mock(): ( messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ), [ @@ -668,7 +668,7 @@ def offset_mock(): messages.ScanQueueMessage( scan_type="fermat_scan", parameter={ - "args": {"samx": (-5, 5), "samy": (-5, 5)}, + "args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3, "spiral_type": 1}, }, queue="primary", @@ -721,7 +721,7 @@ def offset_mock(): }, scan_type="cont_line_scan", parameter={ - "args": ("samx", -1, 1), + "args": ["samx", -1, 1], "kwargs": { "steps": 3, "exp_time": 0.1, @@ -744,7 +744,7 @@ def offset_mock(): metadata={"readout_priority": "monitored"}, device="samx", action="rpc", - parameter={"device": "samx", "func": "velocity.get", "args": (), "kwargs": {}}, + parameter={"device": "samx", "func": "velocity.get", "args": [], "kwargs": {}}, ), messages.DeviceInstructionMessage( metadata={"readout_priority": "monitored"}, @@ -753,7 +753,7 @@ def offset_mock(): parameter={ "device": "samx", "func": "acceleration.get", - "args": (), + "args": [], "kwargs": {}, }, ), @@ -761,7 +761,7 @@ def offset_mock(): metadata={"readout_priority": "monitored"}, device="samx", action="rpc", - parameter={"device": "samx", "func": "read", "args": (), "kwargs": {}}, + parameter={"device": "samx", "func": "read", "args": [], "kwargs": {}}, ), messages.DeviceInstructionMessage( metadata={"readout_priority": "monitored"}, @@ -1053,7 +1053,7 @@ def pre_scan_macro(devices: dict, request: RequestBase): macros = inspect.getsource(pre_scan_macro).encode() scan_msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) args = unpack_scan_args(scan_msg.content.get("parameter", {}).get("args", [])) @@ -1077,7 +1077,7 @@ def pre_scan_macro(devices: dict, request: RequestBase): # device_manager = DMMock() # device_manager.add_device("samx") # parameter = { -# "args": {"samx": (-5, 5), "samy": (-5, 5)}, +# "args": {"samx": [-5, 5], "samy": [-5, 5]}, # "kwargs": {"step": 3}, # } # request = RequestBase(device_manager=device_manager, parameter=parameter) @@ -1099,7 +1099,7 @@ def test_round_roi_scan(): scan_msg = messages.ScanQueueMessage( scan_type="round_roi_scan", parameter={ - "args": {"samx": (10,), "samy": (10,)}, + "args": {"samx": [10], "samy": [10]}, "kwargs": {"dr": 2, "nth": 4, "exp_time": 2, "relative": True}, }, queue="primary", @@ -1211,7 +1211,7 @@ def pre_scan_macro(devices: dict, request: RequestBase): macros = [inspect.getsource(pre_scan_macro).encode()] scan_msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) args = unpack_scan_args(scan_msg.content.get("parameter", {}).get("args", [])) @@ -1228,7 +1228,7 @@ def test_scan_report_devices(): device_manager.add_device("samy") scan_msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) args = unpack_scan_args(scan_msg.content.get("parameter", {}).get("args", [])) @@ -1253,7 +1253,7 @@ def run(self): device_manager.add_device("samy") scan_msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) request = RequestBaseMock( @@ -1305,7 +1305,7 @@ def run(self): device_manager.add_device("samz") scan_msg = messages.ScanQueueMessage( scan_type="fermat_scan", - parameter={"args": {"samx": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) request = RequestBaseMock( @@ -1318,7 +1318,7 @@ def run(self): assert request.scan_motors == ["samx"] request.arg_bundle_size = {"bundle": 2, "min": None, "max": None} - request.caller_args = {"samz": (-2, 2), "samy": (-1, 2)} + request.caller_args = {"samz": [-2, 2], "samy": [-1, 2]} request.update_scan_motors() assert request.scan_motors == ["samz", "samy"] @@ -1340,7 +1340,7 @@ def _calculate_positions(self): scan_msg = messages.ScanQueueMessage( scan_type="", - parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + parameter={"args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3}}, queue="primary", ) with pytest.raises(ValueError) as exc_info: @@ -1358,7 +1358,7 @@ def test_scan_base_set_position_offset(): scan_msg = messages.ScanQueueMessage( scan_type="fermat_scan", parameter={ - "args": {"samx": (-5, 5), "samy": (-5, 5)}, + "args": {"samx": [-5, 5], "samy": [-5, 5]}, "kwargs": {"step": 3, "relative": False}, }, queue="primary", @@ -1385,7 +1385,7 @@ def test_round_scan_fly_simupdate_scan_motors(): device_manager.add_device("flyer_sim") scan_msg = messages.ScanQueueMessage( scan_type="round_scan_fly", - parameter={"args": {"flyer_sim": (0, 50, 5, 3)}, "kwargs": {"realtive": True}}, + parameter={"args": {"flyer_sim": [0, 50, 5, 3]}, "kwargs": {"realtive": True}}, queue="primary", ) request = RoundScanFlySim( @@ -1408,7 +1408,7 @@ def test_round_scan_fly_sim_prepare_positions(): device_manager.add_device("flyer_sim") scan_msg = messages.ScanQueueMessage( scan_type="round_scan_fly", - parameter={"args": {"flyer_sim": (0, 50, 5, 3)}, "kwargs": {"realtive": True}}, + parameter={"args": {"flyer_sim": [0, 50, 5, 3]}, "kwargs": {"realtive": True}}, queue="primary", ) request = RoundScanFlySim( @@ -1433,7 +1433,7 @@ def test_round_scan_fly_sim_prepare_positions(): @pytest.mark.parametrize( - "in_args,reference_positions", [((1, 5, 1, 1), [[0, -3], [0, -7], [0, 7]])] + "in_args,reference_positions", [([1, 5, 1, 1], [[0, -3], [0, -7], [0, 7]])] ) def test_round_scan_fly_sim_calculate_positions(in_args, reference_positions): device_manager = DMMock() @@ -1458,7 +1458,7 @@ def test_round_scan_fly_sim_calculate_positions(in_args, reference_positions): @pytest.mark.parametrize( - "in_args,reference_positions", [((1, 5, 1, 1), [[0, -3], [0, -7], [0, 7]])] + "in_args,reference_positions", [([1, 5, 1, 1], [[0, -3], [0, -7], [0, 7]])] ) def test_round_scan_fly_sim_scan_core(in_args, reference_positions, scan_assembler): scan_msg = messages.ScanQueueMessage( @@ -2139,7 +2139,7 @@ def fake_set(*args, **kwargs): "device": "samx", "func": "read", "rpc_id": "rpc_id", - "args": (), + "args": [], "kwargs": {}, }, ),