Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion sqlmesh/core/config/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def load_configs(
paths: t.Union[str | Path, t.Iterable[str | Path]],
sqlmesh_path: t.Optional[Path] = None,
dotenv_path: t.Optional[Path] = None,
**kwargs: t.Any,
) -> t.Dict[Path, C]:
sqlmesh_path = sqlmesh_path or c.SQLMESH_PATH
config = config or "config"
Expand Down Expand Up @@ -70,6 +71,7 @@ def load_configs(
project_paths=[path / name for name in ALL_CONFIG_FILENAMES],
personal_paths=personal_paths,
config_name=config,
**kwargs,
)
for path in absolute_paths
}
Expand All @@ -81,6 +83,7 @@ def load_config_from_paths(
personal_paths: t.Optional[t.List[Path]] = None,
config_name: str = "config",
load_from_env: bool = True,
**kwargs: t.Any,
) -> C:
project_paths = project_paths or []
personal_paths = personal_paths or []
Expand Down Expand Up @@ -168,7 +171,11 @@ def load_config_from_paths(
if dbt_project_file:
from sqlmesh.dbt.loader import sqlmesh_config

dbt_python_config = sqlmesh_config(project_root=dbt_project_file.parent)
dbt_python_config = sqlmesh_config(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config_loader_kwargs parameter was added to Context so it could flow through to here.

I first attempted to make these arguments to the DbtLoader, but theyre needed to work out the ConnectionConfig and by the time DbtLoader is invoked it's too late

project_root=dbt_project_file.parent,
dbt_profile_name=kwargs.pop("profile", None),
dbt_target_name=kwargs.pop("target", None),
)
if type(dbt_python_config) != config_type:
dbt_python_config = convert_config_type(dbt_python_config, config_type)

Expand Down
5 changes: 4 additions & 1 deletion sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,12 @@ def __init__(
loader: t.Optional[t.Type[Loader]] = None,
load: bool = True,
users: t.Optional[t.List[User]] = None,
config_loader_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
):
self.configs = (
config if isinstance(config, dict) else load_configs(config, self.CONFIG_TYPE, paths)
config
if isinstance(config, dict)
else load_configs(config, self.CONFIG_TYPE, paths, **(config_loader_kwargs or {}))
)
self._projects = {config.project for config in self.configs.values()}
self.dag: DAG[str] = DAG()
Expand Down
3 changes: 2 additions & 1 deletion sqlmesh/dbt/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,14 @@
def sqlmesh_config(
project_root: t.Optional[Path] = None,
state_connection: t.Optional[ConnectionConfig] = None,
dbt_profile_name: t.Optional[str] = None,
dbt_target_name: t.Optional[str] = None,
variables: t.Optional[t.Dict[str, t.Any]] = None,
register_comments: t.Optional[bool] = None,
**kwargs: t.Any,
) -> Config:
project_root = project_root or Path()
context = DbtContext(project_root=project_root)
context = DbtContext(project_root=project_root, profile_name=dbt_profile_name)
profile = Profile.load(context, target_name=dbt_target_name)
model_defaults = kwargs.pop("model_defaults", ModelDefaultsConfig())
if model_defaults.dialect is None:
Expand Down
4 changes: 3 additions & 1 deletion sqlmesh/dbt/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,10 @@ def _read_profile(
target_name = context.render(project_data.get("target"))

if target_name not in outputs:
target_names = "\n".join(f"- {name}" for name in outputs)
raise ConfigError(
f"Target '{target_name}' not specified in profiles for '{context.profile_name}'."
f"Target '{target_name}' not specified in profiles for '{context.profile_name}'. "
f"The valid target names for this profile are:\n{target_names}"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that the dbt cli lists valid targets if you get the profile correct but the target wrong, so I updated this to do the same

)

target_fields = load_yaml(context.render(yaml.dump(outputs[target_name])))
Expand Down
18 changes: 15 additions & 3 deletions sqlmesh_dbt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import sys
import click
from sqlmesh_dbt.operations import DbtOperations, create
from sqlmesh_dbt.error import cli_global_error_handler
from pathlib import Path


def _get_dbt_operations(ctx: click.Context) -> DbtOperations:
Expand All @@ -10,9 +12,14 @@ def _get_dbt_operations(ctx: click.Context) -> DbtOperations:
return ctx.obj


@click.group()
@click.group(invoke_without_command=True)
@click.option("--profile", help="Which existing profile to load. Overrides output.profile")
@click.option("-t", "--target", help="Which target to load for the given profile")
@click.pass_context
def dbt(ctx: click.Context) -> None:
@cli_global_error_handler
def dbt(
ctx: click.Context, profile: t.Optional[str] = None, target: t.Optional[str] = None
) -> None:
"""
An ELT tool for managing your SQL transformations and data models, powered by the SQLMesh engine.
"""
Expand All @@ -22,7 +29,12 @@ def dbt(ctx: click.Context) -> None:
return

# TODO: conditionally call create() if there are times we dont want/need to import sqlmesh and load a project
ctx.obj = create()
ctx.obj = create(project_dir=Path.cwd(), profile=profile, target=target)

if not ctx.invoked_subcommand:
click.echo(
f"No command specified. Run `{ctx.info_name} --help` to see the available commands."
)


@dbt.command()
Expand Down
41 changes: 41 additions & 0 deletions sqlmesh_dbt/error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import typing as t
import logging
from functools import wraps
import click
import sys

logger = logging.getLogger(__name__)


def cli_global_error_handler(
func: t.Callable[..., t.Any],
) -> t.Callable[..., t.Any]:
@wraps(func)
def wrapper(*args: t.List[t.Any], **kwargs: t.Any) -> t.Any:
try:
return func(*args, **kwargs)
except Exception as ex:
# these imports are deliberately deferred to avoid the penalty of importing the `sqlmesh`
# package up front for every CLI command
from sqlmesh.utils.errors import SQLMeshError
from sqlglot.errors import SqlglotError

if isinstance(ex, (SQLMeshError, SqlglotError, ValueError)):
click.echo(click.style("Error: " + str(ex), fg="red"))
sys.exit(1)
else:
raise
finally:
context_or_obj = args[0]
sqlmesh_context = (
context_or_obj.obj if isinstance(context_or_obj, click.Context) else context_or_obj
)
if sqlmesh_context is not None:
# important to import this only if a context was created
# otherwise something like `sqlmesh_dbt run --help` will trigger this import because it's in the finally: block
from sqlmesh import Context

if isinstance(sqlmesh_context, Context):
sqlmesh_context.close()

return wrapper
6 changes: 5 additions & 1 deletion sqlmesh_dbt/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ def console(self) -> DbtCliConsole:


def create(
project_dir: t.Optional[Path] = None, profiles_dir: t.Optional[Path] = None, debug: bool = False
project_dir: t.Optional[Path] = None,
profile: t.Optional[str] = None,
target: t.Optional[str] = None,
debug: bool = False,
) -> DbtOperations:
with Progress(transient=True) as progress:
# Indeterminate progress bar before SQLMesh import to provide feedback to the user that something is indeed happening
Expand All @@ -76,6 +79,7 @@ def create(

sqlmesh_context = Context(
paths=[project_dir],
config_loader_kwargs=dict(profile=profile, target=target),
load=True,
)

Expand Down
30 changes: 30 additions & 0 deletions tests/dbt/cli/test_global_flags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import typing as t
from pathlib import Path
import pytest
from click.testing import Result

pytestmark = pytest.mark.slow


def test_profile_and_target(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
# profile doesnt exist - error
result = invoke_cli(["--profile", "nonexist"])
assert result.exit_code == 1
assert "Profile 'nonexist' not found in profiles" in result.output

# profile exists - successful load with default target
result = invoke_cli(["--profile", "jaffle_shop"])
assert result.exit_code == 0
assert "No command specified" in result.output

# profile exists but target doesnt - error
result = invoke_cli(["--profile", "jaffle_shop", "--target", "nonexist"])
assert result.exit_code == 1
assert "Target 'nonexist' not specified in profiles" in result.output
assert "valid target names for this profile are" in result.output
assert "- dev" in result.output

# profile exists and so does target - successful load with specified target
result = invoke_cli(["--profile", "jaffle_shop", "--target", "dev"])
assert result.exit_code == 0
assert "No command specified" in result.output
16 changes: 16 additions & 0 deletions tests/dbt/cli/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
from sqlmesh_dbt.operations import create
from sqlmesh.utils import yaml
from sqlmesh.utils.errors import SQLMeshError
import time_machine

pytestmark = pytest.mark.slow
Expand Down Expand Up @@ -53,3 +54,18 @@ def test_create_uses_configured_start_date_if_supplied(jaffle_shop_duckdb: Path)
for model in operations.context.models.values()
if not model.kind.is_seed
)


def test_create_can_specify_profile_and_target(jaffle_shop_duckdb: Path):
with pytest.raises(SQLMeshError, match=r"Profile 'foo' not found"):
create(profile="foo")

with pytest.raises(
SQLMeshError, match=r"Target 'prod' not specified in profiles for 'jaffle_shop'"
):
create(profile="jaffle_shop", target="prod")

dbt_project = create(profile="jaffle_shop", target="dev").project

assert dbt_project.context.profile_name == "jaffle_shop"
assert dbt_project.context.target_name == "dev"