diff --git a/bec_lib/bec_lib/bl_checks.py b/bec_lib/bec_lib/bl_checks.py deleted file mode 100644 index 8bdc7c3bf..000000000 --- a/bec_lib/bec_lib/bl_checks.py +++ /dev/null @@ -1,246 +0,0 @@ -""" -This module provides the BeamlineChecks class, which is used to perform beamline checks. It also provides the bl_check -decorator. -""" - -import builtins -import datetime -import functools -import threading -import time -from collections import deque -from uuid import uuid4 - -from typeguard import typechecked - -from bec_lib.bl_conditions import BeamlineCondition -from bec_lib.logger import bec_logger - -logger = bec_logger.logger - - -class BeamlineCheckError(Exception): - pass - - -class BeamlineCheckRepeat(Exception): - pass - - -def bl_check(fcn): - """Decorator to perform rpc calls.""" - - @functools.wraps(fcn) - def bl_check_wrapper(*args, **kwargs): - client = builtins.__dict__.get("bec") - bl_checks = client.bl_checks - _run_with_bl_checks(bl_checks, fcn, *args, **kwargs) - - return bl_check_wrapper - - -def _run_with_bl_checks(bl_checks, fcn, *args, **kwargs): - # pylint: disable=protected-access - chk = {"id": str(uuid4()), "fcn": fcn, "args": args, "kwargs": kwargs} - bl_checks._levels.append(chk) - nested_call = len(bl_checks._levels) > 1 - if bl_checks._is_paused and bl_checks._beamline_checks: - logger.warning( - "Beamline checks are currently paused. Use `bec.bl_checks.resume()` to reactivate them." - ) - try: - if nested_call: - # check if the beam was okay so far - if not bl_checks.beam_is_okay: - raise BeamlineCheckError("Beam is not okay.") - else: - bl_checks.reset() - bl_checks.wait_for_beamline_checks() - successful = False - while not successful: - try: - successful, res = _run_on_failure(bl_checks, fcn, *args, **kwargs) - - if not bl_checks.beam_is_okay: - successful = False - bl_checks.wait_for_beamline_checks() - except BeamlineCheckRepeat: - successful = False - return res - - finally: - bl_checks._levels.pop() - - -def _run_on_failure(bl_checks, fcn, *args, **kwargs) -> tuple: - try: - res = fcn(*args, **kwargs) - return (True, res) - except BeamlineCheckError: - bl_checks.wait_for_beamline_checks() - return (False, None) - - -class BeamlineChecks: - def __init__(self, client, *args, **kwargs): - super().__init__(*args, **kwargs) - self.client = client - self.send_to_scilog = True - self._beam_is_okay = True - self._beamline_checks = {} - self._stop_beam_check_event = threading.Event() - self._beam_check_thread = None - self._started = False - self._is_paused = False - self._check_msgs = [] - self._levels = deque() - - @typechecked - def register(self, check: BeamlineCondition): - """ - Register a beamline check. - - Args: - check (BeamlineCondition): The beamline check to register. - """ - self._beamline_checks[check.name] = check - setattr(self, check.name, check) - - def pause(self) -> None: - """ - Pause beamline checks. This will disable all checks. Use `resume` to - reactivate the checks. - """ - self._is_paused = True - - def resume(self) -> None: - """ - Resume all paused beamline checks. - """ - self._is_paused = False - - def available_checks(self) -> None: - """ - Print all available beamline checks - """ - for name, check in self._beamline_checks.items(): - enabled = f"ENABLED: {check.enabled}" - print(f"{name:<20} {enabled}") - - def disable_check(self, name: str) -> None: - """ - Disable a beamline check. - - Args: - name (str): The name of the beamline check to disable. - """ - if name not in self._beamline_checks: - raise ValueError(f"Beamline check {name} not registered.") - self._beamline_checks[name].enabled = False - - def enable_check(self, name: str) -> None: - """ - Enable a beamline check. - - Args: - name (str): The name of the beamline check to enable. - """ - if name not in self._beamline_checks: - raise ValueError(f"Beamline check {name} not registered.") - self._beamline_checks[name].enabled = True - - def disable_all_checks(self) -> None: - """ - Disable all beamline checks. - """ - for name in self._beamline_checks: - self.disable_check(name) - - def enable_all_checks(self) -> None: - """ - Enable all beamline checks. - """ - for name in self._beamline_checks: - self.enable_check(name) - - def _run_beamline_checks(self): - msgs = [] - for name, check in self._beamline_checks.items(): - if not check.enabled: - continue - if check.run(): - continue - msgs.append(check.on_failure_msg()) - self._beam_is_okay = False - return msgs - - def _check_beam(self): - while not self._stop_beam_check_event.wait(timeout=1): - self._check_msgs = self._run_beamline_checks() - - def start(self): - """Start the beamline checks.""" - if self._started: - return - self._beam_is_okay = True - - self._beam_check_thread = threading.Thread(target=self._check_beam, daemon=True) - self._beam_check_thread.start() - self._started = True - - def stop(self): - """Stop the beamline checks""" - if not self._started: - return - - self._stop_beam_check_event.set() - self._beam_check_thread.join() - - def reset(self): - self._beam_is_okay = True - self._check_msgs = [] - - @property - def beam_is_okay(self): - return self._beam_is_okay - - def _print_beamline_checks(self): - for msg in self._check_msgs: - logger.warning(msg) - - def wait_for_beamline_checks(self): - self._print_beamline_checks() - if self.send_to_scilog and not self.beam_is_okay: - self._send_to_scilog( - f"Beamline checks failed at {str(datetime.datetime.now())}: {''.join(self._check_msgs)}", - pen="red", - ) - - self._run_beamline_checks_until_okay() - - if self.send_to_scilog: - self._send_to_scilog( - f"Operation resumed at {str(datetime.datetime.now())}.", pen="green" - ) - - def _run_beamline_checks_until_okay(self): - while True: - if self._beam_status_is_okay(): - break - self._print_beamline_checks() - time.sleep(5) - - def _beam_status_is_okay(self) -> bool: - self._beam_is_okay = True - self._check_msgs = self._run_beamline_checks() - return self._beam_is_okay - - def _send_to_scilog(self, msg, pen="red"): - try: - msg = self.client.logbook.LogbookMessage() - msg.add_text(f"

{msg}

").add_tag( - ["BEC", "beam_check"] - ) - self.client.logbook.send_logbook_message(msg) - except Exception: - logger.warning("Failed to send update to SciLog.") diff --git a/bec_lib/bec_lib/bl_conditions.py b/bec_lib/bec_lib/bl_conditions.py deleted file mode 100644 index d1ae7fb98..000000000 --- a/bec_lib/bec_lib/bl_conditions.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -This module contains classes for beamline checks, used to check the beamline status. -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -if TYPE_CHECKING: # pragma: no cover - from bec_lib.device import Device - - -class BeamlineCondition(ABC): - """Abstract base class for beamline checks.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.enabled = True - - @property - @abstractmethod - def name(self) -> str: - """Return a name for the beamline check.""" - - @abstractmethod - def run(self) -> bool: - """Run the beamline check and return True if the beam is okay, False otherwise.""" - - @abstractmethod - def on_failure_msg(self) -> str: - """Return a message that will be displayed if the beamline check fails.""" - - -class ShutterCondition(BeamlineCondition): - """Check if the shutter is open.""" - - def __init__(self, shutter: Device): - super().__init__() - self.shutter = shutter - - @property - def name(self): - return "shutter" - - def run(self): - shutter_val = self.shutter.read(cached=True) - return shutter_val["value"].lower() == "open" - - def on_failure_msg(self): - return "Check beam failed: Shutter is closed." - - -class LightAvailableCondition(BeamlineCondition): - """Check if the light is available.""" - - def __init__(self, machine_status: Device): - super().__init__() - self.machine_status = machine_status - - @property - def name(self): - return "light_available" - - def run(self): - machine_status = self.machine_status.read(cached=True) - return machine_status["value"] in ["Light Available", "Light-Available"] - - def on_failure_msg(self): - return "Check beam failed: Light not available." - - -class FastOrbitFeedbackCondition(BeamlineCondition): - """Check if the fast orbit feedback is running.""" - - def __init__(self, sls_fast_orbit_feedback: Device): - super().__init__() - self.sls_fast_orbit_feedback = sls_fast_orbit_feedback - - @property - def name(self): - return "fast_orbit_feedback" - - def run(self): - fast_orbit_feedback = self.sls_fast_orbit_feedback.read(cached=True) - return fast_orbit_feedback["value"] == "running" - - def on_failure_msg(self): - return "Check beam failed: Fast orbit feedback is not running." diff --git a/bec_lib/bec_lib/bl_state_manager.py b/bec_lib/bec_lib/bl_state_manager.py new file mode 100644 index 000000000..9a75f9fbf --- /dev/null +++ b/bec_lib/bec_lib/bl_state_manager.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +import inspect +from inspect import Parameter, Signature +from typing import TYPE_CHECKING, TypedDict + +from pydantic import BaseModel +from rich.console import Console +from rich.table import Table + +from bec_lib import bl_states, messages +from bec_lib.endpoints import MessageEndpoints + +if TYPE_CHECKING: + from bec_lib.bl_states import BeamlineStateConfig + from bec_lib.client import BECClient + + +def build_signature_from_model(model: BaseModel, skip: set[str] | None = None) -> Signature: + """ + Build a function signature from a Pydantic model. The parameters of the signature will match the fields of the model. + """ + parameters = [] + skip = skip or set() + + for name, field in model.model_fields.items(): + if name in skip: + continue + annotation = field.annotation or inspect.Parameter.empty + parameters.append( + Parameter( + name=name, kind=Parameter.KEYWORD_ONLY, default=field.default, annotation=annotation + ) + ) + + return Signature(parameters) + + +class BeamlineStateGet(TypedDict): + """ + TypedDict for the return value of the get method of a beamline state client. + """ + + status: str + label: str + + +class BeamlineStateClientBase: + """Base class for beamline state clients.""" + + def __init__(self, manager: BeamlineStateManager, state: BeamlineStateConfig) -> None: + self._manager = manager + self._connector = manager._connector + self._state = state + self._skip_parameters = {"name"} + + # pylint: disable=unnecessary-lambda + self._run = lambda **kwargs: self._run_update(**kwargs) + self._update_signature() + + def _update_signature(self) -> None: + # Dynamically update the signature of the update_parameters method to match the parameters of the state config + setattr(self, "update_parameters", self._run) + setattr( + getattr(self, "update_parameters"), + "__signature__", + build_signature_from_model(self._state, skip=self._skip_parameters), + ) + + def _run_update(self, **kwargs) -> None: + if not kwargs: + return + if self._skip_parameters.intersection(kwargs): + raise ValueError(f"Invalid parameters: {self._skip_parameters.intersection(kwargs)}") + self._state = self._state.model_copy(update=kwargs) + self._manager._update_state(self._state) # pylint: disable=protected-access + + def get(self) -> BeamlineStateGet: + """ + Get the current status of the beamline state. Returns a dictionary with keys "status" and "label". + + Returns: + BeamlineStateGet: A dictionary containing the status and label of the beamline state. + """ + msg_container: dict[str, messages.BeamlineStateMessage] = self._connector.get_last( + MessageEndpoints.beamline_state(self._state.name) + ) + if not msg_container: + return {"status": "unknown", "label": "No state information available."} + msg = msg_container["data"] + return {"status": msg.status, "label": msg.label} + + def remove(self) -> None: + """ + Remove the current beamline state. + """ + self._manager.delete(self._state.name) + + +class BeamlineStateManager: + """Manager for beamline states.""" + + def __init__(self, client: BECClient) -> None: + self._client = client + self._connector = client.connector + self._states: dict[str, BeamlineStateConfig] = {} + self._connector.register( + MessageEndpoints.available_beamline_states(), + cb=self._on_state_update, + parent=self, + from_start=True, + ) + + @staticmethod + def _on_state_update(msg_dict: dict, *, parent: BeamlineStateManager, **_kwargs) -> None: + # type: ignore ; we know it's an AvailableBeamlineStatesMessage + msg: messages.AvailableBeamlineStatesMessage = msg_dict["data"] + parent._update_states(msg.states) # pylint: disable=protected-access + + def _update_state(self, state: BeamlineStateConfig) -> None: + if state.name in self._states: + self._states[state.name] = state + self._publish_states() + return + raise ValueError(f"State with name {state.name} not found") + + def _update_states(self, states: list[messages.BeamlineStateConfig]) -> None: + remove_state_names = set(self._states) - set(state.name for state in states) + + added_state_names = set(state.name for state in states) - set(self._states) + added_states = {state.name: state for state in states if state.name in added_state_names} + + for state_name in remove_state_names: + if hasattr(self, state_name): + delattr(self, state_name) + self._states.pop(state_name, None) + + for state_name, state in added_states.items(): + state_class = getattr(bl_states, state.state_type) + model_cls = state_class.CONFIG_CLASS + model_instance = model_cls(**state.parameters) + instance = BeamlineStateClientBase(manager=self, state=model_instance) + setattr(self, state.name, instance) + self._states[state.name] = model_instance + + def _publish_states(self) -> None: + bl_states_container = [ + messages.BeamlineStateConfig( + name=state.name, + title=state.title if state.title else state.name, + state_type=state.state_type, + parameters=state.model_dump(), + ) + for state in self._states.values() + ] + msg = messages.AvailableBeamlineStatesMessage(states=bl_states_container) + self._connector.xadd( + MessageEndpoints.available_beamline_states(), + {"data": msg}, + max_size=1, + approximate=False, + ) + + ########################## + ##### Public API ######### + ########################## + + def add(self, state: bl_states.BeamlineStateConfig) -> None: + """ + Add a new beamline state to the manager. + Args: + state (BeamlineStateConfig): The beamline state to add. + """ + + self._states[state.name] = state + self._publish_states() + + def delete(self, state_name: str) -> None: + """ + Delete a beamline state from the manager. + Args: + state_name (str): The name of the state to delete. + """ + if state_name in self._states: + del self._states[state_name] + self._publish_states() + + def show_all(self): + """ + Pretty print all beamline states using rich. + """ + + def _format_parameters(state_config: bl_states.BeamlineStateConfig) -> str: + parameter_dict = state_config.model_dump(exclude={"name", "title"}, exclude_none=True) + if not parameter_dict: + return "-" + return "\n".join(f"{key}={value}" for key, value in parameter_dict.items()) + + def _status_style(status_value: str) -> str: + status_styles = {"valid": "green3", "invalid": "red3", "warning": "yellow3"} + return status_styles.get(status_value.lower(), "grey50") + + console = Console() + table = Table(title="Beamline States", padding=(0, 1, 1, 1)) + table.add_column("Name", style="magenta", no_wrap=True) + table.add_column("Type", style="grey70") + table.add_column("Parameters", style="grey70") + table.add_column("Status") + table.add_column("Label") + + for state in self._states.values(): + params = _format_parameters(state) + status = ( + getattr(self, state.name).get() + if hasattr(self, state.name) + else {"status": "unknown", "label": "No state information available."} + ) + status_value = str(status.get("status", "")) + status_style = _status_style(status_value) + table.add_row( + str(state.name), + str(state.state_type), + str(params), + f"[{status_style}]{status_value}[/{status_style}]", + f"[{status_style}]{str(status.get('label', ''))}[/{status_style}]", + ) + + console.print(table) diff --git a/bec_lib/bec_lib/bl_states.py b/bec_lib/bec_lib/bl_states.py new file mode 100644 index 000000000..fb50daa4a --- /dev/null +++ b/bec_lib/bec_lib/bl_states.py @@ -0,0 +1,261 @@ +from __future__ import annotations + +import keyword +from abc import ABC, abstractmethod +from typing import ClassVar, Generic, Type, TypeVar + +from pydantic import BaseModel, field_validator + +from bec_lib import messages +from bec_lib.device import DeviceBase +from bec_lib.endpoints import MessageEndpoints +from bec_lib.redis_connector import MessageObject, RedisConnector + + +class BeamlineStateConfig(BaseModel): + """ + Base Configuration for a beamline state. + """ + + state_type: ClassVar[str] = "BeamlineState" + + name: str + title: str | None = None + + model_config = {"extra": "forbid", "arbitrary_types_allowed": True} + + @field_validator("name") + @classmethod + def validate_name(cls, v: str) -> str: + """ + Validate that the state name is a valid Python identifier and does not conflict with reserved method names. + """ + if not v.isidentifier(): + raise ValueError(f"State name '{v}' must be a valid Python identifier.") + if keyword.iskeyword(v): + raise ValueError(f"State name '{v}' cannot be a reserved Python keyword.") + if v in {"add", "remove", "show_all"}: + raise ValueError(f"State name '{v}' is reserved and cannot be used.") + return v + + +class DeviceStateConfig(BeamlineStateConfig): + """ + Configuration for a device-based beamline state. + """ + + state_type: ClassVar[str] = "DeviceState" + + device: DeviceBase | str + signal: DeviceBase | str | None = None + + @field_validator("device", "signal", mode="before") + @classmethod + def validate_device(cls, v: DeviceBase | str) -> str: + """ + Validate that the device is either a string or a DeviceBase instance. If it's a DeviceBase instance, return its name. + """ + if isinstance(v, DeviceBase): + return v.dotted_name + return v + + +class DeviceWithinLimitsStateConfig(DeviceStateConfig): + """ + Configuration for a device within limits beamline state. + """ + + state_type: ClassVar[str] = "DeviceWithinLimitsState" + + min_limit: float | None = None + max_limit: float | None = None + tolerance: float = 0.1 + + +C = TypeVar("C", bound=BeamlineStateConfig) +D = TypeVar("D", bound=DeviceStateConfig) + + +class BeamlineState(ABC, Generic[C]): + """Abstract base class for beamline states.""" + + CONFIG_CLASS: Type[C] + + def __init__( + self, config: C | None = None, redis_connector: RedisConnector | None = None, **kwargs + ) -> None: + self.config = config or self.CONFIG_CLASS(**kwargs) + self.connector = redis_connector + self._last_state: messages.BeamlineStateMessage | None = None + + def update_parameters(self, **kwargs) -> None: + """Update the configuration parameters of the state.""" + self.config = self.CONFIG_CLASS(**{**self.config.model_dump(), **kwargs}) + + @abstractmethod + def evaluate(self, *args, **kwargs) -> messages.BeamlineStateMessage | None: + """Evaluate the state and return its state.""" + + def start(self) -> None: + """Start monitoring the state if needed.""" + + def stop(self) -> None: + """Stop monitoring the state if needed.""" + + def restart(self) -> None: + """Restart the state monitoring.""" + self.stop() + self.start() + + def _emit_state(self, state_msg: messages.BeamlineStateMessage) -> None: + if self.connector is None: + return + is_different = ( + state_msg.model_dump(exclude={"timestamp"}) + != self._last_state.model_dump(exclude={"timestamp"}) + if self._last_state + else True + ) + if self._last_state is None: + is_different = True + if is_different: + self._last_state = state_msg + self.connector.xadd( + MessageEndpoints.beamline_state(self.config.name), + {"data": state_msg}, + max_size=1, + approximate=False, + ) + + +class DeviceBeamlineState(BeamlineState[D], Generic[D]): + """A beamline state that depends on a device reading.""" + + CONFIG_CLASS: Type[D] + + def __init__( + self, config: D | None = None, redis_connector: RedisConnector | None = None, **kwargs + ) -> None: + super().__init__(config, redis_connector, **kwargs) + + def start(self) -> None: + if self.connector is None: + raise RuntimeError("Redis connector is not set.") + msg = self.connector.get(MessageEndpoints.device_readback(self.config.device)) + if msg is not None: + self._update_device_state( + MessageObject( + topic=MessageEndpoints.device_readback(self.config.device).endpoint, value=msg + ), + parent=self, + ) + self.connector.register( + MessageEndpoints.device_readback(self.config.device), + cb=self._update_device_state, + parent=self, + ) + + def stop(self) -> None: + if self.connector is None: + return + self.connector.unregister( + MessageEndpoints.device_readback(self.config.device), cb=self._update_device_state + ) + + @staticmethod + def _update_device_state(msg_obj: MessageObject, parent: DeviceBeamlineState) -> None: + + # Since this is called from the Redis connector, we + assert parent.connector is not None + + msg: messages.DeviceMessage = msg_obj.value # type: ignore ; we know it's a DeviceMessage + out = parent.evaluate(msg) + if out is None: + return + parent._emit_state(out) + + +class ShutterState(DeviceBeamlineState[DeviceStateConfig]): + """ + A state that checks if the shutter is open. + + Example: + shutter_state = ShutterState(name="shutter_open") + shutter_state.configure(device="shutter1") + bec.beamline_states.add(shutter_state) + """ + + CONFIG_CLASS = DeviceStateConfig + + def evaluate( + self, msg: messages.DeviceMessage, *args, **kwargs + ) -> messages.BeamlineStateMessage: + val = msg.signals.get(self.config.signal, {}).get("value", "").lower() + if val == "open": + return messages.BeamlineStateMessage( + name=self.config.name, status="valid", label="Shutter is open." + ) + return messages.BeamlineStateMessage( + name=self.config.name, status="invalid", label="Shutter is closed." + ) + + +class DeviceWithinLimitsState(DeviceBeamlineState[DeviceWithinLimitsStateConfig]): + """ + A state that checks if a positioner is within limits. + + Example: + device_state = DeviceWithinLimitsState(name="sample_x_within_limits") + device_state.configure(device="sample_x", signal="sample_x_signal_name", min_limit=0.0, max_limit=10.0) + bec.beamline_states.add(device_state) + + """ + + CONFIG_CLASS = DeviceWithinLimitsStateConfig + + def evaluate( + self, msg: messages.DeviceMessage, *args, **kwargs + ) -> messages.BeamlineStateMessage: + """ + Evaluate if the positioner is within the defined limits. If it is outside the limits, + return an invalid state. Otherwise, return a valid state. If it is within 10% of the limits, + return a warning state. + """ + + if self.config.min_limit is None: + self.config.min_limit = float("-inf") + if self.config.max_limit is None: + self.config.max_limit = float("inf") + + signal_name = self.config.signal if self.config.signal is not None else self.config.device + + val = msg.signals.get(signal_name, {}).get("value", None) + if val is None: + return messages.BeamlineStateMessage( + name=self.config.name, + status="invalid", + label=f"Positioner {self.config.device} value not found.", + ) + + if val < self.config.min_limit or val > self.config.max_limit: + return messages.BeamlineStateMessage( + name=self.config.name, + status="invalid", + label=f"Positioner {self.config.device} out of limits", + ) + + min_warning_threshold = self.config.min_limit + self.config.tolerance + max_warning_threshold = self.config.max_limit - self.config.tolerance + + if val < min_warning_threshold or val > max_warning_threshold: + return messages.BeamlineStateMessage( + name=self.config.name, + status="warning", + label=f"Positioner {self.config.device} near limits", + ) + + return messages.BeamlineStateMessage( + name=self.config.name, + status="valid", + label=f"Positioner {self.config.device} within limits", + ) diff --git a/bec_lib/bec_lib/client.py b/bec_lib/bec_lib/client.py index 5a34ba8fe..7c9aaaa7b 100644 --- a/bec_lib/bec_lib/client.py +++ b/bec_lib/bec_lib/client.py @@ -20,7 +20,7 @@ from bec_lib.alarm_handler import AlarmHandler, Alarms from bec_lib.bec_service import BECService -from bec_lib.bl_checks import BeamlineChecks +from bec_lib.bl_state_manager import BeamlineStateManager from bec_lib.callback_handler import CallbackHandler, EventType from bec_lib.config_helper import ConfigHelperUser from bec_lib.dap_plugins import DAPPlugins @@ -150,7 +150,6 @@ def __init__( self._live_updates = None self.dap = None self.device_monitor = None - self.bl_checks = None self.scans_namespace = SimpleNamespace() self._hli_funcs = {} self.metadata = {} @@ -161,6 +160,7 @@ def __init__( self._initialized = True self._username = "" self._system_user = "" + self.beamline_states = None def __new__(cls, *args, forced=False, **kwargs): if forced or BECClient._client is None: @@ -235,10 +235,9 @@ def _start_services(self): self.config = ConfigHelperUser(self.device_manager) self.history = ScanHistory(client=self) self.dap = DAPPlugins(self) - self.bl_checks = BeamlineChecks(self) - self.bl_checks.start() self.device_monitor = DeviceMonitorPlugin(self.connector) self._update_username() + self.beamline_states = BeamlineStateManager(client=self) def alarms(self, severity=Alarms.WARNING): """get the next alarm with at least the specified severity""" @@ -328,8 +327,6 @@ def shutdown(self, per_thread_timeout_s: float | None = None): self.queue.shutdown() if self.alarm_handler: self.alarm_handler.shutdown() - if self.bl_checks: - self.bl_checks.stop() if self.history is not None: # pylint: disable=protected-access self.history._shutdown() diff --git a/bec_lib/bec_lib/endpoints.py b/bec_lib/bec_lib/endpoints.py index 89bd36d17..f99104650 100644 --- a/bec_lib/bec_lib/endpoints.py +++ b/bec_lib/bec_lib/endpoints.py @@ -1764,6 +1764,40 @@ def macro_update(): endpoint=endpoint, message_type=messages.MacroUpdateMessage, message_op=MessageOp.SEND ) + @staticmethod + def beamline_state(state_name: str): + """ + Endpoint for beamline state. This endpoint is used to publish the beamline state + using a messages.BeamlineStateMessage message. + + Args: + state_name (str): State name. + Returns: + EndpointInfo: Endpoint for beamline state. + """ + endpoint = f"{EndpointType.INFO.value}/beamline_state/{state_name}" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.BeamlineStateMessage, + message_op=MessageOp.STREAM, + ) + + @staticmethod + def available_beamline_states(): + """ + Endpoint for updating the available beamline states. This endpoint is used to + publish beamline state updates using a messages.AvailableBeamlineStatesMessage message. + + Returns: + EndpointInfo: Endpoint for beamline state updates. + """ + endpoint = f"{EndpointType.INFO.value}/available_beamline_states" + return EndpointInfo( + endpoint=endpoint, + message_type=messages.AvailableBeamlineStatesMessage, + message_op=MessageOp.STREAM, + ) + @staticmethod def atlas_websocket_state(deployment_name: str, host_id: str): """ diff --git a/bec_lib/bec_lib/messages.py b/bec_lib/bec_lib/messages.py index e321a464e..2b561e68d 100644 --- a/bec_lib/bec_lib/messages.py +++ b/bec_lib/bec_lib/messages.py @@ -1953,3 +1953,49 @@ class GameLeaderboardMessage(BECMessage): msg_type: ClassVar[str] = "game_leaderboard_message" game_name: str leaderboard: list[GameScoreMessage] + + +class BeamlineStateMessage(BECMessage): + """ + Message for beamline state updates + + Args: + name (str): Name of the beamline state + status (Literal["valid", "invalid", "warning"]): Status of the beamline state + label (str): Description of the beamline state + """ + + msg_type: ClassVar[str] = "beamline_state_message" + name: str + status: Literal["valid", "invalid", "warning"] + label: str + timestamp: float = Field(default_factory=time.time) + + +class BeamlineStateConfig(BaseModel): + """ + Entry for beamline state update + + Args: + name (str): Name of the beamline state + title (str): Title of the beamline state + state_type (str): Type of the beamline state + parameters (dict, optional): Additional parameters for the state + """ + + name: str + title: str + state_type: str + parameters: dict = Field(default_factory=dict) + + +class AvailableBeamlineStatesMessage(BECMessage): + """ + Message for updating beamline states + + Args: + states (list[BeamlineStateConfig]): List of beamline state update entries + """ + + msg_type: ClassVar[str] = "beamline_state_update_message" + states: list[BeamlineStateConfig] diff --git a/bec_lib/tests/test_beamline_checks.py b/bec_lib/tests/test_beamline_checks.py deleted file mode 100644 index 4357bcc00..000000000 --- a/bec_lib/tests/test_beamline_checks.py +++ /dev/null @@ -1,167 +0,0 @@ -# pylint: skip-file -from unittest import mock - -import pytest - -from bec_lib.bl_checks import ( - BeamlineCheckError, - BeamlineChecks, - _run_on_failure, - _run_with_bl_checks, -) - - -def test_run_with_bl_checks(): - bl_checks = mock.MagicMock() - bl_checks._levels = [] - bl_checks._is_paused = False - _run_with_bl_checks(bl_checks, mock.MagicMock(), 1, 2, 3, a=4, b=5) - assert not bl_checks._levels - - -def test_bl_check_raises_on_failed_nested_calls(): - bl_checks = mock.MagicMock() - bl_checks._levels = [{"fcn": mock.MagicMock()}] - bl_checks._is_paused = False - bl_checks.beam_is_okay = False - with pytest.raises(BeamlineCheckError): - _run_with_bl_checks(bl_checks, mock.MagicMock(), 1, 2, 3, a=4, b=5) - - -def test_bl_check_run_on_failure(): - bl_checks = mock.MagicMock() - bl_checks._levels = [] - bl_checks._is_paused = False - bl_checks.beam_is_okay = False - fcn = mock.MagicMock() - fcn.side_effect = BeamlineCheckError - _run_on_failure(bl_checks, fcn, 1, 2, 3, a=4, b=5) - bl_checks.wait_for_beamline_checks.assert_called_once() - - -def test_bl_check_register(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - condition = mock.MagicMock() - condition.name = "test" - bl_check.register(condition) - assert bl_check.test == condition # pylint: disable=no-member - assert bl_check._beamline_checks["test"] == condition - - -def test_bl_check_pause(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check.pause() - assert bl_check._is_paused - - -def test_bl_check_resume(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._is_paused = True - bl_check.resume() - assert not bl_check._is_paused - - -def test_bl_check_reset(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._beam_is_okay = False - bl_check._check_msgs = ["test"] - bl_check.reset() - assert bl_check._beam_is_okay - assert not bl_check._check_msgs - - -def test_bl_check_disable_check(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._beamline_checks = {"test": mock.MagicMock()} - bl_check.disable_check("test") - assert not bl_check._beamline_checks["test"].enabled - - -def test_bl_check_enable_check(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._beamline_checks = {"test": mock.MagicMock()} - bl_check.enable_check("test") - assert bl_check._beamline_checks["test"].enabled - - -def test_bl_check_disable_all_checks(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._beamline_checks = {"test": mock.MagicMock()} - bl_check.disable_all_checks() - assert not bl_check._beamline_checks["test"].enabled - - -def test_bl_check_enable_all_checks(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._beamline_checks = {"test": mock.MagicMock()} - bl_check.enable_all_checks() - assert bl_check._beamline_checks["test"].enabled - - -def test_bl_check_run_beamline_checks(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._beamline_checks = {"test": mock.MagicMock()} - bl_check._beam_is_okay = True - bl_check._run_beamline_checks() - assert bl_check._beam_is_okay - bl_check._beamline_checks["test"].run.assert_called_once() - bl_check._beamline_checks["test"].on_failure_msg.assert_not_called() - - -def test_bl_check_send_to_scilog(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._send_to_scilog("test") - client.logbook.send_logbook_message.assert_called_once() - - -def test_bl_check_beam_status_is_okay(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - with mock.patch.object(bl_check, "_run_beamline_checks") as mock_run: - mock_run.return_value = [] - assert bl_check._beam_status_is_okay() is True - - -def test_bl_check_wait_for_beamline_checks(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._check_msgs = ["test"] - bl_check._beam_is_okay = False - with mock.patch.object(bl_check, "_print_beamline_checks") as mock_print: - with mock.patch.object(bl_check, "_send_to_scilog") as mock_send: - with mock.patch.object(bl_check, "_run_beamline_checks") as mock_run: - mock_run.return_value = ["test"] - bl_check.wait_for_beamline_checks() - mock_print.assert_called_once() - mock_send.assert_called() - - -def test_bl_check_start(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._started = False - bl_check._beam_is_okay = True - with mock.patch("bec_lib.bl_checks.threading.Thread") as mock_thread: - bl_check.start() - mock_thread.assert_called_once() - mock_thread.return_value.start.assert_called_once() - - -def test_bl_check_stop(): - client = mock.MagicMock() - bl_check = BeamlineChecks(client=client) - bl_check._started = True - bl_check._beam_is_okay = True - bl_check._beam_check_thread = mock.MagicMock() - bl_check.stop() - bl_check._beam_check_thread.join.assert_called_once() diff --git a/bec_lib/tests/test_beamline_states.py b/bec_lib/tests/test_beamline_states.py new file mode 100644 index 000000000..e230d4766 --- /dev/null +++ b/bec_lib/tests/test_beamline_states.py @@ -0,0 +1,317 @@ +from __future__ import annotations + +import inspect +from unittest import mock + +import pytest +from pydantic import BaseModel + +from bec_lib import bl_states, messages +from bec_lib.bl_state_manager import ( + BeamlineStateClientBase, + BeamlineStateManager, + build_signature_from_model, +) +from bec_lib.endpoints import MessageEndpoints +from bec_lib.redis_connector import MessageObject + + +@pytest.fixture +def state_manager(connected_connector): + client = mock.MagicMock() + client.connector = connected_connector + manager = BeamlineStateManager(client) + yield manager + + +class TestHelpers: + def test_build_signature_from_model(self): + class DemoConfig(BaseModel): + foo: int = 1 + bar: str = "abc" + + signature = build_signature_from_model(DemoConfig) + + assert list(signature.parameters) == ["foo", "bar"] + assert signature.parameters["foo"].kind == inspect.Parameter.KEYWORD_ONLY + assert signature.parameters["foo"].annotation is int + assert signature.parameters["bar"].default == "abc" + + +class TestConfigModels: + def test_beamline_state_config_valid_name(self): + config = bl_states.BeamlineStateConfig(name="shutter_open", title="Shutter") + assert config.name == "shutter_open" + + @pytest.mark.parametrize("invalid_name", ["state-name", "class", "add", "remove", "show_all"]) + def test_beamline_state_config_invalid_name(self, invalid_name): + with pytest.raises(ValueError): + bl_states.BeamlineStateConfig(name=invalid_name) + + def test_device_state_config_keeps_string_device_and_signal(self): + config = bl_states.DeviceStateConfig(name="state", device="samx", signal="samx") + assert config.device == "samx" + assert config.signal == "samx" + + +class TestBeamlineStateBase: + def test_beamline_state_initialization_and_update(self): + class ConcreteState(bl_states.BeamlineState[bl_states.BeamlineStateConfig]): + CONFIG_CLASS = bl_states.BeamlineStateConfig + + def evaluate(self, *args, **kwargs): + return messages.BeamlineStateMessage( + name=self.config.name, status="valid", label="ok" + ) + + state = ConcreteState(name="test_state") + + assert state.config.name == "test_state" + assert state.connector is None + assert state._last_state is None + + state.update_parameters(title="Test State") + assert state.config.title == "Test State" + + +class TestDeviceBeamlineState: + def test_start_requires_connector(self): + state = bl_states.ShutterState(name="shutter_open", device="shutter1", signal="shutter1") + + with pytest.raises(RuntimeError, match="Redis connector is not set"): + state.start() + + def test_start_registers_device_callback(self, connected_connector): + state = bl_states.ShutterState( + name="shutter_open", + device="shutter1", + signal="shutter1", + redis_connector=connected_connector, + ) + + with mock.patch.object(connected_connector, "register") as register: + state.start() + + register.assert_called_once_with( + MessageEndpoints.device_readback("shutter1"), + cb=state._update_device_state, + parent=state, + ) + + def test_stop_unregisters_device_callback(self, connected_connector): + state = bl_states.ShutterState( + name="shutter_open", + device="shutter1", + signal="shutter1", + redis_connector=connected_connector, + ) + + with mock.patch.object(connected_connector, "unregister") as unregister: + state.stop() + + unregister.assert_called_once_with( + MessageEndpoints.device_readback("shutter1"), cb=state._update_device_state + ) + + def test_update_device_state_publishes_when_state_changes(self, connected_connector): + state = bl_states.ShutterState( + name="shutter_open", + device="shutter1", + signal="shutter1", + redis_connector=connected_connector, + ) + + msg = messages.DeviceMessage( + signals={"shutter1": {"value": "open", "timestamp": 1.0}}, + metadata={"stream": "primary"}, + ) + msg_obj = MessageObject(value=msg, topic="test") + + state._update_device_state(msg_obj, parent=state) + + assert state._last_state is not None + assert state._last_state.status == "valid" + out = connected_connector.xread( + MessageEndpoints.beamline_state("shutter_open"), from_start=True + ) + assert out is not None + assert out[0]["data"].status == "valid" + + +class TestConcreteStates: + def test_shutter_state_open_and_closed(self, connected_connector): + state = bl_states.ShutterState( + name="shutter_open", + device="shutter1", + signal="shutter1", + redis_connector=connected_connector, + ) + + open_msg = messages.DeviceMessage( + signals={"shutter1": {"value": "OPEN", "timestamp": 1.0}}, + metadata={"stream": "primary"}, + ) + closed_msg = messages.DeviceMessage( + signals={"shutter1": {"value": "closed", "timestamp": 2.0}}, + metadata={"stream": "primary"}, + ) + + assert state.evaluate(open_msg).status == "valid" + assert state.evaluate(closed_msg).status == "invalid" + + def test_device_within_limits_state(self, connected_connector): + state = bl_states.DeviceWithinLimitsState( + name="sample_x_limits", + device="sample_x", + min_limit=0.0, + max_limit=10.0, + tolerance=0.1, + redis_connector=connected_connector, + ) + + valid = messages.DeviceMessage( + signals={"sample_x": {"value": 5.0, "timestamp": 1.0}}, metadata={"stream": "primary"} + ) + warning = messages.DeviceMessage( + signals={"sample_x": {"value": 0.05, "timestamp": 2.0}}, metadata={"stream": "primary"} + ) + invalid = messages.DeviceMessage( + signals={"sample_x": {"value": 11.0, "timestamp": 3.0}}, metadata={"stream": "primary"} + ) + missing = messages.DeviceMessage( + signals={"sample_x": {"timestamp": 4.0}}, metadata={"stream": "primary"} + ) + + assert state.evaluate(valid).status == "valid" + assert state.evaluate(warning).status == "warning" + assert state.evaluate(invalid).status == "invalid" + assert state.evaluate(missing).status == "invalid" + + +class TestBeamlineStateManager: + def test_manager_registers_for_state_updates(self, connected_connector): + client = mock.MagicMock() + client.connector = connected_connector + + with mock.patch.object(connected_connector, "register") as register: + BeamlineStateManager(client) + + register.assert_called_once_with( + MessageEndpoints.available_beamline_states(), + cb=mock.ANY, + parent=mock.ANY, + from_start=True, + ) + + def test_on_state_update_creates_client_attribute(self, state_manager): + config = messages.BeamlineStateConfig( + name="shutter_open", + title="Shutter Open", + state_type="ShutterState", + parameters={"name": "shutter_open", "title": "Shutter Open", "device": "shutter1"}, + ) + update = messages.AvailableBeamlineStatesMessage(states=[config]) + + state_manager._on_state_update({"data": update}, parent=state_manager) + + assert "shutter_open" in state_manager._states + assert isinstance(state_manager._states["shutter_open"], bl_states.DeviceStateConfig) + assert isinstance(getattr(state_manager, "shutter_open"), BeamlineStateClientBase) + + def test_update_parameters_from_client_updates_state_and_publishes(self, state_manager): + config = messages.BeamlineStateConfig( + name="limits", + title="Limits", + state_type="DeviceWithinLimitsState", + parameters={ + "name": "limits", + "title": "Limits", + "device": "samx", + "min_limit": 0.0, + "max_limit": 10.0, + }, + ) + update = messages.AvailableBeamlineStatesMessage(states=[config]) + state_manager._on_state_update({"data": update}, parent=state_manager) + + state_manager.limits.update_parameters(tolerance=0.25) + + assert state_manager._states["limits"].tolerance == 0.25 + + out = state_manager._connector.xread( + MessageEndpoints.available_beamline_states(), from_start=True + ) + assert out + assert isinstance(out[-1]["data"], messages.AvailableBeamlineStatesMessage) + + def test_client_get_returns_unknown_without_status_message(self, state_manager): + config = messages.BeamlineStateConfig( + name="shutter_open", + title="Shutter Open", + state_type="ShutterState", + parameters={"name": "shutter_open", "title": "Shutter Open", "device": "shutter1"}, + ) + update = messages.AvailableBeamlineStatesMessage(states=[config]) + state_manager._on_state_update({"data": update}, parent=state_manager) + + result = state_manager.shutter_open.get() + assert result == {"status": "unknown", "label": "No state information available."} + + def test_client_get_returns_latest_status_message(self, state_manager): + config = messages.BeamlineStateConfig( + name="shutter_open", + title="Shutter Open", + state_type="ShutterState", + parameters={"name": "shutter_open", "title": "Shutter Open", "device": "shutter1"}, + ) + update = messages.AvailableBeamlineStatesMessage(states=[config]) + state_manager._on_state_update({"data": update}, parent=state_manager) + + state_manager._connector.xadd( + MessageEndpoints.beamline_state("shutter_open"), + { + "data": messages.BeamlineStateMessage( + name="shutter_open", status="valid", label="ok" + ) + }, + max_size=1, + ) + + result = state_manager.shutter_open.get() + assert result == {"status": "valid", "label": "ok"} + + def test_add_and_delete_publish_updates(self, state_manager): + state = bl_states.DeviceStateConfig( + name="shutter_open", title="Shutter Open", device="shutter1" + ) + + state_manager.add(state) + assert "shutter_open" in state_manager._states + + state_manager.delete("shutter_open") + assert "shutter_open" not in state_manager._states + + def test_client_remove_state(self, state_manager): + config = messages.BeamlineStateConfig( + name="shutter_open", + title="Shutter Open", + state_type="ShutterState", + parameters={"name": "shutter_open", "title": "Shutter Open", "device": "shutter1"}, + ) + update = messages.AvailableBeamlineStatesMessage(states=[config]) + state_manager._on_state_update({"data": update}, parent=state_manager) + + state_manager.shutter_open.remove() + + assert "shutter_open" not in state_manager._states + + def test_show_all_prints_table(self, state_manager, capsys): + state = bl_states.DeviceStateConfig( + name="shutter_open", title="Shutter Open", device="shutter1" + ) + state_manager.add(state) + + state_manager.show_all() + + captured = capsys.readouterr() + assert "shutter_open" in (captured.out + captured.err) diff --git a/bec_lib/tests/test_bl_conditions.py b/bec_lib/tests/test_bl_conditions.py deleted file mode 100644 index 85e87df0b..000000000 --- a/bec_lib/tests/test_bl_conditions.py +++ /dev/null @@ -1,37 +0,0 @@ -from unittest import mock - -from bec_lib.bl_conditions import ( - FastOrbitFeedbackCondition, - LightAvailableCondition, - ShutterCondition, -) - - -def test_shutter_condition(): - device = mock.MagicMock() - shutter_condition = ShutterCondition(device) - shutter_condition.run() - device.read.assert_called_once() - assert shutter_condition.on_failure_msg() == "Check beam failed: Shutter is closed." - assert shutter_condition.name == "shutter" - - -def test_light_available_condition(): - device = mock.MagicMock() - light_available_condition = LightAvailableCondition(device) - light_available_condition.run() - device.read.assert_called_once() - assert light_available_condition.on_failure_msg() == "Check beam failed: Light not available." - assert light_available_condition.name == "light_available" - - -def test_fast_orbit_feedback_condition(): - device = mock.MagicMock() - fast_orbit_feedback_condition = FastOrbitFeedbackCondition(device) - fast_orbit_feedback_condition.run() - device.read.assert_called_once() - assert ( - fast_orbit_feedback_condition.on_failure_msg() - == "Check beam failed: Fast orbit feedback is not running." - ) - assert fast_orbit_feedback_condition.name == "fast_orbit_feedback" diff --git a/bec_server/bec_server/file_writer/default_writer.py b/bec_server/bec_server/file_writer/default_writer.py index 44b44c12c..f17d22b36 100644 --- a/bec_server/bec_server/file_writer/default_writer.py +++ b/bec_server/bec_server/file_writer/default_writer.py @@ -2,6 +2,9 @@ from typing import TYPE_CHECKING, Any +import h5py +import numpy as np + if TYPE_CHECKING: from bec_lib import messages from bec_lib.devicemanager import DeviceManagerBase @@ -20,6 +23,7 @@ def __init__( info_storage: dict, configuration: dict, file_references: dict[str, messages.FileMessage], + beamline_states: dict[str, list[messages.BeamlineStateMessage]], device_manager: DeviceManagerBase, ): self.storage = storage @@ -28,6 +32,7 @@ def __init__( self.file_references = file_references self.device_manager = device_manager self.info_storage = info_storage + self.beamline_states = beamline_states def get_storage_format(self) -> dict: """ @@ -105,6 +110,29 @@ def write_bec_entries(self) -> None: else: file_device.create_ext_link(name="data", target=msg.file_path, entry="/") + # create beamline states + states = {} + for state_name, state_values in self.beamline_states.items(): + dtype = np.dtype( + [ + ("label", h5py.string_dtype("utf-8")), + ("status", h5py.string_dtype("utf-8")), + ("timestamp", np.float64), + ] + ) + states[state_name] = np.array( + [ + (state_msg.label, state_msg.status, state_msg.timestamp) + for state_msg in state_values + ], + dtype=dtype, + ) + beamline_states_group = collection.create_group("states") + beamline_states_group.attrs["NX_class"] = "NXcollection" + for state_name, state_values in states.items(): + state_group = beamline_states_group.create_dataset(name=state_name, data=state_values) + state_group.attrs["NX_class"] = "NXcollection" + def format(self) -> None: """ Prepare the NeXus file format. diff --git a/bec_server/bec_server/file_writer/file_writer.py b/bec_server/bec_server/file_writer/file_writer.py index b589d5418..1ad1294a3 100644 --- a/bec_server/bec_server/file_writer/file_writer.py +++ b/bec_server/bec_server/file_writer/file_writer.py @@ -300,6 +300,7 @@ def write( info_storage=info_storage, configuration=configuration_data, file_references=data.file_references, + beamline_states=data.beamline_states, device_manager=self.file_writer_manager.device_manager, ).get_storage_format() diff --git a/bec_server/bec_server/file_writer/file_writer_manager.py b/bec_server/bec_server/file_writer/file_writer_manager.py index bb6e4fada..59bbbbec3 100644 --- a/bec_server/bec_server/file_writer/file_writer_manager.py +++ b/bec_server/bec_server/file_writer/file_writer_manager.py @@ -4,6 +4,7 @@ import threading import time import traceback +from collections import defaultdict from bec_lib import messages from bec_lib.alarm_handler import Alarms @@ -35,9 +36,10 @@ def __init__(self, scan_number: int, scan_id: str) -> None: self.status_msg: messages.ScanStatusMessage | None = None self.scan_segments = {} self.scan_finished = False - self.num_points = None + self.num_points: int | None = None self.baseline = {} - self.async_writer = None + self.async_writer: AsyncWriter | None = None + self.beamline_states: dict[str, list[messages.BeamlineStateMessage]] = defaultdict(list) self.metadata = {} self.file_references = {} self.start_time = None @@ -96,6 +98,9 @@ def __init__(self, config: ServiceConfig, connector_cls: type[RedisConnector]) - self.file_writer_config = self._service_config.config.get("file_writer") self._start_device_manager() self.device_configuration = {} + self.available_beamline_states: list[messages.BeamlineStateConfig] = [] + self.beamline_state_subscriptions: set[str] = set() + self.beamline_states: dict[str, messages.BeamlineStateMessage] = {} self.connector.register( MessageEndpoints.scan_segment(), cb=self._scan_segment_callback, parent=self ) @@ -107,8 +112,14 @@ def __init__(self, config: ServiceConfig, connector_cls: type[RedisConnector]) - cb=self._device_configuration_callback, parent=self, ) + self.connector.register( + MessageEndpoints.available_beamline_states(), + cb=self._update_available_beamline_states, + parent=self, + from_start=True, + ) self.async_writer = None - self.scan_storage = {} + self.scan_storage: dict[str, ScanStorage] = {} self.file_writer = HDF5FileWriter(self) self.status = messages.BECStatus.RUNNING self.refresh_device_configs() @@ -134,6 +145,61 @@ def _device_configuration_callback(msg, *, parent: FileWriterManager): device = topic.split("/")[-1] parent.update_device_configuration(device, msg) + @staticmethod + def _update_available_beamline_states( + msg: dict[str, messages.AvailableBeamlineStatesMessage], *, parent: FileWriterManager + ): + info = msg["data"] + parent.available_beamline_states = info.states + parent.update_beamline_state_subscriptions() + + def update_beamline_state_subscriptions(self): + """ + Update the beamline state subscriptions. + This method ensures that the file writer is subscribed to the beamline states that are currently available. + """ + current_subscriptions = set(self.beamline_state_subscriptions) + needed_subscriptions = set() + for state in self.available_beamline_states: + needed_subscriptions.add(state.name) + if state.name not in current_subscriptions: + self.connector.register( + MessageEndpoints.beamline_state(state.name), + cb=self._beamline_state_callback, + parent=self, + from_start=True, + ) + self.beamline_state_subscriptions.add(state.name) + # Unregister from states that are no longer needed + for topic in current_subscriptions - needed_subscriptions: + self.connector.unregister( + MessageEndpoints.beamline_state(topic), cb=self._beamline_state_callback + ) + self.beamline_state_subscriptions.remove(topic) + + @staticmethod + def _beamline_state_callback( + msg: dict[str, messages.BeamlineStateMessage], *, parent: FileWriterManager + ): + state_msg = msg["data"] + parent.update_beamline_state(state_msg) + + def update_beamline_state(self, state_msg: messages.BeamlineStateMessage): + """ + Update the beamline state in the file writer. + + Args: + state_msg (messages.BeamlineStateMessage): Beamline state message + """ + # We store the latest beamline state messages in the file writer so that + # they can be added to the file when writing, even if they do not change + # during the scan + self.beamline_states[state_msg.name] = state_msg + + for storage in self.scan_storage.values(): + if storage.metadata.get("status") == "open": + storage.beamline_states[state_msg.name].append(state_msg) + def _update_available_devices(self, *args) -> None: """ Update the available devices. @@ -173,6 +239,10 @@ def update_scan_storage_with_status(self, msg: messages.ScanStatusMessage) -> No scan_number=msg.content["info"].get("scan_number"), scan_id=scan_id ) + for state in self.beamline_states.values(): + if state.name not in self.scan_storage[scan_id].beamline_states: + self.scan_storage[scan_id].beamline_states[state.name].append(state) + # update the status message self.scan_storage[scan_id].status_msg = msg diff --git a/bec_server/bec_server/scan_server/beamline_state_manager.py b/bec_server/bec_server/scan_server/beamline_state_manager.py new file mode 100644 index 000000000..ecceb235b --- /dev/null +++ b/bec_server/bec_server/scan_server/beamline_state_manager.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import traceback + +from bec_lib import bl_states, messages +from bec_lib.alarm_handler import Alarms +from bec_lib.endpoints import MessageEndpoints +from bec_lib.messages import ErrorInfo +from bec_lib.redis_connector import RedisConnector + + +class BeamlineStateManager: + """Manager for beamline states.""" + + def __init__(self, connector: RedisConnector) -> None: + self.connector = connector + self._states: dict[str, bl_states.BeamlineState] = {} + self.connector.register( + MessageEndpoints.available_beamline_states(), + cb=self._handle_state_update, + parent=self, + from_start=True, + ) + + @staticmethod + def _handle_state_update(msg_dict: dict, *, parent: BeamlineStateManager, **_kwargs) -> None: + + msg: messages.AvailableBeamlineStatesMessage = msg_dict["data"] # type: ignore ; we know it's a AvailableBeamlineStatesMessage + try: + parent.update_states(msg) + except Exception as exc: + content = traceback.format_exc() + info = ErrorInfo( + exception_type=type(exc).__name__, + error_message=content, + compact_error_message="Error updating beamline states.", + ) + parent.connector.raise_alarm(severity=Alarms.WARNING, info=info) + + def update_states(self, msg: messages.AvailableBeamlineStatesMessage) -> None: + """ + Update the beamline states based on the received update message. + + Args: + msg (messages.AvailableBeamlineStatesMessage): The update message containing state updates. + """ + + # get the states that we need to remove + remove_state_names = set(self._states) - set(state.name for state in msg.states) + + added_state_names = set(state.name for state in msg.states) - set(self._states) + added_states = { + state.name: state for state in msg.states if state.name in added_state_names + } + + for state_name in remove_state_names: + if hasattr(self, state_name): + delattr(self, state_name) + self._states.pop(state_name, None) + + for state_name, state in added_states.items(): + state_class = getattr(bl_states, state.state_type) + if not issubclass(state_class, bl_states.BeamlineState): + raise ValueError(f"State type {state.state_type} not found in beamline states.") + model_cls = state_class.CONFIG_CLASS + model_instance = model_cls(**state.parameters) + state_instance = state_class(config=model_instance, redis_connector=self.connector) + state_instance.start() + self._states[state.name] = state_instance + + # Check if the config has changed for existing states and update them if needed + for state_msg in msg.states: + state = self._states.get(state_msg.name) + if state is None: + continue + if state.config.model_dump() != state_msg.parameters: + # The config has changed, we need to update the state + state.update_parameters(**state_msg.parameters) + state.restart() diff --git a/bec_server/bec_server/scan_server/scan_server.py b/bec_server/bec_server/scan_server/scan_server.py index 3af7de6dc..a63fac1c5 100644 --- a/bec_server/bec_server/scan_server/scan_server.py +++ b/bec_server/bec_server/scan_server/scan_server.py @@ -15,6 +15,7 @@ from bec_server.procedures.manager import ProcedureManager from bec_server.procedures.subprocess_worker import SubProcessWorker +from .beamline_state_manager import BeamlineStateManager from .scan_assembler import ScanAssembler from .scan_guard import ScanGuard from .scan_manager import ScanManager @@ -40,6 +41,8 @@ def __init__(self, config: ServiceConfig, connector_cls: type[RedisConnector]): use_subprocess_proc_worker=config.model.procedures.use_subprocess_worker ) self.status = messages.BECStatus.RUNNING + self.beamline_states = None + self._start_beamline_state_manager() def _start_device_manager(self): self.wait_for_service("DeviceServer") @@ -59,6 +62,9 @@ def _start_scan_assembler(self): def _start_scan_guard(self): self.scan_guard = ScanGuard(parent=self) + def _start_beamline_state_manager(self): + self.beamline_states = BeamlineStateManager(self.connector) + def _start_alarm_handler(self): self.connector.register(MessageEndpoints.alarm(), cb=self._alarm_callback, parent=self) diff --git a/bec_server/tests/tests_scan_server/test_beamline_state_manager.py b/bec_server/tests/tests_scan_server/test_beamline_state_manager.py new file mode 100644 index 000000000..5198220ba --- /dev/null +++ b/bec_server/tests/tests_scan_server/test_beamline_state_manager.py @@ -0,0 +1,127 @@ +import time +from unittest import mock + +import pytest + +from bec_lib import bl_states, messages +from bec_lib.endpoints import MessageEndpoints +from bec_server.scan_server import beamline_state_manager +from bec_server.scan_server.beamline_state_manager import BeamlineStateManager + + +@pytest.fixture +def state_manager(connected_connector): + manager = BeamlineStateManager(connected_connector) + yield manager + + +@pytest.fixture +def fake_bl_states(monkeypatch): + class FakeState(bl_states.BeamlineState[bl_states.DeviceStateConfig]): + CONFIG_CLASS = bl_states.DeviceStateConfig + + def __init__(self, config=None, redis_connector=None, **kwargs): + super().__init__(config=config, redis_connector=redis_connector, **kwargs) + self.started = False + self.restart_count = 0 + + def evaluate(self, *args, **kwargs): + return None + + def start(self): + self.started = True + + def restart(self): + self.restart_count += 1 + + monkeypatch.setattr(beamline_state_manager.bl_states, "ShutterState", FakeState) + return FakeState + + +def test_state_manager_fetches_states(): + """ + Test that the BeamlineStateManager fetches all available beamline states on initialization. + """ + + connector = mock.MagicMock() + state_manager = BeamlineStateManager(connector) + connector.register.assert_called_once_with( + MessageEndpoints.available_beamline_states(), + cb=state_manager._handle_state_update, + parent=state_manager, + from_start=True, + ) + + +@pytest.mark.timeout(5) +def test_state_manager_updates_states(state_manager, connected_connector, fake_bl_states): + """ + Test that the BeamlineStateManager updates its states correctly when receiving an update message. + """ + + # Initial state: no states + assert len(state_manager._states) == 0 + + msg = messages.AvailableBeamlineStatesMessage( + states=[ + messages.BeamlineStateConfig( + name="State1", + title="Shutter", + state_type="ShutterState", + parameters={"name": "State1", "title": "Shutter", "device": "shutter1"}, + ) + ] + ) + + connected_connector.xadd( + MessageEndpoints.available_beamline_states(), {"data": msg}, max_size=1 + ) + + # Give it some time to process + while len(state_manager._states) < 1: + time.sleep(0.1) + + msg = messages.AvailableBeamlineStatesMessage( + states=[ + messages.BeamlineStateConfig( + name="State1", + title="Shutter", + state_type="ShutterState", + parameters={"name": "State1", "title": "Shutter", "device": "shutter1"}, + ), + messages.BeamlineStateConfig( + name="State2", + title="Shutter2", + state_type="ShutterState", + parameters={"name": "State2", "title": "Shutter2", "device": "shutter2"}, + ), + ] + ) + + connected_connector.xadd( + MessageEndpoints.available_beamline_states(), {"data": msg}, max_size=1 + ) + + # Give it some time to process + while len(state_manager._states) < 2: + time.sleep(0.1) + + msg = messages.AvailableBeamlineStatesMessage( + states=[ + messages.BeamlineStateConfig( + name="State2", + title="Shutter2", + state_type="ShutterState", + parameters={"name": "State2", "title": "Shutter2", "device": "shutter2"}, + ) + ] + ) + connected_connector.xadd( + MessageEndpoints.available_beamline_states(), {"data": msg}, max_size=1 + ) + # Give it some time to process + while len(state_manager._states) > 1: + time.sleep(0.1) + + assert len(state_manager._states) == 1 + assert "State2" in state_manager._states