Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion bec_lib/bec_lib/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import time
from types import SimpleNamespace
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Callable, Literal

import bec_lib
from bec_lib import messages
Expand Down Expand Up @@ -860,3 +860,15 @@ def redis_server_is_running(self):

def get_last(self, topic, key):
return None


def wait_until(predicate: Callable[[], bool], timeout_s: float = 0.1):
"""Sleep until 'predicate' returns True, or raise a TimeoutError"""
# Yes I know this is actually more like retries than a timeout,
# it's just to make sure the threads have plenty of chances to switch in the test
elapsed, step = 0.0, timeout_s / 10
while not predicate():
time.sleep(step)
elapsed += step
if elapsed > timeout_s:
raise TimeoutError()
58 changes: 58 additions & 0 deletions bec_server/bec_server/scan_server/device_locking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import threading
from contextlib import contextmanager
from typing import Dict, Iterable

from bec_lib.logger import bec_logger

logger = bec_logger.logger


class DeviceLockManager:
"""
Manages locks for devices, identified simply as their name.
Allows acquiring multiple item locks atomically via a context manager.
"""

def __init__(self) -> None:
self._locks: Dict[str, threading.RLock] = {}
self._locks_guard = threading.RLock()

def _get_lock(self, key: str) -> threading.RLock:
"""
Get (or create) a lock for a given key.
"""
with self._locks_guard:
if key not in self._locks:
self._locks[key] = threading.RLock()
return self._locks[key]

@contextmanager
def lock(self, keys: Iterable[str], blocking: bool = True):
"""
Context manager to lock one or more items.
"""
keys = list(set(keys))
try:
if not self.acquire(*keys, blocking=blocking):
return
yield
finally:
self.release(*keys)

def acquire(self, *keys: str, blocking: bool = True):
logger.info(f"Locking devices: {keys}")
with self._locks_guard:
new_locks = []
for key in sorted(keys):
next_lock = self._get_lock(key)
if not next_lock.acquire(blocking=blocking):
[lock.release() for lock in new_locks]
return False
new_locks.append(next_lock)
return True

def release(self, *keys: str):
logger.info(f"Releasing devices: {keys}")
with self._locks_guard:
for key in reversed(sorted(keys)):
self._get_lock(key).release()
5 changes: 4 additions & 1 deletion bec_server/bec_server/scan_server/scan_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@

from typing import TYPE_CHECKING

from bec_lib import messages
from bec_lib.alarm_handler import Alarms
from bec_lib.bec_service import BECService
from bec_lib.devicemanager import DeviceManagerBase as DeviceManager
from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger
from bec_lib.scan_number_container import ScanNumberContainer
from bec_lib.service_config import ServiceConfig

from bec_lib import messages
from bec_server.procedures.container_utils import podman_available
from bec_server.procedures.container_worker import ContainerProcedureWorker
from bec_server.procedures.manager import ProcedureManager
from bec_server.procedures.subprocess_worker import SubProcessWorker
from bec_server.scan_server.device_locking import DeviceLockManager

from .scan_assembler import ScanAssembler
from .scan_guard import ScanGuard
Expand All @@ -36,6 +38,7 @@ class ScanServer(BECService):

def __init__(self, config: ServiceConfig, connector_cls: type[RedisConnector]):
super().__init__(config, connector_cls, unique_service=True)
self.device_locks = DeviceLockManager()
self._start_scan_manager()
self._start_device_manager()
self._start_queue_manager()
Expand Down
122 changes: 70 additions & 52 deletions bec_server/bec_server/scan_server/scan_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
import threading
import time
import traceback
from functools import partial
from string import Template
from typing import TYPE_CHECKING, Literal
from traceback import TracebackException
from typing import TYPE_CHECKING, Callable, Literal

from bec_lib import messages
from bec_lib.alarm_handler import Alarms
from bec_lib.endpoints import MessageEndpoints
from bec_lib.file_utils import compile_file_components
from bec_lib.logger import bec_logger
from bec_lib.messages import ErrorInfo
from bec_server.scan_server.scans import RequestBase

from .errors import DeviceInstructionError, ScanAbortion
from .scan_queue import InstructionQueueItem, InstructionQueueStatus, RequestBlock
Expand All @@ -23,6 +27,12 @@
from bec_server.scan_server.scan_server import ScanServer


class _NewErrorInfo: ...


_NEW_ERRORINFO = _NewErrorInfo()


class ScanWorker(threading.Thread):
"""
Scan worker receives device instructions and pre-processes them before sending them to the device server
Expand Down Expand Up @@ -376,28 +386,35 @@ def _get_metadata_for_alarm(self) -> dict:
metadata["scan_number"] = self.current_scan_info["scan_number"]
return metadata

def _process_instructions(self, queue: InstructionQueueItem) -> None:
"""
Process scan instructions and send DeviceInstructions to OPAAS.
For now this is an in-memory communication. In the future however,
we might want to pass it through a dedicated Kafka topic.
Args:
queue: instruction queue
#############################
# PROCESS INSTRUCTIONS LOOP #
#############################

Returns:

"""
def _init_process_loop(self, queue: InstructionQueueItem) -> float | None:
"""Sets up the conditions for the loop and returns the start time"""
if not queue:
return None
self.current_instruction_queue_item = queue

start = time.time()
self.max_point_id = 0

# make sure the device server is ready to receive data
self._wait_for_device_server()

queue.is_active = True
return start

def _propagate_instruction_error(
self, content: str, info: ErrorInfo | Callable[[str], ErrorInfo], exc: Exception
):
logger.error(content)
info = info(content) if callable(info) else info
self.connector.raise_alarm(
severity=Alarms.MAJOR, info=info, metadata=self._get_metadata_for_alarm()
)
raise ScanAbortion from exc

def _process_instructions_inner(self, queue: InstructionQueueItem):
try:
for instr in queue:
self._check_for_interruption()
Expand All @@ -418,57 +435,58 @@ def _process_instructions(self, queue: InstructionQueueItem) -> None:
instr.metadata["queue_id"] = queue.queue_id
self._instruction_step(instr)
except DeviceInstructionError as exc_di:
content = traceback.format_exc()
logger.error(content)
self.connector.raise_alarm(
severity=Alarms.MAJOR,
info=exc_di.error_info,
metadata=self._get_metadata_for_alarm(),
)
raise ScanAbortion from exc_di
self._propagate_instruction_error(traceback.format_exc(), exc_di.error_info, exc_di)
except Exception as exc_return_to_start:
# if the return_to_start fails, raise the original exception
content = traceback.format_exc()
logger.error(content)
error_info = messages.ErrorInfo(
error_message=content,
compact_error_message=traceback.format_exc(limit=0),
exception_type=exc_return_to_start.__class__.__name__,
device=None,
)
self.connector.raise_alarm(
severity=Alarms.MAJOR, info=error_info, metadata=self._get_metadata_for_alarm()
self._propagate_instruction_error(
traceback.format_exc(),
lambda msg: ErrorInfo(
error_message=msg,
compact_error_message=traceback.format_exc(limit=0),
exception_type=exc_return_to_start.__class__.__name__,
device=None,
),
exc_return_to_start,
)
raise ScanAbortion from exc
raise ScanAbortion from exc
except DeviceInstructionError as exc_di:
content = traceback.format_exc()
logger.error(content)
self.connector.raise_alarm(
severity=Alarms.MAJOR,
info=exc_di.error_info,
metadata=self._get_metadata_for_alarm(),
)

raise ScanAbortion from exc_di
self._propagate_instruction_error(traceback.format_exc(), exc_di.error_info, exc_di)
except Exception as exc:
content = traceback.format_exc()
logger.error(content)
error_info = messages.ErrorInfo(
error_message=content,
compact_error_message=traceback.format_exc(limit=0),
exception_type=exc.__class__.__name__,
device=None,
)
self.connector.raise_alarm(
severity=Alarms.MAJOR, info=error_info, metadata=self._get_metadata_for_alarm()
self._propagate_instruction_error(
traceback.format_exc(),
lambda msg: ErrorInfo(
error_message=msg,
compact_error_message=traceback.format_exc(limit=0),
exception_type=exc.__class__.__name__,
device=None,
),
exc,
)

raise ScanAbortion from exc
def _process_instructions(self, queue: InstructionQueueItem) -> None:
"""
Process scan instructions and send DeviceInstructions to OPAAS.
For now this is an in-memory communication. In the future however,
we might want to pass it through a dedicated Kafka topic.
Args:
queue: instruction queue

Returns:

"""
if (start := self._init_process_loop(queue)) is None:
return None

scan_instance: RequestBase | None = getattr(queue.active_request_block, "scan", None)
devices_to_lock = (
[] if scan_instance is None else scan_instance.instance_device_access().device_locking
)
with self.parent.device_locks.lock(devices_to_lock):
self._process_instructions_inner(queue)

queue.is_active = False
queue.status = InstructionQueueStatus.COMPLETED
self.current_instruction_queue_item = None

logger.info(f"QUEUE ITEM finished after {time.time()-start:.2f} seconds")
self.reset()

Expand Down
31 changes: 27 additions & 4 deletions bec_server/bec_server/scan_server/scans.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from typing import Any, Literal

import numpy as np

from bec_lib import messages
from bec_lib.alarm_handler import Alarms
from bec_lib.device import DeviceBase
from bec_lib.devicemanager import DeviceManagerBase
from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger
from pydantic import BaseModel

from bec_lib import messages
from bec_server.scan_server.instruction_handler import InstructionHandler

from .errors import LimitError, ScanAbortion
Expand Down Expand Up @@ -245,7 +246,7 @@ class RequestBase(ABC):
"""

scan_name = ""
arg_input = {}
arg_input: dict[str, ScanArgType] = {}
arg_bundle_size = {"bundle": len(arg_input), "min": None, "max": None}
gui_args = {}
required_kwargs = []
Expand Down Expand Up @@ -376,6 +377,28 @@ def update_readout_priority(self):
def run(self):
pass

@classmethod
def device_access(cls, scan_parameters: dict) -> ScanDeviceAccessList:
"""Provide the devices for which permissions and locking are needed for this scan, with the given parameter set."""
arg_devices = set(scan_parameters.get("args", {}).keys())
param_kwargs = scan_parameters.get("kwargs", {})
kwarg_devices = set()
for arg, T in cls.arg_input.items():
if T == ScanArgType.DEVICE and arg in param_kwargs:
kwarg_devices.add(str(param_kwargs[arg]))
devices_used_in_scan = arg_devices | kwarg_devices
return ScanDeviceAccessList(
device_permissions=devices_used_in_scan, device_locking=devices_used_in_scan
)

def instance_device_access(self) -> ScanDeviceAccessList:
return self.device_access(self.parameter)


class ScanDeviceAccessList(BaseModel):
device_permissions: set[str]
device_locking: set[str]


class ScanBase(RequestBase, PathOptimizerMixin):
"""
Expand Down Expand Up @@ -403,7 +426,7 @@ class ScanBase(RequestBase, PathOptimizerMixin):
Attributes:
scan_name (str): name of the scan
scan_type (str): scan type. Can be "step" or "fly"
arg_input (list): list of scan argument types
arg_input (dict[str, ScanArgType]): list of scan argument types
arg_bundle_size (dict):
- bundle: number of arguments that are bundled together
- min: minimum number of bundles
Expand Down
21 changes: 21 additions & 0 deletions bec_server/tests/tests_scan_server/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Callable, Generator

import fakeredis
import pytest

from bec_lib.logger import bec_logger
from bec_lib.redis_connector import RedisConnector
from bec_server.scan_server.scan_queue import QueueManager
from bec_server.scan_server.tests.fixtures import scan_server_mock

# overwrite threads_check fixture from bec_lib,
# to have it in autouse
Expand All @@ -27,3 +31,20 @@ def connected_connector():
yield connector
finally:
connector.shutdown()


@pytest.fixture
def queuemanager_mock(scan_server_mock):
def _get_queuemanager(queues=None) -> QueueManager:
scan_server = scan_server_mock
if queues is None:
queues = ["primary"]
if isinstance(queues, str):
queues = [queues]
for queue in queues:
scan_server.queue_manager.add_queue(queue)
return scan_server.queue_manager

yield _get_queuemanager

scan_server_mock.queue_manager.shutdown()
Loading
Loading