From fc3d25baa0a2adc2c4aa38362e843b5b33d13947 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 20 Jun 2025 17:16:32 +0100 Subject: [PATCH] feat: Add click ParamType to handle JSON parsing Argument parsing has been moved into a click option validator. This lets clicks error handling deal with invalid JSON being passed to the run_plan method instead of having it with the rest of the run logic. --- src/blueapi/cli/cli.py | 57 +++++++++++++++++++++++------------- tests/unit_tests/test_cli.py | 14 ++++++--- 2 files changed, 47 insertions(+), 24 deletions(-) diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 4a2fff09a..9be9549c8 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -8,15 +8,16 @@ from functools import wraps from pathlib import Path from pprint import pprint -from typing import ParamSpec, TypeVar +from typing import Any, ParamSpec, TypeGuard, TypeVar import click from bluesky.callbacks.best_effort import BestEffortCallback from bluesky_stomp.messaging import MessageContext, StompClient from bluesky_stomp.models import Broker +from click.core import Context, Parameter from click.exceptions import ClickException +from click.types import ParamType from observability_utils.tracing import setup_tracing -from pydantic import ValidationError from blueapi import __version__, config from blueapi.cli.format import OutputFormat @@ -44,6 +45,35 @@ LOGGER = logging.getLogger(__name__) +TaskParameters = dict[str, Any] + + +class ParametersType(ParamType): + """CLI input parameter to accept a JSON object as an argument""" + + name = "TaskParameters" + + def convert( + self, + value: str | dict[str, Any] | None, + param: Parameter | None, + ctx: Context | None, + ) -> TaskParameters: + if isinstance(value, str): + try: + params = json.loads(value) + if is_str_dict(params): + return params + self.fail("Parameters must be a JSON object with string keys") + except json.JSONDecodeError as jde: + self.fail(f"Parameters are not valid JSON: {jde}") + else: + return super().convert(value, param, ctx) + + +def is_str_dict(val: Any) -> TypeGuard[TaskParameters]: + return isinstance(val, dict) and all(isinstance(k, str) for k in val) + @click.group( invoke_without_command=True, context_settings={"auto_envvar_prefix": "BLUEAPI"} @@ -258,7 +288,7 @@ def on_event( @controller.command(name="run") @click.argument("name", type=str) -@click.argument("parameters", type=str, required=False) +@click.argument("parameters", type=ParametersType(), default={}, required=False) @click.option( "--foreground/--background", "--fg/--bg", type=bool, is_flag=True, default=True ) @@ -284,29 +314,16 @@ def on_event( def run_plan( obj: dict, name: str, - parameters: str | None, timeout: float | None, foreground: bool, instrument_session: str, + parameters: TaskParameters, ) -> None: """Run a plan with parameters""" client: BlueapiClient = obj["client"] - - parameters = parameters or "{}" - try: - parsed_params = json.loads(parameters) if isinstance(parameters, str) else {} - except json.JSONDecodeError as jde: - raise ClickException(f"Parameters are not valid JSON: {jde}") from jde - - try: - task = TaskRequest( - name=name, - params=parsed_params, - instrument_session=instrument_session, - ) - except ValidationError as ve: - ip = InvalidParametersError.from_validation_error(ve) - raise ClickException(ip.message()) from ip + task = TaskRequest( + name=name, params=parameters, instrument_session=instrument_session + ) try: if foreground: diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 7210ec2bb..7b1916be8 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -25,7 +25,7 @@ from stomp.connect import StompConnection11 as Connection from blueapi import __version__ -from blueapi.cli.cli import main +from blueapi.cli.cli import ParametersType, main from blueapi.cli.format import OutputFormat, fmt_dict from blueapi.client.event_bus import BlueskyStreamingError from blueapi.client.rest import ( @@ -714,7 +714,7 @@ def test_error_handling(exception, error_message, runner: CliRunner): "params, error", [ ("{", "Parameters are not valid JSON"), - ("[]", ""), + ("[]", "Parameters must be a JSON object with string keys"), ], ) def test_run_task_parsing_errors(params: str, error: str, runner: CliRunner): @@ -731,8 +731,8 @@ def test_run_task_parsing_errors(params: str, error: str, runner: CliRunner): params, ], ) - assert result.stderr.startswith("Error: " + error) - assert result.exit_code == 1 + assert error in result.stderr + assert result.exit_code == 2 def test_device_output_formatting(): @@ -1329,3 +1329,9 @@ def test_config_schema( stream.write.assert_called() else: assert json.loads(result.output) == expected + + +@pytest.mark.parametrize("value,result", [({}, {}), ("{}", {}), (None, None)]) +def test_task_parameter_type(value, result): + t = ParametersType() + assert t.convert(value, None, None) == result