diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 9be9549c8..4e44935a1 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -83,26 +83,23 @@ def is_str_dict(val: Any) -> TypeGuard[TaskParameters]: "-c", "--config", type=Path, help="Path to configuration YAML file", multiple=True ) @click.pass_context -def main(ctx: click.Context, config: Path | None | tuple[Path, ...]) -> None: +def main(ctx: click.Context, config: tuple[Path, ...]) -> None: # if no command is supplied, run with the options passed # Set umask to DLS standard os.umask(stat.S_IWOTH) config_loader = ConfigLoader(ApplicationConfig) - if config is not None: - configs = (config,) if isinstance(config, Path) else config - for path in configs: - if path.exists(): - config_loader.use_values_from_yaml(path) - else: - raise FileNotFoundError(f"Cannot find file: {path}") + try: + config_loader.use_values_from_yaml(*config) + except FileNotFoundError as fnfe: + raise ClickException(f"Config file not found: {fnfe.filename}") from fnfe - ctx.ensure_object(dict) loaded_config: ApplicationConfig = config_loader.load() set_up_logging(loaded_config.logging) + ctx.ensure_object(dict) ctx.obj["config"] = loaded_config if ctx.invoked_subcommand is None: diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 214016860..6f0ca4ae7 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -358,9 +358,9 @@ def recursively_update_map(old: dict[str, Any], new: Mapping[str, Any]) -> None: recursively_update_map(self._values, values) - def use_values_from_yaml(self, path: Path) -> None: + def use_values_from_yaml(self, *paths: Path) -> None: """ - Use all values provided in a YAML/JSON file in the + Use all values provided in a YAML/JSON files in the config, override any defaults and values set by previous calls into this class. @@ -368,9 +368,9 @@ def use_values_from_yaml(self, path: Path) -> None: path (Path): Path to YAML/JSON file """ - with path.open("r") as stream: - values = yaml.load(stream, yaml.Loader) - self.use_values(values) + for path in paths: + with path.open("r") as stream: + self.use_values(yaml.load(stream, yaml.Loader)) def load(self) -> C: """