From 755c170f73e239059103f0ea7e28c5c4495877fa Mon Sep 17 00:00:00 2001 From: David Perl Date: Wed, 14 Jan 2026 16:49:46 +0100 Subject: [PATCH 1/6] feat: lock devices during a scan --- .../bec_server/scan_server/device_locking.py | 58 +++++++++++ .../bec_server/scan_server/scan_server.py | 5 +- .../bec_server/scan_server/scan_worker.py | 97 ++++++++++--------- bec_server/bec_server/scan_server/scans.py | 25 ++++- 4 files changed, 138 insertions(+), 47 deletions(-) create mode 100644 bec_server/bec_server/scan_server/device_locking.py diff --git a/bec_server/bec_server/scan_server/device_locking.py b/bec_server/bec_server/scan_server/device_locking.py new file mode 100644 index 000000000..be3b12129 --- /dev/null +++ b/bec_server/bec_server/scan_server/device_locking.py @@ -0,0 +1,58 @@ +import threading +from contextlib import contextmanager +from typing import Dict, Iterable + +from bec_lib.logger import bec_logger + +logger = bec_logger.logger + + +class DeviceLockManager: + """ + Manages locks for devices, identified simply as their name. + Allows acquiring multiple item locks atomically via a context manager. + """ + + def __init__(self) -> None: + self._locks: Dict[str, threading.RLock] = {} + self._locks_guard = threading.RLock() + + def _get_lock(self, key: str) -> threading.RLock: + """ + Get (or create) a lock for a given key. + """ + with self._locks_guard: + if key not in self._locks: + self._locks[key] = threading.RLock() + return self._locks[key] + + @contextmanager + def lock(self, keys: Iterable[str], blocking: bool = True): + """ + Context manager to lock one or more items. + """ + keys = list(set(keys)) + try: + if not self.acquire(*keys, blocking=blocking): + return + yield + finally: + self.release(*keys) + + def acquire(self, *keys: str, blocking: bool = True): + logger.info(f"Locking devices: {keys}") + with self._locks_guard: + new_locks = [] + for key in sorted(keys): + next_lock = self._get_lock(key) + if not next_lock.acquire(blocking=blocking): + [lock.release() for lock in new_locks] + return False + new_locks.append(next_lock) + return True + + def release(self, *keys: str): + logger.info(f"Releasing devices: {keys}") + with self._locks_guard: + for key in reversed(sorted(keys)): + self._get_lock(key).release() diff --git a/bec_server/bec_server/scan_server/scan_server.py b/bec_server/bec_server/scan_server/scan_server.py index a3e439cfc..ea9723216 100644 --- a/bec_server/bec_server/scan_server/scan_server.py +++ b/bec_server/bec_server/scan_server/scan_server.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING -from bec_lib import messages from bec_lib.alarm_handler import Alarms from bec_lib.bec_service import BECService from bec_lib.devicemanager import DeviceManagerBase as DeviceManager @@ -10,10 +9,13 @@ from bec_lib.logger import bec_logger from bec_lib.scan_number_container import ScanNumberContainer from bec_lib.service_config import ServiceConfig + +from bec_lib import messages from bec_server.procedures.container_utils import podman_available from bec_server.procedures.container_worker import ContainerProcedureWorker from bec_server.procedures.manager import ProcedureManager from bec_server.procedures.subprocess_worker import SubProcessWorker +from bec_server.scan_server.device_locking import DeviceLockManager from .scan_assembler import ScanAssembler from .scan_guard import ScanGuard @@ -36,6 +38,7 @@ class ScanServer(BECService): def __init__(self, config: ServiceConfig, connector_cls: type[RedisConnector]): super().__init__(config, connector_cls, unique_service=True) + self.device_locks = DeviceLockManager() self._start_scan_manager() self._start_device_manager() self._start_queue_manager() diff --git a/bec_server/bec_server/scan_server/scan_worker.py b/bec_server/bec_server/scan_server/scan_worker.py index d66d94d9a..914712c5f 100644 --- a/bec_server/bec_server/scan_server/scan_worker.py +++ b/bec_server/bec_server/scan_server/scan_worker.py @@ -12,6 +12,7 @@ from bec_lib.endpoints import MessageEndpoints from bec_lib.file_utils import compile_file_components from bec_lib.logger import bec_logger +from bec_server.scan_server.scans import RequestBase from .errors import DeviceInstructionError, ScanAbortion from .scan_queue import InstructionQueueItem, InstructionQueueStatus, RequestBlock @@ -398,25 +399,57 @@ def _process_instructions(self, queue: InstructionQueueItem) -> None: self._wait_for_device_server() queue.is_active = True - try: - for instr in queue: - self._check_for_interruption() - if instr is None: - continue - self._exposure_time = getattr(queue.active_request_block.scan, "exp_time", None) - self._instruction_step(instr) - except ScanAbortion as exc: - if queue.stopped or not (queue.return_to_start and queue.active_request_block): - raise ScanAbortion from exc - queue.stopped = True + scan_instance: RequestBase | None + if (scan_instance := getattr(queue.active_request_block, "scan", None)) is None: + devices_to_lock = [] + else: + devices_to_lock = scan_instance.instance_device_access().device_locking + with self.parent.device_locks.lock(devices_to_lock): try: - cleanup = queue.active_request_block.scan.move_to_start() - self.status = InstructionQueueStatus.RUNNING - for instr in cleanup: + for instr in queue: self._check_for_interruption() - instr.metadata["scan_id"] = queue.queue.active_rb.scan_id - instr.metadata["queue_id"] = queue.queue_id + if instr is None: + continue + self._exposure_time = getattr(queue.active_request_block.scan, "exp_time", None) self._instruction_step(instr) + except ScanAbortion as exc: + if queue.stopped or not (queue.return_to_start and queue.active_request_block): + raise ScanAbortion from exc + queue.stopped = True + try: + cleanup = queue.active_request_block.scan.move_to_start() + self.status = InstructionQueueStatus.RUNNING + for instr in cleanup: + self._check_for_interruption() + instr.metadata["scan_id"] = queue.queue.active_rb.scan_id + instr.metadata["queue_id"] = queue.queue_id + self._instruction_step(instr) + except DeviceInstructionError as exc_di: + content = traceback.format_exc() + logger.error(content) + self.connector.raise_alarm( + severity=Alarms.MAJOR, + info=exc_di.error_info, + metadata=self._get_metadata_for_alarm(), + ) + raise ScanAbortion from exc_di + except Exception as exc_return_to_start: + # if the return_to_start fails, raise the original exception + content = traceback.format_exc() + logger.error(content) + error_info = messages.ErrorInfo( + error_message=content, + compact_error_message=traceback.format_exc(limit=0), + exception_type=exc_return_to_start.__class__.__name__, + device=None, + ) + self.connector.raise_alarm( + severity=Alarms.MAJOR, + info=error_info, + metadata=self._get_metadata_for_alarm(), + ) + raise ScanAbortion from exc + raise ScanAbortion from exc except DeviceInstructionError as exc_di: content = traceback.format_exc() logger.error(content) @@ -425,46 +458,22 @@ def _process_instructions(self, queue: InstructionQueueItem) -> None: info=exc_di.error_info, metadata=self._get_metadata_for_alarm(), ) + raise ScanAbortion from exc_di - except Exception as exc_return_to_start: - # if the return_to_start fails, raise the original exception + except Exception as exc: content = traceback.format_exc() logger.error(content) error_info = messages.ErrorInfo( error_message=content, compact_error_message=traceback.format_exc(limit=0), - exception_type=exc_return_to_start.__class__.__name__, + exception_type=exc.__class__.__name__, device=None, ) self.connector.raise_alarm( severity=Alarms.MAJOR, info=error_info, metadata=self._get_metadata_for_alarm() ) - raise ScanAbortion from exc - raise ScanAbortion from exc - except DeviceInstructionError as exc_di: - content = traceback.format_exc() - logger.error(content) - self.connector.raise_alarm( - severity=Alarms.MAJOR, - info=exc_di.error_info, - metadata=self._get_metadata_for_alarm(), - ) - - raise ScanAbortion from exc_di - except Exception as exc: - content = traceback.format_exc() - logger.error(content) - error_info = messages.ErrorInfo( - error_message=content, - compact_error_message=traceback.format_exc(limit=0), - exception_type=exc.__class__.__name__, - device=None, - ) - self.connector.raise_alarm( - severity=Alarms.MAJOR, info=error_info, metadata=self._get_metadata_for_alarm() - ) - raise ScanAbortion from exc + raise ScanAbortion from exc queue.is_active = False queue.status = InstructionQueueStatus.COMPLETED self.current_instruction_queue_item = None diff --git a/bec_server/bec_server/scan_server/scans.py b/bec_server/bec_server/scan_server/scans.py index 1549429f0..338f9df2d 100644 --- a/bec_server/bec_server/scan_server/scans.py +++ b/bec_server/bec_server/scan_server/scans.py @@ -9,6 +9,7 @@ from typing import Any, Literal import numpy as np +from pydantic import BaseModel from bec_lib import messages from bec_lib.alarm_handler import Alarms @@ -245,7 +246,7 @@ class RequestBase(ABC): """ scan_name = "" - arg_input = {} + arg_input: dict[str, ScanArgType] = {} arg_bundle_size = {"bundle": len(arg_input), "min": None, "max": None} gui_args = {} required_kwargs = [] @@ -376,6 +377,26 @@ def update_readout_priority(self): def run(self): pass + @classmethod + def device_access(cls, scan_parameters: dict) -> ScanDeviceAccessList: + """Provide the devices for which permissions and locking are needed for this scan, with the given parameter set.""" + devices_used_in_scan = set( + str(scan_parameters.get(arg)) + for arg, T in cls.arg_input.items() + if T == ScanArgType.DEVICE + ) + return ScanDeviceAccessList( + device_permissions=devices_used_in_scan, device_locking=devices_used_in_scan + ) + + def instance_device_access(self) -> ScanDeviceAccessList: + return self.device_access(self.parameter) + + +class ScanDeviceAccessList(BaseModel): + device_permissions: set[str] + device_locking: set[str] + class ScanBase(RequestBase, PathOptimizerMixin): """ @@ -403,7 +424,7 @@ class ScanBase(RequestBase, PathOptimizerMixin): Attributes: scan_name (str): name of the scan scan_type (str): scan type. Can be "step" or "fly" - arg_input (list): list of scan argument types + arg_input (dict[str, ScanArgType]): list of scan argument types arg_bundle_size (dict): - bundle: number of arguments that are bundled together - min: minimum number of bundles From 3d5c6a9f7ebf2dbce84a06cd8b758e2b5e7b8b2b Mon Sep 17 00:00:00 2001 From: David Perl Date: Thu, 29 Jan 2026 10:13:43 +0100 Subject: [PATCH 2/6] wip tests --- .../tests/tests_scan_server/conftest.py | 22 +++++++++++++++++++ .../tests_scan_server/test_device_locking.py | 0 2 files changed, 22 insertions(+) create mode 100644 bec_server/tests/tests_scan_server/test_device_locking.py diff --git a/bec_server/tests/tests_scan_server/conftest.py b/bec_server/tests/tests_scan_server/conftest.py index 75bb68c11..88f700d51 100644 --- a/bec_server/tests/tests_scan_server/conftest.py +++ b/bec_server/tests/tests_scan_server/conftest.py @@ -1,8 +1,11 @@ +from typing import Callable, Generator + import fakeredis import pytest from bec_lib.logger import bec_logger from bec_lib.redis_connector import RedisConnector +from bec_server.scan_server.scan_queue import QueueManager # overwrite threads_check fixture from bec_lib, # to have it in autouse @@ -27,3 +30,22 @@ def connected_connector(): yield connector finally: connector.shutdown() + + +@pytest.fixture +def queuemanager_mock( + scan_server_mock, +) -> Generator[Callable[[None | str | list[str]], QueueManager], None, None]: + def _get_queuemanager(queues=None): + scan_server = scan_server_mock + if queues is None: + queues = ["primary"] + if isinstance(queues, str): + queues = [queues] + for queue in queues: + scan_server.queue_manager.add_queue(queue) + return scan_server.queue_manager + + yield _get_queuemanager + + scan_server_mock.queue_manager.shutdown() diff --git a/bec_server/tests/tests_scan_server/test_device_locking.py b/bec_server/tests/tests_scan_server/test_device_locking.py new file mode 100644 index 000000000..e69de29bb From 3a58dbe042c61fa039aa0097d6e6ef2ddfbe6b74 Mon Sep 17 00:00:00 2001 From: David Perl Date: Thu, 29 Jan 2026 10:33:27 +0100 Subject: [PATCH 3/6] wip tests --- bec_server/bec_server/scan_server/scans.py | 18 +++++---- .../tests/tests_scan_server/conftest.py | 7 ++-- .../tests_scan_server/test_device_locking.py | 38 +++++++++++++++++++ .../test_scan_server_queue.py | 18 --------- 4 files changed, 51 insertions(+), 30 deletions(-) diff --git a/bec_server/bec_server/scan_server/scans.py b/bec_server/bec_server/scan_server/scans.py index 338f9df2d..02831fb0e 100644 --- a/bec_server/bec_server/scan_server/scans.py +++ b/bec_server/bec_server/scan_server/scans.py @@ -9,14 +9,14 @@ from typing import Any, Literal import numpy as np -from pydantic import BaseModel - -from bec_lib import messages from bec_lib.alarm_handler import Alarms from bec_lib.device import DeviceBase from bec_lib.devicemanager import DeviceManagerBase from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger +from pydantic import BaseModel + +from bec_lib import messages from bec_server.scan_server.instruction_handler import InstructionHandler from .errors import LimitError, ScanAbortion @@ -380,11 +380,13 @@ def run(self): @classmethod def device_access(cls, scan_parameters: dict) -> ScanDeviceAccessList: """Provide the devices for which permissions and locking are needed for this scan, with the given parameter set.""" - devices_used_in_scan = set( - str(scan_parameters.get(arg)) - for arg, T in cls.arg_input.items() - if T == ScanArgType.DEVICE - ) + arg_devices = set(scan_parameters.get("args", {}).keys()) + param_kwargs = scan_parameters.get("kwargs", {}) + kwarg_devices = set() + for arg, T in cls.arg_input.items(): + if T == ScanArgType.DEVICE and arg in param_kwargs: + kwarg_devices.add(str(param_kwargs[arg])) + devices_used_in_scan = arg_devices | kwarg_devices return ScanDeviceAccessList( device_permissions=devices_used_in_scan, device_locking=devices_used_in_scan ) diff --git a/bec_server/tests/tests_scan_server/conftest.py b/bec_server/tests/tests_scan_server/conftest.py index 88f700d51..167297357 100644 --- a/bec_server/tests/tests_scan_server/conftest.py +++ b/bec_server/tests/tests_scan_server/conftest.py @@ -6,6 +6,7 @@ from bec_lib.logger import bec_logger from bec_lib.redis_connector import RedisConnector from bec_server.scan_server.scan_queue import QueueManager +from bec_server.scan_server.tests.fixtures import scan_server_mock # overwrite threads_check fixture from bec_lib, # to have it in autouse @@ -33,10 +34,8 @@ def connected_connector(): @pytest.fixture -def queuemanager_mock( - scan_server_mock, -) -> Generator[Callable[[None | str | list[str]], QueueManager], None, None]: - def _get_queuemanager(queues=None): +def queuemanager_mock(scan_server_mock): + def _get_queuemanager(queues=None) -> QueueManager: scan_server = scan_server_mock if queues is None: queues = ["primary"] diff --git a/bec_server/tests/tests_scan_server/test_device_locking.py b/bec_server/tests/tests_scan_server/test_device_locking.py index e69de29bb..31502efa4 100644 --- a/bec_server/tests/tests_scan_server/test_device_locking.py +++ b/bec_server/tests/tests_scan_server/test_device_locking.py @@ -0,0 +1,38 @@ +from typing import Callable + +import pytest +from bec_server.scan_server.scan_queue import QueueManager + +from bec_lib import messages + + +@pytest.fixture +def qm_with_3_qs_and_lock_man(queuemanager_mock: Callable[..., QueueManager]): + queue_manager = queuemanager_mock(["1", "2", "3"]) + yield queue_manager, queue_manager.parent.device_locks + + +def _linescan_msg(dev: str, start: float, stop: float): + return messages.ScanQueueMessage( + scan_type="line_scan", + parameter={"args": {dev: (start, stop)}, "kwargs": {}}, + queue="primary", + metadata={"RID": "something"}, + ) + + +def test_devices_from_instance(queuemanager_mock): + q_manager = queuemanager_mock() + assembler = q_manager.parent.scan_assembler + scan_instance = assembler.assemble_device_instructions(_linescan_msg("samx", -1, 1), "test") + device_access = scan_instance.instance_device_access() + assert device_access.device_locking == set(("samx",)) + + +def test_queuemanager_add_to_queue_restarts_queue_if_worker_is_dead(qm_with_3_qs_and_lock_man): + queue_manager, locks = qm_with_3_qs_and_lock_man + msg = _linescan_msg("samx", -5, 5) + + queue_manager.add_to_queue(scan_queue="1", msg=msg) + + ... diff --git a/bec_server/tests/tests_scan_server/test_scan_server_queue.py b/bec_server/tests/tests_scan_server/test_scan_server_queue.py index bd5477ac6..f20f1db92 100644 --- a/bec_server/tests/tests_scan_server/test_scan_server_queue.py +++ b/bec_server/tests/tests_scan_server/test_scan_server_queue.py @@ -20,30 +20,12 @@ ScanQueueStatus, ) from bec_server.scan_server.scan_worker import ScanWorker -from bec_server.scan_server.tests.fixtures import scan_server_mock # pylint: disable=missing-function-docstring # pylint: disable=protected-access ScanQueue.AUTO_SHUTDOWN_TIME = 1 # Reduce auto-shutdown time for testing -@pytest.fixture -def queuemanager_mock(scan_server_mock) -> QueueManager: - def _get_queuemanager(queues=None): - scan_server = scan_server_mock - if queues is None: - queues = ["primary"] - if isinstance(queues, str): - queues = [queues] - for queue in queues: - scan_server.queue_manager.add_queue(queue) - return scan_server.queue_manager - - yield _get_queuemanager - - scan_server_mock.queue_manager.shutdown() - - class RequestBlockQueueMock(RequestBlockQueue): request_blocks = [] _scan_id = [] From 63ced9d42cbf25a0dbf36bb57b8ff86318d3c90d Mon Sep 17 00:00:00 2001 From: David Perl Date: Wed, 4 Feb 2026 16:46:37 +0100 Subject: [PATCH 4/6] wip tests --- bec_lib/bec_lib/tests/utils.py | 14 +++++- .../tests_scan_server/test_device_locking.py | 29 +++++++----- .../tests_scan_server/test_procedures.py | 46 ++++++++----------- 3 files changed, 49 insertions(+), 40 deletions(-) diff --git a/bec_lib/bec_lib/tests/utils.py b/bec_lib/bec_lib/tests/utils.py index a318be1bb..62799378d 100644 --- a/bec_lib/bec_lib/tests/utils.py +++ b/bec_lib/bec_lib/tests/utils.py @@ -5,7 +5,7 @@ import os import time from types import SimpleNamespace -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Callable, Literal import bec_lib from bec_lib import messages @@ -860,3 +860,15 @@ def redis_server_is_running(self): def get_last(self, topic, key): return None + + +def wait_until(predicate: Callable[[], bool], timeout_s: float = 0.1): + """Sleep until 'predicate' returns True, or raise a TimeoutError""" + # Yes I know this is actually more like retries than a timeout, + # it's just to make sure the threads have plenty of chances to switch in the test + elapsed, step = 0.0, timeout_s / 10 + while not predicate(): + time.sleep(step) + elapsed += step + if elapsed > timeout_s: + raise TimeoutError() diff --git a/bec_server/tests/tests_scan_server/test_device_locking.py b/bec_server/tests/tests_scan_server/test_device_locking.py index 31502efa4..3b407546d 100644 --- a/bec_server/tests/tests_scan_server/test_device_locking.py +++ b/bec_server/tests/tests_scan_server/test_device_locking.py @@ -1,9 +1,10 @@ from typing import Callable import pytest -from bec_server.scan_server.scan_queue import QueueManager from bec_lib import messages +from bec_lib.tests.utils import wait_until +from bec_server.scan_server.scan_queue import QueueManager @pytest.fixture @@ -12,27 +13,33 @@ def qm_with_3_qs_and_lock_man(queuemanager_mock: Callable[..., QueueManager]): yield queue_manager, queue_manager.parent.device_locks -def _linescan_msg(dev: str, start: float, stop: float): +def _linescan_msg(*args: tuple[str, float, float]): return messages.ScanQueueMessage( scan_type="line_scan", - parameter={"args": {dev: (start, stop)}, "kwargs": {}}, + parameter={"args": {d: (a, b) for (d, a, b) in args}, "kwargs": {}}, queue="primary", metadata={"RID": "something"}, ) -def test_devices_from_instance(queuemanager_mock): +@pytest.mark.parametrize( + ["msg", "devices"], + [ + (_linescan_msg(("samx", -1, 1)), ("samx",)), + (_linescan_msg(("samx", -1, 1), ("samy", -1, 1)), ("samx", "samy")), + (_linescan_msg(("a", -1, 1), ("b", -1, 1), ("c", -1, 1)), ("a", "b", "c")), + ], +) +def test_devices_from_instance(queuemanager_mock, msg, devices): q_manager = queuemanager_mock() assembler = q_manager.parent.scan_assembler - scan_instance = assembler.assemble_device_instructions(_linescan_msg("samx", -1, 1), "test") + scan_instance = assembler.assemble_device_instructions(msg, "test") device_access = scan_instance.instance_device_access() - assert device_access.device_locking == set(("samx",)) + assert device_access.device_locking == set(devices) -def test_queuemanager_add_to_queue_restarts_queue_if_worker_is_dead(qm_with_3_qs_and_lock_man): +def test_scan_worker_locks_devices_single(qm_with_3_qs_and_lock_man): queue_manager, locks = qm_with_3_qs_and_lock_man - msg = _linescan_msg("samx", -5, 5) - + msg = _linescan_msg(("samx", -5, 5)) queue_manager.add_to_queue(scan_queue="1", msg=msg) - - ... + wait_until(lambda: locks._locks != {}, timeout_s=1) diff --git a/bec_server/tests/tests_scan_server/test_procedures.py b/bec_server/tests/tests_scan_server/test_procedures.py index 268726c2d..8a90f3cf9 100644 --- a/bec_server/tests/tests_scan_server/test_procedures.py +++ b/bec_server/tests/tests_scan_server/test_procedures.py @@ -22,6 +22,7 @@ from bec_lib.procedures.helper import FrontendProcedureHelper, ProcedureState from bec_lib.serialization import MsgpackSerialization from bec_lib.service_config import ServiceConfig +from bec_lib.tests.utils import wait_until from bec_server.procedures.builtin_procedures import run_macro from bec_server.procedures.constants import PROCEDURE, BecProcedure, WorkerAlreadyExists from bec_server.procedures.manager import ProcedureManager @@ -272,17 +273,6 @@ def __init__(self, server: str, queue: str, lifetime_s: int | None = None, execu self.execution_id = execution_id -def _wait_until(predicate: Callable[[], bool], timeout_s: float = 0.1): - # Yes I know this is actually more like retries than a timeout, - # it's just to make sure the threads have plenty of chances to switch in the test - elapsed, step = 0.0, timeout_s / 10 - while not predicate(): - time.sleep(step) - elapsed += step - if elapsed > timeout_s: - raise TimeoutError() - - @patch("bec_server.procedures.worker_base.RedisConnector") @patch("bec_server.procedures.manager.RedisConnector", MagicMock()) def test_spawn(redis_connector, procedure_manager: ProcedureManager): @@ -297,12 +287,12 @@ def test_spawn(redis_connector, procedure_manager: ProcedureManager): assert queue in procedure_manager._active_workers.keys() # spawn method should be added as a future - _wait_until(procedure_manager._active_workers[queue]["future"].running) + wait_until(procedure_manager._active_workers[queue]["future"].running) # and then create the worker - _wait_until(lambda: procedure_manager._active_workers[queue].get("worker") is not None) + wait_until(lambda: procedure_manager._active_workers[queue].get("worker") is not None) worker = procedure_manager._active_workers[queue]["worker"] assert isinstance(worker, UnlockableWorker) - _wait_until(lambda: worker.status == ProcedureWorkerStatus.RUNNING) + wait_until(lambda: worker.status == ProcedureWorkerStatus.RUNNING) # check that you can't instantiate the same worker twice - call spawn directly to # raise the exception in this thread @@ -313,11 +303,11 @@ def test_spawn(redis_connector, procedure_manager: ProcedureManager): with procedure_manager.lock: worker.event_1.set() # let the task end and return to ProcedureWorker.work() # queue deletion callback needs the lock so we can catch it in FINISHED - _wait_until(lambda: worker.status == ProcedureWorkerStatus.IDLE) + wait_until(lambda: worker.status == ProcedureWorkerStatus.IDLE) worker.event_2.set() - _wait_until(lambda: worker.status == ProcedureWorkerStatus.FINISHED) + wait_until(lambda: worker.status == ProcedureWorkerStatus.FINISHED) # spawn deletes the worker queue - _wait_until(lambda: len(procedure_manager._active_workers) == 0) + wait_until(lambda: len(procedure_manager._active_workers) == 0) @patch("bec_server.procedures.builtin_procedures.logger") @@ -401,20 +391,20 @@ def test_manager_status_api(_conn, procedure_manager): procedure_manager._worker_cls = UnlockableWorker for message in PROCESS_REQUEST_TEST_CASES: procedure_manager._process_queue_request(message) - _wait_until(lambda: procedure_manager.active_workers() == ["primary", "queue2"]) - _wait_until( + wait_until(lambda: procedure_manager.active_workers() == ["primary", "queue2"]) + wait_until( lambda: procedure_manager.worker_statuses() == {"primary": ProcedureWorkerStatus.RUNNING, "queue2": ProcedureWorkerStatus.RUNNING} ) for w in procedure_manager._active_workers.values(): w["worker"].event_1.set() - _wait_until( + wait_until( lambda: procedure_manager.worker_statuses() == {"primary": ProcedureWorkerStatus.IDLE, "queue2": ProcedureWorkerStatus.IDLE} ) for w in procedure_manager._active_workers.values(): w["worker"].event_2.set() - _wait_until(lambda: procedure_manager.active_workers() == []) + wait_until(lambda: procedure_manager.active_workers() == []) _ManagerWithMsgs = tuple[ProcedureManager, list[ProcedureExecutionMessage]] @@ -548,14 +538,14 @@ def test_abort_all(manager_with_test_msgs: _ManagerWithMsgs): def test_procedure_status_rejected(procedure_manager): status = procedure_manager._helper.request.procedure("doesn't exist") assert status.state == ProcedureState.REQUESTED - _wait_until(lambda: status.state == ProcedureState.REJECTED) + wait_until(lambda: status.state == ProcedureState.REJECTED) assert status.done @patch("bec_server.procedures.oop_worker_base.BECIPythonClient", MagicMock) def test_procedure_status_rejected_not_cancellable(procedure_manager): status = procedure_manager._helper.request.procedure("doesn't exist") - _wait_until(lambda: status.state == ProcedureState.REJECTED) + wait_until(lambda: status.state == ProcedureState.REJECTED) with pytest.raises(ValueError) as e: status.cancel() @@ -571,13 +561,13 @@ def test_procedure_status_accepted(procedure_manager): ) status = procedure_manager._helper.request._procedure(msg) assert status.state == ProcedureState.REQUESTED - _wait_until(lambda: procedure_manager._active_workers.get("primary") is not None, timeout_s=1) + wait_until(lambda: procedure_manager._active_workers.get("primary") is not None, timeout_s=1) worker = procedure_manager._active_workers["primary"]["worker"] assert isinstance(worker, FakeRedisUnlockable) worker.event_1.set() - _wait_until(lambda: status.state == ProcedureState.RUNNING, timeout_s=10) + wait_until(lambda: status.state == ProcedureState.RUNNING, timeout_s=10) worker.event_2.set() - _wait_until(lambda: status.state == ProcedureState.SUCCESS, timeout_s=10) + wait_until(lambda: status.state == ProcedureState.SUCCESS, timeout_s=10) def _mock_error_procedure(*args, **kwargs): @@ -590,7 +580,7 @@ def test_procedure_status_error(procedure_manager): msg = ProcedureRequestMessage(identifier="error", execution_id="test") status = procedure_manager._helper.request._procedure(msg) assert status.state == ProcedureState.REQUESTED - _wait_until(lambda: procedure_manager._active_workers.get("primary") is not None, timeout_s=1) + wait_until(lambda: procedure_manager._active_workers.get("primary") is not None, timeout_s=1) worker = procedure_manager._active_workers["primary"]["worker"] assert isinstance(worker, InlineWorker) @@ -616,7 +606,7 @@ def test_procedure_status_error(procedure_manager): assert "RuntimeError: Encountered error in procedure" in str(status) worker._ending = True - _wait_until(lambda: procedure_manager._active_workers.get("primary") is None, timeout_s=2) + wait_until(lambda: procedure_manager._active_workers.get("primary") is None, timeout_s=2) def test_builtin_proc_run_macro_found(shutdown_client): From 5e9d68b17655fdd3eb60de9365f789fad2773bda Mon Sep 17 00:00:00 2001 From: David Perl Date: Wed, 4 Feb 2026 17:27:08 +0100 Subject: [PATCH 5/6] style: factorise scan worker --- .../bec_server/scan_server/scan_worker.py | 163 +++++++++--------- 1 file changed, 86 insertions(+), 77 deletions(-) diff --git a/bec_server/bec_server/scan_server/scan_worker.py b/bec_server/bec_server/scan_server/scan_worker.py index 914712c5f..61ca85031 100644 --- a/bec_server/bec_server/scan_server/scan_worker.py +++ b/bec_server/bec_server/scan_server/scan_worker.py @@ -4,14 +4,17 @@ import threading import time import traceback +from functools import partial from string import Template -from typing import TYPE_CHECKING, Literal +from traceback import TracebackException +from typing import TYPE_CHECKING, Callable, Literal from bec_lib import messages from bec_lib.alarm_handler import Alarms from bec_lib.endpoints import MessageEndpoints from bec_lib.file_utils import compile_file_components from bec_lib.logger import bec_logger +from bec_lib.messages import ErrorInfo from bec_server.scan_server.scans import RequestBase from .errors import DeviceInstructionError, ScanAbortion @@ -24,6 +27,12 @@ from bec_server.scan_server.scan_server import ScanServer +class _NewErrorInfo: ... + + +_NEW_ERRORINFO = _NewErrorInfo() + + class ScanWorker(threading.Thread): """ Scan worker receives device instructions and pre-processes them before sending them to the device server @@ -377,107 +386,107 @@ def _get_metadata_for_alarm(self) -> dict: metadata["scan_number"] = self.current_scan_info["scan_number"] return metadata - def _process_instructions(self, queue: InstructionQueueItem) -> None: - """ - Process scan instructions and send DeviceInstructions to OPAAS. - For now this is an in-memory communication. In the future however, - we might want to pass it through a dedicated Kafka topic. - Args: - queue: instruction queue + ############################# + # PROCESS INSTRUCTIONS LOOP # + ############################# - Returns: - - """ + def _init_process_loop(self, queue: InstructionQueueItem) -> float | None: + """Sets up the conditions for the loop and returns the start time""" if not queue: return None self.current_instruction_queue_item = queue start = time.time() self.max_point_id = 0 - # make sure the device server is ready to receive data self._wait_for_device_server() queue.is_active = True - scan_instance: RequestBase | None - if (scan_instance := getattr(queue.active_request_block, "scan", None)) is None: - devices_to_lock = [] - else: - devices_to_lock = scan_instance.instance_device_access().device_locking - with self.parent.device_locks.lock(devices_to_lock): + return start + + def _propagate_instruction_error( + self, content: str, info: ErrorInfo | Callable[[str], ErrorInfo], exc: Exception + ): + logger.error(content) + info = info(content) if callable(info) else info + self.connector.raise_alarm( + severity=Alarms.MAJOR, info=info, metadata=self._get_metadata_for_alarm() + ) + raise ScanAbortion from exc + + def _process_instructions_inner(self, queue: InstructionQueueItem): + try: + for instr in queue: + self._check_for_interruption() + if instr is None: + continue + self._exposure_time = getattr(queue.active_request_block.scan, "exp_time", None) + self._instruction_step(instr) + except ScanAbortion as exc: + if queue.stopped or not (queue.return_to_start and queue.active_request_block): + raise ScanAbortion from exc + queue.stopped = True try: - for instr in queue: + cleanup = queue.active_request_block.scan.move_to_start() + self.status = InstructionQueueStatus.RUNNING + for instr in cleanup: self._check_for_interruption() - if instr is None: - continue - self._exposure_time = getattr(queue.active_request_block.scan, "exp_time", None) + instr.metadata["scan_id"] = queue.queue.active_rb.scan_id + instr.metadata["queue_id"] = queue.queue_id self._instruction_step(instr) - except ScanAbortion as exc: - if queue.stopped or not (queue.return_to_start and queue.active_request_block): - raise ScanAbortion from exc - queue.stopped = True - try: - cleanup = queue.active_request_block.scan.move_to_start() - self.status = InstructionQueueStatus.RUNNING - for instr in cleanup: - self._check_for_interruption() - instr.metadata["scan_id"] = queue.queue.active_rb.scan_id - instr.metadata["queue_id"] = queue.queue_id - self._instruction_step(instr) - except DeviceInstructionError as exc_di: - content = traceback.format_exc() - logger.error(content) - self.connector.raise_alarm( - severity=Alarms.MAJOR, - info=exc_di.error_info, - metadata=self._get_metadata_for_alarm(), - ) - raise ScanAbortion from exc_di - except Exception as exc_return_to_start: - # if the return_to_start fails, raise the original exception - content = traceback.format_exc() - logger.error(content) - error_info = messages.ErrorInfo( - error_message=content, + except DeviceInstructionError as exc_di: + self._propagate_instruction_error(traceback.format_exc(), exc_di.error_info, exc_di) + except Exception as exc_return_to_start: + # if the return_to_start fails, raise the original exception + self._propagate_instruction_error( + traceback.format_exc(), + lambda msg: ErrorInfo( + error_message=msg, compact_error_message=traceback.format_exc(limit=0), exception_type=exc_return_to_start.__class__.__name__, device=None, - ) - self.connector.raise_alarm( - severity=Alarms.MAJOR, - info=error_info, - metadata=self._get_metadata_for_alarm(), - ) - raise ScanAbortion from exc - raise ScanAbortion from exc - except DeviceInstructionError as exc_di: - content = traceback.format_exc() - logger.error(content) - self.connector.raise_alarm( - severity=Alarms.MAJOR, - info=exc_di.error_info, - metadata=self._get_metadata_for_alarm(), + ), + exc_return_to_start, ) - - raise ScanAbortion from exc_di - except Exception as exc: - content = traceback.format_exc() - logger.error(content) - error_info = messages.ErrorInfo( - error_message=content, + raise ScanAbortion from exc + except DeviceInstructionError as exc_di: + self._propagate_instruction_error(traceback.format_exc(), exc_di.error_info, exc_di) + except Exception as exc: + self._propagate_instruction_error( + traceback.format_exc(), + lambda msg: ErrorInfo( + error_message=msg, compact_error_message=traceback.format_exc(limit=0), exception_type=exc.__class__.__name__, device=None, - ) - self.connector.raise_alarm( - severity=Alarms.MAJOR, info=error_info, metadata=self._get_metadata_for_alarm() - ) + ), + exc, + ) + + def _process_instructions(self, queue: InstructionQueueItem) -> None: + """ + Process scan instructions and send DeviceInstructions to OPAAS. + For now this is an in-memory communication. In the future however, + we might want to pass it through a dedicated Kafka topic. + Args: + queue: instruction queue + + Returns: + + """ + if (start := self._init_process_loop(queue)) is None: + return None + + scan_instance: RequestBase | None = getattr(queue.active_request_block, "scan", None) + devices_to_lock = ( + [] if scan_instance is None else scan_instance.instance_device_access().device_locking + ) + with self.parent.device_locks.lock(devices_to_lock): + self._process_instructions_inner(queue) - raise ScanAbortion from exc queue.is_active = False queue.status = InstructionQueueStatus.COMPLETED self.current_instruction_queue_item = None - logger.info(f"QUEUE ITEM finished after {time.time()-start:.2f} seconds") self.reset() From 7dc8f2aa2c733f1ef595f21524153750a3acc94f Mon Sep 17 00:00:00 2001 From: David Perl Date: Wed, 4 Feb 2026 17:28:09 +0100 Subject: [PATCH 6/6] f --- .../tests/tests_scan_server/test_device_locking.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/bec_server/tests/tests_scan_server/test_device_locking.py b/bec_server/tests/tests_scan_server/test_device_locking.py index 3b407546d..18ccd1495 100644 --- a/bec_server/tests/tests_scan_server/test_device_locking.py +++ b/bec_server/tests/tests_scan_server/test_device_locking.py @@ -38,8 +38,8 @@ def test_devices_from_instance(queuemanager_mock, msg, devices): assert device_access.device_locking == set(devices) -def test_scan_worker_locks_devices_single(qm_with_3_qs_and_lock_man): - queue_manager, locks = qm_with_3_qs_and_lock_man - msg = _linescan_msg(("samx", -5, 5)) - queue_manager.add_to_queue(scan_queue="1", msg=msg) - wait_until(lambda: locks._locks != {}, timeout_s=1) +# def test_scan_worker_locks_devices_single(qm_with_3_qs_and_lock_man): +# queue_manager, locks = qm_with_3_qs_and_lock_man +# msg = _linescan_msg(("samx", -5, 5)) +# queue_manager.add_to_queue(scan_queue="1", msg=msg) +# wait_until(lambda: locks._locks != {}, timeout_s=1)