From 0fcce6fde85abcd027dd95e00447a34353274471 Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Tue, 6 Jan 2026 09:18:21 +0100 Subject: [PATCH 1/4] refactor: rename json module to json_extended to avoid import conflicts with normal json --- bec_lib/bec_lib/atlas_models.py | 2 +- bec_lib/bec_lib/config_helper.py | 2 +- bec_lib/bec_lib/utils/{json.py => json_extended.py} | 0 bec_lib/tests/test_device_hashing.py | 2 +- bec_lib/tests/test_utils.py | 2 +- 5 files changed, 4 insertions(+), 4 deletions(-) rename bec_lib/bec_lib/utils/{json.py => json_extended.py} (100%) diff --git a/bec_lib/bec_lib/atlas_models.py b/bec_lib/bec_lib/atlas_models.py index c72ca5cef..9676dcefc 100644 --- a/bec_lib/bec_lib/atlas_models.py +++ b/bec_lib/bec_lib/atlas_models.py @@ -12,7 +12,7 @@ from pydantic import BaseModel, Field, PrivateAttr, create_model, field_validator, model_validator from pydantic_core import PydanticUndefined -from bec_lib.utils.json import ExtendedEncoder +from bec_lib.utils.json_extended import ExtendedEncoder _BM = TypeVar("BM", bound=BaseModel) diff --git a/bec_lib/bec_lib/config_helper.py b/bec_lib/bec_lib/config_helper.py index 4281f9b94..11699e024 100644 --- a/bec_lib/bec_lib/config_helper.py +++ b/bec_lib/bec_lib/config_helper.py @@ -31,7 +31,7 @@ from bec_lib.logger import bec_logger from bec_lib.messages import ConfigAction from bec_lib.utils.import_utils import lazy_import_from -from bec_lib.utils.json import ExtendedEncoder +from bec_lib.utils.json_extended import ExtendedEncoder if TYPE_CHECKING: # pragma: no cover from bec_lib.devicemanager import DeviceManagerBase diff --git a/bec_lib/bec_lib/utils/json.py b/bec_lib/bec_lib/utils/json_extended.py similarity index 100% rename from bec_lib/bec_lib/utils/json.py rename to bec_lib/bec_lib/utils/json_extended.py diff --git a/bec_lib/tests/test_device_hashing.py b/bec_lib/tests/test_device_hashing.py index ee283e721..15af1d504 100644 --- a/bec_lib/tests/test_device_hashing.py +++ b/bec_lib/tests/test_device_hashing.py @@ -11,7 +11,7 @@ HashableDeviceSet, HashInclusion, ) -from bec_lib.utils.json import ExtendedEncoder +from bec_lib.utils.json_extended import ExtendedEncoder TEST_DEVICE_DICT = { "name": "test_device", diff --git a/bec_lib/tests/test_utils.py b/bec_lib/tests/test_utils.py index 45025985b..72f5676c3 100644 --- a/bec_lib/tests/test_utils.py +++ b/bec_lib/tests/test_utils.py @@ -2,7 +2,7 @@ from pydantic import BaseModel -from bec_lib.utils.json import ExtendedEncoder +from bec_lib.utils.json_extended import ExtendedEncoder def test_encoder_encodes_set(): From 334c2ac35c4b693a82fd9b465ce6514ab2388bc9 Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Tue, 6 Jan 2026 09:52:56 +0100 Subject: [PATCH 2/4] refactor(scans): remove unused register function --- bec_lib/bec_lib/scans.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/bec_lib/bec_lib/scans.py b/bec_lib/bec_lib/scans.py index b9358398c..c3fb54936 100644 --- a/bec_lib/bec_lib/scans.py +++ b/bec_lib/bec_lib/scans.py @@ -144,17 +144,6 @@ def _run( return report - def _start_register(self, request: messages.ScanQueueMessage) -> None: - """Start a register for the given request""" - self.client.device_manager.connector.register( - [ - MessageEndpoints.device_readback(dev) - for dev in request.content["parameter"]["args"].keys() - ], - threaded=False, - cb=(lambda msg: msg), - ) - def _send_scan_request(self, request: messages.ScanQueueMessage) -> None: """Send a scan request to the scan server""" self.client.device_manager.connector.send(MessageEndpoints.scan_queue_request(), request) From 19f211bff404547272f916646a773aa842f2ba87 Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Tue, 6 Jan 2026 09:54:05 +0100 Subject: [PATCH 3/4] feat(redis connector): add username as RO property --- bec_lib/bec_lib/redis_connector.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/bec_lib/bec_lib/redis_connector.py b/bec_lib/bec_lib/redis_connector.py index 9850f8d67..f142e4923 100644 --- a/bec_lib/bec_lib/redis_connector.py +++ b/bec_lib/bec_lib/redis_connector.py @@ -307,6 +307,16 @@ def authenticate(self, *, username: str = "default", password: str | None = "nul self._redis_conn.connection_pool.connection_kwargs.update(old_kwargs) raise exc + @property + def username(self) -> str: + """ + Get the username used for authentication + + Returns: + str: username + """ + return str(self._redis_conn.acl_whoami()) + def _close_pubsub(self): if self._events_listener_thread: self._stop_events_listener_thread.set() From d5867ed9d3f1840367047b47a939bed0013879e6 Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Tue, 6 Jan 2026 10:28:13 +0100 Subject: [PATCH 4/4] feat: scope scan requests to acl user --- bec_lib/bec_lib/client.py | 2 +- bec_lib/bec_lib/device.py | 9 +- bec_lib/bec_lib/endpoints.py | 10 +- bec_lib/bec_lib/request_items.py | 2 - bec_lib/bec_lib/scan_manager.py | 7 -- bec_lib/bec_lib/scans.py | 9 +- .../bec_server/scan_server/scan_guard.py | 94 ++++++++++++++----- .../tests_scan_server/test_scan_guard.py | 44 +++++++-- 8 files changed, 131 insertions(+), 46 deletions(-) diff --git a/bec_lib/bec_lib/client.py b/bec_lib/bec_lib/client.py index ef8ae52eb..922c81343 100644 --- a/bec_lib/bec_lib/client.py +++ b/bec_lib/bec_lib/client.py @@ -296,7 +296,7 @@ def load_high_level_interface(self, module_name: str) -> None: def _update_username(self): # pylint: disable=protected-access - self._username = self.connector._redis_conn.acl_whoami() + self._username = self.connector.username self._system_user = getpass.getuser() def _start_scan_queue(self): diff --git a/bec_lib/bec_lib/device.py b/bec_lib/bec_lib/device.py index 334bc618a..edd3f7601 100644 --- a/bec_lib/bec_lib/device.py +++ b/bec_lib/bec_lib/device.py @@ -368,8 +368,15 @@ def _run_rpc_call(self, device, func_call, *args, wait_for_rpc_response=True, ** # pylint: disable=protected-access if client.scans._scan_def_id: msg.metadata["scan_def_id"] = client.scans._scan_def_id + + msg.metadata["client_info"] = { + "acl_user": client.username, + "username": client._system_user, + "hostname": client._hostname, + } + # send RPC message - client.connector.send(MessageEndpoints.scan_queue_request(), msg) + client.connector.send(MessageEndpoints.scan_queue_request(client.username), msg) # wait for RPC response if not wait_for_rpc_response: diff --git a/bec_lib/bec_lib/endpoints.py b/bec_lib/bec_lib/endpoints.py index 578486be0..d87f1aeee 100644 --- a/bec_lib/bec_lib/endpoints.py +++ b/bec_lib/bec_lib/endpoints.py @@ -604,15 +604,19 @@ def scan_queue_insert(): ) @staticmethod - def scan_queue_request(): + def scan_queue_request(username: str): """ Endpoint for scan queue request. This endpoint is used to request the new scans. - The request is sent using a messages.ScanQueueMessage message. + The request is sent using a messages.ScanQueueMessage message and scoped to the user + making the request through redis ACLs. + + Args: + username (str): Username of the user making the request. Returns: EndpointInfo: Endpoint for scan queue request. """ - endpoint = f"{EndpointType.USER.value}/queue/queue_request" + endpoint = f"{EndpointType.PERSONAL.value}/{username}/queue/queue_request" return EndpointInfo( endpoint=endpoint, message_type=messages.ScanQueueMessage, message_op=MessageOp.SEND ) diff --git a/bec_lib/bec_lib/request_items.py b/bec_lib/bec_lib/request_items.py index 20ebc268f..f02336731 100644 --- a/bec_lib/bec_lib/request_items.py +++ b/bec_lib/bec_lib/request_items.py @@ -132,9 +132,7 @@ def update_with_response(self, response_msg: messages.RequestResponseMessage) -> logger.trace("Scan queue request exists. Updating with response.") return - # it could be that the response arrived before the request self.storage.append(RequestItem.from_response(self.scan_manager, response_msg)) - logger.trace("Scan queue request does not exist. Creating from response.") def update_with_request(self, request_msg: messages.ScanQueueMessage) -> None: """create or update request item based on a new ScanQueueMessage (i.e. request message)""" diff --git a/bec_lib/bec_lib/scan_manager.py b/bec_lib/bec_lib/scan_manager.py index 993ce766b..8ec30280a 100644 --- a/bec_lib/bec_lib/scan_manager.py +++ b/bec_lib/bec_lib/scan_manager.py @@ -44,9 +44,6 @@ def __init__(self, connector: RedisConnector): self.connector.register( topics=MessageEndpoints.scan_queue_status(), cb=self._scan_queue_status_callback ) - self.connector.register( - topics=MessageEndpoints.scan_queue_request(), cb=self._scan_queue_request_callback - ) self.connector.register( topics=MessageEndpoints.scan_queue_request_response(), cb=self._scan_queue_request_response_callback, @@ -244,10 +241,6 @@ def _scan_queue_status_callback(self, msg, **_kwargs) -> None: return self.update_with_queue_status(queue_status) - def _scan_queue_request_callback(self, msg, **_kwargs) -> None: - request = msg.value - self.request_storage.update_with_request(request) - def _scan_queue_request_response_callback(self, msg, **_kwargs) -> None: response = msg.value self.request_storage.update_with_response(response) diff --git a/bec_lib/bec_lib/scans.py b/bec_lib/bec_lib/scans.py index c3fb54936..40a258f22 100644 --- a/bec_lib/bec_lib/scans.py +++ b/bec_lib/bec_lib/scans.py @@ -124,6 +124,11 @@ def _run( request = Scans.prepare_scan_request(self.scan_name, self.scan_info, *args, **kwargs) request_id = str(uuid.uuid4()) + request.metadata["client_info"] = { + "acl_user": self.client.username, + "username": self.client._system_user, + "hostname": self.client._hostname, + } # pylint: disable=unsupported-assignment-operation request.metadata["RID"] = request_id @@ -146,7 +151,9 @@ def _run( def _send_scan_request(self, request: messages.ScanQueueMessage) -> None: """Send a scan request to the scan server""" - self.client.device_manager.connector.send(MessageEndpoints.scan_queue_request(), request) + self.client.device_manager.connector.send( + MessageEndpoints.scan_queue_request(self.client.username), request + ) class Scans: diff --git a/bec_server/bec_server/scan_server/scan_guard.py b/bec_server/bec_server/scan_server/scan_guard.py index 8d56f4f59..078860f6c 100644 --- a/bec_server/bec_server/scan_server/scan_guard.py +++ b/bec_server/bec_server/scan_server/scan_guard.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re import traceback import uuid from typing import TYPE_CHECKING @@ -43,7 +44,9 @@ def __init__(self, *, parent: ScanServer): self.connector = self.parent.connector self.connector.register( - MessageEndpoints.scan_queue_request(), cb=self._scan_queue_request_callback, parent=self + patterns=MessageEndpoints.scan_queue_request("*"), + cb=self._scan_queue_request_callback, + parent=self, ) self.connector.register( @@ -57,24 +60,61 @@ def __init__(self, *, parent: ScanServer): parent=self, ) - def _is_valid_scan_request(self, request) -> ScanStatus: + def _is_valid_scan_request( + self, request: messages.ScanQueueMessage, username: str + ) -> ScanStatus: + """ + Perform validity checks on the scan request. + Args: + request(messages.ScanQueueMessage): A scan queue message + username(str): The username associated with the request + Returns: + ScanStatus: The status of the scan request + """ try: - self._check_valid_request(request) + self._check_valid_request(request, username) self._check_valid_scan(request) self._check_baton(request) self._check_motors_movable(request) - self._check_soft_limits(request) # pylint: disable=broad-except except Exception: content = traceback.format_exc() return ScanStatus(False, str(content)) return ScanStatus() - def _check_valid_request(self, request) -> None: + def _check_valid_request(self, request: messages.ScanQueueMessage, username: str) -> None: + """ + Check if the scan request is valid. + It must be a proper scan queue message and the username in the topic + must match the client info username, thus ensuring that users cannot + submit scans on behalf of other users. + + Args: + request(messages.ScanQueueMessage): A scan queue message + username(str): The username associated with the request + Raises: + ScanRejection: If the request is invalid + """ if request is None: raise ScanRejection("Invalid request.") + client_info = request.metadata.get("client_info") + if not client_info: + raise ScanRejection("Missing client info in request metadata.") + + # Note: the default redis user does not have an acl username + acl_user = client_info.get("acl_user") or "default" + if acl_user != username: + raise ScanRejection("Username in topic does not match client info.") + + def _check_valid_scan(self, request: messages.ScanQueueMessage) -> None: + """ + Check if the scan is valid and known. - def _check_valid_scan(self, request) -> None: + Args: + request(messages.ScanQueueMessage): A scan queue message + Raises: + ScanRejection: If the scan is invalid + """ avail_scans = self.connector.get(MessageEndpoints.available_scans()) scan_type = request.content.get("scan_type") if scan_type not in avail_scans.resource: @@ -93,20 +133,28 @@ def _device_rpc_is_valid(self, device: str, func: str) -> bool: return False return True - def _check_baton(self, request) -> None: + def _check_baton(self, request: messages.ScanQueueMessage) -> None: # TODO: Implement baton handling pass - def _check_motors_movable(self, request) -> None: - if request.content["scan_type"] == "device_rpc": - device = request.content["parameter"]["device"] + def _check_motors_movable(self, request: messages.ScanQueueMessage) -> None: + """ + Check if the motors involved in the scan request are movable. + Args: + request(messages.ScanQueueMessage): A scan queue message + Raises: + ScanRejection: If any motor is not enabled or movable + """ + parameter = request.parameter + if request.scan_type == "device_rpc": + device = parameter.get("device") if not isinstance(device, list): device = [device] for dev in device: if not self.device_manager.devices[dev].enabled: raise ScanRejection(f"Device {dev} is not enabled.") return - motor_args = request.content["parameter"].get("args") + motor_args = parameter.get("args") if not motor_args: return for motor in motor_args: @@ -119,16 +167,18 @@ def _check_motors_movable(self, request) -> None: if not self.device_manager.devices[motor].enabled: raise ScanRejection(f"Device {motor} is not enabled.") - def _check_soft_limits(self, request) -> None: - # TODO: Implement soft limit checks - pass - @staticmethod def _scan_queue_request_callback(msg, parent, **_kwargs): content = msg.value.content - logger.info(f"Receiving scan request: {content}") + username_regex = f"^{MessageEndpoints.scan_queue_request('([^/]+)').endpoint}$" + result = re.match(username_regex, msg.topic) + if not result: + raise ScanRejection("Could not extract username from topic.") + username = result.group(1) + + logger.info(f"Receiving scan request: {content} from user {username}") # pylint: disable=protected-access - parent._handle_scan_request(msg.value) + parent._handle_scan_request(msg.value, username=username) @staticmethod def _scan_queue_modification_request_callback(msg, parent, **_kwargs): @@ -154,17 +204,15 @@ def _send_scan_request_response(self, scan_status: ScanStatus, metadata: dict): ) self.device_manager.connector.send(sqrr, rrm) - def _handle_scan_request(self, msg: messages.ScanQueueMessage): + def _handle_scan_request(self, msg: messages.ScanQueueMessage, username: str): """ Perform validity checks on the scan request and reply with a 'scan_request_response'. If the scan is accepted it will be enqueued. Args: - msg: ConsumerRecord value - - Returns: - + msg(messages.ScanQueueMessage): A scan queue message + username(str): The username associated with the request """ - scan_status = self._is_valid_scan_request(msg) + scan_status = self._is_valid_scan_request(msg, username=username) self._send_scan_request_response(scan_status, msg.metadata) if not scan_status.accepted: diff --git a/bec_server/tests/tests_scan_server/test_scan_guard.py b/bec_server/tests/tests_scan_server/test_scan_guard.py index 130cb3754..66d04ebb6 100644 --- a/bec_server/tests/tests_scan_server/test_scan_guard.py +++ b/bec_server/tests/tests_scan_server/test_scan_guard.py @@ -74,6 +74,7 @@ def test_device_rpc_is_valid(scan_guard_mock, device, func, is_valid): scan_type="fermat_scan", parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, queue="primary", + metadata={"client_info": {"acl_user": "default"}}, ), True, ), @@ -82,6 +83,7 @@ def test_device_rpc_is_valid(scan_guard_mock, device, func, is_valid): scan_type="device_rpc", parameter={"device": "samy", "args": {}, "kwargs": {}}, queue="primary", + metadata={"client_info": {"acl_user": "default"}}, ), True, ), @@ -90,6 +92,7 @@ def test_device_rpc_is_valid(scan_guard_mock, device, func, is_valid): scan_type="device_rpc", parameter={"device": ["samy"], "args": {}, "kwargs": {}}, queue="primary", + metadata={"client_info": {"acl_user": "default"}}, ), True, ), @@ -107,7 +110,7 @@ def test_valid_request(scan_server_mock, scan_queue_msg, valid): with mock.patch.object(sg, "_check_valid_scan") as valid_scan: k.device_manager.devices["samx"].enabled = True k.device_manager.devices["samy"].enabled = True - status = sg._is_valid_scan_request(scan_queue_msg) + status = sg._is_valid_scan_request(scan_queue_msg, username="default") valid_scan.assert_called_once_with(scan_queue_msg) assert status.accepted == valid @@ -220,10 +223,10 @@ def test_scan_queue_request_callback(scan_guard_mock): parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, queue="primary", ) - msg_obj = MessageObject(MessageEndpoints.scan_queue_request(), msg) + msg_obj = MessageObject(MessageEndpoints.scan_queue_request("default").endpoint, msg) with mock.patch.object(sg, "_handle_scan_request") as handle: sg._scan_queue_request_callback(msg_obj, sg) - handle.assert_called_once_with(msg) + handle.assert_called_once_with(msg, username="default") def test_scan_queue_modification_request_callback(scan_guard_mock): @@ -257,7 +260,7 @@ def test_handle_scan_request(scan_guard_mock): with mock.patch.object(sg, "_is_valid_scan_request") as valid: with mock.patch.object(sg, "_append_to_scan_queue") as append: valid.return_value = ScanStatus(accepted=True, message="") - sg._handle_scan_request(msg) + sg._handle_scan_request(msg, username="default") append.assert_called_once_with(msg) @@ -323,7 +326,7 @@ def test_handle_scan_request_bypassed_for_read(scan_guard_mock, msg): with mock.patch.object(sg, "_is_valid_scan_request") as valid: with mock.patch.object(sg, "_append_to_scan_queue") as append: valid.return_value = ScanStatus(accepted=True, message="") - sg._handle_scan_request(msg) + sg._handle_scan_request(msg, username="default") append.assert_not_called() send.assert_called_once_with(MessageEndpoints.device_instructions(), mock.ANY) @@ -338,7 +341,7 @@ def test_handle_scan_request_rejected(scan_guard_mock): with mock.patch.object(sg, "_is_valid_scan_request") as valid: with mock.patch.object(sg, "_append_to_scan_queue") as append: valid.return_value = ScanStatus(accepted=False, message="") - sg._handle_scan_request(msg) + sg._handle_scan_request(msg, username="default") append.assert_not_called() @@ -348,18 +351,43 @@ def test_is_valid_scan_request_returns_scan_status_on_error(scan_guard_mock): scan_type="fermat_scan", parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, queue="primary", + metadata={"client_info": {"acl_user": "default"}}, ) with mock.patch.object(sg, "_check_valid_scan") as valid: valid.side_effect = Exception("Test exception") - status = sg._is_valid_scan_request(msg) + status = sg._is_valid_scan_request(msg, username="default") assert status.accepted == False assert "Test exception" in status.message +def test_check_valid_request_raises_for_missing_client_info(scan_guard_mock): + sg = scan_guard_mock + msg = messages.ScanQueueMessage( + scan_type="fermat_scan", + parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + queue="primary", + metadata={}, + ) + with pytest.raises(ScanRejection, match="Missing client info in request metadata."): + sg._check_valid_request(msg, username="default") + + +def test_check_valid_request_raises_for_username_mismatch(scan_guard_mock): + sg = scan_guard_mock + msg = messages.ScanQueueMessage( + scan_type="fermat_scan", + parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, + queue="primary", + metadata={"client_info": {"acl_user": "other_user"}}, + ) + with pytest.raises(ScanRejection, match="Username in topic does not match client info."): + sg._check_valid_request(msg, username="default") + + def test_check_valid_request_raises_for_empty_request(scan_guard_mock): sg = scan_guard_mock with pytest.raises(ScanRejection) as scan_rejection: - sg._check_valid_request(None) + sg._check_valid_request(None, username="default") assert "Invalid request." in scan_rejection.value.args