From 5c33f6e540bcce120b674aa1ccf5d40d205a5d41 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Thu, 21 Aug 2025 21:57:20 +0000 Subject: [PATCH 1/4] Feat(dbt_cli): Add support for --vars --- sqlmesh/core/config/loader.py | 1 + sqlmesh_dbt/cli.py | 39 ++++++++++++++++++++++++-------- sqlmesh_dbt/error.py | 12 ---------- sqlmesh_dbt/operations.py | 6 ++++- sqlmesh_dbt/options.py | 25 ++++++++++++++++++++ tests/dbt/cli/test_list.py | 14 ++++++++++++ tests/dbt/cli/test_operations.py | 14 ++++++++++++ 7 files changed, 89 insertions(+), 22 deletions(-) create mode 100644 sqlmesh_dbt/options.py diff --git a/sqlmesh/core/config/loader.py b/sqlmesh/core/config/loader.py index fe9deed0c2..70b086e471 100644 --- a/sqlmesh/core/config/loader.py +++ b/sqlmesh/core/config/loader.py @@ -176,6 +176,7 @@ def load_config_from_paths( project_root=dbt_project_file.parent, dbt_profile_name=kwargs.pop("profile", None), dbt_target_name=kwargs.pop("target", None), + variables=kwargs.pop("variables", None), ) if type(dbt_python_config) != config_type: dbt_python_config = convert_config_type(dbt_python_config, config_type) diff --git a/sqlmesh_dbt/cli.py b/sqlmesh_dbt/cli.py index 7d98e812b7..2c0a207515 100644 --- a/sqlmesh_dbt/cli.py +++ b/sqlmesh_dbt/cli.py @@ -4,12 +4,31 @@ from sqlmesh_dbt.operations import DbtOperations, create from sqlmesh_dbt.error import cli_global_error_handler from pathlib import Path +from sqlmesh_dbt.options import YamlParamType +import functools -def _get_dbt_operations(ctx: click.Context) -> DbtOperations: - if not isinstance(ctx.obj, DbtOperations): +def _get_dbt_operations(ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]]) -> DbtOperations: + if not isinstance(ctx.obj, functools.partial): raise ValueError(f"Unexpected click context object: {type(ctx.obj)}") - return ctx.obj + + dbt_operations = ctx.obj(vars=vars) + + if not isinstance(dbt_operations, DbtOperations): + raise ValueError(f"Unexpected dbt operations type: {type(dbt_operations)}") + + @ctx.call_on_close + def _cleanup() -> None: + dbt_operations.close() + + return dbt_operations + + +vars_option = click.option( + "--vars", + type=YamlParamType(), + help="Supply variables to the project. This argument overrides variables defined in your dbt_project.yml file. This argument should be a YAML string, eg. '{my_variable: my_value}'", +) select_option = click.option( @@ -40,8 +59,9 @@ def dbt( # we dont need to import sqlmesh/load the project for CLI help return - # TODO: conditionally call create() if there are times we dont want/need to import sqlmesh and load a project - ctx.obj = create(project_dir=Path.cwd(), profile=profile, target=target) + # we have a partially applied function here because subcommands might set extra options like --vars + # that need to be known before we attempt to load the project + ctx.obj = functools.partial(create, project_dir=Path.cwd(), profile=profile, target=target) if not ctx.invoked_subcommand: click.echo( @@ -57,19 +77,20 @@ def dbt( "--full-refresh", help="If specified, dbt will drop incremental models and fully-recalculate the incremental table from the model definition.", ) +@vars_option @click.pass_context -def run(ctx: click.Context, **kwargs: t.Any) -> None: +def run(ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]], **kwargs: t.Any) -> None: """Compile SQL and execute against the current target database.""" - _get_dbt_operations(ctx).run(**kwargs) + _get_dbt_operations(ctx, vars).run(**kwargs) @dbt.command(name="list") @select_option @exclude_option @click.pass_context -def list_(ctx: click.Context, **kwargs: t.Any) -> None: +def list_(ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]], **kwargs: t.Any) -> None: """List the resources in your project""" - _get_dbt_operations(ctx).list_(**kwargs) + _get_dbt_operations(ctx, vars).list_(**kwargs) @dbt.command(name="ls", hidden=True) # hidden alias for list diff --git a/sqlmesh_dbt/error.py b/sqlmesh_dbt/error.py index f5d4bc438c..005ca87c50 100644 --- a/sqlmesh_dbt/error.py +++ b/sqlmesh_dbt/error.py @@ -25,17 +25,5 @@ def wrapper(*args: t.List[t.Any], **kwargs: t.Any) -> t.Any: sys.exit(1) else: raise - finally: - context_or_obj = args[0] - sqlmesh_context = ( - context_or_obj.obj if isinstance(context_or_obj, click.Context) else context_or_obj - ) - if sqlmesh_context is not None: - # important to import this only if a context was created - # otherwise something like `sqlmesh_dbt run --help` will trigger this import because it's in the finally: block - from sqlmesh import Context - - if isinstance(sqlmesh_context, Context): - sqlmesh_context.close() return wrapper diff --git a/sqlmesh_dbt/operations.py b/sqlmesh_dbt/operations.py index 2b89c0f3e9..296000847c 100644 --- a/sqlmesh_dbt/operations.py +++ b/sqlmesh_dbt/operations.py @@ -76,11 +76,15 @@ def console(self) -> DbtCliConsole: return console + def close(self) -> None: + self.context.close() + def create( project_dir: t.Optional[Path] = None, profile: t.Optional[str] = None, target: t.Optional[str] = None, + vars: t.Optional[t.Dict[str, t.Any]] = None, debug: bool = False, ) -> DbtOperations: with Progress(transient=True) as progress: @@ -104,7 +108,7 @@ def create( sqlmesh_context = Context( paths=[project_dir], - config_loader_kwargs=dict(profile=profile, target=target), + config_loader_kwargs=dict(profile=profile, target=target, variables=vars), load=True, ) diff --git a/sqlmesh_dbt/options.py b/sqlmesh_dbt/options.py new file mode 100644 index 0000000000..5a7cabe93b --- /dev/null +++ b/sqlmesh_dbt/options.py @@ -0,0 +1,25 @@ +import typing as t +import click +from click.core import Context, Parameter + + +class YamlParamType(click.ParamType): + name = "yaml" + + def convert( + self, value: t.Any, param: t.Optional[Parameter], ctx: t.Optional[Context] + ) -> t.Any: + if not isinstance(value, str): + self.fail(f"Input value '{value}' should be a string", param, ctx) + + from sqlmesh.utils import yaml + + try: + parsed = yaml.load(source=value, render_jinja=False) + except: + self.fail(f"String '{value}' is not valid YAML", param, ctx) + + if not isinstance(parsed, dict): + self.fail(f"String '{value}' did not evaluate to a dict, got: {parsed}", param, ctx) + + return parsed diff --git a/tests/dbt/cli/test_list.py b/tests/dbt/cli/test_list.py index fe3e1e6829..915097913c 100644 --- a/tests/dbt/cli/test_list.py +++ b/tests/dbt/cli/test_list.py @@ -46,3 +46,17 @@ def test_list_select_exclude(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[.. assert "main.orders" not in result.output assert "main.stg_payments" not in result.output assert "main.raw_orders" not in result.output + + +def test_list_with_vars(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): + (jaffle_shop_duckdb / "models" / "aliased_model.sql").write_text(""" + {{ config(alias='model_' + var('foo')) }} + select 1 + """) + + result = invoke_cli(["list", "--vars", "foo: bar"]) + + assert result.exit_code == 0 + assert not result.exception + + assert "model_bar" in result.output \ No newline at end of file diff --git a/tests/dbt/cli/test_operations.py b/tests/dbt/cli/test_operations.py index 9d36b10f60..9b5b3113b3 100644 --- a/tests/dbt/cli/test_operations.py +++ b/tests/dbt/cli/test_operations.py @@ -69,3 +69,17 @@ def test_create_can_specify_profile_and_target(jaffle_shop_duckdb: Path): assert dbt_project.context.profile_name == "jaffle_shop" assert dbt_project.context.target_name == "dev" + + +def test_create_can_set_project_variables(jaffle_shop_duckdb: Path): + (jaffle_shop_duckdb / "models" / "test_model.sql").write_text(""" + select '{{ var('foo') }}' as a + """) + + dbt_project = create(vars={"foo": "bar"}) + assert dbt_project.context.config.variables["foo"] == "bar" + + test_model = dbt_project.context.models['"jaffle_shop"."main"."test_model"'] + query = test_model.render_query() + assert query is not None + assert query.sql() == "SELECT 'bar' AS \"a\"" From bea35583d0bac9b670db7e70e19454cd5f467c65 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Thu, 21 Aug 2025 23:14:49 +0000 Subject: [PATCH 2/4] fix test --- sqlmesh_dbt/cli.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sqlmesh_dbt/cli.py b/sqlmesh_dbt/cli.py index 2c0a207515..7bcb6a700b 100644 --- a/sqlmesh_dbt/cli.py +++ b/sqlmesh_dbt/cli.py @@ -64,6 +64,10 @@ def dbt( ctx.obj = functools.partial(create, project_dir=Path.cwd(), profile=profile, target=target) if not ctx.invoked_subcommand: + if profile or target: + # trigger a project load to validate the specified profile / target + ctx.obj() + click.echo( f"No command specified. Run `{ctx.info_name} --help` to see the available commands." ) From 4b800c9d8eacc6ec814cf916dec7913f5cf48dea Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Mon, 25 Aug 2025 05:27:24 +0000 Subject: [PATCH 3/4] Add tests for YamlParamType --- sqlmesh_dbt/cli.py | 1 + tests/dbt/cli/test_list.py | 2 +- tests/dbt/cli/test_options.py | 23 +++++++++++++++++++++++ 3 files changed, 25 insertions(+), 1 deletion(-) create mode 100644 tests/dbt/cli/test_options.py diff --git a/sqlmesh_dbt/cli.py b/sqlmesh_dbt/cli.py index 7bcb6a700b..d82c2afd92 100644 --- a/sqlmesh_dbt/cli.py +++ b/sqlmesh_dbt/cli.py @@ -91,6 +91,7 @@ def run(ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]], **kwargs: t.An @dbt.command(name="list") @select_option @exclude_option +@vars_option @click.pass_context def list_(ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]], **kwargs: t.Any) -> None: """List the resources in your project""" diff --git a/tests/dbt/cli/test_list.py b/tests/dbt/cli/test_list.py index 915097913c..e854954903 100644 --- a/tests/dbt/cli/test_list.py +++ b/tests/dbt/cli/test_list.py @@ -59,4 +59,4 @@ def test_list_with_vars(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Re assert result.exit_code == 0 assert not result.exception - assert "model_bar" in result.output \ No newline at end of file + assert "model_bar" in result.output diff --git a/tests/dbt/cli/test_options.py b/tests/dbt/cli/test_options.py new file mode 100644 index 0000000000..962ff0beb3 --- /dev/null +++ b/tests/dbt/cli/test_options.py @@ -0,0 +1,23 @@ +import typing as t +import pytest +from sqlmesh_dbt.options import YamlParamType +from click.exceptions import BadParameter + + +@pytest.mark.parametrize( + "input,expected", + [ + (1, BadParameter("Input value '1' should be a string")), + ("", BadParameter("String '' is not valid YAML")), + ("['a', 'b']", BadParameter("String.*did not evaluate to a dict, got.*")), + ("foo: bar", {"foo": "bar"}), + ('{"key": "value", "date": 20180101}', {"key": "value", "date": 20180101}), + ("{key: value, date: 20180101}", {"key": "value", "date": 20180101}), + ], +) +def test_yaml_param_type(input: str, expected: t.Union[BadParameter, t.Dict[str, t.Any]]): + if isinstance(expected, BadParameter): + with pytest.raises(BadParameter, match=expected.message): + YamlParamType().convert(input, None, None) + else: + assert YamlParamType().convert(input, None, None) == expected From bc35c948ba86edfd21a8915360db4c37b5985166 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Mon, 25 Aug 2025 21:55:07 +0000 Subject: [PATCH 4/4] rebase --- sqlmesh/core/config/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmesh/core/config/loader.py b/sqlmesh/core/config/loader.py index 70b086e471..75915800e6 100644 --- a/sqlmesh/core/config/loader.py +++ b/sqlmesh/core/config/loader.py @@ -176,7 +176,7 @@ def load_config_from_paths( project_root=dbt_project_file.parent, dbt_profile_name=kwargs.pop("profile", None), dbt_target_name=kwargs.pop("target", None), - variables=kwargs.pop("variables", None), + variables=variables, ) if type(dbt_python_config) != config_type: dbt_python_config = convert_config_type(dbt_python_config, config_type)