Skip to content

Commit deaad34

Browse files
committed
Feat: dbt cli
1 parent 9eba5c1 commit deaad34

26 files changed

+1038
-2
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ repos:
77
language: python
88
types_or: [python, pyi]
99
require_serial: true
10-
files: &files ^(sqlmesh/|tests/|web/|examples/|setup.py)
10+
files: &files ^(sqlmesh/|sqlmesh_dbt/|tests/|web/|examples/|setup.py)
1111
- id: ruff-format
1212
name: ruff-format
1313
entry: ruff format --force-exclude --line-length 100

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ risingwave = ["psycopg2"]
142142

143143
[project.scripts]
144144
sqlmesh = "sqlmesh.cli.main:cli"
145+
sqlmesh_dbt = "sqlmesh_dbt.cli:dbt"
145146
sqlmesh_cicd = "sqlmesh.cicd.bot:bot"
146147
sqlmesh_lsp = "sqlmesh.lsp.main:main"
147148

@@ -164,7 +165,7 @@ fallback_version = "0.0.0"
164165
local_scheme = "no-local-version"
165166

166167
[tool.setuptools.packages.find]
167-
include = ["sqlmesh", "sqlmesh.*", "web*"]
168+
include = ["sqlmesh", "sqlmesh.*", "sqlmesh_dbt", "sqlmesh_dbt.*", "web*"]
168169

169170
[tool.setuptools.package-data]
170171
web = ["client/dist/**"]

sqlmesh_dbt/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Note: `sqlmesh_dbt` is deliberately in its own package from `sqlmesh` to avoid the upfront time overhead
2+
# that comes from `import sqlmesh`
3+
#
4+
# Obviously we still have to `import sqlmesh` at some point but this allows us to defer it until needed,
5+
# which means we can make the CLI feel more responsive by being able to output something immediately

sqlmesh_dbt/cli.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import typing as t
2+
import sys
3+
import click
4+
from sqlmesh_dbt.operations import DbtOperations, create
5+
6+
7+
def _get_dbt_operations(ctx: click.Context) -> DbtOperations:
8+
if not isinstance(ctx.obj, DbtOperations):
9+
raise ValueError(f"Unexpected click context object: {type(ctx.obj)}")
10+
return ctx.obj
11+
12+
13+
@click.group()
14+
@click.pass_context
15+
def dbt(ctx: click.Context) -> None:
16+
"""
17+
An ELT tool for managing your SQL transformations and data models, powered by the SQLMesh engine.
18+
"""
19+
20+
if "--help" in sys.argv:
21+
# we dont need to import sqlmesh/load the project for CLI help
22+
return
23+
24+
# TODO: conditionally call create() if there are times we dont want/need to import sqlmesh and load a project
25+
ctx.obj = create()
26+
27+
28+
@dbt.command()
29+
@click.option("-s", "-m", "--select", "--models", "--model", help="Specify the nodes to include.")
30+
@click.option(
31+
"-f",
32+
"--full-refresh",
33+
help="If specified, dbt will drop incremental models and fully-recalculate the incremental table from the model definition.",
34+
)
35+
@click.pass_context
36+
def run(ctx: click.Context, select: t.Optional[str], full_refresh: bool) -> None:
37+
"""Compile SQL and execute against the current target database."""
38+
_get_dbt_operations(ctx).run(select=select, full_refresh=full_refresh)
39+
40+
41+
@dbt.command(name="list")
42+
@click.pass_context
43+
def list_(ctx: click.Context) -> None:
44+
"""List the resources in your project"""
45+
_get_dbt_operations(ctx).list_()
46+
47+
48+
@dbt.command(name="ls", hidden=True) # hidden alias for list
49+
@click.pass_context
50+
def ls(ctx: click.Context) -> None:
51+
"""List the resources in your project"""
52+
ctx.forward(list_)
53+
54+
55+
def _not_implemented(name: str) -> None:
56+
@dbt.command(name=name)
57+
def _not_implemented() -> None:
58+
"""Not implemented"""
59+
click.echo(f"dbt {name} not implemented")
60+
61+
62+
for subcommand in (
63+
"build",
64+
"clean",
65+
"clone",
66+
"compile",
67+
"debug",
68+
"deps",
69+
"docs",
70+
"init",
71+
"parse",
72+
"retry",
73+
"run-operation",
74+
"seed",
75+
"show",
76+
"snapshot",
77+
"source",
78+
"test",
79+
):
80+
_not_implemented(subcommand)

sqlmesh_dbt/console.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from sqlmesh.core.console import TerminalConsole
2+
3+
4+
class DbtCliConsole(TerminalConsole):
5+
# TODO: build this out
6+
7+
def print(self, msg: str) -> None:
8+
return self._print(msg)

sqlmesh_dbt/operations.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from __future__ import annotations
2+
import typing as t
3+
from rich.progress import Progress
4+
from pathlib import Path
5+
6+
if t.TYPE_CHECKING:
7+
# important to gate these to be able to defer importing sqlmesh until we need to
8+
from sqlmesh.core.context import Context
9+
from sqlmesh.dbt.project import Project
10+
from sqlmesh_dbt.console import DbtCliConsole
11+
12+
13+
class DbtOperations:
14+
def __init__(self, sqlmesh_context: Context, dbt_project: Project):
15+
self.context = sqlmesh_context
16+
self.project = dbt_project
17+
18+
def list_(self) -> None:
19+
for _, model in self.context.models.items():
20+
self.console.print(model.name)
21+
22+
def run(self, select: t.Optional[str] = None, full_refresh: bool = False) -> None:
23+
# A dbt run both updates data and changes schemas and has no way of rolling back so more closely maps to a SQLMesh forward-only plan
24+
# TODO: if --full-refresh specified, mark incrementals as breaking instead of forward_only?
25+
26+
# TODO: we need to either convert DBT selector syntax to SQLMesh selector syntax
27+
# or make the model selection engine configurable
28+
select_models = None
29+
if select:
30+
if "," in select:
31+
select_models = select.split(",")
32+
else:
33+
select_models = select.split(" ")
34+
35+
self.context.plan(
36+
select_models=select_models,
37+
forward_only=True,
38+
no_auto_categorization=True, # everything is breaking / foward-only
39+
effective_from=self.context.config.model_defaults.start,
40+
run=True,
41+
auto_apply=True,
42+
)
43+
44+
@property
45+
def console(self) -> DbtCliConsole:
46+
console = self.context.console
47+
from sqlmesh_dbt.console import DbtCliConsole
48+
49+
if not isinstance(console, DbtCliConsole):
50+
raise ValueError(f"Expecting dbt cli console, got: {console}")
51+
52+
return console
53+
54+
55+
def create(
56+
project_dir: t.Optional[Path] = None, profiles_dir: t.Optional[Path] = None, debug: bool = False
57+
) -> DbtOperations:
58+
with Progress(transient=True) as progress:
59+
# Indeterminate progress bar before SQLMesh import to provide feedback to the user that something is indeed happening
60+
load_task_id = progress.add_task("Loading engine", total=None)
61+
62+
from sqlmesh import configure_logging
63+
from sqlmesh.core.context import Context
64+
from sqlmesh.dbt.loader import sqlmesh_config, DbtLoader
65+
from sqlmesh.core.console import set_console
66+
from sqlmesh_dbt.console import DbtCliConsole
67+
from sqlmesh.utils.errors import SQLMeshError
68+
69+
configure_logging(force_debug=debug)
70+
set_console(DbtCliConsole())
71+
72+
progress.update(load_task_id, description="Loading project", total=None)
73+
74+
# inject default start date if one is not specified to prevent the user from having to do anything
75+
_inject_default_start_date(project_dir)
76+
77+
config = sqlmesh_config(
78+
project_root=project_dir,
79+
# do we want to use a local duckdb for state?
80+
# warehouse state has a bunch of overhead to initialize, is slow for ongoing operations and will create tables that perhaps the user was not expecting
81+
# on the other hand, local state is not portable
82+
state_connection=None,
83+
)
84+
85+
sqlmesh_context = Context(
86+
config=config,
87+
load=True,
88+
)
89+
90+
# this helps things which want a default project-level start date, like the "effective from date" for forward-only plans
91+
if not sqlmesh_context.config.model_defaults.start:
92+
min_start_date = min(
93+
(
94+
model.start
95+
for model in sqlmesh_context.models.values()
96+
if model.start is not None
97+
),
98+
default=None,
99+
)
100+
sqlmesh_context.config.model_defaults.start = min_start_date
101+
102+
dbt_loader = sqlmesh_context._loaders[0]
103+
if not isinstance(dbt_loader, DbtLoader):
104+
raise SQLMeshError(f"Unexpected loader type: {type(dbt_loader)}")
105+
106+
# so that DbtOperations can query information from the DBT project files in order to invoke SQLMesh correctly
107+
dbt_project = dbt_loader._projects[0]
108+
109+
return DbtOperations(sqlmesh_context, dbt_project)
110+
111+
112+
def _inject_default_start_date(project_dir: t.Optional[Path] = None) -> None:
113+
"""
114+
SQLMesh needs a start date to as the starting point for calculating intervals on incremental models
115+
116+
Rather than forcing the user to update their config manually or having a default that is not saved between runs,
117+
we can inject it automatically to the dbt_project.yml file
118+
"""
119+
from sqlmesh.dbt.project import PROJECT_FILENAME, load_yaml
120+
from sqlmesh.utils.yaml import dump
121+
from sqlmesh.utils.date import yesterday_ds
122+
123+
project_yaml_path = (project_dir or Path.cwd()) / PROJECT_FILENAME
124+
if project_yaml_path.exists():
125+
loaded_project_file = load_yaml(project_yaml_path)
126+
start_date_keys = ("start", "+start")
127+
if "models" in loaded_project_file and all(
128+
k not in loaded_project_file["models"] for k in start_date_keys
129+
):
130+
loaded_project_file["models"]["+start"] = yesterday_ds()
131+
# todo: this may format the file differently, is that acceptable?
132+
with project_yaml_path.open("w") as f:
133+
dump(loaded_project_file, f)

tests/dbt/cli/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pytestmark = ["foo"]

tests/dbt/cli/conftest.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import typing as t
2+
from pathlib import Path
3+
import os
4+
import functools
5+
from click.testing import CliRunner, Result
6+
import pytest
7+
8+
9+
@pytest.fixture
10+
def jaffle_shop_duckdb(copy_to_temp_path: t.Callable[..., t.List[Path]]) -> t.Iterable[Path]:
11+
fixture_path = Path(__file__).parent / "fixtures" / "jaffle_shop_duckdb"
12+
assert fixture_path.exists()
13+
14+
current_path = os.getcwd()
15+
output_path = copy_to_temp_path(paths=fixture_path)[0]
16+
17+
# so that we can invoke commands from the perspective of a user that is alrady in the correct directory
18+
os.chdir(output_path)
19+
20+
yield output_path
21+
22+
os.chdir(current_path)
23+
24+
25+
@pytest.fixture
26+
def invoke_cli() -> t.Callable[..., Result]:
27+
from sqlmesh_dbt.cli import dbt
28+
29+
return functools.partial(CliRunner().invoke, dbt)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
name: 'jaffle_shop'
2+
3+
config-version: 2
4+
version: '0.1'
5+
6+
profile: 'jaffle_shop'
7+
8+
model-paths: ["models"]
9+
seed-paths: ["seeds"]
10+
test-paths: ["tests"]
11+
analysis-paths: ["analysis"]
12+
macro-paths: ["macros"]
13+
14+
target-path: "target"
15+
clean-targets:
16+
- "target"
17+
- "dbt_modules"
18+
- "logs"
19+
20+
require-dbt-version: [">=1.0.0", "<2.0.0"]
21+
22+
seeds:
23+
+docs:
24+
node_color: '#cd7f32'
25+
26+
models:
27+
jaffle_shop:
28+
+materialized: table
29+
staging:
30+
+materialized: view
31+
+docs:
32+
node_color: 'silver'
33+
+docs:
34+
node_color: 'gold'
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
with customers as (
2+
3+
select * from {{ ref('stg_customers') }}
4+
5+
),
6+
7+
orders as (
8+
9+
select * from {{ ref('stg_orders') }}
10+
11+
),
12+
13+
payments as (
14+
15+
select * from {{ ref('stg_payments') }}
16+
17+
),
18+
19+
customer_orders as (
20+
21+
select
22+
customer_id,
23+
24+
min(order_date) as first_order,
25+
max(order_date) as most_recent_order,
26+
count(order_id) as number_of_orders
27+
from orders
28+
29+
group by customer_id
30+
31+
),
32+
33+
customer_payments as (
34+
35+
select
36+
orders.customer_id,
37+
sum(amount) as total_amount
38+
39+
from payments
40+
41+
left join orders on
42+
payments.order_id = orders.order_id
43+
44+
group by orders.customer_id
45+
46+
),
47+
48+
final as (
49+
50+
select
51+
customers.customer_id,
52+
customers.first_name,
53+
customers.last_name,
54+
customer_orders.first_order,
55+
customer_orders.most_recent_order,
56+
customer_orders.number_of_orders,
57+
customer_payments.total_amount as customer_lifetime_value
58+
59+
from customers
60+
61+
left join customer_orders
62+
on customers.customer_id = customer_orders.customer_id
63+
64+
left join customer_payments
65+
on customers.customer_id = customer_payments.customer_id
66+
67+
)
68+
69+
select * from final

0 commit comments

Comments
 (0)