diff --git a/sqlmesh/core/config/loader.py b/sqlmesh/core/config/loader.py index 2c1554454b..a7b997e303 100644 --- a/sqlmesh/core/config/loader.py +++ b/sqlmesh/core/config/loader.py @@ -32,6 +32,7 @@ def load_configs( paths: t.Union[str | Path, t.Iterable[str | Path]], sqlmesh_path: t.Optional[Path] = None, dotenv_path: t.Optional[Path] = None, + **kwargs: t.Any, ) -> t.Dict[Path, C]: sqlmesh_path = sqlmesh_path or c.SQLMESH_PATH config = config or "config" @@ -70,6 +71,7 @@ def load_configs( project_paths=[path / name for name in ALL_CONFIG_FILENAMES], personal_paths=personal_paths, config_name=config, + **kwargs, ) for path in absolute_paths } @@ -81,6 +83,7 @@ def load_config_from_paths( personal_paths: t.Optional[t.List[Path]] = None, config_name: str = "config", load_from_env: bool = True, + **kwargs: t.Any, ) -> C: project_paths = project_paths or [] personal_paths = personal_paths or [] @@ -168,7 +171,11 @@ def load_config_from_paths( if dbt_project_file: from sqlmesh.dbt.loader import sqlmesh_config - dbt_python_config = sqlmesh_config(project_root=dbt_project_file.parent) + dbt_python_config = sqlmesh_config( + project_root=dbt_project_file.parent, + dbt_profile_name=kwargs.pop("profile", None), + dbt_target_name=kwargs.pop("target", None), + ) if type(dbt_python_config) != config_type: dbt_python_config = convert_config_type(dbt_python_config, config_type) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index eca60ecea9..9022f3f069 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -367,9 +367,12 @@ def __init__( loader: t.Optional[t.Type[Loader]] = None, load: bool = True, users: t.Optional[t.List[User]] = None, + config_loader_kwargs: t.Optional[t.Dict[str, t.Any]] = None, ): self.configs = ( - config if isinstance(config, dict) else load_configs(config, self.CONFIG_TYPE, paths) + config + if isinstance(config, dict) + else load_configs(config, self.CONFIG_TYPE, paths, **(config_loader_kwargs or {})) ) self._projects = {config.project for config in self.configs.values()} self.dag: DAG[str] = DAG() diff --git a/sqlmesh/dbt/loader.py b/sqlmesh/dbt/loader.py index 3a22b61bf6..b4e8caf0bc 100644 --- a/sqlmesh/dbt/loader.py +++ b/sqlmesh/dbt/loader.py @@ -44,13 +44,14 @@ def sqlmesh_config( project_root: t.Optional[Path] = None, state_connection: t.Optional[ConnectionConfig] = None, + dbt_profile_name: t.Optional[str] = None, dbt_target_name: t.Optional[str] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, register_comments: t.Optional[bool] = None, **kwargs: t.Any, ) -> Config: project_root = project_root or Path() - context = DbtContext(project_root=project_root) + context = DbtContext(project_root=project_root, profile_name=dbt_profile_name) profile = Profile.load(context, target_name=dbt_target_name) model_defaults = kwargs.pop("model_defaults", ModelDefaultsConfig()) if model_defaults.dialect is None: diff --git a/sqlmesh/dbt/profile.py b/sqlmesh/dbt/profile.py index 72634833a6..ea0384c786 100644 --- a/sqlmesh/dbt/profile.py +++ b/sqlmesh/dbt/profile.py @@ -101,8 +101,10 @@ def _read_profile( target_name = context.render(project_data.get("target")) if target_name not in outputs: + target_names = "\n".join(f"- {name}" for name in outputs) raise ConfigError( - f"Target '{target_name}' not specified in profiles for '{context.profile_name}'." + f"Target '{target_name}' not specified in profiles for '{context.profile_name}'. " + f"The valid target names for this profile are:\n{target_names}" ) target_fields = load_yaml(context.render(yaml.dump(outputs[target_name]))) diff --git a/sqlmesh_dbt/cli.py b/sqlmesh_dbt/cli.py index 2ec59b665d..500a9d6fa0 100644 --- a/sqlmesh_dbt/cli.py +++ b/sqlmesh_dbt/cli.py @@ -2,6 +2,8 @@ import sys import click from sqlmesh_dbt.operations import DbtOperations, create +from sqlmesh_dbt.error import cli_global_error_handler +from pathlib import Path def _get_dbt_operations(ctx: click.Context) -> DbtOperations: @@ -10,9 +12,14 @@ def _get_dbt_operations(ctx: click.Context) -> DbtOperations: return ctx.obj -@click.group() +@click.group(invoke_without_command=True) +@click.option("--profile", help="Which existing profile to load. Overrides output.profile") +@click.option("-t", "--target", help="Which target to load for the given profile") @click.pass_context -def dbt(ctx: click.Context) -> None: +@cli_global_error_handler +def dbt( + ctx: click.Context, profile: t.Optional[str] = None, target: t.Optional[str] = None +) -> None: """ An ELT tool for managing your SQL transformations and data models, powered by the SQLMesh engine. """ @@ -22,7 +29,12 @@ def dbt(ctx: click.Context) -> None: return # TODO: conditionally call create() if there are times we dont want/need to import sqlmesh and load a project - ctx.obj = create() + ctx.obj = create(project_dir=Path.cwd(), profile=profile, target=target) + + if not ctx.invoked_subcommand: + click.echo( + f"No command specified. Run `{ctx.info_name} --help` to see the available commands." + ) @dbt.command() diff --git a/sqlmesh_dbt/error.py b/sqlmesh_dbt/error.py new file mode 100644 index 0000000000..f5d4bc438c --- /dev/null +++ b/sqlmesh_dbt/error.py @@ -0,0 +1,41 @@ +import typing as t +import logging +from functools import wraps +import click +import sys + +logger = logging.getLogger(__name__) + + +def cli_global_error_handler( + func: t.Callable[..., t.Any], +) -> t.Callable[..., t.Any]: + @wraps(func) + def wrapper(*args: t.List[t.Any], **kwargs: t.Any) -> t.Any: + try: + return func(*args, **kwargs) + except Exception as ex: + # these imports are deliberately deferred to avoid the penalty of importing the `sqlmesh` + # package up front for every CLI command + from sqlmesh.utils.errors import SQLMeshError + from sqlglot.errors import SqlglotError + + if isinstance(ex, (SQLMeshError, SqlglotError, ValueError)): + click.echo(click.style("Error: " + str(ex), fg="red")) + sys.exit(1) + else: + raise + finally: + context_or_obj = args[0] + sqlmesh_context = ( + context_or_obj.obj if isinstance(context_or_obj, click.Context) else context_or_obj + ) + if sqlmesh_context is not None: + # important to import this only if a context was created + # otherwise something like `sqlmesh_dbt run --help` will trigger this import because it's in the finally: block + from sqlmesh import Context + + if isinstance(sqlmesh_context, Context): + sqlmesh_context.close() + + return wrapper diff --git a/sqlmesh_dbt/operations.py b/sqlmesh_dbt/operations.py index ec07efd37b..f9aae3cdac 100644 --- a/sqlmesh_dbt/operations.py +++ b/sqlmesh_dbt/operations.py @@ -53,7 +53,10 @@ def console(self) -> DbtCliConsole: def create( - project_dir: t.Optional[Path] = None, profiles_dir: t.Optional[Path] = None, debug: bool = False + project_dir: t.Optional[Path] = None, + profile: t.Optional[str] = None, + target: t.Optional[str] = None, + debug: bool = False, ) -> DbtOperations: with Progress(transient=True) as progress: # Indeterminate progress bar before SQLMesh import to provide feedback to the user that something is indeed happening @@ -76,6 +79,7 @@ def create( sqlmesh_context = Context( paths=[project_dir], + config_loader_kwargs=dict(profile=profile, target=target), load=True, ) diff --git a/tests/dbt/cli/test_global_flags.py b/tests/dbt/cli/test_global_flags.py new file mode 100644 index 0000000000..802d359346 --- /dev/null +++ b/tests/dbt/cli/test_global_flags.py @@ -0,0 +1,30 @@ +import typing as t +from pathlib import Path +import pytest +from click.testing import Result + +pytestmark = pytest.mark.slow + + +def test_profile_and_target(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): + # profile doesnt exist - error + result = invoke_cli(["--profile", "nonexist"]) + assert result.exit_code == 1 + assert "Profile 'nonexist' not found in profiles" in result.output + + # profile exists - successful load with default target + result = invoke_cli(["--profile", "jaffle_shop"]) + assert result.exit_code == 0 + assert "No command specified" in result.output + + # profile exists but target doesnt - error + result = invoke_cli(["--profile", "jaffle_shop", "--target", "nonexist"]) + assert result.exit_code == 1 + assert "Target 'nonexist' not specified in profiles" in result.output + assert "valid target names for this profile are" in result.output + assert "- dev" in result.output + + # profile exists and so does target - successful load with specified target + result = invoke_cli(["--profile", "jaffle_shop", "--target", "dev"]) + assert result.exit_code == 0 + assert "No command specified" in result.output diff --git a/tests/dbt/cli/test_operations.py b/tests/dbt/cli/test_operations.py index c35cab992c..9d36b10f60 100644 --- a/tests/dbt/cli/test_operations.py +++ b/tests/dbt/cli/test_operations.py @@ -2,6 +2,7 @@ import pytest from sqlmesh_dbt.operations import create from sqlmesh.utils import yaml +from sqlmesh.utils.errors import SQLMeshError import time_machine pytestmark = pytest.mark.slow @@ -53,3 +54,18 @@ def test_create_uses_configured_start_date_if_supplied(jaffle_shop_duckdb: Path) for model in operations.context.models.values() if not model.kind.is_seed ) + + +def test_create_can_specify_profile_and_target(jaffle_shop_duckdb: Path): + with pytest.raises(SQLMeshError, match=r"Profile 'foo' not found"): + create(profile="foo") + + with pytest.raises( + SQLMeshError, match=r"Target 'prod' not specified in profiles for 'jaffle_shop'" + ): + create(profile="jaffle_shop", target="prod") + + dbt_project = create(profile="jaffle_shop", target="dev").project + + assert dbt_project.context.profile_name == "jaffle_shop" + assert dbt_project.context.target_name == "dev"