diff --git a/sqlmesh_dbt/cli.py b/sqlmesh_dbt/cli.py index 2daa3f9d54..c215663f0a 100644 --- a/sqlmesh_dbt/cli.py +++ b/sqlmesh_dbt/cli.py @@ -2,7 +2,7 @@ import sys import click from sqlmesh_dbt.operations import DbtOperations, create -from sqlmesh_dbt.error import cli_global_error_handler +from sqlmesh_dbt.error import cli_global_error_handler, ErrorHandlingGroup from pathlib import Path from sqlmesh_dbt.options import YamlParamType import functools @@ -43,7 +43,7 @@ def _cleanup() -> None: exclude_option = click.option("--exclude", multiple=True, help="Specify the nodes to exclude.") -@click.group(invoke_without_command=True) +@click.group(cls=ErrorHandlingGroup, invoke_without_command=True) @click.option("--profile", help="Which existing profile to load. Overrides output.profile") @click.option("-t", "--target", help="Which target to load for the given profile") @click.option( diff --git a/sqlmesh_dbt/error.py b/sqlmesh_dbt/error.py index 005ca87c50..49a2f8195b 100644 --- a/sqlmesh_dbt/error.py +++ b/sqlmesh_dbt/error.py @@ -27,3 +27,10 @@ def wrapper(*args: t.List[t.Any], **kwargs: t.Any) -> t.Any: raise return wrapper + + +class ErrorHandlingGroup(click.Group): + def add_command(self, cmd: click.Command, name: t.Optional[str] = None) -> None: + if cmd.callback: + cmd.callback = cli_global_error_handler(cmd.callback) + super().add_command(cmd, name=name) diff --git a/tests/dbt/cli/test_global_flags.py b/tests/dbt/cli/test_global_flags.py index 802d359346..66dee7236c 100644 --- a/tests/dbt/cli/test_global_flags.py +++ b/tests/dbt/cli/test_global_flags.py @@ -1,7 +1,10 @@ import typing as t from pathlib import Path import pytest +from pytest_mock import MockerFixture from click.testing import Result +from sqlmesh.utils.errors import SQLMeshError +from sqlglot.errors import SqlglotError pytestmark = pytest.mark.slow @@ -28,3 +31,65 @@ def test_profile_and_target(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[... result = invoke_cli(["--profile", "jaffle_shop", "--target", "dev"]) assert result.exit_code == 0 assert "No command specified" in result.output + + +def test_run_error_handler( + jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result], mocker: MockerFixture +) -> None: + mock_run = mocker.patch("sqlmesh_dbt.operations.DbtOperations.run") + mock_run.side_effect = SQLMeshError("Test error message") + + result = invoke_cli(["run"]) + assert result.exit_code == 1 + assert "Error: Test error message" in result.output + assert "Traceback" not in result.output + + # test SqlglotError in run command + mock_run = mocker.patch("sqlmesh_dbt.operations.DbtOperations.run") + mock_run.side_effect = SqlglotError("Invalid SQL syntax") + + result = invoke_cli(["run"]) + + assert result.exit_code == 1 + assert "Error: Invalid SQL syntax" in result.output + assert "Traceback" not in result.output + + # test ValueError in run command + mock_run = mocker.patch("sqlmesh_dbt.operations.DbtOperations.run") + mock_run.side_effect = ValueError("Invalid configuration value") + + result = invoke_cli(["run"]) + + assert result.exit_code == 1 + assert "Error: Invalid configuration value" in result.output + assert "Traceback" not in result.output + + # test SQLMeshError in list command + mock_list = mocker.patch("sqlmesh_dbt.operations.DbtOperations.list_") + mock_list.side_effect = SQLMeshError("List command error") + + result = invoke_cli(["list"]) + + assert result.exit_code == 1 + assert "Error: List command error" in result.output + assert "Traceback" not in result.output + + # test SQLMeshError in main command without subcommand + mock_create = mocker.patch("sqlmesh_dbt.cli.create") + mock_create.side_effect = SQLMeshError("Failed to load project") + result = invoke_cli(["--profile", "jaffle_shop"]) + + assert result.exit_code == 1 + assert "Error: Failed to load project" in result.output + assert "Traceback" not in result.output + mocker.stopall() + + # test error with select option + mock_run_select = mocker.patch("sqlmesh_dbt.operations.DbtOperations.run") + mock_run_select.side_effect = SQLMeshError("Error with selector") + + result = invoke_cli(["run", "--select", "model1"]) + + assert result.exit_code == 1 + assert "Error: Error with selector" in result.output + assert "Traceback" not in result.output