Skip to content

Commit eac8745

Browse files
committed
Feat: yaml config for dbt projects
1 parent 0c70406 commit eac8745

File tree

11 files changed

+211
-73
lines changed

11 files changed

+211
-73
lines changed

sqlmesh/cli/project_init.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sqlmesh.utils.date import yesterday_ds
99
from sqlmesh.utils.errors import SQLMeshError
1010

11+
from sqlmesh.core.config.common import DBT_PROJECT_FILENAME
1112
from sqlmesh.core.config.connection import (
1213
CONNECTION_CONFIG_TO_TYPE,
1314
DIALECT_TO_TYPE,
@@ -113,11 +114,10 @@ def _gen_config(
113114
- ambiguousorinvalidcolumn
114115
- invalidselectstarexpansion
115116
""",
116-
ProjectTemplate.DBT: """from pathlib import Path
117-
118-
from sqlmesh.dbt.loader import sqlmesh_config
119-
120-
config = sqlmesh_config(Path(__file__).parent)
117+
ProjectTemplate.DBT: f"""# --- Model Defaults ---
118+
# https://sqlmesh.readthedocs.io/en/stable/reference/model_configuration/#model-defaults
119+
model_defaults:
120+
start: {start or yesterday_ds()}
121121
""",
122122
}
123123

@@ -285,8 +285,13 @@ def init_example_project(
285285
cli_mode: InitCliMode = InitCliMode.DEFAULT,
286286
) -> Path:
287287
root_path = Path(path)
288-
config_extension = "py" if template == ProjectTemplate.DBT else "yaml"
289-
config_path = root_path / f"config.{config_extension}"
288+
289+
config_path = root_path / "config.yaml"
290+
if template == ProjectTemplate.DBT:
291+
# name the config file `sqlmesh.yaml` to make it clear that within the context of all
292+
# the existing yaml files DBT project, this one specifically relates to configuring the sqlmesh engine
293+
config_path = root_path / "sqlmesh.yaml"
294+
290295
audits_path = root_path / "audits"
291296
macros_path = root_path / "macros"
292297
models_path = root_path / "models"
@@ -298,7 +303,7 @@ def init_example_project(
298303
f"Found an existing config file '{config_path}'.\n\nPlease change to another directory or remove the existing file."
299304
)
300305

301-
if template == ProjectTemplate.DBT and not Path(root_path, "dbt_project.yml").exists():
306+
if template == ProjectTemplate.DBT and not Path(root_path, DBT_PROJECT_FILENAME).exists():
302307
raise SQLMeshError(
303308
"Required dbt project file 'dbt_project.yml' not found in the current directory.\n\nPlease add it or change directories before running `sqlmesh init` to set up your project."
304309
)

sqlmesh/core/config/common.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@
88
from sqlmesh.utils.errors import ConfigError
99
from sqlmesh.utils.pydantic import field_validator
1010

11+
# Config files that can be present in the project dir
12+
ALL_CONFIG_FILENAMES = ("config.py", "config.yml", "config.yaml", "sqlmesh.yml", "sqlmesh.yaml")
13+
14+
# For personal paths (~/.sqlmesh/) where python config is not supported
15+
YAML_CONFIG_FILENAMES = (n for n in ALL_CONFIG_FILENAMES if not n.endswith(".py"))
16+
17+
# Note: is here to prevent having to import from sqlmesh.dbt.loader which introduces a dependency
18+
# on dbt-core in a native project
19+
DBT_PROJECT_FILENAME = "dbt_project.yml"
20+
1121

1222
class EnvironmentSuffixTarget(str, Enum):
1323
# Intended to create virtual environments in their own schemas, with names like "<model_schema_name>__<env name>". The view name is untouched.

sqlmesh/core/config/loader.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
from sqlglot.helper import ensure_list
1111

1212
from sqlmesh.core import constants as c
13+
from sqlmesh.core.config.common import (
14+
ALL_CONFIG_FILENAMES,
15+
YAML_CONFIG_FILENAMES,
16+
DBT_PROJECT_FILENAME,
17+
)
1318
from sqlmesh.core.config.model import ModelDefaultsConfig
1419
from sqlmesh.core.config.root import Config
1520
from sqlmesh.utils import env_vars, merge_dicts, sys_path
@@ -51,10 +56,7 @@ def load_configs(
5156
return {path: config for path in absolute_paths}
5257

5358
config_env_vars = None
54-
personal_paths = [
55-
sqlmesh_path / "config.yml",
56-
sqlmesh_path / "config.yaml",
57-
]
59+
personal_paths = [sqlmesh_path / name for name in YAML_CONFIG_FILENAMES]
5860
for path in personal_paths:
5961
if path.exists():
6062
config_env_vars = load_config_from_yaml(path).get("env_vars")
@@ -65,7 +67,7 @@ def load_configs(
6567
return {
6668
path: load_config_from_paths(
6769
config_type,
68-
project_paths=[path / "config.py", path / "config.yml", path / "config.yaml"],
70+
project_paths=[path / name for name in ALL_CONFIG_FILENAMES],
6971
personal_paths=personal_paths,
7072
config_name=config,
7173
)
@@ -156,6 +158,22 @@ def load_config_from_paths(
156158
)
157159

158160
no_dialect_err_msg = "Default model SQL dialect is a required configuration parameter. Set it in the `model_defaults` `dialect` key in your config file."
161+
162+
# if "dbt_project.yml" is present *and there was no python config already defined*,
163+
# create a basic one to ensure we are using the DBT loader.
164+
# any config within yaml files will get overlayed on top of it.
165+
if not python_config:
166+
potential_project_files = [f / DBT_PROJECT_FILENAME for f in visited_folders]
167+
dbt_project_file = next((f for f in potential_project_files if f.exists()), None)
168+
if dbt_project_file:
169+
from sqlmesh.dbt.loader import sqlmesh_config
170+
171+
dbt_python_config = sqlmesh_config(project_root=dbt_project_file.parent)
172+
if type(dbt_python_config) != config_type:
173+
dbt_python_config = convert_config_type(dbt_python_config, config_type)
174+
175+
python_config = dbt_python_config
176+
159177
if python_config:
160178
model_defaults = python_config.model_defaults
161179
if model_defaults.dialect is None:
@@ -165,6 +183,7 @@ def load_config_from_paths(
165183
model_defaults = non_python_config.model_defaults
166184
if model_defaults.dialect is None:
167185
raise ConfigError(no_dialect_err_msg)
186+
168187
return non_python_config
169188

170189

sqlmesh/dbt/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
from sqlmesh.utils.jinja import MacroReference
1515
from sqlmesh.utils.pydantic import PydanticModel, field_validator
1616
from sqlmesh.utils.yaml import load
17+
from sqlmesh.core.config.common import DBT_PROJECT_FILENAME
1718

1819
T = t.TypeVar("T", bound="GeneralConfig")
1920

20-
PROJECT_FILENAME = "dbt_project.yml"
21+
PROJECT_FILENAME = DBT_PROJECT_FILENAME
2122

2223
JINJA_ONLY = {
2324
"adapter",

sqlmesh/dbt/manifest.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from dbt.tracking import do_not_track
3333

3434
from sqlmesh.core import constants as c
35+
from sqlmesh.core.config import ModelDefaultsConfig
3536
from sqlmesh.dbt.basemodel import Dependencies
3637
from sqlmesh.dbt.builtin import BUILTIN_FILTERS, BUILTIN_GLOBALS, OVERRIDDEN_MACROS
3738
from sqlmesh.dbt.model import ModelConfig
@@ -78,12 +79,14 @@ def __init__(
7879
target: TargetConfig,
7980
variable_overrides: t.Optional[t.Dict[str, t.Any]] = None,
8081
cache_dir: t.Optional[str] = None,
82+
model_defaults: t.Optional[ModelDefaultsConfig] = None,
8183
):
8284
self.project_path = project_path
8385
self.profiles_path = profiles_path
8486
self.profile_name = profile_name
8587
self.target = target
8688
self.variable_overrides = variable_overrides or {}
89+
self.model_defaults = model_defaults or ModelDefaultsConfig()
8790

8891
self.__manifest: t.Optional[Manifest] = None
8992
self._project_name: str = ""
@@ -380,9 +383,12 @@ def _load_manifest(self) -> Manifest:
380383
profile = self._load_profile()
381384
project = self._load_project(profile)
382385

383-
if not any(k in project.models for k in ("start", "+start")):
386+
if (
387+
not any(k in project.models for k in ("start", "+start"))
388+
and not self.model_defaults.start
389+
):
384390
raise ConfigError(
385-
"SQLMesh's requires a start date in order to have a finite range of backfilling data. Add start to the 'models:' block in dbt_project.yml. https://sqlmesh.readthedocs.io/en/stable/integrations/dbt/#setting-model-backfill-start-dates"
391+
"SQLMesh requires a start date in order to have a finite range of backfilling data. Add start to the 'models:' block in dbt_project.yml. https://sqlmesh.readthedocs.io/en/stable/integrations/dbt/#setting-model-backfill-start-dates"
386392
)
387393

388394
runtime_config = RuntimeConfig.from_parts(project, profile, args)

sqlmesh/dbt/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def to_sqlmesh(
569569
query,
570570
dialect=model_dialect,
571571
kind=kind,
572-
start=self.start,
572+
start=self.start or context.sqlmesh_config.model_defaults.start,
573573
audit_definitions=audit_definitions,
574574
path=model_kwargs.pop("path", self.path),
575575
# This ensures that we bypass query rendering that would otherwise be required to extract additional

sqlmesh/dbt/project.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N
7676
target=profile.target,
7777
variable_overrides=variable_overrides,
7878
cache_dir=context.sqlmesh_config.cache_dir,
79+
model_defaults=context.sqlmesh_config.model_defaults,
7980
)
8081

8182
extra_fields = profile.target.extra

sqlmesh_dbt/operations.py

Lines changed: 17 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def create(
6161

6262
from sqlmesh import configure_logging
6363
from sqlmesh.core.context import Context
64-
from sqlmesh.dbt.loader import sqlmesh_config, DbtLoader
64+
from sqlmesh.dbt.loader import DbtLoader
6565
from sqlmesh.core.console import set_console
6666
from sqlmesh_dbt.console import DbtCliConsole
6767
from sqlmesh.utils.errors import SQLMeshError
@@ -71,34 +71,14 @@ def create(
7171

7272
progress.update(load_task_id, description="Loading project", total=None)
7373

74-
# inject default start date if one is not specified to prevent the user from having to do anything
75-
_inject_default_start_date(project_dir)
76-
77-
config = sqlmesh_config(
78-
project_root=project_dir,
79-
# do we want to use a local duckdb for state?
80-
# warehouse state has a bunch of overhead to initialize, is slow for ongoing operations and will create tables that perhaps the user was not expecting
81-
# on the other hand, local state is not portable
82-
state_connection=None,
83-
)
74+
project_dir = project_dir or Path.cwd()
75+
init_project_if_required(project_dir)
8476

8577
sqlmesh_context = Context(
86-
config=config,
78+
paths=[project_dir],
8779
load=True,
8880
)
8981

90-
# this helps things which want a default project-level start date, like the "effective from date" for forward-only plans
91-
if not sqlmesh_context.config.model_defaults.start:
92-
min_start_date = min(
93-
(
94-
model.start
95-
for model in sqlmesh_context.models.values()
96-
if model.start is not None
97-
),
98-
default=None,
99-
)
100-
sqlmesh_context.config.model_defaults.start = min_start_date
101-
10282
dbt_loader = sqlmesh_context._loaders[0]
10383
if not isinstance(dbt_loader, DbtLoader):
10484
raise SQLMeshError(f"Unexpected loader type: {type(dbt_loader)}")
@@ -109,25 +89,20 @@ def create(
10989
return DbtOperations(sqlmesh_context, dbt_project)
11090

11191

112-
def _inject_default_start_date(project_dir: t.Optional[Path] = None) -> None:
92+
def init_project_if_required(project_dir: Path) -> None:
11393
"""
114-
SQLMesh needs a start date to as the starting point for calculating intervals on incremental models
94+
SQLMesh needs a start date to as the starting point for calculating intervals on incremental models, amongst other things
11595
11696
Rather than forcing the user to update their config manually or having a default that is not saved between runs,
117-
we can inject it automatically to the dbt_project.yml file
97+
we can generate a basic SQLMesh config if it doesnt exist.
98+
99+
This is preferable to trying to inject config into `dbt_project.yml` because it means we have full control over the file
100+
and dont need to worry about accidentally reformatting it or accidentally clobbering other config
118101
"""
119-
from sqlmesh.dbt.project import PROJECT_FILENAME, load_yaml
120-
from sqlmesh.utils.yaml import dump
121-
from sqlmesh.utils.date import yesterday_ds
122-
123-
project_yaml_path = (project_dir or Path.cwd()) / PROJECT_FILENAME
124-
if project_yaml_path.exists():
125-
loaded_project_file = load_yaml(project_yaml_path)
126-
start_date_keys = ("start", "+start")
127-
if "models" in loaded_project_file and all(
128-
k not in loaded_project_file["models"] for k in start_date_keys
129-
):
130-
loaded_project_file["models"]["+start"] = yesterday_ds()
131-
# todo: this may format the file differently, is that acceptable?
132-
with project_yaml_path.open("w") as f:
133-
dump(loaded_project_file, f)
102+
from sqlmesh.cli.project_init import init_example_project, ProjectTemplate
103+
from sqlmesh.core.config.common import ALL_CONFIG_FILENAMES
104+
from sqlmesh.core.console import get_console
105+
106+
if not any(f.exists() for f in [project_dir / file for file in ALL_CONFIG_FILENAMES]):
107+
get_console().log_warning("No existing SQLMesh config detected; creating one")
108+
init_example_project(path=project_dir, engine_type=None, template=ProjectTemplate.DBT)

tests/cli/test_project_init.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import pytest
2+
from pathlib import Path
3+
from sqlmesh.utils.errors import SQLMeshError
4+
from sqlmesh.cli.project_init import init_example_project, ProjectTemplate
5+
from sqlmesh.utils import yaml
6+
7+
8+
def test_project_init_dbt(tmp_path: Path):
9+
assert not len(list(tmp_path.glob("**/*")))
10+
11+
with pytest.raises(SQLMeshError, match=r"Required dbt project file.*not found"):
12+
init_example_project(path=tmp_path, engine_type=None, template=ProjectTemplate.DBT)
13+
14+
with (tmp_path / "dbt_project.yml").open("w") as f:
15+
yaml.dump({"name": "jaffle_shop"}, f)
16+
17+
init_example_project(path=tmp_path, engine_type=None, template=ProjectTemplate.DBT)
18+
files = [f for f in tmp_path.glob("**/*") if f.is_file()]
19+
20+
assert set([f.name for f in files]) == set(["sqlmesh.yaml", "dbt_project.yml"])
21+
22+
sqlmesh_config = next(f for f in files if f.name == "sqlmesh.yaml")
23+
assert "model_defaults" in sqlmesh_config.read_text()
24+
assert "start: " in sqlmesh_config.read_text()

0 commit comments

Comments
 (0)