Skip to content

Commit a38ca6b

Browse files
committed
Feat(dbt_cli): Add support for --vars
1 parent fe64851 commit a38ca6b

File tree

7 files changed

+95
-22
lines changed

7 files changed

+95
-22
lines changed

sqlmesh/core/config/loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def load_config_from_paths(
175175
project_root=dbt_project_file.parent,
176176
dbt_profile_name=kwargs.pop("profile", None),
177177
dbt_target_name=kwargs.pop("target", None),
178+
variables=kwargs.pop("variables", None),
178179
)
179180
if type(dbt_python_config) != config_type:
180181
dbt_python_config = convert_config_type(dbt_python_config, config_type)

sqlmesh_dbt/cli.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,31 @@
44
from sqlmesh_dbt.operations import DbtOperations, create
55
from sqlmesh_dbt.error import cli_global_error_handler
66
from pathlib import Path
7+
from sqlmesh_dbt.options import YamlParamType
8+
import functools
79

810

9-
def _get_dbt_operations(ctx: click.Context) -> DbtOperations:
10-
if not isinstance(ctx.obj, DbtOperations):
11+
def _get_dbt_operations(ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]]) -> DbtOperations:
12+
if not isinstance(ctx.obj, functools.partial):
1113
raise ValueError(f"Unexpected click context object: {type(ctx.obj)}")
12-
return ctx.obj
14+
15+
dbt_operations = ctx.obj(vars=vars)
16+
17+
if not isinstance(dbt_operations, DbtOperations):
18+
raise ValueError(f"Unexpected dbt operations type: {type(dbt_operations)}")
19+
20+
@ctx.call_on_close
21+
def _cleanup() -> None:
22+
dbt_operations.close()
23+
24+
return dbt_operations
25+
26+
27+
vars_option = click.option(
28+
"--vars",
29+
type=YamlParamType(),
30+
help="Supply variables to the project. This argument overrides variables defined in your dbt_project.yml file. This argument should be a YAML string, eg. '{my_variable: my_value}'",
31+
)
1332

1433

1534
@click.group(invoke_without_command=True)
@@ -28,8 +47,9 @@ def dbt(
2847
# we dont need to import sqlmesh/load the project for CLI help
2948
return
3049

31-
# TODO: conditionally call create() if there are times we dont want/need to import sqlmesh and load a project
32-
ctx.obj = create(project_dir=Path.cwd(), profile=profile, target=target)
50+
# we have a partially applied function here because subcommands might set extra options like --vars
51+
# that need to be known before we attempt to load the project
52+
ctx.obj = functools.partial(create, project_dir=Path.cwd(), profile=profile, target=target)
3353

3454
if not ctx.invoked_subcommand:
3555
click.echo(
@@ -44,17 +64,24 @@ def dbt(
4464
"--full-refresh",
4565
help="If specified, dbt will drop incremental models and fully-recalculate the incremental table from the model definition.",
4666
)
67+
@vars_option
4768
@click.pass_context
48-
def run(ctx: click.Context, select: t.Optional[str], full_refresh: bool) -> None:
69+
def run(
70+
ctx: click.Context,
71+
vars: t.Optional[t.Dict[str, t.Any]],
72+
select: t.Optional[str],
73+
full_refresh: bool,
74+
) -> None:
4975
"""Compile SQL and execute against the current target database."""
50-
_get_dbt_operations(ctx).run(select=select, full_refresh=full_refresh)
76+
_get_dbt_operations(ctx, vars).run(select=select, full_refresh=full_refresh)
5177

5278

5379
@dbt.command(name="list")
80+
@vars_option
5481
@click.pass_context
55-
def list_(ctx: click.Context) -> None:
82+
def list_(ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]]) -> None:
5683
"""List the resources in your project"""
57-
_get_dbt_operations(ctx).list_()
84+
_get_dbt_operations(ctx, vars).list_()
5885

5986

6087
@dbt.command(name="ls", hidden=True) # hidden alias for list

sqlmesh_dbt/error.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,5 @@ def wrapper(*args: t.List[t.Any], **kwargs: t.Any) -> t.Any:
2525
sys.exit(1)
2626
else:
2727
raise
28-
finally:
29-
context_or_obj = args[0]
30-
sqlmesh_context = (
31-
context_or_obj.obj if isinstance(context_or_obj, click.Context) else context_or_obj
32-
)
33-
if sqlmesh_context is not None:
34-
# important to import this only if a context was created
35-
# otherwise something like `sqlmesh_dbt run --help` will trigger this import because it's in the finally: block
36-
from sqlmesh import Context
37-
38-
if isinstance(sqlmesh_context, Context):
39-
sqlmesh_context.close()
4028

4129
return wrapper

sqlmesh_dbt/operations.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,15 @@ def console(self) -> DbtCliConsole:
5151

5252
return console
5353

54+
def close(self) -> None:
55+
self.context.close()
56+
5457

5558
def create(
5659
project_dir: t.Optional[Path] = None,
5760
profile: t.Optional[str] = None,
5861
target: t.Optional[str] = None,
62+
vars: t.Optional[t.Dict[str, t.Any]] = None,
5963
debug: bool = False,
6064
) -> DbtOperations:
6165
with Progress(transient=True) as progress:
@@ -79,7 +83,7 @@ def create(
7983

8084
sqlmesh_context = Context(
8185
paths=[project_dir],
82-
config_loader_kwargs=dict(profile=profile, target=target),
86+
config_loader_kwargs=dict(profile=profile, target=target, variables=vars),
8387
load=True,
8488
)
8589

sqlmesh_dbt/options.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import typing as t
2+
import click
3+
from click.core import Context, Parameter
4+
5+
6+
class YamlParamType(click.ParamType):
7+
name = "yaml"
8+
9+
def convert(
10+
self, value: t.Any, param: t.Optional[Parameter], ctx: t.Optional[Context]
11+
) -> t.Any:
12+
if not isinstance(value, str):
13+
self.fail(f"Input value '{value}' should be a string", param, ctx)
14+
15+
from sqlmesh.utils import yaml
16+
17+
try:
18+
parsed = yaml.load(source=value, render_jinja=False)
19+
except:
20+
self.fail(f"String '{value}' is not valid YAML", param, ctx)
21+
22+
if not isinstance(parsed, dict):
23+
self.fail(f"String '{value}' did not evaluate to a dict, got: {parsed}", param, ctx)
24+
25+
return parsed

tests/dbt/cli/test_list.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,17 @@ def test_list(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
1515
assert "main.orders" in result.output
1616
assert "main.customers" in result.output
1717
assert "main.stg_payments" in result.output
18+
19+
20+
def test_list_with_vars(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
21+
(jaffle_shop_duckdb / "models" / "aliased_model.sql").write_text("""
22+
{{ config(alias='model_' + var('foo')) }}
23+
select 1
24+
""")
25+
26+
result = invoke_cli(["list", "--vars", "foo: bar"])
27+
28+
assert result.exit_code == 0
29+
assert not result.exception
30+
31+
assert "model_bar" in result.output

tests/dbt/cli/test_operations.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,17 @@ def test_create_can_specify_profile_and_target(jaffle_shop_duckdb: Path):
6969

7070
assert dbt_project.context.profile_name == "jaffle_shop"
7171
assert dbt_project.context.target_name == "dev"
72+
73+
74+
def test_create_can_set_project_variables(jaffle_shop_duckdb: Path):
75+
(jaffle_shop_duckdb / "models" / "test_model.sql").write_text("""
76+
select '{{ var('foo') }}' as a
77+
""")
78+
79+
dbt_project = create(vars={"foo": "bar"})
80+
assert dbt_project.context.config.variables["foo"] == "bar"
81+
82+
test_model = dbt_project.context.models['"jaffle_shop"."main"."test_model"']
83+
query = test_model.render_query()
84+
assert query is not None
85+
assert query.sql() == "SELECT 'bar' AS \"a\""

0 commit comments

Comments
 (0)