From eb36f00ff5f140638cc203fc4ce17d6ef835b337 Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Mon, 1 Dec 2025 16:46:48 +0100 Subject: [PATCH 1/6] feat: add beamline states --- bec_lib/bec_lib/bl_checks.py | 246 -------- bec_lib/bec_lib/bl_conditions.py | 89 --- bec_lib/bec_lib/bl_states.py | 338 +++++++++++ bec_lib/bec_lib/client.py | 9 +- bec_lib/bec_lib/endpoints.py | 34 ++ bec_lib/bec_lib/messages.py | 45 ++ bec_lib/tests/test_beamline_checks.py | 167 ------ bec_lib/tests/test_beamline_states.py | 564 ++++++++++++++++++ bec_lib/tests/test_bl_conditions.py | 37 -- .../scan_server/beamline_state_manager.py | 83 +++ .../bec_server/scan_server/scan_server.py | 6 + .../test_beamline_state_manager.py | 103 ++++ 12 files changed, 1176 insertions(+), 545 deletions(-) delete mode 100644 bec_lib/bec_lib/bl_checks.py delete mode 100644 bec_lib/bec_lib/bl_conditions.py create mode 100644 bec_lib/bec_lib/bl_states.py delete mode 100644 bec_lib/tests/test_beamline_checks.py create mode 100644 bec_lib/tests/test_beamline_states.py delete mode 100644 bec_lib/tests/test_bl_conditions.py create mode 100644 bec_server/bec_server/scan_server/beamline_state_manager.py create mode 100644 bec_server/tests/tests_scan_server/test_beamline_state_manager.py diff --git a/bec_lib/bec_lib/bl_checks.py b/bec_lib/bec_lib/bl_checks.py deleted file mode 100644 index 8bdc7c3bf..000000000 --- a/bec_lib/bec_lib/bl_checks.py +++ /dev/null @@ -1,246 +0,0 @@ -""" -This module provides the BeamlineChecks class, which is used to perform beamline checks. It also provides the bl_check -decorator. -""" - -import builtins -import datetime -import functools -import threading -import time -from collections import deque -from uuid import uuid4 - -from typeguard import typechecked - -from bec_lib.bl_conditions import BeamlineCondition -from bec_lib.logger import bec_logger - -logger = bec_logger.logger - - -class BeamlineCheckError(Exception): - pass - - -class BeamlineCheckRepeat(Exception): - pass - - -def bl_check(fcn): - """Decorator to perform rpc calls.""" - - @functools.wraps(fcn) - def bl_check_wrapper(*args, **kwargs): - client = builtins.__dict__.get("bec") - bl_checks = client.bl_checks - _run_with_bl_checks(bl_checks, fcn, *args, **kwargs) - - return bl_check_wrapper - - -def _run_with_bl_checks(bl_checks, fcn, *args, **kwargs): - # pylint: disable=protected-access - chk = {"id": str(uuid4()), "fcn": fcn, "args": args, "kwargs": kwargs} - bl_checks._levels.append(chk) - nested_call = len(bl_checks._levels) > 1 - if bl_checks._is_paused and bl_checks._beamline_checks: - logger.warning( - "Beamline checks are currently paused. Use `bec.bl_checks.resume()` to reactivate them." - ) - try: - if nested_call: - # check if the beam was okay so far - if not bl_checks.beam_is_okay: - raise BeamlineCheckError("Beam is not okay.") - else: - bl_checks.reset() - bl_checks.wait_for_beamline_checks() - successful = False - while not successful: - try: - successful, res = _run_on_failure(bl_checks, fcn, *args, **kwargs) - - if not bl_checks.beam_is_okay: - successful = False - bl_checks.wait_for_beamline_checks() - except BeamlineCheckRepeat: - successful = False - return res - - finally: - bl_checks._levels.pop() - - -def _run_on_failure(bl_checks, fcn, *args, **kwargs) -> tuple: - try: - res = fcn(*args, **kwargs) - return (True, res) - except BeamlineCheckError: - bl_checks.wait_for_beamline_checks() - return (False, None) - - -class BeamlineChecks: - def __init__(self, client, *args, **kwargs): - super().__init__(*args, **kwargs) - self.client = client - self.send_to_scilog = True - self._beam_is_okay = True - self._beamline_checks = {} - self._stop_beam_check_event = threading.Event() - self._beam_check_thread = None - self._started = False - self._is_paused = False - self._check_msgs = [] - self._levels = deque() - - @typechecked - def register(self, check: BeamlineCondition): - """ - Register a beamline check. - - Args: - check (BeamlineCondition): The beamline check to register. - """ - self._beamline_checks[check.name] = check - setattr(self, check.name, check) - - def pause(self) -> None: - """ - Pause beamline checks. This will disable all checks. Use `resume` to - reactivate the checks. - """ - self._is_paused = True - - def resume(self) -> None: - """ - Resume all paused beamline checks. - """ - self._is_paused = False - - def available_checks(self) -> None: - """ - Print all available beamline checks - """ - for name, check in self._beamline_checks.items(): - enabled = f"ENABLED: {check.enabled}" - print(f"{name:<20} {enabled}") - - def disable_check(self, name: str) -> None: - """ - Disable a beamline check. - - Args: - name (str): The name of the beamline check to disable. - """ - if name not in self._beamline_checks: - raise ValueError(f"Beamline check {name} not registered.") - self._beamline_checks[name].enabled = False - - def enable_check(self, name: str) -> None: - """ - Enable a beamline check. - - Args: - name (str): The name of the beamline check to enable. - """ - if name not in self._beamline_checks: - raise ValueError(f"Beamline check {name} not registered.") - self._beamline_checks[name].enabled = True - - def disable_all_checks(self) -> None: - """ - Disable all beamline checks. - """ - for name in self._beamline_checks: - self.disable_check(name) - - def enable_all_checks(self) -> None: - """ - Enable all beamline checks. - """ - for name in self._beamline_checks: - self.enable_check(name) - - def _run_beamline_checks(self): - msgs = [] - for name, check in self._beamline_checks.items(): - if not check.enabled: - continue - if check.run(): - continue - msgs.append(check.on_failure_msg()) - self._beam_is_okay = False - return msgs - - def _check_beam(self): - while not self._stop_beam_check_event.wait(timeout=1): - self._check_msgs = self._run_beamline_checks() - - def start(self): - """Start the beamline checks.""" - if self._started: - return - self._beam_is_okay = True - - self._beam_check_thread = threading.Thread(target=self._check_beam, daemon=True) - self._beam_check_thread.start() - self._started = True - - def stop(self): - """Stop the beamline checks""" - if not self._started: - return - - self._stop_beam_check_event.set() - self._beam_check_thread.join() - - def reset(self): - self._beam_is_okay = True - self._check_msgs = [] - - @property - def beam_is_okay(self): - return self._beam_is_okay - - def _print_beamline_checks(self): - for msg in self._check_msgs: - logger.warning(msg) - - def wait_for_beamline_checks(self): - self._print_beamline_checks() - if self.send_to_scilog and not self.beam_is_okay: - self._send_to_scilog( - f"Beamline checks failed at {str(datetime.datetime.now())}: {''.join(self._check_msgs)}", - pen="red", - ) - - self._run_beamline_checks_until_okay() - - if self.send_to_scilog: - self._send_to_scilog( - f"Operation resumed at {str(datetime.datetime.now())}.", pen="green" - ) - - def _run_beamline_checks_until_okay(self): - while True: - if self._beam_status_is_okay(): - break - self._print_beamline_checks() - time.sleep(5) - - def _beam_status_is_okay(self) -> bool: - self._beam_is_okay = True - self._check_msgs = self._run_beamline_checks() - return self._beam_is_okay - - def _send_to_scilog(self, msg, pen="red"): - try: - msg = self.client.logbook.LogbookMessage() - msg.add_text(f"

{msg}

").add_tag( - ["BEC", "beam_check"] - ) - self.client.logbook.send_logbook_message(msg) - except Exception: - logger.warning("Failed to send update to SciLog.") diff --git a/bec_lib/bec_lib/bl_conditions.py b/bec_lib/bec_lib/bl_conditions.py deleted file mode 100644 index d1ae7fb98..000000000 --- a/bec_lib/bec_lib/bl_conditions.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -This module contains classes for beamline checks, used to check the beamline status. -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -if TYPE_CHECKING: # pragma: no cover - from bec_lib.device import Device - - -class BeamlineCondition(ABC): - """Abstract base class for beamline checks.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.enabled = True - - @property - @abstractmethod - def name(self) -> str: - """Return a name for the beamline check.""" - - @abstractmethod - def run(self) -> bool: - """Run the beamline check and return True if the beam is okay, False otherwise.""" - - @abstractmethod - def on_failure_msg(self) -> str: - """Return a message that will be displayed if the beamline check fails.""" - - -class ShutterCondition(BeamlineCondition): - """Check if the shutter is open.""" - - def __init__(self, shutter: Device): - super().__init__() - self.shutter = shutter - - @property - def name(self): - return "shutter" - - def run(self): - shutter_val = self.shutter.read(cached=True) - return shutter_val["value"].lower() == "open" - - def on_failure_msg(self): - return "Check beam failed: Shutter is closed." - - -class LightAvailableCondition(BeamlineCondition): - """Check if the light is available.""" - - def __init__(self, machine_status: Device): - super().__init__() - self.machine_status = machine_status - - @property - def name(self): - return "light_available" - - def run(self): - machine_status = self.machine_status.read(cached=True) - return machine_status["value"] in ["Light Available", "Light-Available"] - - def on_failure_msg(self): - return "Check beam failed: Light not available." - - -class FastOrbitFeedbackCondition(BeamlineCondition): - """Check if the fast orbit feedback is running.""" - - def __init__(self, sls_fast_orbit_feedback: Device): - super().__init__() - self.sls_fast_orbit_feedback = sls_fast_orbit_feedback - - @property - def name(self): - return "fast_orbit_feedback" - - def run(self): - fast_orbit_feedback = self.sls_fast_orbit_feedback.read(cached=True) - return fast_orbit_feedback["value"] == "running" - - def on_failure_msg(self): - return "Check beam failed: Fast orbit feedback is not running." diff --git a/bec_lib/bec_lib/bl_states.py b/bec_lib/bec_lib/bl_states.py new file mode 100644 index 000000000..84d5c0c59 --- /dev/null +++ b/bec_lib/bec_lib/bl_states.py @@ -0,0 +1,338 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from rich.console import Console +from rich.table import Table + +from bec_lib import messages +from bec_lib.device import DeviceBase +from bec_lib.endpoints import MessageEndpoints + +if TYPE_CHECKING: + from bec_lib.client import BECClient + from bec_lib.redis_connector import MessageObject, RedisConnector + + +class BeamlineStateManager: + """Manager for beamline states.""" + + def __init__(self, client: BECClient) -> None: + self._client = client + self._connector = client.connector + self._states: list[messages.BeamlineStateConfig] = [] + self._connector.register( + MessageEndpoints.available_beamline_states(), + cb=self._on_state_update, + parent=self, + from_start=True, + ) + + @staticmethod + def _on_state_update(msg_dict: dict, *, parent: BeamlineStateManager, **_kwargs) -> None: + # type: ignore ; we know it's an AvailableBeamlineStatesMessage + msg: messages.AvailableBeamlineStatesMessage = msg_dict["data"] + parent._states = msg.states + + def add(self, state: BeamlineState) -> None: + """ + Add a new beamline state to the manager. + Args: + state (BeamlineState): The beamline state to add. + """ + + if any(state.name == existing_state.name for existing_state in self._states): + return # state already exists + info: messages.BeamlineStateConfig = messages.BeamlineStateConfig( + name=state.name, + title=state.title, + state_type=state.__class__.__name__, + parameters=state.parameters(), + ) + cls = state.__class__ + + try: + condi = cls(name=state.name, redis_connector=self._connector) + condi.configure(**state.parameters()) + except Exception as e: + raise RuntimeError(f"Failed to add state {state.name}: {e}") from e + + if isinstance(state, DeviceBeamlineState): + self._verify_signal_exists(state) + + self._states.append(info) + msg = messages.AvailableBeamlineStatesMessage(states=self._states) + self._connector.xadd( + MessageEndpoints.available_beamline_states(), {"data": msg}, max_size=1 + ) + + def _verify_signal_exists(self, state: DeviceBeamlineState) -> None: + """ + Verify that the device and signal exist in the device manager. + + Args: + state (DeviceBeamlineState): The state to verify. + + Raises: RuntimeError if the device or signal does not exist. + """ + device = state.parameters().get("device") + signal = state.parameters().get("signal") + if isinstance(device, DeviceBase): + device = device.name + + if not self._client.device_manager.devices.get(device): + raise RuntimeError( + f"Device {device} not found in device manager. Cannot add state {state.name}." + ) + if signal is not None: + if signal not in self._client.device_manager.devices[device].read(): + raise RuntimeError( + f"Signal {signal} not found in device {device}. Cannot add state {state.name}." + ) + else: + hinted_signals = self._client.device_manager.devices[device]._hints + if hinted_signals: + signal = hinted_signals[0] + else: + signal = device + state.update_parameters(device=device, signal=signal) + + def remove(self, state_name: str) -> None: + """ + Remove a beamline state by name. + Args: + state_name (str): The name of the state to remove. + """ + if not any(state.name == state_name for state in self._states): + return # state does not exist + self._states = [state for state in self._states if state.name != state_name] + msg = messages.AvailableBeamlineStatesMessage(states=self._states) + self._connector.xadd( + MessageEndpoints.available_beamline_states(), {"data": msg}, max_size=1 + ) + + def show_all(self): + """ + Pretty print all beamline states using rich. + """ + console = Console() + table = Table(title="Beamline States") + table.add_column("Name", style="cyan", no_wrap=True) + table.add_column("Type", style="magenta") + table.add_column("Parameters", style="green") + + for state in self._states: + params = state.parameters if state.parameters else "-" + table.add_row(str(state.name), str(state.state_type), str(params)) + + console.print(table) + + +class BeamlineState(ABC): + """Abstract base class for beamline states.""" + + def __init__( + self, name: str, redis_connector: RedisConnector | None = None, title: str | None = None + ) -> None: + self.name = name + self.connector = redis_connector + self.title = title if title is not None else name + self._configured = False + self._last_state: messages.BeamlineStateMessage | None = None + + def configure(self, **kwargs) -> None: + """Configure the state with given parameters.""" + self._configured = True + + def parameters(self) -> dict: + """Return the configuration parameters of the state.""" + return {} + + def update_parameters(self, **kwargs) -> None: + """Update the configuration parameters of the state.""" + pass + + @abstractmethod + def evaluate(self, *args, **kwargs) -> messages.BeamlineStateMessage | None: + """Evaluate the state and return its state.""" + + def start(self) -> None: + """Start monitoring the state if needed.""" + + def stop(self) -> None: + """Stop monitoring the state if needed.""" + + +class DeviceBeamlineState(BeamlineState): + """A beamline state that depends on a device reading.""" + + def configure(self, device: str | DeviceBase, signal: str | None = None, **kwargs) -> None: + self.device = device if isinstance(device, str) else device.name + self.signal = signal + super().configure(**kwargs) + + def parameters(self) -> dict: + params = super().parameters() + params.update({"device": self.device, "signal": self.signal}) + return params + + def update_parameters(self, **kwargs) -> None: + if "device" in kwargs: + device = kwargs.pop("device") + self.device = device if isinstance(device, str) else device.name + if "signal" in kwargs: + self.signal = kwargs.pop("signal") + super().update_parameters(**kwargs) + + def start(self) -> None: + if not self._configured: + raise RuntimeError("State must be configured before starting.") + if self.connector is None: + raise RuntimeError("Redis connector is not set.") + self.connector.register( + MessageEndpoints.device_readback(self.device), cb=self._update_device_state, parent=self + ) + + def stop(self) -> None: + if not self._configured: + return + if self.connector is None: + return + self.connector.unregister( + MessageEndpoints.device_readback(self.device), cb=self._update_device_state + ) + + @staticmethod + def _update_device_state(msg_obj: MessageObject, parent: DeviceBeamlineState) -> None: + + # Since this is called from the Redis connector, we + assert parent.connector is not None + + msg: messages.DeviceMessage = msg_obj.value # type: ignore ; we know it's a DeviceMessage + out = parent.evaluate(msg) + if out is not None and out != parent._last_state: + parent._last_state = out + parent.connector.xadd( + MessageEndpoints.beamline_state(parent.name), {"data": out}, max_size=1 + ) + + +class ShutterState(DeviceBeamlineState): + """ + A state that checks if the shutter is open. + + Example: + shutter_state = ShutterState(name="shutter_open") + shutter_state.configure(device="shutter1") + bec.beamline_states.add(shutter_state) + """ + + def evaluate(self, msg: messages.DeviceMessage, **kwargs) -> messages.BeamlineStateMessage: + val = msg.signals.get(self.signal, {}).get("value", "").lower() + if val == "open": + return messages.BeamlineStateMessage( + name=self.name, status="valid", label="Shutter is open." + ) + return messages.BeamlineStateMessage( + name=self.name, status="invalid", label="Shutter is closed." + ) + + +class DeviceWithinLimitsState(DeviceBeamlineState): + """ + A state that checks if a positioner is within limits. + + Example: + device_state = DeviceWithinLimitsState(name="sample_x_within_limits") + device_state.configure(device="sample_x", signal="sample_x_signal_name", min_limit=0.0, max_limit=10.0) + bec.beamline_states.add(device_state) + + """ + + def configure( + self, + device: str, + min_limit: float | None = None, + max_limit: float | None = None, + tolerance: float = 0.1, + signal: str | None = None, + **kwargs, + ) -> None: + """ + Configure the positioner condition. + + Args: + device (str): The name of the positioner device. + min_limit (float | None): The minimum limit for the positioner. If None, no minimum limit is enforced. + max_limit (float | None): The maximum limit for the positioner. If None, no maximum limit is enforced. + tolerance (float): The tolerance for warning conditions (default is 0.1). When the positioner is within + 10% of the limits, a warning condition will be issued. Note that the tolerance is ignored + if one of the limits is None. + signal (str, optional): The name of the signal to monitor. If not provided, defaults to the device name. + """ + self.min_limit = min_limit + self.max_limit = max_limit + self.tolerance = tolerance + super().configure(device=device, signal=signal, **kwargs) + + def parameters(self) -> dict: + params = super().parameters() + params.update( + { + "device": self.device, + "min_limit": self.min_limit, + "max_limit": self.max_limit, + "tolerance": self.tolerance, + "signal": self.signal, + } + ) + return params + + def update_parameters(self, **kwargs) -> None: + if "min_limit" in kwargs: + self.min_limit = kwargs.pop("min_limit") + if "max_limit" in kwargs: + self.max_limit = kwargs.pop("max_limit") + if "tolerance" in kwargs: + self.tolerance = kwargs.pop("tolerance") + super().update_parameters(**kwargs) + + def evaluate(self, msg: messages.DeviceMessage, **kwargs) -> messages.BeamlineStateMessage: + """ + Evaluate if the positioner is within the defined limits. If it is outside the limits, + return an invalid state. Otherwise, return a valid state. If it is within 10% of the limits, + return a warning state. + """ + + if self.min_limit is None: + self.min_limit = float("-inf") + if self.max_limit is None: + self.max_limit = float("inf") + + signal_name = self.signal if self.signal is not None else self.device + + val = msg.signals.get(signal_name, {}).get("value", None) + if val is None: + return messages.BeamlineStateMessage( + name=self.name, status="invalid", label=f"Positioner {self.device} value not found." + ) + + if val < self.min_limit or val > self.max_limit: + return messages.BeamlineStateMessage( + name=self.name, status="invalid", label=f"Positioner {self.device} out of limits" + ) + + if self.min_limit == float("-inf") or self.max_limit == float("inf"): + self.tolerance = 0 + + min_warning_threshold = self.min_limit + self.tolerance * (self.max_limit - self.min_limit) + max_warning_threshold = self.max_limit - self.tolerance * (self.max_limit - self.min_limit) + if val < min_warning_threshold or val > max_warning_threshold: + return messages.BeamlineStateMessage( + name=self.name, status="warning", label=f"Positioner {self.device} near limits" + ) + + return messages.BeamlineStateMessage( + name=self.name, status="valid", label=f"Positioner {self.device} within limits" + ) diff --git a/bec_lib/bec_lib/client.py b/bec_lib/bec_lib/client.py index 5a34ba8fe..e71d38264 100644 --- a/bec_lib/bec_lib/client.py +++ b/bec_lib/bec_lib/client.py @@ -20,7 +20,7 @@ from bec_lib.alarm_handler import AlarmHandler, Alarms from bec_lib.bec_service import BECService -from bec_lib.bl_checks import BeamlineChecks +from bec_lib.bl_states import BeamlineStateManager from bec_lib.callback_handler import CallbackHandler, EventType from bec_lib.config_helper import ConfigHelperUser from bec_lib.dap_plugins import DAPPlugins @@ -150,7 +150,6 @@ def __init__( self._live_updates = None self.dap = None self.device_monitor = None - self.bl_checks = None self.scans_namespace = SimpleNamespace() self._hli_funcs = {} self.metadata = {} @@ -161,6 +160,7 @@ def __init__( self._initialized = True self._username = "" self._system_user = "" + self.beamline_states = None def __new__(cls, *args, forced=False, **kwargs): if forced or BECClient._client is None: @@ -235,10 +235,9 @@ def _start_services(self): self.config = ConfigHelperUser(self.device_manager) self.history = ScanHistory(client=self) self.dap = DAPPlugins(self) - self.bl_checks = BeamlineChecks(self) - self.bl_checks.start() self.device_monitor = DeviceMonitorPlugin(self.connector) self._update_username() + self.beamline_states = BeamlineStateManager(client=self) def alarms(self, severity=Alarms.WARNING): """get the next alarm with at least the specified severity""" @@ -328,8 +327,6 @@ def shutdown(self, per_thread_timeout_s: float | None = None): self.queue.shutdown() if self.alarm_handler: self.alarm_handler.shutdown() - if self.bl_checks: - self.bl_checks.stop() if self.history is not None: # pylint: disable=protected-access self.history._shutdown() diff --git a/bec_lib/bec_lib/endpoints.py b/bec_lib/bec_lib/endpoints.py index 89bd36d17..f99104650 100644 --- a/bec_lib/bec_lib/endpoints.py +++ b/bec_lib/bec_lib/endpoints.py @@ -1764,6 +1764,40 @@ def macro_update(): endpoint=endpoint, message_type=messages.MacroUpdateMessage, message_op=MessageOp.SEND ) + @staticmethod + def beamline_state(state_name: str): + """ + Endpoint for beamline state. This endpoint is used to publish the beamline state + using a messages.BeamlineStateMessage message. + + Args: + state_name (str): State name. + Returns: + EndpointInfo: Endpoint for beamline state. + """ + endpoint = f"{EndpointType.INFO.value}/beamline_state/{state_name}" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.BeamlineStateMessage, + message_op=MessageOp.STREAM, + ) + + @staticmethod + def available_beamline_states(): + """ + Endpoint for updating the available beamline states. This endpoint is used to + publish beamline state updates using a messages.AvailableBeamlineStatesMessage message. + + Returns: + EndpointInfo: Endpoint for beamline state updates. + """ + endpoint = f"{EndpointType.INFO.value}/available_beamline_states" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.AvailableBeamlineStatesMessage, + message_op=MessageOp.STREAM, + ) + @staticmethod def atlas_websocket_state(deployment_name: str, host_id: str): """ diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index e321a464e..f3b4d7ff7 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -1953,3 +1953,48 @@ class GameLeaderboardMessage(BECMessage): msg_type: ClassVar[str] = "game_leaderboard_message" game_name: str leaderboard: list[GameScoreMessage] + + +class BeamlineStateMessage(BECMessage): + """ + Message for beamline state updates + + Args: + name (str): Name of the beamline state + status (Literal["valid", "invalid", "warning"]): Status of the beamline state + label (str): Description of the beamline state + """ + + msg_type: ClassVar[str] = "beamline_state_message" + name: str + status: Literal["valid", "invalid", "warning"] + label: str + + +class BeamlineStateConfig(BaseModel): + """ + Entry for beamline state update + + Args: + name (str): Name of the beamline state + title (str): Title of the beamline state + state_type (str): Type of the beamline state + parameters (dict, optional): Additional parameters for the state + """ + + name: str + title: str + state_type: str + parameters: dict = Field(default_factory=dict) + + +class AvailableBeamlineStatesMessage(BECMessage): + """ + Message for updating beamline states + + Args: + states (list[BeamlineStateConfig]): List of beamline state update entries + """ + + msg_type: ClassVar[str] = "beamline_state_update_message" + states: list[BeamlineStateConfig] diff --git a/bec_lib/tests/test_beamline_checks.py b/bec_lib/tests/test_beamline_checks.py deleted file mode 100644 index 4357bcc00..000000000 --- a/bec_lib/tests/test_beamline_checks.py +++ /dev/null @@ -1,167 +0,0 @@ -# pylint: skip-file -from unittest import mock - -import pytest - -from bec_lib.bl_checks import ( - BeamlineCheckError, - BeamlineChecks, - _run_on_failure, - _run_with_bl_checks, -) - - -def test_run_with_bl_checks(): - bl_checks = mock.MagicMock() - bl_checks._levels = [] - bl_checks._is_paused = False - _run_with_bl_checks(bl_checks, mock.MagicMock(), 1, 2, 3, a=4, b=5) - assert not bl_checks._levels - - -def test_bl_check_raises_on_failed_nested_calls(): - bl_checks = mock.MagicMock() - bl_checks._levels = [{"fcn": mock.MagicMock()}] - bl_checks._is_paused = False - bl_checks.beam_is_okay = False - with pytest.raises(BeamlineCheckError): - _run_with_bl_checks(bl_checks, mock.MagicMock(), 1, 2, 3, a=4, b=5) - - -def test_bl_check_run_on_failure(): - bl_checks = mock.MagicMock() - bl_checks._levels = [] - bl_checks._is_paused = False - bl_checks.beam_is_okay = False - fcn = mock.MagicMock() - fcn.side_effect = BeamlineCheckError - _run_on_failure(bl_checks, fcn, 1, 2, 3, a=4, b=5) - bl_checks.wait_for_beamline_checks.assert_called_once() - - -def test_bl_check_register(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - condition = mock.MagicMock() - condition.name = "test" - bl_check.register(condition) - assert bl_check.test == condition # pylint: disable=no-member - assert bl_check._beamline_checks["test"] == condition - - -def test_bl_check_pause(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check.pause() - assert bl_check._is_paused - - -def test_bl_check_resume(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._is_paused = True - bl_check.resume() - assert not bl_check._is_paused - - -def test_bl_check_reset(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._beam_is_okay = False - bl_check._check_msgs = ["test"] - bl_check.reset() - assert bl_check._beam_is_okay - assert not bl_check._check_msgs - - -def test_bl_check_disable_check(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._beamline_checks = {"test": mock.MagicMock()} - bl_check.disable_check("test") - assert not bl_check._beamline_checks["test"].enabled - - -def test_bl_check_enable_check(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._beamline_checks = {"test": mock.MagicMock()} - bl_check.enable_check("test") - assert bl_check._beamline_checks["test"].enabled - - -def test_bl_check_disable_all_checks(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._beamline_checks = {"test": mock.MagicMock()} - bl_check.disable_all_checks() - assert not bl_check._beamline_checks["test"].enabled - - -def test_bl_check_enable_all_checks(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._beamline_checks = {"test": mock.MagicMock()} - bl_check.enable_all_checks() - assert bl_check._beamline_checks["test"].enabled - - -def test_bl_check_run_beamline_checks(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._beamline_checks = {"test": mock.MagicMock()} - bl_check._beam_is_okay = True - bl_check._run_beamline_checks() - assert bl_check._beam_is_okay - bl_check._beamline_checks["test"].run.assert_called_once() - bl_check._beamline_checks["test"].on_failure_msg.assert_not_called() - - -def test_bl_check_send_to_scilog(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._send_to_scilog("test") - client.logbook.send_logbook_message.assert_called_once() - - -def test_bl_check_beam_status_is_okay(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - with mock.patch.object(bl_check, "_run_beamline_checks") as mock_run: - mock_run.return_value = [] - assert bl_check._beam_status_is_okay() is True - - -def test_bl_check_wait_for_beamline_checks(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._check_msgs = ["test"] - bl_check._beam_is_okay = False - with mock.patch.object(bl_check, "_print_beamline_checks") as mock_print: - with mock.patch.object(bl_check, "_send_to_scilog") as mock_send: - with mock.patch.object(bl_check, "_run_beamline_checks") as mock_run: - mock_run.return_value = ["test"] - bl_check.wait_for_beamline_checks() - mock_print.assert_called_once() - mock_send.assert_called() - - -def test_bl_check_start(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._started = False - bl_check._beam_is_okay = True - with mock.patch("bec_lib.bl_checks.threading.Thread") as mock_thread: - bl_check.start() - mock_thread.assert_called_once() - mock_thread.return_value.start.assert_called_once() - - -def test_bl_check_stop(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._started = True - bl_check._beam_is_okay = True - bl_check._beam_check_thread = mock.MagicMock() - bl_check.stop() - bl_check._beam_check_thread.join.assert_called_once() diff --git a/bec_lib/tests/test_beamline_states.py b/bec_lib/tests/test_beamline_states.py new file mode 100644 index 000000000..e680f4b67 --- /dev/null +++ b/bec_lib/tests/test_beamline_states.py @@ -0,0 +1,564 @@ +import time +from unittest import mock + +import pytest + +from bec_lib import messages +from bec_lib.bl_states import ( + BeamlineState, + BeamlineStateManager, + DeviceBeamlineState, + DeviceWithinLimitsState, + ShutterState, +) +from bec_lib.endpoints import MessageEndpoints +from bec_lib.redis_connector import MessageObject + + +@pytest.fixture +def state_manager(connected_connector): + client = mock.MagicMock() + client.connector = connected_connector + client.device_manager = mock.MagicMock() + config = BeamlineStateManager(client) + yield config + + +# ============================================================================ +# BeamlineState tests +# ============================================================================ + + +class TestBeamlineState: + """Tests for the abstract BeamlineState base class.""" + + def test_beamline_state_initialization(self): + """Test basic initialization of a BeamlineState.""" + + class ConcreteState(BeamlineState): + def evaluate(self, *args, **kwargs): + return messages.BeamlineStateMessage(name=self.name, status="invalid", label="Test") + + state = ConcreteState(name="test_state", title="Test State") + assert state.name == "test_state" + assert state.title == "Test State" + assert state.connector is None + assert state._configured is False + assert state._last_state is None + + def test_beamline_state_default_title(self): + """Test that title defaults to name if not provided.""" + + class ConcreteState(BeamlineState): + def evaluate(self, *args, **kwargs): + return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") + + state = ConcreteState(name="test_state") + assert state.title == "test_state" + + def test_beamline_state_configure(self): + """Test that configure marks the condition as configured.""" + + class ConcreteState(BeamlineState): + def evaluate(self, *args, **kwargs): + return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") + + state = ConcreteState(name="test_state") + assert state._configured is False + state.configure() + assert state._configured is True + + def test_beamline_state_parameters(self): + """Test that parameters returns an empty dict by default.""" + + class ConcreteState(BeamlineState): + def evaluate(self, *args, **kwargs): + return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") + + state = ConcreteState(name="test_state") + assert state.parameters() == {} + + def test_beamline_state_with_connector(self, connected_connector): + """Test BeamlineState initialization with a connector.""" + + class ConcreteState(BeamlineState): + def evaluate(self, *args, **kwargs): + return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") + + state = ConcreteState(name="test_state", redis_connector=connected_connector) + assert state.connector == connected_connector + + +# ============================================================================ +# DeviceBeamlineState tests +# ============================================================================ + + +class TestDeviceBeamlineState: + """Tests for DeviceBeamlineState.""" + + def test_device_state_configure(self, connected_connector): + """Test DeviceBeamlineState configuration.""" + + class ConcreteDeviceState(DeviceBeamlineState): + def evaluate(self, *args, **kwargs): + return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") + + state = ConcreteDeviceState(name="device_test", redis_connector=connected_connector) + state.configure(device="samx", signal="samx_value") + assert state.device == "samx" + assert state.signal == "samx_value" + assert state._configured is True + + def test_device_state_configure_default_signal(self, connected_connector): + """Test that signal defaults to device name if not provided.""" + + class ConcreteDeviceState(DeviceBeamlineState): + def evaluate(self, *args, **kwargs): + return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") + + state = ConcreteDeviceState(name="device_test", redis_connector=connected_connector) + state.configure(device="samx", signal="samx") + assert state.device == "samx" + assert state.signal == "samx" + + def test_device_state_parameters(self, connected_connector): + """Test that parameters includes device and signal.""" + + class ConcreteDeviceState(DeviceBeamlineState): + def evaluate(self, *args, **kwargs): + return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") + + state = ConcreteDeviceState(name="device_test", redis_connector=connected_connector) + state.configure(device="samx", signal="samx_value") + params = state.parameters() + assert params["device"] == "samx" + assert params["signal"] == "samx_value" + + def test_device_state_start_not_configured(self, connected_connector): + """Test that start raises RuntimeError if state is not configured.""" + + class ConcreteDeviceState(DeviceBeamlineState): + def evaluate(self, *args, **kwargs): + return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") + + state = ConcreteDeviceState(name="device_test", redis_connector=connected_connector) + with pytest.raises(RuntimeError, match="State must be configured before starting"): + state.start() + + def test_device_state_start_no_connector(self): + """Test that start raises RuntimeError if connector is not set.""" + + class ConcreteDeviceState(DeviceBeamlineState): + def evaluate(self, *args, **kwargs): + return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") + + state = ConcreteDeviceState(name="device_test") + state.configure(device="samx") + with pytest.raises(RuntimeError, match="Redis connector is not set"): + state.start() + + def test_device_state_start_registers_callback(self, connected_connector): + """Test that start registers the callback with the connector.""" + + class ConcreteDeviceState(DeviceBeamlineState): + def evaluate(self, *args, **kwargs): + return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") + + state = ConcreteDeviceState(name="device_test", redis_connector=connected_connector) + state.configure(device="samx") + with mock.patch.object(connected_connector, "register") as mock_register: + state.start() + mock_register.assert_called_once() + call_args = mock_register.call_args + assert call_args[0][0] == MessageEndpoints.device_readback("samx") + + def test_device_state_stop(self, connected_connector): + """Test that stop unregisters the callback.""" + + class ConcreteDeviceState(DeviceBeamlineState): + def evaluate(self, *args, **kwargs): + return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") + + state = ConcreteDeviceState(name="device_test", redis_connector=connected_connector) + state.configure(device="samx") + + with mock.patch.object(connected_connector, "unregister") as mock_unregister: + state.stop() + mock_unregister.assert_called_once() + + def test_device_state_stop_not_configured(self, connected_connector): + """Test that stop doesn't raise an error if not configured.""" + + class ConcreteDeviceState(DeviceBeamlineState): + def evaluate(self, *args, **kwargs): + return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") + + state = ConcreteDeviceState(name="device_test", redis_connector=connected_connector) + # Should not raise an error + state.stop() + + def test_device_state_stop_no_connector(self): + """Test that stop doesn't raise an error if connector is not set.""" + + class ConcreteDeviceState(DeviceBeamlineState): + def evaluate(self, *args, **kwargs): + return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") + + state = ConcreteDeviceState(name="device_test") + state.configure(device="samx") + # Should not raise an error + state.stop() + + def test_device_state_update_device_state(self, connected_connector): + """Test that _update_device_state calls evaluate and updates _last_state.""" + + msg = messages.BeamlineStateMessage(name="device_test", status="valid", label="Test") + + class ConcreteDeviceState(DeviceBeamlineState): + def evaluate(self, *args, **kwargs): + return msg + + state = ConcreteDeviceState(name="device_test", redis_connector=connected_connector) + state.configure(device="samx") + + msg_obj = MessageObject(value=msg, topic="test_topic") + state._update_device_state(msg_obj, parent=state) + assert state._last_state == msg + out = state.connector.xread(MessageEndpoints.beamline_state("device_test"), from_start=True) + assert out is not None + assert out[0]["data"] == msg + + +# ============================================================================ +# ShutterState tests +# ============================================================================ + + +class TestShutterState: + """Tests for ShutterState.""" + + def test_shutter_open(self, connected_connector): + """Test evaluation when shutter is open.""" + state = ShutterState(name="shutter_open", redis_connector=connected_connector) + state.configure(device="shutter1", signal="shutter1") + + msg = messages.DeviceMessage( + signals={"shutter1": {"value": "open", "timestamp": 1234567890.0}}, + metadata={"stream": "primary"}, + ) + + result = state.evaluate(msg) + assert result.name == "shutter_open" + assert result.status == "valid" + assert result.label == "Shutter is open." + + def test_shutter_open_uppercase(self, connected_connector): + """Test evaluation when shutter value is uppercase and gets lowercased.""" + state = ShutterState(name="shutter_open", redis_connector=connected_connector) + state.configure(device="shutter1", signal="shutter1") + + msg = messages.DeviceMessage( + signals={"shutter1": {"value": "OPEN", "timestamp": 1234567890.0}}, + metadata={"stream": "primary"}, + ) + + result = state.evaluate(msg) + assert result.status == "valid" + assert result.label == "Shutter is open." + + def test_shutter_closed(self, connected_connector): + """Test evaluation when shutter is closed.""" + state = ShutterState(name="shutter_open", redis_connector=connected_connector) + state.configure(device="shutter1") + + msg = messages.DeviceMessage( + signals={"shutter1": {"value": "closed", "timestamp": 1234567890.0}}, + metadata={"stream": "primary"}, + ) + + result = state.evaluate(msg) + assert result.name == "shutter_open" + assert result.status == "invalid" + assert result.label == "Shutter is closed." + + def test_shutter_missing_value(self, connected_connector): + """Test evaluation when value is missing.""" + state = ShutterState(name="shutter_open", redis_connector=connected_connector) + state.configure(device="shutter1") + + msg = messages.DeviceMessage( + signals={"shutter1": {"timestamp": 1234567890.0}}, metadata={"stream": "primary"} + ) + + result = state.evaluate(msg) + assert result.status == "invalid" + assert result.label == "Shutter is closed." + + +# ============================================================================ +# DeviceWithinLimitsState tests +# ============================================================================ + + +class TestDeviceWithinLimitsState: + """Tests for DeviceWithinLimitsState.""" + + def test_within_limits_configure(self, connected_connector): + """Test configuration of DeviceWithinLimitsState.""" + state = DeviceWithinLimitsState(name="sample_x_limits", redis_connector=connected_connector) + state.configure(device="sample_x", min_limit=0.0, max_limit=10.0) + + assert state.device == "sample_x" + assert state.min_limit == 0.0 + assert state.max_limit == 10.0 + assert state.tolerance == 0.1 + + def test_within_limits_configure_custom_tolerance(self, connected_connector): + """Test configuration with custom tolerance.""" + state = DeviceWithinLimitsState(name="sample_x_limits", redis_connector=connected_connector) + state.configure(device="sample_x", min_limit=0.0, max_limit=10.0, tolerance=0.2) + + assert state.tolerance == 0.2 + + def test_within_limits_value_inside(self, connected_connector): + """Test evaluation when value is within limits.""" + state = DeviceWithinLimitsState(name="sample_x_limits", redis_connector=connected_connector) + state.configure(device="sample_x", min_limit=0.0, max_limit=10.0) + + msg = messages.DeviceMessage( + signals={"sample_x": {"value": 5.0, "timestamp": 1234567890.0}}, + metadata={"stream": "primary"}, + ) + + result = state.evaluate(msg) + assert result.status == "valid" + assert result.label == "Positioner sample_x within limits" + + def test_within_limits_value_outside_low(self, connected_connector): + """Test evaluation when value is below minimum limit.""" + state = DeviceWithinLimitsState(name="sample_x_limits", redis_connector=connected_connector) + state.configure(device="sample_x", min_limit=0.0, max_limit=10.0) + + msg = messages.DeviceMessage( + signals={"sample_x": {"value": -1.0, "timestamp": 1234567890.0}}, + metadata={"stream": "primary"}, + ) + + result = state.evaluate(msg) + assert result.status == "invalid" + assert result.label == "Positioner sample_x out of limits" + + def test_within_limits_value_outside_high(self, connected_connector): + """Test evaluation when value is above maximum limit.""" + state = DeviceWithinLimitsState(name="sample_x_limits", redis_connector=connected_connector) + state.configure(device="sample_x", min_limit=0.0, max_limit=10.0) + + msg = messages.DeviceMessage( + signals={"sample_x": {"value": 11.0, "timestamp": 1234567890.0}}, + metadata={"stream": "primary"}, + ) + + result = state.evaluate(msg) + assert result.status == "invalid" + assert result.label == "Positioner sample_x out of limits" + + def test_within_limits_value_near_min(self, connected_connector): + """Test evaluation when value is near minimum limit (within tolerance).""" + state = DeviceWithinLimitsState(name="sample_x_limits", redis_connector=connected_connector) + state.configure(device="sample_x", min_limit=0.0, max_limit=10.0, tolerance=0.1) + + # 10% of (10 - 0) = 1.0, so near min is < 1.0 + msg = messages.DeviceMessage( + signals={"sample_x": {"value": 0.5, "timestamp": 1234567890.0}}, + metadata={"stream": "primary"}, + ) + + result = state.evaluate(msg) + assert result.status == "warning" + assert result.label == "Positioner sample_x near limits" + + def test_within_limits_value_near_max(self, connected_connector): + """Test evaluation when value is near maximum limit (within tolerance).""" + state = DeviceWithinLimitsState(name="sample_x_limits", redis_connector=connected_connector) + state.configure(device="sample_x", min_limit=0.0, max_limit=10.0, tolerance=0.1) + + # 10% of (10 - 0) = 1.0, so near max is > 9.0 + msg = messages.DeviceMessage( + signals={"sample_x": {"value": 9.5, "timestamp": 1234567890.0}}, + metadata={"stream": "primary"}, + ) + + result = state.evaluate(msg) + assert result.status == "warning" + assert result.label == "Positioner sample_x near limits" + + def test_within_limits_missing_value(self, connected_connector): + """Test evaluation when value is missing.""" + state = DeviceWithinLimitsState(name="sample_x_limits", redis_connector=connected_connector) + state.configure(device="sample_x", min_limit=0.0, max_limit=10.0) + + msg = messages.DeviceMessage( + signals={"sample_x": {"timestamp": 1234567890.0}}, metadata={"stream": "primary"} + ) + + result = state.evaluate(msg) + assert result.status == "invalid" + assert "value not found" in result.label + + def test_within_limits_parameters(self, connected_connector): + """Test that parameters includes all configuration.""" + state = DeviceWithinLimitsState(name="sample_x_limits", redis_connector=connected_connector) + state.configure(device="sample_x", min_limit=0.0, max_limit=10.0, signal="x_readback") + + params = state.parameters() + assert params["device"] == "sample_x" + assert params["min_limit"] == 0.0 + assert params["max_limit"] == 10.0 + assert params["tolerance"] == 0.1 + assert params["signal"] == "x_readback" + + +# ============================================================================ +# BeamlineStateConfig tests +# ============================================================================ + + +class TestBeamlineStateConfig: + """Tests for BeamlineStateConfig manager.""" + + @pytest.mark.timeout(5) + def test_add_state(self, state_manager): + """Test adding a state.""" + state = ShutterState(name="shutter_open", title="Shutter Open") + state.configure(device="shutter1") + + # Setup device manager mock - the signal should match the device name when no signal is provided + state_manager._client.device_manager.devices = {"shutter1": mock.MagicMock()} + state_manager._client.device_manager.devices["shutter1"].read.return_value = { + "shutter1": {"value": "open"} + } + + state_manager.add(state) + while True: + if any(c.name == "shutter_open" for c in state_manager._states): + break + time.sleep(0.1) + # Check that the state was added + assert any(c.name == "shutter_open" for c in state_manager._states) + + @pytest.mark.timeout(5) + def test_add_state_already_exists(self, state_manager): + """Test that adding a duplicate state is ignored.""" + state = ShutterState(name="shutter_open", title="Shutter Open") + state.configure(device="shutter1") + + # Setup device manager mock + state_manager._client.device_manager.devices = {"shutter1": mock.MagicMock()} + state_manager._client.device_manager.devices["shutter1"].read.return_value = { + "shutter1": {"value": "open"} + } + + # Add the state once + state_manager.add(state) + while True: + if any(c.name == "shutter_open" for c in state_manager._states): + break + time.sleep(0.1) + initial_count = len(state_manager._states) + + # Add the same state again + state_manager.add(state) + time.sleep(0.5) + # Count should not increase + assert len(state_manager._states) == initial_count + + def test_add_state_device_not_found(self, state_manager): + """Test that adding a state with invalid device raises RuntimeError.""" + state = ShutterState(name="shutter_open", title="Shutter Open") + state.configure(device="nonexistent_shutter") + + state_manager._client.device_manager.devices = {} + + with pytest.raises(RuntimeError, match="Device nonexistent_shutter not found"): + state_manager.add(state) + + def test_add_state_signal_not_found(self, state_manager): + """Test that adding a state with invalid signal raises RuntimeError.""" + state = ShutterState(name="shutter_open", title="Shutter Open") + # Setup device manager mock with device but without the signal + mock_device = mock.MagicMock() + mock_device.read.return_value = {"other_signal": {"value": "open"}} + state_manager._client.device_manager.devices = {"shutter1": mock_device} + + state.configure(device="shutter1", signal="value") + + with pytest.raises(RuntimeError, match="Signal value not found in device shutter1"): + state_manager.add(state) + + @pytest.mark.timeout(5) + def test_remove_state(self, state_manager): + """Test removing a state.""" + state = ShutterState(name="shutter_open", title="Shutter Open") + state.configure(device="shutter1") + + # Setup device manager mock + state_manager._client.device_manager.devices = {"shutter1": mock.MagicMock()} + state_manager._client.device_manager.devices["shutter1"].read.return_value = { + "shutter1": {"value": "open"} + } + + # Add and then remove + state_manager.add(state) + while True: + if any(c.name == "shutter_open" for c in state_manager._states): + break + time.sleep(0.1) + + state_manager.remove("shutter_open") + while True: + if not any(c.name == "shutter_open" for c in state_manager._states): + break + time.sleep(0.1) + + def test_remove_nonexistent_state(self, state_manager): + """Test removing a state that doesn't exist.""" + # Should not raise an error + state_manager.remove("nonexistent") + assert len(state_manager._states) == 0 + + @pytest.mark.timeout(5) + def test_show_all(self, state_manager, capsys): + """Test that show_all displays states in a table.""" + state = ShutterState(name="shutter_open", title="Shutter Open") + state.configure(device="shutter1") + + # Setup device manager mock + state_manager._client.device_manager.devices = {"shutter1": mock.MagicMock()} + state_manager._client.device_manager.devices["shutter1"].read.return_value = { + "shutter1": {"value": "open"} + } + + state_manager.add(state) + while True: + if any(c.name == "shutter_open" for c in state_manager._states): + break + time.sleep(0.1) + state_manager.show_all() + + # The output should be printed (checked via capsys) + captured = capsys.readouterr() + # Check that the state name appears in the output + assert "shutter_open" in captured.out or "shutter_open" in captured.err + + def test_on_state_update(self, state_manager): + """Test that _on_state_update updates the states list.""" + update_entry = messages.BeamlineStateConfig( + name="test_state", title="Test State", state_type="ShutterState", parameters={} + ) + msg = messages.AvailableBeamlineStatesMessage(states=[update_entry]) + + state_manager._on_state_update({"data": msg}, parent=state_manager) + + assert len(state_manager._states) == 1 + assert state_manager._states[0].name == "test_state" diff --git a/bec_lib/tests/test_bl_conditions.py b/bec_lib/tests/test_bl_conditions.py deleted file mode 100644 index 85e87df0b..000000000 --- a/bec_lib/tests/test_bl_conditions.py +++ /dev/null @@ -1,37 +0,0 @@ -from unittest import mock - -from bec_lib.bl_conditions import ( - FastOrbitFeedbackCondition, - LightAvailableCondition, - ShutterCondition, -) - - -def test_shutter_condition(): - device = mock.MagicMock() - shutter_condition = ShutterCondition(device) - shutter_condition.run() - device.read.assert_called_once() - assert shutter_condition.on_failure_msg() == "Check beam failed: Shutter is closed." - assert shutter_condition.name == "shutter" - - -def test_light_available_condition(): - device = mock.MagicMock() - light_available_condition = LightAvailableCondition(device) - light_available_condition.run() - device.read.assert_called_once() - assert light_available_condition.on_failure_msg() == "Check beam failed: Light not available." - assert light_available_condition.name == "light_available" - - -def test_fast_orbit_feedback_condition(): - device = mock.MagicMock() - fast_orbit_feedback_condition = FastOrbitFeedbackCondition(device) - fast_orbit_feedback_condition.run() - device.read.assert_called_once() - assert ( - fast_orbit_feedback_condition.on_failure_msg() - == "Check beam failed: Fast orbit feedback is not running." - ) - assert fast_orbit_feedback_condition.name == "fast_orbit_feedback" diff --git a/bec_server/bec_server/scan_server/beamline_state_manager.py b/bec_server/bec_server/scan_server/beamline_state_manager.py new file mode 100644 index 000000000..8c6a861f6 --- /dev/null +++ b/bec_server/bec_server/scan_server/beamline_state_manager.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from bec_lib import bl_states, messages +from bec_lib.alarm_handler import Alarms +from bec_lib.endpoints import MessageEndpoints +from bec_lib.redis_connector import RedisConnector + + +class BeamlineStateManager: + """Manager for beamline states.""" + + def __init__(self, connector: RedisConnector) -> None: + self.connector = connector + self.states: list[bl_states.BeamlineState] = [] + self.connector.register( + MessageEndpoints.available_beamline_states(), + cb=self._handle_state_update, + parent=self, + from_start=True, + ) + + @staticmethod + def _handle_state_update(msg_dict: dict, *, parent: BeamlineStateManager, **_kwargs) -> None: + + msg: messages.AvailableBeamlineStatesMessage = msg_dict["data"] # type: ignore ; we know it's a AvailableBeamlineStatesMessage + parent.update_states(msg) + + def update_states(self, msg: messages.AvailableBeamlineStatesMessage) -> None: + """ + Update the beamline states based on the received update message. + + Args: + msg (messages.AvailableBeamlineStatesMessage): The update message containing state updates. + """ + + # get the states that we need to remove + states_in_msg = {state.name for state in msg.states} + current_states = {state.name for state in self.states} + states_to_remove = current_states - states_in_msg + # remove states that are no longer needed + for state_name in states_to_remove: + state = next((s for s in self.states if s.name == state_name), None) + if state: + state.stop() + self.states.remove(state) + # filter out existing states from the message + new_states = [state for state in msg.states if state.name not in current_states] + # add new states + for state in new_states: + self.states.append(self.create_state_from_message(state)) + + def create_state_from_message( + self, state_info: messages.BeamlineStateConfig + ) -> bl_states.BeamlineState: + """ + Create a BeamlineState instance from a BeamlineStateConfig message. + + Args: + state_info (messages.BeamlineStateConfig): The state config message. + Returns: + BeamlineState: The created BeamlineState instance. + """ + try: + cls = getattr(bl_states, state_info.state_type, None) + if cls is None or not issubclass(cls, bl_states.BeamlineState): + raise ValueError( + f"State type {state_info.state_type} not found in beamline states." + ) + state = cls( + name=state_info.name, redis_connector=self.connector, title=state_info.title + ) + state.configure(**state_info.parameters) + state.start() + except Exception as exc: + self.connector.raise_alarm( + severity=Alarms.WARNING, + info=messages.ErrorInfo( + error_message=f"Failed to create beamline state {state_info.name}: {exc}", + compact_error_message=f"Failed to create beamline state {state_info.name}", + exception_type=type(exc).__name__, + ), + ) + return state diff --git a/bec_server/bec_server/scan_server/scan_server.py b/bec_server/bec_server/scan_server/scan_server.py index 3af7de6dc..a63fac1c5 100644 --- a/bec_server/bec_server/scan_server/scan_server.py +++ b/bec_server/bec_server/scan_server/scan_server.py @@ -15,6 +15,7 @@ from bec_server.procedures.manager import ProcedureManager from bec_server.procedures.subprocess_worker import SubProcessWorker +from .beamline_state_manager import BeamlineStateManager from .scan_assembler import ScanAssembler from .scan_guard import ScanGuard from .scan_manager import ScanManager @@ -40,6 +41,8 @@ def __init__(self, config: ServiceConfig, connector_cls: type[RedisConnector]): use_subprocess_proc_worker=config.model.procedures.use_subprocess_worker ) self.status = messages.BECStatus.RUNNING + self.beamline_states = None + self._start_beamline_state_manager() def _start_device_manager(self): self.wait_for_service("DeviceServer") @@ -59,6 +62,9 @@ def _start_scan_assembler(self): def _start_scan_guard(self): self.scan_guard = ScanGuard(parent=self) + def _start_beamline_state_manager(self): + self.beamline_states = BeamlineStateManager(self.connector) + def _start_alarm_handler(self): self.connector.register(MessageEndpoints.alarm(), cb=self._alarm_callback, parent=self) diff --git a/bec_server/tests/tests_scan_server/test_beamline_state_manager.py b/bec_server/tests/tests_scan_server/test_beamline_state_manager.py new file mode 100644 index 000000000..9bb943620 --- /dev/null +++ b/bec_server/tests/tests_scan_server/test_beamline_state_manager.py @@ -0,0 +1,103 @@ +import time +from unittest import mock + +import pytest + +from bec_lib import messages +from bec_lib.endpoints import MessageEndpoints +from bec_server.scan_server.beamline_state_manager import BeamlineStateManager + + +@pytest.fixture +def state_manager(connected_connector): + manager = BeamlineStateManager(connected_connector) + yield manager + + +def test_state_manager_fetches_states(): + """ + Test that the BeamlineStateManager fetches all available beamline states on initialization. + """ + + connector = mock.MagicMock() + state_manager = BeamlineStateManager(connector) + connector.register.assert_called_once_with( + MessageEndpoints.available_beamline_states(), + cb=state_manager._handle_state_update, + parent=state_manager, + from_start=True, + ) + + +@pytest.mark.timeout(5) +def test_state_manager_updates_states(state_manager, connected_connector): + """ + Test that the BeamlineStateManager updates its states correctly when receiving an update message. + """ + + # Initial state: no states + assert len(state_manager.states) == 0 + + msg = messages.AvailableBeamlineStatesMessage( + states=[ + messages.BeamlineStateConfig( + name="State1", + title="Shutter", + state_type="ShutterState", + parameters={"device": "shutter1"}, + ) + ] + ) + + connected_connector.xadd( + MessageEndpoints.available_beamline_states(), {"data": msg}, max_size=1 + ) + + # Give it some time to process + while len(state_manager.states) < 1: + time.sleep(0.1) + + msg = messages.AvailableBeamlineStatesMessage( + states=[ + messages.BeamlineStateConfig( + name="State1", + title="Shutter", + state_type="ShutterState", + parameters={"device": "shutter1"}, + ), + messages.BeamlineStateConfig( + name="State2", + title="Shutter2", + state_type="ShutterState", + parameters={"device": "shutter2"}, + ), + ] + ) + + connected_connector.xadd( + MessageEndpoints.available_beamline_states(), {"data": msg}, max_size=1 + ) + + # Give it some time to process + while len(state_manager.states) < 2: + time.sleep(0.1) + + msg = messages.AvailableBeamlineStatesMessage( + states=[ + messages.BeamlineStateConfig( + name="State2", + title="Shutter2", + state_type="ShutterState", + parameters={"device": "shutter2"}, + ) + ] + ) + connected_connector.xadd( + MessageEndpoints.available_beamline_states(), {"data": msg}, max_size=1 + ) + # Give it some time to process + while len(state_manager.states) > 1: + time.sleep(0.1) + + assert len(state_manager.states) == 1 + assert state_manager.states[0].name == "State2" From 8e911d4a9e57cb4ae6deec64391154d1351796f1 Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Sat, 21 Feb 2026 15:41:23 +0100 Subject: [PATCH 2/6] f - wip: bl state rewrite --- bec_lib/bec_lib/bl_state_manager.py | 220 ++++++ bec_lib/bec_lib/bl_states.py | 313 +++----- bec_lib/bec_lib/client.py | 2 +- bec_lib/tests/test_beamline_states.py | 707 ++++++------------ .../scan_server/beamline_state_manager.py | 89 ++- .../test_beamline_state_manager.py | 48 +- 6 files changed, 636 insertions(+), 743 deletions(-) create mode 100644 bec_lib/bec_lib/bl_state_manager.py diff --git a/bec_lib/bec_lib/bl_state_manager.py b/bec_lib/bec_lib/bl_state_manager.py new file mode 100644 index 000000000..3213f6817 --- /dev/null +++ b/bec_lib/bec_lib/bl_state_manager.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +import inspect +from inspect import Parameter, Signature +from typing import TYPE_CHECKING, TypedDict + +from pydantic import BaseModel +from rich.console import Console +from rich.table import Table + +from bec_lib import bl_states, messages +from bec_lib.endpoints import MessageEndpoints + +if TYPE_CHECKING: + from bec_lib.bl_states import BeamlineStateConfig + from bec_lib.client import BECClient + + +def build_signature_from_model(model: BaseModel) -> Signature: + """ + Build a function signature from a Pydantic model. The parameters of the signature will match the fields of the model. + """ + parameters = [] + + for name, field in model.model_fields.items(): + annotation = field.annotation or inspect.Parameter.empty + parameters.append( + Parameter( + name=name, kind=Parameter.KEYWORD_ONLY, default=field.default, annotation=annotation + ) + ) + + return Signature(parameters) + + +class BeamlineStateGet(TypedDict): + """ + TypedDict for the return value of the get method of a beamline state client. + """ + + status: str + label: str + + +class BeamlineStateClientBase: + """Base class for beamline state clients.""" + + def __init__(self, manager: BeamlineStateManager, state: BeamlineStateConfig) -> None: + self._manager = manager + self._connector = manager._connector + self._state = state + + # pylint: disable=unnecessary-lambda + self._run = lambda **kwargs: self._run_update(**kwargs) + self._update_signature() + + def _update_signature(self) -> None: + # Dynamically update the signature of the update_parameters method to match the parameters of the state config + setattr(self, "update_parameters", self._run) + setattr( + getattr(self, "update_parameters"), + "__signature__", + build_signature_from_model(self._state), + ) + + def _run_update(self, **kwargs) -> None: + self._state = self._state.model_copy(update=kwargs) + self._manager._update_state(self._state) # pylint: disable=protected-access + + def get(self) -> BeamlineStateGet: + """ + Get the current status of the beamline state. Returns a dictionary with keys "status" and "label". + + Returns: + BeamlineStateGet: A dictionary containing the status and label of the beamline state. + """ + msg_container: dict[str, messages.BeamlineStateMessage] = self._connector.get_last( + MessageEndpoints.beamline_state(self._state.name) + ) + if not msg_container: + return {"status": "unknown", "label": "No state information available."} + msg = msg_container["data"] + return {"status": msg.status, "label": msg.label} + + def delete(self) -> None: + """ + Delete the current beamline state. + """ + self._manager.remove(self._state.name) + + +class BeamlineStateManager: + """Manager for beamline states.""" + + def __init__(self, client: BECClient) -> None: + self._client = client + self._connector = client.connector + self._states: dict[str, BeamlineStateConfig] = {} + self._connector.register( + MessageEndpoints.available_beamline_states(), + cb=self._on_state_update, + parent=self, + from_start=True, + ) + + @staticmethod + def _on_state_update(msg_dict: dict, *, parent: BeamlineStateManager, **_kwargs) -> None: + # type: ignore ; we know it's an AvailableBeamlineStatesMessage + msg: messages.AvailableBeamlineStatesMessage = msg_dict["data"] + parent._update_states(msg.states) # pylint: disable=protected-access + + def _update_state(self, state: BeamlineStateConfig) -> None: + if state.name in self._states: + self._states[state.name] = state + self._publish_states() + return + raise ValueError(f"State with name {state.name} not found") + + def _update_states(self, states: list[messages.BeamlineStateConfig]) -> None: + remove_state_names = set(self._states) - set(state.name for state in states) + + added_state_names = set(state.name for state in states) - set(self._states) + added_states = {state.name: state for state in states if state.name in added_state_names} + + for state_name in remove_state_names: + if hasattr(self, state_name): + delattr(self, state_name) + self._states.pop(state_name, None) + + for state_name, state in added_states.items(): + state_class = getattr(bl_states, state.state_type) + model_cls = state_class.CONFIG_CLASS + model_instance = model_cls(**state.parameters) + instance = BeamlineStateClientBase(manager=self, state=model_instance) + setattr(self, state.name, instance) + self._states[state.name] = model_instance + + def _publish_states(self) -> None: + bl_states_container = [ + messages.BeamlineStateConfig( + name=state.name, + title=state.title if state.title else state.name, + state_type=state.state_type, + parameters=state.model_dump(), + ) + for state in self._states.values() + ] + msg = messages.AvailableBeamlineStatesMessage(states=bl_states_container) + self._connector.xadd( + MessageEndpoints.available_beamline_states(), + {"data": msg}, + max_size=1, + approximate=False, + ) + + ########################## + ##### Public API ######### + ########################## + + def add(self, state: bl_states.BeamlineStateConfig) -> None: + """ + Add a new beamline state to the manager. + Args: + state (BeamlineStateConfig): The beamline state to add. + """ + + self._states[state.name] = state + self._publish_states() + + def remove(self, state_name: str) -> None: + """ + Remove a beamline state by name. + Args: + state_name (str): The name of the state to remove. + """ + if state_name in self._states: + del self._states[state_name] + self._publish_states() + + def show_all(self): + """ + Pretty print all beamline states using rich. + """ + + def _format_parameters(state_config: bl_states.BeamlineStateConfig) -> str: + parameter_dict = state_config.model_dump(exclude={"name", "title"}, exclude_none=True) + if not parameter_dict: + return "-" + return "\n".join(f"{key}={value}" for key, value in parameter_dict.items()) + + def _status_style(status_value: str) -> str: + status_styles = {"valid": "green3", "invalid": "red3", "warning": "yellow3"} + return status_styles.get(status_value.lower(), "grey50") + + console = Console() + table = Table(title="Beamline States", padding=(0, 1, 1, 1)) + table.add_column("Name", style="magenta", no_wrap=True) + table.add_column("Type", style="grey70") + table.add_column("Parameters", style="grey70") + table.add_column("Status") + table.add_column("Label") + + for state in self._states.values(): + params = _format_parameters(state) + status = ( + getattr(self, state.name).get() + if hasattr(self, state.name) + else {"status": "unknown", "label": "No state information available."} + ) + status_value = str(status.get("status", "")) + status_style = _status_style(status_value) + table.add_row( + str(state.name), + str(state.state_type), + str(params), + f"[{status_style}]{status_value}[/{status_style}]", + f"[{status_style}]{str(status.get('label', ''))}[/{status_style}]", + ) + + console.print(table) diff --git a/bec_lib/bec_lib/bl_states.py b/bec_lib/bec_lib/bl_states.py index 84d5c0c59..63972b32c 100644 --- a/bec_lib/bec_lib/bl_states.py +++ b/bec_lib/bec_lib/bl_states.py @@ -1,157 +1,99 @@ from __future__ import annotations +import keyword from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar, Generic, Type, TypeVar -from rich.console import Console -from rich.table import Table +from pydantic import BaseModel, field_validator from bec_lib import messages from bec_lib.device import DeviceBase from bec_lib.endpoints import MessageEndpoints +from bec_lib.redis_connector import RedisConnector if TYPE_CHECKING: - from bec_lib.client import BECClient from bec_lib.redis_connector import MessageObject, RedisConnector -class BeamlineStateManager: - """Manager for beamline states.""" +class BeamlineStateConfig(BaseModel): + """ + Base Configuration for a beamline state. + """ - def __init__(self, client: BECClient) -> None: - self._client = client - self._connector = client.connector - self._states: list[messages.BeamlineStateConfig] = [] - self._connector.register( - MessageEndpoints.available_beamline_states(), - cb=self._on_state_update, - parent=self, - from_start=True, - ) + state_type: ClassVar[str] = "BeamlineState" - @staticmethod - def _on_state_update(msg_dict: dict, *, parent: BeamlineStateManager, **_kwargs) -> None: - # type: ignore ; we know it's an AvailableBeamlineStatesMessage - msg: messages.AvailableBeamlineStatesMessage = msg_dict["data"] - parent._states = msg.states + name: str + title: str | None = None - def add(self, state: BeamlineState) -> None: + model_config = {"extra": "forbid", "arbitrary_types_allowed": True} + + @field_validator("name") + @classmethod + def validate_name(cls, v: str) -> str: """ - Add a new beamline state to the manager. - Args: - state (BeamlineState): The beamline state to add. + Validate that the state name is a valid Python identifier and does not conflict with reserved method names. """ + if not v.isidentifier(): + raise ValueError(f"State name '{v}' must be a valid Python identifier.") + if keyword.iskeyword(v): + raise ValueError(f"State name '{v}' cannot be a reserved Python keyword.") + if v in {"add", "remove", "show_all"}: + raise ValueError(f"State name '{v}' is reserved and cannot be used.") + return v - if any(state.name == existing_state.name for existing_state in self._states): - return # state already exists - info: messages.BeamlineStateConfig = messages.BeamlineStateConfig( - name=state.name, - title=state.title, - state_type=state.__class__.__name__, - parameters=state.parameters(), - ) - cls = state.__class__ - try: - condi = cls(name=state.name, redis_connector=self._connector) - condi.configure(**state.parameters()) - except Exception as e: - raise RuntimeError(f"Failed to add state {state.name}: {e}") from e +class DeviceStateConfig(BeamlineStateConfig): + """ + Configuration for a device-based beamline state. + """ - if isinstance(state, DeviceBeamlineState): - self._verify_signal_exists(state) + state_type: ClassVar[str] = "DeviceState" - self._states.append(info) - msg = messages.AvailableBeamlineStatesMessage(states=self._states) - self._connector.xadd( - MessageEndpoints.available_beamline_states(), {"data": msg}, max_size=1 - ) + device: DeviceBase | str + signal: DeviceBase | str | None = None - def _verify_signal_exists(self, state: DeviceBeamlineState) -> None: + @field_validator("device", "signal", mode="before") + @classmethod + def validate_device(cls, v: DeviceBase | str) -> str: """ - Verify that the device and signal exist in the device manager. + Validate that the device is either a string or a DeviceBase instance. If it's a DeviceBase instance, return its name. + """ + if isinstance(v, DeviceBase): + return v.dotted_name + return v - Args: - state (DeviceBeamlineState): The state to verify. - Raises: RuntimeError if the device or signal does not exist. - """ - device = state.parameters().get("device") - signal = state.parameters().get("signal") - if isinstance(device, DeviceBase): - device = device.name - - if not self._client.device_manager.devices.get(device): - raise RuntimeError( - f"Device {device} not found in device manager. Cannot add state {state.name}." - ) - if signal is not None: - if signal not in self._client.device_manager.devices[device].read(): - raise RuntimeError( - f"Signal {signal} not found in device {device}. Cannot add state {state.name}." - ) - else: - hinted_signals = self._client.device_manager.devices[device]._hints - if hinted_signals: - signal = hinted_signals[0] - else: - signal = device - state.update_parameters(device=device, signal=signal) - - def remove(self, state_name: str) -> None: - """ - Remove a beamline state by name. - Args: - state_name (str): The name of the state to remove. - """ - if not any(state.name == state_name for state in self._states): - return # state does not exist - self._states = [state for state in self._states if state.name != state_name] - msg = messages.AvailableBeamlineStatesMessage(states=self._states) - self._connector.xadd( - MessageEndpoints.available_beamline_states(), {"data": msg}, max_size=1 - ) +class DeviceWithinLimitsStateConfig(DeviceStateConfig): + """ + Configuration for a device within limits beamline state. + """ - def show_all(self): - """ - Pretty print all beamline states using rich. - """ - console = Console() - table = Table(title="Beamline States") - table.add_column("Name", style="cyan", no_wrap=True) - table.add_column("Type", style="magenta") - table.add_column("Parameters", style="green") + state_type: ClassVar[str] = "DeviceWithinLimitsState" + + min_limit: float | None = None + max_limit: float | None = None + tolerance: float = 0.1 - for state in self._states: - params = state.parameters if state.parameters else "-" - table.add_row(str(state.name), str(state.state_type), str(params)) - console.print(table) +C = TypeVar("C", bound=BeamlineStateConfig) +D = TypeVar("D", bound=DeviceStateConfig) -class BeamlineState(ABC): +class BeamlineState(ABC, Generic[C]): """Abstract base class for beamline states.""" + CONFIG_CLASS: Type[C] + def __init__( - self, name: str, redis_connector: RedisConnector | None = None, title: str | None = None + self, config: C | None = None, redis_connector: RedisConnector | None = None, **kwargs ) -> None: - self.name = name + self.config = config or self.CONFIG_CLASS(**kwargs) self.connector = redis_connector - self.title = title if title is not None else name - self._configured = False self._last_state: messages.BeamlineStateMessage | None = None - def configure(self, **kwargs) -> None: - """Configure the state with given parameters.""" - self._configured = True - - def parameters(self) -> dict: - """Return the configuration parameters of the state.""" - return {} - def update_parameters(self, **kwargs) -> None: """Update the configuration parameters of the state.""" - pass + self.config = self.CONFIG_CLASS(**{**self.config.model_dump(), **kwargs}) @abstractmethod def evaluate(self, *args, **kwargs) -> messages.BeamlineStateMessage | None: @@ -163,44 +105,37 @@ def start(self) -> None: def stop(self) -> None: """Stop monitoring the state if needed.""" + def restart(self) -> None: + """Restart the state monitoring.""" + self.stop() + self.start() -class DeviceBeamlineState(BeamlineState): - """A beamline state that depends on a device reading.""" - def configure(self, device: str | DeviceBase, signal: str | None = None, **kwargs) -> None: - self.device = device if isinstance(device, str) else device.name - self.signal = signal - super().configure(**kwargs) +class DeviceBeamlineState(BeamlineState[D], Generic[D]): + """A beamline state that depends on a device reading.""" - def parameters(self) -> dict: - params = super().parameters() - params.update({"device": self.device, "signal": self.signal}) - return params + CONFIG_CLASS: Type[D] - def update_parameters(self, **kwargs) -> None: - if "device" in kwargs: - device = kwargs.pop("device") - self.device = device if isinstance(device, str) else device.name - if "signal" in kwargs: - self.signal = kwargs.pop("signal") - super().update_parameters(**kwargs) + def __init__( + self, config: D | None = None, redis_connector: RedisConnector | None = None, **kwargs + ) -> None: + super().__init__(config, redis_connector, **kwargs) + self._last_value = None def start(self) -> None: - if not self._configured: - raise RuntimeError("State must be configured before starting.") if self.connector is None: raise RuntimeError("Redis connector is not set.") self.connector.register( - MessageEndpoints.device_readback(self.device), cb=self._update_device_state, parent=self + MessageEndpoints.device_readback(self.config.device), + cb=self._update_device_state, + parent=self, ) def stop(self) -> None: - if not self._configured: - return if self.connector is None: return self.connector.unregister( - MessageEndpoints.device_readback(self.device), cb=self._update_device_state + MessageEndpoints.device_readback(self.config.device), cb=self._update_device_state ) @staticmethod @@ -214,11 +149,11 @@ def _update_device_state(msg_obj: MessageObject, parent: DeviceBeamlineState) -> if out is not None and out != parent._last_state: parent._last_state = out parent.connector.xadd( - MessageEndpoints.beamline_state(parent.name), {"data": out}, max_size=1 + MessageEndpoints.beamline_state(parent.config.name), {"data": out}, max_size=1 ) -class ShutterState(DeviceBeamlineState): +class ShutterState(DeviceBeamlineState[DeviceStateConfig]): """ A state that checks if the shutter is open. @@ -228,18 +163,22 @@ class ShutterState(DeviceBeamlineState): bec.beamline_states.add(shutter_state) """ - def evaluate(self, msg: messages.DeviceMessage, **kwargs) -> messages.BeamlineStateMessage: - val = msg.signals.get(self.signal, {}).get("value", "").lower() + CONFIG_CLASS = DeviceStateConfig + + def evaluate( + self, msg: messages.DeviceMessage, *args, **kwargs + ) -> messages.BeamlineStateMessage: + val = msg.signals.get(self.config.signal, {}).get("value", "").lower() if val == "open": return messages.BeamlineStateMessage( - name=self.name, status="valid", label="Shutter is open." + name=self.config.name, status="valid", label="Shutter is open." ) return messages.BeamlineStateMessage( - name=self.name, status="invalid", label="Shutter is closed." + name=self.config.name, status="invalid", label="Shutter is closed." ) -class DeviceWithinLimitsState(DeviceBeamlineState): +class DeviceWithinLimitsState(DeviceBeamlineState[DeviceWithinLimitsStateConfig]): """ A state that checks if a positioner is within limits. @@ -250,89 +189,51 @@ class DeviceWithinLimitsState(DeviceBeamlineState): """ - def configure( - self, - device: str, - min_limit: float | None = None, - max_limit: float | None = None, - tolerance: float = 0.1, - signal: str | None = None, - **kwargs, - ) -> None: - """ - Configure the positioner condition. - - Args: - device (str): The name of the positioner device. - min_limit (float | None): The minimum limit for the positioner. If None, no minimum limit is enforced. - max_limit (float | None): The maximum limit for the positioner. If None, no maximum limit is enforced. - tolerance (float): The tolerance for warning conditions (default is 0.1). When the positioner is within - 10% of the limits, a warning condition will be issued. Note that the tolerance is ignored - if one of the limits is None. - signal (str, optional): The name of the signal to monitor. If not provided, defaults to the device name. - """ - self.min_limit = min_limit - self.max_limit = max_limit - self.tolerance = tolerance - super().configure(device=device, signal=signal, **kwargs) - - def parameters(self) -> dict: - params = super().parameters() - params.update( - { - "device": self.device, - "min_limit": self.min_limit, - "max_limit": self.max_limit, - "tolerance": self.tolerance, - "signal": self.signal, - } - ) - return params + CONFIG_CLASS = DeviceWithinLimitsStateConfig - def update_parameters(self, **kwargs) -> None: - if "min_limit" in kwargs: - self.min_limit = kwargs.pop("min_limit") - if "max_limit" in kwargs: - self.max_limit = kwargs.pop("max_limit") - if "tolerance" in kwargs: - self.tolerance = kwargs.pop("tolerance") - super().update_parameters(**kwargs) - - def evaluate(self, msg: messages.DeviceMessage, **kwargs) -> messages.BeamlineStateMessage: + def evaluate( + self, msg: messages.DeviceMessage, *args, **kwargs + ) -> messages.BeamlineStateMessage: """ Evaluate if the positioner is within the defined limits. If it is outside the limits, return an invalid state. Otherwise, return a valid state. If it is within 10% of the limits, return a warning state. """ - if self.min_limit is None: - self.min_limit = float("-inf") - if self.max_limit is None: - self.max_limit = float("inf") + if self.config.min_limit is None: + self.config.min_limit = float("-inf") + if self.config.max_limit is None: + self.config.max_limit = float("inf") - signal_name = self.signal if self.signal is not None else self.device + signal_name = self.config.signal if self.config.signal is not None else self.config.device val = msg.signals.get(signal_name, {}).get("value", None) if val is None: return messages.BeamlineStateMessage( - name=self.name, status="invalid", label=f"Positioner {self.device} value not found." + name=self.config.name, + status="invalid", + label=f"Positioner {self.config.device} value not found.", ) - if val < self.min_limit or val > self.max_limit: + if val < self.config.min_limit or val > self.config.max_limit: return messages.BeamlineStateMessage( - name=self.name, status="invalid", label=f"Positioner {self.device} out of limits" + name=self.config.name, + status="invalid", + label=f"Positioner {self.config.device} out of limits", ) - if self.min_limit == float("-inf") or self.max_limit == float("inf"): - self.tolerance = 0 + min_warning_threshold = self.config.min_limit + self.config.tolerance + max_warning_threshold = self.config.max_limit - self.config.tolerance - min_warning_threshold = self.min_limit + self.tolerance * (self.max_limit - self.min_limit) - max_warning_threshold = self.max_limit - self.tolerance * (self.max_limit - self.min_limit) if val < min_warning_threshold or val > max_warning_threshold: return messages.BeamlineStateMessage( - name=self.name, status="warning", label=f"Positioner {self.device} near limits" + name=self.config.name, + status="warning", + label=f"Positioner {self.config.device} near limits", ) return messages.BeamlineStateMessage( - name=self.name, status="valid", label=f"Positioner {self.device} within limits" + name=self.config.name, + status="valid", + label=f"Positioner {self.config.device} within limits", ) diff --git a/bec_lib/bec_lib/client.py b/bec_lib/bec_lib/client.py index e71d38264..7c9aaaa7b 100644 --- a/bec_lib/bec_lib/client.py +++ b/bec_lib/bec_lib/client.py @@ -20,7 +20,7 @@ from bec_lib.alarm_handler import AlarmHandler, Alarms from bec_lib.bec_service import BECService -from bec_lib.bl_states import BeamlineStateManager +from bec_lib.bl_state_manager import BeamlineStateManager from bec_lib.callback_handler import CallbackHandler, EventType from bec_lib.config_helper import ConfigHelperUser from bec_lib.dap_plugins import DAPPlugins diff --git a/bec_lib/tests/test_beamline_states.py b/bec_lib/tests/test_beamline_states.py index e680f4b67..3490a0e41 100644 --- a/bec_lib/tests/test_beamline_states.py +++ b/bec_lib/tests/test_beamline_states.py @@ -1,15 +1,16 @@ -import time +from __future__ import annotations + +import inspect from unittest import mock import pytest +from pydantic import BaseModel -from bec_lib import messages -from bec_lib.bl_states import ( - BeamlineState, +from bec_lib import bl_states, messages +from bec_lib.bl_state_manager import ( + BeamlineStateClientBase, BeamlineStateManager, - DeviceBeamlineState, - DeviceWithinLimitsState, - ShutterState, + build_signature_from_model, ) from bec_lib.endpoints import MessageEndpoints from bec_lib.redis_connector import MessageObject @@ -19,546 +20,298 @@ def state_manager(connected_connector): client = mock.MagicMock() client.connector = connected_connector - client.device_manager = mock.MagicMock() - config = BeamlineStateManager(client) - yield config + manager = BeamlineStateManager(client) + yield manager -# ============================================================================ -# BeamlineState tests -# ============================================================================ +class TestHelpers: + def test_build_signature_from_model(self): + class DemoConfig(BaseModel): + foo: int = 1 + bar: str = "abc" + signature = build_signature_from_model(DemoConfig) -class TestBeamlineState: - """Tests for the abstract BeamlineState base class.""" + assert list(signature.parameters) == ["foo", "bar"] + assert signature.parameters["foo"].kind == inspect.Parameter.KEYWORD_ONLY + assert signature.parameters["foo"].annotation is int + assert signature.parameters["bar"].default == "abc" - def test_beamline_state_initialization(self): - """Test basic initialization of a BeamlineState.""" - class ConcreteState(BeamlineState): - def evaluate(self, *args, **kwargs): - return messages.BeamlineStateMessage(name=self.name, status="invalid", label="Test") +class TestConfigModels: + def test_beamline_state_config_valid_name(self): + config = bl_states.BeamlineStateConfig(name="shutter_open", title="Shutter") + assert config.name == "shutter_open" - state = ConcreteState(name="test_state", title="Test State") - assert state.name == "test_state" - assert state.title == "Test State" - assert state.connector is None - assert state._configured is False - assert state._last_state is None + @pytest.mark.parametrize("invalid_name", ["state-name", "class", "add", "remove", "show_all"]) + def test_beamline_state_config_invalid_name(self, invalid_name): + with pytest.raises(ValueError): + bl_states.BeamlineStateConfig(name=invalid_name) - def test_beamline_state_default_title(self): - """Test that title defaults to name if not provided.""" + def test_device_state_config_keeps_string_device_and_signal(self): + config = bl_states.DeviceStateConfig(name="state", device="samx", signal="samx") + assert config.device == "samx" + assert config.signal == "samx" - class ConcreteState(BeamlineState): - def evaluate(self, *args, **kwargs): - return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") - state = ConcreteState(name="test_state") - assert state.title == "test_state" - - def test_beamline_state_configure(self): - """Test that configure marks the condition as configured.""" +class TestBeamlineStateBase: + def test_beamline_state_initialization_and_update(self): + class ConcreteState(bl_states.BeamlineState[bl_states.BeamlineStateConfig]): + CONFIG_CLASS = bl_states.BeamlineStateConfig - class ConcreteState(BeamlineState): def evaluate(self, *args, **kwargs): - return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") + return messages.BeamlineStateMessage( + name=self.config.name, status="valid", label="ok" + ) state = ConcreteState(name="test_state") - assert state._configured is False - state.configure() - assert state._configured is True - - def test_beamline_state_parameters(self): - """Test that parameters returns an empty dict by default.""" - - class ConcreteState(BeamlineState): - def evaluate(self, *args, **kwargs): - return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") - - state = ConcreteState(name="test_state") - assert state.parameters() == {} - - def test_beamline_state_with_connector(self, connected_connector): - """Test BeamlineState initialization with a connector.""" - - class ConcreteState(BeamlineState): - def evaluate(self, *args, **kwargs): - return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") - - state = ConcreteState(name="test_state", redis_connector=connected_connector) - assert state.connector == connected_connector + assert state.config.name == "test_state" + assert state.connector is None + assert state._last_state is None -# ============================================================================ -# DeviceBeamlineState tests -# ============================================================================ + state.update_parameters(title="Test State") + assert state.config.title == "Test State" class TestDeviceBeamlineState: - """Tests for DeviceBeamlineState.""" - - def test_device_state_configure(self, connected_connector): - """Test DeviceBeamlineState configuration.""" - - class ConcreteDeviceState(DeviceBeamlineState): - def evaluate(self, *args, **kwargs): - return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") - - state = ConcreteDeviceState(name="device_test", redis_connector=connected_connector) - state.configure(device="samx", signal="samx_value") - assert state.device == "samx" - assert state.signal == "samx_value" - assert state._configured is True - - def test_device_state_configure_default_signal(self, connected_connector): - """Test that signal defaults to device name if not provided.""" - - class ConcreteDeviceState(DeviceBeamlineState): - def evaluate(self, *args, **kwargs): - return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") - - state = ConcreteDeviceState(name="device_test", redis_connector=connected_connector) - state.configure(device="samx", signal="samx") - assert state.device == "samx" - assert state.signal == "samx" - - def test_device_state_parameters(self, connected_connector): - """Test that parameters includes device and signal.""" - - class ConcreteDeviceState(DeviceBeamlineState): - def evaluate(self, *args, **kwargs): - return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") - - state = ConcreteDeviceState(name="device_test", redis_connector=connected_connector) - state.configure(device="samx", signal="samx_value") - params = state.parameters() - assert params["device"] == "samx" - assert params["signal"] == "samx_value" - - def test_device_state_start_not_configured(self, connected_connector): - """Test that start raises RuntimeError if state is not configured.""" - - class ConcreteDeviceState(DeviceBeamlineState): - def evaluate(self, *args, **kwargs): - return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") + def test_start_requires_connector(self): + state = bl_states.ShutterState(name="shutter_open", device="shutter1", signal="shutter1") - state = ConcreteDeviceState(name="device_test", redis_connector=connected_connector) - with pytest.raises(RuntimeError, match="State must be configured before starting"): - state.start() - - def test_device_state_start_no_connector(self): - """Test that start raises RuntimeError if connector is not set.""" - - class ConcreteDeviceState(DeviceBeamlineState): - def evaluate(self, *args, **kwargs): - return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") - - state = ConcreteDeviceState(name="device_test") - state.configure(device="samx") with pytest.raises(RuntimeError, match="Redis connector is not set"): state.start() - def test_device_state_start_registers_callback(self, connected_connector): - """Test that start registers the callback with the connector.""" - - class ConcreteDeviceState(DeviceBeamlineState): - def evaluate(self, *args, **kwargs): - return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") + def test_start_registers_device_callback(self, connected_connector): + state = bl_states.ShutterState( + name="shutter_open", + device="shutter1", + signal="shutter1", + redis_connector=connected_connector, + ) - state = ConcreteDeviceState(name="device_test", redis_connector=connected_connector) - state.configure(device="samx") - with mock.patch.object(connected_connector, "register") as mock_register: + with mock.patch.object(connected_connector, "register") as register: state.start() - mock_register.assert_called_once() - call_args = mock_register.call_args - assert call_args[0][0] == MessageEndpoints.device_readback("samx") - def test_device_state_stop(self, connected_connector): - """Test that stop unregisters the callback.""" - - class ConcreteDeviceState(DeviceBeamlineState): - def evaluate(self, *args, **kwargs): - return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") + register.assert_called_once_with( + MessageEndpoints.device_readback("shutter1"), + cb=state._update_device_state, + parent=state, + ) - state = ConcreteDeviceState(name="device_test", redis_connector=connected_connector) - state.configure(device="samx") + def test_stop_unregisters_device_callback(self, connected_connector): + state = bl_states.ShutterState( + name="shutter_open", + device="shutter1", + signal="shutter1", + redis_connector=connected_connector, + ) - with mock.patch.object(connected_connector, "unregister") as mock_unregister: + with mock.patch.object(connected_connector, "unregister") as unregister: state.stop() - mock_unregister.assert_called_once() - - def test_device_state_stop_not_configured(self, connected_connector): - """Test that stop doesn't raise an error if not configured.""" - - class ConcreteDeviceState(DeviceBeamlineState): - def evaluate(self, *args, **kwargs): - return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") - - state = ConcreteDeviceState(name="device_test", redis_connector=connected_connector) - # Should not raise an error - state.stop() - - def test_device_state_stop_no_connector(self): - """Test that stop doesn't raise an error if connector is not set.""" - - class ConcreteDeviceState(DeviceBeamlineState): - def evaluate(self, *args, **kwargs): - return messages.BeamlineStateMessage(name=self.name, status="valid", label="Test") - - state = ConcreteDeviceState(name="device_test") - state.configure(device="samx") - # Should not raise an error - state.stop() - - def test_device_state_update_device_state(self, connected_connector): - """Test that _update_device_state calls evaluate and updates _last_state.""" - - msg = messages.BeamlineStateMessage(name="device_test", status="valid", label="Test") - - class ConcreteDeviceState(DeviceBeamlineState): - def evaluate(self, *args, **kwargs): - return msg - - state = ConcreteDeviceState(name="device_test", redis_connector=connected_connector) - state.configure(device="samx") - - msg_obj = MessageObject(value=msg, topic="test_topic") - state._update_device_state(msg_obj, parent=state) - assert state._last_state == msg - out = state.connector.xread(MessageEndpoints.beamline_state("device_test"), from_start=True) - assert out is not None - assert out[0]["data"] == msg - - -# ============================================================================ -# ShutterState tests -# ============================================================================ - - -class TestShutterState: - """Tests for ShutterState.""" - def test_shutter_open(self, connected_connector): - """Test evaluation when shutter is open.""" - state = ShutterState(name="shutter_open", redis_connector=connected_connector) - state.configure(device="shutter1", signal="shutter1") - - msg = messages.DeviceMessage( - signals={"shutter1": {"value": "open", "timestamp": 1234567890.0}}, - metadata={"stream": "primary"}, + unregister.assert_called_once_with( + MessageEndpoints.device_readback("shutter1"), cb=state._update_device_state ) - result = state.evaluate(msg) - assert result.name == "shutter_open" - assert result.status == "valid" - assert result.label == "Shutter is open." - - def test_shutter_open_uppercase(self, connected_connector): - """Test evaluation when shutter value is uppercase and gets lowercased.""" - state = ShutterState(name="shutter_open", redis_connector=connected_connector) - state.configure(device="shutter1", signal="shutter1") - - msg = messages.DeviceMessage( - signals={"shutter1": {"value": "OPEN", "timestamp": 1234567890.0}}, - metadata={"stream": "primary"}, + def test_update_device_state_publishes_when_state_changes(self, connected_connector): + state = bl_states.ShutterState( + name="shutter_open", + device="shutter1", + signal="shutter1", + redis_connector=connected_connector, ) - result = state.evaluate(msg) - assert result.status == "valid" - assert result.label == "Shutter is open." - - def test_shutter_closed(self, connected_connector): - """Test evaluation when shutter is closed.""" - state = ShutterState(name="shutter_open", redis_connector=connected_connector) - state.configure(device="shutter1") - msg = messages.DeviceMessage( - signals={"shutter1": {"value": "closed", "timestamp": 1234567890.0}}, + signals={"shutter1": {"value": "open", "timestamp": 1.0}}, metadata={"stream": "primary"}, ) + msg_obj = MessageObject(value=msg, topic="test") - result = state.evaluate(msg) - assert result.name == "shutter_open" - assert result.status == "invalid" - assert result.label == "Shutter is closed." - - def test_shutter_missing_value(self, connected_connector): - """Test evaluation when value is missing.""" - state = ShutterState(name="shutter_open", redis_connector=connected_connector) - state.configure(device="shutter1") + state._update_device_state(msg_obj, parent=state) - msg = messages.DeviceMessage( - signals={"shutter1": {"timestamp": 1234567890.0}}, metadata={"stream": "primary"} + assert state._last_state is not None + assert state._last_state.status == "valid" + out = connected_connector.xread( + MessageEndpoints.beamline_state("shutter_open"), from_start=True ) + assert out is not None + assert out[0]["data"].status == "valid" - result = state.evaluate(msg) - assert result.status == "invalid" - assert result.label == "Shutter is closed." - - -# ============================================================================ -# DeviceWithinLimitsState tests -# ============================================================================ - - -class TestDeviceWithinLimitsState: - """Tests for DeviceWithinLimitsState.""" - - def test_within_limits_configure(self, connected_connector): - """Test configuration of DeviceWithinLimitsState.""" - state = DeviceWithinLimitsState(name="sample_x_limits", redis_connector=connected_connector) - state.configure(device="sample_x", min_limit=0.0, max_limit=10.0) - - assert state.device == "sample_x" - assert state.min_limit == 0.0 - assert state.max_limit == 10.0 - assert state.tolerance == 0.1 - - def test_within_limits_configure_custom_tolerance(self, connected_connector): - """Test configuration with custom tolerance.""" - state = DeviceWithinLimitsState(name="sample_x_limits", redis_connector=connected_connector) - state.configure(device="sample_x", min_limit=0.0, max_limit=10.0, tolerance=0.2) - - assert state.tolerance == 0.2 - - def test_within_limits_value_inside(self, connected_connector): - """Test evaluation when value is within limits.""" - state = DeviceWithinLimitsState(name="sample_x_limits", redis_connector=connected_connector) - state.configure(device="sample_x", min_limit=0.0, max_limit=10.0) - msg = messages.DeviceMessage( - signals={"sample_x": {"value": 5.0, "timestamp": 1234567890.0}}, - metadata={"stream": "primary"}, +class TestConcreteStates: + def test_shutter_state_open_and_closed(self, connected_connector): + state = bl_states.ShutterState( + name="shutter_open", + device="shutter1", + signal="shutter1", + redis_connector=connected_connector, ) - result = state.evaluate(msg) - assert result.status == "valid" - assert result.label == "Positioner sample_x within limits" - - def test_within_limits_value_outside_low(self, connected_connector): - """Test evaluation when value is below minimum limit.""" - state = DeviceWithinLimitsState(name="sample_x_limits", redis_connector=connected_connector) - state.configure(device="sample_x", min_limit=0.0, max_limit=10.0) - - msg = messages.DeviceMessage( - signals={"sample_x": {"value": -1.0, "timestamp": 1234567890.0}}, + open_msg = messages.DeviceMessage( + signals={"shutter1": {"value": "OPEN", "timestamp": 1.0}}, metadata={"stream": "primary"}, ) - - result = state.evaluate(msg) - assert result.status == "invalid" - assert result.label == "Positioner sample_x out of limits" - - def test_within_limits_value_outside_high(self, connected_connector): - """Test evaluation when value is above maximum limit.""" - state = DeviceWithinLimitsState(name="sample_x_limits", redis_connector=connected_connector) - state.configure(device="sample_x", min_limit=0.0, max_limit=10.0) - - msg = messages.DeviceMessage( - signals={"sample_x": {"value": 11.0, "timestamp": 1234567890.0}}, + closed_msg = messages.DeviceMessage( + signals={"shutter1": {"value": "closed", "timestamp": 2.0}}, metadata={"stream": "primary"}, ) - result = state.evaluate(msg) - assert result.status == "invalid" - assert result.label == "Positioner sample_x out of limits" - - def test_within_limits_value_near_min(self, connected_connector): - """Test evaluation when value is near minimum limit (within tolerance).""" - state = DeviceWithinLimitsState(name="sample_x_limits", redis_connector=connected_connector) - state.configure(device="sample_x", min_limit=0.0, max_limit=10.0, tolerance=0.1) - - # 10% of (10 - 0) = 1.0, so near min is < 1.0 - msg = messages.DeviceMessage( - signals={"sample_x": {"value": 0.5, "timestamp": 1234567890.0}}, - metadata={"stream": "primary"}, + assert state.evaluate(open_msg).status == "valid" + assert state.evaluate(closed_msg).status == "invalid" + + def test_device_within_limits_state(self, connected_connector): + state = bl_states.DeviceWithinLimitsState( + name="sample_x_limits", + device="sample_x", + min_limit=0.0, + max_limit=10.0, + tolerance=0.1, + redis_connector=connected_connector, ) - result = state.evaluate(msg) - assert result.status == "warning" - assert result.label == "Positioner sample_x near limits" - - def test_within_limits_value_near_max(self, connected_connector): - """Test evaluation when value is near maximum limit (within tolerance).""" - state = DeviceWithinLimitsState(name="sample_x_limits", redis_connector=connected_connector) - state.configure(device="sample_x", min_limit=0.0, max_limit=10.0, tolerance=0.1) - - # 10% of (10 - 0) = 1.0, so near max is > 9.0 - msg = messages.DeviceMessage( - signals={"sample_x": {"value": 9.5, "timestamp": 1234567890.0}}, - metadata={"stream": "primary"}, + valid = messages.DeviceMessage( + signals={"sample_x": {"value": 5.0, "timestamp": 1.0}}, metadata={"stream": "primary"} ) - - result = state.evaluate(msg) - assert result.status == "warning" - assert result.label == "Positioner sample_x near limits" - - def test_within_limits_missing_value(self, connected_connector): - """Test evaluation when value is missing.""" - state = DeviceWithinLimitsState(name="sample_x_limits", redis_connector=connected_connector) - state.configure(device="sample_x", min_limit=0.0, max_limit=10.0) - - msg = messages.DeviceMessage( - signals={"sample_x": {"timestamp": 1234567890.0}}, metadata={"stream": "primary"} + warning = messages.DeviceMessage( + signals={"sample_x": {"value": 0.05, "timestamp": 2.0}}, metadata={"stream": "primary"} + ) + invalid = messages.DeviceMessage( + signals={"sample_x": {"value": 11.0, "timestamp": 3.0}}, metadata={"stream": "primary"} ) + missing = messages.DeviceMessage( + signals={"sample_x": {"timestamp": 4.0}}, metadata={"stream": "primary"} + ) + + assert state.evaluate(valid).status == "valid" + assert state.evaluate(warning).status == "warning" + assert state.evaluate(invalid).status == "invalid" + assert state.evaluate(missing).status == "invalid" - result = state.evaluate(msg) - assert result.status == "invalid" - assert "value not found" in result.label - def test_within_limits_parameters(self, connected_connector): - """Test that parameters includes all configuration.""" - state = DeviceWithinLimitsState(name="sample_x_limits", redis_connector=connected_connector) - state.configure(device="sample_x", min_limit=0.0, max_limit=10.0, signal="x_readback") +class TestBeamlineStateManager: + def test_manager_registers_for_state_updates(self, connected_connector): + client = mock.MagicMock() + client.connector = connected_connector - params = state.parameters() - assert params["device"] == "sample_x" - assert params["min_limit"] == 0.0 - assert params["max_limit"] == 10.0 - assert params["tolerance"] == 0.1 - assert params["signal"] == "x_readback" + with mock.patch.object(connected_connector, "register") as register: + BeamlineStateManager(client) + register.assert_called_once_with( + MessageEndpoints.available_beamline_states(), + cb=mock.ANY, + parent=mock.ANY, + from_start=True, + ) -# ============================================================================ -# BeamlineStateConfig tests -# ============================================================================ + def test_on_state_update_creates_client_attribute(self, state_manager): + config = messages.BeamlineStateConfig( + name="shutter_open", + title="Shutter Open", + state_type="ShutterState", + parameters={"name": "shutter_open", "title": "Shutter Open", "device": "shutter1"}, + ) + update = messages.AvailableBeamlineStatesMessage(states=[config]) + + state_manager._on_state_update({"data": update}, parent=state_manager) + + assert "shutter_open" in state_manager._states + assert isinstance(state_manager._states["shutter_open"], bl_states.DeviceStateConfig) + assert isinstance(getattr(state_manager, "shutter_open"), BeamlineStateClientBase) + + def test_update_parameters_from_client_updates_state_and_publishes(self, state_manager): + config = messages.BeamlineStateConfig( + name="limits", + title="Limits", + state_type="DeviceWithinLimitsState", + parameters={ + "name": "limits", + "title": "Limits", + "device": "samx", + "min_limit": 0.0, + "max_limit": 10.0, + }, + ) + update = messages.AvailableBeamlineStatesMessage(states=[config]) + state_manager._on_state_update({"data": update}, parent=state_manager) + state_manager.limits.update_parameters(tolerance=0.25) -class TestBeamlineStateConfig: - """Tests for BeamlineStateConfig manager.""" + assert state_manager._states["limits"].tolerance == 0.25 - @pytest.mark.timeout(5) - def test_add_state(self, state_manager): - """Test adding a state.""" - state = ShutterState(name="shutter_open", title="Shutter Open") - state.configure(device="shutter1") + out = state_manager._connector.xread( + MessageEndpoints.available_beamline_states(), from_start=True + ) + assert out + assert isinstance(out[-1]["data"], messages.AvailableBeamlineStatesMessage) + + def test_client_get_returns_unknown_without_status_message(self, state_manager): + config = messages.BeamlineStateConfig( + name="shutter_open", + title="Shutter Open", + state_type="ShutterState", + parameters={"name": "shutter_open", "title": "Shutter Open", "device": "shutter1"}, + ) + update = messages.AvailableBeamlineStatesMessage(states=[config]) + state_manager._on_state_update({"data": update}, parent=state_manager) + + result = state_manager.shutter_open.get() + assert result == {"status": "unknown", "label": "No state information available."} + + def test_client_get_returns_latest_status_message(self, state_manager): + config = messages.BeamlineStateConfig( + name="shutter_open", + title="Shutter Open", + state_type="ShutterState", + parameters={"name": "shutter_open", "title": "Shutter Open", "device": "shutter1"}, + ) + update = messages.AvailableBeamlineStatesMessage(states=[config]) + state_manager._on_state_update({"data": update}, parent=state_manager) + + state_manager._connector.xadd( + MessageEndpoints.beamline_state("shutter_open"), + { + "data": messages.BeamlineStateMessage( + name="shutter_open", status="valid", label="ok" + ) + }, + max_size=1, + ) - # Setup device manager mock - the signal should match the device name when no signal is provided - state_manager._client.device_manager.devices = {"shutter1": mock.MagicMock()} - state_manager._client.device_manager.devices["shutter1"].read.return_value = { - "shutter1": {"value": "open"} - } + result = state_manager.shutter_open.get() + assert result == {"status": "valid", "label": "ok"} - state_manager.add(state) - while True: - if any(c.name == "shutter_open" for c in state_manager._states): - break - time.sleep(0.1) - # Check that the state was added - assert any(c.name == "shutter_open" for c in state_manager._states) - - @pytest.mark.timeout(5) - def test_add_state_already_exists(self, state_manager): - """Test that adding a duplicate state is ignored.""" - state = ShutterState(name="shutter_open", title="Shutter Open") - state.configure(device="shutter1") - - # Setup device manager mock - state_manager._client.device_manager.devices = {"shutter1": mock.MagicMock()} - state_manager._client.device_manager.devices["shutter1"].read.return_value = { - "shutter1": {"value": "open"} - } - - # Add the state once - state_manager.add(state) - while True: - if any(c.name == "shutter_open" for c in state_manager._states): - break - time.sleep(0.1) - initial_count = len(state_manager._states) + def test_add_and_remove_publish_updates(self, state_manager): + state = bl_states.DeviceStateConfig( + name="shutter_open", title="Shutter Open", device="shutter1" + ) - # Add the same state again state_manager.add(state) - time.sleep(0.5) - # Count should not increase - assert len(state_manager._states) == initial_count - - def test_add_state_device_not_found(self, state_manager): - """Test that adding a state with invalid device raises RuntimeError.""" - state = ShutterState(name="shutter_open", title="Shutter Open") - state.configure(device="nonexistent_shutter") - - state_manager._client.device_manager.devices = {} - - with pytest.raises(RuntimeError, match="Device nonexistent_shutter not found"): - state_manager.add(state) - - def test_add_state_signal_not_found(self, state_manager): - """Test that adding a state with invalid signal raises RuntimeError.""" - state = ShutterState(name="shutter_open", title="Shutter Open") - # Setup device manager mock with device but without the signal - mock_device = mock.MagicMock() - mock_device.read.return_value = {"other_signal": {"value": "open"}} - state_manager._client.device_manager.devices = {"shutter1": mock_device} - - state.configure(device="shutter1", signal="value") - - with pytest.raises(RuntimeError, match="Signal value not found in device shutter1"): - state_manager.add(state) - - @pytest.mark.timeout(5) - def test_remove_state(self, state_manager): - """Test removing a state.""" - state = ShutterState(name="shutter_open", title="Shutter Open") - state.configure(device="shutter1") - - # Setup device manager mock - state_manager._client.device_manager.devices = {"shutter1": mock.MagicMock()} - state_manager._client.device_manager.devices["shutter1"].read.return_value = { - "shutter1": {"value": "open"} - } - - # Add and then remove - state_manager.add(state) - while True: - if any(c.name == "shutter_open" for c in state_manager._states): - break - time.sleep(0.1) + assert "shutter_open" in state_manager._states state_manager.remove("shutter_open") - while True: - if not any(c.name == "shutter_open" for c in state_manager._states): - break - time.sleep(0.1) - - def test_remove_nonexistent_state(self, state_manager): - """Test removing a state that doesn't exist.""" - # Should not raise an error - state_manager.remove("nonexistent") - assert len(state_manager._states) == 0 - - @pytest.mark.timeout(5) - def test_show_all(self, state_manager, capsys): - """Test that show_all displays states in a table.""" - state = ShutterState(name="shutter_open", title="Shutter Open") - state.configure(device="shutter1") - - # Setup device manager mock - state_manager._client.device_manager.devices = {"shutter1": mock.MagicMock()} - state_manager._client.device_manager.devices["shutter1"].read.return_value = { - "shutter1": {"value": "open"} - } + assert "shutter_open" not in state_manager._states + + def test_client_delete_removes_state(self, state_manager): + config = messages.BeamlineStateConfig( + name="shutter_open", + title="Shutter Open", + state_type="ShutterState", + parameters={"name": "shutter_open", "title": "Shutter Open", "device": "shutter1"}, + ) + update = messages.AvailableBeamlineStatesMessage(states=[config]) + state_manager._on_state_update({"data": update}, parent=state_manager) - state_manager.add(state) - while True: - if any(c.name == "shutter_open" for c in state_manager._states): - break - time.sleep(0.1) - state_manager.show_all() + state_manager.shutter_open.delete() - # The output should be printed (checked via capsys) - captured = capsys.readouterr() - # Check that the state name appears in the output - assert "shutter_open" in captured.out or "shutter_open" in captured.err + assert "shutter_open" not in state_manager._states - def test_on_state_update(self, state_manager): - """Test that _on_state_update updates the states list.""" - update_entry = messages.BeamlineStateConfig( - name="test_state", title="Test State", state_type="ShutterState", parameters={} + def test_show_all_prints_table(self, state_manager, capsys): + state = bl_states.DeviceStateConfig( + name="shutter_open", title="Shutter Open", device="shutter1" ) - msg = messages.AvailableBeamlineStatesMessage(states=[update_entry]) + state_manager.add(state) - state_manager._on_state_update({"data": msg}, parent=state_manager) + state_manager.show_all() - assert len(state_manager._states) == 1 - assert state_manager._states[0].name == "test_state" + captured = capsys.readouterr() + assert "shutter_open" in (captured.out + captured.err) diff --git a/bec_server/bec_server/scan_server/beamline_state_manager.py b/bec_server/bec_server/scan_server/beamline_state_manager.py index 8c6a861f6..f9b6f2e67 100644 --- a/bec_server/bec_server/scan_server/beamline_state_manager.py +++ b/bec_server/bec_server/scan_server/beamline_state_manager.py @@ -1,8 +1,11 @@ from __future__ import annotations +import traceback + from bec_lib import bl_states, messages from bec_lib.alarm_handler import Alarms from bec_lib.endpoints import MessageEndpoints +from bec_lib.messages import ErrorInfo from bec_lib.redis_connector import RedisConnector @@ -11,7 +14,7 @@ class BeamlineStateManager: def __init__(self, connector: RedisConnector) -> None: self.connector = connector - self.states: list[bl_states.BeamlineState] = [] + self._states: dict[str, bl_states.BeamlineState] = {} self.connector.register( MessageEndpoints.available_beamline_states(), cb=self._handle_state_update, @@ -23,7 +26,16 @@ def __init__(self, connector: RedisConnector) -> None: def _handle_state_update(msg_dict: dict, *, parent: BeamlineStateManager, **_kwargs) -> None: msg: messages.AvailableBeamlineStatesMessage = msg_dict["data"] # type: ignore ; we know it's a AvailableBeamlineStatesMessage - parent.update_states(msg) + try: + parent.update_states(msg) + except Exception as exc: + content = traceback.format_exc() + info = ErrorInfo( + exception_type=type(exc).__name__, + error_message=content, + compact_error_message="Error updating beamline states.", + ) + parent.connector.raise_alarm(severity=Alarms.WARNING, info=info) def update_states(self, msg: messages.AvailableBeamlineStatesMessage) -> None: """ @@ -34,50 +46,33 @@ def update_states(self, msg: messages.AvailableBeamlineStatesMessage) -> None: """ # get the states that we need to remove - states_in_msg = {state.name for state in msg.states} - current_states = {state.name for state in self.states} - states_to_remove = current_states - states_in_msg - # remove states that are no longer needed - for state_name in states_to_remove: - state = next((s for s in self.states if s.name == state_name), None) - if state: - state.stop() - self.states.remove(state) - # filter out existing states from the message - new_states = [state for state in msg.states if state.name not in current_states] - # add new states - for state in new_states: - self.states.append(self.create_state_from_message(state)) + remove_state_names = set(self._states) - set(state.name for state in msg.states) - def create_state_from_message( - self, state_info: messages.BeamlineStateConfig - ) -> bl_states.BeamlineState: - """ - Create a BeamlineState instance from a BeamlineStateConfig message. + added_state_names = set(state.name for state in msg.states) - set(self._states) + added_states = { + state.name: state for state in msg.states if state.name in added_state_names + } - Args: - state_info (messages.BeamlineStateConfig): The state config message. - Returns: - BeamlineState: The created BeamlineState instance. - """ - try: - cls = getattr(bl_states, state_info.state_type, None) - if cls is None or not issubclass(cls, bl_states.BeamlineState): - raise ValueError( - f"State type {state_info.state_type} not found in beamline states." - ) - state = cls( - name=state_info.name, redis_connector=self.connector, title=state_info.title - ) - state.configure(**state_info.parameters) - state.start() - except Exception as exc: - self.connector.raise_alarm( - severity=Alarms.WARNING, - info=messages.ErrorInfo( - error_message=f"Failed to create beamline state {state_info.name}: {exc}", - compact_error_message=f"Failed to create beamline state {state_info.name}", - exception_type=type(exc).__name__, - ), - ) - return state + for state_name in remove_state_names: + if hasattr(self, state_name): + delattr(self, state_name) + self._states.pop(state_name, None) + + for state_name, state in added_states.items(): + state_class = getattr(bl_states, state.state_type) + if not issubclass(state_class, bl_states.BeamlineState): + raise ValueError(f"State type {state.state_type} not found in beamline states.") + model_cls = state_class.CONFIG_CLASS + model_instance = model_cls(**state.parameters) + state_instance = state_class(config=model_instance, redis_connector=self.connector) + state_instance.start() + self._states[state.name] = state_instance + + # Check if the config has changed for existing states and update them if needed + for state_msg in msg.states: + state = self._states.get(state_msg.name) + if state is None: + continue + if state.config.model_dump() != state_msg.parameters: + # The config has changed, we need to update the state + state.restart() diff --git a/bec_server/tests/tests_scan_server/test_beamline_state_manager.py b/bec_server/tests/tests_scan_server/test_beamline_state_manager.py index 9bb943620..5198220ba 100644 --- a/bec_server/tests/tests_scan_server/test_beamline_state_manager.py +++ b/bec_server/tests/tests_scan_server/test_beamline_state_manager.py @@ -3,8 +3,9 @@ import pytest -from bec_lib import messages +from bec_lib import bl_states, messages from bec_lib.endpoints import MessageEndpoints +from bec_server.scan_server import beamline_state_manager from bec_server.scan_server.beamline_state_manager import BeamlineStateManager @@ -14,6 +15,29 @@ def state_manager(connected_connector): yield manager +@pytest.fixture +def fake_bl_states(monkeypatch): + class FakeState(bl_states.BeamlineState[bl_states.DeviceStateConfig]): + CONFIG_CLASS = bl_states.DeviceStateConfig + + def __init__(self, config=None, redis_connector=None, **kwargs): + super().__init__(config=config, redis_connector=redis_connector, **kwargs) + self.started = False + self.restart_count = 0 + + def evaluate(self, *args, **kwargs): + return None + + def start(self): + self.started = True + + def restart(self): + self.restart_count += 1 + + monkeypatch.setattr(beamline_state_manager.bl_states, "ShutterState", FakeState) + return FakeState + + def test_state_manager_fetches_states(): """ Test that the BeamlineStateManager fetches all available beamline states on initialization. @@ -30,13 +54,13 @@ def test_state_manager_fetches_states(): @pytest.mark.timeout(5) -def test_state_manager_updates_states(state_manager, connected_connector): +def test_state_manager_updates_states(state_manager, connected_connector, fake_bl_states): """ Test that the BeamlineStateManager updates its states correctly when receiving an update message. """ # Initial state: no states - assert len(state_manager.states) == 0 + assert len(state_manager._states) == 0 msg = messages.AvailableBeamlineStatesMessage( states=[ @@ -44,7 +68,7 @@ def test_state_manager_updates_states(state_manager, connected_connector): name="State1", title="Shutter", state_type="ShutterState", - parameters={"device": "shutter1"}, + parameters={"name": "State1", "title": "Shutter", "device": "shutter1"}, ) ] ) @@ -54,7 +78,7 @@ def test_state_manager_updates_states(state_manager, connected_connector): ) # Give it some time to process - while len(state_manager.states) < 1: + while len(state_manager._states) < 1: time.sleep(0.1) msg = messages.AvailableBeamlineStatesMessage( @@ -63,13 +87,13 @@ def test_state_manager_updates_states(state_manager, connected_connector): name="State1", title="Shutter", state_type="ShutterState", - parameters={"device": "shutter1"}, + parameters={"name": "State1", "title": "Shutter", "device": "shutter1"}, ), messages.BeamlineStateConfig( name="State2", title="Shutter2", state_type="ShutterState", - parameters={"device": "shutter2"}, + parameters={"name": "State2", "title": "Shutter2", "device": "shutter2"}, ), ] ) @@ -79,7 +103,7 @@ def test_state_manager_updates_states(state_manager, connected_connector): ) # Give it some time to process - while len(state_manager.states) < 2: + while len(state_manager._states) < 2: time.sleep(0.1) msg = messages.AvailableBeamlineStatesMessage( @@ -88,7 +112,7 @@ def test_state_manager_updates_states(state_manager, connected_connector): name="State2", title="Shutter2", state_type="ShutterState", - parameters={"device": "shutter2"}, + parameters={"name": "State2", "title": "Shutter2", "device": "shutter2"}, ) ] ) @@ -96,8 +120,8 @@ def test_state_manager_updates_states(state_manager, connected_connector): MessageEndpoints.available_beamline_states(), {"data": msg}, max_size=1 ) # Give it some time to process - while len(state_manager.states) > 1: + while len(state_manager._states) > 1: time.sleep(0.1) - assert len(state_manager.states) == 1 - assert state_manager.states[0].name == "State2" + assert len(state_manager._states) == 1 + assert "State2" in state_manager._states From d762af09920768ac97dc4e36a0e692c958e0df03 Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Sat, 21 Feb 2026 16:13:54 +0100 Subject: [PATCH 3/6] f - wip --- bec_lib/bec_lib/bl_states.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/bec_lib/bec_lib/bl_states.py b/bec_lib/bec_lib/bl_states.py index 63972b32c..88a7796ed 100644 --- a/bec_lib/bec_lib/bl_states.py +++ b/bec_lib/bec_lib/bl_states.py @@ -110,6 +110,26 @@ def restart(self) -> None: self.stop() self.start() + def _emit_state(self, state_msg: messages.BeamlineStateMessage) -> None: + if self.connector is None: + return + is_different = ( + state_msg.model_dump(exclude={"timestamp"}) + != self._last_state.model_dump(exclude={"timestamp"}) + if self._last_state + else True + ) + if self._last_state is None: + is_different = True + if is_different: + self._last_state = state_msg + self.connector.xadd( + MessageEndpoints.beamline_state(self.config.name), + {"data": state_msg}, + max_size=1, + approximate=False, + ) + class DeviceBeamlineState(BeamlineState[D], Generic[D]): """A beamline state that depends on a device reading.""" @@ -120,7 +140,6 @@ def __init__( self, config: D | None = None, redis_connector: RedisConnector | None = None, **kwargs ) -> None: super().__init__(config, redis_connector, **kwargs) - self._last_value = None def start(self) -> None: if self.connector is None: @@ -146,11 +165,9 @@ def _update_device_state(msg_obj: MessageObject, parent: DeviceBeamlineState) -> msg: messages.DeviceMessage = msg_obj.value # type: ignore ; we know it's a DeviceMessage out = parent.evaluate(msg) - if out is not None and out != parent._last_state: - parent._last_state = out - parent.connector.xadd( - MessageEndpoints.beamline_state(parent.config.name), {"data": out}, max_size=1 - ) + if out is None: + return + parent._emit_state(out) class ShutterState(DeviceBeamlineState[DeviceStateConfig]): From 6f71add9e8cda5f1d6ab73cfd7bb869fc81821ca Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Sat, 21 Feb 2026 19:13:45 +0100 Subject: [PATCH 4/6] f - file writer and bug fixes --- bec_lib/bec_lib/bl_state_manager.py | 12 ++- bec_lib/bec_lib/bl_states.py | 15 ++-- bec_lib/bec_lib/messages.py | 1 + .../bec_server/file_writer/default_writer.py | 28 +++++++ .../bec_server/file_writer/file_writer.py | 1 + .../file_writer/file_writer_manager.py | 76 ++++++++++++++++++- .../scan_server/beamline_state_manager.py | 1 + 7 files changed, 124 insertions(+), 10 deletions(-) diff --git a/bec_lib/bec_lib/bl_state_manager.py b/bec_lib/bec_lib/bl_state_manager.py index 3213f6817..acb960c0b 100644 --- a/bec_lib/bec_lib/bl_state_manager.py +++ b/bec_lib/bec_lib/bl_state_manager.py @@ -16,13 +16,16 @@ from bec_lib.client import BECClient -def build_signature_from_model(model: BaseModel) -> Signature: +def build_signature_from_model(model: BaseModel, skip: set[str] | None = None) -> Signature: """ Build a function signature from a Pydantic model. The parameters of the signature will match the fields of the model. """ parameters = [] + skip = skip or set() for name, field in model.model_fields.items(): + if name in skip: + continue annotation = field.annotation or inspect.Parameter.empty parameters.append( Parameter( @@ -49,6 +52,7 @@ def __init__(self, manager: BeamlineStateManager, state: BeamlineStateConfig) -> self._manager = manager self._connector = manager._connector self._state = state + self._skip_parameters = {"name"} # pylint: disable=unnecessary-lambda self._run = lambda **kwargs: self._run_update(**kwargs) @@ -60,10 +64,14 @@ def _update_signature(self) -> None: setattr( getattr(self, "update_parameters"), "__signature__", - build_signature_from_model(self._state), + build_signature_from_model(self._state, skip=self._skip_parameters), ) def _run_update(self, **kwargs) -> None: + if not kwargs: + return + if self._skip_parameters.intersection(kwargs): + raise ValueError(f"Invalid parameters: {self._skip_parameters.intersection(kwargs)}") self._state = self._state.model_copy(update=kwargs) self._manager._update_state(self._state) # pylint: disable=protected-access diff --git a/bec_lib/bec_lib/bl_states.py b/bec_lib/bec_lib/bl_states.py index 88a7796ed..fb50daa4a 100644 --- a/bec_lib/bec_lib/bl_states.py +++ b/bec_lib/bec_lib/bl_states.py @@ -2,17 +2,14 @@ import keyword from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, ClassVar, Generic, Type, TypeVar +from typing import ClassVar, Generic, Type, TypeVar from pydantic import BaseModel, field_validator from bec_lib import messages from bec_lib.device import DeviceBase from bec_lib.endpoints import MessageEndpoints -from bec_lib.redis_connector import RedisConnector - -if TYPE_CHECKING: - from bec_lib.redis_connector import MessageObject, RedisConnector +from bec_lib.redis_connector import MessageObject, RedisConnector class BeamlineStateConfig(BaseModel): @@ -144,6 +141,14 @@ def __init__( def start(self) -> None: if self.connector is None: raise RuntimeError("Redis connector is not set.") + msg = self.connector.get(MessageEndpoints.device_readback(self.config.device)) + if msg is not None: + self._update_device_state( + MessageObject( + topic=MessageEndpoints.device_readback(self.config.device).endpoint, value=msg + ), + parent=self, + ) self.connector.register( MessageEndpoints.device_readback(self.config.device), cb=self._update_device_state, diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index f3b4d7ff7..2b561e68d 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -1969,6 +1969,7 @@ class BeamlineStateMessage(BECMessage): name: str status: Literal["valid", "invalid", "warning"] label: str + timestamp: float = Field(default_factory=time.time) class BeamlineStateConfig(BaseModel): diff --git a/bec_server/bec_server/file_writer/default_writer.py b/bec_server/bec_server/file_writer/default_writer.py index 44b44c12c..f17d22b36 100644 --- a/bec_server/bec_server/file_writer/default_writer.py +++ b/bec_server/bec_server/file_writer/default_writer.py @@ -2,6 +2,9 @@ from typing import TYPE_CHECKING, Any +import h5py +import numpy as np + if TYPE_CHECKING: from bec_lib import messages from bec_lib.devicemanager import DeviceManagerBase @@ -20,6 +23,7 @@ def __init__( info_storage: dict, configuration: dict, file_references: dict[str, messages.FileMessage], + beamline_states: dict[str, list[messages.BeamlineStateMessage]], device_manager: DeviceManagerBase, ): self.storage = storage @@ -28,6 +32,7 @@ def __init__( self.file_references = file_references self.device_manager = device_manager self.info_storage = info_storage + self.beamline_states = beamline_states def get_storage_format(self) -> dict: """ @@ -105,6 +110,29 @@ def write_bec_entries(self) -> None: else: file_device.create_ext_link(name="data", target=msg.file_path, entry="/") + # create beamline states + states = {} + for state_name, state_values in self.beamline_states.items(): + dtype = np.dtype( + [ + ("label", h5py.string_dtype("utf-8")), + ("status", h5py.string_dtype("utf-8")), + ("timestamp", np.float64), + ] + ) + states[state_name] = np.array( + [ + (state_msg.label, state_msg.status, state_msg.timestamp) + for state_msg in state_values + ], + dtype=dtype, + ) + beamline_states_group = collection.create_group("states") + beamline_states_group.attrs["NX_class"] = "NXcollection" + for state_name, state_values in states.items(): + state_group = beamline_states_group.create_dataset(name=state_name, data=state_values) + state_group.attrs["NX_class"] = "NXcollection" + def format(self) -> None: """ Prepare the NeXus file format. diff --git a/bec_server/bec_server/file_writer/file_writer.py b/bec_server/bec_server/file_writer/file_writer.py index b589d5418..1ad1294a3 100644 --- a/bec_server/bec_server/file_writer/file_writer.py +++ b/bec_server/bec_server/file_writer/file_writer.py @@ -300,6 +300,7 @@ def write( info_storage=info_storage, configuration=configuration_data, file_references=data.file_references, + beamline_states=data.beamline_states, device_manager=self.file_writer_manager.device_manager, ).get_storage_format() diff --git a/bec_server/bec_server/file_writer/file_writer_manager.py b/bec_server/bec_server/file_writer/file_writer_manager.py index bb6e4fada..59bbbbec3 100644 --- a/bec_server/bec_server/file_writer/file_writer_manager.py +++ b/bec_server/bec_server/file_writer/file_writer_manager.py @@ -4,6 +4,7 @@ import threading import time import traceback +from collections import defaultdict from bec_lib import messages from bec_lib.alarm_handler import Alarms @@ -35,9 +36,10 @@ def __init__(self, scan_number: int, scan_id: str) -> None: self.status_msg: messages.ScanStatusMessage | None = None self.scan_segments = {} self.scan_finished = False - self.num_points = None + self.num_points: int | None = None self.baseline = {} - self.async_writer = None + self.async_writer: AsyncWriter | None = None + self.beamline_states: dict[str, list[messages.BeamlineStateMessage]] = defaultdict(list) self.metadata = {} self.file_references = {} self.start_time = None @@ -96,6 +98,9 @@ def __init__(self, config: ServiceConfig, connector_cls: type[RedisConnector]) - self.file_writer_config = self._service_config.config.get("file_writer") self._start_device_manager() self.device_configuration = {} + self.available_beamline_states: list[messages.BeamlineStateConfig] = [] + self.beamline_state_subscriptions: set[str] = set() + self.beamline_states: dict[str, messages.BeamlineStateMessage] = {} self.connector.register( MessageEndpoints.scan_segment(), cb=self._scan_segment_callback, parent=self ) @@ -107,8 +112,14 @@ def __init__(self, config: ServiceConfig, connector_cls: type[RedisConnector]) - cb=self._device_configuration_callback, parent=self, ) + self.connector.register( + MessageEndpoints.available_beamline_states(), + cb=self._update_available_beamline_states, + parent=self, + from_start=True, + ) self.async_writer = None - self.scan_storage = {} + self.scan_storage: dict[str, ScanStorage] = {} self.file_writer = HDF5FileWriter(self) self.status = messages.BECStatus.RUNNING self.refresh_device_configs() @@ -134,6 +145,61 @@ def _device_configuration_callback(msg, *, parent: FileWriterManager): device = topic.split("/")[-1] parent.update_device_configuration(device, msg) + @staticmethod + def _update_available_beamline_states( + msg: dict[str, messages.AvailableBeamlineStatesMessage], *, parent: FileWriterManager + ): + info = msg["data"] + parent.available_beamline_states = info.states + parent.update_beamline_state_subscriptions() + + def update_beamline_state_subscriptions(self): + """ + Update the beamline state subscriptions. + This method ensures that the file writer is subscribed to the beamline states that are currently available. + """ + current_subscriptions = set(self.beamline_state_subscriptions) + needed_subscriptions = set() + for state in self.available_beamline_states: + needed_subscriptions.add(state.name) + if state.name not in current_subscriptions: + self.connector.register( + MessageEndpoints.beamline_state(state.name), + cb=self._beamline_state_callback, + parent=self, + from_start=True, + ) + self.beamline_state_subscriptions.add(state.name) + # Unregister from states that are no longer needed + for topic in current_subscriptions - needed_subscriptions: + self.connector.unregister( + MessageEndpoints.beamline_state(topic), cb=self._beamline_state_callback + ) + self.beamline_state_subscriptions.remove(topic) + + @staticmethod + def _beamline_state_callback( + msg: dict[str, messages.BeamlineStateMessage], *, parent: FileWriterManager + ): + state_msg = msg["data"] + parent.update_beamline_state(state_msg) + + def update_beamline_state(self, state_msg: messages.BeamlineStateMessage): + """ + Update the beamline state in the file writer. + + Args: + state_msg (messages.BeamlineStateMessage): Beamline state message + """ + # We store the latest beamline state messages in the file writer so that + # they can be added to the file when writing, even if they do not change + # during the scan + self.beamline_states[state_msg.name] = state_msg + + for storage in self.scan_storage.values(): + if storage.metadata.get("status") == "open": + storage.beamline_states[state_msg.name].append(state_msg) + def _update_available_devices(self, *args) -> None: """ Update the available devices. @@ -173,6 +239,10 @@ def update_scan_storage_with_status(self, msg: messages.ScanStatusMessage) -> No scan_number=msg.content["info"].get("scan_number"), scan_id=scan_id ) + for state in self.beamline_states.values(): + if state.name not in self.scan_storage[scan_id].beamline_states: + self.scan_storage[scan_id].beamline_states[state.name].append(state) + # update the status message self.scan_storage[scan_id].status_msg = msg diff --git a/bec_server/bec_server/scan_server/beamline_state_manager.py b/bec_server/bec_server/scan_server/beamline_state_manager.py index f9b6f2e67..ecceb235b 100644 --- a/bec_server/bec_server/scan_server/beamline_state_manager.py +++ b/bec_server/bec_server/scan_server/beamline_state_manager.py @@ -75,4 +75,5 @@ def update_states(self, msg: messages.AvailableBeamlineStatesMessage) -> None: continue if state.config.model_dump() != state_msg.parameters: # The config has changed, we need to update the state + state.update_parameters(**state_msg.parameters) state.restart() From d3f3f1136f087119c9c5ccce8d142a604067fcc8 Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Mon, 23 Feb 2026 12:38:27 +0100 Subject: [PATCH 5/6] f - rename delete to remove --- bec_lib/bec_lib/bl_state_manager.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/bec_lib/bec_lib/bl_state_manager.py b/bec_lib/bec_lib/bl_state_manager.py index acb960c0b..9a75f9fbf 100644 --- a/bec_lib/bec_lib/bl_state_manager.py +++ b/bec_lib/bec_lib/bl_state_manager.py @@ -90,11 +90,11 @@ def get(self) -> BeamlineStateGet: msg = msg_container["data"] return {"status": msg.status, "label": msg.label} - def delete(self) -> None: + def remove(self) -> None: """ - Delete the current beamline state. + Remove the current beamline state. """ - self._manager.remove(self._state.name) + self._manager.delete(self._state.name) class BeamlineStateManager: @@ -175,11 +175,11 @@ def add(self, state: bl_states.BeamlineStateConfig) -> None: self._states[state.name] = state self._publish_states() - def remove(self, state_name: str) -> None: + def delete(self, state_name: str) -> None: """ - Remove a beamline state by name. + Delete a beamline state from the manager. Args: - state_name (str): The name of the state to remove. + state_name (str): The name of the state to delete. """ if state_name in self._states: del self._states[state_name] From c778b83b60da9c8c9c2723a45b9003413dc06bd1 Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Mon, 23 Feb 2026 12:47:36 +0100 Subject: [PATCH 6/6] f - fix tests --- bec_lib/tests/test_beamline_states.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bec_lib/tests/test_beamline_states.py b/bec_lib/tests/test_beamline_states.py index 3490a0e41..e230d4766 100644 --- a/bec_lib/tests/test_beamline_states.py +++ b/bec_lib/tests/test_beamline_states.py @@ -280,7 +280,7 @@ def test_client_get_returns_latest_status_message(self, state_manager): result = state_manager.shutter_open.get() assert result == {"status": "valid", "label": "ok"} - def test_add_and_remove_publish_updates(self, state_manager): + def test_add_and_delete_publish_updates(self, state_manager): state = bl_states.DeviceStateConfig( name="shutter_open", title="Shutter Open", device="shutter1" ) @@ -288,10 +288,10 @@ def test_add_and_remove_publish_updates(self, state_manager): state_manager.add(state) assert "shutter_open" in state_manager._states - state_manager.remove("shutter_open") + state_manager.delete("shutter_open") assert "shutter_open" not in state_manager._states - def test_client_delete_removes_state(self, state_manager): + def test_client_remove_state(self, state_manager): config = messages.BeamlineStateConfig( name="shutter_open", title="Shutter Open", @@ -301,7 +301,7 @@ def test_client_delete_removes_state(self, state_manager): update = messages.AvailableBeamlineStatesMessage(states=[config]) state_manager._on_state_update({"data": update}, parent=state_manager) - state_manager.shutter_open.delete() + state_manager.shutter_open.remove() assert "shutter_open" not in state_manager._states