diff --git a/sqlmesh_dbt/cli.py b/sqlmesh_dbt/cli.py index 500a9d6fa0..7d98e812b7 100644 --- a/sqlmesh_dbt/cli.py +++ b/sqlmesh_dbt/cli.py @@ -12,6 +12,18 @@ def _get_dbt_operations(ctx: click.Context) -> DbtOperations: return ctx.obj +select_option = click.option( + "-s", + "-m", + "--select", + "--models", + "--model", + multiple=True, + help="Specify the nodes to include.", +) +exclude_option = click.option("--exclude", multiple=True, help="Specify the nodes to exclude.") + + @click.group(invoke_without_command=True) @click.option("--profile", help="Which existing profile to load. Overrides output.profile") @click.option("-t", "--target", help="Which target to load for the given profile") @@ -38,23 +50,26 @@ def dbt( @dbt.command() -@click.option("-s", "-m", "--select", "--models", "--model", help="Specify the nodes to include.") +@select_option +@exclude_option @click.option( "-f", "--full-refresh", help="If specified, dbt will drop incremental models and fully-recalculate the incremental table from the model definition.", ) @click.pass_context -def run(ctx: click.Context, select: t.Optional[str], full_refresh: bool) -> None: +def run(ctx: click.Context, **kwargs: t.Any) -> None: """Compile SQL and execute against the current target database.""" - _get_dbt_operations(ctx).run(select=select, full_refresh=full_refresh) + _get_dbt_operations(ctx).run(**kwargs) @dbt.command(name="list") +@select_option +@exclude_option @click.pass_context -def list_(ctx: click.Context) -> None: +def list_(ctx: click.Context, **kwargs: t.Any) -> None: """List the resources in your project""" - _get_dbt_operations(ctx).list_() + _get_dbt_operations(ctx).list_(**kwargs) @dbt.command(name="ls", hidden=True) # hidden alias for list diff --git a/sqlmesh_dbt/console.py b/sqlmesh_dbt/console.py index 7d804ceb71..3c62adfe68 100644 --- a/sqlmesh_dbt/console.py +++ b/sqlmesh_dbt/console.py @@ -1,8 +1,27 @@ +import typing as t from sqlmesh.core.console import TerminalConsole +from sqlmesh.core.model import Model +from rich.tree import Tree class DbtCliConsole(TerminalConsole): - # TODO: build this out - def print(self, msg: str) -> None: return self._print(msg) + + def list_models( + self, models: t.List[Model], list_parents: bool = True, list_audits: bool = True + ) -> None: + model_list = Tree("[bold]Models in project:[/bold]") + + for model in models: + model_tree = model_list.add(model.name) + + if list_parents: + for parent in model.depends_on: + model_tree.add(f"depends_on: {parent}") + + if list_audits: + for audit_name in model.audit_definitions: + model_tree.add(f"audit: {audit_name}") + + self._print(model_list) diff --git a/sqlmesh_dbt/operations.py b/sqlmesh_dbt/operations.py index e8e443a64a..2b89c0f3e9 100644 --- a/sqlmesh_dbt/operations.py +++ b/sqlmesh_dbt/operations.py @@ -2,12 +2,17 @@ import typing as t from rich.progress import Progress from pathlib import Path +import logging +from sqlmesh_dbt import selectors if t.TYPE_CHECKING: # important to gate these to be able to defer importing sqlmesh until we need to from sqlmesh.core.context import Context from sqlmesh.dbt.project import Project from sqlmesh_dbt.console import DbtCliConsole + from sqlmesh.core.model import Model + +logger = logging.getLogger(__name__) class DbtOperations: @@ -15,22 +20,28 @@ def __init__(self, sqlmesh_context: Context, dbt_project: Project): self.context = sqlmesh_context self.project = dbt_project - def list_(self) -> None: - for _, model in self.context.models.items(): - self.console.print(model.name) - - def run(self, select: t.Optional[str] = None, full_refresh: bool = False) -> None: - # A dbt run both updates data and changes schemas and has no way of rolling back so more closely maps to a SQLMesh forward-only plan - # TODO: if --full-refresh specified, mark incrementals as breaking instead of forward_only? - - # TODO: we need to either convert DBT selector syntax to SQLMesh selector syntax - # or make the model selection engine configurable + def list_( + self, + select: t.Optional[t.List[str]] = None, + exclude: t.Optional[t.List[str]] = None, + ) -> None: + # dbt list prints: + # - models + # - "data tests" (audits) for those models + # it also applies selectors which is useful for testing selectors + selected_models = list(self._selected_models(select, exclude).values()) + self.console.list_models(selected_models) + + def run( + self, + select: t.Optional[t.List[str]] = None, + exclude: t.Optional[t.List[str]] = None, + full_refresh: bool = False, + ) -> None: select_models = None - if select: - if "," in select: - select_models = select.split(",") - else: - select_models = select.split(" ") + + if sqlmesh_selector := selectors.to_sqlmesh(select or [], exclude or []): + select_models = [sqlmesh_selector] self.context.plan( select_models=select_models, @@ -40,6 +51,21 @@ def run(self, select: t.Optional[str] = None, full_refresh: bool = False) -> Non auto_apply=True, ) + def _selected_models( + self, select: t.Optional[t.List[str]] = None, exclude: t.Optional[t.List[str]] = None + ) -> t.Dict[str, Model]: + if sqlmesh_selector := selectors.to_sqlmesh(select or [], exclude or []): + model_selector = self.context._new_selector() + selected_models = { + fqn: model + for fqn, model in self.context.models.items() + if fqn in model_selector.expand_model_selections([sqlmesh_selector]) + } + else: + selected_models = dict(self.context.models) + + return selected_models + @property def console(self) -> DbtCliConsole: console = self.context.console diff --git a/sqlmesh_dbt/selectors.py b/sqlmesh_dbt/selectors.py new file mode 100644 index 0000000000..16f5c2ea98 --- /dev/null +++ b/sqlmesh_dbt/selectors.py @@ -0,0 +1,130 @@ +import typing as t +import logging + +logger = logging.getLogger(__name__) + + +def to_sqlmesh(dbt_select: t.Collection[str], dbt_exclude: t.Collection[str]) -> t.Optional[str]: + """ + Given selectors defined in the format of the dbt cli --select and --exclude arguments, convert them into a selector expression that + the SQLMesh selector engine can understand. + + The main things being mapped are: + - set union (" " between items within the same selector string OR multiple --select arguments) is mapped to " | " + - set intersection ("," between items within the same selector string) is mapped to " & " + - `--exclude`. The SQLMesh selector engine does not treat this as a separate parameter and rather treats exclusion as a normal selector + that just happens to contain negation syntax, so we generate these by negating each expression and then intersecting the result + with any --select expressions + + Things that are *not* currently being mapped include: + - selectors based on file paths + - selectors based on partially qualified names like "model_a". The SQLMesh selector engine requires either: + - wildcards, eg "*model_a*" + - the full model name qualified with the schema, eg "staging.model_a" + + Examples: + --select "model_a" + -> "model_a" + --select "main.model_a" + -> "main.model_a" + --select "main.model_a" --select "main.model_b" + -> "main.model_a | main.model_b" + --select "main.model_a main.model_b" + -> "main.model_a | main.model_b" + --select "(main.model_a+ & ^main.model_b)" + -> "(main.model_a+ & ^main.model_b)" + --select "+main.model_a" --exclude "raw.src_data" + -> "+main.model_a & ^(raw.src_data)" + --select "+main.model_a" --select "main.*b+" --exclude "raw.src_data" + -> "(+main.model_a | main.*b+) & ^(raw.src_data)" + """ + if not dbt_select and not dbt_exclude: + return None + + select_expr = " | ".join(_to_sqlmesh(expr) for expr in dbt_select) + select_expr = _wrap(select_expr) if dbt_exclude and len(dbt_select) > 1 else select_expr + + exclude_expr = " | ".join(_to_sqlmesh(expr, negate=True) for expr in dbt_exclude) + exclude_expr = _wrap(exclude_expr) if dbt_select and len(dbt_exclude) > 1 else exclude_expr + + main_expr = " & ".join([expr for expr in [select_expr, exclude_expr] if expr]) + + logger.debug( + f"Expanded dbt select: {dbt_select}, exclude: {dbt_exclude} into SQLMesh: {main_expr}" + ) + + return main_expr + + +def _to_sqlmesh(selector_str: str, negate: bool = False) -> str: + unions, intersections = _split_unions_and_intersections(selector_str) + + if negate: + unions = [_negate(u) for u in unions] + intersections = [_negate(i) for i in intersections] + + union_expr = " | ".join(unions) + intersection_expr = " & ".join(intersections) + + if len(unions) > 1 and intersections: + union_expr = f"({union_expr})" + + if len(intersections) > 1 and unions: + intersection_expr = f"({intersection_expr})" + + return " | ".join([expr for expr in [union_expr, intersection_expr] if expr]) + + +def _split_unions_and_intersections(selector_str: str) -> t.Tuple[t.List[str], t.List[str]]: + # break space-separated items like: "my_first_model my_second_model" into a list of selectors to union + # and comma-separated items like: "my_first_model,my_second_model" into a list of selectors to intersect + # but, take into account brackets, eg "(my_first_model & my_second_model)" should not be split + + def _split_by(input: str, delimiter: str) -> t.Iterator[str]: + buf = "" + depth = 0 + + for char in input: + if char == delimiter and depth <= 0: + # only split on a space if we are not within parenthesis + yield buf + buf = "" + continue + elif char == "(": + depth += 1 + elif char == ")": + depth -= 1 + + buf += char + + if buf: + yield buf + + # first, break up based on spaces + segments = list(_split_by(selector_str, " ")) + + # then, within each segment, identify the unions and intersections + unions = [] + intersections = [] + + for segment in segments: + maybe_intersections = list(_split_by(segment, ",")) + if len(maybe_intersections) > 1: + intersections.extend(maybe_intersections) + else: + unions.append(segment) + + return unions, intersections + + +def _negate(expr: str) -> str: + return f"^{_wrap(expr)}" + + +def _wrap(expr: str) -> str: + already_wrapped = expr.strip().startswith("(") and expr.strip().endswith(")") + + if expr and not already_wrapped: + return f"({expr})" + + return expr diff --git a/tests/dbt/cli/test_list.py b/tests/dbt/cli/test_list.py index 9312be8635..fe3e1e6829 100644 --- a/tests/dbt/cli/test_list.py +++ b/tests/dbt/cli/test_list.py @@ -15,3 +15,34 @@ def test_list(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): assert "main.orders" in result.output assert "main.customers" in result.output assert "main.stg_payments" in result.output + assert "main.raw_orders" in result.output + + +def test_list_select(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): + result = invoke_cli(["list", "--select", "main.raw_customers+"]) + + assert result.exit_code == 0 + assert not result.exception + + assert "main.orders" in result.output + assert "main.customers" in result.output + assert "main.stg_customers" in result.output + assert "main.raw_customers" in result.output + + assert "main.stg_payments" not in result.output + assert "main.raw_orders" not in result.output + + +def test_list_select_exclude(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): + result = invoke_cli(["list", "--select", "main.raw_customers+", "--exclude", "main.orders"]) + + assert result.exit_code == 0 + assert not result.exception + + assert "main.customers" in result.output + assert "main.stg_customers" in result.output + assert "main.raw_customers" in result.output + + assert "main.orders" not in result.output + assert "main.stg_payments" not in result.output + assert "main.raw_orders" not in result.output diff --git a/tests/dbt/cli/test_run.py b/tests/dbt/cli/test_run.py index 0e4a04bcb1..4d80514fc8 100644 --- a/tests/dbt/cli/test_run.py +++ b/tests/dbt/cli/test_run.py @@ -2,6 +2,8 @@ import pytest from pathlib import Path from click.testing import Result +import time_machine +from tests.cli.test_cli import FREEZE_TIME pytestmark = pytest.mark.slow @@ -13,3 +15,26 @@ def test_run(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): assert not result.exception assert "Model batches executed" in result.output + + +def test_run_with_selectors(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): + with time_machine.travel(FREEZE_TIME): + # do an initial run to create the objects + # otherwise the selected subset may depend on something that hasnt been created + result = invoke_cli(["run"]) + assert result.exit_code == 0 + assert "main.orders" in result.output + + result = invoke_cli(["run", "--select", "main.raw_customers+", "--exclude", "main.orders"]) + + assert result.exit_code == 0 + assert not result.exception + + assert "main.stg_customers" in result.output + assert "main.stg_orders" in result.output + assert "main.stg_payments" in result.output + assert "main.customers" in result.output + + assert "main.orders" not in result.output + + assert "Model batches executed" in result.output diff --git a/tests/dbt/cli/test_selectors.py b/tests/dbt/cli/test_selectors.py new file mode 100644 index 0000000000..e494ed98a3 --- /dev/null +++ b/tests/dbt/cli/test_selectors.py @@ -0,0 +1,78 @@ +import typing as t +import pytest +from sqlmesh_dbt import selectors + + +@pytest.mark.parametrize( + "dbt_select,expected", + [ + ([], None), + (["main.model_a"], "main.model_a"), + (["main.model_a main.model_b"], "main.model_a | main.model_b"), + (["main.model_a", "main.model_b"], "main.model_a | main.model_b"), + (["(main.model_a & ^main.model_b)"], "(main.model_a & ^main.model_b)"), + ( + ["(+main.model_a & ^main.model_b)", "main.model_c"], + "(+main.model_a & ^main.model_b) | main.model_c", + ), + ], +) +def test_selection(dbt_select: t.List[str], expected: t.Optional[str]): + assert selectors.to_sqlmesh(dbt_select=dbt_select, dbt_exclude=[]) == expected + + +@pytest.mark.parametrize( + "dbt_exclude,expected", + [ + ([], None), + (["main.model_a"], "^(main.model_a)"), + (["(main.model_a & main.model_b)"], "^(main.model_a & main.model_b)"), + (["main.model_a +main.model_b"], "^(main.model_a) | ^(+main.model_b)"), + ( + ["(+main.model_a & ^main.model_b)", "main.model_c"], + "^(+main.model_a & ^main.model_b) | ^(main.model_c)", + ), + ], +) +def test_exclusion(dbt_exclude: t.List[str], expected: t.Optional[str]): + assert selectors.to_sqlmesh(dbt_select=[], dbt_exclude=dbt_exclude) == expected + + +@pytest.mark.parametrize( + "dbt_select,dbt_exclude,expected", + [ + ([], [], None), + (["+main.model_a"], ["raw.src_data"], "+main.model_a & ^(raw.src_data)"), + ( + ["+main.model_a", "main.*b+"], + ["raw.src_data"], + "(+main.model_a | main.*b+) & ^(raw.src_data)", + ), + ( + ["+main.model_a", "main.*b+"], + ["raw.src_data", "tag:disabled"], + "(+main.model_a | main.*b+) & (^(raw.src_data) | ^(tag:disabled))", + ), + ], +) +def test_selection_and_exclusion( + dbt_select: t.List[str], dbt_exclude: t.List[str], expected: t.Optional[str] +): + assert selectors.to_sqlmesh(dbt_select=dbt_select, dbt_exclude=dbt_exclude) == expected + + +@pytest.mark.parametrize( + "expression,expected", + [ + ("", ([], [])), + ("model_a", (["model_a"], [])), + ("model_a model_b", (["model_a", "model_b"], [])), + ("model_a,model_b", ([], ["model_a", "model_b"])), + ("model_a model_b,model_c", (["model_a"], ["model_b", "model_c"])), + ("model_a,model_b model_c", (["model_c"], ["model_a", "model_b"])), + ], +) +def test_split_unions_and_intersections( + expression: str, expected: t.Tuple[t.List[str], t.List[str]] +): + assert selectors._split_unions_and_intersections(expression) == expected