diff --git a/sqlmesh/cli/project_init.py b/sqlmesh/cli/project_init.py index 613ea72c45..b6dc5050bc 100644 --- a/sqlmesh/cli/project_init.py +++ b/sqlmesh/cli/project_init.py @@ -8,6 +8,7 @@ from sqlmesh.utils.date import yesterday_ds from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.core.config.common import DBT_PROJECT_FILENAME from sqlmesh.core.config.connection import ( CONNECTION_CONFIG_TO_TYPE, DIALECT_TO_TYPE, @@ -113,11 +114,10 @@ def _gen_config( - ambiguousorinvalidcolumn - invalidselectstarexpansion """, - ProjectTemplate.DBT: """from pathlib import Path - -from sqlmesh.dbt.loader import sqlmesh_config - -config = sqlmesh_config(Path(__file__).parent) + ProjectTemplate.DBT: f"""# --- Model Defaults --- +# https://sqlmesh.readthedocs.io/en/stable/reference/model_configuration/#model-defaults +model_defaults: + start: {start or yesterday_ds()} """, } @@ -285,8 +285,13 @@ def init_example_project( cli_mode: InitCliMode = InitCliMode.DEFAULT, ) -> Path: root_path = Path(path) - config_extension = "py" if template == ProjectTemplate.DBT else "yaml" - config_path = root_path / f"config.{config_extension}" + + config_path = root_path / "config.yaml" + if template == ProjectTemplate.DBT: + # name the config file `sqlmesh.yaml` to make it clear that within the context of all + # the existing yaml files DBT project, this one specifically relates to configuring the sqlmesh engine + config_path = root_path / "sqlmesh.yaml" + audits_path = root_path / "audits" macros_path = root_path / "macros" models_path = root_path / "models" @@ -298,7 +303,7 @@ def init_example_project( f"Found an existing config file '{config_path}'.\n\nPlease change to another directory or remove the existing file." ) - if template == ProjectTemplate.DBT and not Path(root_path, "dbt_project.yml").exists(): + if template == ProjectTemplate.DBT and not Path(root_path, DBT_PROJECT_FILENAME).exists(): raise SQLMeshError( "Required dbt project file 'dbt_project.yml' not found in the current directory.\n\nPlease add it or change directories before running `sqlmesh init` to set up your project." ) diff --git a/sqlmesh/core/config/common.py b/sqlmesh/core/config/common.py index 2963632041..dca472d7a9 100644 --- a/sqlmesh/core/config/common.py +++ b/sqlmesh/core/config/common.py @@ -8,6 +8,16 @@ from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.pydantic import field_validator +# Config files that can be present in the project dir +ALL_CONFIG_FILENAMES = ("config.py", "config.yml", "config.yaml", "sqlmesh.yml", "sqlmesh.yaml") + +# For personal paths (~/.sqlmesh/) where python config is not supported +YAML_CONFIG_FILENAMES = tuple(n for n in ALL_CONFIG_FILENAMES if not n.endswith(".py")) + +# Note: is here to prevent having to import from sqlmesh.dbt.loader which introduces a dependency +# on dbt-core in a native project +DBT_PROJECT_FILENAME = "dbt_project.yml" + class EnvironmentSuffixTarget(str, Enum): # Intended to create virtual environments in their own schemas, with names like "__". The view name is untouched. diff --git a/sqlmesh/core/config/loader.py b/sqlmesh/core/config/loader.py index c381252fb9..2c1554454b 100644 --- a/sqlmesh/core/config/loader.py +++ b/sqlmesh/core/config/loader.py @@ -10,6 +10,11 @@ from sqlglot.helper import ensure_list from sqlmesh.core import constants as c +from sqlmesh.core.config.common import ( + ALL_CONFIG_FILENAMES, + YAML_CONFIG_FILENAMES, + DBT_PROJECT_FILENAME, +) from sqlmesh.core.config.model import ModelDefaultsConfig from sqlmesh.core.config.root import Config from sqlmesh.utils import env_vars, merge_dicts, sys_path @@ -51,10 +56,7 @@ def load_configs( return {path: config for path in absolute_paths} config_env_vars = None - personal_paths = [ - sqlmesh_path / "config.yml", - sqlmesh_path / "config.yaml", - ] + personal_paths = [sqlmesh_path / name for name in YAML_CONFIG_FILENAMES] for path in personal_paths: if path.exists(): config_env_vars = load_config_from_yaml(path).get("env_vars") @@ -65,7 +67,7 @@ def load_configs( return { path: load_config_from_paths( config_type, - project_paths=[path / "config.py", path / "config.yml", path / "config.yaml"], + project_paths=[path / name for name in ALL_CONFIG_FILENAMES], personal_paths=personal_paths, config_name=config, ) @@ -156,6 +158,22 @@ def load_config_from_paths( ) no_dialect_err_msg = "Default model SQL dialect is a required configuration parameter. Set it in the `model_defaults` `dialect` key in your config file." + + # if "dbt_project.yml" is present *and there was no python config already defined*, + # create a basic one to ensure we are using the DBT loader. + # any config within yaml files will get overlayed on top of it. + if not python_config: + potential_project_files = [f / DBT_PROJECT_FILENAME for f in visited_folders] + dbt_project_file = next((f for f in potential_project_files if f.exists()), None) + if dbt_project_file: + from sqlmesh.dbt.loader import sqlmesh_config + + dbt_python_config = sqlmesh_config(project_root=dbt_project_file.parent) + if type(dbt_python_config) != config_type: + dbt_python_config = convert_config_type(dbt_python_config, config_type) + + python_config = dbt_python_config + if python_config: model_defaults = python_config.model_defaults if model_defaults.dialect is None: @@ -165,6 +183,7 @@ def load_config_from_paths( model_defaults = non_python_config.model_defaults if model_defaults.dialect is None: raise ConfigError(no_dialect_err_msg) + return non_python_config diff --git a/sqlmesh/dbt/common.py b/sqlmesh/dbt/common.py index 49d6c7ca18..d9db5a472c 100644 --- a/sqlmesh/dbt/common.py +++ b/sqlmesh/dbt/common.py @@ -14,10 +14,11 @@ from sqlmesh.utils.jinja import MacroReference from sqlmesh.utils.pydantic import PydanticModel, field_validator from sqlmesh.utils.yaml import load +from sqlmesh.core.config.common import DBT_PROJECT_FILENAME T = t.TypeVar("T", bound="GeneralConfig") -PROJECT_FILENAME = "dbt_project.yml" +PROJECT_FILENAME = DBT_PROJECT_FILENAME JINJA_ONLY = { "adapter", diff --git a/sqlmesh/dbt/manifest.py b/sqlmesh/dbt/manifest.py index 19795a0b9b..4f839b9c9b 100644 --- a/sqlmesh/dbt/manifest.py +++ b/sqlmesh/dbt/manifest.py @@ -32,6 +32,7 @@ from dbt.tracking import do_not_track from sqlmesh.core import constants as c +from sqlmesh.core.config import ModelDefaultsConfig from sqlmesh.dbt.basemodel import Dependencies from sqlmesh.dbt.builtin import BUILTIN_FILTERS, BUILTIN_GLOBALS, OVERRIDDEN_MACROS from sqlmesh.dbt.model import ModelConfig @@ -78,12 +79,14 @@ def __init__( target: TargetConfig, variable_overrides: t.Optional[t.Dict[str, t.Any]] = None, cache_dir: t.Optional[str] = None, + model_defaults: t.Optional[ModelDefaultsConfig] = None, ): self.project_path = project_path self.profiles_path = profiles_path self.profile_name = profile_name self.target = target self.variable_overrides = variable_overrides or {} + self.model_defaults = model_defaults or ModelDefaultsConfig() self.__manifest: t.Optional[Manifest] = None self._project_name: str = "" @@ -380,9 +383,12 @@ def _load_manifest(self) -> Manifest: profile = self._load_profile() project = self._load_project(profile) - if not any(k in project.models for k in ("start", "+start")): + if ( + not any(k in project.models for k in ("start", "+start")) + and not self.model_defaults.start + ): raise ConfigError( - "SQLMesh's requires a start date in order to have a finite range of backfilling data. Add start to the 'models:' block in dbt_project.yml. https://sqlmesh.readthedocs.io/en/stable/integrations/dbt/#setting-model-backfill-start-dates" + "SQLMesh requires a start date in order to have a finite range of backfilling data. Add start to the 'models:' block in dbt_project.yml. https://sqlmesh.readthedocs.io/en/stable/integrations/dbt/#setting-model-backfill-start-dates" ) runtime_config = RuntimeConfig.from_parts(project, profile, args) diff --git a/sqlmesh/dbt/model.py b/sqlmesh/dbt/model.py index 8563d20d22..4198fabca7 100644 --- a/sqlmesh/dbt/model.py +++ b/sqlmesh/dbt/model.py @@ -569,7 +569,7 @@ def to_sqlmesh( query, dialect=model_dialect, kind=kind, - start=self.start, + start=self.start or context.sqlmesh_config.model_defaults.start, audit_definitions=audit_definitions, path=model_kwargs.pop("path", self.path), # This ensures that we bypass query rendering that would otherwise be required to extract additional diff --git a/sqlmesh/dbt/project.py b/sqlmesh/dbt/project.py index ac36ee4e0a..d37c9cc6c4 100644 --- a/sqlmesh/dbt/project.py +++ b/sqlmesh/dbt/project.py @@ -76,6 +76,7 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N target=profile.target, variable_overrides=variable_overrides, cache_dir=context.sqlmesh_config.cache_dir, + model_defaults=context.sqlmesh_config.model_defaults, ) extra_fields = profile.target.extra diff --git a/sqlmesh_dbt/operations.py b/sqlmesh_dbt/operations.py index b826a00e37..ec07efd37b 100644 --- a/sqlmesh_dbt/operations.py +++ b/sqlmesh_dbt/operations.py @@ -61,7 +61,7 @@ def create( from sqlmesh import configure_logging from sqlmesh.core.context import Context - from sqlmesh.dbt.loader import sqlmesh_config, DbtLoader + from sqlmesh.dbt.loader import DbtLoader from sqlmesh.core.console import set_console from sqlmesh_dbt.console import DbtCliConsole from sqlmesh.utils.errors import SQLMeshError @@ -71,34 +71,14 @@ def create( progress.update(load_task_id, description="Loading project", total=None) - # inject default start date if one is not specified to prevent the user from having to do anything - _inject_default_start_date(project_dir) - - config = sqlmesh_config( - project_root=project_dir, - # do we want to use a local duckdb for state? - # warehouse state has a bunch of overhead to initialize, is slow for ongoing operations and will create tables that perhaps the user was not expecting - # on the other hand, local state is not portable - state_connection=None, - ) + project_dir = project_dir or Path.cwd() + init_project_if_required(project_dir) sqlmesh_context = Context( - config=config, + paths=[project_dir], load=True, ) - # this helps things which want a default project-level start date, like the "effective from date" for forward-only plans - if not sqlmesh_context.config.model_defaults.start: - min_start_date = min( - ( - model.start - for model in sqlmesh_context.models.values() - if model.start is not None - ), - default=None, - ) - sqlmesh_context.config.model_defaults.start = min_start_date - dbt_loader = sqlmesh_context._loaders[0] if not isinstance(dbt_loader, DbtLoader): raise SQLMeshError(f"Unexpected loader type: {type(dbt_loader)}") @@ -109,25 +89,20 @@ def create( return DbtOperations(sqlmesh_context, dbt_project) -def _inject_default_start_date(project_dir: t.Optional[Path] = None) -> None: +def init_project_if_required(project_dir: Path) -> None: """ - SQLMesh needs a start date to as the starting point for calculating intervals on incremental models + SQLMesh needs a start date to as the starting point for calculating intervals on incremental models, amongst other things Rather than forcing the user to update their config manually or having a default that is not saved between runs, - we can inject it automatically to the dbt_project.yml file + we can generate a basic SQLMesh config if it doesnt exist. + + This is preferable to trying to inject config into `dbt_project.yml` because it means we have full control over the file + and dont need to worry about accidentally reformatting it or accidentally clobbering other config """ - from sqlmesh.dbt.project import PROJECT_FILENAME, load_yaml - from sqlmesh.utils.yaml import dump - from sqlmesh.utils.date import yesterday_ds - - project_yaml_path = (project_dir or Path.cwd()) / PROJECT_FILENAME - if project_yaml_path.exists(): - loaded_project_file = load_yaml(project_yaml_path) - start_date_keys = ("start", "+start") - if "models" in loaded_project_file and all( - k not in loaded_project_file["models"] for k in start_date_keys - ): - loaded_project_file["models"]["+start"] = yesterday_ds() - # todo: this may format the file differently, is that acceptable? - with project_yaml_path.open("w") as f: - dump(loaded_project_file, f) + from sqlmesh.cli.project_init import init_example_project, ProjectTemplate + from sqlmesh.core.config.common import ALL_CONFIG_FILENAMES + from sqlmesh.core.console import get_console + + if not any(f.exists() for f in [project_dir / file for file in ALL_CONFIG_FILENAMES]): + get_console().log_warning("No existing SQLMesh config detected; creating one") + init_example_project(path=project_dir, engine_type=None, template=ProjectTemplate.DBT) diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index d1f792dc28..45accccaa8 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -1954,21 +1954,13 @@ def test_init_dbt_template(runner: CliRunner, tmp_path: Path): ) assert result.exit_code == 0 - config_path = tmp_path / "config.py" + config_path = tmp_path / "sqlmesh.yaml" assert config_path.exists() - with open(config_path) as file: - config = file.read() - - assert ( - config - == """from pathlib import Path + config = config_path.read_text() -from sqlmesh.dbt.loader import sqlmesh_config - -config = sqlmesh_config(Path(__file__).parent) -""" - ) + assert "model_defaults" in config + assert "start:" in config @time_machine.travel(FREEZE_TIME) diff --git a/tests/cli/test_project_init.py b/tests/cli/test_project_init.py new file mode 100644 index 0000000000..e89e59d90c --- /dev/null +++ b/tests/cli/test_project_init.py @@ -0,0 +1,24 @@ +import pytest +from pathlib import Path +from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.cli.project_init import init_example_project, ProjectTemplate +from sqlmesh.utils import yaml + + +def test_project_init_dbt(tmp_path: Path): + assert not len(list(tmp_path.glob("**/*"))) + + with pytest.raises(SQLMeshError, match=r"Required dbt project file.*not found"): + init_example_project(path=tmp_path, engine_type=None, template=ProjectTemplate.DBT) + + with (tmp_path / "dbt_project.yml").open("w") as f: + yaml.dump({"name": "jaffle_shop"}, f) + + init_example_project(path=tmp_path, engine_type=None, template=ProjectTemplate.DBT) + files = [f for f in tmp_path.glob("**/*") if f.is_file()] + + assert set([f.name for f in files]) == set(["sqlmesh.yaml", "dbt_project.yml"]) + + sqlmesh_config = next(f for f in files if f.name == "sqlmesh.yaml") + assert "model_defaults" in sqlmesh_config.read_text() + assert "start: " in sqlmesh_config.read_text() diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 9277fc6902..8e932ee30d 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -36,6 +36,8 @@ from sqlmesh.core.notification_target import ConsoleNotificationTarget from sqlmesh.core.user import User from sqlmesh.utils.errors import ConfigError +from sqlmesh.utils import yaml +from sqlmesh.dbt.loader import DbtLoader from tests.utils.test_filesystem import create_temp_file @@ -1441,3 +1443,100 @@ def test_physical_table_naming_convention( config = load_config_from_paths(Config, project_paths=[tmp_path / "config.yaml"]) assert config.physical_table_naming_convention == expected + + +def test_load_configs_includes_sqlmesh_yaml(tmp_path: Path): + for extension in ("yaml", "yml"): + config_file = tmp_path / f"sqlmesh.{extension}" + config_file.write_text(""" +model_defaults: + start: '2023-04-05' + dialect: bigquery""") + + configs = load_configs(config=None, config_type=Config, paths=[tmp_path]) + assert len(configs) == 1 + + config: Config = list(configs.values())[0] + + assert config.model_defaults.start == "2023-04-05" + assert config.model_defaults.dialect == "bigquery" + + config_file.unlink() + + +def test_load_configs_without_main_connection(tmp_path: Path): + # this is for DBT projects where the main connection is defined in profiles.yml + # but we also need to be able to specify the sqlmesh state connection without editing any DBT files + # and without also duplicating the main connection + config_file = tmp_path / "sqlmesh.yaml" + with config_file.open("w") as f: + yaml.dump( + { + "gateways": {"": {"state_connection": {"type": "duckdb", "database": "state.db"}}}, + "model_defaults": {"dialect": "duckdb", "start": "2020-01-01"}, + }, + f, + ) + + configs = list(load_configs(config=None, config_type=Config, paths=[tmp_path]).values()) + assert len(configs) == 1 + + config = configs[0] + state_connection_config = config.get_state_connection() + assert isinstance(state_connection_config, DuckDBConnectionConfig) + assert state_connection_config.database == "state.db" + + +def test_load_configs_in_dbt_project_without_config_py(tmp_path: Path): + # this is when someone either: + # - inits a dbt project for sqlmesh, which creates a sqlmesh.yaml file + # - uses the sqlmesh_dbt cli for the first time, which runs init if the config doesnt exist, which creates a config + # when in pure yaml mode, sqlmesh should be able to auto-detect the presence of DBT and select the DbtLoader instead + # of the main loader + (tmp_path / "dbt_project.yml").write_text(""" +name: jaffle_shop + """) + + (tmp_path / "profiles.yml").write_text(""" +jaffle_shop: + + target: dev + outputs: + dev: + type: duckdb + path: 'jaffle_shop.duckdb' + """) + + (tmp_path / "sqlmesh.yaml").write_text(""" +gateways: + dev: + state_connection: + type: duckdb + database: state.db +model_defaults: + start: '2020-01-01' +""") + + configs = list(load_configs(config=None, config_type=Config, paths=[tmp_path]).values()) + assert len(configs) == 1 + + config = configs[0] + assert config.loader == DbtLoader + + assert list(config.gateways) == ["dev"] + + # main connection + connection_config = config.get_connection() + assert connection_config + assert isinstance(connection_config, DuckDBConnectionConfig) + assert connection_config.database == "jaffle_shop.duckdb" # from dbt profiles.yml + + # state connection + state_connection_config = config.get_state_connection() + assert state_connection_config + assert isinstance(state_connection_config, DuckDBConnectionConfig) + assert state_connection_config.database == "state.db" # from sqlmesh.yaml + + # model_defaults + assert config.model_defaults.dialect == "duckdb" # from dbt profiles.yml + assert config.model_defaults.start == "2020-01-01" # from sqlmesh.yaml diff --git a/tests/dbt/cli/test_operations.py b/tests/dbt/cli/test_operations.py index e384028bbc..c35cab992c 100644 --- a/tests/dbt/cli/test_operations.py +++ b/tests/dbt/cli/test_operations.py @@ -7,45 +7,43 @@ pytestmark = pytest.mark.slow -def test_create_injects_default_start_date(jaffle_shop_duckdb: Path): +def test_create_sets_and_persists_default_start_date(jaffle_shop_duckdb: Path): with time_machine.travel("2020-01-02 00:00:00 UTC"): - from sqlmesh.utils.date import yesterday_ds + from sqlmesh.utils.date import yesterday_ds, to_ds assert yesterday_ds() == "2020-01-01" operations = create() - assert operations.context.config.model_defaults.start == "2020-01-01" + assert operations.context.config.model_defaults.start + assert to_ds(operations.context.config.model_defaults.start) == "2020-01-01" assert all( - model.start == "2020-01-01" + to_ds(model.start) if model.start else None == "2020-01-01" for model in operations.context.models.values() if not model.kind.is_seed ) # check that the date set on the first invocation persists to future invocations - from sqlmesh.utils.date import yesterday_ds + from sqlmesh.utils.date import yesterday_ds, to_ds assert yesterday_ds() != "2020-01-01" operations = create() - assert operations.context.config.model_defaults.start == "2020-01-01" + assert operations.context.config.model_defaults.start + assert to_ds(operations.context.config.model_defaults.start) == "2020-01-01" assert all( - model.start == "2020-01-01" + to_ds(model.start) if model.start else None == "2020-01-01" for model in operations.context.models.values() if not model.kind.is_seed ) def test_create_uses_configured_start_date_if_supplied(jaffle_shop_duckdb: Path): - dbt_project_yaml = jaffle_shop_duckdb / "dbt_project.yml" + sqlmesh_yaml = jaffle_shop_duckdb / "sqlmesh.yml" - contents = yaml.load(dbt_project_yaml, render_jinja=False) - - contents["models"]["+start"] = "2023-12-12" - - with dbt_project_yaml.open("w") as f: - yaml.dump(contents, f) + with sqlmesh_yaml.open("w") as f: + yaml.dump({"model_defaults": {"start": "2023-12-12"}}, f) operations = create()