diff --git a/docs/tutorials/scripting-plans.md b/docs/tutorials/scripting-plans.md new file mode 100644 index 0000000000..77acbb7eee --- /dev/null +++ b/docs/tutorials/scripting-plans.md @@ -0,0 +1,168 @@ +# Scripting Plans + +While the CLI can be used to query devices and run plans, it can be useful to +combine multiple plans within a better interface than bash/shell scripting. + +For this, `blueapi` can be used as a library providing a `BlueapiClient` +wrapping interactions with the server. + +## Login to blueapi + +The following steps require the user to have logged in blueapi. This can be done +via the `blueapi login` command. + +## Create an instance of the client + +```python +from blueapi.client.client import BlueapiClient + +# A client can be created from either a config instance or the path to a config +# file. The minimal configuration required # is: +# api: +# url: https://address.of.blueapi:1234 +# stomp: +# enabled: true +# url: tcp://address.of.rabbitmq:61613 +bc = BlueapiClient.from_config_file("/path/to/config.yaml") +``` + +Plans and devices are available via the `plans` and `devices` attributes of the +client. It can be useful to alias these locally to reduce the boilerplate in +scripts. + +```python +plans = bc.plans +devices = bc.devices +``` + +## Query devices + +The devices available on the server are accessible via the `devices` attribute +of the client. + +```python +for device in bc.devices: + print(device) +``` + +Individual devices can be accessed as attributes on the `devices` field. It can +also be useful to alias these locally. + +```python +det = bc.devices.det +stage = bc.devices.stage +``` + +Child devices can be accessed via their parent devices + +```python +stage_x = stage.x +``` + +Trying to access a child device that does not exist will raise an +`AttributeError` + +## Run a plan + +Plans are accessible via the `plans` attribute of the client instance. They can, +for the most part, be treated as if they were local functions. + +```python +bc.plans.count([bc.devices.det], num=3, delay=4.2) +``` + +Running a plan in this way will block until the plan is complete. If the script +is interrupted (eg via Ctrl-C) while a plan is running it will be aborted before +the script exits. + +Where parameters to a plan are optional, they can be omitted from the method +call. Where parameters are required, they can be passed either as positional or +named arguments. + +## Run multiple plans + +Plans can then be co-ordinated using standard python constructs, eg to run a +plan multiple times + +```python +for temp in range(1, 5): + plans.set_absolute({devices.temp: temp}) + plans.count([devices.det]) +``` + +## Passing more complex arguments + +Anything passed to a plan function will be serialized into JSON before being +sent to the server. For many types you can pass the instance directly and the +serialization should handle the conversion for you. + +```python +from scanspec.specs import Line + +bc.plans.spec_scan(detectors=[det], spec=Line(bc.devices.stage.x, 0, 10, 11)) +``` + +if a type does not serialize correctly, passing the JSON equivalent should be +possible instead. For instance the above is equivalent to + +```python +bc.plans.spec_scan(detectors=[det], spec={ + "axis": "stage.x", + "start": 0.0, + "stop": 10.0, + "num": 11, + "type": "Line"}) +``` + +## Add callbacks + +By default there is no indication of progress while a scan is running however it +is possible to subscribe to events so that updates can be provided. + +A callback should accept a single parameter which will be the event from server. +This will be one of `WorkerEvent`, `ProgressEvent` or `DataEvent`. + +An example that prints data for each point could be something like + +```python +def feedback(evt): + match evt: + case DataEvent(name="start"): + print("Run started") + case DataEvent(name="stop", doc={"exit_status": status}): + print("Run complete: ", status) + case DataEvent(name="event", doc={"seq_num": point, "data": data}): + print(f" Point {point}: {data}") + +bc.add_callback(feedback) + +bc.plans.spec_scan([bc.devices.det], Line(bc.devices.stage.x, 0, 1, 11)) +``` + +The above prints the following as the scan progresses + +``` +Run started + Point 1: {'stage-x': 0.0} + Point 2: {'stage-x': 0.1} + Point 3: {'stage-x': 0.2} + Point 4: {'stage-x': 0.3} + Point 5: {'stage-x': 0.4} + Point 6: {'stage-x': 0.5} + Point 7: {'stage-x': 0.6} + Point 8: {'stage-x': 0.7000000000000001} + Point 9: {'stage-x': 0.8} + Point 10: {'stage-x': 0.9} + Point 11: {'stage-x': 1.0} +Run complete: success +``` + +The `add_callback` method returns an ID that can be used to remove the callback + +```python +# Add the callback and record the handle +hnd = bc.add_callback(callback_function) + +# remove the callback using the returned handle +bc.remove_callback(hnd) +``` diff --git a/pyproject.toml b/pyproject.toml index dda9611ea9..2eea434322 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "pyjwt[crypto]", "tomlkit", "graypy>=2.1.0", + "jinja2>=3.1.6", ] dynamic = ["version"] license.file = "LICENSE" diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 4d974d35e2..e642e24129 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -17,7 +17,6 @@ from click.exceptions import ClickException from observability_utils.tracing import setup_tracing from pydantic import ValidationError -from requests.exceptions import ConnectionError from blueapi import __version__, config from blueapi.cli.format import OutputFormat @@ -26,6 +25,7 @@ from blueapi.client.rest import ( BlueskyRemoteControlError, InvalidParametersError, + ServiceUnavailableError, UnauthorisedAccessError, UnknownPlanError, ) @@ -36,9 +36,10 @@ from blueapi.core import OTLP_EXPORT_ENABLED, DataEvent from blueapi.log import set_up_logging from blueapi.service.authentication import SessionCacheManager, SessionManager -from blueapi.service.model import SourceInfo, TaskRequest +from blueapi.service.model import DeviceResponse, PlanResponse, SourceInfo, TaskRequest from blueapi.worker import ProgressEvent, WorkerEvent +from . import stubgen from .scratch import setup_scratch from .updates import CliEventRenderer @@ -154,6 +155,23 @@ def start_application(obj: dict): start(config) +@main.command() +@click.pass_obj +@click.argument("target", type=click.Path(file_okay=False)) +def generate_stubs(obj: dict, target: Path): + """ + Generate a type-stubs project for blueapi for the currently running server. + This enables users using blueapi as a library to benefit from type checking + and linting when writing scripts against the BlueapiClient. + """ + click.echo(f"Writing stubs to {target}") + + config: ApplicationConfig = obj["config"] + bc = BlueapiClient.from_config(config) + + stubgen.generate_stubs(Path(target), list(bc.plans), list(bc.devices)) + + @main.group() @click.option( "-o", @@ -185,7 +203,7 @@ def check_connection(func: Callable[P, T]) -> Callable[P, T]: def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: try: return func(*args, **kwargs) - except ConnectionError as ce: + except ServiceUnavailableError as ce: raise ClickException( "Failed to establish connection to blueapi server." ) from ce @@ -206,7 +224,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: def get_plans(obj: dict) -> None: """Get a list of plans available for the worker to use""" client: BlueapiClient = obj["client"] - obj["fmt"].display(client.get_plans()) + obj["fmt"].display(PlanResponse(plans=[p.model for p in client.plans])) @controller.command(name="devices") @@ -215,7 +233,7 @@ def get_plans(obj: dict) -> None: def get_devices(obj: dict) -> None: """Get a list of devices available for the worker to use""" client: BlueapiClient = obj["client"] - obj["fmt"].display(client.get_devices()) + obj["fmt"].display(DeviceResponse(devices=[dev.model for dev in client.devices])) @controller.command(name="listen") @@ -347,7 +365,7 @@ def get_state(obj: dict) -> None: """Print the current state of the worker""" client: BlueapiClient = obj["client"] - print(client.get_state().name) + print(client.state.name) @controller.command(name="pause") @@ -430,7 +448,7 @@ def env( status = client.reload_environment(timeout=timeout) print("Environment is initialized") else: - status = client.get_environment() + status = client.environment print(status) @@ -472,14 +490,13 @@ def login(obj: dict) -> None: print("Logged in") except Exception: client = BlueapiClient.from_config(config) - oidc_config = client.get_oidc_config() - if oidc_config is None: + if oidc := client.oidc_config: + auth = SessionManager( + oidc, cache_manager=SessionCacheManager(config.auth_token_path) + ) + auth.start_device_flow() + else: print("Server is not configured to use authentication!") - return - auth = SessionManager( - oidc_config, cache_manager=SessionCacheManager(config.auth_token_path) - ) - auth.start_device_flow() @main.command(name="logout") diff --git a/src/blueapi/cli/format.py b/src/blueapi/cli/format.py index 98a8dafba4..29ba01b732 100644 --- a/src/blueapi/cli/format.py +++ b/src/blueapi/cli/format.py @@ -12,7 +12,9 @@ from blueapi.core.bluesky_types import DataEvent from blueapi.service.model import ( + DeviceModel, DeviceResponse, + PlanModel, PlanResponse, PythonEnvironmentResponse, SourceInfo, @@ -54,17 +56,21 @@ def display_full(obj: Any, stream: Stream): match obj: case PlanResponse(plans=plans): for plan in plans: - print(plan.name) - if desc := plan.description: - print(indent(dedent(desc).strip(), " ")) - if schema := plan.parameter_schema: - print(" Schema") - print(indent(json.dumps(schema, indent=2), " ")) + display_full(plan, stream) + case PlanModel(name=name, description=desc, parameter_schema=schema): + print(name) + if desc: + print(indent(dedent(desc).strip(), " ")) + if schema: + print(" Schema") + print(indent(json.dumps(schema, indent=2), " ")) case DeviceResponse(devices=devices): for dev in devices: - print(dev.name) - for proto in dev.protocols: - print(f" {proto}") + display_full(dev, stream) + case DeviceModel(name=name, protocols=protocols): + print(name) + for proto in protocols: + print(f" {proto}") case DataEvent(name=name, doc=doc): print(f"{name.title()}:{fmt_dict(doc)}") case WorkerEvent(state=st, task_status=task): @@ -98,11 +104,13 @@ def display_json(obj: Any, stream: Stream): print = partial(builtins.print, file=stream) match obj: case PlanResponse(plans=plans): - print(json.dumps([p.model_dump() for p in plans], indent=2)) + display_json(plans, stream) case DeviceResponse(devices=devices): - print(json.dumps([d.model_dump() for d in devices], indent=2)) + display_json(devices, stream) case BaseModel(): print(json.dumps(obj.model_dump())) + case list(): + print(json.dumps([it.model_dump() for it in obj], indent=2)) case _: print(json.dumps(obj)) @@ -112,26 +120,30 @@ def display_compact(obj: Any, stream: Stream): match obj: case PlanResponse(plans=plans): for plan in plans: - print(plan.name) - if desc := plan.description: - print(indent(dedent(desc.split("\n\n")[0].strip("\n")), " ")) - if schema := plan.parameter_schema: - print(" Args") - for arg, spec in schema.get("properties", {}).items(): - req = arg in schema.get("required", {}) - print(f" {arg}={_describe_type(spec, req)}") + display_compact(plan, stream) + case PlanModel(name=name, description=desc, parameter_schema=schema): + print(name) + if desc: + print(indent(dedent(desc.split("\n\n")[0].strip("\n")), " ")) + if schema: + print(" Args") + for arg, spec in schema.get("properties", {}).items(): + req = arg in schema.get("required", {}) + print(f" {arg}={_describe_type(spec, req)}") case DeviceResponse(devices=devices): for dev in devices: - print(dev.name) - print( - indent( - textwrap.fill( - ", ".join(str(proto) for proto in dev.protocols), - 80, - ), - " ", - ) + display_compact(dev, stream) + case DeviceModel(name=name, protocols=protocols): + print(name) + print( + indent( + textwrap.fill( + ", ".join(str(proto) for proto in protocols), + 80, + ), + " ", ) + ) case DataEvent(name=name): print(f"Data Event: {name}") case WorkerEvent(state=state): diff --git a/src/blueapi/cli/stubgen.py b/src/blueapi/cli/stubgen.py new file mode 100644 index 0000000000..6f6fbf4bd6 --- /dev/null +++ b/src/blueapi/cli/stubgen.py @@ -0,0 +1,117 @@ +import logging +from dataclasses import dataclass +from inspect import cleandoc +from pathlib import Path +from textwrap import dedent +from typing import Self, TextIO + +from jinja2 import Environment, PackageLoader + +from blueapi.client.cache import DeviceRef, Plan +from blueapi.core import context +from blueapi.core.bluesky_types import BLUESKY_PROTOCOLS + +log = logging.getLogger(__name__) + + +@dataclass +class ArgSpec: + name: str + type: str + optional: bool + + +@dataclass +class PlanSpec: + name: str + docs: str + args: list[ArgSpec] + + @classmethod + def from_plan(cls, plan: Plan) -> Self: + req = set(plan.required) + args = [ + ArgSpec(arg, _type_string(spec), arg not in req) + for arg, spec in plan.model.parameter_schema.get("properties", {}).items() + ] + return cls(plan.name, plan.help_text, args) + + +BLUESKY_PROTOCOL_NAMES = {context.qualified_name(proto) for proto in BLUESKY_PROTOCOLS} + + +def _type_string(spec) -> str: + """Best effort attempt at making useful type hints for plans""" + match spec.get("type"): + case "array": + return f"list[{_type_string(spec.get('items'))}]" + case "integer": + return "int" + case "number": + return "float" + case proto if proto in BLUESKY_PROTOCOL_NAMES: + return "DeviceRef" + case "object": + return "dict[str, Any]" + case "string": + return "str" + case "boolean": + return "bool" + case None if opts := spec.get("anyOf"): + return " | ".join(_type_string(opt) for opt in opts) + case _: + return "Any" + + +def generate_stubs(target: Path, plans: list[Plan], devices: list[DeviceRef]): + log.info("Generating stubs for %d plans and %d devices", len(plans), len(devices)) + target.mkdir(parents=True, exist_ok=True) + client_dir = target / "src" / "blueapi-stubs" / "client" + + log.debug("Making project structure: %s", client_dir) + client_dir.mkdir(parents=True, exist_ok=True) + + stub_file = client_dir / "cache.pyi" + project_file = target / "pyproject.toml" + py_typed = target / "src" / "blueapi-stubs" / "py.typed" + + log.debug("Writing pyproject.toml to %s", project_file) + with open(project_file, "w") as out: + out.write( + dedent(""" + [project] + name = "blueapi-stubs" + version = "0.1.0" + description = "Generated client stubs for a running server" + readme = "README.md" + requires-python = ">=3.11" + + dependencies = [ + "blueapi" + ] + """) + ) + + log.debug("Writing py.typed file to %s", py_typed) + with open(py_typed, "w") as out: + out.write("partial\n") + + log.debug("Writing stub file to %s", stub_file) + with open(stub_file, "w") as out: + render_stub_file(out, plans, devices) + + +def _docstring(text: str) -> str: + # """Convert a docstring to a format that can be inserted into the template""" + return cleandoc(text).replace('"""', '\\"""') + + +def render_stub_file( + stub_file: TextIO, plan_models: list[Plan], devices: list[DeviceRef] +): + plans = [PlanSpec.from_plan(p) for p in plan_models] + + env = Environment(loader=PackageLoader("blueapi", package_path="stubs/templates")) + env.filters["docstring"] = _docstring + tmpl = env.get_template("cache_template.pyi.jinja") + stub_file.write(tmpl.render(plans=plans, devices=devices)) diff --git a/src/blueapi/client/cache.py b/src/blueapi/client/cache.py new file mode 100644 index 0000000000..0ec8c4c87c --- /dev/null +++ b/src/blueapi/client/cache.py @@ -0,0 +1,177 @@ +import logging +from collections.abc import Callable +from itertools import chain +from typing import Any + +from blueapi.client.rest import BlueapiRestClient +from blueapi.service.model import DeviceModel, PlanModel +from blueapi.worker.event import WorkerEvent + +log = logging.getLogger(__name__) + + +# This file should be kept in sync with the type stub template in stubs/templates + + +PlanRunner = Callable[[str, dict[str, Any]], WorkerEvent] + + +class PlanCache: + """ + Cache of plans available on the server + """ + + def __init__(self, runner: PlanRunner, plans: list[PlanModel]): + self._cache = {model.name: Plan(model=model, runner=runner) for model in plans} + for name, plan in self._cache.items(): + if name.startswith("_"): + continue + setattr(self, name, plan) + + def __getitem__(self, name: str) -> "Plan": + return self._cache[name] + + def __getattr__(self, name: str) -> "Plan": + raise AttributeError(f"No plan named '{name}' available") + + def __iter__(self): + return iter(self._cache.values()) + + def __repr__(self) -> str: + return f"PlanCache({len(self._cache)} plans)" + + +class Plan: + """ + An interface to a plan on the blueapi server + + This allows remote plans to be called (mostly) as if they were local + methods when writing user scripts. + + If you are seeing this help while using blueapi as a library, generating + type stubs may be helpful for type checking and plan discovery, eg + + blueapi generate-stubs /tmp/blueapi-stubs + uv add --editable /tmp/blueapi-stubs + + """ + + model: PlanModel + + def __init__(self, model: PlanModel, runner: PlanRunner): + self.model = model + self._runner = runner + self.__doc__ = model.description + + def __call__(self, *args, **kwargs) -> WorkerEvent: + """ + Run the plan on the server mapping the given args into the required parameters + """ + return self._runner(self.name, self._build_args(*args, **kwargs)) + + @property + def name(self) -> str: + return self.model.name + + @property + def help_text(self) -> str: + return self.model.description or f"Plan {self!r}" + + @property + def properties(self) -> set[str]: + return self.model.parameter_schema.get("properties", {}).keys() + + @property + def required(self) -> list[str]: + return self.model.parameter_schema.get("required", []) + + def _build_args(self, *args, **kwargs): + log.info( + "Building args for %s, using %s and %s", + "[" + ",".join(self.properties) + "]", + args, + kwargs, + ) + + if len(args) > len(self.properties): + raise TypeError(f"{self.name} got too many arguments") + if extra := {k for k in kwargs if k not in self.properties}: + raise TypeError(f"{self.name} got unexpected arguments: {extra}") + + params = {} + # Initially fill parameters using positional args assuming the order + # from the parameter_schema + for req, arg in zip(self.properties, args, strict=False): + params[req] = arg + + # Then append any values given via kwargs + for key, value in kwargs.items(): + # If we've already assumed a positional arg was this value, bail out + if key in params: + raise TypeError(f"{self.name} got multiple values for {key}") + params[key] = value + + if missing := {k for k in self.required if k not in params}: + raise TypeError(f"Missing argument(s) for {missing}") + return params + + def __repr__(self): + opts = [p for p in self.properties if p not in self.required] + params = ", ".join(chain(self.required, (f"{opt}=None" for opt in opts))) + return f"{self.name}({params})" + + +class DeviceCache: + def __init__(self, rest: BlueapiRestClient): + self._rest = rest + self._cache = { + model.name: DeviceRef(name=model.name, cache=self, model=model) + for model in rest.get_devices().devices + } + for name, device in self._cache.items(): + if name.startswith("_"): + continue + setattr(self, name, device) + + def __getitem__(self, name: str) -> "DeviceRef": + if dev := self._cache.get(name): + return dev + try: + model = self._rest.get_device(name) + device = DeviceRef(name=name, cache=self, model=model) + self._cache[name] = device + setattr(self, model.name, device) + return device + except KeyError: + pass + raise AttributeError(f"No device named '{name}' available") + + def __getattr__(self, name: str) -> "DeviceRef": + if name.startswith("_"): + return super().__getattribute__(name) + return self[name] + + def __iter__(self): + return iter(self._cache.values()) + + def __repr__(self) -> str: + return f"DeviceCache({len(self._cache)} devices)" + + +class DeviceRef(str): + model: DeviceModel + _cache: DeviceCache + + def __new__(cls, name: str, cache: DeviceCache, model: DeviceModel): + instance = super().__new__(cls, name) + instance.model = model + instance._cache = cache + return instance + + def __getattr__(self, name) -> "DeviceRef": + if name.startswith("_"): + raise AttributeError(f"No child device named {name}") + return self._cache[f"{self}.{name}"] + + def __repr__(self): + return f"Device({self})" diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 0930e240a9..292ff06e06 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -1,5 +1,11 @@ +import itertools +import logging import time +from collections.abc import Iterable from concurrent.futures import Future +from functools import cached_property +from pathlib import Path +from typing import Any, Self from bluesky_stomp.messaging import MessageContext, StompClient from bluesky_stomp.models import Broker @@ -8,9 +14,13 @@ start_as_current_span, ) -from blueapi.config import ApplicationConfig, MissingStompConfigurationError +from blueapi.config import ( + ApplicationConfig, + ConfigLoader, + MissingStompConfigurationError, +) from blueapi.core.bluesky_types import DataEvent -from blueapi.service.authentication import SessionManager +from blueapi.service.authentication import SessionCacheManager, SessionManager from blueapi.service.model import ( DeviceModel, DeviceResponse, @@ -25,20 +35,33 @@ TasksListResponse, WorkerTask, ) -from blueapi.worker import TrackableTask, WorkerEvent, WorkerState +from blueapi.utils import deprecated +from blueapi.worker import WorkerEvent, WorkerState from blueapi.worker.event import ProgressEvent, TaskStatus +from blueapi.worker.task_worker import TrackableTask +from .cache import DeviceCache, PlanCache from .event_bus import AnyEvent, BlueskyStreamingError, EventBusClient, OnAnyEvent from .rest import BlueapiRestClient, BlueskyRemoteControlError TRACER = get_tracer("client") +log = logging.getLogger(__name__) + + +class MissingInstrumentSessionError(Exception): + pass + + class BlueapiClient: """Unified client for controlling blueapi""" _rest: BlueapiRestClient _events: EventBusClient | None + _instrument_session: str | None = None + _callbacks: dict[int, OnAnyEvent] + _callback_id: itertools.count def __init__( self, @@ -47,9 +70,17 @@ def __init__( ): self._rest = rest self._events = events + self._callbacks = {} + self._callback_id = itertools.count() + + @classmethod + def from_config_file(cls, config_file: str) -> Self: + conf = ConfigLoader(ApplicationConfig) + conf.use_values_from_yaml(Path(config_file)) + return cls.from_config(conf.load()) @classmethod - def from_config(cls, config: ApplicationConfig) -> "BlueapiClient": + def from_config(cls, config: ApplicationConfig) -> Self: session_manager: SessionManager | None = None try: session_manager = SessionManager.from_cache(config.auth_token_path) @@ -71,7 +102,33 @@ def from_config(cls, config: ApplicationConfig) -> "BlueapiClient": else: return cls(rest) + @cached_property @start_as_current_span(TRACER) + def plans(self) -> PlanCache: + return PlanCache(self.run_plan, self._rest.get_plans().plans) + + @cached_property + @start_as_current_span(TRACER) + def devices(self) -> DeviceCache: + return DeviceCache(self._rest) + + @property + def instrument_session(self) -> str: + if self._instrument_session is None: + raise MissingInstrumentSessionError() + return self._instrument_session + + @instrument_session.setter + def instrument_session(self, session: str): + log.debug("Setting instrument_session to %s", session) + self._instrument_session = session + + def with_instrument_session(self, session: str) -> Self: + self.instrument_session = session + return self + + @start_as_current_span(TRACER) + @deprecated("plans property") def get_plans(self) -> PlanResponse: """ List plans available @@ -82,6 +139,7 @@ def get_plans(self) -> PlanResponse: return self._rest.get_plans() @start_as_current_span(TRACER, "name") + @deprecated("plans[name]") def get_plan(self, name: str) -> PlanModel: """ Get details of a single plan @@ -95,6 +153,7 @@ def get_plan(self, name: str) -> PlanModel: return self._rest.get_plan(name) @start_as_current_span(TRACER) + @deprecated("devices property") def get_devices(self) -> DeviceResponse: """ List devices available @@ -105,7 +164,20 @@ def get_devices(self) -> DeviceResponse: return self._rest.get_devices() + def add_callback(self, callback: OnAnyEvent) -> int: + cb_id = next(self._callback_id) + self._callbacks[cb_id] = callback + return cb_id + + def remove_callback(self, id: int): + self._callbacks.pop(id) + + @property + def callbacks(self) -> Iterable[OnAnyEvent]: + return self._callbacks.values() + @start_as_current_span(TRACER, "name") + @deprecated("devices[name]") def get_device(self, name: str) -> DeviceModel: """ Get details of a single device @@ -119,8 +191,9 @@ def get_device(self, name: str) -> DeviceModel: return self._rest.get_device(name) + @property @start_as_current_span(TRACER) - def get_state(self) -> WorkerState: + def state(self) -> WorkerState: """ Get current state of the blueapi worker @@ -130,6 +203,18 @@ def get_state(self) -> WorkerState: return self._rest.get_state() + @start_as_current_span(TRACER) + @deprecated("state property") + def get_state(self) -> WorkerState: + """ + Get current state of the blueapi worker + + Returns: + WorkerState: Current state + """ + + return self.state + @start_as_current_span(TRACER, "defer") def pause(self, defer: bool = False) -> WorkerState: """ @@ -159,6 +244,7 @@ def resume(self) -> WorkerState: return self._rest.set_state(WorkerState.RUNNING, defer=False) @start_as_current_span(TRACER, "task_id") + @deprecated("rest client") def get_task(self, task_id: str) -> TrackableTask: """ Get a task stored by the worker @@ -173,6 +259,7 @@ def get_task(self, task_id: str) -> TrackableTask: return self._rest.get_task(task_id) @start_as_current_span(TRACER) + @deprecated("rest client") def get_all_tasks(self) -> TasksListResponse: """ Get a list of all task stored by the worker @@ -183,8 +270,9 @@ def get_all_tasks(self) -> TasksListResponse: return self._rest.get_all_tasks() + @property @start_as_current_span(TRACER) - def get_active_task(self) -> WorkerTask: + def active_task(self) -> WorkerTask: """ Get the currently active task, if any @@ -195,6 +283,28 @@ def get_active_task(self) -> WorkerTask: return self._rest.get_active_task() + @start_as_current_span(TRACER) + @deprecated("active_task property") + def get_active_task(self) -> WorkerTask: + """ + Get the currently active task, if any + + Returns: + WorkerTask: The currently active task, the task the worker + is executing right now. + """ + + return self.active_task + + @start_as_current_span(TRACER, "name", "params") + def run_plan(self, name: str, params: dict[str, Any]) -> WorkerEvent: + req = TaskRequest( + name=name, + params=params, + instrument_session=self.instrument_session, + ) + return self.run_task(req) + @start_as_current_span(TRACER, "task", "timeout") def run_task( self, @@ -221,7 +331,7 @@ def run_task( "Stomp configuration required to run plans is missing or disabled" ) - task_response = self.create_task(task) + task_response = self._rest.create_task(task) task_id = task_response.task_id complete: Future[WorkerEvent] = Future() @@ -239,6 +349,13 @@ def inner_on_event(event: AnyEvent, ctx: MessageContext) -> None: if relates_to_task: if on_event is not None: on_event(event) + for cb in self._callbacks.values(): + try: + cb(event) + except Exception as e: + log.error( + f"Callback ({cb}) failed for event: {event}", exc_info=e + ) if isinstance(event, WorkerEvent) and ( (event.is_complete()) and (ctx.correlation_id == task_id) ): @@ -255,7 +372,7 @@ def inner_on_event(event: AnyEvent, ctx: MessageContext) -> None: with self._events: self._events.subscribe_to_all_events(inner_on_event) - self.start_task(WorkerTask(task_id=task_id)) + self._rest.update_worker_task(WorkerTask(task_id=task_id)) return complete.result(timeout=timeout) @start_as_current_span(TRACER, "task") @@ -271,8 +388,10 @@ def create_and_start_task(self, task: TaskRequest) -> TaskResponse: TaskResponse: Acknowledgement of request """ - response = self.create_task(task) - worker_response = self.start_task(WorkerTask(task_id=response.task_id)) + response = self._rest.create_task(task) + worker_response = self._rest.update_worker_task( + WorkerTask(task_id=response.task_id) + ) if worker_response.task_id == response.task_id: return response else: @@ -282,6 +401,7 @@ def create_and_start_task(self, task: TaskRequest) -> TaskResponse: ) @start_as_current_span(TRACER, "task") + @deprecated("rest client") def create_task(self, task: TaskRequest) -> TaskResponse: """ Create a new task, does not start execution @@ -296,6 +416,7 @@ def create_task(self, task: TaskRequest) -> TaskResponse: return self._rest.create_task(task) @start_as_current_span(TRACER) + @deprecated("rest client") def clear_task(self, task_id: str) -> TaskResponse: """ Delete a stored task on the worker @@ -310,6 +431,7 @@ def clear_task(self, task_id: str) -> TaskResponse: return self._rest.clear_task(task_id) @start_as_current_span(TRACER, "task") + @deprecated("rest client") def start_task(self, task: WorkerTask) -> WorkerTask: """ Instruct the worker to start a stored task immediately @@ -358,7 +480,15 @@ def stop(self) -> WorkerState: return self._rest.cancel_current_task(WorkerState.STOPPING) + @property @start_as_current_span(TRACER) + def environment(self) -> EnvironmentResponse: + """Details of the worker environment""" + + return self._rest.get_environment() + + @start_as_current_span(TRACER) + @deprecated("environment property") def get_environment(self) -> EnvironmentResponse: """ Get details of the worker environment @@ -368,7 +498,7 @@ def get_environment(self) -> EnvironmentResponse: environment. """ - return self._rest.get_environment() + return self.environment @start_as_current_span(TRACER, "timeout", "polling_interval") def reload_environment( @@ -433,7 +563,15 @@ def _wait_for_reload( "seconds, a server restart is recommended" ) + @property + @start_as_current_span(TRACER) + def oidc_config(self) -> OIDCConfig | None: + """OIDC config from the server""" + + return self._rest.get_oidc_config() + @start_as_current_span(TRACER) + @deprecated("oidc_config property") def get_oidc_config(self) -> OIDCConfig | None: """ Get oidc config from the server @@ -442,7 +580,7 @@ def get_oidc_config(self) -> OIDCConfig | None: OIDCConfig: Details of the oidc Config """ - return self._rest.get_oidc_config() + return self.oidc_config @start_as_current_span(TRACER) def get_python_env( @@ -457,3 +595,18 @@ def get_python_env( """ return self._rest.get_python_environment(name=name, source=source) + + def login(self, token_path: Path | None = None): + try: + auth: SessionManager = SessionManager.from_cache(token_path) + access_token = auth.get_valid_access_token() + assert access_token + print("Logged in") + except Exception: + if oidc := self.oidc_config: + auth = SessionManager( + oidc, cache_manager=SessionCacheManager(token_path) + ) + auth.start_device_flow() + else: + print("Server is not configured to use authentication!") diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index 3ff119449e..52150d36fd 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -136,6 +136,8 @@ def _create_task_exceptions(response: requests.Response) -> Exception | None: class BlueapiRestClient: _config: RestConfig + _session_manager: SessionManager | None + _pool: requests.Session def __init__( self, @@ -144,6 +146,7 @@ def __init__( ) -> None: self._config = config or RestConfig() self._session_manager = session_manager + self._pool = requests.Session() def get_plans(self) -> PlanResponse: return self._request_and_deserialize("/plans", PlanResponse) @@ -252,14 +255,17 @@ def _request_and_deserialize( url = self._config.url.unicode_string().removesuffix("/") + suffix # Get the trace context to propagate to the REST API carr = get_context_propagator() - response = requests.request( - method, - url, - json=data, - params=params, - headers=carr, - auth=JWTAuth(self._session_manager), - ) + try: + response = self._pool.request( + method, + url, + json=data, + params=params, + headers=carr, + auth=JWTAuth(self._session_manager), + ) + except requests.exceptions.ConnectionError as ce: + raise ServiceUnavailableError() from ce exception = get_exception(response) if exception is not None: raise exception @@ -289,3 +295,7 @@ def __getattr__(name: str): ) return rename raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +class ServiceUnavailableError(Exception): + pass diff --git a/src/blueapi/stubs/templates/cache_template.pyi.jinja b/src/blueapi/stubs/templates/cache_template.pyi.jinja new file mode 100644 index 0000000000..b06ef36385 --- /dev/null +++ b/src/blueapi/stubs/templates/cache_template.pyi.jinja @@ -0,0 +1,72 @@ +from collections.abc import Callable +from typing import Any +from blueapi.client.rest import BlueapiRestClient +from blueapi.service.model import DeviceModel, PlanModel +from blueapi.worker.event import WorkerEvent + +{#- + This file is based on the cache.py file in blueapi/client/cache.py and should + be kept in sync with changes there. +#} + +# This file is auto-generated for a live server and should not be modified directly + +PlanRunner = Callable[[str, dict[str, Any]], WorkerEvent] + +class PlanCache: + def __init__(self, runner: PlanRunner, plans: list[PlanModel]) -> None: ... + def __getitem__(self, name: str) -> Plan: ... + def __iter__(self): # -> Iterator[Plan]: + ... + def __repr__(self) -> str: ... + +### Generated plans +{%- for item in plans %} + def {{ item.name }}(self,{% for arg in item.args %} + {{ arg.name }}: {{ arg.type }}{% if arg.optional %} | None = None{% endif %}, + {%- endfor %} + ) -> WorkerEvent: + """ + {{ item.docs | docstring | indent(8) }} + """ + ... +{%- endfor %} +### End + + +class Plan: + model: PlanModel + def __init__(self, model: PlanModel, runner: PlanRunner) -> None: ... + def __call__(self, *args, **kwargs): # -> None: + ... + + @property + def name(self) -> str: ... + @property + def help_text(self) -> str: ... + @property + def properties(self) -> set[str]: ... + @property + def required(self) -> list[str]: ... + def __repr__(self) -> str: ... + + +class DeviceRef(str): + model: DeviceModel + _cache: DeviceCache + def __new__(cls, name: str, cache: DeviceCache, model: DeviceModel): ... + def __getattr__(self, name) -> DeviceRef: ... + def __repr__(self) -> str: ... + +class DeviceCache: + def __init__(self, rest: BlueapiRestClient) -> None: ... + def __getitem__(self, name: str) -> DeviceRef: ... + def __iter__(self): # -> Iterator[DeviceRef]: + ... + def __repr__(self) -> str: ... + +### Generated devices + {%- for item in devices %} + {{ item }}: DeviceRef + {%- endfor %} +### End diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index 81c100bdc2..4b2e41f2c1 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -1,3 +1,7 @@ +from collections.abc import Callable +from functools import wraps +from typing import ParamSpec, TypeVar + from .base_model import BlueapiBaseModel, BlueapiModelConfig, BlueapiPlanModelConfig from .connect_devices import connect_devices, report_successful_devices from .file_permissions import get_owner_gid, is_sgid_set @@ -21,4 +25,31 @@ "is_sgid_set", "get_owner_gid", "is_function_sourced_from_module", + "deprecated", ] + +Args = ParamSpec("Args") +Return = TypeVar("Return") + + +def deprecated(alternative): + from warnings import warn + + def deprecated(func: Callable[Args, Return]) -> Callable[Args, Return]: + called = False + + @wraps(func) + def wrapped(*args, **kwargs): + nonlocal called + if not called: + warn( + f"Function {func.__name__} is deprecated - use {alternative}", + DeprecationWarning, + stacklevel=2, + ) + called = True + return func(*args, **kwargs) + + return wrapped + + return deprecated diff --git a/tests/system_tests/test_blueapi_system.py b/tests/system_tests/test_blueapi_system.py index e4ce504fab..5aecae5461 100644 --- a/tests/system_tests/test_blueapi_system.py +++ b/tests/system_tests/test_blueapi_system.py @@ -9,13 +9,14 @@ import requests from bluesky_stomp.models import BasicAuthentication from pydantic import TypeAdapter -from requests.exceptions import ConnectionError from blueapi.client import BlueapiClient from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError from blueapi.client.rest import ( + BlueapiRestClient, BlueskyRemoteControlError, BlueskyRequestError, + ServiceUnavailableError, ) from blueapi.config import ( ApplicationConfig, @@ -129,9 +130,9 @@ def client_with_stomp() -> Generator[BlueapiClient]: def wait_for_server(client: BlueapiClient): for _ in range(20): try: - client.get_environment() + _ = client.environment return - except ConnectionError: + except ServiceUnavailableError: ... time.sleep(0.5) raise TimeoutError("No connection to the blueapi server") @@ -148,6 +149,11 @@ def client() -> Generator[BlueapiClient]: yield BlueapiClient.from_config(config=ApplicationConfig()) +@pytest.fixture +def rest_client(client: BlueapiClient) -> BlueapiRestClient: + return client._rest + + @pytest.fixture def expected_plans() -> PlanResponse: return TypeAdapter(PlanResponse).validate_json( @@ -163,25 +169,22 @@ def expected_devices() -> DeviceResponse: @pytest.fixture -def blueapi_client_get_methods() -> list[str]: +def blueapi_rest_client_get_methods() -> list[str]: # Get a list of methods that take only one argument (self) - # This will currently return - # ['get_plans', 'get_devices', 'get_state', 'get_all_tasks', - # 'get_active_task','get_environment','resume', 'stop','get_oidc_config'] return [ - method - for method in BlueapiClient.__dict__ - if callable(getattr(BlueapiClient, method)) - and not method.startswith("__") - and len(inspect.signature(getattr(BlueapiClient, method)).parameters) == 1 - and "self" in inspect.signature(getattr(BlueapiClient, method)).parameters + name + for name, method in BlueapiRestClient.__dict__.items() + if not name.startswith("__") + and callable(method) + and len(params := inspect.signature(method).parameters) == 1 + and "self" in params ] @pytest.fixture(autouse=True) -def clean_existing_tasks(client: BlueapiClient): - for task in client.get_all_tasks().tasks: - client.clear_task(task.task_id) +def clean_existing_tasks(rest_client: BlueapiRestClient): + for task in rest_client.get_all_tasks().tasks: + rest_client.clear_task(task.task_id) yield @@ -213,26 +216,26 @@ def reset_numtracker(server_config: ApplicationConfig): def test_cannot_access_endpoints( - client_without_auth: BlueapiClient, blueapi_client_get_methods: list[str] + client_without_auth: BlueapiClient, blueapi_rest_client_get_methods: list[str] ): - blueapi_client_get_methods.remove( + blueapi_rest_client_get_methods.remove( "get_oidc_config" ) # get_oidc_config can be accessed without auth - for get_method in blueapi_client_get_methods: + for get_method in blueapi_rest_client_get_methods: with pytest.raises(BlueskyRemoteControlError, match=r""): - getattr(client_without_auth, get_method)() + getattr(client_without_auth._rest, get_method)() def test_can_get_oidc_config_without_auth(client_without_auth: BlueapiClient): - assert client_without_auth.get_oidc_config() == OIDCConfig( + assert client_without_auth.oidc_config == OIDCConfig( well_known_url=KEYCLOAK_BASE_URL + "realms/master/.well-known/openid-configuration", client_id="ixx-cli-blueapi", ) -def test_get_plans(client: BlueapiClient, expected_plans: PlanResponse): - retrieved_plans = client.get_plans() +def test_get_plans(rest_client: BlueapiRestClient, expected_plans: PlanResponse): + retrieved_plans = rest_client.get_plans() retrieved_plans.plans.sort(key=lambda x: x.name) expected_plans.plans.sort(key=lambda x: x.name) @@ -241,40 +244,52 @@ def test_get_plans(client: BlueapiClient, expected_plans: PlanResponse): def test_get_plans_by_name(client: BlueapiClient, expected_plans: PlanResponse): for plan in expected_plans.plans: - assert client.get_plan(plan.name) == plan + assert client.plans[plan.name].model == plan -def test_get_non_existent_plan(client: BlueapiClient): +def test_get_non_existent_plan(rest_client: BlueapiRestClient): with pytest.raises(KeyError, match="{'detail': 'Item not found'}"): - client.get_plan("Not exists") + rest_client.get_plan("Not exists") + +def test_client_non_existent_plan(client: BlueapiClient): + with pytest.raises(AttributeError, match="No plan named 'missing' available"): + _ = client.plans.missing -def test_get_devices(client: BlueapiClient, expected_devices: DeviceResponse): - retrieved_devices = client.get_devices() + +def test_get_devices(rest_client: BlueapiRestClient, expected_devices: DeviceResponse): + retrieved_devices = rest_client.get_devices() retrieved_devices.devices.sort(key=lambda x: x.name) expected_devices.devices.sort(key=lambda x: x.name) assert retrieved_devices == expected_devices -def test_get_device_by_name(client: BlueapiClient, expected_devices: DeviceResponse): +def test_get_device_by_name( + rest_client: BlueapiRestClient, expected_devices: DeviceResponse +): for device in expected_devices.devices: - assert client.get_device(device.name) == device + assert rest_client.get_device(device.name) == device -def test_get_non_existent_device(client: BlueapiClient): +def test_get_non_existent_device(rest_client: BlueapiRestClient): with pytest.raises(KeyError, match="{'detail': 'Item not found'}"): - client.get_device("Not exists") + rest_client.get_device("Not exists") + + +def test_client_non_existent_device(client: BlueapiClient): + with pytest.raises(AttributeError, match="No device named 'missing' available"): + _ = client.devices.missing -def test_create_task_and_delete_task_by_id(client: BlueapiClient): - create_task = client.create_task(_SIMPLE_TASK) - client.clear_task(create_task.task_id) +def test_create_task_and_delete_task_by_id(rest_client: BlueapiRestClient): + create_task = rest_client.create_task(_SIMPLE_TASK) + rest_client.clear_task(create_task.task_id) -def test_instrument_session_propagated(client: BlueapiClient): - response = client.create_task(_SIMPLE_TASK) - trackable_task = client.get_task(response.task_id) +def test_instrument_session_propagated(rest_client: BlueapiRestClient): + response = rest_client.create_task(_SIMPLE_TASK) + trackable_task = rest_client.get_task(response.task_id) assert trackable_task.task.metadata == { "instrument_session": AUTHORIZED_INSTRUMENT_SESSION, "tiled_access_tags": [ @@ -283,9 +298,9 @@ def test_instrument_session_propagated(client: BlueapiClient): } -def test_create_task_validation_error(client: BlueapiClient): +def test_create_task_validation_error(rest_client: BlueapiRestClient): with pytest.raises(BlueskyRequestError, match="Internal Server Error"): - client.create_task( + rest_client.create_task( TaskRequest( name="Not-exists", params={"Not-exists": 0.0}, @@ -294,26 +309,26 @@ def test_create_task_validation_error(client: BlueapiClient): ) -def test_get_all_tasks(client: BlueapiClient): +def test_get_all_tasks(rest_client: BlueapiRestClient): created_tasks: list[TaskResponse] = [] for task in [_SIMPLE_TASK, _LONG_TASK]: - created_task = client.create_task(task) + created_task = rest_client.create_task(task) created_tasks.append(created_task) task_ids = [task.task_id for task in created_tasks] - task_list = client.get_all_tasks() + task_list = rest_client.get_all_tasks() for trackable_task in task_list.tasks: assert trackable_task.task_id in task_ids assert trackable_task.is_complete is False and trackable_task.is_pending is True for task_id in task_ids: - client.clear_task(task_id) + rest_client.clear_task(task_id) -def test_get_task_by_id(client: BlueapiClient): - created_task = client.create_task(_SIMPLE_TASK) +def test_get_task_by_id(rest_client: BlueapiRestClient): + created_task = rest_client.create_task(_SIMPLE_TASK) - get_task = client.get_task(created_task.task_id) + get_task = rest_client.get_task(created_task.task_id) assert ( get_task.task_id == created_task.task_id and get_task.is_pending @@ -321,45 +336,45 @@ def test_get_task_by_id(client: BlueapiClient): and len(get_task.errors) == 0 ) - client.clear_task(created_task.task_id) + rest_client.clear_task(created_task.task_id) -def test_get_non_existent_task(client: BlueapiClient): +def test_get_non_existent_task(rest_client: BlueapiRestClient): with pytest.raises(KeyError, match="{'detail': 'Item not found'}"): - client.get_task("Not-exists") + rest_client.get_task("Not-exists") -def test_delete_non_existent_task(client: BlueapiClient): +def test_delete_non_existent_task(rest_client: BlueapiRestClient): with pytest.raises(KeyError, match="{'detail': 'Item not found'}"): - client.clear_task("Not-exists") + rest_client.clear_task("Not-exists") -def test_put_worker_task(client: BlueapiClient): - created_task = client.create_task(_SIMPLE_TASK) - client.start_task(WorkerTask(task_id=created_task.task_id)) - active_task = client.get_active_task() +def test_put_worker_task(rest_client: BlueapiRestClient): + created_task = rest_client.create_task(_SIMPLE_TASK) + rest_client.update_worker_task(WorkerTask(task_id=created_task.task_id)) + active_task = rest_client.get_active_task() assert active_task.task_id == created_task.task_id - client.clear_task(created_task.task_id) + rest_client.clear_task(created_task.task_id) -def test_put_worker_task_fails_if_not_idle(client: BlueapiClient): - small_task = client.create_task(_SIMPLE_TASK) - long_task = client.create_task(_LONG_TASK) +def test_put_worker_task_fails_if_not_idle(rest_client: BlueapiRestClient): + small_task = rest_client.create_task(_SIMPLE_TASK) + long_task = rest_client.create_task(_LONG_TASK) - client.start_task(WorkerTask(task_id=long_task.task_id)) - active_task = client.get_active_task() + rest_client.update_worker_task(WorkerTask(task_id=long_task.task_id)) + active_task = rest_client.get_active_task() assert active_task.task_id == long_task.task_id with pytest.raises(BlueskyRemoteControlError) as exception: - client.start_task(WorkerTask(task_id=small_task.task_id)) + rest_client.update_worker_task(WorkerTask(task_id=small_task.task_id)) assert "" in str(exception) - client.abort() - client.clear_task(small_task.task_id) - client.clear_task(long_task.task_id) + rest_client.cancel_current_task(WorkerState.ABORTING) + rest_client.clear_task(small_task.task_id) + rest_client.clear_task(long_task.task_id) def test_get_worker_state(client: BlueapiClient): - assert client.get_state() == WorkerState.IDLE + assert client.state == WorkerState.IDLE def test_set_state_transition_error(client: BlueapiClient): @@ -371,10 +386,10 @@ def test_set_state_transition_error(client: BlueapiClient): assert "" in str(exception) -def test_get_task_by_status(client: BlueapiClient): - task_1 = client.create_task(_SIMPLE_TASK) - task_2 = client.create_task(_SIMPLE_TASK) - task_by_pending = client.get_all_tasks() +def test_get_task_by_status(rest_client: BlueapiRestClient): + task_1 = rest_client.create_task(_SIMPLE_TASK) + task_2 = rest_client.create_task(_SIMPLE_TASK) + task_by_pending = rest_client.get_all_tasks() # https://github.com/DiamondLightSource/blueapi/issues/680 # task_by_pending = client.get_tasks_by_status(TaskStatusEnum.PENDING) assert len(task_by_pending.tasks) == 2 @@ -383,13 +398,13 @@ def test_get_task_by_status(client: BlueapiClient): trackable_task = TypeAdapter(TrackableTask).validate_python(task) assert trackable_task.is_complete is False and trackable_task.is_pending is True - client.start_task(WorkerTask(task_id=task_1.task_id)) - while not client.get_task(task_1.task_id).is_complete: + rest_client.update_worker_task(WorkerTask(task_id=task_1.task_id)) + while not rest_client.get_task(task_1.task_id).is_complete: time.sleep(0.1) - client.start_task(WorkerTask(task_id=task_2.task_id)) - while not client.get_task(task_2.task_id).is_complete: + rest_client.update_worker_task(WorkerTask(task_id=task_2.task_id)) + while not rest_client.get_task(task_2.task_id).is_complete: time.sleep(0.1) - task_by_completed = client.get_all_tasks() + task_by_completed = rest_client.get_all_tasks() # https://github.com/DiamondLightSource/blueapi/issues/680 # task_by_pending = client.get_tasks_by_status(TaskStatusEnum.COMPLETE) assert len(task_by_completed.tasks) == 2 @@ -398,8 +413,8 @@ def test_get_task_by_status(client: BlueapiClient): trackable_task = TypeAdapter(TrackableTask).validate_python(task) assert trackable_task.is_complete is True and trackable_task.is_pending is False - client.clear_task(task_id=task_1.task_id) - client.clear_task(task_id=task_2.task_id) + rest_client.clear_task(task_id=task_1.task_id) + rest_client.clear_task(task_id=task_2.task_id) def test_progress_with_stomp(client_with_stomp: BlueapiClient): @@ -440,13 +455,13 @@ def on_event(event: AnyEvent): def test_get_current_state_of_environment(client: BlueapiClient): - assert client.get_environment().initialized + assert client.environment.initialized def test_delete_current_environment(client: BlueapiClient): - old_env = client.get_environment() + old_env = client.environment client.reload_environment() - new_env = client.get_environment() + new_env = client.environment assert new_env.initialized assert new_env.environment_id != old_env.environment_id assert new_env.error_message is None diff --git a/tests/unit_tests/cli/test_stubgen.py b/tests/unit_tests/cli/test_stubgen.py new file mode 100644 index 0000000000..766f3e2f06 --- /dev/null +++ b/tests/unit_tests/cli/test_stubgen.py @@ -0,0 +1,214 @@ +from io import StringIO +from textwrap import dedent +from types import FunctionType +from unittest.mock import Mock + +import pytest + +from blueapi.cli.stubgen import ( + _docstring, + _type_string, + generate_stubs, + render_stub_file, +) +from blueapi.client.cache import DeviceRef, Plan +from blueapi.service.model import DeviceModel, PlanModel + + +def single_line(): + """Single line docstring""" + + +def single_line_new_line(): + """ + Single line docstring + """ + + +def multi_line_inline(): + """First line + Second line""" + + +def multi_line_new_line(): + """ + First line + Second line + """ + + +def indented_multi_line(): + """ + First line + indented + """ + + +@pytest.mark.parametrize( + "input,expected", + [ + (single_line, "Single line docstring"), + (single_line_new_line, "Single line docstring"), + (multi_line_inline, "First line\nSecond line"), + (multi_line_new_line, "First line\nSecond line"), + (indented_multi_line, "First line\n indented"), + ], +) +def test_docstring_filter(input: FunctionType, expected: str): + assert input.__doc__ + assert _docstring(input.__doc__) == expected + + +@pytest.mark.parametrize( + "typ,expected", + [ + ({"type": "string"}, "str"), + ({"type": "number"}, "float"), + ({"type": "integer"}, "int"), + ({"type": "object"}, "dict[str, Any]"), + ({"type": "boolean"}, "bool"), + ({"type": "array", "items": {"type": "integer"}}, "list[int]"), + ({"type": "array", "items": {"type": "object"}}, "list[dict[str, Any]]"), + ( + { + "type": "array", + "items": {"anyOf": [{"type": "integer"}, {"type": "boolean"}]}, + }, + "list[int | bool]", + ), + ({"anyOf": [{"type": "object"}, {"type": "string"}]}, "dict[str, Any] | str"), + ({"type": "unknown.other.Type"}, "Any"), + # Special case the bluesky protocols to require device references + ({"type": "bluesky.protocols.Readable"}, "DeviceRef"), + ({}, "Any"), + ], + ids=lambda param: param.get("type") if isinstance(param, dict) else param, +) +def test_type_string(typ: dict, expected: str): + assert _type_string(typ) == expected + + +def test_render_empty(): + output = StringIO() + + render_stub_file(output, [], []) + plan_text, device_text = _extract_rendered(output) + + assert plan_text == "" + assert device_text == "" + + +FOO = PlanModel(name="empty", description="Doc string for empty", schema={}) + +BAR = PlanModel( + name="two_args", + description="Doc string for two_args", + schema={ + "properties": { + "one": {"type": "integer"}, + "two": {"type": "string"}, + }, + "required": ["one"], + }, +) + + +def test_render_empty_plan_function(): + output = StringIO() + plans = [Plan(model=FOO, runner=Mock())] + render_stub_file(output, plans, []) + plan_text, device_text = _extract_rendered(output) + + assert device_text == "" + + assert ( + plan_text + == """\ + def empty(self, + ) -> WorkerEvent: + \""" + Doc string for empty + \""" + ...\n""" + ) + + +def test_render_multiple_plan_functions(): + output = StringIO() + runner = Mock() + plans = [Plan(FOO, runner), Plan(BAR, runner)] + render_stub_file(output, plans, []) + plan_text, device_text = _extract_rendered(output) + assert device_text == "" + + assert ( + plan_text + == """\ + def empty(self, + ) -> WorkerEvent: + \""" + Doc string for empty + \""" + ... + def two_args(self, + one: int, + two: str | None = None, + ) -> WorkerEvent: + \""" + Doc string for two_args + \""" + ...\n""" + ) + + +def test_device_fields(): + output = StringIO() + cache = Mock() + devices = [ + DeviceRef("one", cache, DeviceModel(name="one", protocols=[])), + DeviceRef("two", cache, DeviceModel(name="two", protocols=[])), + ] + render_stub_file(output, [], devices) + + plan_text, device_text = _extract_rendered(output) + assert plan_text == "" + assert device_text == " one: DeviceRef\n two: DeviceRef\n" + + +def test_package_creation(tmp_path): + generate_stubs(tmp_path / "blueapi-stubs", [], []) + with open(tmp_path / "blueapi-stubs" / "pyproject.toml") as pyproj: + assert pyproj.read().startswith( + dedent(""" + [project] + name = "blueapi-stubs" + version = "0.1.0" + """) + ) + with open( + tmp_path / "blueapi-stubs" / "src" / "blueapi-stubs" / "py.typed" + ) as typed: + assert typed.read() == "partial\n" + + assert ( + tmp_path / "blueapi-stubs" / "src" / "blueapi-stubs" / "client" / "cache.pyi" + ).exists() + + +def _extract_rendered(src: StringIO) -> tuple[str, str]: + src.seek(0) + _read_until_line(src, "### Generated plans") + plan_text = _read_until_line(src, "### End") + _read_until_line(src, "### Generated devices") + device_text = _read_until_line(src, "### End") + return plan_text, device_text + + +def _read_until_line(src: StringIO, match: str) -> str: + text = "" + for line in src: + if line.startswith(match): + break + text += line + + return text diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index d6d2e1f22b..5649270d2a 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -8,8 +8,11 @@ JsonObjectSpanExporter, asserting_span_exporter, ) +from pydantic import HttpUrl from blueapi.client import BlueapiClient +from blueapi.client.cache import DeviceCache, DeviceRef, Plan, PlanCache +from blueapi.client.client import MissingInstrumentSessionError from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient from blueapi.client.rest import BlueapiRestClient, BlueskyRemoteControlError from blueapi.config import MissingStompConfigurationError @@ -20,6 +23,7 @@ EnvironmentResponse, PlanModel, PlanResponse, + ProtocolInfo, TaskRequest, TaskResponse, TasksListResponse, @@ -35,6 +39,19 @@ ] ) PLAN = PlanModel(name="foo") +FULL_PLAN = PlanModel( + name="foobar", + description="Description of plan foobar", + schema={ + "title": "foobar", + "description": "Model description of plan foobar", + "properties": { + "one": {}, + "two": {}, + }, + "required": ["one"], + }, +) DEVICES = DeviceResponse( devices=[ DeviceModel(name="foo", protocols=[]), @@ -72,9 +89,9 @@ def mock_rest() -> BlueapiRestClient: mock = Mock(spec=BlueapiRestClient) mock.get_plans.return_value = PLANS - mock.get_plan.return_value = PLAN + mock.get_plan.side_effect = lambda n: {p.name: p for p in PLANS.plans}[n] mock.get_devices.return_value = DEVICES - mock.get_device.return_value = DEVICE + mock.get_device.side_effect = lambda n: {d.name: d for d in DEVICES.devices}[n] mock.get_state.return_value = WorkerState.IDLE mock.get_task.return_value = TASK mock.get_all_tasks.return_value = TASKS @@ -105,116 +122,72 @@ def client_with_events(mock_rest: Mock, mock_events: MagicMock): return BlueapiClient(rest=mock_rest, events=mock_events) +def test_client_from_config(): + bc = BlueapiClient.from_config_file( + "tests/unit_tests/valid_example_config/client.yaml" + ) + assert bc._rest._config.url == HttpUrl("http://example.com:8082") + + def test_get_plans(client: BlueapiClient): - assert client.get_plans() == PLANS + assert PlanResponse(plans=[p.model for p in client.plans]) == PLANS def test_get_plan(client: BlueapiClient): - assert client.get_plan("foo") == PLAN + assert client.plans.foo.model == PLAN + assert client.plans["foo"].model == PLAN def test_get_nonexistant_plan( client: BlueapiClient, - mock_rest: Mock, ): - mock_rest.get_plan.side_effect = KeyError("Not found") - with pytest.raises(KeyError): - client.get_plan("baz") + with pytest.raises(AttributeError): + _ = client.plans.fizz_buzz.model def test_get_devices(client: BlueapiClient): - assert client.get_devices() == DEVICES + assert DeviceResponse(devices=[d.model for d in client.devices]) == DEVICES def test_get_device(client: BlueapiClient): - assert client.get_device("foo") == DEVICE - - -def test_get_nonexistant_device( - client: BlueapiClient, - mock_rest: Mock, -): - mock_rest.get_device.side_effect = KeyError("Not found") - with pytest.raises(KeyError): - client.get_device("baz") - - -def test_get_state(client: BlueapiClient): - assert client.get_state() == WorkerState.IDLE - + assert client.devices.foo.model == DEVICE -def test_get_task(client: BlueapiClient): - assert client.get_task("foo") == TASK - -def test_get_nonexistent_task( +def test_get_nonexistent_device( client: BlueapiClient, - mock_rest: Mock, ): - mock_rest.get_task.side_effect = KeyError("Not found") - with pytest.raises(KeyError): - client.get_task("baz") + with pytest.raises(AttributeError): + _ = client.devices.baz -def test_get_task_with_empty_id(client: BlueapiClient): - with pytest.raises(AssertionError) as exc: - client.get_task("") - assert str(exc) == "Task ID not provided!" +def test_get_child_device(mock_rest: Mock, client: BlueapiClient): + mock_rest.get_device.side_effect = lambda name: ( + DeviceModel(name="foo.x", protocols=[ProtocolInfo(name="One")]) + if name == "foo.x" + else None + ) + foo = client.devices.foo + assert foo == "foo" + x = client.devices.foo.x + assert x == "foo.x" -def test_get_all_tasks( - client: BlueapiClient, -): - assert client.get_all_tasks() == TASKS +def test_state_property(client: BlueapiClient): + assert client.state == WorkerState.IDLE -def test_create_task( - client: BlueapiClient, - mock_rest: Mock, -): - client.create_task(task=TaskRequest(name="foo", instrument_session="cm12345-1")) - mock_rest.create_task.assert_called_once_with( - TaskRequest(name="foo", instrument_session="cm12345-1") - ) +def test_get_state(client: BlueapiClient): + assert client.get_state() == WorkerState.IDLE -def test_create_task_does_not_start_task( - client: BlueapiClient, - mock_rest: Mock, -): - client.create_task(task=TaskRequest(name="foo", instrument_session="cm12345-1")) - mock_rest.update_worker_task.assert_not_called() - - -def test_clear_task( - client: BlueapiClient, - mock_rest: Mock, -): - client.clear_task(task_id="foo") - mock_rest.clear_task.assert_called_once_with("foo") +def test_active_task_property(client: BlueapiClient): + assert client.active_task == ACTIVE_TASK def test_get_active_task(client: BlueapiClient): assert client.get_active_task() == ACTIVE_TASK -def test_start_task( - client: BlueapiClient, - mock_rest: Mock, -): - client.start_task(task=WorkerTask(task_id="bar")) - mock_rest.update_worker_task.assert_called_once_with(WorkerTask(task_id="bar")) - - -def test_start_nonexistant_task( - client: BlueapiClient, - mock_rest: Mock, -): - mock_rest.update_worker_task.side_effect = KeyError("Not found") - with pytest.raises(KeyError): - client.start_task(task=WorkerTask(task_id="bar")) - - def test_create_and_start_task_calls_both_creating_and_starting_endpoints( client: BlueapiClient, mock_rest: Mock, @@ -265,6 +238,10 @@ def test_create_and_start_task_fails_if_task_start_fails( ) +def test_environment_property(client: BlueapiClient): + assert client.environment == ENV + + def test_get_environment(client: BlueapiClient): assert client.get_environment() == ENV @@ -439,6 +416,15 @@ def test_run_task_fails_on_failing_event( on_event.assert_called_with(FAILED_EVENT) +@patch("blueapi.client.client.BlueapiClient.run_task") +def test_run_plan(run_task, client, mock_rest): + client.instrument_session = "cm12345-2" + client.run_plan("foo", {"foo": "bar"}) + run_task.assert_called_once_with( + TaskRequest(name="foo", params={"foo": "bar"}, instrument_session="cm12345-2") + ) + + @pytest.mark.parametrize( "test_event", [ @@ -521,76 +507,44 @@ def callback(on_event: Callable[[AnyEvent, MessageContext], None]): mock_on_event.assert_called_once_with(COMPLETE_EVENT) -def test_get_plans_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_plans"): - client.get_plans() - +def test_oidc_config_property(client, mock_rest): + assert client.oidc_config == mock_rest.get_oidc_config() -def test_get_plan_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_plan", "name"): - client.get_plan("foo") +def test_get_oidc_config(client, mock_rest): + assert client.get_oidc_config() == mock_rest.get_oidc_config() -def test_get_devices_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_devices"): - client.get_devices() - - -def test_get_device_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_device", "name"): - client.get_device("foo") - -def test_get_state_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_state"): - client.get_state() +def test_get_plans_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): + with asserting_span_exporter(exporter, "plans"): + _ = client.plans -def test_get_task_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): - with asserting_span_exporter(exporter, "get_task", "task_id"): - client.get_task("foo") +def test_get_plan_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): + with asserting_span_exporter(exporter, "plans"): + _ = client.plans.foo -def test_get_all_tasks_span_ok( - exporter: JsonObjectSpanExporter, - client: BlueapiClient, -): - with asserting_span_exporter(exporter, "get_all_tasks"): - client.get_all_tasks() +def test_get_devices_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): + with asserting_span_exporter(exporter, "devices"): + _ = client.devices -def test_create_task_span_ok( - exporter: JsonObjectSpanExporter, - client: BlueapiClient, - mock_rest: Mock, -): - with asserting_span_exporter(exporter, "create_task", "task"): - client.create_task(task=TaskRequest(name="foo", instrument_session="cm12345-1")) +def test_get_device_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): + with asserting_span_exporter(exporter, "devices"): + _ = client.devices.foo -def test_clear_task_span_ok( - exporter: JsonObjectSpanExporter, - client: BlueapiClient, - mock_rest: Mock, -): - with asserting_span_exporter(exporter, "clear_task"): - client.clear_task(task_id="foo") +def test_get_state_span_ok(exporter: JsonObjectSpanExporter, client: BlueapiClient): + with asserting_span_exporter(exporter, "state"): + _ = client.state def test_get_active_task_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient ): - with asserting_span_exporter(exporter, "get_active_task"): - client.get_active_task() - - -def test_start_task_span_ok( - exporter: JsonObjectSpanExporter, - client: BlueapiClient, - mock_rest: Mock, -): - with asserting_span_exporter(exporter, "start_task", "task"): - client.start_task(task=WorkerTask(task_id="bar")) + with asserting_span_exporter(exporter, "active_task"): + _ = client.active_task def test_create_and_start_task_span_ok( @@ -609,8 +563,8 @@ def test_create_and_start_task_span_ok( def test_get_environment_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient ): - with asserting_span_exporter(exporter, "get_environment"): - client.get_environment() + with asserting_span_exporter(exporter, "environment"): + _ = client.environment def test_reload_environment_span_ok( @@ -668,3 +622,274 @@ def test_cannot_run_task_span_ok( ): with asserting_span_exporter(exporter, "grun_task"): client.run_task(TaskRequest(name="foo", instrument_session="cm12345-1")) + + +def test_instrument_session_required(client): + with pytest.raises(MissingInstrumentSessionError): + _ = client.instrument_session + + +def test_setting_instrument_session(client): + # This looks like a completely pointless test but instrument_session is a + # property with some logic so it's not purely to get coverage up + client.instrument_session = "cm12345-4" + assert client.instrument_session == "cm12345-4" + + +def test_fluent_instrument_session_setter(client): + client2 = client.with_instrument_session("cm12345-3") + assert client is client2 + assert client.instrument_session == "cm12345-3" + + +def test_plan_cache_ignores_underscores(client): + cache = PlanCache(client, [PlanModel(name="_ignored"), PlanModel(name="used")]) + with pytest.raises(AttributeError, match="_ignored"): + _ = cache._ignored + + +def test_plan_cache_repr(client): + assert repr(client.plans) == "PlanCache(2 plans)" + + +def test_device_cache_ignores_underscores(): + rest = Mock() + rest.get_devices.return_value = DeviceResponse( + devices=[ + DeviceModel(name="_ignored", protocols=[]), + ] + ) + cache = DeviceCache(rest) + with pytest.raises(AttributeError, match="_ignored"): + _ = cache._ignored + + rest.get_devices.reset_mock() + with pytest.raises(AttributeError, match="_anything"): + _ = cache._anything + rest.get_device.assert_not_called() + + +def test_devices_are_cached(mock_rest): + cache = DeviceCache(mock_rest) + _ = cache.foo + mock_rest.get_device.assert_not_called() + _ = cache["foo"] + mock_rest.get_device.assert_not_called() + + +def test_device_cache_repr(client): + assert repr(client.devices) == "DeviceCache(2 devices)" + + +def test_device_repr(): + cache = Mock() + model = Mock() + dev = DeviceRef(name="foo", cache=cache, model=model) + assert repr(dev) == "Device(foo)" + + +def test_device_ignores_underscores(): + cache = MagicMock() + model = Mock() + dev = DeviceRef(name="foo", cache=cache, model=model) + with pytest.raises(AttributeError, match="_underscore"): + _ = dev._underscore + cache.__getitem__.assert_not_called() + + +def test_plan_help_text(): + plan = Plan(PlanModel(name="foo", description="help for foo"), Mock()) + assert plan.help_text == "help for foo" + + +def test_plan_fallback_help_text(): + plan = Plan( + PlanModel( + name="foo", + schema={"properties": {"one": {}, "two": {}}, "required": ["one"]}, + ), + Mock(), + ) + assert plan.help_text == "Plan foo(one, two=None)" + + +def test_plan_properties(): + plan = Plan( + PlanModel( + name="foo", + schema={"properties": {"one": {}, "two": {}}, "required": ["one"]}, + ), + Mock(), + ) + + assert plan.properties == {"one", "two"} + assert plan.required == ["one"] + + +def test_plan_empty_fallback_help_text(): + plan = Plan( + PlanModel(name="foo", schema={"properties": {}, "required": []}), Mock() + ) + assert plan.help_text == "Plan foo()" + + +p = pytest.param + + +@pytest.mark.parametrize( + "args,kwargs,params", + [ + p((1,), {}, {"one": 1}, id="required_as_positional"), + p((), {"one": 7}, {"one": 7}, id="required_as_keyword"), + p((1,), {"two": 23}, {"one": 1, "two": 23}, id="all_as_mixed_args_kwargs"), + p((1, 2), {}, {"one": 1, "two": 2}, id="all_as_positional"), + p((), {"one": 21, "two": 42}, {"one": 21, "two": 42}, id="all_as_keyword"), + ], +) +def test_plan_param_mapping(args, kwargs, params): + runner = Mock() + plan = Plan(FULL_PLAN, runner) + + plan(*args, **kwargs) + runner.assert_called_once_with("foobar", params) + + +@pytest.mark.parametrize( + "args,kwargs,msg", + [ + p((), {}, r"Missing argument\(s\) for \{'one'\}", id="missing_required"), + p((1,), {"one": 7}, "multiple values for one", id="duplicate_required"), + p((1, 2), {"two": 23}, "multiple values for two", id="duplicate_optional"), + p((1, 2, 3), {}, "too many arguments", id="too_many_args"), + p( + (), + {"unknown_key": 42}, + r"got unexpected arguments: \{'unknown_key'\}", + id="unknown_arg", + ), + ], +) +def test_plan_invalid_param_mapping(args, kwargs, msg): + runner = Mock(spec=Callable) + plan = Plan( + FULL_PLAN, + runner, + ) + + with pytest.raises(TypeError, match=msg): + plan(*args, **kwargs) + runner.assert_not_called() + + +def test_adding_removing_callback(client): + def callback(*a, **kw): + pass + + cb_id = client.add_callback(callback) + assert len(client.callbacks) == 1 + client.remove_callback(cb_id) + assert len(client.callbacks) == 0 + + +@pytest.mark.parametrize( + "test_event", + [ + WorkerEvent( + state=WorkerState.RUNNING, + task_status=TaskStatus( + task_id="foo", + task_complete=False, + task_failed=False, + ), + ), + ProgressEvent(task_id="foo"), + DataEvent(name="start", doc={}, task_id="0000-1111"), + ], +) +def test_client_callbacks( + client_with_events: BlueapiClient, + mock_rest: Mock, + mock_events: MagicMock, + test_event: AnyEvent, +): + callback = Mock() + client_with_events.add_callback(callback) + mock_rest.create_task.return_value = TaskResponse(task_id="foo") + mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo") + + ctx = Mock() + ctx.correlation_id = "foo" + + def subscribe(on_event: Callable[[AnyEvent, MessageContext], None]): + on_event(test_event, ctx) + on_event(COMPLETE_EVENT, ctx) + + mock_events.subscribe_to_all_events = subscribe # type: ignore + + client_with_events.run_task(TaskRequest(name="foo", instrument_session="cm12345-1")) + + assert callback.mock_calls == [call(test_event), call(COMPLETE_EVENT)] + + +def test_client_callback_failures( + client_with_events: BlueapiClient, + mock_rest: Mock, + mock_events: MagicMock, +): + failing_callback = Mock(side_effect=ValueError("Broken callback")) + callback = Mock() + client_with_events.add_callback(failing_callback) + client_with_events.add_callback(callback) + mock_rest.create_task.return_value = TaskResponse(task_id="foo") + mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo") + + ctx = Mock() + ctx.correlation_id = "foo" + + evt = DataEvent(name="start", doc={}, task_id="foo") + + def subscribe(on_event: Callable[[AnyEvent, MessageContext], None]): + on_event(evt, ctx) + on_event(COMPLETE_EVENT, ctx) + + mock_events.subscribe_to_all_events = subscribe # type: ignore + + client_with_events.run_task(TaskRequest(name="foo", instrument_session="cm12345-1")) + + assert failing_callback.mock_calls == [call(evt), call(COMPLETE_EVENT)] + assert callback.mock_calls == [call(evt), call(COMPLETE_EVENT)] + + +@patch("blueapi.client.client.SessionManager") +def test_client_login_existing_login(mock_session_manager: Mock, client: BlueapiClient): + client.login() + + mock_session_manager.from_cache.assert_called_once() + mock_session_manager.from_cache().get_valid_access_token.assert_called_once() + + +@patch("blueapi.client.client.SessionManager") +def test_client_new_login(mock_session_manager: Mock, client: BlueapiClient): + manager = Mock() + manager.get_valid_access_token.side_effect = ValueError("No existing token") + + mock_session_manager.from_cache.return_value = manager + + client.login() + + mock_session_manager.assert_called_once() + mock_session_manager.return_value.start_device_flow.assert_called_once() + + +@patch("blueapi.client.client.SessionManager") +def test_client_login_no_oidc( + mock_session_manager: Mock, mock_rest: Mock, client: BlueapiClient +): + mock_rest.get_oidc_config.return_value = None + mock_session_manager.from_cache.return_value.get_valid_access_token.side_effect = ( + ValueError("No existing token") + ) + + client.login() + + mock_session_manager.assert_not_called() diff --git a/tests/unit_tests/client/test_rest.py b/tests/unit_tests/client/test_rest.py index c8fce9d101..2ddcdd3800 100644 --- a/tests/unit_tests/client/test_rest.py +++ b/tests/unit_tests/client/test_rest.py @@ -45,7 +45,7 @@ def rest_with_auth(oidc_config: OIDCConfig, tmp_path) -> BlueapiRestClient: (500, BlueskyRemoteControlError), ], ) -@patch("blueapi.client.rest.requests.request") +@patch("blueapi.client.rest.requests.Session.request") def test_rest_error_code( mock_request: Mock, rest: BlueapiRestClient, diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 2a49a1fe80..a789d21eb5 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -110,7 +110,7 @@ def test_runs_with_umask_002( mock_umask.assert_called_once_with(0o002) -@patch("requests.request") +@patch("blueapi.client.rest.requests.Session.request") def test_connection_error_caught_by_wrapper_func( mock_requests: Mock, runner: CliRunner ): @@ -120,7 +120,7 @@ def test_connection_error_caught_by_wrapper_func( assert result.output == "Error: Failed to establish connection to blueapi server.\n" -@patch("requests.request") +@patch("blueapi.client.rest.requests.Session.request") def test_authentication_error_caught_by_wrapper_func( mock_requests: Mock, runner: CliRunner ): @@ -133,7 +133,7 @@ def test_authentication_error_caught_by_wrapper_func( ) -@patch("requests.request") +@patch("blueapi.client.rest.requests.Session.request") def test_remote_error_raised_by_wrapper_func(mock_requests: Mock, runner: CliRunner): mock_requests.side_effect = BlueskyRemoteControlError("Response [450]") @@ -198,15 +198,15 @@ def test_invalid_config_path_handling(runner: CliRunner): assert result.exit_code == 1 -@patch("blueapi.cli.cli.BlueapiClient.get_plans") +@patch("blueapi.cli.cli.BlueapiClient.plans") @patch("blueapi.cli.cli.OutputFormat.FULL.display") def test_options_via_env(mock_display, mock_plans, runner: CliRunner): result = runner.invoke( main, args=["controller", "plans"], env={"BLUEAPI_CONTROLLER_OUTPUT": "full"} ) - mock_plans.assert_called_once_with() - mock_display.assert_called_once_with(mock_plans.return_value) + mock_plans.__iter__.assert_called_once_with() + mock_display.assert_called_once_with(PlanResponse(plans=list(mock_plans))) assert result.exit_code == 0 @@ -493,9 +493,7 @@ def test_valid_stomp_config_for_listener( @responses.activate -def test_get_env( - runner: CliRunner, -): +def test_get_env(runner: CliRunner): environment_id = uuid.uuid4() responses.add( responses.GET, @@ -514,6 +512,17 @@ def test_get_env( ) +@responses.activate +def test_get_state(runner: CliRunner): + responses.add( + responses.GET, "http://localhost:8000/worker/state", json="IDLE", status=200 + ) + state = runner.invoke(main, ["controller", "state"]) + print(state.stderr) + assert state.exit_code == 0 + assert state.output == "IDLE\n" + + @responses.activate(assert_all_requests_are_fired=True) @patch("blueapi.client.client.time.sleep", return_value=None) def test_reset_env_client_behavior( @@ -1320,3 +1329,17 @@ def test_config_schema( stream.write.assert_called() else: assert json.loads(result.output) == expected + pass + + +@patch("blueapi.client.client.BlueapiClient.from_config") +@patch("blueapi.cli.cli.stubgen") +def test_genstubs( + stubgen, + client, + runner: CliRunner, +): + runner.invoke(main, ["generate-stubs", "/path/to/stub_dir"]) + stubgen.generate_stubs.assert_called_once_with( + Path("/path/to/stub_dir"), list(client().plans), list(client().devices) + ) diff --git a/tests/unit_tests/utils/test_deprecated.py b/tests/unit_tests/utils/test_deprecated.py new file mode 100644 index 0000000000..6e1d4787dc --- /dev/null +++ b/tests/unit_tests/utils/test_deprecated.py @@ -0,0 +1,19 @@ +import warnings + +import pytest + +from blueapi.utils import deprecated + + +def test_deprecated_annotation(): + @deprecated("bar") + def foo(): + return 1 + + with pytest.warns(DeprecationWarning, match="Function foo is deprecated - use bar"): + assert foo() == 1 + + # The second time a function is called, the warning should not be raised + with warnings.catch_warnings(): + warnings.simplefilter("error") + assert foo() == 1 diff --git a/uv.lock b/uv.lock index 1eb18bb2a0..6daa8f8d29 100644 --- a/uv.lock +++ b/uv.lock @@ -428,6 +428,7 @@ dependencies = [ { name = "fastapi" }, { name = "gitpython" }, { name = "graypy" }, + { name = "jinja2" }, { name = "observability-utils" }, { name = "opentelemetry-distro" }, { name = "opentelemetry-instrumentation-fastapi" }, @@ -484,6 +485,7 @@ requires-dist = [ { name = "fastapi", specifier = ">=0.112.0" }, { name = "gitpython" }, { name = "graypy", specifier = ">=2.1.0" }, + { name = "jinja2", specifier = ">=3.1.6" }, { name = "observability-utils", specifier = ">=0.1.4" }, { name = "opentelemetry-distro", specifier = ">=0.48b0" }, { name = "opentelemetry-instrumentation-fastapi", specifier = ">=0.48b0" },