Skip to content

Commit 0f90cd8

Browse files
committed
Feat(dbt_cli): Add --select and --exclude
1 parent 31763bf commit 0f90cd8

File tree

7 files changed

+300
-22
lines changed

7 files changed

+300
-22
lines changed

sqlmesh_dbt/cli.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@ def _get_dbt_operations(ctx: click.Context) -> DbtOperations:
1212
return ctx.obj
1313

1414

15+
select_option = click.option(
16+
"-s",
17+
"-m",
18+
"--select",
19+
"--models",
20+
"--model",
21+
multiple=True,
22+
help="Specify the nodes to include.",
23+
)
24+
exclude_option = click.option("--exclude", multiple=True, help="Specify the nodes to exclude.")
25+
26+
1527
@click.group(invoke_without_command=True)
1628
@click.option("--profile", help="Which existing profile to load. Overrides output.profile")
1729
@click.option("-t", "--target", help="Which target to load for the given profile")
@@ -38,23 +50,26 @@ def dbt(
3850

3951

4052
@dbt.command()
41-
@click.option("-s", "-m", "--select", "--models", "--model", help="Specify the nodes to include.")
53+
@select_option
54+
@exclude_option
4255
@click.option(
4356
"-f",
4457
"--full-refresh",
4558
help="If specified, dbt will drop incremental models and fully-recalculate the incremental table from the model definition.",
4659
)
4760
@click.pass_context
48-
def run(ctx: click.Context, select: t.Optional[str], full_refresh: bool) -> None:
61+
def run(ctx: click.Context, **kwargs: t.Any) -> None:
4962
"""Compile SQL and execute against the current target database."""
50-
_get_dbt_operations(ctx).run(select=select, full_refresh=full_refresh)
63+
_get_dbt_operations(ctx).run(**kwargs)
5164

5265

5366
@dbt.command(name="list")
67+
@select_option
68+
@exclude_option
5469
@click.pass_context
55-
def list_(ctx: click.Context) -> None:
70+
def list_(ctx: click.Context, **kwargs: t.Any) -> None:
5671
"""List the resources in your project"""
57-
_get_dbt_operations(ctx).list_()
72+
_get_dbt_operations(ctx).list_(**kwargs)
5873

5974

6075
@dbt.command(name="ls", hidden=True) # hidden alias for list

sqlmesh_dbt/console.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,27 @@
1+
import typing as t
12
from sqlmesh.core.console import TerminalConsole
3+
from sqlmesh.core.model import Model
4+
from rich.tree import Tree
25

36

47
class DbtCliConsole(TerminalConsole):
5-
# TODO: build this out
6-
78
def print(self, msg: str) -> None:
89
return self._print(msg)
10+
11+
def list_models(
12+
self, models: t.List[Model], list_parents: bool = True, list_audits: bool = True
13+
) -> None:
14+
model_list = Tree("[bold]Models in project:[/bold]")
15+
16+
for model in models:
17+
model_tree = model_list.add(model.name)
18+
19+
if list_parents:
20+
for parent in model.depends_on:
21+
model_tree.add(f"depends_on: {parent}")
22+
23+
if list_audits:
24+
for audit_name in model.audit_definitions:
25+
model_tree.add(f"audit: {audit_name}")
26+
27+
self._print(model_list)

sqlmesh_dbt/operations.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,46 @@
22
import typing as t
33
from rich.progress import Progress
44
from pathlib import Path
5+
import logging
6+
from sqlmesh_dbt import selectors
57

68
if t.TYPE_CHECKING:
79
# important to gate these to be able to defer importing sqlmesh until we need to
810
from sqlmesh.core.context import Context
911
from sqlmesh.dbt.project import Project
1012
from sqlmesh_dbt.console import DbtCliConsole
13+
from sqlmesh.core.model import Model
14+
15+
logger = logging.getLogger(__name__)
1116

1217

1318
class DbtOperations:
1419
def __init__(self, sqlmesh_context: Context, dbt_project: Project):
1520
self.context = sqlmesh_context
1621
self.project = dbt_project
1722

18-
def list_(self) -> None:
19-
for _, model in self.context.models.items():
20-
self.console.print(model.name)
21-
22-
def run(self, select: t.Optional[str] = None, full_refresh: bool = False) -> None:
23-
# A dbt run both updates data and changes schemas and has no way of rolling back so more closely maps to a SQLMesh forward-only plan
24-
# TODO: if --full-refresh specified, mark incrementals as breaking instead of forward_only?
25-
26-
# TODO: we need to either convert DBT selector syntax to SQLMesh selector syntax
27-
# or make the model selection engine configurable
23+
def list_(
24+
self,
25+
select: t.Optional[t.List[str]] = None,
26+
exclude: t.Optional[t.List[str]] = None,
27+
) -> None:
28+
# dbt list prints:
29+
# - models
30+
# - "data tests" (audits) for those models
31+
# it also applies selectors which is useful for testing selectors
32+
selected_models = list(self._selected_models(select, exclude).values())
33+
self.console.list_models(selected_models)
34+
35+
def run(
36+
self,
37+
select: t.Optional[t.List[str]] = None,
38+
exclude: t.Optional[t.List[str]] = None,
39+
full_refresh: bool = False,
40+
) -> None:
2841
select_models = None
29-
if select:
30-
if "," in select:
31-
select_models = select.split(",")
32-
else:
33-
select_models = select.split(" ")
42+
43+
if sqlmesh_selector := selectors.to_sqlmesh(select or [], exclude or []):
44+
select_models = [sqlmesh_selector]
3445

3546
self.context.plan(
3647
select_models=select_models,
@@ -41,6 +52,21 @@ def run(self, select: t.Optional[str] = None, full_refresh: bool = False) -> Non
4152
auto_apply=True,
4253
)
4354

55+
def _selected_models(
56+
self, select: t.Optional[t.List[str]] = None, exclude: t.Optional[t.List[str]] = None
57+
) -> t.Dict[str, Model]:
58+
if sqlmesh_selector := selectors.to_sqlmesh(select or [], exclude or []):
59+
model_selector = self.context._new_selector()
60+
selected_models = {
61+
fqn: model
62+
for fqn, model in self.context.models.items()
63+
if fqn in model_selector.expand_model_selections([sqlmesh_selector])
64+
}
65+
else:
66+
selected_models = dict(self.context.models)
67+
68+
return selected_models
69+
4470
@property
4571
def console(self) -> DbtCliConsole:
4672
console = self.context.console

sqlmesh_dbt/selectors.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import typing as t
2+
import logging
3+
4+
logger = logging.getLogger(__name__)
5+
6+
7+
def to_sqlmesh(dbt_select: t.Collection[str], dbt_exclude: t.Collection[str]) -> t.Optional[str]:
8+
"""
9+
Given selectors defined in the format of the dbt cli --select and --exclude arguments, convert them into a selector expression that
10+
the SQLMesh selector engine can understand.
11+
12+
Note that actually implementing compatible dbt selector syntax and maintaining compatibility with existing dbt selectors is considered out of scope
13+
at this stage, so the incoming selectors are expected to follow the SQLMesh syntax.
14+
15+
The main things being mapped are:
16+
- set union (" " between items within the same selector string)
17+
- `--exclude`. The SQLMesh selector engine does not treat this as a separate parameter and rather treats exclusion as a normal selector
18+
that just happens to contain negation syntax, so we generate these by negating each expression
19+
20+
Things that are *not* being mapped include:
21+
- set intersection ("," between items) as the SQLMesh selector engine doesnt support this
22+
- selectors based on file paths
23+
- selectors based on partially qualified names like "model_a". The SQLMesh selector engine requires either:
24+
- wildcards, eg "*model_a*"
25+
- the full model name qualified with the schema, eg "staging.model_a"
26+
27+
Examples:
28+
--select "main.model_a"
29+
-> "main.model_a"
30+
--select "main.model_a" --select "main.model_b"
31+
-> "main.model_a & main.model_b"
32+
--select "main.model_a main.model_b"
33+
-> "main.model_a & main.model_b"
34+
--select "(main.model_a | ^main.model_b)"
35+
-> "(main.model_a | ^main.model_b)"
36+
--select "+main.model_a" --exclude "raw.src_data"
37+
-> "+main.model_a & ^(raw.src_data)"
38+
--select "+main.model_a" --select "main.*b+" --exclude "raw.src_data"
39+
-> "(+main.model_a & main.*b+) & ^(raw.src_data)"
40+
"""
41+
if not dbt_select and not dbt_exclude:
42+
return None
43+
44+
def _is_wrapped_in_parenthesis(test: str) -> bool:
45+
return test.strip().startswith("(") and test.strip().endswith(")")
46+
47+
# expand space-separated items like: "my_first_model my_second_model" into multiple items
48+
# but take into account brackets, eg "(my_first_model & my_second_model)" should not be split
49+
def _split_selector_string(selector_str: str) -> t.List[str]:
50+
splits = []
51+
buf = ""
52+
stack = 0
53+
54+
for char in selector_str:
55+
if char == " " and stack <= 0:
56+
# only split on a space if we are not within parenthesis
57+
splits.append(buf)
58+
buf = ""
59+
continue
60+
elif char == "(":
61+
stack += 1
62+
elif char == ")":
63+
stack -= 1
64+
65+
buf += char
66+
67+
if buf:
68+
splits.append(buf)
69+
70+
return splits
71+
72+
split_dbt_select = [item for s in dbt_select for item in _split_selector_string(s)]
73+
74+
split_dbt_exclude = [item for s in dbt_exclude for item in _split_selector_string(s)]
75+
76+
main_expr = " & ".join(split_dbt_select)
77+
78+
if split_dbt_exclude:
79+
negated_dbt_exclude = [
80+
f"^{e}" if _is_wrapped_in_parenthesis(e) else f"^({e})" for e in split_dbt_exclude
81+
]
82+
negated_expr = " & ".join(negated_dbt_exclude)
83+
84+
# only wrap in extra parenthesis if there was more than 1 exclusion expression with some inclusion expressioons
85+
# otherwise it can stand by itself with no parenthesis
86+
if len(split_dbt_exclude) > 1 and split_dbt_select:
87+
negated_expr = f"({negated_expr})"
88+
89+
if len(split_dbt_select) > 1:
90+
main_expr = f"({main_expr})"
91+
92+
if main_expr:
93+
main_expr = f"{main_expr} & {negated_expr}"
94+
else:
95+
main_expr = negated_expr
96+
97+
logger.debug(
98+
f"Expanded dbt select: {dbt_select}, exclude: {dbt_exclude} into SQLMesh: {main_expr}"
99+
)
100+
101+
return main_expr

tests/dbt/cli/test_list.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,34 @@ def test_list(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
1515
assert "main.orders" in result.output
1616
assert "main.customers" in result.output
1717
assert "main.stg_payments" in result.output
18+
assert "main.raw_orders" in result.output
19+
20+
21+
def test_list_select(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
22+
result = invoke_cli(["list", "--select", "main.raw_customers+"])
23+
24+
assert result.exit_code == 0
25+
assert not result.exception
26+
27+
assert "main.orders" in result.output
28+
assert "main.customers" in result.output
29+
assert "main.stg_customers" in result.output
30+
assert "main.raw_customers" in result.output
31+
32+
assert "main.stg_payments" not in result.output
33+
assert "main.raw_orders" not in result.output
34+
35+
36+
def test_list_select_exclude(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
37+
result = invoke_cli(["list", "--select", "main.raw_customers+", "--exclude", "main.orders"])
38+
39+
assert result.exit_code == 0
40+
assert not result.exception
41+
42+
assert "main.customers" in result.output
43+
assert "main.stg_customers" in result.output
44+
assert "main.raw_customers" in result.output
45+
46+
assert "main.orders" not in result.output
47+
assert "main.stg_payments" not in result.output
48+
assert "main.raw_orders" not in result.output

tests/dbt/cli/test_run.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import pytest
33
from pathlib import Path
44
from click.testing import Result
5+
import time_machine
6+
from tests.cli.test_cli import FREEZE_TIME
57

68
pytestmark = pytest.mark.slow
79

@@ -13,3 +15,26 @@ def test_run(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
1315
assert not result.exception
1416

1517
assert "Model batches executed" in result.output
18+
19+
20+
def test_run_with_selectors(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
21+
with time_machine.travel(FREEZE_TIME):
22+
# do an initial run to create the objects
23+
# otherwise the selected subset may depend on something that hasnt been created
24+
result = invoke_cli(["run"])
25+
assert result.exit_code == 0
26+
assert "main.orders" in result.output
27+
28+
result = invoke_cli(["run", "--select", "main.raw_customers+", "--exclude", "main.orders"])
29+
30+
assert result.exit_code == 0
31+
assert not result.exception
32+
33+
assert "main.stg_customers" in result.output
34+
assert "main.stg_orders" in result.output
35+
assert "main.stg_payments" in result.output
36+
assert "main.customers" in result.output
37+
38+
assert "main.orders" not in result.output
39+
40+
assert "Model batches executed" in result.output

tests/dbt/cli/test_selectors.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import typing as t
2+
import pytest
3+
from sqlmesh_dbt import selectors
4+
5+
6+
@pytest.mark.parametrize(
7+
"dbt_select,expected",
8+
[
9+
([], None),
10+
(["main.model_a"], "main.model_a"),
11+
(["main.model_a main.model_b"], "main.model_a & main.model_b"),
12+
(["main.model_a", "main.model_b"], "main.model_a & main.model_b"),
13+
(["(main.model_a | ^main.model_b)"], "(main.model_a | ^main.model_b)"),
14+
(
15+
["(+main.model_a | ^main.model_b)", "main.model_c"],
16+
"(+main.model_a | ^main.model_b) & main.model_c",
17+
),
18+
],
19+
)
20+
def test_selection(dbt_select: t.List[str], expected: t.Optional[str]):
21+
assert selectors.to_sqlmesh(dbt_select=dbt_select, dbt_exclude=[]) == expected
22+
23+
24+
@pytest.mark.parametrize(
25+
"dbt_exclude,expected",
26+
[
27+
([], None),
28+
(["main.model_a"], "^(main.model_a)"),
29+
(["(main.model_a & main.model_b)"], "^(main.model_a & main.model_b)"),
30+
(["main.model_a +main.model_b"], "^(main.model_a) & ^(+main.model_b)"),
31+
(
32+
["(+main.model_a | ^main.model_b)", "main.model_c"],
33+
"^(+main.model_a | ^main.model_b) & ^(main.model_c)",
34+
),
35+
],
36+
)
37+
def test_exclusion(dbt_exclude: t.List[str], expected: t.Optional[str]):
38+
assert selectors.to_sqlmesh(dbt_select=[], dbt_exclude=dbt_exclude) == expected
39+
40+
41+
@pytest.mark.parametrize(
42+
"dbt_select,dbt_exclude,expected",
43+
[
44+
([], [], None),
45+
(["+main.model_a"], ["raw.src_data"], "+main.model_a & ^(raw.src_data)"),
46+
(
47+
["+main.model_a", "main.*b+"],
48+
["raw.src_data"],
49+
"(+main.model_a & main.*b+) & ^(raw.src_data)",
50+
),
51+
(
52+
["+main.model_a", "main.*b+"],
53+
["raw.src_data", "tag:disabled"],
54+
"(+main.model_a & main.*b+) & (^(raw.src_data) & ^(tag:disabled))",
55+
),
56+
],
57+
)
58+
def test_selection_and_exclusion(
59+
dbt_select: t.List[str], dbt_exclude: t.List[str], expected: t.Optional[str]
60+
):
61+
assert selectors.to_sqlmesh(dbt_select=dbt_select, dbt_exclude=dbt_exclude) == expected

0 commit comments

Comments
 (0)