Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqlmesh/core/config/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def load_config_from_paths(
project_root=dbt_project_file.parent,
dbt_profile_name=kwargs.pop("profile", None),
dbt_target_name=kwargs.pop("target", None),
variables=variables,
)
if type(dbt_python_config) != config_type:
dbt_python_config = convert_config_type(dbt_python_config, config_type)
Expand Down
44 changes: 35 additions & 9 deletions sqlmesh_dbt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,31 @@
from sqlmesh_dbt.operations import DbtOperations, create
from sqlmesh_dbt.error import cli_global_error_handler
from pathlib import Path
from sqlmesh_dbt.options import YamlParamType
import functools


def _get_dbt_operations(ctx: click.Context) -> DbtOperations:
if not isinstance(ctx.obj, DbtOperations):
def _get_dbt_operations(ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]]) -> DbtOperations:
if not isinstance(ctx.obj, functools.partial):
raise ValueError(f"Unexpected click context object: {type(ctx.obj)}")
return ctx.obj

dbt_operations = ctx.obj(vars=vars)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we handle this differently from profile / target`?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--profile / --target are top level dbt options - they show in dbt --help and apply to every subcommand.

--vars only applies to certain subcommands. eg it shows in dbt list --help and dbt run --help but not dbt source --help


if not isinstance(dbt_operations, DbtOperations):
raise ValueError(f"Unexpected dbt operations type: {type(dbt_operations)}")

@ctx.call_on_close
def _cleanup() -> None:
dbt_operations.close()

return dbt_operations


vars_option = click.option(
"--vars",
type=YamlParamType(),
help="Supply variables to the project. This argument overrides variables defined in your dbt_project.yml file. This argument should be a YAML string, eg. '{my_variable: my_value}'",
)


select_option = click.option(
Expand Down Expand Up @@ -40,10 +59,15 @@ def dbt(
# we dont need to import sqlmesh/load the project for CLI help
return

# TODO: conditionally call create() if there are times we dont want/need to import sqlmesh and load a project
ctx.obj = create(project_dir=Path.cwd(), profile=profile, target=target)
# we have a partially applied function here because subcommands might set extra options like --vars
# that need to be known before we attempt to load the project
ctx.obj = functools.partial(create, project_dir=Path.cwd(), profile=profile, target=target)

if not ctx.invoked_subcommand:
if profile or target:
# trigger a project load to validate the specified profile / target
ctx.obj()

click.echo(
f"No command specified. Run `{ctx.info_name} --help` to see the available commands."
)
Expand All @@ -57,19 +81,21 @@ def dbt(
"--full-refresh",
help="If specified, dbt will drop incremental models and fully-recalculate the incremental table from the model definition.",
)
@vars_option
@click.pass_context
def run(ctx: click.Context, **kwargs: t.Any) -> None:
def run(ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]], **kwargs: t.Any) -> None:
"""Compile SQL and execute against the current target database."""
_get_dbt_operations(ctx).run(**kwargs)
_get_dbt_operations(ctx, vars).run(**kwargs)


@dbt.command(name="list")
@select_option
@exclude_option
@vars_option
@click.pass_context
def list_(ctx: click.Context, **kwargs: t.Any) -> None:
def list_(ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]], **kwargs: t.Any) -> None:
"""List the resources in your project"""
_get_dbt_operations(ctx).list_(**kwargs)
_get_dbt_operations(ctx, vars).list_(**kwargs)


@dbt.command(name="ls", hidden=True) # hidden alias for list
Expand Down
12 changes: 0 additions & 12 deletions sqlmesh_dbt/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,5 @@ def wrapper(*args: t.List[t.Any], **kwargs: t.Any) -> t.Any:
sys.exit(1)
else:
raise
finally:
context_or_obj = args[0]
sqlmesh_context = (
context_or_obj.obj if isinstance(context_or_obj, click.Context) else context_or_obj
)
if sqlmesh_context is not None:
# important to import this only if a context was created
# otherwise something like `sqlmesh_dbt run --help` will trigger this import because it's in the finally: block
from sqlmesh import Context

if isinstance(sqlmesh_context, Context):
sqlmesh_context.close()

return wrapper
6 changes: 5 additions & 1 deletion sqlmesh_dbt/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,15 @@ def console(self) -> DbtCliConsole:

return console

def close(self) -> None:
self.context.close()


def create(
project_dir: t.Optional[Path] = None,
profile: t.Optional[str] = None,
target: t.Optional[str] = None,
vars: t.Optional[t.Dict[str, t.Any]] = None,
debug: bool = False,
) -> DbtOperations:
with Progress(transient=True) as progress:
Expand All @@ -104,7 +108,7 @@ def create(

sqlmesh_context = Context(
paths=[project_dir],
config_loader_kwargs=dict(profile=profile, target=target),
config_loader_kwargs=dict(profile=profile, target=target, variables=vars),
load=True,
)

Expand Down
25 changes: 25 additions & 0 deletions sqlmesh_dbt/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import typing as t
import click
from click.core import Context, Parameter


class YamlParamType(click.ParamType):
name = "yaml"

def convert(
self, value: t.Any, param: t.Optional[Parameter], ctx: t.Optional[Context]
) -> t.Any:
if not isinstance(value, str):
self.fail(f"Input value '{value}' should be a string", param, ctx)

from sqlmesh.utils import yaml

try:
parsed = yaml.load(source=value, render_jinja=False)
except:
self.fail(f"String '{value}' is not valid YAML", param, ctx)

if not isinstance(parsed, dict):
self.fail(f"String '{value}' did not evaluate to a dict, got: {parsed}", param, ctx)

return parsed
14 changes: 14 additions & 0 deletions tests/dbt/cli/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,17 @@ def test_list_select_exclude(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..
assert "main.orders" not in result.output
assert "main.stg_payments" not in result.output
assert "main.raw_orders" not in result.output


def test_list_with_vars(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
(jaffle_shop_duckdb / "models" / "aliased_model.sql").write_text("""
{{ config(alias='model_' + var('foo')) }}
select 1
""")

result = invoke_cli(["list", "--vars", "foo: bar"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to test more variances like:

$ dbt run --vars '{"key": "value", "date": 20180101}'
$ dbt run --vars '{key: value, date: 20180101}'

or are we confident that the YAML parser takes care of it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The YAML parser does indeed take care of it, but I added a test test_yaml_param_type() to prove it


assert result.exit_code == 0
assert not result.exception

assert "model_bar" in result.output
14 changes: 14 additions & 0 deletions tests/dbt/cli/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,17 @@ def test_create_can_specify_profile_and_target(jaffle_shop_duckdb: Path):

assert dbt_project.context.profile_name == "jaffle_shop"
assert dbt_project.context.target_name == "dev"


def test_create_can_set_project_variables(jaffle_shop_duckdb: Path):
(jaffle_shop_duckdb / "models" / "test_model.sql").write_text("""
select '{{ var('foo') }}' as a
""")

dbt_project = create(vars={"foo": "bar"})
assert dbt_project.context.config.variables["foo"] == "bar"

test_model = dbt_project.context.models['"jaffle_shop"."main"."test_model"']
query = test_model.render_query()
assert query is not None
assert query.sql() == "SELECT 'bar' AS \"a\""
23 changes: 23 additions & 0 deletions tests/dbt/cli/test_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import typing as t
import pytest
from sqlmesh_dbt.options import YamlParamType
from click.exceptions import BadParameter


@pytest.mark.parametrize(
"input,expected",
[
(1, BadParameter("Input value '1' should be a string")),
("", BadParameter("String '' is not valid YAML")),
("['a', 'b']", BadParameter("String.*did not evaluate to a dict, got.*")),
("foo: bar", {"foo": "bar"}),
('{"key": "value", "date": 20180101}', {"key": "value", "date": 20180101}),
("{key: value, date: 20180101}", {"key": "value", "date": 20180101}),
],
)
def test_yaml_param_type(input: str, expected: t.Union[BadParameter, t.Dict[str, t.Any]]):
if isinstance(expected, BadParameter):
with pytest.raises(BadParameter, match=expected.message):
YamlParamType().convert(input, None, None)
else:
assert YamlParamType().convert(input, None, None) == expected