Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions bec_ipython_client/tests/end-2-end/test_scans_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ def test_mv_scan_nested_device(capsys, bec_ipython_client_fixture):
bec.metadata.update({"unit_test": "test_mv_scan_nested_device"})
dev = bec.device_manager.devices
scans.mv(dev.hexapod.x, 10, dev.hexapod.y, 20, relative=False).wait()
if not bec.connector._messages_queue.empty():
while not bec.connector._messages_queue.empty():
print("Waiting for messages to be processed")
time.sleep(0.5)
time.sleep(0.1)
current_pos_hexapod_x = dev.hexapod.x.read(cached=True)["hexapod_x"]["value"]
current_pos_hexapod_y = dev.hexapod.y.read(cached=True)["hexapod_y"]["value"]
assert np.isclose(
Expand All @@ -126,9 +126,9 @@ def test_mv_scan_nested_device(capsys, bec_ipython_client_fixture):
current_pos_hexapod_y, 20, atol=dev.hexapod._config["deviceConfig"].get("tolerance", 0.5)
)
scans.umv(dev.hexapod.x, 10, dev.hexapod.y, 20, relative=False)
if not bec.connector._messages_queue.empty():
while not bec.connector._messages_queue.empty():
print("Waiting for messages to be processed")
time.sleep(0.5)
time.sleep(0.1)
current_pos_hexapod_x = dev.hexapod.x.read(cached=True)["hexapod_x"]["value"]
current_pos_hexapod_y = dev.hexapod.y.read(cached=True)["hexapod_y"]["value"]
captured = capsys.readouterr()
Expand Down
115 changes: 115 additions & 0 deletions bec_lib/bec_lib/bec_serializable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# pylint: disable=too-many-lines
from __future__ import annotations

import base64
from io import BytesIO
from types import UnionType
from typing import Annotated, Any, Callable, ClassVar, Union, get_args, get_origin

import numpy as np
from pydantic import (
BaseModel,
ConfigDict,
PlainSerializer,
WithJsonSchema,
computed_field,
model_validator,
)

_NDARRAY_TAG = b"__NP_NDARRAY__"
_NDARRAY_TAG_STR = _NDARRAY_TAG.decode()
_NDARRAY_TAG_OFFSET = len(_NDARRAY_TAG)


def ndarray_to_bytes(arr: np.ndarray) -> bytes:
if not isinstance(arr, np.ndarray):
return arr
out_buf = BytesIO()
np.save(out_buf, arr)
return _NDARRAY_TAG + base64.urlsafe_b64encode(out_buf.getvalue())


def numpy_decode(input: str | bytes):
is_str = isinstance(input, str)
is_bytes = isinstance(input, bytes)
# let pydantic handle any other validation or coercion
if not (is_str or is_bytes):
return input
if is_str and not input.startswith(_NDARRAY_TAG_STR):
return input
if is_bytes and not input.startswith(_NDARRAY_TAG):
return input
# strip the tag, decode, and load
io = BytesIO(base64.urlsafe_b64decode(input[_NDARRAY_TAG_OFFSET:]))
return np.load(io)


NumpyField = Annotated[
np.ndarray,
PlainSerializer(ndarray_to_bytes),
WithJsonSchema({"type": "string", "contentEncoding": "base64"}),
]


def serialize_type(cls: type):
return cls.__name__


class BecCodecInfo(BaseModel):
type_name: str


class BECSerializable(BaseModel):

_deserialization_registry: ClassVar[
list[tuple[tuple[type | Annotated, ...], Callable[[Any], Any]]]
] = [((NumpyField, np.ndarray), numpy_decode)]

model_config = ConfigDict(
json_schema_serialization_defaults_required=True,
json_encoders={np.ndarray: ndarray_to_bytes},
arbitrary_types_allowed=True,
)

@computed_field()
@property
def bec_codec(self) -> BecCodecInfo:
return BecCodecInfo(type_name=self.__class__.__name__)

@classmethod
def _try_apply_registry(cls, anno: type, data: dict, field: str):
for entry, deserializer in cls._deserialization_registry:
if anno in entry:
data[field] = deserializer(data[field])

@model_validator(mode="before")
@classmethod
def deser_custom(cls, data: dict[str, Any]):
for field in data:
if (field_info := cls.model_fields.get(field)) is not None:
if field_info.annotation is None:
continue # No need to do anything for NoneType
if get_origin(field_info.annotation) in [UnionType, Union]:
for arg in get_args(field_info.annotation):
cls._try_apply_registry(arg, data, field)
else:
cls._try_apply_registry(field_info.annotation, data, field)
return data


class BecWrappedValue(BECSerializable):
data: np.ndarray # can be extended, must be in registry

def __getattr__(self, name: str) -> Any:
if hasattr(self.data, name):
return getattr(self.data, name)
else:
raise AttributeError(
f"{self.__class__.__name__} wrapping data type {type(self.data)} has no attribute {name} on either itself or its data."
)

def __getitem__(self, item):
if hasattr(self.data, "__getitem__"):
return self.data.__getitem__(item)
else:
raise AttributeError(f"Wrapped data type {type(self.data)} has no __getitem__.")
134 changes: 1 addition & 133 deletions bec_lib/bec_lib/codecs.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
from __future__ import annotations

import enum
from abc import ABC, abstractmethod
from typing import Any, Type

import numpy as np
from pydantic import BaseModel

from bec_lib import messages as messages_module
from bec_lib import numpy_encoder
from bec_lib.device import DeviceBase
from bec_lib.endpoints import EndpointInfo
from bec_lib.messages import BECMessage, BECStatus


class BECCodec(ABC):
Expand All @@ -24,137 +16,13 @@ class BECCodec(ABC):
def encode(obj: Any) -> Any:
"""Encode an object into a serializable format."""

@staticmethod
@abstractmethod
def decode(type_name: str, data: dict):
"""Decode data into an object."""


class NumpyEncoder(BECCodec):
obj_type: list[Type] = [np.ndarray, np.bool_, np.number, complex]

@staticmethod
def encode(obj: np.ndarray) -> dict:
return numpy_encoder.numpy_encode(obj)

@staticmethod
def decode(type_name: str, data: dict) -> np.ndarray:
return numpy_encoder.numpy_decode(data)


class NumpyEncoderList(BECCodec):
obj_type: list[Type] = [np.ndarray, np.bool_, np.number, complex]

@staticmethod
def encode(obj: np.ndarray) -> dict:
return numpy_encoder.numpy_encode_list(obj)

@staticmethod
def decode(type_name: str, data: dict) -> np.ndarray:
return numpy_encoder.numpy_decode_list(data)


class BECMessageEncoder(BECCodec):
obj_type: Type = BECMessage

@staticmethod
def encode(obj: BECMessage) -> dict:
return obj.__dict__

@staticmethod
def decode(type_name: str, data: dict) -> BECMessage:
return getattr(messages_module, type_name)(**data)


class EnumEncoder(BECCodec):
obj_type: Type = enum.Enum

@staticmethod
def encode(obj: enum.Enum) -> Any:
return obj.value

@staticmethod
def decode(type_name: str, data: Any) -> Any:
if type_name == "BECStatus":
return BECStatus(data)
return data


class BECDeviceEncoder(BECCodec):
obj_type: Type = DeviceBase
obj_type = DeviceBase

@staticmethod
def encode(obj: DeviceBase) -> str:
if hasattr(obj, "_compile_function_path"):
# pylint: disable=protected-access
return obj._compile_function_path()
return obj.name

@staticmethod
def decode(type_name: str, data: str) -> str:
"""
DeviceBase objects are encoded as strings. No decoding is necessary.
"""
return data


class PydanticEncoder(BECCodec):
obj_type: Type = BaseModel

@staticmethod
def encode(obj: BaseModel) -> dict:
return obj.model_dump()

@staticmethod
def decode(type_name: str, data: dict) -> dict:
return data


class EndpointInfoEncoder(BECCodec):
obj_type: Type = EndpointInfo

@staticmethod
def encode(obj: EndpointInfo) -> dict:
return {
"endpoint": obj.endpoint,
"message_type": obj.message_type.__name__,
"message_op": obj.message_op,
}

@staticmethod
def decode(type_name: str, data: dict) -> EndpointInfo:
return EndpointInfo(
endpoint=data["endpoint"],
message_type=getattr(messages_module, data["message_type"]),
message_op=data["message_op"],
)


class SetEncoder(BECCodec):
obj_type: Type = set

@staticmethod
def encode(obj: set) -> list:
return list(obj)

@staticmethod
def decode(type_name: str, data: list) -> set:
return set(data)


class BECTypeEncoder(BECCodec):
obj_type: Type = type

@staticmethod
def encode(obj: type) -> dict:
return {"type_name": obj.__name__, "module": obj.__module__}

@staticmethod
def decode(type_name: str, data: dict) -> type:
if data["module"] == "builtins":
return __builtins__.get(data["type_name"])
if data["module"] == "bec_lib.messages":
return getattr(messages_module, data["type_name"])
if data["module"] == "numpy":
return getattr(np, data["type_name"])
raise ValueError(f"Unknown type {data['type_name']} in module {data['module']}")
10 changes: 7 additions & 3 deletions bec_lib/bec_lib/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from bec_lib.atlas_models import _DeviceModelCore
from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger
from bec_lib.messages import DeviceMessage
from bec_lib.queue_items import QueueItem
from bec_lib.utils.import_utils import lazy_import

Expand Down Expand Up @@ -1111,12 +1112,15 @@ def limits(self):
"""
Returns the device limits.
"""
limit_msg = self.root.parent.connector.get(MessageEndpoints.device_limits(self.root.name))
limit_msg: DeviceMessage = self.root.parent.connector.get(
MessageEndpoints.device_limits(self.root.name)
)
if not limit_msg:
return [0, 0]

limits = [
limit_msg.content["signals"].get("low", {}).get("value", 0),
limit_msg.content["signals"].get("high", {}).get("value", 0),
limit_msg.signals.get("low", {}).get("value", 0),
limit_msg.signals.get("high", {}).get("value", 0),
]
return limits

Expand Down
38 changes: 24 additions & 14 deletions bec_lib/bec_lib/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,18 @@
from __future__ import annotations

import enum
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from functools import lru_cache
from typing import Annotated, Any, Callable, ClassVar, Generic, TypeVar

from bec_lib.utils.import_utils import lazy_import
from pydantic import PlainSerializer

from bec_lib import messages
from bec_lib.bec_serializable import BECSerializable, serialize_type

# pylint: disable=too-many-public-methods
# pylint: disable=too-many-lines


if TYPE_CHECKING: # pragma: no cover
from bec_lib import messages
else:
# TODO: put back normal import when Pydantic gets faster
# from bec_lib import messages
messages = lazy_import("bec_lib.messages")


class EndpointType(str, enum.Enum):
"""Endpoint type enum"""

Expand Down Expand Up @@ -49,8 +44,21 @@ class MessageOp(list[str], enum.Enum):
MessageType = TypeVar("MessageType", bound="type[messages.BECMessage]")


@dataclass
class EndpointInfo(Generic[MessageType]):
@lru_cache()
def _resolve_msg_type(n: type[messages.BECMessage] | str):
if n is Any or n == "Any":
return Any
if isinstance(n, str):
return messages.__dict__.get(n, messages.RawMessage)
if issubclass(n, messages.BECMessage):
return n
raise TypeError(f"Invalid argument for _resolve_msg_type: {n}")


MesageAnno = Annotated[MessageType, PlainSerializer(serialize_type)]


class EndpointInfo(BECSerializable, Generic[MessageType]):
"""
Dataclass for endpoint info.

Expand All @@ -60,8 +68,10 @@ class EndpointInfo(Generic[MessageType]):
message_op (MessageOp): Message operation.
"""

_deserialization_registry = [((MesageAnno,), _resolve_msg_type)]

endpoint: str
message_type: MessageType
message_type: MesageAnno | Any
message_op: MessageOp


Expand Down
Loading
Loading