From 4a67667f31ab95d897eafa00a744c9c6a4082c98 Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Fri, 5 Dec 2025 17:33:19 +0100 Subject: [PATCH 1/2] feat: data api --- bec_lib/bec_lib/data_api/__init__.py | 1 + bec_lib/bec_lib/data_api/data_api.py | 418 +++++++++ bec_lib/bec_lib/data_api/plugins.py | 869 ++++++++++++++++++ bec_lib/tests/test_data_api.py | 1272 ++++++++++++++++++++++++++ 4 files changed, 2560 insertions(+) create mode 100644 bec_lib/bec_lib/data_api/__init__.py create mode 100644 bec_lib/bec_lib/data_api/data_api.py create mode 100644 bec_lib/bec_lib/data_api/plugins.py create mode 100644 bec_lib/tests/test_data_api.py diff --git a/bec_lib/bec_lib/data_api/__init__.py b/bec_lib/bec_lib/data_api/__init__.py new file mode 100644 index 000000000..f5cbf5d66 --- /dev/null +++ b/bec_lib/bec_lib/data_api/__init__.py @@ -0,0 +1 @@ +from .data_api import DataAPI, DataSubscription diff --git a/bec_lib/bec_lib/data_api/data_api.py b/bec_lib/bec_lib/data_api/data_api.py new file mode 100644 index 000000000..37a5b64ed --- /dev/null +++ b/bec_lib/bec_lib/data_api/data_api.py @@ -0,0 +1,418 @@ +from __future__ import annotations + +from typing import Any, Callable + +from bec_lib.logger import bec_logger + +from .plugins import BECLiveDataPlugin, DataAPIPlugin + +logger = bec_logger.logger + + +class DataSubscription: + """ + A subscription object that manages synchronized data updates for multiple device/signal pairs. + + The subscription automatically handles synchronization across all subscribed device/signal pairs + and provides methods to dynamically add/remove devices and reload data. The subscription is + automatically cleaned up when the object is destroyed. + + Example: + >>> subscription = data_api.subscribe(scan_id="my_scan") + >>> subscription.add_device("samx", "samx") + >>> subscription.add_device("detector1", "async_sig1") + >>> subscription.set_callback(my_callback_function) + >>> # Later: update the device list + >>> subscription.remove_device("samx", "samx") + >>> subscription.add_device("samy", "samy") + >>> # Reload all data + >>> subscription.reload() + >>> # Cleanup happens automatically when object is destroyed or explicitly: + >>> subscription.close() + """ + + def __init__(self, data_api: DataAPI, scan_id: str, buffered: bool = False): + """ + Initialize a data subscription. + + Args: + data_api: The DataAPI instance that manages this subscription. + scan_id: Identifier for the scan. + buffered: If True, re-emit the entire accumulated data buffer on each update. + If False (default), only emit new synchronized data blocks. + """ + self._data_api = data_api + self._scan_id = scan_id + self._devices: dict[tuple[str, str], str | None] = {} # (device, entry) -> subscription_id + self._callback: Callable[[dict, dict], Any] | None = None + self._user_callback: Callable[[dict, dict], Any] | None = None + self._is_closed = False + self._buffered = buffered + self._data_buffer: dict[str, dict[str, list[dict]]] = ( + {} + ) # device_name -> device_entry -> list of {value, timestamp} + + @property + def scan_id(self) -> str: + """Get the scan ID for this subscription.""" + return self._scan_id + + @property + def devices(self) -> list[tuple[str, str]]: + """Get the list of subscribed (device_name, device_entry) pairs.""" + return list(self._devices.keys()) + + @property + def buffered(self) -> bool: + """Get whether this subscription is in buffered mode.""" + return self._buffered + + def set_buffered(self, buffered: bool) -> DataSubscription: + """ + Change the buffering mode of the subscription. + + Args: + buffered: If True, re-emit the entire accumulated data buffer on each update. + If False, only emit new synchronized data blocks. + + Returns: + self for method chaining. + """ + if self._is_closed: + raise RuntimeError("Cannot change buffered mode on a closed subscription") + + if buffered == self._buffered: + return self + + self._buffered = buffered + + # Clear buffer when switching modes + if not buffered: + self._data_buffer.clear() + + return self + + def _buffering_callback(self, data: dict, metadata: dict) -> None: + """ + Internal callback wrapper that handles buffering logic. + + Args: + data: Data dictionary from the plugin. + metadata: Metadata dictionary from the plugin. + """ + if self._user_callback is None: + return + + if not self._buffered: + # Pass through directly without buffering + self._user_callback(data, metadata) + return + + # Buffered mode: accumulate data and re-emit entire buffer + for device_name, device_data in data.items(): + if device_name not in self._data_buffer: + self._data_buffer[device_name] = {} + + for device_entry, signal_data in device_data.items(): + if device_entry not in self._data_buffer[device_name]: + self._data_buffer[device_name][device_entry] = [] + + self._data_buffer[device_name][device_entry].append(signal_data) + + # Re-emit the entire buffer + buffered_data = {} + for device_name, device_entries in self._data_buffer.items(): + buffered_data[device_name] = {} + for device_entry, signal_list in device_entries.items(): + buffered_data[device_name][device_entry] = signal_list + + self._user_callback(buffered_data, metadata) + + def set_scan_id(self, scan_id: str) -> DataSubscription: + """ + Update the scan ID and resubscribe all devices to the new scan. + + Args: + scan_id: New scan identifier. + + Returns: + self for method chaining. + """ + if self._is_closed: + raise RuntimeError("Cannot change scan_id on a closed subscription") + + if scan_id == self._scan_id: + return self + + old_scan_id = self._scan_id + self._scan_id = scan_id + + # Clear buffer when changing scans + self._data_buffer.clear() + + # If we have devices and a callback, resubscribe to new scan + if self._devices and self._callback is not None: + logger.info( + f"Changing scan_id from {old_scan_id} to {scan_id}, resubscribing all devices" + ) + self._resubscribe_all() + + return self + + def set_callback(self, callback: Callable[[dict, dict], Any]) -> DataSubscription: + """ + Set or update the callback function for data updates. + + Args: + callback: Function to call on data update. Receives (data_dict, metadata_dict). + In non-buffered mode, receives individual synchronized data blocks. + In buffered mode, receives the entire accumulated data buffer. + + Returns: + self for method chaining. + """ + if self._is_closed: + raise RuntimeError("Cannot set callback on a closed subscription") + + old_user_callback = self._user_callback + self._user_callback = callback + + # The internal callback is always the buffering wrapper + new_internal_callback = self._buffering_callback + old_internal_callback = self._callback + self._callback = new_internal_callback + + # If we already have devices subscribed and callback changed, we need to resubscribe + if old_internal_callback is not None and old_user_callback != callback and self._devices: + self._resubscribe_all() + + return self + + def add_device(self, device_name: str, device_entry: str) -> DataSubscription: + """ + Add a device/signal pair to the synchronized subscription. + + Args: + device_name: Name of the device. + device_entry: Specific entry/signal of the device. + + Returns: + self for method chaining. + """ + if self._is_closed: + raise RuntimeError("Cannot add device to a closed subscription") + + key = (device_name, device_entry) + if key in self._devices: + logger.debug(f"Device {device_name}/{device_entry} already subscribed") + return self + + if self._callback is None: + # Store the device but don't subscribe yet + self._devices[key] = None + logger.debug(f"Device {device_name}/{device_entry} queued, waiting for callback") + else: + # Subscribe immediately + sub_id = self._data_api.subscribe( + device_name, device_entry, self._scan_id, self._callback + ) + self._devices[key] = sub_id + + return self + + def remove_device(self, device_name: str, device_entry: str) -> DataSubscription: + """ + Remove a device/signal pair from the subscription. + + Args: + device_name: Name of the device. + device_entry: Specific entry/signal of the device. + + Returns: + self for method chaining. + """ + if self._is_closed: + raise RuntimeError("Cannot remove device from a closed subscription") + + key = (device_name, device_entry) + sub_id = self._devices.pop(key, None) + + if sub_id is not None: + self._data_api.unsubscribe(subscription_id=sub_id) + + return self + + def reload(self) -> DataSubscription: + """ + Reload data for all subscribed devices by resubscribing. + + Returns: + self for method chaining. + """ + if self._is_closed: + raise RuntimeError("Cannot reload a closed subscription") + + if self._callback is None: + logger.warning("Cannot reload without a callback set") + return self + + self._resubscribe_all() + return self + + def close(self) -> None: + """ + Close the subscription and unsubscribe from all devices. + + This is called automatically when the object is destroyed. + """ + if self._is_closed: + return + + # Unsubscribe from all devices + for sub_id in self._devices.values(): + if sub_id is not None: + self._data_api.unsubscribe(subscription_id=sub_id) + + self._devices.clear() + self._callback = None + self._user_callback = None + self._data_buffer.clear() + self._is_closed = True + + def _resubscribe_all(self) -> None: + """Resubscribe to all devices (used when callback changes or reload is requested).""" + if self._callback is None: + return + + # Unsubscribe from all + for sub_id in self._devices.values(): + if sub_id is not None: + self._data_api.unsubscribe(subscription_id=sub_id) + + # Resubscribe with new callback + for device_name, device_entry in list(self._devices.keys()): + sub_id = self._data_api.subscribe( + device_name, device_entry, self._scan_id, self._callback + ) + self._devices[(device_name, device_entry)] = sub_id + + def __del__(self): + """Ensure cleanup when object is garbage collected.""" + self.close() + + def __enter__(self) -> DataSubscription: + """Support context manager protocol.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Support context manager protocol.""" + self.close() + return False + + +class DataAPI: + """ + DataAPI class that manages data retrieval through plugins. + + This is a singleton - only one instance exists globally. + """ + + _instance: DataAPI | None = None + + def __new__(cls, client): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, client): + # Only initialize once + if hasattr(self, "_initialized"): + return + + self.client = client + self.plugins: list[DataAPIPlugin] = [] + self._initialized = True + self.register_plugin(BECLiveDataPlugin(self.client)) + + @classmethod + def clear_instance(cls) -> None: + """Clear the singleton instance. Useful for testing.""" + cls._instance = None + + def register_plugin(self, plugin: DataAPIPlugin) -> None: + """Register a new plugin.""" + plugin.connect() + self.plugins.append(plugin) + self.plugins.sort(key=lambda p: p.get_info().get("priority", 100)) + + def create_subscription(self, scan_id: str, buffered: bool = False) -> DataSubscription: + """ + Create a new subscription object for synchronized data updates. + + This is the recommended way to subscribe to data updates as it provides + automatic lifecycle management and synchronization across multiple devices. + + Args: + scan_id: Identifier for the scan. + buffered: If True, re-emit the entire accumulated data buffer on each update. + If False (default), only emit new synchronized data blocks. + + Returns: + A DataSubscription object that manages the subscription lifecycle. + + Example: + >>> # Non-buffered mode (default): receive only new data blocks + >>> sub = data_api.create_subscription("my_scan") + >>> sub.add_device("samx", "samx").add_device("detector1", "async_sig1") + >>> sub.set_callback(my_callback) + >>> + >>> # Buffered mode: receive entire accumulated buffer on each update + >>> sub = data_api.create_subscription("my_scan", buffered=True) + >>> sub.add_device("samx", "samx").set_callback(my_callback) + >>> # Later: + >>> sub.close() # or use context manager: with data_api.create_subscription(...) as sub: + """ + return DataSubscription(self, scan_id, buffered=buffered) + + def subscribe( + self, + device_name: str, + device_entry: str, + scan_id: str, + callback: Callable[[dict, dict], Any], + ) -> str | None: + """ + Subscribe to data updates for a specific device and entry. + + Args: + device_name: Name of the device. + device_entry: Specific entry of the device. + scan_id: Identifier for the scan. + callback: Function to call on data update. + + Returns: + A string subscription ID. + """ + for plugin in self.plugins: + if plugin.can_provide(device_name, device_entry, scan_id): + return plugin.subscribe(device_name, device_entry, scan_id, callback) + logger.warning( + f"No plugin available to provide data for device '{device_name}', entry '{device_entry}', scan_id '{scan_id}'" + ) + + def unsubscribe( + self, + subscription_id: str | None = None, + scan_id: str | None = None, + callback: Callable[[dict, dict], Any] | None = None, + ) -> None: + """ + Unsubscribe from data updates by either subscription ID, scan ID and callback, or both. + + Args: + subscription_id: The ID of the subscription to cancel. + scan_id: Identifier for the scan. + callback: Function that was used for subscription. + """ + for plugin in self.plugins: + plugin.unsubscribe(subscription_id, scan_id, callback) diff --git a/bec_lib/bec_lib/data_api/plugins.py b/bec_lib/bec_lib/data_api/plugins.py new file mode 100644 index 000000000..cbd1614f3 --- /dev/null +++ b/bec_lib/bec_lib/data_api/plugins.py @@ -0,0 +1,869 @@ +from __future__ import annotations + +import uuid +import weakref +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import Any, Callable, Literal, Tuple + +import louie +from pydantic import BaseModel, ConfigDict + +from bec_lib import messages +from bec_lib.client import BECClient +from bec_lib.endpoints import MessageEndpoints +from bec_lib.messages import DeviceAsyncUpdate + +CallbackRef = louie.saferef.BoundMethodWeakref | weakref.ReferenceType[Callable[[dict, dict], Any]] + + +class DataAPIPlugin(ABC): + """Base class for DataAPI plugins.""" + + def connect(self) -> None: + """ + Connection setup for the plugin. + """ + + def disconnect(self) -> None: + """ + Disconnect and clean up resources for the plugin. + """ + + @abstractmethod + def has_scan_data(self, scan_id: str) -> bool: + """ + Check if the plugin has data for the given scan ID. + + Args: + scan_id: Identifier for the scan. + Returns: + True if the plugin has data for the scan ID, False otherwise. + """ + + @abstractmethod + def can_provide(self, device_name: str, device_entry: str, scan_id: str) -> bool: + """ + Check if the plugin can provide data for the given device and entry. + + Args: + device_name: Name of the device. + device_entry: Specific entry of the device. + scan_id: Identifier for the scan. + + Returns: + True if the plugin can provide the data, False otherwise. + """ + + def get_info(self) -> dict: + """Return plugin metadata such as name and priority.""" + return {} + + @abstractmethod + def subscribe( + self, + device_name: str, + device_entry: str, + scan_id: str, + callback: Callable[[dict, dict], Any], + ) -> str: + """ + Subscribe to data updates. + + Args: + device_name: Name of the device. + device_entry: Specific entry of the device. + scan_id: Identifier for the scan. + callback: Function to call on data update. The function should accept two dicts: + one for the data and one for the metadata. + + Returns: + A unique subscription ID. + """ + + @abstractmethod + def unsubscribe( + self, + subscription_id: str | None = None, + scan_id: str | None = None, + callback: Callable[[dict, dict], Any] | None = None, + ) -> None: + """ + Unsubscribe from data updates by either subscription ID, scan ID and callback, or both. + + Args: + subscription_id: The ID of the subscription to cancel. + scan_id: Identifier for the scan. + callback: Function that was used for subscription. + """ + + +class _MonitoredSubscription(BaseModel): + scan_id: str + callback_ref: CallbackRef + devices: list[Tuple[str, str]] # List of (device_name, device_entry) tuples + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class _AsyncSubscription(BaseModel): + """Tracks a single async signal subscription shared by multiple callbacks.""" + + scan_id: str + device_name: str + device_entry: str + callback_refs: list[CallbackRef] + connector_id: Any # ID returned by client.connector.register + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class _DataBuffer(BaseModel): + """Buffer for storing data updates until all sources are synchronized.""" + + device_name: str + device_entry: str + data: list[dict] # List of data points with value and timestamp + source_type: Literal["monitored", "async_signal"] + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class _CallbackBuffer(BaseModel): + """Tracks buffered data for a specific callback across all its subscribed devices.""" + + callback_ref: CallbackRef + scan_id: str + buffers: dict[tuple[str, str], _DataBuffer] # (device_name, device_entry) -> buffer + min_length: int = 0 # Minimum data length across all buffers for this callback + monitored_indices: dict[tuple[str, str], int] = ( + {} + ) # Track last processed index for monitored devices + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class _SubscriptionInfo(BaseModel): + """Information about a single subscription.""" + + subscription_id: str + scan_id: str + device_name: str + device_entry: str + callback_ref: CallbackRef + subscription_type: Literal["monitored", "async_signal"] + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class BECLiveDataPlugin(DataAPIPlugin): + """ + Plugin to access live BEC data. + Provides real-time data from the BEC client if available. It fetches live data from + the storage of the BEC client as well as from async updates. + """ + + def __init__(self, client: BECClient): + self.client = client + # Subscription tracking: sub_id -> subscription info + self._subscriptions: dict[str, _SubscriptionInfo] = {} + # Scan-level grouping for monitored devices: scan_id -> {callback_ref -> devices} + self._monitored_subscriptions: dict[str, dict[CallbackRef, _MonitoredSubscription]] = {} + # Async signal grouping: (scan_id, device_name, device_entry) -> _AsyncSubscription + self._async_subscriptions: dict[tuple[str, str, str], _AsyncSubscription] = {} + # Data buffers for synchronization: callback_ref -> _CallbackBuffer + self._callback_buffers: dict[CallbackRef, _CallbackBuffer] = {} + self._connect_id = None + + def connect(self): + """Connect to client signals for live data updates.""" + self._connect_id = self.client.callbacks.register( + "scan_segment", self._handle_scan_segment_update + ) + + def disconnect(self): + """Disconnect from client signals.""" + if self._connect_id is not None: + self.client.callbacks.remove(self._connect_id) + self._connect_id = None + + # Unregister all async signal subscriptions from redis connector + for async_sub in self._async_subscriptions.values(): + self.client.connector.unregister(async_sub.connector_id) + self._async_subscriptions.clear() + + def has_scan_data(self, scan_id: str) -> bool: + """ + Check if live data is available for the given scan ID. + + Args: + scan_id: Identifier for the scan. + Returns: + True if live data is available, False otherwise. + """ + if not self.client.started: + return False + if self.client.queue is None: + return False + + scan_item = self.client.queue.scan_storage.find_scan_by_ID(scan_id) + if scan_item is None: + return False + + if scan_item.status in ["closed", "aborted", "halted"]: + # We skip closed scans and instead rely on historical data plugin + return False + return True + + def can_provide(self, device_name: str, device_entry: str, scan_id: str) -> bool: + """ + Check if live data is available for the given device and entry. + + Args: + device_name: Name of the device. + device_entry: Specific entry of the device. + scan_id: Identifier for the scan. + + Returns: + True if live data is available, False otherwise. + """ + mode = self._get_device_mode(device_name, device_entry, scan_id) + return mode is not None + + @lru_cache(maxsize=128) + def _get_device_mode( + self, device_name: str, device_entry: str, scan_id: str + ) -> Literal["monitored", "async_signal", None]: + """ + Get the mode of the device entry for the given scan ID. + As the mode does not change during a scan, we cache the results for performance. + + Args: + device_name: Name of the device. + device_entry: Specific entry of the device. + scan_id: Identifier for the scan. + + Returns: + "monitored" if live data is available as monitored device, + "async_signal" if live data is available as async signal, + None otherwise. + """ + # Pre-checks; mostly for type checks + if not self.client.started or self.client.queue is None: + return None + + scan_item = self.client.queue.scan_storage.find_scan_by_ID(scan_id) + if scan_item is None: + return None + + if self._device_entry_is_monitored(device_name, device_entry, scan_item): + return "monitored" + + if self._device_entry_is_async_signal(device_name, device_entry): + return "async_signal" + return None + + def _device_entry_is_monitored(self, device_name: str, device_entry: str, scan_item) -> bool: + """ + Check if the device entry is a monitored devices in the scan item. + + Args: + device_name: Name of the device. + device_entry: Specific entry of the device. + scan_item: The scan item to check against. + + Returns: + True if the device entry is monitored, False otherwise. + """ + if scan_item.status_message is None: + return False + + readout_priority = scan_item.status_message.readout_priority or {} + if device_name in readout_priority.get("monitored", []): + return True + + # FIXME: we should also check that the device_entry is actually part of the monitored device + return False + + def _device_entry_is_async_signal(self, device_name: str, device_entry: str) -> bool: + """ + Check if the device entry is an async signal. + + Args: + device_name: Name of the device. + device_entry: Specific entry of the device. + Returns: + True if the device entry is an async signal, False otherwise. + """ + if not self.client.device_manager: + return False + async_signals = self.client.device_manager.get_bec_signals("AsyncSignal") + for entry_name, _, entry_data in async_signals: + if entry_name == device_entry and entry_data.get("device_name") == device_name: + return True + return False + + def subscribe( + self, + device_name: str, + device_entry: str, + scan_id: str, + callback: Callable[[dict, dict], Any], + ) -> str: + """ + Subscribe to live data updates for the given device and entry. + + Args: + device_name: Name of the device. + device_entry: Specific entry of the device. + scan_id: Identifier for the scan. + callback: Function to call on data update. The function should accept two dicts: + one for the data and one for the metadata. + Returns: + A unique subscription ID. + """ + + match self._get_device_mode(device_name, device_entry, scan_id): + case "monitored": + return self._subscribe_to_monitored_device( + device_name, device_entry, scan_id, callback + ) + + case "async_signal": + return self._subscribe_to_async_signal(device_name, device_entry, scan_id, callback) + case None: + raise ValueError( + f"Cannot subscribe to device '{device_name}' entry '{device_entry}' for scan '{scan_id}'." + ) + case _: + raise ValueError( + f"Cannot subscribe to device '{device_name}' entry '{device_entry}' for scan '{scan_id}': unknown mode." + ) + + def unsubscribe( + self, + subscription_id: str | None = None, + scan_id: str | None = None, + callback: Callable[[dict, dict], Any] | None = None, + ) -> None: + """ + Unsubscribe from live data updates by either subscription ID, scan ID and callback, or both. + + Args: + subscription_id: The ID of the subscription to cancel. + scan_id: Identifier for the scan. + callback: Function that was used for subscription. + """ + + if subscription_id is not None: + self._unsubscribe_by_id(subscription_id) + return + if scan_id is not None and callback is not None: + # find all subscriptions matching scan_id and callback + callback_ref = louie.saferef.safe_ref(callback) + to_remove = [] + for sub_id, sub_info in self._subscriptions.items(): + if sub_info.scan_id == scan_id and sub_info.callback_ref == callback_ref: + to_remove.append(sub_id) + for sub_id in to_remove: + self._unsubscribe_by_id(sub_id) + return + if scan_id is not None: + # find all subscriptions matching scan_id + to_remove = [] + for sub_id, sub_info in self._subscriptions.items(): + if sub_info.scan_id == scan_id: + to_remove.append(sub_id) + for sub_id in to_remove: + self._unsubscribe_by_id(sub_id) + return + if callback is not None: + # find all subscriptions matching callback + callback_ref = louie.saferef.safe_ref(callback) + to_remove = [] + for sub_id, sub_info in self._subscriptions.items(): + if sub_info.callback_ref == callback_ref: + to_remove.append(sub_id) + for sub_id in to_remove: + self._unsubscribe_by_id(sub_id) + return + + def _unsubscribe_by_id(self, subscription_id: str) -> None: + """ + Unsubscribe from live data updates by subscription ID. + Args: + subscription_id: The ID of the subscription to cancel. + """ + + # Look up subscription info + if subscription_id not in self._subscriptions: + return + + sub_info = self._subscriptions[subscription_id] + + # Handle based on subscription type + if sub_info.subscription_type == "monitored": + self._unsubscribe_monitored(sub_info) + elif sub_info.subscription_type == "async_signal": + self._unsubscribe_async_signal(sub_info) + + # Remove from main subscription tracking + del self._subscriptions[subscription_id] + + def _unsubscribe_monitored(self, sub_info: _SubscriptionInfo) -> None: + """Unsubscribe from monitored device updates.""" + scan_id = sub_info.scan_id + device_name = sub_info.device_name + device_entry = sub_info.device_entry + callback_ref = sub_info.callback_ref + + if scan_id not in self._monitored_subscriptions: + return + + subscriptions = self._monitored_subscriptions[scan_id] + + # Find the subscription for this callback + if callback_ref in subscriptions: + sub = subscriptions[callback_ref] + + # Remove the device from the callback's device list + if (device_name, device_entry) in sub.devices: + sub.devices.remove((device_name, device_entry)) + + # Remove from buffer if exists + if callback_ref in self._callback_buffers: + buffer_key = (device_name, device_entry) + if buffer_key in self._callback_buffers[callback_ref].buffers: + del self._callback_buffers[callback_ref].buffers[buffer_key] + + # Clean up callback buffer if empty + if not self._callback_buffers[callback_ref].buffers: + del self._callback_buffers[callback_ref] + + # If callback has no more devices, remove it entirely + if not sub.devices: + del subscriptions[callback_ref] + + # Clean up scan if no more subscriptions + if not subscriptions: + del self._monitored_subscriptions[scan_id] + + def _unsubscribe_async_signal(self, sub_info: _SubscriptionInfo) -> None: + """Unsubscribe from async signal updates.""" + scan_id = sub_info.scan_id + device_name = sub_info.device_name + device_entry = sub_info.device_entry + callback_ref = sub_info.callback_ref + + key = (scan_id, device_name, device_entry) + + if key not in self._async_subscriptions: + return + + async_sub = self._async_subscriptions[key] + + # Remove the callback from the list + if callback_ref in async_sub.callback_refs: + async_sub.callback_refs.remove(callback_ref) + + # Remove from buffer if exists + if callback_ref in self._callback_buffers: + buffer_key = (device_name, device_entry) + if buffer_key in self._callback_buffers[callback_ref].buffers: + del self._callback_buffers[callback_ref].buffers[buffer_key] + + # Clean up callback buffer if empty + if not self._callback_buffers[callback_ref].buffers: + del self._callback_buffers[callback_ref] + + # If no more callbacks, unregister from redis connector and clean up + if not async_sub.callback_refs: + self.client.connector.unregister(async_sub.connector_id) + del self._async_subscriptions[key] + + def _subscribe_to_monitored_device( + self, + device_name: str, + device_entry: str, + scan_id: str, + callback: Callable[[dict, dict], Any], + ) -> str: + """ + Subscribe to monitored device data updates. + + Args: + device_name: Name of the device. + device_entry: Specific entry of the device. + scan_id: Identifier for the scan. + callback: Function to call on data update. + + Returns: + A unique subscription ID. + """ + # Generate unique subscription ID + sub_id = str(uuid.uuid4()) + + callback_ref = louie.saferef.safe_ref(callback) + + # Store subscription info + sub_info = _SubscriptionInfo( + subscription_id=sub_id, + scan_id=scan_id, + device_name=device_name, + device_entry=device_entry, + callback_ref=callback_ref, + subscription_type="monitored", + ) + self._subscriptions[sub_id] = sub_info + + # Update monitored subscriptions grouping + available_subscriptions = self._monitored_subscriptions.get(scan_id) + + if available_subscriptions is None: + sub = _MonitoredSubscription( + scan_id=scan_id, callback_ref=callback_ref, devices=[(device_name, device_entry)] + ) + self._monitored_subscriptions[scan_id] = {callback_ref: sub} + return sub_id + + for callback_ref_existing, sub in available_subscriptions.items(): + if callback_ref_existing == callback_ref: + # Found existing subscription for this callback + if (device_name, device_entry) not in sub.devices: + sub.devices.append((device_name, device_entry)) + return sub_id + + # New callback for this scan + sub = _MonitoredSubscription( + scan_id=scan_id, callback_ref=callback_ref, devices=[(device_name, device_entry)] + ) + self._monitored_subscriptions[scan_id][callback_ref] = sub + return sub_id + + def _handle_scan_segment_update(self, _scan_segment: dict, metadata: dict) -> None: + """ + Handle scan segment updates from the client. We do not use the scan_segment directly, + but use it as a trigger to fetch data for all subscribed monitored devices from the live update storage. + + Args: + scan_segment: The scan segment data (content from ScanMessage). + metadata: Metadata associated with the scan segment. + """ + scan_id = _scan_segment.get("scan_id") + if scan_id is None: + return + + if scan_id not in self._monitored_subscriptions: + return + + if self.client.queue is None: + return + + scan_item = self.client.queue.scan_storage.find_scan_by_ID(scan_id) + if scan_item is None: + return + + subscriptions = self._monitored_subscriptions[scan_id] + + for callback_ref, sub in subscriptions.items(): + callback = callback_ref() + if callback is None: + continue + + # Get or initialize callback buffer + if callback_ref not in self._callback_buffers: + self._callback_buffers[callback_ref] = _CallbackBuffer( + callback_ref=callback_ref, scan_id=scan_id, buffers={} + ) + + callback_buffer = self._callback_buffers[callback_ref] + + # Prepare data for this subscription + for device_name, device_entry in sub.devices: + # live_data returns lists of all values and timestamps + values = ( + scan_item.live_data.get(device_name, {}).get(device_entry, {}).get("val", None) + ) + timestamps = ( + scan_item.live_data.get(device_name, {}) + .get(device_entry, {}) + .get("timestamp", None) + ) + if values is None and timestamps is None: + continue + + if not isinstance(values, list): + values = [values] + if not isinstance(timestamps, list): + timestamps = [timestamps] + + # Track which index we've already processed for this device + key = (device_name, device_entry) + last_processed_index = callback_buffer.monitored_indices.get(key, 0) + + # Only add new data points we haven't processed yet + for idx in range(last_processed_index, len(values)): + if idx < len(timestamps): + self._add_to_buffer( + callback_ref, + scan_id, + device_name, + device_entry, + values[idx], + timestamps[idx], + "monitored", + ) + + # Update the last processed index + callback_buffer.monitored_indices[key] = len(values) + + # Check if we can emit synchronized data + self._check_and_emit_synchronized_data(callback_ref, scan_id) + + def _subscribe_to_async_signal( + self, + device_name: str, + device_entry: str, + scan_id: str, + callback: Callable[[dict, dict], Any], + ) -> str: + """ + Subscribe to async signal data updates. + + Args: + device_name: Name of the device. + device_entry: Specific entry of the device. + scan_id: Identifier for the scan. + callback: Function to call on data update. + + Returns: + A unique subscription ID. + """ + # Generate unique subscription ID + sub_id = str(uuid.uuid4()) + + callback_ref = louie.saferef.safe_ref(callback) + + # Store subscription info + sub_info = _SubscriptionInfo( + subscription_id=sub_id, + scan_id=scan_id, + device_name=device_name, + device_entry=device_entry, + callback_ref=callback_ref, + subscription_type="async_signal", + ) + self._subscriptions[sub_id] = sub_info + + # Check if we already have a subscription for this device/entry/scan combination + key = (scan_id, device_name, device_entry) + + if key in self._async_subscriptions: + # Reuse existing subscription, just add the callback + async_sub = self._async_subscriptions[key] + if callback_ref not in async_sub.callback_refs: + async_sub.callback_refs.append(callback_ref) + else: + # Create new redis connector subscription + connector_id = self.client.connector.register( + MessageEndpoints.device_async_signal( + scan_id=scan_id, device=device_name, signal=device_entry + ), + cb=self._async_signal_sync_callback, + from_start=True, + parent=self, + scan_id=scan_id, + device_name=device_name, + device_entry=device_entry, + ) + + # Create subscription tracking entry + async_sub = _AsyncSubscription( + scan_id=scan_id, + device_name=device_name, + device_entry=device_entry, + callback_refs=[callback_ref], + connector_id=connector_id, + ) + self._async_subscriptions[key] = async_sub + + return sub_id + + @staticmethod + def _async_signal_sync_callback( + msg: dict, parent: BECLiveDataPlugin, scan_id: str, device_name: str, device_entry: str + ): + """Callback for async signal updates from the client. Broadcasts to all subscribers.""" + + msg_obj = msg.get("data") + if not isinstance(msg_obj, messages.DeviceMessage): + return + + signals = msg_obj.signals + timestamp = msg_obj.metadata.get("timestamp") + + # Get all callbacks for this device/entry/scan combination + key = (scan_id, device_name, device_entry) + if key not in parent._async_subscriptions: + return + + async_sub = parent._async_subscriptions[key] + + # Add data to buffer for each subscriber + for callback_ref in async_sub.callback_refs: + callback = callback_ref() + if callback is None: + continue + + # Add to buffer + parent._add_to_buffer( + callback_ref, scan_id, device_name, device_entry, signals, timestamp, "async_signal" + ) + + # Check if we can emit synchronized data + parent._check_and_emit_synchronized_data(callback_ref, scan_id) + + def _add_to_buffer( + self, + callback_ref: CallbackRef, + scan_id: str, + device_name: str, + device_entry: str, + value: Any, + timestamp: Any, + source_type: Literal["monitored", "async_signal"], + ) -> None: + """ + Add data to the buffer for a specific callback and device. + + Args: + callback_ref: Weak reference to the callback + scan_id: Scan identifier + device_name: Name of the device + device_entry: Device entry + value: Data value + timestamp: Data timestamp + source_type: Type of data source (monitored or async_signal) + """ + # Initialize callback buffer if not exists + if callback_ref not in self._callback_buffers: + self._callback_buffers[callback_ref] = _CallbackBuffer( + callback_ref=callback_ref, scan_id=scan_id, buffers={} + ) + + callback_buffer = self._callback_buffers[callback_ref] + key = (device_name, device_entry) + + # Initialize device buffer if not exists + if key not in callback_buffer.buffers: + callback_buffer.buffers[key] = _DataBuffer( + device_name=device_name, device_entry=device_entry, data=[], source_type=source_type + ) + + # Add data point to buffer + data_point = {"value": value, "timestamp": timestamp} + callback_buffer.buffers[key].data.append(data_point) + + def _get_expected_device_count(self, callback_ref: CallbackRef, scan_id: str) -> int: + """ + Get the total number of devices (monitored + async) that this callback is subscribed to. + + Args: + callback_ref: Weak reference to the callback + scan_id: Scan identifier + + Returns: + Total count of subscribed devices for this callback + """ + count = 0 + + # Count monitored devices + if scan_id in self._monitored_subscriptions: + if callback_ref in self._monitored_subscriptions[scan_id]: + count += len(self._monitored_subscriptions[scan_id][callback_ref].devices) + + # Count async signals + for ( + sub_scan_id, + device_name, + device_entry, + ), async_sub in self._async_subscriptions.items(): + if sub_scan_id == scan_id and callback_ref in async_sub.callback_refs: + count += 1 + + return count + + def _check_and_emit_synchronized_data(self, callback_ref: CallbackRef, scan_id: str) -> None: + """ + Check if all buffers for a callback have data of equal length and emit synchronized data. + + Args: + callback_ref: Weak reference to the callback + scan_id: Scan identifier + """ + if callback_ref not in self._callback_buffers: + return + + callback_buffer = self._callback_buffers[callback_ref] + + if not callback_buffer.buffers: + return + + # Determine how many devices this callback is subscribed to + expected_device_count = self._get_expected_device_count(callback_ref, scan_id) + + # Wait until all subscribed devices have buffers + if len(callback_buffer.buffers) < expected_device_count: + return + + # Find minimum length across all buffers + min_length = min(len(buffer.data) for buffer in callback_buffer.buffers.values()) + + # If no data is available in all buffers yet, return + if min_length == 0: + return + + # Only emit new data (data beyond min_length we've already emitted) + if min_length <= callback_buffer.min_length: + return + + callback = callback_ref() + if callback is None: + return + + # Emit data from min_length onward up to the new min_length + for idx in range(callback_buffer.min_length, min_length): + data = {} + for (device_name, device_entry), buffer in callback_buffer.buffers.items(): + data_point = buffer.data[idx] + if device_name not in data: + data[device_name] = {} + data[device_name][device_entry] = { + "value": data_point["value"], + "timestamp": data_point["timestamp"], + } + + # Call the callback with synchronized data + callback( + data, + { + "scan_id": scan_id, + "async_update": DeviceAsyncUpdate(type="replace").model_dump(), + }, + ) + + # Update the min_length to track what we've already emitted + callback_buffer.min_length = min_length + + +""" +NOTES + +- AsyncSignal subscriptions should be shared between multiple subscribers to avoid redundant subscriptions. +- Whenever the redis connector triggers the callback, we broadcast to all subscribers of that device/entry/scan combination. +- When the user subscribed to an AsyncSignal, we check if there is already a subscription for that device/entry/scan combination. +- If yes, we just add the callback to the list of callbacks for that subscription. +- When the user subscribes to multiple AsyncSignals, we synchronize the data length and only broadcast the data of equal length. Same for mixtures + of monitored devices and AsyncSignals. + +""" diff --git a/bec_lib/tests/test_data_api.py b/bec_lib/tests/test_data_api.py new file mode 100644 index 000000000..ad5e01111 --- /dev/null +++ b/bec_lib/tests/test_data_api.py @@ -0,0 +1,1272 @@ +""" +Tests for the data_api module and plugins. + +These tests verify the functionality of the DataAPI and BECLiveDataPlugin classes, +including subscription management, data synchronization, and proper handling of +BECMessages (ScanMessage and DeviceMessage). +""" + +import copy +from unittest import mock + +import louie +import pytest + +from bec_lib import messages +from bec_lib.client import BECClient +from bec_lib.data_api.data_api import DataAPI +from bec_lib.data_api.plugins import BECLiveDataPlugin, _AsyncSubscription +from bec_lib.live_scan_data import LiveScanData +from bec_lib.scan_items import ScanItem + +# pylint: disable=protected-access +# pylint: disable=missing-function-docstring +# pylint: disable=redefined-outer-name + + +@pytest.fixture +def mock_client(connected_connector): + """Create a mock BECClient with necessary attributes.""" + client = mock.MagicMock(spec=BECClient) + client.started = True + client.connector = connected_connector + client.callbacks = mock.MagicMock() + client.callbacks.register = mock.MagicMock(return_value="callback_id") + client.callbacks.remove = mock.MagicMock() + + # Setup queue and scan storage + client.queue = mock.MagicMock() + client.queue.scan_storage = mock.MagicMock() + + # Setup device manager + client.device_manager = mock.MagicMock() + + return client + + +@pytest.fixture +def mock_callback(): + """Create a real callback function for testing (needed for louie.saferef).""" + calls = [] + + def callback(data, metadata): + calls.append((data, metadata)) + + callback.calls = calls + callback.reset = lambda: calls.clear() + return callback + + +@pytest.fixture +def data_api(mock_client): + """Create a DataAPI instance for testing.""" + # Clear singleton before creating instance + DataAPI.clear_instance() + api = DataAPI(mock_client) + yield api + # Clean up after test + DataAPI.clear_instance() + + +@pytest.fixture +def live_plugin(mock_client): + """Create a BECLiveDataPlugin instance for testing.""" + plugin = BECLiveDataPlugin(mock_client) + plugin.connect() + yield plugin + plugin.disconnect() + + +@pytest.fixture +def scan_item_with_monitored_devices(mock_client): + """Create a scan item with monitored devices configured.""" + scan_item = ScanItem( + queue_id="test_queue", scan_number=[1], scan_id=["test_scan_id"], status="open" + ) + + # Setup status message with monitored devices + scan_item.status_message = messages.ScanStatusMessage( + scan_id="test_scan_id", status="open", info={} + ) + scan_item.status_message.readout_priority = { + "monitored": ["samx", "samy"], + "baseline": ["samz"], + } + + # Initialize live_data with proper LiveScanData instance + scan_item.live_data = LiveScanData() + + mock_client.queue.scan_storage.find_scan_by_ID.return_value = scan_item + return scan_item + + +class TestDataAPI: + """Tests for the DataAPI class.""" + + def test_singleton_pattern(self, mock_client): + """Test that DataAPI follows singleton pattern.""" + DataAPI.clear_instance() + api1 = DataAPI(mock_client) + api2 = DataAPI(mock_client) + assert api1 is api2 + DataAPI.clear_instance() + + def test_clear_instance(self, mock_client): + """Test that clear_instance properly resets singleton.""" + DataAPI.clear_instance() + api1 = DataAPI(mock_client) + DataAPI.clear_instance() + api2 = DataAPI(mock_client) + assert api1 is not api2 + DataAPI.clear_instance() + + def test_register_plugin(self, data_api): + """Test plugin registration.""" + mock_plugin = mock.MagicMock() + mock_plugin.get_info.return_value = {"priority": 50} + + data_api.register_plugin(mock_plugin) + + mock_plugin.connect.assert_called_once() + assert mock_plugin in data_api.plugins + + def test_plugin_priority_sorting(self, data_api): + """Test that plugins are sorted by priority.""" + # Remove default plugin + data_api.plugins = [] + + plugin1 = mock.MagicMock() + plugin1.get_info.return_value = {"priority": 100} + plugin2 = mock.MagicMock() + plugin2.get_info.return_value = {"priority": 50} + plugin3 = mock.MagicMock() + plugin3.get_info.return_value = {"priority": 75} + + data_api.register_plugin(plugin1) + data_api.register_plugin(plugin2) + data_api.register_plugin(plugin3) + + assert data_api.plugins[0] is plugin2 # priority 50 + assert data_api.plugins[1] is plugin3 # priority 75 + assert data_api.plugins[2] is plugin1 # priority 100 + + def test_subscribe_with_capable_plugin(self, data_api): + """Test subscription when a plugin can provide data.""" + mock_plugin = mock.MagicMock() + mock_plugin.can_provide.return_value = True + mock_plugin.subscribe.return_value = "sub_id_123" + mock_plugin.get_info.return_value = {} + + data_api.plugins = [mock_plugin] + + callback = mock.MagicMock() + sub_id = data_api.subscribe("samx", "value", "test_scan", callback) + + assert sub_id == "sub_id_123" + mock_plugin.subscribe.assert_called_once_with("samx", "value", "test_scan", callback) + + def test_subscribe_no_capable_plugin(self, data_api): + """Test subscription when no plugin can provide data.""" + mock_plugin = mock.MagicMock() + mock_plugin.can_provide.return_value = False + mock_plugin.get_info.return_value = {} + + data_api.plugins = [mock_plugin] + + callback = mock.MagicMock() + sub_id = data_api.subscribe("samx", "value", "test_scan", callback) + + assert sub_id is None + mock_plugin.subscribe.assert_not_called() + + def test_unsubscribe_by_id(self, data_api): + """Test unsubscribe by subscription ID.""" + mock_plugin = mock.MagicMock() + mock_plugin.get_info.return_value = {} + data_api.plugins = [mock_plugin] + + data_api.unsubscribe(subscription_id="sub_123") + + mock_plugin.unsubscribe.assert_called_once_with("sub_123", None, None) + + def test_unsubscribe_by_scan_and_callback(self, data_api): + """Test unsubscribe by scan ID and callback.""" + mock_plugin = mock.MagicMock() + mock_plugin.get_info.return_value = {} + data_api.plugins = [mock_plugin] + + callback = mock.MagicMock() + data_api.unsubscribe(scan_id="test_scan", callback=callback) + + mock_plugin.unsubscribe.assert_called_once_with(None, "test_scan", callback) + + +class TestBECLiveDataPlugin: + """Tests for the BECLiveDataPlugin class.""" + + def test_connect_registers_callback(self, mock_client): + """Test that connect registers scan_segment callback.""" + plugin = BECLiveDataPlugin(mock_client) + plugin.connect() + + mock_client.callbacks.register.assert_called_once_with( + "scan_segment", plugin._handle_scan_segment_update + ) + assert plugin._connect_id == "callback_id" + + def test_disconnect_removes_callback(self, mock_client): + """Test that disconnect removes callback and cleans up.""" + plugin = BECLiveDataPlugin(mock_client) + plugin.connect() + plugin.disconnect() + + mock_client.callbacks.remove.assert_called_once_with("callback_id") + assert plugin._connect_id is None + + def test_disconnect_unregisters_async_subscriptions(self, mock_client): + """Test that disconnect unregisters all async signal subscriptions.""" + plugin = BECLiveDataPlugin(mock_client) + + # Mock connector.unregister as a MagicMock + mock_client.connector.unregister = mock.MagicMock() + + # Simulate having async subscriptions + plugin._async_subscriptions[("scan1", "dev1", "sig1")] = _AsyncSubscription( + scan_id="scan1", + device_name="dev1", + device_entry="sig1", + callback_refs=[], + connector_id="conn_id_1", + ) + plugin._async_subscriptions[("scan2", "dev2", "sig2")] = _AsyncSubscription( + scan_id="scan2", + device_name="dev2", + device_entry="sig2", + callback_refs=[], + connector_id="conn_id_2", + ) + + plugin.disconnect() + + assert mock_client.connector.unregister.call_count == 2 + assert len(plugin._async_subscriptions) == 0 + + def test_has_scan_data_client_not_started(self, mock_client): + """Test has_scan_data returns False when client not started.""" + mock_client.started = False + plugin = BECLiveDataPlugin(mock_client) + + assert plugin.has_scan_data("test_scan") is False + + def test_has_scan_data_no_queue(self, mock_client): + """Test has_scan_data returns False when queue is None.""" + mock_client.queue = None + plugin = BECLiveDataPlugin(mock_client) + + assert plugin.has_scan_data("test_scan") is False + + def test_has_scan_data_scan_not_found(self, mock_client): + """Test has_scan_data returns False when scan not found.""" + mock_client.queue.scan_storage.find_scan_by_ID.return_value = None + plugin = BECLiveDataPlugin(mock_client) + + assert plugin.has_scan_data("test_scan") is False + + def test_has_scan_data_scan_closed(self, mock_client): + """Test has_scan_data returns False for closed scans.""" + scan_item = mock.MagicMock() + scan_item.status = "closed" + mock_client.queue.scan_storage.find_scan_by_ID.return_value = scan_item + + plugin = BECLiveDataPlugin(mock_client) + + assert plugin.has_scan_data("test_scan") is False + + def test_has_scan_data_scan_open(self, mock_client): + """Test has_scan_data returns True for open scans.""" + scan_item = mock.MagicMock() + scan_item.status = "open" + mock_client.queue.scan_storage.find_scan_by_ID.return_value = scan_item + + plugin = BECLiveDataPlugin(mock_client) + + assert plugin.has_scan_data("test_scan") is True + + def test_device_entry_is_monitored(self, mock_client, scan_item_with_monitored_devices): + """Test detection of monitored device entries.""" + plugin = BECLiveDataPlugin(mock_client) + + assert ( + plugin._device_entry_is_monitored("samx", "samx", scan_item_with_monitored_devices) + is True + ) + assert ( + plugin._device_entry_is_monitored("samz", "samz", scan_item_with_monitored_devices) + is False + ) + + def test_device_entry_is_async_signal(self, mock_client): + """Test detection of async signal device entries.""" + mock_client.device_manager.get_bec_signals.return_value = [ + ("async_sig1", None, {"device_name": "detector1"}), + ("async_sig2", None, {"device_name": "detector2"}), + ] + + plugin = BECLiveDataPlugin(mock_client) + + assert plugin._device_entry_is_async_signal("detector1", "async_sig1") is True + assert plugin._device_entry_is_async_signal("detector2", "async_sig2") is True + assert plugin._device_entry_is_async_signal("detector1", "async_sig2") is False + + def test_subscribe_to_monitored_device( + self, mock_client, scan_item_with_monitored_devices, mock_callback + ): + """Test subscription to monitored device.""" + plugin = BECLiveDataPlugin(mock_client) + + sub_id = plugin._subscribe_to_monitored_device( + "samx", "samx", "test_scan_id", mock_callback + ) + + assert sub_id is not None + assert sub_id in plugin._subscriptions + assert plugin._subscriptions[sub_id].device_name == "samx" + assert plugin._subscriptions[sub_id].device_entry == "samx" + assert plugin._subscriptions[sub_id].subscription_type == "monitored" + + def test_subscribe_to_monitored_device_multiple_callbacks( + self, mock_client, scan_item_with_monitored_devices + ): + """Test multiple callbacks for same monitored device.""" + plugin = BECLiveDataPlugin(mock_client) + + def callback1(data, metadata): + pass + + def callback2(data, metadata): + pass + + sub_id1 = plugin._subscribe_to_monitored_device("samx", "samx", "test_scan_id", callback1) + sub_id2 = plugin._subscribe_to_monitored_device("samx", "samx", "test_scan_id", callback2) + + assert sub_id1 != sub_id2 + assert len(plugin._monitored_subscriptions["test_scan_id"]) == 2 + + def test_subscribe_to_async_signal(self, mock_client, mock_callback): + """Test subscription to async signal.""" + mock_client.device_manager.get_bec_signals.return_value = [ + ("async_sig1", None, {"device_name": "detector1"}) + ] + + # Setup scan item + scan_item = ScanItem( + queue_id="test_queue", scan_number=[1], scan_id=["test_scan_id"], status="open" + ) + scan_item.status_message = messages.ScanStatusMessage( + scan_id="test_scan_id", status="open", info={} + ) + mock_client.queue.scan_storage.find_scan_by_ID.return_value = scan_item + + plugin = BECLiveDataPlugin(mock_client) + + # Mock the connector.register method properly + mock_client.connector.register = mock.MagicMock(return_value="redis_conn_id") + + sub_id = plugin._subscribe_to_async_signal( + "detector1", "async_sig1", "test_scan_id", mock_callback + ) + + assert sub_id is not None + assert sub_id in plugin._subscriptions + assert plugin._subscriptions[sub_id].subscription_type == "async_signal" + + # Check that redis connector was registered + mock_client.connector.register.assert_called_once() + call_args = mock_client.connector.register.call_args + assert call_args.kwargs["from_start"] is True + assert call_args.kwargs["scan_id"] == "test_scan_id" + assert call_args.kwargs["device_name"] == "detector1" + assert call_args.kwargs["device_entry"] == "async_sig1" + + def test_subscribe_to_async_signal_shared_subscription(self, mock_client): + """Test that multiple callbacks share the same async signal subscription.""" + mock_client.device_manager.get_bec_signals.return_value = [ + ("async_sig1", None, {"device_name": "detector1"}) + ] + + scan_item = ScanItem( + queue_id="test_queue", scan_number=[1], scan_id=["test_scan_id"], status="open" + ) + scan_item.status_message = messages.ScanStatusMessage( + scan_id="test_scan_id", status="open", info={} + ) + mock_client.queue.scan_storage.find_scan_by_ID.return_value = scan_item + + plugin = BECLiveDataPlugin(mock_client) + + def callback1(data, metadata): + pass + + def callback2(data, metadata): + pass + + # Mock the connector.register method properly + mock_client.connector.register = mock.MagicMock(return_value="redis_conn_id") + + sub_id1 = plugin._subscribe_to_async_signal( + "detector1", "async_sig1", "test_scan_id", callback1 + ) + sub_id2 = plugin._subscribe_to_async_signal( + "detector1", "async_sig1", "test_scan_id", callback2 + ) + + # Should only register once with redis connector + assert mock_client.connector.register.call_count == 1 + + # Both subscriptions should share the same async subscription + key = ("test_scan_id", "detector1", "async_sig1") + assert key in plugin._async_subscriptions + assert len(plugin._async_subscriptions[key].callback_refs) == 2 + + def test_handle_scan_segment_update_monitored( + self, mock_client, scan_item_with_monitored_devices, mock_callback + ): + """Test handling scan segment updates for monitored devices with proper ScanMessage.""" + plugin = BECLiveDataPlugin(mock_client) + + # Subscribe to monitored device + plugin._subscribe_to_monitored_device("samx", "samx", "test_scan_id", mock_callback) + + # Create proper ScanMessages and set them in live_data + # Each ScanMessage represents one point with single values + scan_msgs = [ + messages.ScanMessage( + point_id=0, + scan_id="test_scan_id", + data={"samx": {"samx": {"value": 1.0, "timestamp": 100.0}}}, + metadata={"scan_id": "test_scan_id"}, + ), + messages.ScanMessage( + point_id=1, + scan_id="test_scan_id", + data={"samx": {"samx": {"value": 2.0, "timestamp": 101.0}}}, + metadata={"scan_id": "test_scan_id"}, + ), + messages.ScanMessage( + point_id=2, + scan_id="test_scan_id", + data={"samx": {"samx": {"value": 3.0, "timestamp": 102.0}}}, + metadata={"scan_id": "test_scan_id"}, + ), + ] + + # Set the messages in live_data + for idx, msg in enumerate(scan_msgs): + scan_item_with_monitored_devices.live_data.set(idx, msg) + + # Trigger the callback for each point + for msg in scan_msgs: + plugin._handle_scan_segment_update(msg.content, msg.metadata) + + # Callback should be called for each data point + assert len(mock_callback.calls) >= 1 # At least one call should have been made + + # If called, check the structure of the first call + if len(mock_callback.calls) > 0: + call_data, call_metadata = mock_callback.calls[0] + assert "samx" in call_data + assert "samx" in call_data["samx"] + assert call_metadata["scan_id"] == "test_scan_id" + + def test_async_signal_callback_with_device_message(self, mock_client, mock_callback): + """Test async signal callback receives proper DeviceMessage.""" + plugin = BECLiveDataPlugin(mock_client) + + # Create async subscription manually + callback_ref = louie.saferef.safe_ref(mock_callback) + async_sub = _AsyncSubscription( + scan_id="test_scan_id", + device_name="detector1", + device_entry="async_sig1", + callback_refs=[callback_ref], + connector_id="conn_id", + ) + plugin._async_subscriptions[("test_scan_id", "detector1", "async_sig1")] = async_sub + + # Create a proper DeviceMessage + device_msg = messages.DeviceMessage( + signals={ + "async_sig1": {"value": 42.0, "timestamp": 123.456}, + "status": {"value": "ok", "timestamp": 123.456}, + }, + metadata={"timestamp": 123.456}, + ) + + # Call the static callback method + BECLiveDataPlugin._async_signal_sync_callback( + {"data": device_msg}, plugin, "test_scan_id", "detector1", "async_sig1" + ) + + # Callback should be invoked with the data + assert len(mock_callback.calls) == 1 + call_data, call_metadata = mock_callback.calls[0] + assert "detector1" in call_data + # The signals dict is stored as the value + assert call_data["detector1"]["async_sig1"]["value"] == device_msg.signals + assert call_data["detector1"]["async_sig1"]["timestamp"] == 123.456 + + def test_data_synchronization_mixed_sources( + self, mock_client, scan_item_with_monitored_devices, mock_callback + ): + """Test data synchronization between monitored and async sources.""" + mock_client.device_manager.get_bec_signals.return_value = [ + ("async_sig1", None, {"device_name": "detector1"}) + ] + + plugin = BECLiveDataPlugin(mock_client) + + # Subscribe to both monitored and async + plugin._subscribe_to_monitored_device("samx", "samx", "test_scan_id", mock_callback) + + # Mock the connector.register method properly + mock_client.connector.register = mock.MagicMock(return_value="redis_conn_id") + plugin._subscribe_to_async_signal("detector1", "async_sig1", "test_scan_id", mock_callback) + + # Add monitored data via proper ScanMessage + scan_msg = messages.ScanMessage( + point_id=0, + scan_id="test_scan_id", + data={"samx": {"samx": {"value": 1.0, "timestamp": 100.0}}}, + metadata={"scan_id": "test_scan_id"}, + ) + scan_item_with_monitored_devices.live_data.set(0, scan_msg) + + plugin._handle_scan_segment_update(scan_msg.content, scan_msg.metadata) + + # At this point, callback should not be called yet (waiting for async data) + assert len(mock_callback.calls) == 0 + + # Now add async data + device_msg = messages.DeviceMessage( + signals={"async_sig1": {"value": 10.0, "timestamp": 100.0}} + ) + + BECLiveDataPlugin._async_signal_sync_callback( + {"data": device_msg}, plugin, "test_scan_id", "detector1", "async_sig1" + ) + + # Callback should eventually be called with synchronized data + # The synchronization requires all sources to have data + assert len(mock_callback.calls) == 1 + + # Check that the data contains both sources + call_data, call_metadata = mock_callback.calls[0] + assert "samx" in call_data + assert "detector1" in call_data + + def test_unsubscribe_monitored_device( + self, mock_client, scan_item_with_monitored_devices, mock_callback + ): + """Test unsubscribing from monitored device.""" + plugin = BECLiveDataPlugin(mock_client) + + sub_id = plugin._subscribe_to_monitored_device( + "samx", "samx", "test_scan_id", mock_callback + ) + + # Verify subscription exists + assert sub_id in plugin._subscriptions + assert "test_scan_id" in plugin._monitored_subscriptions + + # Unsubscribe + plugin.unsubscribe(subscription_id=sub_id) + + # Verify cleanup + assert sub_id not in plugin._subscriptions + assert "test_scan_id" not in plugin._monitored_subscriptions + + def test_unsubscribe_async_signal(self, mock_client, mock_callback): + """Test unsubscribing from async signal.""" + mock_client.device_manager.get_bec_signals.return_value = [ + ("async_sig1", None, {"device_name": "detector1"}) + ] + + scan_item = ScanItem( + queue_id="test_queue", scan_number=[1], scan_id=["test_scan_id"], status="open" + ) + scan_item.status_message = messages.ScanStatusMessage( + scan_id="test_scan_id", status="open", info={} + ) + mock_client.queue.scan_storage.find_scan_by_ID.return_value = scan_item + + plugin = BECLiveDataPlugin(mock_client) + + # Mock the connector methods properly + mock_client.connector.register = mock.MagicMock(return_value="redis_conn_id") + mock_client.connector.unregister = mock.MagicMock() + + sub_id = plugin._subscribe_to_async_signal( + "detector1", "async_sig1", "test_scan_id", mock_callback + ) + + # Verify subscription exists + assert sub_id in plugin._subscriptions + key = ("test_scan_id", "detector1", "async_sig1") + assert key in plugin._async_subscriptions + + # Unsubscribe + plugin.unsubscribe(subscription_id=sub_id) + + # Verify cleanup + assert sub_id not in plugin._subscriptions + assert key not in plugin._async_subscriptions + mock_client.connector.unregister.assert_called_once_with("redis_conn_id") + + def test_unsubscribe_by_scan_id(self, mock_client, scan_item_with_monitored_devices): + """Test unsubscribing all subscriptions for a scan ID.""" + plugin = BECLiveDataPlugin(mock_client) + + def callback1(data, metadata): + pass + + def callback2(data, metadata): + pass + + sub_id1 = plugin._subscribe_to_monitored_device("samx", "samx", "test_scan_id", callback1) + sub_id2 = plugin._subscribe_to_monitored_device("samy", "samy", "test_scan_id", callback2) + + # Unsubscribe all for scan + plugin.unsubscribe(scan_id="test_scan_id") + + # Verify all cleaned up + assert sub_id1 not in plugin._subscriptions + assert sub_id2 not in plugin._subscriptions + assert "test_scan_id" not in plugin._monitored_subscriptions + + def test_can_provide_monitored(self, mock_client, scan_item_with_monitored_devices): + """Test can_provide returns True for monitored devices.""" + plugin = BECLiveDataPlugin(mock_client) + + assert plugin.can_provide("samx", "samx", "test_scan_id") is True + assert plugin.can_provide("samz", "samz", "test_scan_id") is False + + def test_can_provide_async_signal(self, mock_client): + """Test can_provide returns True for async signals.""" + mock_client.device_manager.get_bec_signals.return_value = [ + ("async_sig1", None, {"device_name": "detector1"}) + ] + + scan_item = ScanItem( + queue_id="test_queue", scan_number=[1], scan_id=["test_scan_id"], status="open" + ) + scan_item.status_message = messages.ScanStatusMessage( + scan_id="test_scan_id", status="open", info={} + ) + mock_client.queue.scan_storage.find_scan_by_ID.return_value = scan_item + + plugin = BECLiveDataPlugin(mock_client) + + assert plugin.can_provide("detector1", "async_sig1", "test_scan_id") is True + assert plugin.can_provide("detector1", "other_signal", "test_scan_id") is False + + def test_get_info(self, live_plugin): + """Test get_info returns empty dict (base implementation).""" + info = live_plugin.get_info() + assert info == {} + + def test_cache_invalidation_on_new_scan(self, mock_client): + """Test that device mode cache is properly scoped per scan.""" + plugin = BECLiveDataPlugin(mock_client) + + # Setup first scan with samx monitored + scan_item1 = ScanItem(queue_id="queue1", scan_number=[1], scan_id=["scan_1"], status="open") + scan_item1.status_message = messages.ScanStatusMessage( + scan_id="scan_1", status="open", info={} + ) + scan_item1.status_message.readout_priority = {"monitored": ["samx"]} + + # Setup second scan with samx as baseline (not monitored) + scan_item2 = ScanItem(queue_id="queue2", scan_number=[2], scan_id=["scan_2"], status="open") + scan_item2.status_message = messages.ScanStatusMessage( + scan_id="scan_2", status="open", info={} + ) + scan_item2.status_message.readout_priority = {"baseline": ["samx"]} + + def find_scan_side_effect(scan_id): + if scan_id == "scan_1": + return scan_item1 + elif scan_id == "scan_2": + return scan_item2 + return None + + mock_client.queue.scan_storage.find_scan_by_ID.side_effect = find_scan_side_effect + + # Check scan_1 - samx should be monitored + mode1 = plugin._get_device_mode("samx", "samx", "scan_1") + assert mode1 == "monitored" + + # Check scan_2 - samx should NOT be monitored (different scan) + mode2 = plugin._get_device_mode("samx", "samx", "scan_2") + assert mode2 is None + + +class TestDataSubscription: + """Test suite for the DataSubscription class.""" + + def test_create_subscription(self, data_api): + """Test creating a subscription object.""" + sub = data_api.create_subscription("test_scan") + assert sub.scan_id == "test_scan" + assert sub.devices == [] + sub.close() + + def test_add_device_without_callback(self, data_api): + """Test adding devices before setting a callback.""" + sub = data_api.create_subscription("test_scan") + sub.add_device("samx", "samx") + sub.add_device("samy", "samy") + + assert len(sub.devices) == 2 + assert ("samx", "samx") in sub.devices + assert ("samy", "samy") in sub.devices + sub.close() + + def test_add_device_with_callback( + self, data_api, mock_callback, scan_item_with_monitored_devices + ): + """Test adding devices after setting a callback triggers immediate subscription.""" + sub = data_api.create_subscription("test_scan_id") + sub.set_callback(mock_callback) + sub.add_device("samx", "samx") + + assert len(sub.devices) == 1 + assert ("samx", "samx") in sub.devices + sub.close() + + def test_set_callback_after_devices( + self, data_api, mock_callback, scan_item_with_monitored_devices + ): + """Test setting callback after adding devices triggers subscription.""" + sub = data_api.create_subscription("test_scan_id") + sub.add_device("samx", "samx") + sub.add_device("samy", "samy") + + # Setting callback should subscribe all queued devices + sub.set_callback(mock_callback) + + assert len(sub.devices) == 2 + sub.close() + + def test_method_chaining(self, data_api, mock_callback, scan_item_with_monitored_devices): + """Test that methods support chaining.""" + sub = ( + data_api.create_subscription("test_scan_id") + .add_device("samx", "samx") + .add_device("samy", "samy") + .set_callback(mock_callback) + ) + + assert len(sub.devices) == 2 + sub.close() + + def test_remove_device(self, data_api, mock_callback, scan_item_with_monitored_devices): + """Test removing a device from subscription.""" + sub = data_api.create_subscription("test_scan_id") + sub.set_callback(mock_callback) + sub.add_device("samx", "samx") + sub.add_device("samy", "samy") + + assert len(sub.devices) == 2 + + sub.remove_device("samx", "samx") + assert len(sub.devices) == 1 + assert ("samx", "samx") not in sub.devices + assert ("samy", "samy") in sub.devices + + sub.close() + + def test_close_unsubscribes_all( + self, data_api, mock_callback, scan_item_with_monitored_devices + ): + """Test that close() unsubscribes from all devices.""" + sub = data_api.create_subscription("test_scan_id") + sub.set_callback(mock_callback) + sub.add_device("samx", "samx") + sub.add_device("samy", "samy") + + sub.close() + + # After close, devices should be cleared + assert len(sub.devices) == 0 + + def test_context_manager(self, data_api, mock_callback, scan_item_with_monitored_devices): + """Test using subscription as a context manager.""" + with data_api.create_subscription("test_scan_id") as sub: + sub.set_callback(mock_callback) + sub.add_device("samx", "samx") + assert len(sub.devices) == 1 + + # After context exit, subscription should be closed + with pytest.raises(RuntimeError, match="Cannot add device to a closed subscription"): + sub.add_device("samy", "samy") + + def test_cannot_modify_closed_subscription(self, data_api, mock_callback): + """Test that operations on closed subscription raise errors.""" + sub = data_api.create_subscription("test_scan") + sub.close() + + with pytest.raises(RuntimeError, match="Cannot add device to a closed subscription"): + sub.add_device("samx", "samx") + + with pytest.raises(RuntimeError, match="Cannot remove device from a closed subscription"): + sub.remove_device("samx", "samx") + + with pytest.raises(RuntimeError, match="Cannot set callback on a closed subscription"): + sub.set_callback(mock_callback) + + with pytest.raises(RuntimeError, match="Cannot reload a closed subscription"): + sub.reload() + + def test_reload_resubscribes(self, data_api, mock_callback, scan_item_with_monitored_devices): + """Test that reload() resubscribes to all devices.""" + sub = data_api.create_subscription("test_scan_id") + sub.set_callback(mock_callback) + sub.add_device("samx", "samx") + + # Reload should resubscribe + sub.reload() + + assert len(sub.devices) == 1 + sub.close() + + def test_callback_change_resubscribes(self, data_api, scan_item_with_monitored_devices): + """Test that changing callback resubscribes all devices.""" + calls1 = [] + calls2 = [] + + def callback1(data, metadata): + calls1.append((data, metadata)) + + def callback2(data, metadata): + calls2.append((data, metadata)) + + sub = data_api.create_subscription("test_scan_id") + sub.set_callback(callback1) + sub.add_device("samx", "samx") + + # Change callback + sub.set_callback(callback2) + + # Should still have the same device + assert len(sub.devices) == 1 + sub.close() + + def test_add_duplicate_device(self, data_api, mock_callback, scan_item_with_monitored_devices): + """Test adding the same device twice doesn't create duplicate subscriptions.""" + sub = data_api.create_subscription("test_scan_id") + sub.set_callback(mock_callback) + sub.add_device("samx", "samx") + sub.add_device("samx", "samx") # duplicate + + assert len(sub.devices) == 1 + sub.close() + + def test_remove_nonexistent_device(self, data_api, mock_callback): + """Test removing a device that wasn't subscribed doesn't cause errors.""" + sub = data_api.create_subscription("test_scan") + sub.set_callback(mock_callback) + + # Should not raise an error + sub.remove_device("nonexistent", "signal") + sub.close() + + def test_reload_without_callback(self, data_api): + """Test that reload without callback logs warning and doesn't crash.""" + sub = data_api.create_subscription("test_scan") + sub.add_device("samx", "samx") + + # Should not crash, just log warning + sub.reload() + sub.close() + + def test_destructor_cleanup(self, data_api, mock_callback, scan_item_with_monitored_devices): + """Test that __del__ properly cleans up subscriptions.""" + sub = data_api.create_subscription("test_scan_id") + sub.set_callback(mock_callback) + sub.add_device("samx", "samx") + + # Explicitly delete the object + del sub + + # Subscription should be cleaned up (this is implicitly tested by no errors) + + def test_set_scan_id(self, data_api, mock_callback, mock_client): + """Test changing the scan_id.""" + # Setup two scan items + scan_item1 = ScanItem(queue_id="queue1", scan_number=[1], scan_id=["scan_1"], status="open") + scan_item1.status_message = messages.ScanStatusMessage( + scan_id="scan_1", status="open", info={} + ) + scan_item1.status_message.readout_priority = {"monitored": ["samx"]} + scan_item1.live_data = LiveScanData() + + scan_item2 = ScanItem(queue_id="queue2", scan_number=[2], scan_id=["scan_2"], status="open") + scan_item2.status_message = messages.ScanStatusMessage( + scan_id="scan_2", status="open", info={} + ) + scan_item2.status_message.readout_priority = {"monitored": ["samx"]} + scan_item2.live_data = LiveScanData() + + def find_scan_side_effect(scan_id): + if scan_id == "scan_1": + return scan_item1 + elif scan_id == "scan_2": + return scan_item2 + return None + + mock_client.queue.scan_storage.find_scan_by_ID.side_effect = find_scan_side_effect + + sub = data_api.create_subscription("scan_1") + sub.set_callback(mock_callback) + sub.add_device("samx", "samx") + + assert sub.scan_id == "scan_1" + + # Change to new scan + sub.set_scan_id("scan_2") + assert sub.scan_id == "scan_2" + + # Devices should still be there + assert len(sub.devices) == 1 + sub.close() + + def test_set_scan_id_without_callback(self, data_api): + """Test changing scan_id when no callback is set yet.""" + sub = data_api.create_subscription("scan_1") + sub.add_device("samx", "samx") + + # Should work without errors + sub.set_scan_id("scan_2") + assert sub.scan_id == "scan_2" + assert len(sub.devices) == 1 + sub.close() + + def test_set_scan_id_same_value(self, data_api, mock_callback): + """Test setting scan_id to the same value is a no-op.""" + sub = data_api.create_subscription("scan_1") + sub.set_callback(mock_callback) + + # Should not trigger resubscription + sub.set_scan_id("scan_1") + assert sub.scan_id == "scan_1" + sub.close() + + def test_set_scan_id_on_closed(self, data_api): + """Test that changing scan_id on closed subscription raises error.""" + sub = data_api.create_subscription("scan_1") + sub.close() + + with pytest.raises(RuntimeError, match="Cannot change scan_id on a closed subscription"): + sub.set_scan_id("scan_2") + + def test_set_scan_id_method_chaining(self, data_api, mock_callback, mock_client): + """Test that set_scan_id supports method chaining.""" + scan_item = ScanItem(queue_id="queue1", scan_number=[1], scan_id=["scan_2"], status="open") + scan_item.status_message = messages.ScanStatusMessage( + scan_id="scan_2", status="open", info={} + ) + scan_item.status_message.readout_priority = {"monitored": ["samx"]} + scan_item.live_data = LiveScanData() + mock_client.queue.scan_storage.find_scan_by_ID.return_value = scan_item + + sub = ( + data_api.create_subscription("scan_1") + .set_callback(mock_callback) + .set_scan_id("scan_2") + .add_device("samx", "samx") + ) + + assert sub.scan_id == "scan_2" + assert len(sub.devices) == 1 + sub.close() + + +class TestDataSubscriptionBuffered: + """Test suite for buffered mode in DataSubscription.""" + + def test_create_buffered_subscription(self, data_api): + """Test creating a buffered subscription.""" + sub = data_api.create_subscription("test_scan", buffered=True) + assert sub.buffered is True + assert sub.scan_id == "test_scan" + sub.close() + + def test_create_non_buffered_subscription(self, data_api): + """Test creating a non-buffered subscription (default).""" + sub = data_api.create_subscription("test_scan") + assert sub.buffered is False + sub.close() + + def test_buffered_mode_accumulates_data(self, data_api, scan_item_with_monitored_devices): + """Test that buffered mode accumulates and re-emits all data.""" + calls = [] + + def callback(data, metadata): + # Deep copy to capture state at callback time + calls.append(copy.deepcopy(data)) + + sub = data_api.create_subscription("test_scan_id", buffered=True) + sub.set_callback(callback) + sub.add_device("samx", "samx") + + # Simulate multiple data updates through the plugin + plugin = data_api.plugins[0] + + # First data point + scan_msg1 = messages.ScanMessage( + point_id=0, + scan_id="test_scan_id", + data={"samx": {"samx": {"value": 1.0, "timestamp": 100.0}}}, + metadata={"scan_id": "test_scan_id"}, + ) + scan_item_with_monitored_devices.live_data.set(0, scan_msg1) + plugin._handle_scan_segment_update(scan_msg1.content, scan_msg1.metadata) + + # Second data point + scan_msg2 = messages.ScanMessage( + point_id=1, + scan_id="test_scan_id", + data={"samx": {"samx": {"value": 2.0, "timestamp": 200.0}}}, + metadata={"scan_id": "test_scan_id"}, + ) + scan_item_with_monitored_devices.live_data.set(1, scan_msg2) + plugin._handle_scan_segment_update(scan_msg2.content, scan_msg2.metadata) + + # In buffered mode, each callback should contain ALL accumulated data + # The plugin processes all data in live_data each time, so we get multiple callbacks + assert len(calls) >= 2 + + # Find the last call - it should have all accumulated data + last_call = calls[-1] + assert "samx" in last_call + assert "samx" in last_call["samx"] + # Should be a list with all accumulated points + assert isinstance(last_call["samx"]["samx"], list) + assert len(last_call["samx"]["samx"]) >= 2 + # Verify the accumulated values are present + values = [point["value"] for point in last_call["samx"]["samx"]] + assert 1.0 in values + assert 2.0 in values + + sub.close() + + def test_non_buffered_mode_emits_only_new(self, data_api, scan_item_with_monitored_devices): + """Test that non-buffered mode only emits new data blocks.""" + calls = [] + + def callback(data, metadata): + calls.append(copy.deepcopy(data)) + + sub = data_api.create_subscription("test_scan_id", buffered=False) + sub.set_callback(callback) + sub.add_device("samx", "samx") + + plugin = data_api.plugins[0] + + # First data point + scan_msg1 = messages.ScanMessage( + point_id=0, + scan_id="test_scan_id", + data={"samx": {"samx": {"value": 1.0, "timestamp": 100.0}}}, + metadata={"scan_id": "test_scan_id"}, + ) + scan_item_with_monitored_devices.live_data.set(0, scan_msg1) + plugin._handle_scan_segment_update(scan_msg1.content, scan_msg1.metadata) + + # Second data point + scan_msg2 = messages.ScanMessage( + point_id=1, + scan_id="test_scan_id", + data={"samx": {"samx": {"value": 2.0, "timestamp": 200.0}}}, + metadata={"scan_id": "test_scan_id"}, + ) + scan_item_with_monitored_devices.live_data.set(1, scan_msg2) + plugin._handle_scan_segment_update(scan_msg2.content, scan_msg2.metadata) + + # In non-buffered mode, each callback contains only individual data blocks + assert len(calls) >= 2 + + # All calls should have single value/timestamp dict (not a list) + for call in calls: + assert "samx" in call + assert "samx" in call["samx"] + # In non-buffered mode, data is NOT a list + assert isinstance(call["samx"]["samx"], dict) + assert "value" in call["samx"]["samx"] + assert "timestamp" in call["samx"]["samx"] + + # Verify we got both values at some point + values = [call["samx"]["samx"]["value"] for call in calls] + assert 1.0 in values + assert 2.0 in values + + sub.close() + + def test_set_buffered_mode(self, data_api): + """Test changing buffered mode dynamically.""" + sub = data_api.create_subscription("test_scan", buffered=False) + assert sub.buffered is False + + sub.set_buffered(True) + assert sub.buffered is True + + sub.set_buffered(False) + assert sub.buffered is False + + sub.close() + + def test_set_buffered_same_value_is_noop(self, data_api): + """Test setting buffered to the same value is a no-op.""" + sub = data_api.create_subscription("test_scan", buffered=True) + sub.set_buffered(True) # No-op + assert sub.buffered is True + sub.close() + + def test_set_buffered_on_closed_raises_error(self, data_api): + """Test that changing buffered mode on closed subscription raises error.""" + sub = data_api.create_subscription("test_scan") + sub.close() + + with pytest.raises( + RuntimeError, match="Cannot change buffered mode on a closed subscription" + ): + sub.set_buffered(True) + + def test_set_buffered_clears_buffer_when_disabling( + self, data_api, scan_item_with_monitored_devices + ): + """Test that disabling buffered mode clears the accumulated buffer.""" + calls = [] + + def callback(data, metadata): + calls.append(copy.deepcopy(data)) + + sub = data_api.create_subscription("test_scan_id", buffered=True) + sub.set_callback(callback) + sub.add_device("samx", "samx") + + plugin = data_api.plugins[0] + + # Add some data in buffered mode + scan_msg1 = messages.ScanMessage( + point_id=0, + scan_id="test_scan_id", + data={"samx": {"samx": {"value": 1.0, "timestamp": 100.0}}}, + metadata={"scan_id": "test_scan_id"}, + ) + scan_item_with_monitored_devices.live_data.set(0, scan_msg1) + plugin._handle_scan_segment_update(scan_msg1.content, scan_msg1.metadata) + + assert len(calls) >= 1 + # Verify we got buffered data (as a list) + assert isinstance(calls[-1]["samx"]["samx"], list) + calls.clear() + + # Switch to non-buffered mode (should clear buffer) + sub.set_buffered(False) + + # Add new data - should only get this new data, not accumulated buffer + scan_msg2 = messages.ScanMessage( + point_id=1, + scan_id="test_scan_id", + data={"samx": {"samx": {"value": 2.0, "timestamp": 200.0}}}, + metadata={"scan_id": "test_scan_id"}, + ) + scan_item_with_monitored_devices.live_data.set(1, scan_msg2) + plugin._handle_scan_segment_update(scan_msg2.content, scan_msg2.metadata) + + assert len(calls) >= 1 + # Should have non-buffered data (dict, not list) + assert isinstance(calls[-1]["samx"]["samx"], dict) + assert calls[-1]["samx"]["samx"]["value"] == 2.0 + + sub.close() + + def test_buffered_method_chaining(self, data_api, mock_callback): + """Test that set_buffered supports method chaining.""" + sub = ( + data_api.create_subscription("test_scan").set_buffered(True).set_callback(mock_callback) + ) + + assert sub.buffered is True + sub.close() + + def test_scan_id_change_clears_buffer(self, data_api, mock_client): + """Test that changing scan_id clears the accumulated buffer.""" + calls = [] + + def callback(data, metadata): + calls.append(copy.deepcopy(data)) + + # Setup two scans + scan_item1 = ScanItem(queue_id="queue1", scan_number=[1], scan_id=["scan_1"], status="open") + scan_item1.status_message = messages.ScanStatusMessage( + scan_id="scan_1", status="open", info={} + ) + scan_item1.status_message.readout_priority = {"monitored": ["samx"]} + scan_item1.live_data = LiveScanData() + + scan_item2 = ScanItem(queue_id="queue2", scan_number=[2], scan_id=["scan_2"], status="open") + scan_item2.status_message = messages.ScanStatusMessage( + scan_id="scan_2", status="open", info={} + ) + scan_item2.status_message.readout_priority = {"monitored": ["samx"]} + scan_item2.live_data = LiveScanData() + + def find_scan_side_effect(scan_id): + if scan_id == "scan_1": + return scan_item1 + elif scan_id == "scan_2": + return scan_item2 + return None + + mock_client.queue.scan_storage.find_scan_by_ID.side_effect = find_scan_side_effect + + sub = data_api.create_subscription("scan_1", buffered=True) + sub.set_callback(callback) + sub.add_device("samx", "samx") + + plugin = data_api.plugins[0] + + # Add data to scan_1 + scan_msg1 = messages.ScanMessage( + point_id=0, + scan_id="scan_1", + data={"samx": {"samx": {"value": 1.0, "timestamp": 100.0}}}, + metadata={"scan_id": "scan_1"}, + ) + scan_item1.live_data.set(0, scan_msg1) + plugin._handle_scan_segment_update(scan_msg1.content, scan_msg1.metadata) + + assert len(calls) == 1 + calls.clear() + + # Change scan_id (should clear buffer) + sub.set_scan_id("scan_2") + + # Add data to scan_2 - should start fresh buffer + scan_msg2 = messages.ScanMessage( + point_id=0, + scan_id="scan_2", + data={"samx": {"samx": {"value": 2.0, "timestamp": 200.0}}}, + metadata={"scan_id": "scan_2"}, + ) + scan_item2.live_data.set(0, scan_msg2) + plugin._handle_scan_segment_update(scan_msg2.content, scan_msg2.metadata) + + assert len(calls) == 1 + # Should only have data from scan_2 (one point) + assert len(calls[0]["samx"]["samx"]) == 1 + assert calls[0]["samx"]["samx"][0]["value"] == 2.0 + + sub.close() From 26c3f855d37c3f22875500550f4121aebcad822d Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Mon, 12 Jan 2026 09:47:49 +0100 Subject: [PATCH 2/2] wip - async signals --- bec_lib/bec_lib/data_api/data_api.py | 14 +++++++ bec_lib/bec_lib/data_api/plugins.py | 55 +++++++++++++++------------- bec_lib/tests/test_data_api.py | 16 ++++---- 3 files changed, 52 insertions(+), 33 deletions(-) diff --git a/bec_lib/bec_lib/data_api/data_api.py b/bec_lib/bec_lib/data_api/data_api.py index 37a5b64ed..c0b4e0414 100644 --- a/bec_lib/bec_lib/data_api/data_api.py +++ b/bec_lib/bec_lib/data_api/data_api.py @@ -416,3 +416,17 @@ def unsubscribe( """ for plugin in self.plugins: plugin.unsubscribe(subscription_id, scan_id, callback) + + +if __name__ == "__main__": + from bec_lib.client import BECClient + + def my_callback(data, metadata): + print("Received data:", data) + print("With metadata:", metadata) + + client = BECClient() + data_api = DataAPI(client) + sub = data_api.create_subscription("test_scan") + sub.add_device(device_name="waveform", device_entry="waveform_data") + sub.set_callback(my_callback) diff --git a/bec_lib/bec_lib/data_api/plugins.py b/bec_lib/bec_lib/data_api/plugins.py index cbd1614f3..8aee1066c 100644 --- a/bec_lib/bec_lib/data_api/plugins.py +++ b/bec_lib/bec_lib/data_api/plugins.py @@ -295,13 +295,28 @@ def _device_entry_is_async_signal(self, device_name: str, device_entry: str) -> Returns: True if the device entry is an async signal, False otherwise. """ + async_signal_info = self._get_async_signal_info(device_name, device_entry) + return async_signal_info is not None + + def _get_async_signal_info(self, device_name: str, device_entry: str) -> dict | None: + """ + Get the async signal information for the given device and entry. + + Args: + device_name: Name of the device. + device_entry: Specific entry of the device. + Returns: + The async signal information dict if found, None otherwise. + """ if not self.client.device_manager: - return False - async_signals = self.client.device_manager.get_bec_signals("AsyncSignal") - for entry_name, _, entry_data in async_signals: - if entry_name == device_entry and entry_data.get("device_name") == device_name: - return True - return False + return None + async_signals = self.client.device_manager.get_bec_signals( + ["AsyncSignal", "AsyncMultiSignal", "DynamicSignal"] + ) + for dev_name, _, entry_info in async_signals: + if entry_info.get("obj_name") == device_entry and dev_name == device_name: + return entry_info + return None def subscribe( self, @@ -665,9 +680,16 @@ def _subscribe_to_async_signal( async_sub.callback_refs.append(callback_ref) else: # Create new redis connector subscription + async_signal_info = self._get_async_signal_info(device_name, device_entry) + if async_signal_info is None: + raise ValueError( + f"Cannot subscribe to async signal '{device_name}' entry '{device_entry}': signal not found." + ) connector_id = self.client.connector.register( MessageEndpoints.device_async_signal( - scan_id=scan_id, device=device_name, signal=device_entry + scan_id=scan_id, + device=device_name, + signal=async_signal_info.get("storage_name"), ), cb=self._async_signal_sync_callback, from_start=True, @@ -783,11 +805,7 @@ def _get_expected_device_count(self, callback_ref: CallbackRef, scan_id: str) -> count += len(self._monitored_subscriptions[scan_id][callback_ref].devices) # Count async signals - for ( - sub_scan_id, - device_name, - device_entry, - ), async_sub in self._async_subscriptions.items(): + for (sub_scan_id, _, _), async_sub in self._async_subscriptions.items(): if sub_scan_id == scan_id and callback_ref in async_sub.callback_refs: count += 1 @@ -854,16 +872,3 @@ def _check_and_emit_synchronized_data(self, callback_ref: CallbackRef, scan_id: # Update the min_length to track what we've already emitted callback_buffer.min_length = min_length - - -""" -NOTES - -- AsyncSignal subscriptions should be shared between multiple subscribers to avoid redundant subscriptions. -- Whenever the redis connector triggers the callback, we broadcast to all subscribers of that device/entry/scan combination. -- When the user subscribed to an AsyncSignal, we check if there is already a subscription for that device/entry/scan combination. -- If yes, we just add the callback to the list of callbacks for that subscription. -- When the user subscribes to multiple AsyncSignals, we synchronize the data length and only broadcast the data of equal length. Same for mixtures - of monitored devices and AsyncSignals. - -""" diff --git a/bec_lib/tests/test_data_api.py b/bec_lib/tests/test_data_api.py index ad5e01111..e5551cca0 100644 --- a/bec_lib/tests/test_data_api.py +++ b/bec_lib/tests/test_data_api.py @@ -308,8 +308,8 @@ def test_device_entry_is_monitored(self, mock_client, scan_item_with_monitored_d def test_device_entry_is_async_signal(self, mock_client): """Test detection of async signal device entries.""" mock_client.device_manager.get_bec_signals.return_value = [ - ("async_sig1", None, {"device_name": "detector1"}), - ("async_sig2", None, {"device_name": "detector2"}), + ("detector1", None, {"obj_name": "async_sig1", "storage_name": "async_sig1"}), + ("detector2", None, {"obj_name": "async_sig2", "storage_name": "async_sig2"}), ] plugin = BECLiveDataPlugin(mock_client) @@ -355,12 +355,12 @@ def callback2(data, metadata): def test_subscribe_to_async_signal(self, mock_client, mock_callback): """Test subscription to async signal.""" mock_client.device_manager.get_bec_signals.return_value = [ - ("async_sig1", None, {"device_name": "detector1"}) + ("detector1", None, {"obj_name": "async_sig1", "storage_name": "async_sig1"}) ] # Setup scan item scan_item = ScanItem( - queue_id="test_queue", scan_number=[1], scan_id=["test_scan_id"], status="open" + queue_id="test_queue", scan_number=1, scan_id="test_scan_id", status="open" ) scan_item.status_message = messages.ScanStatusMessage( scan_id="test_scan_id", status="open", info={} @@ -391,7 +391,7 @@ def test_subscribe_to_async_signal(self, mock_client, mock_callback): def test_subscribe_to_async_signal_shared_subscription(self, mock_client): """Test that multiple callbacks share the same async signal subscription.""" mock_client.device_manager.get_bec_signals.return_value = [ - ("async_sig1", None, {"device_name": "detector1"}) + ("detector1", None, {"obj_name": "async_sig1", "storage_name": "async_sig1"}) ] scan_item = ScanItem( @@ -520,7 +520,7 @@ def test_data_synchronization_mixed_sources( ): """Test data synchronization between monitored and async sources.""" mock_client.device_manager.get_bec_signals.return_value = [ - ("async_sig1", None, {"device_name": "detector1"}) + ("detector1", None, {"obj_name": "async_sig1", "storage_name": "async_sig1"}) ] plugin = BECLiveDataPlugin(mock_client) @@ -588,7 +588,7 @@ def test_unsubscribe_monitored_device( def test_unsubscribe_async_signal(self, mock_client, mock_callback): """Test unsubscribing from async signal.""" mock_client.device_manager.get_bec_signals.return_value = [ - ("async_sig1", None, {"device_name": "detector1"}) + ("detector1", None, {"obj_name": "async_sig1", "storage_name": "async_sig1"}) ] scan_item = ScanItem( @@ -653,7 +653,7 @@ def test_can_provide_monitored(self, mock_client, scan_item_with_monitored_devic def test_can_provide_async_signal(self, mock_client): """Test can_provide returns True for async signals.""" mock_client.device_manager.get_bec_signals.return_value = [ - ("async_sig1", None, {"device_name": "detector1"}) + ("detector1", None, {"obj_name": "async_sig1", "storage_name": "async_sig1"}) ] scan_item = ScanItem(