Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 37 additions & 20 deletions src/blueapi/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
from functools import wraps
from pathlib import Path
from pprint import pprint
from typing import ParamSpec, TypeVar
from typing import Any, ParamSpec, TypeGuard, TypeVar

import click
from bluesky.callbacks.best_effort import BestEffortCallback
from bluesky_stomp.messaging import MessageContext, StompClient
from bluesky_stomp.models import Broker
from click.core import Context, Parameter
from click.exceptions import ClickException
from click.types import ParamType
from observability_utils.tracing import setup_tracing
from pydantic import ValidationError

from blueapi import __version__, config
from blueapi.cli.format import OutputFormat
Expand Down Expand Up @@ -44,6 +45,35 @@

LOGGER = logging.getLogger(__name__)

TaskParameters = dict[str, Any]


class ParametersType(ParamType):
"""CLI input parameter to accept a JSON object as an argument"""

name = "TaskParameters"

def convert(
self,
value: str | dict[str, Any] | None,
param: Parameter | None,
ctx: Context | None,
) -> TaskParameters:
if isinstance(value, str):
try:
params = json.loads(value)
if is_str_dict(params):
return params
self.fail("Parameters must be a JSON object with string keys")
except json.JSONDecodeError as jde:
self.fail(f"Parameters are not valid JSON: {jde}")
else:
return super().convert(value, param, ctx)


def is_str_dict(val: Any) -> TypeGuard[TaskParameters]:
return isinstance(val, dict) and all(isinstance(k, str) for k in val)


@click.group(
invoke_without_command=True, context_settings={"auto_envvar_prefix": "BLUEAPI"}
Expand Down Expand Up @@ -258,7 +288,7 @@ def on_event(

@controller.command(name="run")
@click.argument("name", type=str)
@click.argument("parameters", type=str, required=False)
@click.argument("parameters", type=ParametersType(), default={}, required=False)
@click.option(
"--foreground/--background", "--fg/--bg", type=bool, is_flag=True, default=True
)
Expand All @@ -284,29 +314,16 @@ def on_event(
def run_plan(
obj: dict,
name: str,
parameters: str | None,
timeout: float | None,
foreground: bool,
instrument_session: str,
parameters: TaskParameters,
) -> None:
"""Run a plan with parameters"""
client: BlueapiClient = obj["client"]

parameters = parameters or "{}"
try:
parsed_params = json.loads(parameters) if isinstance(parameters, str) else {}
except json.JSONDecodeError as jde:
raise ClickException(f"Parameters are not valid JSON: {jde}") from jde

try:
task = TaskRequest(
name=name,
params=parsed_params,
instrument_session=instrument_session,
)
except ValidationError as ve:
ip = InvalidParametersError.from_validation_error(ve)
raise ClickException(ip.message()) from ip
task = TaskRequest(
name=name, params=parameters, instrument_session=instrument_session
)

try:
if foreground:
Expand Down
14 changes: 10 additions & 4 deletions tests/unit_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from stomp.connect import StompConnection11 as Connection

from blueapi import __version__
from blueapi.cli.cli import main
from blueapi.cli.cli import ParametersType, main
from blueapi.cli.format import OutputFormat, fmt_dict
from blueapi.client.event_bus import BlueskyStreamingError
from blueapi.client.rest import (
Expand Down Expand Up @@ -714,7 +714,7 @@ def test_error_handling(exception, error_message, runner: CliRunner):
"params, error",
[
("{", "Parameters are not valid JSON"),
("[]", ""),
("[]", "Parameters must be a JSON object with string keys"),
],
)
def test_run_task_parsing_errors(params: str, error: str, runner: CliRunner):
Expand All @@ -731,8 +731,8 @@ def test_run_task_parsing_errors(params: str, error: str, runner: CliRunner):
params,
],
)
assert result.stderr.startswith("Error: " + error)
assert result.exit_code == 1
assert error in result.stderr
assert result.exit_code == 2


def test_device_output_formatting():
Expand Down Expand Up @@ -1329,3 +1329,9 @@ def test_config_schema(
stream.write.assert_called()
else:
assert json.loads(result.output) == expected


@pytest.mark.parametrize("value,result", [({}, {}), ("{}", {}), (None, None)])
def test_task_parameter_type(value, result):
t = ParametersType()
assert t.convert(value, None, None) == result