Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bec_lib/bec_lib/atlas_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pydantic import BaseModel, Field, PrivateAttr, create_model, field_validator, model_validator
from pydantic_core import PydanticUndefined

from bec_lib.utils.json import ExtendedEncoder
from bec_lib.utils.json_extended import ExtendedEncoder

_BM = TypeVar("BM", bound=BaseModel)

Expand Down
2 changes: 1 addition & 1 deletion bec_lib/bec_lib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def load_high_level_interface(self, module_name: str) -> None:

def _update_username(self):
# pylint: disable=protected-access
self._username = self.connector._redis_conn.acl_whoami()
self._username = self.connector.username
self._system_user = getpass.getuser()

def _start_scan_queue(self):
Expand Down
2 changes: 1 addition & 1 deletion bec_lib/bec_lib/config_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from bec_lib.logger import bec_logger
from bec_lib.messages import ConfigAction
from bec_lib.utils.import_utils import lazy_import_from
from bec_lib.utils.json import ExtendedEncoder
from bec_lib.utils.json_extended import ExtendedEncoder

if TYPE_CHECKING: # pragma: no cover
from bec_lib.devicemanager import DeviceManagerBase
Expand Down
9 changes: 8 additions & 1 deletion bec_lib/bec_lib/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,15 @@ def _run_rpc_call(self, device, func_call, *args, wait_for_rpc_response=True, **
# pylint: disable=protected-access
if client.scans._scan_def_id:
msg.metadata["scan_def_id"] = client.scans._scan_def_id

msg.metadata["client_info"] = {
"acl_user": client.username,
"username": client._system_user,
"hostname": client._hostname,
}

# send RPC message
client.connector.send(MessageEndpoints.scan_queue_request(), msg)
client.connector.send(MessageEndpoints.scan_queue_request(client.username), msg)

# wait for RPC response
if not wait_for_rpc_response:
Expand Down
10 changes: 7 additions & 3 deletions bec_lib/bec_lib/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,15 +604,19 @@ def scan_queue_insert():
)

@staticmethod
def scan_queue_request():
def scan_queue_request(username: str):
"""
Endpoint for scan queue request. This endpoint is used to request the new scans.
The request is sent using a messages.ScanQueueMessage message.
The request is sent using a messages.ScanQueueMessage message and scoped to the user
making the request through redis ACLs.

Args:
username (str): Username of the user making the request.

Returns:
EndpointInfo: Endpoint for scan queue request.
"""
endpoint = f"{EndpointType.USER.value}/queue/queue_request"
endpoint = f"{EndpointType.PERSONAL.value}/{username}/queue/queue_request"
return EndpointInfo(
endpoint=endpoint, message_type=messages.ScanQueueMessage, message_op=MessageOp.SEND
)
Expand Down
10 changes: 10 additions & 0 deletions bec_lib/bec_lib/redis_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,16 @@ def authenticate(self, *, username: str = "default", password: str | None = "nul
self._redis_conn.connection_pool.connection_kwargs.update(old_kwargs)
raise exc

@property
def username(self) -> str:
"""
Get the username used for authentication

Returns:
str: username
"""
return str(self._redis_conn.acl_whoami())

def _close_pubsub(self):
if self._events_listener_thread:
self._stop_events_listener_thread.set()
Expand Down
2 changes: 0 additions & 2 deletions bec_lib/bec_lib/request_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,7 @@ def update_with_response(self, response_msg: messages.RequestResponseMessage) ->
logger.trace("Scan queue request exists. Updating with response.")
return

# it could be that the response arrived before the request
self.storage.append(RequestItem.from_response(self.scan_manager, response_msg))
logger.trace("Scan queue request does not exist. Creating from response.")

def update_with_request(self, request_msg: messages.ScanQueueMessage) -> None:
"""create or update request item based on a new ScanQueueMessage (i.e. request message)"""
Expand Down
7 changes: 0 additions & 7 deletions bec_lib/bec_lib/scan_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ def __init__(self, connector: RedisConnector):
self.connector.register(
topics=MessageEndpoints.scan_queue_status(), cb=self._scan_queue_status_callback
)
self.connector.register(
topics=MessageEndpoints.scan_queue_request(), cb=self._scan_queue_request_callback
)
self.connector.register(
topics=MessageEndpoints.scan_queue_request_response(),
cb=self._scan_queue_request_response_callback,
Expand Down Expand Up @@ -244,10 +241,6 @@ def _scan_queue_status_callback(self, msg, **_kwargs) -> None:
return
self.update_with_queue_status(queue_status)

def _scan_queue_request_callback(self, msg, **_kwargs) -> None:
request = msg.value
self.request_storage.update_with_request(request)

def _scan_queue_request_response_callback(self, msg, **_kwargs) -> None:
response = msg.value
self.request_storage.update_with_response(response)
Expand Down
20 changes: 8 additions & 12 deletions bec_lib/bec_lib/scans.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ def _run(

request = Scans.prepare_scan_request(self.scan_name, self.scan_info, *args, **kwargs)
request_id = str(uuid.uuid4())
request.metadata["client_info"] = {
"acl_user": self.client.username,
"username": self.client._system_user,
"hostname": self.client._hostname,
}

# pylint: disable=unsupported-assignment-operation
request.metadata["RID"] = request_id
Expand All @@ -144,20 +149,11 @@ def _run(

return report

def _start_register(self, request: messages.ScanQueueMessage) -> None:
"""Start a register for the given request"""
self.client.device_manager.connector.register(
[
MessageEndpoints.device_readback(dev)
for dev in request.content["parameter"]["args"].keys()
],
threaded=False,
cb=(lambda msg: msg),
)

def _send_scan_request(self, request: messages.ScanQueueMessage) -> None:
"""Send a scan request to the scan server"""
self.client.device_manager.connector.send(MessageEndpoints.scan_queue_request(), request)
self.client.device_manager.connector.send(
MessageEndpoints.scan_queue_request(self.client.username), request
)


class Scans:
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion bec_lib/tests/test_device_hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
HashableDeviceSet,
HashInclusion,
)
from bec_lib.utils.json import ExtendedEncoder
from bec_lib.utils.json_extended import ExtendedEncoder

TEST_DEVICE_DICT = {
"name": "test_device",
Expand Down
2 changes: 1 addition & 1 deletion bec_lib/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import BaseModel

from bec_lib.utils.json import ExtendedEncoder
from bec_lib.utils.json_extended import ExtendedEncoder


def test_encoder_encodes_set():
Expand Down
94 changes: 71 additions & 23 deletions bec_server/bec_server/scan_server/scan_guard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import re
import traceback
import uuid
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -43,7 +44,9 @@ def __init__(self, *, parent: ScanServer):
self.connector = self.parent.connector

self.connector.register(
MessageEndpoints.scan_queue_request(), cb=self._scan_queue_request_callback, parent=self
patterns=MessageEndpoints.scan_queue_request("*"),
cb=self._scan_queue_request_callback,
parent=self,
)

self.connector.register(
Expand All @@ -57,24 +60,61 @@ def __init__(self, *, parent: ScanServer):
parent=self,
)

def _is_valid_scan_request(self, request) -> ScanStatus:
def _is_valid_scan_request(
self, request: messages.ScanQueueMessage, username: str
) -> ScanStatus:
"""
Perform validity checks on the scan request.
Args:
request(messages.ScanQueueMessage): A scan queue message
username(str): The username associated with the request
Returns:
ScanStatus: The status of the scan request
"""
try:
self._check_valid_request(request)
self._check_valid_request(request, username)
self._check_valid_scan(request)
self._check_baton(request)
self._check_motors_movable(request)
self._check_soft_limits(request)
# pylint: disable=broad-except
except Exception:
content = traceback.format_exc()
return ScanStatus(False, str(content))
return ScanStatus()

def _check_valid_request(self, request) -> None:
def _check_valid_request(self, request: messages.ScanQueueMessage, username: str) -> None:
"""
Check if the scan request is valid.
It must be a proper scan queue message and the username in the topic
must match the client info username, thus ensuring that users cannot
submit scans on behalf of other users.

Args:
request(messages.ScanQueueMessage): A scan queue message
username(str): The username associated with the request
Raises:
ScanRejection: If the request is invalid
"""
if request is None:
raise ScanRejection("Invalid request.")
client_info = request.metadata.get("client_info")
if not client_info:
raise ScanRejection("Missing client info in request metadata.")

# Note: the default redis user does not have an acl username
acl_user = client_info.get("acl_user") or "default"
if acl_user != username:
raise ScanRejection("Username in topic does not match client info.")

def _check_valid_scan(self, request: messages.ScanQueueMessage) -> None:
"""
Check if the scan is valid and known.

def _check_valid_scan(self, request) -> None:
Args:
request(messages.ScanQueueMessage): A scan queue message
Raises:
ScanRejection: If the scan is invalid
"""
avail_scans = self.connector.get(MessageEndpoints.available_scans())
scan_type = request.content.get("scan_type")
if scan_type not in avail_scans.resource:
Expand All @@ -93,20 +133,28 @@ def _device_rpc_is_valid(self, device: str, func: str) -> bool:
return False
return True

def _check_baton(self, request) -> None:
def _check_baton(self, request: messages.ScanQueueMessage) -> None:
# TODO: Implement baton handling
pass

def _check_motors_movable(self, request) -> None:
if request.content["scan_type"] == "device_rpc":
device = request.content["parameter"]["device"]
def _check_motors_movable(self, request: messages.ScanQueueMessage) -> None:
"""
Check if the motors involved in the scan request are movable.
Args:
request(messages.ScanQueueMessage): A scan queue message
Raises:
ScanRejection: If any motor is not enabled or movable
"""
parameter = request.parameter
if request.scan_type == "device_rpc":
device = parameter.get("device")
if not isinstance(device, list):
device = [device]
for dev in device:
if not self.device_manager.devices[dev].enabled:
raise ScanRejection(f"Device {dev} is not enabled.")
return
motor_args = request.content["parameter"].get("args")
motor_args = parameter.get("args")
if not motor_args:
return
for motor in motor_args:
Expand All @@ -119,16 +167,18 @@ def _check_motors_movable(self, request) -> None:
if not self.device_manager.devices[motor].enabled:
raise ScanRejection(f"Device {motor} is not enabled.")

def _check_soft_limits(self, request) -> None:
# TODO: Implement soft limit checks
pass

@staticmethod
def _scan_queue_request_callback(msg, parent, **_kwargs):
content = msg.value.content
logger.info(f"Receiving scan request: {content}")
username_regex = f"^{MessageEndpoints.scan_queue_request('([^/]+)').endpoint}$"
result = re.match(username_regex, msg.topic)
if not result:
raise ScanRejection("Could not extract username from topic.")
username = result.group(1)

logger.info(f"Receiving scan request: {content} from user {username}")
# pylint: disable=protected-access
parent._handle_scan_request(msg.value)
parent._handle_scan_request(msg.value, username=username)

@staticmethod
def _scan_queue_modification_request_callback(msg, parent, **_kwargs):
Expand All @@ -154,17 +204,15 @@ def _send_scan_request_response(self, scan_status: ScanStatus, metadata: dict):
)
self.device_manager.connector.send(sqrr, rrm)

def _handle_scan_request(self, msg: messages.ScanQueueMessage):
def _handle_scan_request(self, msg: messages.ScanQueueMessage, username: str):
"""
Perform validity checks on the scan request and reply with a 'scan_request_response'.
If the scan is accepted it will be enqueued.
Args:
msg: ConsumerRecord value

Returns:

msg(messages.ScanQueueMessage): A scan queue message
username(str): The username associated with the request
"""
scan_status = self._is_valid_scan_request(msg)
scan_status = self._is_valid_scan_request(msg, username=username)

self._send_scan_request_response(scan_status, msg.metadata)
if not scan_status.accepted:
Expand Down
Loading
Loading