Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ repos:
language: python
types_or: [python, pyi]
require_serial: true
files: &files ^(sqlmesh/|tests/|web/|examples/|setup.py)
files: &files ^(sqlmesh/|sqlmesh_dbt/|tests/|web/|examples/|setup.py)
- id: ruff-format
name: ruff-format
entry: ruff format --force-exclude --line-length 100
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ risingwave = ["psycopg2"]

[project.scripts]
sqlmesh = "sqlmesh.cli.main:cli"
sqlmesh_dbt = "sqlmesh_dbt.cli:dbt"
sqlmesh_cicd = "sqlmesh.cicd.bot:bot"
sqlmesh_lsp = "sqlmesh.lsp.main:main"

Expand All @@ -164,7 +165,7 @@ fallback_version = "0.0.0"
local_scheme = "no-local-version"

[tool.setuptools.packages.find]
include = ["sqlmesh", "sqlmesh.*", "web*"]
include = ["sqlmesh", "sqlmesh.*", "sqlmesh_dbt", "sqlmesh_dbt.*", "web*"]

[tool.setuptools.package-data]
web = ["client/dist/**"]
Expand Down
5 changes: 5 additions & 0 deletions sqlmesh_dbt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Note: `sqlmesh_dbt` is deliberately in its own package from `sqlmesh` to avoid the upfront time overhead
# that comes from `import sqlmesh`
#
# Obviously we still have to `import sqlmesh` at some point but this allows us to defer it until needed,
# which means we can make the CLI feel more responsive by being able to output something immediately
80 changes: 80 additions & 0 deletions sqlmesh_dbt/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import typing as t
import sys
import click
from sqlmesh_dbt.operations import DbtOperations, create


def _get_dbt_operations(ctx: click.Context) -> DbtOperations:
if not isinstance(ctx.obj, DbtOperations):
raise ValueError(f"Unexpected click context object: {type(ctx.obj)}")
return ctx.obj


@click.group()
@click.pass_context
def dbt(ctx: click.Context) -> None:
"""
An ELT tool for managing your SQL transformations and data models, powered by the SQLMesh engine.
"""

if "--help" in sys.argv:
# we dont need to import sqlmesh/load the project for CLI help
return

# TODO: conditionally call create() if there are times we dont want/need to import sqlmesh and load a project
ctx.obj = create()


@dbt.command()
@click.option("-s", "-m", "--select", "--models", "--model", help="Specify the nodes to include.")
@click.option(
"-f",
"--full-refresh",
help="If specified, dbt will drop incremental models and fully-recalculate the incremental table from the model definition.",
)
@click.pass_context
def run(ctx: click.Context, select: t.Optional[str], full_refresh: bool) -> None:
"""Compile SQL and execute against the current target database."""
_get_dbt_operations(ctx).run(select=select, full_refresh=full_refresh)


@dbt.command(name="list")
@click.pass_context
def list_(ctx: click.Context) -> None:
"""List the resources in your project"""
_get_dbt_operations(ctx).list_()


@dbt.command(name="ls", hidden=True) # hidden alias for list
@click.pass_context
def ls(ctx: click.Context) -> None:
"""List the resources in your project"""
ctx.forward(list_)


def _not_implemented(name: str) -> None:
@dbt.command(name=name)
def _not_implemented() -> None:
"""Not implemented"""
click.echo(f"dbt {name} not implemented")


for subcommand in (
"build",
"clean",
"clone",
"compile",
"debug",
"deps",
"docs",
"init",
"parse",
"retry",
"run-operation",
"seed",
"show",
"snapshot",
"source",
"test",
):
_not_implemented(subcommand)
8 changes: 8 additions & 0 deletions sqlmesh_dbt/console.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from sqlmesh.core.console import TerminalConsole


class DbtCliConsole(TerminalConsole):
# TODO: build this out

def print(self, msg: str) -> None:
return self._print(msg)
133 changes: 133 additions & 0 deletions sqlmesh_dbt/operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from __future__ import annotations
import typing as t
from rich.progress import Progress
from pathlib import Path

if t.TYPE_CHECKING:
# important to gate these to be able to defer importing sqlmesh until we need to
from sqlmesh.core.context import Context
from sqlmesh.dbt.project import Project
from sqlmesh_dbt.console import DbtCliConsole


class DbtOperations:
def __init__(self, sqlmesh_context: Context, dbt_project: Project):
self.context = sqlmesh_context
self.project = dbt_project

def list_(self) -> None:
for _, model in self.context.models.items():
self.console.print(model.name)

def run(self, select: t.Optional[str] = None, full_refresh: bool = False) -> None:
# A dbt run both updates data and changes schemas and has no way of rolling back so more closely maps to a SQLMesh forward-only plan
# TODO: if --full-refresh specified, mark incrementals as breaking instead of forward_only?

# TODO: we need to either convert DBT selector syntax to SQLMesh selector syntax
# or make the model selection engine configurable
select_models = None
if select:
if "," in select:
select_models = select.split(",")
else:
select_models = select.split(" ")

self.context.plan(
select_models=select_models,
no_auto_categorization=True, # everything is breaking / foward-only
run=True,
no_diff=True,
no_prompts=True,
auto_apply=True,
)

@property
def console(self) -> DbtCliConsole:
console = self.context.console
from sqlmesh_dbt.console import DbtCliConsole

if not isinstance(console, DbtCliConsole):
raise ValueError(f"Expecting dbt cli console, got: {console}")

return console


def create(
project_dir: t.Optional[Path] = None, profiles_dir: t.Optional[Path] = None, debug: bool = False
) -> DbtOperations:
with Progress(transient=True) as progress:
# Indeterminate progress bar before SQLMesh import to provide feedback to the user that something is indeed happening
load_task_id = progress.add_task("Loading engine", total=None)

from sqlmesh import configure_logging
from sqlmesh.core.context import Context
from sqlmesh.dbt.loader import sqlmesh_config, DbtLoader
from sqlmesh.core.console import set_console
from sqlmesh_dbt.console import DbtCliConsole
from sqlmesh.utils.errors import SQLMeshError

configure_logging(force_debug=debug)
set_console(DbtCliConsole())

progress.update(load_task_id, description="Loading project", total=None)

# inject default start date if one is not specified to prevent the user from having to do anything
_inject_default_start_date(project_dir)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of doing this I'd suggest extending sqlmesh_config with the start date and hardcode the start date at sqlmesh init time.

Copy link
Collaborator Author

@erindru erindru Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do this in a follow-up PR. To summarize our conversation:

  • We dont touch any existing dbt file like dbt_project.yml
  • We check for a file called sqlmesh.yaml in the root of the DBT project
  • If it exists, we read directives like the start date from it and it generally follows the same format as SQLMesh's existing config.yaml file
  • If it doesn't exist, we do an "automatic init" to create it and stamp the start date as yesterday_ds() so it can take effect on subsequent invocations

This approach also provides a place to store other configuration in future, such as the virtual data environment mode

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


config = sqlmesh_config(
project_root=project_dir,
# do we want to use a local duckdb for state?
# warehouse state has a bunch of overhead to initialize, is slow for ongoing operations and will create tables that perhaps the user was not expecting
# on the other hand, local state is not portable
state_connection=None,
)

sqlmesh_context = Context(
config=config,
load=True,
)

# this helps things which want a default project-level start date, like the "effective from date" for forward-only plans
if not sqlmesh_context.config.model_defaults.start:
min_start_date = min(
(
model.start
for model in sqlmesh_context.models.values()
if model.start is not None
),
default=None,
)
sqlmesh_context.config.model_defaults.start = min_start_date

dbt_loader = sqlmesh_context._loaders[0]
if not isinstance(dbt_loader, DbtLoader):
raise SQLMeshError(f"Unexpected loader type: {type(dbt_loader)}")

# so that DbtOperations can query information from the DBT project files in order to invoke SQLMesh correctly
dbt_project = dbt_loader._projects[0]

return DbtOperations(sqlmesh_context, dbt_project)


def _inject_default_start_date(project_dir: t.Optional[Path] = None) -> None:
"""
SQLMesh needs a start date to as the starting point for calculating intervals on incremental models

Rather than forcing the user to update their config manually or having a default that is not saved between runs,
we can inject it automatically to the dbt_project.yml file
"""
from sqlmesh.dbt.project import PROJECT_FILENAME, load_yaml
from sqlmesh.utils.yaml import dump
from sqlmesh.utils.date import yesterday_ds

project_yaml_path = (project_dir or Path.cwd()) / PROJECT_FILENAME
if project_yaml_path.exists():
loaded_project_file = load_yaml(project_yaml_path)
start_date_keys = ("start", "+start")
if "models" in loaded_project_file and all(
k not in loaded_project_file["models"] for k in start_date_keys
):
loaded_project_file["models"]["+start"] = yesterday_ds()
# todo: this may format the file differently, is that acceptable?
with project_yaml_path.open("w") as f:
dump(loaded_project_file, f)
29 changes: 29 additions & 0 deletions tests/dbt/cli/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import typing as t
from pathlib import Path
import os
import functools
from click.testing import CliRunner, Result
import pytest


@pytest.fixture
def jaffle_shop_duckdb(copy_to_temp_path: t.Callable[..., t.List[Path]]) -> t.Iterable[Path]:
fixture_path = Path(__file__).parent / "fixtures" / "jaffle_shop_duckdb"
assert fixture_path.exists()

current_path = os.getcwd()
output_path = copy_to_temp_path(paths=fixture_path)[0]

# so that we can invoke commands from the perspective of a user that is alrady in the correct directory
os.chdir(output_path)

yield output_path

os.chdir(current_path)


@pytest.fixture
def invoke_cli() -> t.Callable[..., Result]:
from sqlmesh_dbt.cli import dbt

return functools.partial(CliRunner().invoke, dbt)
34 changes: 34 additions & 0 deletions tests/dbt/cli/fixtures/jaffle_shop_duckdb/dbt_project.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: 'jaffle_shop'

config-version: 2
version: '0.1'

profile: 'jaffle_shop'

model-paths: ["models"]
seed-paths: ["seeds"]
test-paths: ["tests"]
analysis-paths: ["analysis"]
macro-paths: ["macros"]

target-path: "target"
clean-targets:
- "target"
- "dbt_modules"
- "logs"

require-dbt-version: [">=1.0.0", "<2.0.0"]

seeds:
+docs:
node_color: '#cd7f32'

models:
jaffle_shop:
+materialized: table
staging:
+materialized: view
+docs:
node_color: 'silver'
+docs:
node_color: 'gold'
69 changes: 69 additions & 0 deletions tests/dbt/cli/fixtures/jaffle_shop_duckdb/models/customers.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
with customers as (

select * from {{ ref('stg_customers') }}

),

orders as (

select * from {{ ref('stg_orders') }}

),

payments as (

select * from {{ ref('stg_payments') }}

),

customer_orders as (

select
customer_id,

min(order_date) as first_order,
max(order_date) as most_recent_order,
count(order_id) as number_of_orders
from orders

group by customer_id

),

customer_payments as (

select
orders.customer_id,
sum(amount) as total_amount

from payments

left join orders on
payments.order_id = orders.order_id

group by orders.customer_id

),

final as (

select
customers.customer_id,
customers.first_name,
customers.last_name,
customer_orders.first_order,
customer_orders.most_recent_order,
customer_orders.number_of_orders,
customer_payments.total_amount as customer_lifetime_value

from customers

left join customer_orders
on customers.customer_id = customer_orders.customer_id

left join customer_payments
on customers.customer_id = customer_payments.customer_id

)

select * from final
Loading
Loading