From ec0f2e3b5eb9d4df5ab0210720c24eac4df298db Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Thu, 28 Aug 2025 02:02:25 +0000 Subject: [PATCH 1/2] Revert "Feat(experimental): DBT project conversion (#4495)" This reverts commit 976ffee5329c46cbac9a5d480004db4bb440d272. --- sqlmesh/cli/main.py | 37 -- sqlmesh/core/config/root.py | 9 +- sqlmesh/core/constants.py | 3 - sqlmesh/core/loader.py | 109 +--- sqlmesh/core/model/definition.py | 61 +- sqlmesh/core/model/kind.py | 5 +- sqlmesh/core/renderer.py | 1 - sqlmesh/dbt/adapter.py | 3 - sqlmesh/dbt/builtin.py | 2 +- sqlmesh/dbt/converter/__init__.py | 0 sqlmesh/dbt/converter/common.py | 40 -- sqlmesh/dbt/converter/console.py | 117 ---- sqlmesh/dbt/converter/convert.py | 420 ------------ sqlmesh/dbt/converter/jinja.py | 604 ------------------ sqlmesh/dbt/converter/jinja_builtins.py | 109 ---- sqlmesh/dbt/converter/jinja_transforms.py | 465 -------------- sqlmesh/dbt/loader.py | 11 +- sqlmesh/dbt/model.py | 1 - sqlmesh/dbt/target.py | 93 +-- sqlmesh/utils/jinja.py | 99 +-- tests/core/test_config.py | 27 - tests/core/test_loader.py | 129 ---- tests/core/test_model.py | 93 +-- tests/dbt/converter/conftest.py | 21 - .../fixtures/empty_dbt_project/.gitignore | 2 - .../empty_dbt_project/analyses/.gitkeep | 0 .../fixtures/empty_dbt_project/config.py | 7 - .../empty_dbt_project/dbt_project.yml | 22 - .../empty_dbt_project/macros/.gitkeep | 0 .../empty_dbt_project/models/.gitkeep | 0 .../empty_dbt_project/models/sources.yml | 6 - .../empty_dbt_project/packages/.gitkeep | 0 .../fixtures/empty_dbt_project/profiles.yml | 6 - .../fixtures/empty_dbt_project/seeds/.gitkeep | 0 .../empty_dbt_project/seeds/items.csv | 94 --- .../empty_dbt_project/seeds/properties.yml | 13 - .../empty_dbt_project/snapshots/.gitkeep | 0 .../fixtures/empty_dbt_project/tests/.gitkeep | 0 .../converter/fixtures/jinja_nested_if.sql | 15 - .../fixtures/macro_dbt_incremental.sql | 11 - .../fixtures/macro_func_with_params.sql | 17 - .../fixtures/model_query_incremental.sql | 34 - tests/dbt/converter/test_convert.py | 105 --- tests/dbt/converter/test_jinja.py | 450 ------------- tests/dbt/converter/test_jinja_transforms.py | 453 ------------- tests/utils/test_jinja.py | 78 --- 46 files changed, 40 insertions(+), 3732 deletions(-) delete mode 100644 sqlmesh/dbt/converter/__init__.py delete mode 100644 sqlmesh/dbt/converter/common.py delete mode 100644 sqlmesh/dbt/converter/console.py delete mode 100644 sqlmesh/dbt/converter/convert.py delete mode 100644 sqlmesh/dbt/converter/jinja.py delete mode 100644 sqlmesh/dbt/converter/jinja_builtins.py delete mode 100644 sqlmesh/dbt/converter/jinja_transforms.py delete mode 100644 tests/dbt/converter/conftest.py delete mode 100644 tests/dbt/converter/fixtures/empty_dbt_project/.gitignore delete mode 100644 tests/dbt/converter/fixtures/empty_dbt_project/analyses/.gitkeep delete mode 100644 tests/dbt/converter/fixtures/empty_dbt_project/config.py delete mode 100644 tests/dbt/converter/fixtures/empty_dbt_project/dbt_project.yml delete mode 100644 tests/dbt/converter/fixtures/empty_dbt_project/macros/.gitkeep delete mode 100644 tests/dbt/converter/fixtures/empty_dbt_project/models/.gitkeep delete mode 100644 tests/dbt/converter/fixtures/empty_dbt_project/models/sources.yml delete mode 100644 tests/dbt/converter/fixtures/empty_dbt_project/packages/.gitkeep delete mode 100644 tests/dbt/converter/fixtures/empty_dbt_project/profiles.yml delete mode 100644 tests/dbt/converter/fixtures/empty_dbt_project/seeds/.gitkeep delete mode 100644 tests/dbt/converter/fixtures/empty_dbt_project/seeds/items.csv delete mode 100644 tests/dbt/converter/fixtures/empty_dbt_project/seeds/properties.yml delete mode 100644 tests/dbt/converter/fixtures/empty_dbt_project/snapshots/.gitkeep delete mode 100644 tests/dbt/converter/fixtures/empty_dbt_project/tests/.gitkeep delete mode 100644 tests/dbt/converter/fixtures/jinja_nested_if.sql delete mode 100644 tests/dbt/converter/fixtures/macro_dbt_incremental.sql delete mode 100644 tests/dbt/converter/fixtures/macro_func_with_params.sql delete mode 100644 tests/dbt/converter/fixtures/model_query_incremental.sql delete mode 100644 tests/dbt/converter/test_convert.py delete mode 100644 tests/dbt/converter/test_jinja.py delete mode 100644 tests/dbt/converter/test_jinja_transforms.py diff --git a/sqlmesh/cli/main.py b/sqlmesh/cli/main.py index 961b78069e..2d8673405f 100644 --- a/sqlmesh/cli/main.py +++ b/sqlmesh/cli/main.py @@ -39,7 +39,6 @@ "rollback", "run", "table_name", - "dbt", ) SKIP_CONTEXT_COMMANDS = ("init", "ui") @@ -1307,39 +1306,3 @@ def state_import(obj: Context, input_file: Path, replace: bool, no_confirm: bool """Import a state export file back into the state database""" confirm = not no_confirm obj.import_state(input_file=input_file, clear=replace, confirm=confirm) - - -@cli.group(no_args_is_help=True, hidden=True) -def dbt() -> None: - """Commands for doing dbt-specific things""" - pass - - -@dbt.command("convert") -@click.option( - "-i", - "--input-dir", - help="Path to the DBT project", - required=True, - type=click.Path(exists=True, dir_okay=True, file_okay=False, readable=True, path_type=Path), -) -@click.option( - "-o", - "--output-dir", - required=True, - help="Path to write out the converted SQLMesh project", - type=click.Path(exists=False, dir_okay=True, file_okay=False, readable=True, path_type=Path), -) -@click.option("--no-prompts", is_flag=True, help="Disable interactive prompts", default=False) -@click.pass_obj -@error_handler -@cli_analytics -def dbt_convert(obj: Context, input_dir: Path, output_dir: Path, no_prompts: bool) -> None: - """Convert a DBT project to a SQLMesh project""" - from sqlmesh.dbt.converter.convert import convert_project_files - - convert_project_files( - input_dir.absolute(), - output_dir.absolute(), - no_prompts=no_prompts, - ) diff --git a/sqlmesh/core/config/root.py b/sqlmesh/core/config/root.py index 65889cb7cf..9b6fae63e3 100644 --- a/sqlmesh/core/config/root.py +++ b/sqlmesh/core/config/root.py @@ -42,7 +42,7 @@ scheduler_config_validator, ) from sqlmesh.core.config.ui import UIConfig -from sqlmesh.core.loader import Loader, SqlMeshLoader, MigratedDbtProjectLoader +from sqlmesh.core.loader import Loader, SqlMeshLoader from sqlmesh.core.notification_target import NotificationTarget from sqlmesh.core.user import User from sqlmesh.utils.date import to_timestamp, now @@ -227,13 +227,6 @@ def _normalize_and_validate_fields(cls, data: t.Any) -> t.Any: f"^{k}$": v for k, v in physical_schema_override.items() } - if ( - (variables := data.get("variables", "")) - and isinstance(variables, dict) - and c.MIGRATED_DBT_PROJECT_NAME in variables - ): - data["loader"] = MigratedDbtProjectLoader - return data @model_validator(mode="after") diff --git a/sqlmesh/core/constants.py b/sqlmesh/core/constants.py index 2df7697b9d..a1d117f4fb 100644 --- a/sqlmesh/core/constants.py +++ b/sqlmesh/core/constants.py @@ -32,9 +32,6 @@ MAX_MODEL_DEFINITION_SIZE = 10000 """Maximum number of characters in a model definition""" -MIGRATED_DBT_PROJECT_NAME = "__dbt_project_name__" -MIGRATED_DBT_PACKAGES = "__dbt_packages__" - # The maximum number of fork processes, used for loading projects # None means default to process pool, 1 means don't fork, :N is number of processes diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index 8126b39107..6647a2edba 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -38,11 +38,7 @@ from sqlmesh.core.test import ModelTestMetadata, filter_tests_by_patterns from sqlmesh.utils import UniqueKeyDict, sys_path from sqlmesh.utils.errors import ConfigError -from sqlmesh.utils.jinja import ( - JinjaMacroRegistry, - MacroExtractor, - SQLMESH_DBT_COMPATIBILITY_PACKAGE, -) +from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor from sqlmesh.utils.metaprogramming import import_python_file from sqlmesh.utils.pydantic import validation_error_message from sqlmesh.utils.process import create_process_pool_executor @@ -561,7 +557,6 @@ def _load_sql_models( signals: UniqueKeyDict[str, signal], cache: CacheBase, gateway: t.Optional[str], - loading_default_kwargs: t.Optional[t.Dict[str, t.Any]] = None, ) -> UniqueKeyDict[str, Model]: """Loads the sql models into a Dict""" models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") @@ -604,7 +599,6 @@ def _load_sql_models( signal_definitions=signals, default_catalog_per_gateway=self.context.default_catalog_per_gateway, virtual_environment_mode=self.config.virtual_environment_mode, - **loading_default_kwargs or {}, ) with create_process_pool_executor( @@ -971,104 +965,3 @@ def _model_cache_entry_id(self, model_path: Path) -> str: self._loader.context.gateway or self._loader.config.default_gateway_name, ] ) - - -class MigratedDbtProjectLoader(SqlMeshLoader): - @property - def migrated_dbt_project_name(self) -> str: - return self.config.variables[c.MIGRATED_DBT_PROJECT_NAME] - - def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]: - from sqlmesh.dbt.converter.common import infer_dbt_package_from_path - from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS - - # Store a copy of the macro registry - standard_macros = macro.get_registry() - - jinja_macros = JinjaMacroRegistry( - create_builtins_module=SQLMESH_DBT_COMPATIBILITY_PACKAGE, - top_level_packages=["dbt", self.migrated_dbt_project_name], - ) - extractor = MacroExtractor() - - macros_max_mtime: t.Optional[float] = None - - for path in self._glob_paths( - self.config_path / c.MACROS, - ignore_patterns=self.config.ignore_patterns, - extension=".py", - ): - if import_python_file(path, self.config_path): - self._track_file(path) - macro_file_mtime = self._path_mtimes[path] - macros_max_mtime = ( - max(macros_max_mtime, macro_file_mtime) - if macros_max_mtime - else macro_file_mtime - ) - - for path in self._glob_paths( - self.config_path / c.MACROS, - ignore_patterns=self.config.ignore_patterns, - extension=".sql", - ): - self._track_file(path) - macro_file_mtime = self._path_mtimes[path] - macros_max_mtime = ( - max(macros_max_mtime, macro_file_mtime) if macros_max_mtime else macro_file_mtime - ) - - with open(path, "r", encoding="utf-8") as file: - try: - package = infer_dbt_package_from_path(path) or self.migrated_dbt_project_name - - jinja_macros.add_macros( - extractor.extract(file.read(), dialect=self.config.model_defaults.dialect), - package=package, - ) - except Exception as e: - raise ConfigError(f"Failed to load macro file: {e}", path) - - self._macros_max_mtime = macros_max_mtime - - macros = macro.get_registry() - macro.set_registry(standard_macros) - - connection_config = self.context.connection_config - # this triggers the DBT create_builtins_module to have a `target` property which is required for a bunch of DBT macros to work - if dbt_config_type := TARGET_TYPE_TO_CONFIG_CLASS.get(connection_config.type_): - try: - jinja_macros.add_globals( - { - "target": dbt_config_type.from_sqlmesh( - connection_config, - name=self.config.default_gateway_name, - ).attribute_dict() - } - ) - except NotImplementedError: - raise ConfigError(f"Unsupported dbt target type: {connection_config.type_}") - - return macros, jinja_macros - - def _load_sql_models( - self, - macros: MacroRegistry, - jinja_macros: JinjaMacroRegistry, - audits: UniqueKeyDict[str, ModelAudit], - signals: UniqueKeyDict[str, signal], - cache: CacheBase, - gateway: t.Optional[str], - loading_default_kwargs: t.Optional[t.Dict[str, t.Any]] = None, - ) -> UniqueKeyDict[str, Model]: - return super()._load_sql_models( - macros=macros, - jinja_macros=jinja_macros, - audits=audits, - signals=signals, - cache=cache, - gateway=gateway, - loading_default_kwargs=dict( - migrated_dbt_project_name=self.migrated_dbt_project_name, - ), - ) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 8e71b3aa02..dba8eedc31 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -2061,7 +2061,6 @@ def load_sql_based_model( variables: t.Optional[t.Dict[str, t.Any]] = None, infer_names: t.Optional[bool] = False, blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, - migrated_dbt_project_name: t.Optional[str] = None, **kwargs: t.Any, ) -> Model: """Load a model from a parsed SQLMesh model SQL file. @@ -2239,7 +2238,6 @@ def load_sql_based_model( query_or_seed_insert, kind=kind, time_column_format=time_column_format, - migrated_dbt_project_name=migrated_dbt_project_name, **common_kwargs, ) @@ -2451,7 +2449,6 @@ def _create_model( signal_definitions: t.Optional[SignalRegistry] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, - migrated_dbt_project_name: t.Optional[str] = None, **kwargs: t.Any, ) -> Model: validate_extra_and_required_fields( @@ -2531,31 +2528,16 @@ def _create_model( if jinja_macros: jinja_macros = ( - jinja_macros - if jinja_macros.trimmed - else jinja_macros.trim(jinja_macro_references, package=migrated_dbt_project_name) + jinja_macros if jinja_macros.trimmed else jinja_macros.trim(jinja_macro_references) ) else: jinja_macros = JinjaMacroRegistry() - if migrated_dbt_project_name: - # extract {{ var() }} references used in all jinja macro dependencies to check for any variables specific - # to a migrated DBT package and resolve them accordingly - # vars are added into __sqlmesh_vars__ in the Python env so that the native SQLMesh var() function can resolve them - variables = variables or {} - - nested_macro_used_variables, flattened_package_variables = ( - _extract_migrated_dbt_variable_references(jinja_macros, variables) + for jinja_macro in jinja_macros.root_macros.values(): + referenced_variables.update( + extract_macro_references_and_variables(jinja_macro.definition)[1] ) - referenced_variables.update(nested_macro_used_variables) - variables.update(flattened_package_variables) - else: - for jinja_macro in jinja_macros.root_macros.values(): - referenced_variables.update( - extract_macro_references_and_variables(jinja_macro.definition)[1] - ) - # Merge model-specific audits with default audits if default_audits := defaults.pop("audits", None): kwargs["audits"] = default_audits + d.extract_function_calls(kwargs.pop("audits", [])) @@ -2943,7 +2925,7 @@ def render_expression( "cron_tz": lambda value: exp.Literal.string(value), "partitioned_by_": _single_expr_or_tuple, "clustered_by": _single_expr_or_tuple, - "depends_on_": lambda value: exp.Tuple(expressions=sorted(value)) if value else "()", + "depends_on_": lambda value: exp.Tuple(expressions=sorted(value)), "pre": _list_of_calls_to_exp, "post": _list_of_calls_to_exp, "audits": _list_of_calls_to_exp, @@ -3020,37 +3002,4 @@ def clickhouse_partition_func( ) -def _extract_migrated_dbt_variable_references( - jinja_macros: JinjaMacroRegistry, project_variables: t.Dict[str, t.Any] -) -> t.Tuple[t.Set[str], t.Dict[str, t.Any]]: - if not jinja_macros.trimmed: - raise ValueError("Expecting a trimmed JinjaMacroRegistry") - - used_variables = set() - # note: JinjaMacroRegistry is trimmed here so "all_macros" should be just be all the macros used by this model - for _, _, jinja_macro in jinja_macros.all_macros: - _, extracted_variable_names = extract_macro_references_and_variables(jinja_macro.definition) - used_variables.update(extracted_variable_names) - - flattened = {} - if (dbt_package_variables := project_variables.get(c.MIGRATED_DBT_PACKAGES)) and isinstance( - dbt_package_variables, dict - ): - # flatten the nested dict structure from the migrated dbt package variables in the SQLmesh config into __dbt_packages.. - # to match what extract_macro_references_and_variables() returns. This allows the usage checks in create_python_env() to work - def _flatten(prefix: str, root: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - acc = {} - for k, v in root.items(): - key_with_prefix = f"{prefix}.{k}" - if isinstance(v, dict): - acc.update(_flatten(key_with_prefix, v)) - else: - acc[key_with_prefix] = v - return acc - - flattened = _flatten(c.MIGRATED_DBT_PACKAGES, dbt_package_variables) - - return used_variables, flattened - - TIME_COL_PARTITION_FUNC = {"clickhouse": clickhouse_partition_func} diff --git a/sqlmesh/core/model/kind.py b/sqlmesh/core/model/kind.py index 6fbbc3534b..dc5f533c21 100644 --- a/sqlmesh/core/model/kind.py +++ b/sqlmesh/core/model/kind.py @@ -4,7 +4,7 @@ from enum import Enum from typing_extensions import Self -from pydantic import Field, BeforeValidator +from pydantic import Field from sqlglot import exp from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import quote_identifiers @@ -33,7 +33,6 @@ field_validator, get_dialect, validate_string, - positive_int_validator, validate_expression, ) @@ -505,7 +504,7 @@ class IncrementalByUniqueKeyKind(_IncrementalBy): unique_key: SQLGlotListOfFields when_matched: t.Optional[exp.Whens] = None merge_filter: t.Optional[exp.Expression] = None - batch_concurrency: t.Annotated[t.Literal[1], BeforeValidator(positive_int_validator)] = 1 + batch_concurrency: t.Literal[1] = 1 @field_validator("when_matched", mode="before") def _when_matched_validator( diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index 8b733d4c55..4078d718a6 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -179,7 +179,6 @@ def _resolve_table(table: str | exp.Table) -> str: ) render_kwargs = { - "dialect": self._dialect, **date_dict( to_datetime(execution_time or c.EPOCH), start_time, diff --git a/sqlmesh/dbt/adapter.py b/sqlmesh/dbt/adapter.py index 2dc9890ca4..9e1ade1565 100644 --- a/sqlmesh/dbt/adapter.py +++ b/sqlmesh/dbt/adapter.py @@ -40,9 +40,6 @@ def __init__( self.jinja_globals = jinja_globals.copy() if jinja_globals else {} self.jinja_globals["adapter"] = self self.project_dialect = project_dialect - self.jinja_globals["dialect"] = ( - project_dialect # so the dialect is available in the jinja env created by self.dispatch() - ) self.quote_policy = quote_policy or Policy() @abc.abstractmethod diff --git a/sqlmesh/dbt/builtin.py b/sqlmesh/dbt/builtin.py index 4b564eb781..4edfea687a 100644 --- a/sqlmesh/dbt/builtin.py +++ b/sqlmesh/dbt/builtin.py @@ -157,7 +157,7 @@ class Var: def __init__(self, variables: t.Dict[str, t.Any]) -> None: self.variables = variables - def __call__(self, name: str, default: t.Optional[t.Any] = None, **kwargs: t.Any) -> t.Any: + def __call__(self, name: str, default: t.Optional[t.Any] = None) -> t.Any: return self.variables.get(name, default) def has_var(self, name: str) -> bool: diff --git a/sqlmesh/dbt/converter/__init__.py b/sqlmesh/dbt/converter/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/sqlmesh/dbt/converter/common.py b/sqlmesh/dbt/converter/common.py deleted file mode 100644 index 2bf4131065..0000000000 --- a/sqlmesh/dbt/converter/common.py +++ /dev/null @@ -1,40 +0,0 @@ -from __future__ import annotations -import jinja2.nodes as j -from sqlglot import exp -import typing as t -import sqlmesh.core.constants as c -from pathlib import Path - - -# jinja transform is a function that takes (current node, previous node, parent node) and returns a new Node or None -# returning None means the current node is removed from the tree -# returning a different Node means the current node is replaced with the new Node -JinjaTransform = t.Callable[[j.Node, t.Optional[j.Node], t.Optional[j.Node]], t.Optional[j.Node]] -SQLGlotTransform = t.Callable[[exp.Expression], t.Optional[exp.Expression]] - - -def _sqlmesh_predefined_macro_variables() -> t.Set[str]: - def _gen() -> t.Iterable[str]: - for suffix in ("dt", "date", "ds", "ts", "tstz", "hour", "epoch", "millis"): - for prefix in ("start", "end", "execution"): - yield f"{prefix}_{suffix}" - - for item in ("runtime_stage", "gateway", "this_model", "this_env", "model_kind_name"): - yield item - - return set(_gen()) - - -SQLMESH_PREDEFINED_MACRO_VARIABLES = _sqlmesh_predefined_macro_variables() - - -def infer_dbt_package_from_path(path: Path) -> t.Optional[str]: - """ - Given a path like "sqlmesh-project/macros/__dbt_packages__/foo/bar.sql" - - Infer that 'foo' is the DBT package - """ - if c.MIGRATED_DBT_PACKAGES in path.parts: - idx = path.parts.index(c.MIGRATED_DBT_PACKAGES) - return path.parts[idx + 1] - return None diff --git a/sqlmesh/dbt/converter/console.py b/sqlmesh/dbt/converter/console.py deleted file mode 100644 index 3fb12bcbc5..0000000000 --- a/sqlmesh/dbt/converter/console.py +++ /dev/null @@ -1,117 +0,0 @@ -from __future__ import annotations -import typing as t -from pathlib import Path -from rich.console import Console as RichConsole -from rich.tree import Tree -from rich.progress import Progress, TextColumn, BarColumn, MofNCompleteColumn, TimeElapsedColumn -from sqlmesh.core.console import PROGRESS_BAR_WIDTH -from sqlmesh.utils import columns_to_types_all_known -from sqlmesh.utils import rich as srich -import logging -from rich.prompt import Confirm - -logger = logging.getLogger(__name__) - -if t.TYPE_CHECKING: - from sqlmesh.dbt.converter.convert import ConversionReport - - -def make_progress_bar( - console: t.Optional[RichConsole] = None, - justify: t.Literal["default", "left", "center", "right", "full"] = "right", -) -> Progress: - return Progress( - TextColumn("[bold blue]{task.description}", justify=justify), - BarColumn(bar_width=PROGRESS_BAR_WIDTH), - "[progress.percentage]{task.percentage:>3.1f}%", - "•", - MofNCompleteColumn(), - "•", - TimeElapsedColumn(), - console=console, - ) - - -class DbtConversionConsole: - """Console for displaying DBT project conversion progress""" - - def __init__(self, console: t.Optional[RichConsole] = None) -> None: - self.console: RichConsole = console or srich.console - - def log_message(self, message: str) -> None: - self.console.print(message) - - def start_project_conversion(self, input_path: Path) -> None: - self.log_message(f"DBT project loaded from {input_path}; starting conversion") - - def prompt_clear_directory(self, prefix: str, path: Path) -> bool: - return Confirm.ask( - f"{prefix}'{path}' is not empty.\nWould you like to clear it?", console=self.console - ) - - # Models - def start_models_conversion(self, model_count: int) -> None: - self.progress_bar = make_progress_bar(justify="left", console=self.console) - self.progress_bar.start() - self.models_progress_task_id = self.progress_bar.add_task( - "Converting models", total=model_count - ) - - def start_model_conversion(self, model_name: str) -> None: - logger.debug(f"Converting model {model_name}") - self.progress_bar.update(self.models_progress_task_id, description=None, refresh=True) - - def complete_model_conversion(self) -> None: - self.progress_bar.update(self.models_progress_task_id, refresh=True, advance=1) - - def complete_models_conversion(self) -> None: - self.progress_bar.update(self.models_progress_task_id, description=None, refresh=True) - - # Audits - - def start_audits_conversion(self, audit_count: int) -> None: - self.audits_progress_task_id = self.progress_bar.add_task( - "Converting audits", total=audit_count - ) - - def start_audit_conversion(self, audit_name: str) -> None: - self.progress_bar.update(self.audits_progress_task_id, description=None, refresh=True) - - def complete_audit_conversion(self) -> None: - self.progress_bar.update(self.audits_progress_task_id, refresh=True, advance=1) - - def complete_audits_conversion(self) -> None: - self.progress_bar.update(self.audits_progress_task_id, description=None, refresh=True) - - # Macros - - def start_macros_conversion(self, macro_count: int) -> None: - self.macros_progress_task_id = self.progress_bar.add_task( - "Converting macros", total=macro_count - ) - - def start_macro_conversion(self, macro_name: str) -> None: - self.progress_bar.update(self.macros_progress_task_id, description=None, refresh=True) - - def complete_macro_conversion(self) -> None: - self.progress_bar.update(self.macros_progress_task_id, refresh=True, advance=1) - - def complete_macros_conversion(self) -> None: - self.progress_bar.update(self.macros_progress_task_id, description=None, refresh=True) - self.progress_bar.stop() - - def output_report(self, report: ConversionReport) -> None: - tree = Tree( - "[blue]The following models are self-referencing and their column types could not be statically inferred:" - ) - - for output_path, model in report.self_referencing_models: - if not model.columns_to_types or not columns_to_types_all_known(model.columns_to_types): - tree_node = tree.add(f"[green]{model.name}") - tree_node.add(output_path.as_posix()) - - self.console.print(tree) - - self.log_message( - "[red]These will need to be manually fixed.[/red]\nEither specify the column types in the MODEL block or ensure the outer SELECT lists all columns" - ) diff --git a/sqlmesh/dbt/converter/convert.py b/sqlmesh/dbt/converter/convert.py deleted file mode 100644 index 7eab536946..0000000000 --- a/sqlmesh/dbt/converter/convert.py +++ /dev/null @@ -1,420 +0,0 @@ -import typing as t -from pathlib import Path -import shutil -import os - -from sqlmesh.dbt.loader import sqlmesh_config, DbtLoader, DbtContext, Project -from sqlmesh.core.context import Context -import sqlmesh.core.dialect as d -from sqlmesh.core import constants as c - -from sqlmesh.core.model.kind import SeedKind -from sqlmesh.core.model import SqlModel, SeedModel -from sqlmesh.dbt.converter.jinja import convert_jinja_query, convert_jinja_macro -from sqlmesh.dbt.converter.common import infer_dbt_package_from_path -import dataclasses -from dataclasses import dataclass - -from sqlmesh.dbt.converter.console import DbtConversionConsole -from sqlmesh.utils.jinja import JinjaMacroRegistry, extract_macro_references_and_variables -from sqlmesh.utils import yaml - - -@dataclass -class ConversionReport: - self_referencing_models: t.List[t.Tuple[Path, SqlModel]] = dataclasses.field( - default_factory=list - ) - - -@dataclass -class InputPaths: - # todo: read paths from DBT project yaml - - base: Path - - @property - def models(self) -> Path: - return self.base / "models" - - @property - def seeds(self) -> Path: - return self.base / "seeds" - - @property - def tests(self) -> Path: - return self.base / "tests" - - @property - def macros(self) -> Path: - return self.base / "macros" - - @property - def snapshots(self) -> Path: - return self.base / "snapshots" - - @property - def packages(self) -> Path: - return self.base / "dbt_packages" - - -@dataclass -class OutputPaths: - base: Path - - @property - def models(self) -> Path: - return self.base / "models" - - @property - def seeds(self) -> Path: - return self.base / "seeds" - - @property - def audits(self) -> Path: - return self.base / "audits" - - @property - def macros(self) -> Path: - return self.base / "macros" - - -def convert_project_files(src: Path, dest: Path, no_prompts: bool = True) -> None: - console = DbtConversionConsole() - report = ConversionReport() - - console.log_message(f"Converting project at '{src}' to '{dest}'") - - ctx, dbt_project = _load_project(src) - dbt_load_context = dbt_project.context - - console.start_project_conversion(src) - - input_paths, output_paths = _ensure_paths(src, dest, console, no_prompts) - - model_count = len(ctx.models) - - # DBT Models -> SQLMesh Models - console.start_models_conversion(model_count) - _convert_models(ctx, input_paths, output_paths, report, console) - console.complete_models_conversion() - - # DBT Tests -> Standalone Audits - console.start_audits_conversion(len(ctx.standalone_audits)) - _convert_standalone_audits(ctx, input_paths, output_paths, console) - console.complete_audits_conversion() - - # DBT Macros -> SQLMesh Jinja Macros - all_macros = list( - iterate_macros(input_paths.macros, output_paths.macros, dbt_load_context, ctx) - ) - console.start_macros_conversion(len(all_macros)) - for package, macro_text, input_id, output_file_path, should_transform in all_macros: - console.start_macro_conversion(input_id) - - output_file_path.parent.mkdir(parents=True, exist_ok=True) - converted = ( - convert_jinja_macro(ctx, macro_text, package) if should_transform else macro_text - ) - output_file_path.write_text(converted, encoding="utf8") - - console.complete_macro_conversion() - - console.complete_macros_conversion() - - # Generate SQLMesh config - # TODO: read all profiles from config and convert to gateways instead of just the current profile? - console.log_message("Writing SQLMesh config") - new_config = _generate_sqlmesh_config(ctx, dbt_project, dbt_load_context) - (dest / "config.yml").write_text(yaml.dump(new_config)) - - if report.self_referencing_models: - console.output_report(report) - - console.log_message("All done") - - -def _load_project(src: Path) -> t.Tuple[Context, Project]: - config = sqlmesh_config(project_root=src) - - ctx = Context(config=config, paths=src) - - dbt_loader = ctx._loaders[0] - assert isinstance(dbt_loader, DbtLoader) - - dbt_project = dbt_loader._projects[0] - - return ctx, dbt_project - - -def _ensure_paths( - src: Path, dest: Path, console: DbtConversionConsole, no_prompts: bool -) -> t.Tuple[InputPaths, OutputPaths]: - if not dest.exists(): - console.log_message(f"Creating output directory: {dest}") - dest.mkdir() - - if dest.is_file(): - raise ValueError(f"Output path must be a directory") - - if any(dest.iterdir()): - if not no_prompts and console.prompt_clear_directory("Output directory ", dest): - for path in dest.glob("**/*"): - if path.is_file(): - path.unlink() - elif path.is_dir(): - shutil.rmtree(path) - console.log_message(f"Output directory '{dest}' cleared") - else: - raise ValueError("Please ensure the output directory is empty") - - input_paths = InputPaths(src) - output_paths = OutputPaths(dest) - - for dir in (output_paths.models, output_paths.seeds, output_paths.audits, output_paths.macros): - dir.mkdir() - - return input_paths, output_paths - - -def _convert_models( - ctx: Context, - input_paths: InputPaths, - output_paths: OutputPaths, - report: ConversionReport, - console: DbtConversionConsole, -) -> None: - # Iterating in DAG order helps minimize re-rendering when the fingerprint cache is busted when we call upsert_model() to check if - # a self-referencing model has all its columns_to_types known or not - for fqn in ctx.dag: - model = ctx.models.get(fqn) - - if not model: - # some entries in the dag are not models - continue - - model_name = fqn - - # todo: support DBT model_paths[] being not `models` or being a list - # todo: write out column_descriptions() into model block - console.start_model_conversion(model_name) - - if model.kind.is_external: - # skip external models - # they can be created with `sqlmesh create_external_models` post-conversion - console.complete_model_conversion() # still advance the progress bar - continue - - if model.kind.is_seed: - # this will produce the original seed file, eg "items.csv" - if model._path is None: - raise ValueError(f"Unhandled model path for model {model_name}") - seed_filename = model._path.relative_to(input_paths.seeds) - - # seed definition - rename "items.csv" -> "items.sql" - model_filename = seed_filename.with_suffix(".sql") - - # copy the seed data itself to the seeds dir - shutil.copyfile(model._path, output_paths.seeds / seed_filename) - - # monkeypatch the model kind to have a relative reference to the seed file - assert isinstance(model.kind, SeedKind) - model.kind.path = str(Path("../seeds", seed_filename)) - else: - if model._path is None: - raise ValueError(f"Unhandled model path for model {model_name}") - if input_paths.models in model._path.parents: - model_filename = model._path.relative_to(input_paths.models) - elif input_paths.snapshots in model._path.parents: - # /base/path/snapshots/foo.sql -> /output/path/models/dbt_snapshots/foo.sql - model_filename = "dbt_snapshots" / model._path.relative_to(input_paths.snapshots) - elif input_paths.packages in model._path.parents: - model_filename = c.MIGRATED_DBT_PACKAGES / model._path.relative_to( - input_paths.packages - ) - else: - raise ValueError(f"Unhandled model path: {model._path}") - - # todo: a SQLGLot transform on `audits` in the model definition to lowercase the names? - model_output_path = output_paths.models / model_filename - model_output_path.parent.mkdir(parents=True, exist_ok=True) - model_package = infer_dbt_package_from_path(model_output_path) - - def _render(e: d.exp.Expression) -> str: - if isinstance(e, d.Jinja): - e = convert_jinja_query(ctx, model, e, model_package) - rendered = e.sql(dialect=model.dialect, pretty=True) - if not isinstance(e, d.Jinja): - rendered += ";" - return rendered - - model_to_render = model.model_copy( - update=dict(depends_on_=None if len(model.depends_on) > 0 else set()) - ) - if isinstance(model, (SqlModel, SeedModel)): - # Keep depends_on for SQL Models because sometimes the entire query is a macro call. - # If we clear it and rely on inference, the SQLMesh native loader will throw: - # - ConfigError: Dependencies must be provided explicitly for models that can be rendered only at runtime - model_to_render = model.model_copy( - update=dict(depends_on_=resolve_fqns_to_model_names(ctx, model.depends_on)) - ) - - rendered_queries = [ - _render(q) - for q in model_to_render.render_definition(render_query=False, include_python=False) - ] - - # add inline audits - # todo: handle these better - # maybe output generic audits for the 4 DBT audits (not_null, unique, accepted_values, relationships) and emit definitions for them? - for _, audit in model.audit_definitions.items(): - rendered_queries.append("\n" + _render(d.parse_one(f"AUDIT (name {audit.name})"))) - # todo: or do we want the original? - rendered_queries.append(_render(model.render_audit_query(audit))) - - model_definition = "\n".join(rendered_queries) - - model_output_path.write_text(model_definition) - - console.complete_model_conversion() - - -def _convert_standalone_audits( - ctx: Context, input_paths: InputPaths, output_paths: OutputPaths, console: DbtConversionConsole -) -> None: - for _, audit in ctx.standalone_audits.items(): - console.start_audit_conversion(audit.name) - audit_definition = audit.render_definition(include_python=False) - - stringified = [] - for expression in audit_definition: - if isinstance(expression, d.JinjaQuery): - expression = convert_jinja_query(ctx, audit, expression) - stringified.append(expression.sql(dialect=audit.dialect, pretty=True)) - - audit_definition_string = ";\n".join(stringified) - - if audit._path is None: - continue - audit_filename = audit._path.relative_to(input_paths.tests) - audit_output_path = output_paths.audits / audit_filename - audit_output_path.write_text(audit_definition_string) - console.complete_audit_conversion() - return None - - -def _generate_sqlmesh_config( - ctx: Context, dbt_project: Project, dbt_load_context: DbtContext -) -> t.Dict[str, t.Any]: - DEFAULT_ARGS: t.Dict[str, t.Any] - from sqlmesh.utils.pydantic import DEFAULT_ARGS - - base_config = ctx.config.model_dump( - mode="json", include={"gateways", "model_defaults", "variables"}, **DEFAULT_ARGS - ) - # Extend with the variables loaded from DBT - if "variables" not in base_config: - base_config["variables"] = {} - if c.MIGRATED_DBT_PACKAGES not in base_config["variables"]: - base_config["variables"][c.MIGRATED_DBT_PACKAGES] = {} - - # this is used when loading with the native loader to set the package name for top level macros - base_config["variables"][c.MIGRATED_DBT_PROJECT_NAME] = dbt_project.context.project_name - - migrated_package_names = [] - for package in dbt_project.packages.values(): - dbt_load_context.set_and_render_variables(package.variables, package.name) - - if package.name == dbt_project.context.project_name: - base_config["variables"].update(dbt_load_context.variables) - else: - base_config["variables"][c.MIGRATED_DBT_PACKAGES][package.name] = ( - dbt_load_context.variables - ) - migrated_package_names.append(package.name) - - for package_name in migrated_package_names: - # these entries are duplicates because the DBT loader already applies any project specific overrides to the - # package level variables - base_config["variables"].pop(package_name, None) - - return base_config - - -def iterate_macros( - input_macros_dir: Path, output_macros_dir: Path, dbt_load_context: DbtContext, ctx: Context -) -> t.Iterator[t.Tuple[t.Optional[str], str, str, Path, bool]]: - """ - Return an iterator over all the macros that need to be migrated - - The main project level ones are read from the source macros directory (it's assumed these are written by the user) - - The rest / library level ones are read from the DBT manifest based on merging together all the model JinjaMacroRegistry's from the SQLMesh context - """ - - all_macro_references = set() - - for dirpath, _, files in os.walk( - input_macros_dir - ): # note: pathlib doesnt have a walk function until python 3.12 - for name in files: - if name.lower().endswith(".sql"): - input_file_path = Path(dirpath) / name - - output_file_path = output_macros_dir / ( - input_file_path.relative_to(input_macros_dir) - ) - - input_file_contents = input_file_path.read_text(encoding="utf8") - - # as we migrate user-defined macros, keep track of other macros they reference from other packages/libraries - # so we can be sure theyre included - # (since there is no guarantee a model references a user-defined macro which means the dependencies may not be pulled in automatically) - macro_refs, _ = extract_macro_references_and_variables( - input_file_contents, dbt_target_name=dbt_load_context.target_name - ) - all_macro_references.update(macro_refs) - - yield ( - None, - input_file_contents, - str(input_file_path), - output_file_path, - True, - ) - - jmr = JinjaMacroRegistry() - for model in ctx.models.values(): - jmr = jmr.merge(model.jinja_macros) - - # add any macros that are referenced in user macros but not necessarily directly in models - # this can happen if a user has defined a macro that is currently unused in a model but we still want to migrate it - jmr = jmr.merge( - dbt_load_context.jinja_macros.trim( - all_macro_references, package=dbt_load_context.project_name - ) - ) - - for package, name, macro in jmr.all_macros: - if package and package != dbt_load_context.project_name: - output_file_path = output_macros_dir / c.MIGRATED_DBT_PACKAGES / package / f"{name}.sql" - - yield ( - package, - macro.definition, - f"{package}.{name}", - output_file_path, - "var(" in macro.definition, # todo: check for ref() etc as well? - ) - - -def resolve_fqns_to_model_names(ctx: Context, fqns: t.Set[str]) -> t.Set[str]: - # model.depends_on is provided by the DbtLoader as a list of fully qualified table name strings - # if we output them verbatim, when loading them back we get errors like: - # - ConfigError: Failed to load model definition: 'Dot' object has no attribute 'catalog' - # So we need to resolve them to model names instead. - # External models also need to be excluded because the "name" is still a FQN string so cause the above error - - return { - ctx.models[i].name for i in fqns if i in ctx.models and not ctx.models[i].kind.is_external - } diff --git a/sqlmesh/dbt/converter/jinja.py b/sqlmesh/dbt/converter/jinja.py deleted file mode 100644 index 783ae5a74f..0000000000 --- a/sqlmesh/dbt/converter/jinja.py +++ /dev/null @@ -1,604 +0,0 @@ -import typing as t -import jinja2.nodes as j -import sqlmesh.core.dialect as d -from sqlmesh.core.context import Context -from sqlmesh.core.snapshot import Node -from sqlmesh.core.model import SqlModel, load_sql_based_model -from sqlglot import exp -from sqlmesh.dbt.converter.common import JinjaTransform -from inspect import signature -from more_itertools import windowed -from itertools import chain -from sqlmesh.dbt.context import DbtContext -import sqlmesh.dbt.converter.jinja_transforms as jt -from sqlmesh.utils.errors import ConfigError -from sqlmesh.utils.jinja import SQLMESH_DBT_COMPATIBILITY_PACKAGE - -# for j.Operand.op -OPERATOR_MAP = { - "eq": "==", - "ne": "!=", - "lt": "<", - "gt": ">", - "lteq": "<=", - "gteq": ">=", - "in": "in", - "notin": "not in", -} - - -def lpad_windowed(iterable: t.Iterable[j.Node]) -> t.Iterator[t.Tuple[t.Optional[j.Node], j.Node]]: - for prev, curr in windowed(chain([None], iterable), 2): - if curr is None: - raise ValueError("Current item cannot be None") - yield prev, curr - - -class JinjaGenerator: - def generate( - self, node: j.Node, prev: t.Optional[j.Node] = None, parent: t.Optional[j.Node] = None - ) -> str: - if not isinstance(node, j.Node): - raise ValueError(f"Generator only works with Jinja AST nodes, not: {type(node)}") - - acc = "" - - node_type = type(node) - generator_fn_name = f"_generate_{node_type.__name__.lower()}" - - if generator_fn := getattr(self, generator_fn_name, None): - sig = signature(generator_fn) - kwargs: t.Dict[str, t.Optional[j.Node]] = {"node": node} - if "prev" in sig.parameters: - kwargs["prev"] = prev - if "parent" in sig.parameters: - kwargs["parent"] = parent - acc += generator_fn(**kwargs) - else: - raise NotImplementedError(f"Generator for node type '{type(node)}' is not implemented") - - return acc - - def _generate_template(self, node: j.Template) -> str: - acc = [] - for prev, curr in lpad_windowed(node.body): - if curr: - acc.append(self.generate(curr, prev, node)) - - return "".join(acc) - - def _generate_output(self, node: j.Output) -> str: - acc = [] - for prev, curr in lpad_windowed(node.nodes): - acc.append(self.generate(curr, prev, node)) - - return "".join(acc) - - def _generate_templatedata(self, node: j.TemplateData) -> str: - return node.data - - def _generate_name( - self, node: j.Name, prev: t.Optional[j.Node], parent: t.Optional[j.Node] - ) -> str: - return self._wrap_in_expression_if_necessary(node.name, prev, parent) - - def _generate_getitem( - self, node: j.Getitem, prev: t.Optional[j.Node], parent: t.Optional[j.Node] - ) -> str: - item_name = self.generate(node.node, parent=node) - if node.arg: - if node.node.find(j.Filter): - # for when someone has {{ (foo | bar | baz)[0] }} - item_name = f"({item_name})" - item_name = f"{item_name}[{self.generate(node.arg, parent=node)}]" - - return self._wrap_in_expression_if_necessary(item_name, prev, parent) - - def _generate_getattr( - self, node: j.Getattr, prev: t.Optional[j.Node], parent: t.Optional[j.Node] - ) -> str: - what_str = self.generate(node.node, parent=node) - - return self._wrap_in_expression_if_necessary(f"{what_str}.{node.attr}", prev, parent) - - def _generate_const( - self, node: j.Const, prev: t.Optional[j.Node], parent: t.Optional[j.Node] - ) -> str: - quotechar = "" - node_value: str - if isinstance(node.value, str): - quotechar = "'" if "'" not in node.value else '"' - node_value = node.value - else: - node_value = str(node.value) - - const_value = quotechar + node_value + quotechar - - return self._wrap_in_expression_if_necessary(const_value, prev, parent) - - def _generate_keyword(self, node: j.Keyword) -> str: - return node.key + "=" + self.generate(node.value, parent=node) - - def _generate_test(self, node: j.Test, parent: t.Optional[j.Node]) -> str: - var_name = self.generate(node.node, parent=node) - test = "is" if not isinstance(parent, j.Not) else "is not" - if node.name: - return f"{var_name} {test} {node.name}" - return var_name - - def _generate_assign(self, node: j.Assign) -> str: - target_str = self.generate(node.target, parent=node) - what_str = self.generate(node.node, parent=node) - return "{% set " + target_str + " = " + what_str + " %}" - - def _generate_assignblock(self, node: j.AssignBlock) -> str: - target_str = self.generate(node.target, parent=node) - body_str = "".join(self.generate(c, parent=node) for c in node.body) - # todo: node.filter? - return "{% set " + target_str + " %}" + body_str + "{% endset %}" - - def _generate_call( - self, node: j.Call, prev: t.Optional[j.Node], parent: t.Optional[j.Node] - ) -> str: - call_name = self.generate(node.node, parent=node) - call_args = ", ".join(self.generate(a, parent=node) for a in node.args) - call_kwargs = ", ".join(self.generate(a, parent=node) for a in node.kwargs) - sep = ", " if call_args and call_kwargs else "" - call_str = call_name + f"({call_args}{sep}{call_kwargs})" - - return self._wrap_in_expression_if_necessary(call_str, prev, parent) - - def _generate_if(self, node: j.If, parent: t.Optional[j.Node]) -> str: - test_str = self.generate(node.test, parent=node) - body_str = "".join(self.generate(c, parent=node) for c in node.body) - elifs_str = "".join(self.generate(c, parent=node) for c in node.elif_) - elses_str = "".join(self.generate(c, parent=node) for c in node.else_) - - end_block_name: t.Optional[str] - block_name, end_block_name = "if", "endif" - if isinstance(parent, j.If): - if node in parent.elif_: - block_name, end_block_name = "elif", None - - end_block = "{% " + end_block_name + " %}" if end_block_name else "" - - elses_str = "{% else %}" + elses_str if elses_str else "" - - return ( - "{% " - + block_name - + " " - + test_str - + " %}" - + body_str - + elifs_str - + elses_str - + end_block - ) - - def _generate_macro(self, node: j.Macro, prev: t.Optional[j.Node]) -> str: - name_str = node.name - rendered_defaults = list(reversed([self.generate(d, parent=node) for d in node.defaults])) - rendered_args = [self.generate(a, parent=node) for a in node.args] - - # the defaults, if they exist, line up with the last arguments in the list - # so we reverse the lists to match the arrays and then reverse the result to get the original order - args_with_defaults = [ - (arg, next(iter(rendered_defaults[idx : idx + 1]), None)) - for idx, arg in enumerate(reversed(rendered_args)) - ] - args_with_defaults = list(reversed(args_with_defaults)) - - args_str = ", ".join(f"{a}={d}" if d is not None else a for a, d in args_with_defaults) - body_str = "".join(self.generate(c, parent=node) for c in node.body) - - # crude sql comment detection that will cause false positives that hopefully shouldnt matter - # this is to work around a WONTFIX bug in the SQLGlot tokenizer that if the macro body contains a SQL comment - # and {% endmacro %} is on the same line, it gets included as comment instead of a proper token - # the bug also occurs if the {% macro %} tag is on a line that starts with a SQL comment - start_tag = "{% macro " - if prev: - prev_str = self.generate(prev) - if "--" in prev_str and not prev_str.rstrip(" ").endswith("\n"): - start_tag = "\n" + start_tag - - end_tag = "{% endmacro %}" - if "--" in body_str and not body_str.rstrip(" ").endswith("\n"): - end_tag = "\n" + end_tag - - return start_tag + name_str + "(" + args_str + ")" + " %}" + body_str + end_tag - - def _generate_for(self, node: j.For) -> str: - target_str = self.generate(node.target, parent=node) - iter_str = self.generate(node.iter, parent=node) - test_str = "if " + self.generate(node.test, parent=node) if node.test else None - body_str = "".join(self.generate(c, parent=node) for c in node.body) - - acc = "{% for " + target_str + " in " + iter_str - if test_str: - acc += f" {test_str}" - acc += " %}" - acc += body_str - acc += "{% endfor %}" - - return acc - - def _generate_list(self, node: j.List, parent: t.Optional[j.Node]) -> str: - items_str_array = [self.generate(i, parent=node) for i in node.items] - items_on_newline = ( - not isinstance(parent, j.Pair) - and len(items_str_array) > 1 - and any(len(i) > 50 for i in items_str_array) - ) - item_separator = "\n\t" if items_on_newline else " " - items_str = f",{item_separator}".join(items_str_array) - start_separator = "\n\t" if items_on_newline else "" - end_separator = "\n" if items_on_newline else "" - return f"[{start_separator}{items_str}{end_separator}]" - - def _generate_dict(self, node: j.Dict) -> str: - items_str = ", ".join(self.generate(c, parent=node) for c in node.items) - return "{ " + items_str + " }" - - def _generate_pair(self, node: j.Pair) -> str: - key_str = self.generate(node.key, parent=node) - value_str = self.generate(node.value, parent=node) - return f"{key_str}: {value_str}" - - def _generate_not(self, node: j.Not) -> str: - if isinstance(node.node, j.Test): - return self.generate(node.node, parent=node) - - return self.__generate_unaryexp(node) - - def _generate_neg(self, node: j.Neg) -> str: - return self.__generate_unaryexp(node) - - def _generate_pos(self, node: j.Pos) -> str: - return self.__generate_unaryexp(node) - - def _generate_compare(self, node: j.Compare) -> str: - what_str = self.generate(node.expr, parent=node) - - # todo: is this correct? need to test with multiple ops - ops_str = "".join(self.generate(o, parent=node) for o in node.ops) - - return f"{what_str} {ops_str}" - - def _generate_slice(self, node: j.Slice) -> str: - start_str = self.generate(node.start, parent=node) if node.start else "" - stop_str = self.generate(node.stop, parent=node) if node.stop else "" - # todo: need a syntax example of step - return f"{start_str}:{stop_str}" - - def _generate_operand(self, node: j.Operand) -> str: - assert isinstance(node, j.Operand) - value_str = self.generate(node.expr, parent=node) - - return f"{OPERATOR_MAP[node.op]} " + value_str - - def _generate_add(self, node: j.Add, parent: t.Optional[j.Node]) -> str: - return self.__generate_binexp(node, parent) - - def _generate_mul(self, node: j.Mul, parent: t.Optional[j.Node]) -> str: - return self.__generate_binexp(node, parent) - - def _generate_div(self, node: j.Div, parent: t.Optional[j.Node]) -> str: - return self.__generate_binexp(node, parent) - - def _generate_sub(self, node: j.Sub, parent: t.Optional[j.Node]) -> str: - return self.__generate_binexp(node, parent) - - def _generate_floordiv(self, node: j.FloorDiv, parent: t.Optional[j.Node]) -> str: - return self.__generate_binexp(node, parent) - - def _generate_mod(self, node: j.Mod, parent: t.Optional[j.Node]) -> str: - return self.__generate_binexp(node, parent) - - def _generate_pow(self, node: j.Pow, parent: t.Optional[j.Node]) -> str: - return self.__generate_binexp(node, parent) - - def _generate_or(self, node: j.Or, parent: t.Optional[j.Node]) -> str: - return self.__generate_binexp(node, parent) - - def _generate_and(self, node: j.And, parent: t.Optional[j.Node]) -> str: - return self.__generate_binexp(node, parent) - - def _generate_concat(self, node: j.Concat) -> str: - return " ~ ".join(self.generate(c, parent=node) for c in node.nodes) - - def _generate_tuple(self, node: j.Tuple, parent: t.Optional[j.Node]) -> str: - parenthesis = isinstance(parent, (j.Operand, j.Call)) - items_str = ", ".join(self.generate(i, parent=node) for i in node.items) - return items_str if not parenthesis else f"({items_str})" - - def _generate_filter( - self, node: j.Filter, prev: t.Optional[j.Node], parent: t.Optional[j.Node] - ) -> str: - # node.node may be None if this Filter is part of a FilterBlock - what_str = self.generate(node.node, parent=node) if node.node else None - if isinstance(node.node, j.CondExpr): - what_str = f"({what_str})" - - args_str = ", ".join(self.generate(a, parent=node) for a in node.args + node.kwargs) - if args_str: - args_str = f"({args_str})" - - filter_expr = f"{node.name}{args_str}" - if what_str: - filter_expr = f"{what_str} | {filter_expr}" - - return self._wrap_in_expression_if_necessary(filter_expr, prev=prev, parent=parent) - - def _generate_filterblock(self, node: j.FilterBlock) -> str: - filter_str = self.generate(node.filter, parent=node) - body_str = "".join(self.generate(c, parent=node) for c in node.body) - return "{% filter " + filter_str + " %}" + body_str + "{% endfilter %}" - - def _generate_exprstmt(self, node: j.ExprStmt) -> str: - node_str = self.generate(node.node, parent=node) - return "{% do " + node_str + " %}" - - def _generate_condexpr( - self, node: j.CondExpr, prev: t.Optional[j.Node], parent: t.Optional[j.Node] - ) -> str: - test_sql = self.generate(node.test, parent=node) - expr1_sql = self.generate(node.expr1, parent=node) - - if node.expr2 is None: - raise ValueError("CondExpr lacked an 'else', not sure how to handle this") - - expr2_sql = self.generate(node.expr2, parent=node) - return self._wrap_in_expression_if_necessary( - f"{expr1_sql} if {test_sql} else {expr2_sql}", prev, parent - ) - - def __generate_binexp(self, node: j.BinExpr, parent: t.Optional[j.Node]) -> str: - left_str = self.generate(node.left, parent=node) - right_str = self.generate(node.right, parent=node) - - wrap_left = isinstance(node.left, j.BinExpr) - wrap_right = isinstance(node.right, j.BinExpr) - - acc = f"({left_str})" if wrap_left else left_str - acc += f" {node.operator} " - acc += f"({right_str})" if wrap_right else right_str - - return acc - - def __generate_unaryexp(self, node: j.UnaryExpr) -> str: - body_str = self.generate(node.node, parent=node) - return f"{node.operator} {body_str}" - - def _generate_nsref(self, node: j.NSRef) -> str: - return f"{node.name}.{node.attr}" - - def _generate_callblock(self, node: j.CallBlock) -> str: - call = self.generate(node.call, parent=node) - body = "".join(self.generate(e, parent=node) for e in node.body) - args = ", ".join(self.generate(arg, parent=node) for arg in node.args) - - open_tag = "{% call" - - if args: - open_tag += "(" + args + ")" - - if len(node.defaults) > 0: - raise NotImplementedError("Not sure how to handle CallBlock.defaults") - - return open_tag + " " + call + " %}" + body + "{% endcall %}" - - def _wrap_in_expression_if_necessary( - self, string: str, prev: t.Optional[j.Node], parent: t.Optional[j.Node] - ) -> str: - wrap = False - if isinstance(prev, j.TemplateData): - wrap = True - elif prev is None and isinstance(parent, j.Output): - wrap = True - elif parent: - # if the node is nested inside eg an {% if %} block, dont wrap it in {{ }} - wrap = not any(isinstance(parent, t) for t in (j.Operand, j.Stmt, j.Expr, j.Helper)) - - return "{{ " + string + " }}" if wrap else string - - -def _contains_jinja(query: str) -> bool: - if "{{" in query: - return True - if "{%" in query: - return True - return False - - -def transform(base: j.Node, handler: JinjaTransform) -> j.Node: - sig = signature(handler) - - def _build_handler_kwargs( - node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node] - ) -> t.Dict[str, t.Any]: - kwargs: t.Dict[str, t.Optional[j.Node]] = {"node": node} - if "prev" in sig.parameters: - kwargs["prev"] = prev - if "parent" in sig.parameters: - kwargs["parent"] = parent - return kwargs - - def _transform( - node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node] - ) -> t.Optional[j.Node]: - transformed_node: t.Optional[j.Node] = handler(**_build_handler_kwargs(node, prev, parent)) # type: ignore - - if not transformed_node: - return None - - node = transformed_node - - new_children: t.Dict[j.Node, t.Optional[j.Node]] = {} - prev = None - for child in list(node.iter_child_nodes()): - transformed_child = _transform(node=child, prev=prev, parent=node) - if transformed_child != child: - new_children[child] = transformed_child - prev = child - - if new_children: - replacement_fields: t.Dict[str, t.Union[j.Node, t.List[j.Node]]] = {} - for name, value in node.iter_fields(): - assert isinstance(name, str) - - if isinstance(value, list): - replacement_value_list = [new_children.get(i, i) for i in value] - replacement_fields[name] = [r for r in replacement_value_list if r is not None] - elif isinstance(value, j.Node): - replacement_value = new_children.get(value) or value - replacement_fields[name] = replacement_value - for name, value in replacement_fields.items(): - setattr(node, name, value) - - return node - - transformed = _transform(node=base, prev=None, parent=None) - if transformed is None: - raise ValueError( - f"Transform '{handler.__name__}' consumed the entire AST; this indicates a bug" - ) - return transformed - - -def convert_jinja_query( - context: Context, - node: Node, - query: d.Jinja, - package: t.Optional[str] = None, - exclude: t.Optional[t.List[t.Callable]] = None, -) -> t.Union[d.JinjaQuery, d.JinjaStatement, exp.Query, exp.DDL]: - jinja_env = node.jinja_macros.build_environment() - - ast: j.Node = jinja_env.parse(query.text("this")) # type: ignore - - transforms = [ - # transform {{ ref("foo") }} -> schema.foo (NOT "fully_qualified"."schema"."foo") - jt.resolve_dbt_ref_to_model_name(context.models, jinja_env, node.dialect), - # Rewrite ref() calls that cant be converted to strings (maybe theyre macro aguments) to __migrated_ref() calls - jt.rewrite_dbt_ref_to_migrated_ref(context.models, jinja_env, node.dialect), - # transform {{ source("upstream"."foo") }} -> upstream.foo (NOT "fully_qualified"."upstream"."foo") - jt.resolve_dbt_source_to_model_name(context.models, jinja_env, node.dialect), - # Rewrite source() calls that cant be converted to strings (maybe theyre macro aguments) to __migrated_source() calls - jt.rewrite_dbt_source_to_migrated_source(context.models, jinja_env, node.dialect), - # transform {{ this }} -> model.name - jt.resolve_dbt_this_to_model_name(node.name), - # deuplicate where both {% if sqlmesh_incremental %} and {% if is_incremental() %} are used - jt.deduplicate_incremental_checks(), - # unpack {% if is_incremental() %} blocks because they arent necessary when running a native project - jt.unpack_incremental_checks(), - ] - - if package: - transforms.append(jt.append_dbt_package_kwarg_to_var_calls(package)) - - transforms = [ - t for t in transforms if not any(e.__name__ in t.__name__ for e in (exclude or [])) - ] - - for handler in transforms: - ast = transform(ast, handler) - - generator = JinjaGenerator() - pre_post_processing = generator.generate(ast) - if isinstance(node, SqlModel) and isinstance(query, d.JinjaQuery) and not node.depends_on_self: - # is it self-referencing now is_incremental() has been removed? - # if so, and columns_to_types are not all known, then we can't remove is_incremental() or we will get a load error - - # try to load the converted model with the native loader - model_definition = node.copy(update=dict(audits=[])).render_definition()[0].sql() - - # we need the Jinja builtins that inclide the compatibility shims because the transforms may have created eg __migrated_ref() calls - jinja_macros = node.jinja_macros.copy( - update=dict(create_builtins_module=SQLMESH_DBT_COMPATIBILITY_PACKAGE) - ) - - converted_node = load_sql_based_model( - expressions=[d.parse_one(model_definition), d.JinjaQuery(this=pre_post_processing)], - jinja_macros=jinja_macros, - defaults=context.config.model_defaults.dict(), - default_catalog=node.default_catalog, - ) - original_model = context.models[node.fqn] - - if converted_node.depends_on_self: - try: - # we need to upsert the model into the context to trigger columns_to_types inference - # note that this can sometimes bust the optimized query cache which can lead to long pauses converting some models in large projects - context.upsert_model(converted_node) - except ConfigError as e: - if "Self-referencing models require inferrable column types" in str(e): - # we have a self-referencing model where the columns_to_types cannot be inferred - # run the conversion again without the unpack_incremental_checks transform - return convert_jinja_query( - context, node, query, exclude=[jt.unpack_incremental_checks] - ) - raise - except Exception: - # todo: perhaps swallow this so that we just continue on with the original logic - raise - finally: - context.upsert_model(original_model) # put the original model definition back - - ast = transform(ast, jt.rewrite_sqlmesh_predefined_variables_to_sqlmesh_macro_syntax()) - post_processed = generator.generate(ast) - - # post processing - have we removed all the jinja so this can effectively be a normal SQL query? - if not _contains_jinja(post_processed): - parsed = d.parse_one(post_processed, dialect=node.dialect) - - # converting DBT '{{ start_ds }}' to a SQLMesh macro results in single quoted '@start_ds' but we really need unquoted @start_ds - transformed = parsed.transform(jt.unwrap_macros_in_string_literals()) - if isinstance(transformed, (exp.Query, exp.DDL)): - return transformed - - raise ValueError( - f"Transformation resulted in a {type(transformed)} node instead of Query / DDL statement" - ) - - if isinstance(query, d.JinjaQuery): - return d.JinjaQuery(this=pre_post_processing) - if isinstance(query, d.JinjaStatement): - return d.JinjaStatement(this=pre_post_processing) - - raise ValueError(f"Not sure how to handle: {type(query)}") - - -def convert_jinja_macro(context: Context, src: str, package: t.Optional[str] = None) -> str: - jinja_macros = DbtContext().jinja_macros # ensures the correct create_builtins_module is set - jinja_macros = jinja_macros.merge(context._jinja_macros) - - jinja_env = jinja_macros.build_environment() - - dialect = context.default_dialect - if not dialect: - raise ValueError("No project dialect configured?") - - transforms = [ - # transform {{ ref("foo") }} -> schema.foo (NOT "fully_qualified"."schema"."foo") - jt.resolve_dbt_ref_to_model_name(context.models, jinja_env, dialect), - # Rewrite ref() calls that cant be converted to strings (maybe theyre macro aguments) to __migrated_ref() calls - jt.rewrite_dbt_ref_to_migrated_ref(context.models, jinja_env, dialect), - # transform {{ source("foo", "bar") }} -> `qualified`.`foo`.`bar` - jt.resolve_dbt_source_to_model_name(context.models, jinja_env, dialect), - # transform {{ var('foo') }} -> {{ var('foo', __dbt_package='') }} - jt.append_dbt_package_kwarg_to_var_calls(package), - # deduplicate where both {% if sqlmesh_incremental %} and {% if is_incremental() %} are used - jt.deduplicate_incremental_checks(), - # unpack {% if sqlmesh_incremental %} blocks because they arent necessary when running a native project - jt.unpack_incremental_checks(), - ] - - ast: j.Node = jinja_env.parse(src) - - for handler in transforms: - ast = transform(ast, handler) - - generator = JinjaGenerator() - - return generator.generate(ast) diff --git a/sqlmesh/dbt/converter/jinja_builtins.py b/sqlmesh/dbt/converter/jinja_builtins.py deleted file mode 100644 index 59303ad344..0000000000 --- a/sqlmesh/dbt/converter/jinja_builtins.py +++ /dev/null @@ -1,109 +0,0 @@ -import typing as t -import functools -from sqlmesh.utils.jinja import JinjaMacroRegistry -from dbt.adapters.base.relation import BaseRelation -from sqlmesh.dbt.builtin import Api -from sqlmesh.core.engine_adapter import EngineAdapter -from sqlmesh.utils.errors import ConfigError -from dbt.adapters.base import BaseRelation -from sqlglot import exp - -from dbt.adapters.base import BaseRelation - - -def migrated_ref( - dbt_api: Api, - database: t.Optional[str] = None, - schema: t.Optional[str] = None, - identifier: t.Optional[str] = None, - version: t.Optional[int] = None, - sqlmesh_model_name: t.Optional[str] = None, -) -> BaseRelation: - if version: - raise ValueError("dbt model versions are not supported in converted projects.") - - return dbt_api.Relation.create(database=database, schema=schema, identifier=identifier) - - -def migrated_source( - dbt_api: Api, - database: t.Optional[str] = None, - schema: t.Optional[str] = None, - identifier: t.Optional[str] = None, -) -> BaseRelation: - return dbt_api.Relation.create(database=database, schema=schema, identifier=identifier) - - -def create_builtin_globals( - jinja_macros: JinjaMacroRegistry, - global_vars: t.Dict[str, t.Any], - engine_adapter: t.Optional[EngineAdapter], - *args: t.Any, - **kwargs: t.Any, -) -> t.Dict[str, t.Any]: - import sqlmesh.utils.jinja as sqlmesh_native_jinja - import sqlmesh.dbt.builtin as sqlmesh_dbt_jinja - - # Capture dialect before the dbt builtins pops it - dialect = global_vars.get("dialect") - - sqlmesh_native_globals = sqlmesh_native_jinja.create_builtin_globals( - jinja_macros, global_vars, *args, **kwargs - ) - - if this_model := global_vars.get("this_model"): - # create a DBT-compatible version of @this_model for {{ this }} - if isinstance(this_model, str): - if not dialect: - raise ConfigError("No dialect?") - - # in audits, `this_model` is a SQL SELECT query that selects from the current table - # elsewhere, it's a fqn string - parsed: exp.Expression = exp.maybe_parse(this_model, dialect=dialect) - - table: t.Optional[exp.Table] = None - if isinstance(parsed, exp.Column): - table = exp.to_table(this_model, dialect=dialect) - elif isinstance(parsed, exp.Query): - table = parsed.find(exp.Table) - else: - raise ConfigError(f"Not sure how to handle this_model: {this_model}") - - if table: - # sqlmesh_dbt_jinja.create_builtin_globals() will construct a Relation for {{ this }} based on the supplied dict - global_vars["this"] = { - "database": table.catalog, - "schema": table.db, - "identifier": table.name, - } - - else: - raise ConfigError(f"Unhandled this_model type: {type(this_model)}") - - sqlmesh_dbt_globals = sqlmesh_dbt_jinja.create_builtin_globals( - jinja_macros, global_vars, engine_adapter, *args, **kwargs - ) - - def source(dbt_api: Api, source_name: str, table_name: str) -> BaseRelation: - # some source() calls cant be converted to __migrated_source() calls because they contain dynamic parameters - # this is a fallback and will be wrong in some situations because `sources` in DBT can be aliased in config - # TODO: maybe we migrate sources into the SQLMesh variables so we can look them up here? - return dbt_api.Relation.create(database=source_name, identifier=table_name) - - def ref(dbt_api: Api, ref_name: str, package: t.Optional[str] = None) -> BaseRelation: - # some ref() calls cant be converted to __migrated_ref() calls because they contain dynamic parameters - raise NotImplementedError( - f"Unable to resolve ref: {ref_name}. Please replace it with an actual model name or use a SQLMesh macro to generate dynamic model name." - ) - - dbt_compatibility_shims = { - "dialect": dialect, - "__migrated_ref": functools.partial(migrated_ref, sqlmesh_dbt_globals["api"]), - "__migrated_source": functools.partial(migrated_source, sqlmesh_dbt_globals["api"]), - "source": functools.partial(source, sqlmesh_dbt_globals["api"]), - "ref": functools.partial(ref, sqlmesh_dbt_globals["api"]), - # make {{ config(...) }} a no-op, some macros call it but its meaningless in a SQLMesh Native project - "config": lambda *_args, **_kwargs: None, - } - - return {**sqlmesh_native_globals, **sqlmesh_dbt_globals, **dbt_compatibility_shims} diff --git a/sqlmesh/dbt/converter/jinja_transforms.py b/sqlmesh/dbt/converter/jinja_transforms.py deleted file mode 100644 index 4c4cf03edc..0000000000 --- a/sqlmesh/dbt/converter/jinja_transforms.py +++ /dev/null @@ -1,465 +0,0 @@ -import typing as t -from types import MappingProxyType -from sqlmesh.core.model import Model -from jinja2 import Environment -import jinja2.nodes as j -from sqlmesh.dbt.converter.common import ( - SQLMESH_PREDEFINED_MACRO_VARIABLES, - JinjaTransform, - SQLGlotTransform, -) -from dbt.adapters.base.relation import BaseRelation -from sqlmesh.core.dialect import normalize_model_name -from sqlglot import exp -import sqlmesh.core.dialect as d -from functools import wraps - - -def _make_standalone_call_transform(fn_name: str, handler: JinjaTransform) -> JinjaTransform: - """ - Creates a transform that identifies standalone Call nodes (that arent nested in other Call nodes) and replaces them with nodes - containing the result of the handler() function - """ - - def _handle( - node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node] - ) -> t.Optional[j.Node]: - if isinstance(node, j.Call): - if isinstance(parent, (j.Call, j.List, j.Keyword)): - return node - - if (name := node.find(j.Name)) and name.name == fn_name: - return handler(node, prev, parent) - - return node - - return _handle - - -def _make_single_expression_transform( - mapping: t.Union[ - t.Dict[str, str], - t.Callable[[j.Node, t.Optional[j.Node], t.Optional[j.Node], str], t.Optional[str]], - ], -) -> JinjaTransform: - """ - Creates a transform that looks for standalone {{ expression }} nodes - It then looks up 'expression' in the provided mapping and replaces it with a TemplateData node containing the value - """ - - def _handle(node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node]) -> j.Node: - # the assumption is that individual expressions are nested in between TemplateData - if prev and not isinstance(prev, j.TemplateData): - return node - - if isinstance(node, j.Name) and not isinstance(parent, j.Getattr): - if isinstance(mapping, dict): - result = mapping.get(node.name) - else: - result = mapping(node, prev, parent, node.name) - if result is not None: - return j.TemplateData(result) - - return node - - return _handle - - -def _dbt_relation_to_model_name( - models: MappingProxyType[str, t.Union[Model, str]], relation: BaseRelation, dialect: str -) -> t.Optional[str]: - model_fqn = normalize_model_name( - table=relation.render(), default_catalog=relation.database, dialect=dialect - ) - if resolved_value := models.get(model_fqn): - return resolved_value if isinstance(resolved_value, str) else resolved_value.name - return None - - -def _dbt_relation_to_kwargs(relation: BaseRelation) -> t.List[j.Keyword]: - kwargs = [] - if database := relation.database: - kwargs.append(j.Keyword("database", j.Const(database))) - if schema := relation.schema: - kwargs.append(j.Keyword("schema", j.Const(schema))) - if identifier := relation.identifier: - kwargs.append(j.Keyword("identifier", j.Const(identifier))) - return kwargs - - -ASTTransform = t.TypeVar("ASTTransform", JinjaTransform, SQLGlotTransform) - - -def ast_transform(fn: t.Callable[..., ASTTransform]) -> t.Callable[..., ASTTransform]: - """ - Decorator to mark functions as being Jinja or SQLGlot AST transforms - - The purpose is to set __name__ to be the outer function name so that the transforms have stable names for an exclude list - The function itself as well as the ASTTransform returned by the function should have the same __name__ for this to work - """ - - @wraps(fn) - def wrapper(*args: t.Any, **kwargs: t.Any) -> ASTTransform: - result = fn(*args, **kwargs) - result.__name__ = fn.__name__ - return result - - return wrapper - - -@ast_transform -def resolve_dbt_ref_to_model_name( - models: MappingProxyType[str, t.Union[Model, str]], env: Environment, dialect: str -) -> JinjaTransform: - """ - Takes an expression like "{{ ref('foo') }}" - And turns it into "sqlmesh.foo" based on the provided list of models and resolver() function - - Args: - models: A dict of models (or model names) keyed by model fqn - jinja_env: Should contain an implementation of {{ ref() }} to turn a DBT relation name into a DBT relation object - - Returns: - A string containing the **model name** (not fqn) of the model referenced by the DBT "{{ ref() }}" call - """ - - ref: t.Callable = env.globals["ref"] # type: ignore - - def _resolve( - node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node] - ) -> t.Optional[j.Node]: - if isinstance(node, j.Call) and node.args and isinstance(node.args[0], j.Const): - ref_name = node.args[0].value - version = None - if version_kwarg := next((k for k in node.kwargs if k.key in ("version", "v")), None): - if isinstance(version_kwarg.value, j.Const): - version = version_kwarg.value.value - else: - # the version arg is present but its some kind of dynamic runtime value - # this means we cant resolve the ref to a model - return node - - if relation := ref(ref_name, version=version): - if not isinstance(relation, BaseRelation): - raise ValueError( - f"ref() returned non-relation type for '{ref_name}': {relation}" - ) - if model_name := _dbt_relation_to_model_name(models, relation, dialect): - return j.TemplateData(model_name) - return j.TemplateData(f"__unresolved_ref__.{ref_name}") - - return node - - return _make_standalone_call_transform("ref", _resolve) - - -@ast_transform -def rewrite_dbt_ref_to_migrated_ref( - models: MappingProxyType[str, t.Union[Model, str]], env: Environment, dialect: str -) -> JinjaTransform: - """ - Takes an expression like "{{ ref('foo') }}" - And turns it into "{{ __migrated_ref(database='foo', schema='bar', identifier='baz', sqlmesh_model_name='') }}" - so that the SQLMesh Native loader can construct a Relation instance without needing the Context - - Args: - models: A dict of models (or model names) keyed by model fqn - jinja_env: Should contain an implementation of {{ ref() }} to turn a DBT relation name into a DBT relation object - - Returns: - A new Call node with enough data to reconstruct the Relation - """ - - ref: t.Callable = env.globals["ref"] # type: ignore - - def _rewrite( - node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node] - ) -> t.Optional[j.Node]: - if isinstance(node, j.Call) and isinstance(node.node, j.Name) and node.node.name == "ref": - if node.args and isinstance(node.args[0], j.Const): - ref_name = node.args[0].value - version_kwarg = next((k for k in node.kwargs if k.key == "version"), None) - if (relation := ref(ref_name)) and isinstance(relation, BaseRelation): - if model_name := _dbt_relation_to_model_name(models, relation, dialect): - kwargs = _dbt_relation_to_kwargs(relation) - if version_kwarg: - kwargs.append(version_kwarg) - kwargs.append(j.Keyword("sqlmesh_model_name", j.Const(model_name))) - return j.Call(j.Name("__migrated_ref", "load"), [], kwargs, None, None) - - return node - - return _rewrite - - -@ast_transform -def resolve_dbt_source_to_model_name( - models: MappingProxyType[str, t.Union[Model, str]], env: Environment, dialect: str -) -> JinjaTransform: - """ - Takes an expression like "{{ source('foo', 'bar') }}" - And turns it into "foo.bar" based on the provided list of models and resolver() function - - Args: - models: A dict of models (or model names) keyed by model fqn - jinja_env: Should contain an implementation of {{ source() }} to turn a DBT source name / table name into a DBT relation object - - Returns: - A string containing the table fqn of the external table referenced by the DBT "{{ source() }}" call - """ - source: t.Callable = env.globals["source"] # type: ignore - - def _resolve( - node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node] - ) -> t.Optional[j.Node]: - if isinstance(node, j.Call) and isinstance(parent, (j.TemplateData, j.Output)): - if ( - len(node.args) == 2 - and isinstance(node.args[0], j.Const) - and isinstance(node.args[1], j.Const) - ): - source_name = node.args[0].value - table_name = node.args[1].value - if relation := source(source_name, table_name): - if not isinstance(relation, BaseRelation): - raise ValueError( - f"source() returned non-relation type for '{source_name}.{table_name}': {relation}" - ) - if model_name := _dbt_relation_to_model_name(models, relation, dialect): - return j.TemplateData(model_name) - return j.TemplateData(relation.render()) - # source() didnt resolve anything, just pass through the arguments verbatim - return j.TemplateData(f"{source_name}.{table_name}") - - return node - - return _make_standalone_call_transform("source", _resolve) - - -@ast_transform -def rewrite_dbt_source_to_migrated_source( - models: MappingProxyType[str, t.Union[Model, str]], env: Environment, dialect: str -) -> JinjaTransform: - """ - Takes an expression like "{{ source('foo', 'bar') }}" - And turns it into "{{ __migrated_source(database='foo', identifier='bar') }}" - so that the SQLMesh Native loader can construct a Relation instance without needing the Context - - Args: - models: A dict of models (or model names) keyed by model fqn - jinja_env: Should contain an implementation of {{ source() }} to turn a DBT source name / table name into a DBT relation object - - Returns: - A new Call node with enough data to reconstruct the Relation - """ - - source: t.Callable = env.globals["source"] # type: ignore - - def _rewrite( - node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node] - ) -> t.Optional[j.Node]: - if ( - isinstance(node, j.Call) - and isinstance(node.node, j.Name) - and node.node.name == "source" - ): - if ( - len(node.args) == 2 - and isinstance(node.args[0], j.Const) - and isinstance(node.args[1], j.Const) - ): - source_name = node.args[0].value - table_name = node.args[1].value - if (relation := source(source_name, table_name)) and isinstance( - relation, BaseRelation - ): - kwargs = _dbt_relation_to_kwargs(relation) - return j.Call(j.Name("__migrated_source", "load"), [], kwargs, None, None) - - return node - - return _rewrite - - -@ast_transform -def resolve_dbt_this_to_model_name(model_name: str) -> JinjaTransform: - """ - Takes an expression like "{{ this }}" and turns it into the provided "model_name" string - """ - return _make_single_expression_transform({"this": model_name}) - - -@ast_transform -def deduplicate_incremental_checks() -> JinjaTransform: - """ - Some files may have been designed to run with both the SQLMesh DBT loader and DBT itself and contain sections like: - - --- - select * from foo - where - {% if is_incremental() %}ds > (select max(ds)) from {{ this }}{% endif %} - {% if sqlmesh_incremental is defined %}ds BETWEEN {{ start_ds }} and {{ end_ds }}{% endif %} - --- - - This is transform detects usages of {% if sqlmesh_incremental ... %} - If it finds them, it: - - removes occurances of {% if is_incremental() %} in favour of the {% if sqlmesh_incremental %} check - - If no instances of {% if sqlmesh_incremental %} are found, nothing changes - - For for example, the above will be transformed into: - --- - select * from foo - where - ds BETWEEN {{ start_ds }} and {{ end_ds }} - --- - - But if it didnt contain the {% if sqlmesh_incremental %} block, this transform would output: - --- - select * from foo - where - {% if is_incremental() %}ds > (select max(ds)) from {{ this }}){% endif %} - --- - - """ - has_sqlmesh_incremental = False - - def _handle( - node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node] - ) -> t.Optional[j.Node]: - nonlocal has_sqlmesh_incremental - - if isinstance(node, j.Template): - for if_node in node.find_all(j.If): - if test_name := if_node.test.find(j.Name): - if test_name.name == "sqlmesh_incremental": - has_sqlmesh_incremental = True - - # only remove the {% if is_incremental() %} checks in the present of {% sqlmesh_incremental is defined %} checks - if has_sqlmesh_incremental: - if isinstance(node, j.If) and node.test: - if test_name := node.test.find(j.Name): - if test_name.name == "is_incremental": - return None - - return node - - return _handle - - -@ast_transform -def unpack_incremental_checks() -> JinjaTransform: - """ - This takes queries like: - - > select * from foo where {% if sqlmesh_incremental is defined %}ds BETWEEN {{ start_ds }} and {{ end_ds }}{% endif %} - > select * from foo where {% if is_incremental() %}ds > (select max(ds)) from foo.table){% endif %} - - And, if possible, removes the {% if sqlmesh_incremental is defined %} / {% is_incremental %} block to achieve: - - > select * from foo where ds BETWEEN {{ start_ds }} and {{ end_ds }} - > select * from foo where ds > (select max(ds)) from foo.table) - - Note that if there is a {% else %} portion to the block, there is no SQLMesh equivalent so in that case the check is untouched. - - Also, if both may be present in a model, run the deduplicate_incremental_checks() transform first so only one gets unpacked by this transform - """ - - def _handle(node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node]) -> j.Node: - if isinstance(node, j.If) and node.test: - if test_name := node.test.find(j.Name): - if ( - test_name.name in ("is_incremental", "sqlmesh_incremental") - and not node.elif_ - and not node.else_ - ): - return j.Output(node.body) - - return node - - return _handle - - -@ast_transform -def rewrite_sqlmesh_predefined_variables_to_sqlmesh_macro_syntax() -> JinjaTransform: - """ - If there are SQLMesh predefined variables in Jinja form, eg "{{ start_dt }}" - Rewrite them to eg "@start_dt" - - Example: - - select * from foo where ds between {{ start_dt }} and {{ end_dt }} - - > select * from foo where ds between @start_dt and @end_dt - """ - - mapping = {v: f"@{v}" for v in SQLMESH_PREDEFINED_MACRO_VARIABLES} - - literal_remapping = {"dt": "ts", "date": "ds"} - - def _mapping_func( - node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node], name: str - ) -> t.Optional[str]: - wrapped_in_literal = False - if prev and isinstance(prev, j.TemplateData): - data = prev.data.strip() - if data.endswith("'"): - wrapped_in_literal = True - - if wrapped_in_literal: - for original, new in literal_remapping.items(): - if name.endswith(original): - name = name.removesuffix(original) + new - - return mapping.get(name) - - return _make_single_expression_transform(_mapping_func) - - -@ast_transform -def append_dbt_package_kwarg_to_var_calls(package_name: t.Optional[str]) -> JinjaTransform: - """ " - If there are calls like: - - > {% if 'col_name' in var('history_columns') %} - - Assuming package_name=foo, change it to: - - > {% if 'col_name' in var('history_columns', __dbt_package="foo") %} - - The point of this is to give a hint to the "var" shim in SQLMesh Native so it knows which key - under "__dbt_packages__" in the project variables to look for - """ - - def _append( - node: j.Node, prev: t.Optional[j.Node], parent: t.Optional[j.Node] - ) -> t.Optional[j.Node]: - if package_name and isinstance(node, j.Call): - node.kwargs.append(j.Keyword("__dbt_package", j.Const(package_name))) - return node - - return _make_standalone_call_transform("var", _append) - - -@ast_transform -def unwrap_macros_in_string_literals() -> SQLGlotTransform: - """ - Given a query containing string literals *that match SQLMesh predefined macro variables* like: - - > select * from foo where ds between '@start_dt' and '@end_dt' - - Unwrap them into: - - > select * from foo where ds between @start_dt and @end_dt - """ - values_to_check = {f"@{var}": var for var in SQLMESH_PREDEFINED_MACRO_VARIABLES} - - def _transform(e: exp.Expression) -> exp.Expression: - if isinstance(e, exp.Literal) and e.is_string: - if (value := e.text("this")) and value in values_to_check: - return d.MacroVar( - this=values_to_check[value] - ) # MacroVar adds in the @ so dont want to add it twice - return e - - return _transform diff --git a/sqlmesh/dbt/loader.py b/sqlmesh/dbt/loader.py index 594c5a8807..cd1c4b6c1a 100644 --- a/sqlmesh/dbt/loader.py +++ b/sqlmesh/dbt/loader.py @@ -98,12 +98,11 @@ def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]: for file in macro_files: self._track_file(file) - jinja_macros = JinjaMacroRegistry() - for project in self._load_projects(): - jinja_macros = jinja_macros.merge(project.context.jinja_macros) - jinja_macros.add_globals(project.context.jinja_globals) - - return (macro.get_registry(), jinja_macros) + # This doesn't do anything, the actual content will be loaded from the manifest + return ( + macro.get_registry(), + JinjaMacroRegistry(), + ) def _load_models( self, diff --git a/sqlmesh/dbt/model.py b/sqlmesh/dbt/model.py index 6d31efe772..d2d1a52abc 100644 --- a/sqlmesh/dbt/model.py +++ b/sqlmesh/dbt/model.py @@ -590,7 +590,6 @@ def to_sqlmesh( kind=kind, start=self.start or context.sqlmesh_config.model_defaults.start, audit_definitions=audit_definitions, - path=model_kwargs.pop("path", self.path), # This ensures that we bypass query rendering that would otherwise be required to extract additional # dependencies from the model's SQL. # Note: any table dependencies that are not referenced using the `ref` macro will not be included. diff --git a/sqlmesh/dbt/target.py b/sqlmesh/dbt/target.py index 035b5b9e93..82574a044c 100644 --- a/sqlmesh/dbt/target.py +++ b/sqlmesh/dbt/target.py @@ -103,8 +103,26 @@ def load(cls, data: t.Dict[str, t.Any]) -> TargetConfig: The configuration of the provided profile target """ db_type = data["type"] - if config_class := TARGET_TYPE_TO_CONFIG_CLASS.get(db_type): - return config_class(**data) + if db_type == "databricks": + return DatabricksConfig(**data) + if db_type == "duckdb": + return DuckDbConfig(**data) + if db_type == "postgres": + return PostgresConfig(**data) + if db_type == "redshift": + return RedshiftConfig(**data) + if db_type == "snowflake": + return SnowflakeConfig(**data) + if db_type == "bigquery": + return BigQueryConfig(**data) + if db_type == "sqlserver": + return MSSQLConfig(**data) + if db_type == "trino": + return TrinoConfig(**data) + if db_type == "clickhouse": + return ClickhouseConfig(**data) + if db_type == "athena": + return AthenaConfig(**data) raise ConfigError(f"{db_type} not supported.") @@ -117,10 +135,6 @@ def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: """Converts target config to SQLMesh connection config""" raise NotImplementedError - @classmethod - def from_sqlmesh(cls, config: ConnectionConfig, **kwargs: t.Dict[str, t.Any]) -> "TargetConfig": - raise NotImplementedError - def attribute_dict(self) -> AttributeDict: fields = self.dict(include=SERIALIZABLE_FIELDS).copy() fields["target_name"] = self.name @@ -214,18 +228,6 @@ def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: **kwargs, ) - @classmethod - def from_sqlmesh(cls, config: ConnectionConfig, **kwargs: t.Dict[str, t.Any]) -> "DuckDbConfig": - if not isinstance(config, DuckDBConnectionConfig): - raise ValueError(f"Incorrect config type: {type(config)}") - - return cls( - path=config.database, - extensions=config.extensions, - settings=config.connector_config, - **kwargs, - ) - class SnowflakeConfig(TargetConfig): """ @@ -398,28 +400,6 @@ def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: **kwargs, ) - @classmethod - def from_sqlmesh( - cls, config: ConnectionConfig, **kwargs: t.Dict[str, t.Any] - ) -> "PostgresConfig": - if not isinstance(config, PostgresConnectionConfig): - raise ValueError(f"Incorrect config type: {type(config)}") - - return cls( - schema="public", - host=config.host, - user=config.user, - password=config.password, - port=config.port, - dbname=config.database, - keepalives_idle=config.keepalives_idle, - threads=config.concurrent_tasks, - connect_timeout=config.connect_timeout, - role=config.role, - sslmode=config.sslmode, - **kwargs, - ) - class RedshiftConfig(TargetConfig): """ @@ -664,39 +644,6 @@ def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: **kwargs, ) - @classmethod - def from_sqlmesh( - cls, config: ConnectionConfig, **kwargs: t.Dict[str, t.Any] - ) -> "BigQueryConfig": - if not isinstance(config, BigQueryConnectionConfig): - raise ValueError(f"Incorrect config type: {type(config)}") - - return cls( - schema="__unknown__", - method=config.method, - project=config.project, - execution_project=config.execution_project, - quota_project=config.quota_project, - location=config.location, - threads=config.concurrent_tasks, - keyfile=config.keyfile, - keyfile_json=config.keyfile_json, - token=config.token, - refresh_token=config.refresh_token, - client_id=config.client_id, - client_secret=config.client_secret, - token_uri=config.token_uri, - scopes=config.scopes, - impersonated_service_account=config.impersonated_service_account, - job_creation_timeout_seconds=config.job_creation_timeout_seconds, - job_execution_timeout_seconds=config.job_execution_timeout_seconds, - job_retries=config.job_retries, - job_retry_deadline_seconds=config.job_retry_deadline_seconds, - priority=config.priority, - maximum_bytes_billed=config.maximum_bytes_billed, - **kwargs, - ) - class MSSQLConfig(TargetConfig): """ diff --git a/sqlmesh/utils/jinja.py b/sqlmesh/utils/jinja.py index d2d830c521..9764e625a4 100644 --- a/sqlmesh/utils/jinja.py +++ b/sqlmesh/utils/jinja.py @@ -22,7 +22,6 @@ CallNames = t.Tuple[t.Tuple[str, ...], t.Union[nodes.Call, nodes.Getattr]] SQLMESH_JINJA_PACKAGE = "sqlmesh.utils.jinja" -SQLMESH_DBT_COMPATIBILITY_PACKAGE = "sqlmesh.dbt.converter.jinja_builtins" def environment(**kwargs: t.Any) -> Environment: @@ -95,11 +94,7 @@ def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]: macro_str = self._find_sql(macro_start, self._next) macros[name] = MacroInfo( definition=macro_str, - depends_on=list( - extract_macro_references_and_variables(macro_str, dbt_target_name=dialect)[ - 0 - ] - ), + depends_on=list(extract_macro_references_and_variables(macro_str)[0]), ) self._advance() @@ -171,35 +166,6 @@ def parse() -> t.List[CallNames]: return parse() -def extract_dbt_adapter_dispatch_targets(jinja_str: str) -> t.List[t.Tuple[str, t.Optional[str]]]: - """ - Given a jinja string, identify {{ adapter.dispatch('foo','bar') }} calls and extract the (foo, bar) part as a tuple - """ - ast = ENVIRONMENT.parse(jinja_str) - - extracted = [] - - def _extract(node: nodes.Node, parent: t.Optional[nodes.Node] = None) -> None: - if ( - isinstance(node, nodes.Getattr) - and isinstance(parent, nodes.Call) - and (node_name := node.find(nodes.Name)) - ): - if node_name.name == "adapter" and node.attr == "dispatch": - call_args = [arg.value for arg in parent.args if isinstance(arg, nodes.Const)][0:2] - if len(call_args) == 1: - call_args.append(None) - macro_name, package = call_args - extracted.append((macro_name, package)) - - for child_node in node.iter_child_nodes(): - _extract(child_node, parent=node) - - _extract(ast) - - return extracted - - def is_variable_node(n: nodes.Node) -> bool: return ( isinstance(n, nodes.Call) @@ -209,33 +175,11 @@ def is_variable_node(n: nodes.Node) -> bool: def extract_macro_references_and_variables( - *jinja_strs: str, dbt_target_name: t.Optional[str] = None + *jinja_strs: str, ) -> t.Tuple[t.Set[MacroReference], t.Set[str]]: macro_references = set() variables = set() for jinja_str in jinja_strs: - if dbt_target_name and "adapter.dispatch" in jinja_str: - for dispatch_target_name, package in extract_dbt_adapter_dispatch_targets(jinja_str): - # here we are guessing at the macro names that the {{ adapter.dispatch() }} call will invoke - # there is a defined resolution order: https://docs.getdbt.com/reference/dbt-jinja-functions/dispatch - # we rely on JinjaMacroRegistry.trim() to tune the dependencies down into just the ones that actually exist - macro_references.add( - MacroReference(package=package, name=f"default__{dispatch_target_name}") - ) - macro_references.add( - MacroReference( - package=package, name=f"{dbt_target_name}__{dispatch_target_name}" - ) - ) - if package and package.startswith("dbt"): - # handle the case where macros like `current_timestamp()` in the `dbt` package expect an implementation in eg the `dbt_bigquery` package - macro_references.add( - MacroReference( - package=f"dbt_{dbt_target_name}", - name=f"{dbt_target_name}__{dispatch_target_name}", - ) - ) - for call_name, node in extract_call_names(jinja_str): if call_name[0] in (c.VAR, c.BLUEPRINT_VAR): if not is_variable_node(node): @@ -249,24 +193,7 @@ def extract_macro_references_and_variables( node = t.cast(nodes.Call, node) args = [jinja_call_arg_name(arg) for arg in node.args] if args and args[0]: - variable_name = args[0].lower() - - # check if this {{ var() }} reference is from a migrated DBT package - # if it is, there will be a __dbt_package= kwarg - dbt_package = next( - ( - kwarg.value - for kwarg in node.kwargs - if isinstance(kwarg, nodes.Keyword) and kwarg.key == "__dbt_package" - ), - None, - ) - if dbt_package and isinstance(dbt_package, nodes.Const): - dbt_package = dbt_package.value - # this convention is a flat way of referencing the nested values under `__dbt_packages__` in the SQLMesh project variables - variable_name = f"{c.MIGRATED_DBT_PACKAGES}.{dbt_package}.{variable_name}" - - variables.add(variable_name) + variables.add(args[0].lower()) elif call_name[0] == c.GATEWAY: variables.add(c.GATEWAY) elif len(call_name) == 1: @@ -344,19 +271,6 @@ def _convert( def trimmed(self) -> bool: return self._trimmed - @property - def all_macros(self) -> t.Iterable[t.Tuple[t.Optional[str], str, MacroInfo]]: - """ - Returns (package, macro_name, MacroInfo) tuples for every macro in this registry - Root macros will have package=None - """ - for name, macro in self.root_macros.items(): - yield None, name, macro - - for package, macros in self.packages.items(): - for name, macro in macros.items(): - yield (package, name, macro) - def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None: """Adds macros to the target package. @@ -698,12 +612,7 @@ def jinja_call_arg_name(node: nodes.Node) -> str: def create_var(variables: t.Dict[str, t.Any]) -> t.Callable: - def _var( - var_name: str, default: t.Optional[t.Any] = None, **kwargs: t.Any - ) -> t.Optional[t.Any]: - if dbt_package := kwargs.get("__dbt_package"): - var_name = f"{c.MIGRATED_DBT_PACKAGES}.{dbt_package}.{var_name}" - + def _var(var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: value = variables.get(var_name.lower(), default) if isinstance(value, SqlValue): return value.sql diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 10881dc493..d0fad16e76 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -31,7 +31,6 @@ from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter from sqlmesh.core.engine_adapter.redshift import RedshiftEngineAdapter -from sqlmesh.core.loader import MigratedDbtProjectLoader from sqlmesh.core.notification_target import ConsoleNotificationTarget from sqlmesh.core.user import User from sqlmesh.utils.errors import ConfigError @@ -1066,32 +1065,6 @@ def test_config_complex_types_supplied_as_json_strings_from_env(tmp_path: Path) assert conn.keyfile_json == {"foo": "bar"} -def test_loader_for_migrated_dbt_project(tmp_path: Path): - config_path = tmp_path / "config.yaml" - config_path.write_text(""" - gateways: - bigquery: - connection: - type: bigquery - project: unit-test - - default_gateway: bigquery - - model_defaults: - dialect: bigquery - - variables: - __dbt_project_name__: sushi -""") - - config = load_config_from_paths( - Config, - project_paths=[config_path], - ) - - assert config.loader == MigratedDbtProjectLoader - - def test_config_user_macro_function(tmp_path: Path) -> None: config_path = tmp_path / "config.yaml" config_path.write_text(""" diff --git a/tests/core/test_loader.py b/tests/core/test_loader.py index b3d605e353..14a20ec09a 100644 --- a/tests/core/test_loader.py +++ b/tests/core/test_loader.py @@ -4,9 +4,6 @@ from sqlmesh.core.config import Config, ModelDefaultsConfig from sqlmesh.core.context import Context from sqlmesh.utils.errors import ConfigError -import sqlmesh.core.constants as c -from sqlmesh.core.config import load_config_from_yaml -from sqlmesh.utils.yaml import dump @pytest.fixture @@ -204,129 +201,3 @@ def my_model(context, **kwargs): assert model.description == "model_payload_a" path_b.write_text(model_payload_b) context.load() # raise no error to duplicate key if the functions are identical (by registry class_method) - - -def test_load_migrated_dbt_adapter_dispatch_macros(tmp_path: Path): - init_example_project(tmp_path, engine_type="duckdb") - - migrated_package_path = tmp_path / "macros" / c.MIGRATED_DBT_PACKAGES / "dbt_utils" - migrated_package_path.mkdir(parents=True) - - (migrated_package_path / "deduplicate.sql").write_text(""" - {%- macro deduplicate(relation) -%} - {{ return(adapter.dispatch('deduplicate', 'dbt_utils')(relation)) }} - {% endmacro %} - """) - - (migrated_package_path / "default__deduplicate.sql").write_text(""" - {%- macro default__deduplicate(relation) -%} - select 'default impl' from {{ relation }} - {% endmacro %} - """) - - (migrated_package_path / "duckdb__deduplicate.sql").write_text(""" - {%- macro duckdb__deduplicate(relation) -%} - select 'duckdb impl' from {{ relation }} - {% endmacro %} - """) - - # this should be pruned from the JinjaMacroRegistry because the target is duckdb, not bigquery - (migrated_package_path / "bigquery__deduplicate.sql").write_text(""" - {%- macro bigquery__deduplicate(relation) -%} - select 'bigquery impl' from {{ relation }} - {% endmacro %} - """) - - (tmp_path / "models" / "test_model.sql").write_text(""" - MODEL ( - name sqlmesh_example.test, - kind FULL, - ); -JINJA_QUERY_BEGIN; -{{ dbt_utils.deduplicate(__migrated_ref(schema='sqlmesh_example', identifier='full_model')) }} -JINJA_END; - """) - - config_path = tmp_path / "config.yaml" - assert config_path.exists() - config = load_config_from_yaml(config_path) - config["variables"] = {} - config["variables"][c.MIGRATED_DBT_PROJECT_NAME] = "test" - - config_path.write_text(dump(config)) - - ctx = Context(paths=tmp_path) - - model = ctx.models['"db"."sqlmesh_example"."test"'] - assert model.dialect == "duckdb" - assert {(package, name) for package, name, _ in model.jinja_macros.all_macros} == { - ("dbt_utils", "deduplicate"), - ("dbt_utils", "default__deduplicate"), - ("dbt_utils", "duckdb__deduplicate"), - } - - assert ( - model.render_query_or_raise().sql(dialect="duckdb") - == """SELECT \'duckdb impl\' AS "duckdb impl" FROM "db"."sqlmesh_example"."full_model" AS "full_model\"""" - ) - - -def test_load_migrated_dbt_adapter_dispatch_macros_in_different_packages(tmp_path: Path): - # some things like dbt.current_timestamp() dispatch to macros in a different package - init_example_project(tmp_path, engine_type="duckdb") - - migrated_package_path_dbt = tmp_path / "macros" / c.MIGRATED_DBT_PACKAGES / "dbt" - migrated_package_path_dbt_duckdb = tmp_path / "macros" / c.MIGRATED_DBT_PACKAGES / "dbt_duckdb" - migrated_package_path_dbt.mkdir(parents=True) - migrated_package_path_dbt_duckdb.mkdir(parents=True) - - (migrated_package_path_dbt / "current_timestamp.sql").write_text(""" - {%- macro current_timestamp(relation) -%} - {{ return(adapter.dispatch('current_timestamp', 'dbt')()) }} - {% endmacro %} - """) - - (migrated_package_path_dbt / "default__current_timestamp.sql").write_text(""" - {% macro default__current_timestamp() -%} - {{ exceptions.raise_not_implemented('current_timestamp macro not implemented') }} - {%- endmacro %} - """) - - (migrated_package_path_dbt_duckdb / "duckdb__current_timestamp.sql").write_text(""" - {%- macro duckdb__current_timestamp() -%} - 'duckdb current_timestamp impl' - {% endmacro %} - """) - - (tmp_path / "models" / "test_model.sql").write_text(""" - MODEL ( - name sqlmesh_example.test, - kind FULL, - ); -JINJA_QUERY_BEGIN; -select {{ dbt.current_timestamp() }} as a -JINJA_END; - """) - - config_path = tmp_path / "config.yaml" - assert config_path.exists() - config = load_config_from_yaml(config_path) - config["variables"] = {} - config["variables"][c.MIGRATED_DBT_PROJECT_NAME] = "test" - - config_path.write_text(dump(config)) - - ctx = Context(paths=tmp_path) - - model = ctx.models['"db"."sqlmesh_example"."test"'] - assert model.dialect == "duckdb" - assert {(package, name) for package, name, _ in model.jinja_macros.all_macros} == { - ("dbt", "current_timestamp"), - ("dbt", "default__current_timestamp"), - ("dbt_duckdb", "duckdb__current_timestamp"), - } - - assert ( - model.render_query_or_raise().sql(dialect="duckdb") - == "SELECT 'duckdb current_timestamp impl' AS \"a\"" - ) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 9266a56c10..3850e08164 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -15,7 +15,6 @@ from sqlmesh.cli.project_init import init_example_project, ProjectTemplate from sqlmesh.core.environment import EnvironmentNamingInfo from sqlmesh.core.model.kind import TimeColumn, ModelKindName -from pydantic import ValidationError from sqlmesh import CustomMaterialization, CustomKind from pydantic import model_validator, ValidationError @@ -66,13 +65,7 @@ from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory from sqlmesh.utils.date import TimeLike, to_datetime, to_ds, to_timestamp from sqlmesh.utils.errors import ConfigError, SQLMeshError, LinterError -from sqlmesh.utils.jinja import ( - JinjaMacroRegistry, - MacroInfo, - MacroExtractor, - MacroReference, - SQLMESH_DBT_COMPATIBILITY_PACKAGE, -) +from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroInfo, MacroExtractor from sqlmesh.utils.metaprogramming import Executable, SqlValue from sqlmesh.core.macros import RuntimeStage from tests.utils.test_helpers import use_terminal_console @@ -6386,59 +6379,6 @@ def model_with_variables(context, **kwargs): assert df.to_dict(orient="records") == [{"a": "test_value", "b": "default_value", "c": None}] -def test_variables_migrated_dbt_package_macro(): - expressions = parse( - """ - MODEL( - name test_model, - kind FULL, - ); - - JINJA_QUERY_BEGIN; - SELECT '{{ var('TEST_VAR_A') }}' as a, '{{ test.test_macro_var() }}' as b - JINJA_END; - """, - default_dialect="bigquery", - ) - - jinja_macros = JinjaMacroRegistry( - create_builtins_module=SQLMESH_DBT_COMPATIBILITY_PACKAGE, - packages={ - "test": { - "test_macro_var": MacroInfo( - definition=""" - {% macro test_macro_var() %} - {{- var('test_var_b', __dbt_package='test') }} - {%- endmacro %}""", - depends_on=[MacroReference(name="var")], - ) - } - }, - ) - - model = load_sql_based_model( - expressions, - variables={ - "test_var_a": "test_var_a_value", - c.MIGRATED_DBT_PACKAGES: { - "test": {"test_var_b": "test_var_b_value", "unused": "unused_value"}, - }, - "test_var_c": "test_var_c_value", - }, - jinja_macros=jinja_macros, - migrated_dbt_project_name="test", - dialect="bigquery", - ) - assert model.python_env[c.SQLMESH_VARS] == Executable.value( - {"test_var_a": "test_var_a_value", "__dbt_packages__.test.test_var_b": "test_var_b_value"}, - sort_root_dict=True, - ) - assert ( - model.render_query().sql(dialect="bigquery") - == "SELECT 'test_var_a_value' AS `a`, 'test_var_b_value' AS `b`" - ) - - def test_load_external_model_python(sushi_context) -> None: @model( "test_load_external_model_python", @@ -8150,37 +8090,6 @@ def test_model_kind_to_expression(): ) -def test_incremental_by_unique_key_batch_concurrency(): - with pytest.raises(ValidationError, match=r"Input should be 1"): - load_sql_based_model( - d.parse(""" - MODEL ( - name db.table, - kind INCREMENTAL_BY_UNIQUE_KEY ( - unique_key a, - batch_concurrency 2 - ) - ); - select 1; - """) - ) - - model = load_sql_based_model( - d.parse(""" - MODEL ( - name db.table, - kind INCREMENTAL_BY_UNIQUE_KEY ( - unique_key a, - batch_concurrency 1 - ) - ); - select 1; - """) - ) - assert isinstance(model.kind, IncrementalByUniqueKeyKind) - assert model.kind.batch_concurrency == 1 - - def test_bad_model_kind(): with pytest.raises( SQLMeshError, diff --git a/tests/dbt/converter/conftest.py b/tests/dbt/converter/conftest.py deleted file mode 100644 index e8dffeb263..0000000000 --- a/tests/dbt/converter/conftest.py +++ /dev/null @@ -1,21 +0,0 @@ -from pathlib import Path -import typing as t -import pytest -from sqlmesh.core.context import Context - - -@pytest.fixture -def sushi_dbt_context(copy_to_temp_path: t.Callable) -> Context: - return Context(paths=copy_to_temp_path("examples/sushi_dbt")) - - -@pytest.fixture -def empty_dbt_context(copy_to_temp_path: t.Callable) -> Context: - fixture_path = Path(__file__).parent / "fixtures" / "empty_dbt_project" - assert fixture_path.exists() - - actual_path = copy_to_temp_path(fixture_path)[0] - - ctx = Context(paths=actual_path) - - return ctx diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/.gitignore b/tests/dbt/converter/fixtures/empty_dbt_project/.gitignore deleted file mode 100644 index 232ccd1d8c..0000000000 --- a/tests/dbt/converter/fixtures/empty_dbt_project/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -target/ -logs/ diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/analyses/.gitkeep b/tests/dbt/converter/fixtures/empty_dbt_project/analyses/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/config.py b/tests/dbt/converter/fixtures/empty_dbt_project/config.py deleted file mode 100644 index e7e28c98e4..0000000000 --- a/tests/dbt/converter/fixtures/empty_dbt_project/config.py +++ /dev/null @@ -1,7 +0,0 @@ -from pathlib import Path - -from sqlmesh.dbt.loader import sqlmesh_config - -config = sqlmesh_config(Path(__file__).parent) - -test_config = config diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/dbt_project.yml b/tests/dbt/converter/fixtures/empty_dbt_project/dbt_project.yml deleted file mode 100644 index 007649e553..0000000000 --- a/tests/dbt/converter/fixtures/empty_dbt_project/dbt_project.yml +++ /dev/null @@ -1,22 +0,0 @@ - -name: 'test' -version: '1.0.0' -config-version: 2 -profile: 'test' - -model-paths: ["models"] -analysis-paths: ["analyses"] -test-paths: ["tests"] -seed-paths: ["seeds"] -macro-paths: ["macros"] -snapshot-paths: ["snapshots"] - -target-path: "target" - -models: - +start: Jan 1 2022 - -seeds: - +schema: raw - -vars: {} diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/macros/.gitkeep b/tests/dbt/converter/fixtures/empty_dbt_project/macros/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/models/.gitkeep b/tests/dbt/converter/fixtures/empty_dbt_project/models/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/models/sources.yml b/tests/dbt/converter/fixtures/empty_dbt_project/models/sources.yml deleted file mode 100644 index 49354831f4..0000000000 --- a/tests/dbt/converter/fixtures/empty_dbt_project/models/sources.yml +++ /dev/null @@ -1,6 +0,0 @@ -version: 2 - -sources: - - name: external - tables: - - name: orders diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/packages/.gitkeep b/tests/dbt/converter/fixtures/empty_dbt_project/packages/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/profiles.yml b/tests/dbt/converter/fixtures/empty_dbt_project/profiles.yml deleted file mode 100644 index 6d91ecbe65..0000000000 --- a/tests/dbt/converter/fixtures/empty_dbt_project/profiles.yml +++ /dev/null @@ -1,6 +0,0 @@ -test: - outputs: - in_memory: - type: duckdb - schema: project - target: in_memory diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/seeds/.gitkeep b/tests/dbt/converter/fixtures/empty_dbt_project/seeds/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/seeds/items.csv b/tests/dbt/converter/fixtures/empty_dbt_project/seeds/items.csv deleted file mode 100644 index 0f87cb2507..0000000000 --- a/tests/dbt/converter/fixtures/empty_dbt_project/seeds/items.csv +++ /dev/null @@ -1,94 +0,0 @@ -id,name,price,ds -0,Maguro,4.34,2022-01-01 -1,Ika,7.35,2022-01-01 -2,Aji,6.06,2022-01-01 -3,Hotate,8.5,2022-01-01 -4,Escolar,8.46,2022-01-01 -5,Sake,4.91,2022-01-01 -6,Tamago,4.94,2022-01-01 -7,Umi Masu,8.61,2022-01-01 -8,Bincho,9.71,2022-01-01 -9,Toro,9.13,2022-01-01 -10,Aoyagi,5.5,2022-01-01 -11,Hamachi,6.51,2022-01-01 -12,Tobiko,7.78,2022-01-01 -13,Unagi,7.99,2022-01-01 -14,Tako,5.59,2022-01-01 -0,Kani,8.22,2022-01-02 -1,Amaebi,9.14,2022-01-02 -2,Uni,4.55,2022-01-02 -3,Sake Toro,5.01,2022-01-02 -4,Maguro,9.95,2022-01-02 -5,Katsuo,9.03,2022-01-02 -6,Hamachi Toro,3.76,2022-01-02 -7,Iwashi,5.56,2022-01-02 -8,Tamago,6.96,2022-01-02 -9,Tai,5.84,2022-01-02 -10,Ika,3.23,2022-01-02 -0,Hirame,7.74,2022-01-03 -1,Uni,3.98,2022-01-03 -2,Tai,4.09,2022-01-03 -3,Kanpachi,7.55,2022-01-03 -4,Tobiko,9.87,2022-01-03 -5,Hotate,7.86,2022-01-03 -6,Iwashi,8.33,2022-01-03 -7,Ikura,5.98,2022-01-03 -8,Maguro,3.97,2022-01-03 -9,Tsubugai,4.51,2022-01-03 -10,Tako,8.35,2022-01-03 -11,Sake,3.38,2022-01-03 -12,Tamago,6.43,2022-01-03 -13,Ika,4.26,2022-01-03 -14,Unagi,7.42,2022-01-03 -0,Ikura,5.02,2022-01-04 -1,Tobiko,9.15,2022-01-04 -2,Hamachi,6.66,2022-01-04 -3,Bincho,8.4,2022-01-04 -4,Tsubugai,5.26,2022-01-04 -5,Hotate,8.92,2022-01-04 -6,Toro,7.52,2022-01-04 -7,Aji,7.49,2022-01-04 -8,Ebi,5.67,2022-01-04 -9,Kanpachi,7.51,2022-01-04 -10,Kani,6.97,2022-01-04 -11,Hirame,4.51,2022-01-04 -0,Saba,7.41,2022-01-05 -1,Unagi,8.45,2022-01-05 -2,Uni,3.67,2022-01-05 -3,Maguro,8.76,2022-01-05 -4,Katsuo,5.99,2022-01-05 -5,Bincho,9.15,2022-01-05 -6,Sake Toro,3.67,2022-01-05 -7,Aji,9.55,2022-01-05 -8,Umi Masu,9.88,2022-01-05 -9,Hamachi,6.53,2022-01-05 -10,Tai,6.83,2022-01-05 -11,Tsubugai,4.62,2022-01-05 -12,Ikura,4.86,2022-01-05 -13,Ahi,9.66,2022-01-05 -14,Hotate,7.85,2022-01-05 -0,Hamachi Toro,4.87,2022-01-06 -1,Ika,3.26,2022-01-06 -2,Kanpachi,8.63,2022-01-06 -3,Hirame,5.34,2022-01-06 -4,Katsuo,9.24,2022-01-06 -5,Iwashi,8.67,2022-01-06 -6,Sake Toro,9.75,2022-01-06 -7,Bincho,9.7,2022-01-06 -8,Aji,7.14,2022-01-06 -9,Hokigai,5.18,2022-01-06 -10,Umi Masu,9.43,2022-01-06 -11,Unagi,3.35,2022-01-06 -12,Sake,4.58,2022-01-06 -13,Aoyagi,5.54,2022-01-06 -0,Amaebi,6.94,2022-01-07 -1,Ebi,7.84,2022-01-07 -2,Saba,5.28,2022-01-07 -3,Anago,4.53,2022-01-07 -4,Escolar,7.28,2022-01-07 -5,Ahi,6.48,2022-01-07 -6,Katsuo,5.16,2022-01-07 -7,Umi Masu,6.09,2022-01-07 -8,Maguro,7.7,2022-01-07 -9,Hokigai,7.37,2022-01-07 -10,Sake Toro,6.99,2022-01-07 diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/seeds/properties.yml b/tests/dbt/converter/fixtures/empty_dbt_project/seeds/properties.yml deleted file mode 100644 index 86ce6964fe..0000000000 --- a/tests/dbt/converter/fixtures/empty_dbt_project/seeds/properties.yml +++ /dev/null @@ -1,13 +0,0 @@ -version: 2 - -seeds: - - name: items - columns: - - name: id - description: Item id - - name: name - description: Name of the item - - name: price - description: Price of the item - - name: ds - description: Date \ No newline at end of file diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/snapshots/.gitkeep b/tests/dbt/converter/fixtures/empty_dbt_project/snapshots/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/dbt/converter/fixtures/empty_dbt_project/tests/.gitkeep b/tests/dbt/converter/fixtures/empty_dbt_project/tests/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/dbt/converter/fixtures/jinja_nested_if.sql b/tests/dbt/converter/fixtures/jinja_nested_if.sql deleted file mode 100644 index e7a1bed137..0000000000 --- a/tests/dbt/converter/fixtures/jinja_nested_if.sql +++ /dev/null @@ -1,15 +0,0 @@ -{% if foo == 'bar' %} - baz - {% if baz == 'bing' %} - bong - {% else %} - qux - {% endif %} -{% elif a == fn(b) %} - {% if c == 'f' and fn1(a, c, 'foo') == 'test' %} - output1 - {% elif z is defined %} - output2 - {% endif %} - output -{% endif %} \ No newline at end of file diff --git a/tests/dbt/converter/fixtures/macro_dbt_incremental.sql b/tests/dbt/converter/fixtures/macro_dbt_incremental.sql deleted file mode 100644 index a76f60713b..0000000000 --- a/tests/dbt/converter/fixtures/macro_dbt_incremental.sql +++ /dev/null @@ -1,11 +0,0 @@ -{% macro incremental_by_time(col, time_type) %} - {% if is_incremental() %} - WHERE - {{ col }} > (select max({{ col }}) from {{ this }}) - {% endif %} - {% if sqlmesh_incremental is defined %} - {% set dates = incremental_dates_by_time_type(time_type) %} - WHERE - {{ col }} BETWEEN '{{ dates[0] }}' AND '{{ dates[1] }}' - {% endif %} -{% endmacro %} \ No newline at end of file diff --git a/tests/dbt/converter/fixtures/macro_func_with_params.sql b/tests/dbt/converter/fixtures/macro_func_with_params.sql deleted file mode 100644 index 06bb757ef9..0000000000 --- a/tests/dbt/converter/fixtures/macro_func_with_params.sql +++ /dev/null @@ -1,17 +0,0 @@ -{% macro func_with_params(amount, category) %} - case - {% for row in [ - { 'category': '1', 'range': [0, 10], 'consider': True }, - { 'category': '2', 'range': [11, 20], 'consider': None } - ] %} - when {{ category }} = '{{ row.category }}' - and {{ amount }} >= {{ row.range[0] }} - {% if row.consider is not none %} - and {{ amount }} < {{ row.range[1] }} - {% endif %} - then - ({{ amount }} * {{ row.range[0] }} + {{ row.range[1] }}) * 4 - {% endfor %} - else null - end -{% endmacro %} \ No newline at end of file diff --git a/tests/dbt/converter/fixtures/model_query_incremental.sql b/tests/dbt/converter/fixtures/model_query_incremental.sql deleted file mode 100644 index a9603dbcbb..0000000000 --- a/tests/dbt/converter/fixtures/model_query_incremental.sql +++ /dev/null @@ -1,34 +0,0 @@ -WITH cte AS ( - SELECT - oi.order_id AS order_id, - FROM {{ ref('order_items') }} AS oi - LEFT JOIN {{ ref('items') }} AS i - ON oi.item_id = i.id AND oi.ds = i.ds -{% if is_incremental() %} -WHERE - oi.ds > (select max(ds) from {{ this }}) -{% endif %} -{% if sqlmesh_incremental is defined %} -WHERE - oi.ds BETWEEN '{{ start_ds }}' AND '{{ end_ds }}' -{% endif %} -GROUP BY - oi.order_id, - oi.ds -) -SELECT - o.customer_id::INT AS customer_id, /* Customer id */ - SUM(ot.total)::NUMERIC AS revenue, /* Revenue from orders made by this customer */ - o.ds::TEXT AS ds /* Date */ -FROM {{ ref('orders') }} AS o - LEFT JOIN order_total AS ot - ON o.id = ot.order_id AND o.ds = ot.ds -{% if is_incremental() %} - WHERE o.ds > (select max(ds) from {{ this }}) -{% endif %} -{% if sqlmesh_incremental is defined %} - WHERE o.ds BETWEEN '{{ start_ds }}' AND '{{ end_ds }}' -{% endif %} -GROUP BY - o.customer_id, - o.ds \ No newline at end of file diff --git a/tests/dbt/converter/test_convert.py b/tests/dbt/converter/test_convert.py deleted file mode 100644 index 001b1f82cc..0000000000 --- a/tests/dbt/converter/test_convert.py +++ /dev/null @@ -1,105 +0,0 @@ -from pathlib import Path -from sqlmesh.core.context import Context -from sqlmesh.dbt.converter.convert import convert_project_files, resolve_fqns_to_model_names -import uuid -import sqlmesh.core.constants as c - - -def test_convert_project_files(sushi_dbt_context: Context, tmp_path: Path) -> None: - src_context = sushi_dbt_context - src_path = sushi_dbt_context.path - output_path = tmp_path / f"output_{uuid.uuid4().hex}" - - convert_project_files(src_path, output_path) - - target_context = Context(paths=output_path) - - assert src_context.models.keys() == target_context.models.keys() - - target_context.plan(auto_apply=True) - - -def test_convert_project_files_includes_library_macros( - sushi_dbt_context: Context, tmp_path: Path -) -> None: - src_path = sushi_dbt_context.path - output_path = tmp_path / f"output_{uuid.uuid4().hex}" - - (src_path / "macros" / "call_library.sql").write_text(""" -{% macro call_library() %} - {{ dbt.current_timestamp() }} -{% endmacro %} -""") - - convert_project_files(src_path, output_path) - - migrated_output_macros_path = output_path / "macros" / c.MIGRATED_DBT_PACKAGES - assert (migrated_output_macros_path / "dbt" / "current_timestamp.sql").exists() - # note: the DBT manifest is smart enough to prune "dbt / default__current_timestamp.sql" from the list so it is not migrated - assert (migrated_output_macros_path / "dbt_duckdb" / "duckdb__current_timestamp.sql").exists() - - -def test_resolve_fqns_to_model_names(empty_dbt_context: Context) -> None: - ctx = empty_dbt_context - - # macro that uses a property of {{ ref() }} and also creates another ref() - (ctx.path / "macros" / "foo.sql").write_text( - """ -{% macro foo(relation) %} - {{ relation.name }} r - left join {{ source('external', 'orders') }} et - on r.id = et.id -{% endmacro %} -""" - ) - - # model 1 - can be fully unwrapped - (ctx.path / "models" / "model1.sql").write_text( - """ -{{ - config( - materialized='incremental', - incremental_strategy='delete+insert', - time_column='ds' - ) -}} - -select * from {{ ref('items') }} -{% if is_incremental() %} - where ds > (select max(ds) from {{ this }}) -{% endif %} -""" - ) - - # model 2 - has ref passed to macro as parameter and also another ref nested in macro - (ctx.path / "models" / "model2.sql").write_text( - """ -select * from {{ foo(ref('model1')) }} union select * from {{ ref('items') }} -""" - ) - - ctx.load() - - assert len(ctx.models) == 3 - - model1 = ctx.models['"memory"."project"."model1"'] - model2 = ctx.models['"memory"."project"."model2"'] - - assert model1.depends_on == {'"memory"."project_raw"."items"'} - assert model2.depends_on == { - '"memory"."project"."model1"', - '"memory"."external"."orders"', - '"memory"."project_raw"."items"', - } - - # All dependencies in model 1 can be tracked by the native loader but its very difficult to cover all the edge cases at conversion time - # so we still populate depends_on() - assert resolve_fqns_to_model_names(ctx, model1.depends_on) == {"project_raw.items"} - - # For model 2, the external model "external.orders" should be removed from depends_on - # If it was output verbatim as depends_on ("memory"."external"."orders"), the native loader would throw an error like: - # - Error: Failed to load model definition, 'Dot' object is not iterable - assert resolve_fqns_to_model_names(ctx, model2.depends_on) == { - "project.model1", - "project_raw.items", - } diff --git a/tests/dbt/converter/test_jinja.py b/tests/dbt/converter/test_jinja.py deleted file mode 100644 index 5d3e4508d3..0000000000 --- a/tests/dbt/converter/test_jinja.py +++ /dev/null @@ -1,450 +0,0 @@ -import pytest -from sqlmesh.utils.jinja import ( - JinjaMacroRegistry, - MacroExtractor, - extract_macro_references_and_variables, -) -from sqlmesh.dbt.converter.jinja import JinjaGenerator, convert_jinja_query, convert_jinja_macro -import sqlmesh.dbt.converter.jinja_transforms as jt -from pathlib import Path -from sqlmesh.core.context import Context -import sqlmesh.core.dialect as d -from sqlglot import exp -from _pytest.mark.structures import ParameterSet -from sqlmesh.core.model import SqlModel, load_sql_based_model -from sqlmesh.utils import columns_to_types_all_known - - -def _load_fixture(name: str) -> ParameterSet: - return pytest.param( - (Path(__file__).parent / "fixtures" / name).read_text(encoding="utf8"), id=name - ) - - -@pytest.mark.parametrize( - "original_jinja", - [ - "select 1", - "select bar from {{ ref('foo') }} as f", - "select max(ds) from {{ this }}", - "{% if is_incremental() %}where ds > (select max(ds) from {{ this }}){% endif %}", - "foo {% if sqlmesh_incremental is defined %} bar {% endif %} bar", - "foo between '{{ start_ds }}' and '{{ end_ds }}'", - "{{ 42 }}", - "{{ foo.bar }}", - "{{ 'baz' }}", - "{{ col }} BETWEEN '{{ dates[0] }}' AND '{{ dates[1] }}'", - "{% set foo = bar(baz, bing='bong') %}", - "{% if a == 'ds' %}foo{% elif a == 'ts' %}bar{% elif a < 'ys' or (b != 'ds' and c >= 'ts') %}baz{% else %}bing{% endif %}", - "{% set my_string = my_string ~ stuff ~ ', ' ~ 1 %}", - "{{ context.do_some_action('param') }}", - "{% set big_ole_block %}foo{% endset %}", - "{% if not loop.last %}foo{% endif %}", - "{% for a, b in some_func(a=foo['bar'][0], b=c.d[5]).items() %}foo_{{ a }}_{{ b }}{% endfor %}", - "{{ column | replace(prefix, '') }}", - "{{ column | filter('a', foo='bar') }}", - "{% filter upper %}foo{% endfilter %}", - "{% filter foo(0, bar='baz') %}foo{% endfilter %}", - "{% if foo in ('bar', 'baz') %}bar{% endif %}", - "{% if foo not in ('bar', 'baz') %}bing{% endif %}", - "{% if (field.a if field.a else field.b) | lower not in ('c', 'd') %}foo{% endif %}", - "{% do foo.bar('baz') %}", - "{% set a = (col | lower + '_') + b %}", - "{{ foo[1:10] | lower }}", - "{{ foo[1:] }}", - "{{ foo[:1] }}", - "{% for col in all_columns if col.name in columns_to_compare and col.name in special_names %}{{ col }}{% endfor %}", - "{{ ' or ' if not loop.first else '' }}", - "{% set foo = ['a', 'b', c, d.e, f[0], g.h.i[0][1]] %}", - """{% set foo = "('%Y%m%d', partition_id)" %}""", - "{% set foo = (graph.nodes.values() | selectattr('name', 'equalto', model_name) | list)[0] %}", - "{% set foo.bar = baz.bing(database='foo') %}", - "{{ return(('some', 'tuple')) }}", - "{% call foo('bar', baz=True) %}bar{% endcall %}", - "{% call(user) dump_users(list_of_user) %}bar{% endcall %}", - "{% macro foo(a, b='default', c=None) %}{% endmacro %}", - # "{# some comment #}", #todo: comments get stripped entirely - # "foo\n{%- if bar -%} baz {% endif -%}", #todo: whitespace trim handling is a nice-to-have - _load_fixture("model_query_incremental.sql"), - _load_fixture("macro_dbt_incremental.sql"), - _load_fixture("jinja_nested_if.sql"), - ], -) -def test_generator_roundtrip(original_jinja: str) -> None: - registry = JinjaMacroRegistry() - env = registry.build_environment() - - ast = env.parse(original_jinja) - generated = JinjaGenerator().generate(ast) - - assert generated == original_jinja - - me = MacroExtractor() - # basically just test this doesnt throw an exception. - # The MacroExtractor uses SQLGLot's tokenizer and not Jinja's so these need to work when the converted project is loaded by the native loader - me.extract(generated) - - -def test_generator_sql_comment_macro(): - jinja_str = "-- before sql comment{% macro foo() %}-- inner sql comment{% endmacro %}" - - registry = JinjaMacroRegistry() - env = registry.build_environment() - - ast = env.parse(jinja_str) - generated = JinjaGenerator().generate(ast) - - assert ( - generated == "-- before sql comment\n{% macro foo() %}-- inner sql comment\n{% endmacro %}" - ) - - # check roundtripping an existing newline doesnt keep adding newlines - assert JinjaGenerator().generate(env.parse(generated)) == generated - - -@pytest.mark.parametrize("original_jinja", [_load_fixture("macro_func_with_params.sql")]) -def test_generator_roundtrip_ignore_whitespace(original_jinja: str) -> None: - """ - This makes the following assumptions: - - SQL isnt too sensitive about indentation / whitespace - - The Jinja AST doesnt capture enough information to perfectly replicate the input template with regards to whitespace handling - - So if, disregarding whitespace, the original input string is the same as the AST being run through the generator: the test passes - """ - registry = JinjaMacroRegistry() - env = registry.build_environment() - - ast = env.parse(original_jinja) - - generated = JinjaGenerator().generate(ast) - - assert " ".join(original_jinja.split()) == " ".join(generated.split()) - - -def test_convert_jinja_query(sushi_dbt_context: Context) -> None: - model = sushi_dbt_context.models['"memory"."sushi"."customer_revenue_by_day"'] - assert isinstance(model, SqlModel) - - query = model.query - assert isinstance(query, d.JinjaQuery) - - result = convert_jinja_query(sushi_dbt_context, model, query) - - assert isinstance(result, exp.Query) - - assert ( - result.sql(dialect=model.dialect, pretty=True) - == """WITH order_total AS ( - SELECT - oi.order_id AS order_id, - SUM(oi.quantity * i.price) AS total, - oi.ds AS ds - FROM sushi_raw.order_items AS oi - LEFT JOIN sushi_raw.items AS i - ON oi.item_id = i.id AND oi.ds = i.ds - WHERE - oi.ds BETWEEN @start_ds AND @end_ds - GROUP BY - oi.order_id, - oi.ds -) -SELECT - CAST(o.customer_id AS INT) AS customer_id, /* Customer id */ - CAST(SUM(ot.total) AS DOUBLE) AS revenue, /* Revenue from orders made by this customer */ - CAST(o.ds AS TEXT) AS ds /* Date */ -FROM sushi_raw.orders AS o -LEFT JOIN order_total AS ot - ON o.id = ot.order_id AND o.ds = ot.ds -WHERE - o.ds BETWEEN @start_ds AND @end_ds -GROUP BY - o.customer_id, - o.ds""" - ) - - -def test_convert_jinja_query_exclude_transform(empty_dbt_context: Context) -> None: - ctx = empty_dbt_context - - (ctx.path / "models" / "model1.sql").write_text(""" - {{ - config( - materialized='incremental', - incremental_strategy='delete+insert', - time_column='ds' - ) - }} - - select * from {{ ref('items') }} - {% if is_incremental() %} - where ds > (select max(ds) from {{ this }}) - {% endif %} - """) - - ctx.load() - - model = ctx.models['"memory"."project"."model1"'] - assert isinstance(model, SqlModel) - - query = model.query - assert isinstance(query, d.JinjaQuery) - - converted_query = convert_jinja_query( - ctx, - model, - query, - exclude=[jt.resolve_dbt_ref_to_model_name, jt.rewrite_dbt_ref_to_migrated_ref], - ) - sql = converted_query.sql() - - assert "{{ ref('items') }}" in sql - assert "{{ this }}" not in sql - assert "{% if is_incremental() %}" not in sql - assert "{% endif %}" not in sql - - -def test_convert_jinja_query_self_referencing(empty_dbt_context: Context) -> None: - ctx = empty_dbt_context - - (ctx.path / "models" / "model1.sql").write_text(""" - {{ - config( - materialized='incremental', - incremental_strategy='delete+insert', - time_column='ds' - ) - }} - - select * from {{ ref('items') }} - {% if is_incremental() %} - where ds > (select max(ds) from {{ this }}) - {% endif %} - """) - - ctx.load() - - model = ctx.models['"memory"."project"."model1"'] - assert model.columns_to_types_or_raise - assert ( - not model.depends_on_self - ) # the DBT loader doesnt detect self-references within is_incremental blocks - assert isinstance(model, SqlModel) - - query = model.query - assert isinstance(query, d.JinjaQuery) - - converted_query = convert_jinja_query(ctx, model, query) - converted_model_definition = model.copy().render_definition()[0].sql() - - # load from scratch to use the native loader and clear @cached_property's - ctx.upsert_model( - load_sql_based_model( - expressions=[d.parse_one(converted_model_definition), converted_query], - default_catalog=ctx.default_catalog, - ) - ) - converted_model = ctx.models['"memory"."project"."model1"'] - assert isinstance(converted_model, SqlModel) - - assert not "{% is_incremental" in converted_model.query.sql() - assert ( - converted_model.depends_on_self - ) # Once the is_incremental blocks are removed, the model can be detected as self referencing - assert columns_to_types_all_known( - converted_model.columns_to_types_or_raise - ) # columns to types must all be known for self-referencing models - - -def test_convert_jinja_query_self_referencing_columns_to_types_not_all_known( - empty_dbt_context: Context, -) -> None: - ctx = empty_dbt_context - - (ctx.path / "models" / "model1.sql").write_text(""" - {{ - config( - materialized='incremental', - incremental_strategy='delete+insert', - time_column='ds' - ) - }} - - select id, name, ds from external.table - {% if is_incremental() %} - where ds > (select max(ds) from {{ this }}) - {% endif %} - """) - - ctx.load() - - model = ctx.models['"memory"."project"."model1"'] - assert model.columns_to_types_or_raise - assert ( - not model.depends_on_self - ) # the DBT loader doesnt detect self-references within is_incremental blocks - assert isinstance(model, SqlModel) - - query = model.query - assert isinstance(query, d.JinjaQuery) - - converted_query = convert_jinja_query(ctx, model, query) - converted_model_definition = model.render_definition()[0].sql() - - # load from scratch to use the native loader and clear @cached_property's - ctx.upsert_model( - load_sql_based_model( - expressions=[d.parse_one(converted_model_definition), converted_query], - jinja_macros=model.jinja_macros, - default_catalog=ctx.default_catalog, - ) - ) - converted_model = ctx.models['"memory"."project"."model1"'] - assert isinstance(converted_model, SqlModel) - - # {% is_incremental() %} block should be retained because removing it would make the model self-referencing but the columns_to_types - # arent all known so this would create a load error like: Error: Self-referencing models require inferrable column types. - assert "{% if is_incremental" in converted_model.query.sql() - assert "{{ this }}" not in converted_model.query.sql() - assert not converted_model.depends_on_self - - assert not columns_to_types_all_known( - converted_model.columns_to_types_or_raise - ) # this is ok because the model is not self-referencing - - -def test_convert_jinja_query_migrated_ref(empty_dbt_context: Context) -> None: - ctx = empty_dbt_context - - (ctx.path / "models" / "model1.sql").write_text(""" - {{ - config( - materialized='incremental', - incremental_strategy='delete+insert', - time_column='ds' - ) - }} - - {% macro ref_handler(relation) %} - {{ relation.name }} - {% endmacro %} - - select * from {{ ref_handler(ref("items")) }} - """) - - ctx.load() - - model = ctx.models['"memory"."project"."model1"'] - assert isinstance(model, SqlModel) - query = model.query - assert isinstance(query, d.JinjaQuery) - - converted_query = convert_jinja_query(ctx, model, query) - - assert ( - """select * from {{ ref_handler(__migrated_ref(database='memory', schema='project_raw', identifier='items', sqlmesh_model_name='project_raw.items')) }}""" - in converted_query.sql() - ) - - -def test_convert_jinja_query_post_statement(empty_dbt_context: Context) -> None: - ctx = empty_dbt_context - - (ctx.path / "models" / "model1.sql").write_text(""" - {{ - config( - materialized='incremental', - incremental_strategy='delete+insert', - time_column='ds', - post_hook="create index foo_idx on {{ this }} (id)" - ) - }} - - select * from {{ ref("items") }} - """) - - ctx.load() - - model = ctx.models['"memory"."project"."model1"'] - assert isinstance(model, SqlModel) - - assert model.post_statements - post_statement = model.post_statements[0] - assert isinstance(post_statement, d.JinjaStatement) - - converted_post_statement = convert_jinja_query(ctx, model, post_statement) - - assert "CREATE INDEX foo_idx ON project.model1(id)" in converted_post_statement.sql( - dialect="duckdb" - ) - - -@pytest.mark.parametrize( - "input,expected", - [ - ( - """ - {% macro incremental_by_time(col, time_type) %} - {% if is_incremental() %} - WHERE - {{ col }} > (select max({{ col }}) from {{ this }}) - {% endif %} - {% if sqlmesh_incremental is defined %} - {% set dates = incremental_dates_by_time_type(time_type) %} - WHERE - {{ col }} BETWEEN '{{ dates[0] }}' AND '{{ dates[1] }}' - {% endif %} - {% endmacro %} - """, - """ - {% macro incremental_by_time(col, time_type) %} - {% set dates = incremental_dates_by_time_type(time_type) %} - WHERE - {{ col }} BETWEEN '{{ dates[0] }}' AND '{{ dates[1] }}' - {% endmacro %} - """, - ), - ( - """ - {% macro foo(iterations) %} - with base as ( - select * from {{ ref('customer_revenue_by_day') }} - ), - iter as ( - {% for i in range(0, iterations) %} - 'iter_{{ i }}' as iter_num_{{ i }} - {% if not loop.last %},{% endif %} - {% endfor %} - ) - select 1 - {% endmacro %}""", - """ - {% macro foo(iterations) %} - with base as ( - select * from sushi.customer_revenue_by_day - ), - iter as ( - {% for i in range(0, iterations) %} - 'iter_{{ i }}' as iter_num_{{ i }} - {% if not loop.last %},{% endif %} - {% endfor %} - ) - select 1 - {% endmacro %}""", - ), - ( - """{% macro expand_ref(model_name) %}{{ ref(model_name) }}{% endmacro %}""", - """{% macro expand_ref(model_name) %}{{ ref(model_name) }}{% endmacro %}""", - ), - ], -) -def test_convert_jinja_macro(input: str, expected: str, sushi_dbt_context: Context) -> None: - result = convert_jinja_macro(sushi_dbt_context, input.strip()) - - assert " ".join(result.split()) == " ".join(expected.strip().split()) - - -def test_extract_macro_references_and_variables() -> None: - input = """JINJA_QUERY('{%- set something = "'"~var("variable").split("|") -%}""" - _, variables = extract_macro_references_and_variables(input) - assert len(variables) == 1 - assert variables == {"variable"} diff --git a/tests/dbt/converter/test_jinja_transforms.py b/tests/dbt/converter/test_jinja_transforms.py deleted file mode 100644 index c7d060ea40..0000000000 --- a/tests/dbt/converter/test_jinja_transforms.py +++ /dev/null @@ -1,453 +0,0 @@ -import pytest -import typing as t -from sqlglot import parse_one -from sqlmesh.core.model import create_sql_model, create_external_model -from sqlmesh.dbt.converter.jinja import transform, JinjaGenerator -import sqlmesh.dbt.converter.jinja_transforms as jt -from sqlmesh.dbt.converter.common import JinjaTransform -from sqlmesh.utils.jinja import environment, Environment, ENVIRONMENT -from sqlmesh.core.context import Context -from sqlmesh.core.config import Config, ModelDefaultsConfig - - -def transform_str( - input: str, handler: JinjaTransform, environment: t.Optional[Environment] = None -) -> str: - environment = environment or ENVIRONMENT - ast = environment.parse(input) - return JinjaGenerator().generate(transform(ast, handler)) - - -@pytest.mark.parametrize( - "input,expected", - [ - ("select * from {{ ref('bar') }} as t", "select * from foo.bar as t"), - ("select * from {{ ref('bar', version=1) }} as t", "select * from foo.bar_v1 as t"), - ("select * from {{ ref('bar', v=1) }} as t", "select * from foo.bar_v1 as t"), - ( - "select * from {{ ref('unknown') }} as t", - "select * from __unresolved_ref__.unknown as t", - ), - ( - "{% macro foo() %}select * from {{ ref('bar') }}{% endmacro %}", - "{% macro foo() %}select * from foo.bar{% endmacro %}", - ), - # these shouldnt be transformed as the macro call might rely on some property of the Relation object returned by ref() - ("{{ dbt_utils.union_relations([ref('foo')]) }},", None), - ("select * from {% if some_macro(ref('bar')) %}foo{% endif %}", None), - ( - "select * from {% if some_macro(ref('bar')) %}{{ ref('bar') }}{% endif %}", - "select * from {% if some_macro(ref('bar')) %}foo.bar{% endif %}", - ), - ("{{ some_macro(ref('bar')) }}", None), - ("{{ some_macro(table=ref('bar')) }}", None), - ], -) -def test_resolve_dbt_ref_to_model_name(input: str, expected: t.Optional[str]) -> None: - expected = expected or input - - from dbt.adapters.base import BaseRelation - - # note: bigquery dialect chosen because its identifiers have backticks - # but internally SQLMesh stores model fqn with double quotes - config = Config(model_defaults=ModelDefaultsConfig(dialect="bigquery")) - ctx = Context(config=config) - ctx.default_catalog = "sqlmesh" - - assert ctx.default_catalog == "sqlmesh" - assert ctx.default_dialect == "bigquery" - - model = create_sql_model( - name="foo.bar", query=parse_one("select 1"), default_catalog=ctx.default_catalog - ) - model2 = create_sql_model( - name="foo.bar_v1", query=parse_one("select 1"), default_catalog=ctx.default_catalog - ) - ctx.upsert_model(model) - ctx.upsert_model(model2) - - assert '"sqlmesh"."foo"."bar"' in ctx.models - - def _resolve_ref(ref_name: str, version: t.Optional[int] = None) -> t.Optional[BaseRelation]: - if ref_name == "bar": - identifier = "bar" - if version: - identifier = f"bar_v{version}" - - relation = BaseRelation.create( - database="sqlmesh", schema="foo", identifier=identifier, quote_character="`" - ) - assert ( - relation.render() == "`sqlmesh`.`foo`.`bar`" - if not version - else f"`sqlmesh`.`foo`.`bar_v{version}`" - ) - return relation - return None - - jinja_env = environment() - jinja_env.globals["ref"] = _resolve_ref - - assert ( - transform_str( - input, - jt.resolve_dbt_ref_to_model_name(ctx.models, jinja_env, dialect=ctx.default_dialect), - ) - == expected - ) - - -@pytest.mark.parametrize( - "input,expected", - [ - ( - "select * from {{ ref('bar') }} as t", - "select * from {{ __migrated_ref(database='sqlmesh', schema='foo', identifier='bar', sqlmesh_model_name='foo.bar') }} as t", - ), - ( - "{% macro foo() %}select * from {{ ref('bar') }}{% endmacro %}", - "{% macro foo() %}select * from {{ __migrated_ref(database='sqlmesh', schema='foo', identifier='bar', sqlmesh_model_name='foo.bar') }}{% endmacro %}", - ), - ( - "{{ dbt_utils.union_relations([ref('bar')]) }}", - "{{ dbt_utils.union_relations([__migrated_ref(database='sqlmesh', schema='foo', identifier='bar', sqlmesh_model_name='foo.bar')]) }}", - ), - ( - "select * from {% if some_macro(ref('bar')) %}foo{% endif %}", - "select * from {% if some_macro(__migrated_ref(database='sqlmesh', schema='foo', identifier='bar', sqlmesh_model_name='foo.bar')) %}foo{% endif %}", - ), - ( - "select * from {% if some_macro(ref('bar')) %}{{ ref('bar') }}{% endif %}", - "select * from {% if some_macro(__migrated_ref(database='sqlmesh', schema='foo', identifier='bar', sqlmesh_model_name='foo.bar')) %}{{ __migrated_ref(database='sqlmesh', schema='foo', identifier='bar', sqlmesh_model_name='foo.bar') }}{% endif %}", - ), - ( - "{{ some_macro(ref('bar')) }}", - "{{ some_macro(__migrated_ref(database='sqlmesh', schema='foo', identifier='bar', sqlmesh_model_name='foo.bar')) }}", - ), - ( - "{{ some_macro(table=ref('bar')) }}", - "{{ some_macro(table=__migrated_ref(database='sqlmesh', schema='foo', identifier='bar', sqlmesh_model_name='foo.bar')) }}", - ), - ], -) -def test_rewrite_dbt_ref_to_migrated_ref(input: str, expected: t.Optional[str]) -> None: - expected = expected or input - - from dbt.adapters.base import BaseRelation - - # note: bigquery dialect chosen because its identifiers have backticks - # but internally SQLMesh stores model fqn with double quotes - config = Config(model_defaults=ModelDefaultsConfig(dialect="bigquery")) - ctx = Context(config=config) - ctx.default_catalog = "sqlmesh" - - assert ctx.default_catalog == "sqlmesh" - assert ctx.default_dialect == "bigquery" - - model = create_sql_model( - name="foo.bar", query=parse_one("select 1"), default_catalog=ctx.default_catalog - ) - ctx.upsert_model(model) - - assert '"sqlmesh"."foo"."bar"' in ctx.models - - def _resolve_ref(ref_name: str) -> t.Optional[BaseRelation]: - if ref_name == "bar": - relation = BaseRelation.create( - database="sqlmesh", schema="foo", identifier="bar", quote_character="`" - ) - assert relation.render() == "`sqlmesh`.`foo`.`bar`" - return relation - return None - - jinja_env = environment() - jinja_env.globals["ref"] = _resolve_ref - - assert ( - transform_str( - input, - jt.rewrite_dbt_ref_to_migrated_ref(ctx.models, jinja_env, dialect=ctx.default_dialect), - ) - == expected - ) - - -@pytest.mark.parametrize( - "input,expected", - [ - ("select * from {{ source('upstream', 'foo') }} as t", "select * from upstream.foo as t"), - ("select * from {{ source('unknown', 'foo') }} as t", "select * from unknown.foo as t"), - ( - "{% macro foo() %}select * from {{ source('upstream', 'foo') }}{% endmacro %}", - "{% macro foo() %}select * from upstream.foo{% endmacro %}", - ), - # these shouldnt be transformed as the macro call might rely on some property of the Relation object returned by source() - ("select * from {% if some_macro(source('upstream', 'foo')) %}foo{% endif %}", None), - ("{{ dbt_utils.union_relations([source('upstream', 'foo')]) }},", None), - ( - "select * from {% if some_macro(source('upstream', 'foo')) %}{{ source('upstream', 'foo') }}{% endif %}", - "select * from {% if some_macro(source('upstream', 'foo')) %}upstream.foo{% endif %}", - ), - ("{{ some_macro(source('upstream', 'foo')) }}", None), - ("{% set results = run_query('select foo from ' ~ source('schema', 'table')) %}", None), - ], -) -def test_resolve_dbt_source_to_model_name(input: str, expected: t.Optional[str]) -> None: - expected = expected or input - - from dbt.adapters.base import BaseRelation - - # note: bigquery dialect chosen because its identifiers have backticks - # but internally SQLMesh stores model fqn with double quotes - config = Config(model_defaults=ModelDefaultsConfig(dialect="bigquery")) - ctx = Context(config=config) - ctx.default_catalog = "sqlmesh" - - assert ctx.default_catalog == "sqlmesh" - assert ctx.default_dialect == "bigquery" - - model = create_external_model(name="upstream.foo", default_catalog=ctx.default_catalog) - ctx.upsert_model(model) - - assert '"sqlmesh"."upstream"."foo"' in ctx.models - - def _resolve_source(schema_name: str, table_name: str) -> t.Optional[BaseRelation]: - if schema_name == "upstream" and table_name == "foo": - relation = BaseRelation.create( - database="sqlmesh", schema="upstream", identifier="foo", quote_character="`" - ) - assert relation.render() == "`sqlmesh`.`upstream`.`foo`" - return relation - return None - - jinja_env = environment() - jinja_env.globals["source"] = _resolve_source - - assert ( - transform_str( - input, - jt.resolve_dbt_source_to_model_name(ctx.models, jinja_env, dialect=ctx.default_dialect), - ) - == expected - ) - - -@pytest.mark.parametrize( - "input,expected", - [ - ( - "select * from {{ source('upstream', 'foo') }} as t", - "select * from {{ __migrated_source(database='sqlmesh', schema='upstream', identifier='foo') }} as t", - ), - ( - "select * from {{ source('unknown', 'foo') }} as t", - "select * from {{ source('unknown', 'foo') }} as t", - ), - ( - "{% macro foo() %}select * from {{ source('upstream', 'foo') }}{% endmacro %}", - "{% macro foo() %}select * from {{ __migrated_source(database='sqlmesh', schema='upstream', identifier='foo') }}{% endmacro %}", - ), - ( - "select * from {% if some_macro(source('upstream', 'foo')) %}foo{% endif %}", - "select * from {% if some_macro(__migrated_source(database='sqlmesh', schema='upstream', identifier='foo')) %}foo{% endif %}", - ), - ( - "{{ dbt_utils.union_relations([source('upstream', 'foo')]) }},", - "{{ dbt_utils.union_relations([__migrated_source(database='sqlmesh', schema='upstream', identifier='foo')]) }},", - ), - ( - "select * from {% if some_macro(source('upstream', 'foo')) %}{{ source('upstream', 'foo') }}{% endif %}", - "select * from {% if some_macro(__migrated_source(database='sqlmesh', schema='upstream', identifier='foo')) %}{{ __migrated_source(database='sqlmesh', schema='upstream', identifier='foo') }}{% endif %}", - ), - ( - "{{ some_macro(source('upstream', 'foo')) }}", - "{{ some_macro(__migrated_source(database='sqlmesh', schema='upstream', identifier='foo')) }}", - ), - ( - "{% set results = run_query('select foo from ' ~ source('upstream', 'foo')) %}", - "{% set results = run_query('select foo from ' ~ __migrated_source(database='sqlmesh', schema='upstream', identifier='foo')) %}", - ), - ], -) -def test_rewrite_dbt_source_to_migrated_source(input: str, expected: t.Optional[str]) -> None: - expected = expected or input - - from dbt.adapters.base import BaseRelation - - # note: bigquery dialect chosen because its identifiers have backticks - # but internally SQLMesh stores model fqn with double quotes - config = Config(model_defaults=ModelDefaultsConfig(dialect="bigquery")) - ctx = Context(config=config) - ctx.default_catalog = "sqlmesh" - - assert ctx.default_catalog == "sqlmesh" - assert ctx.default_dialect == "bigquery" - - model = create_external_model(name="upstream.foo", default_catalog=ctx.default_catalog) - ctx.upsert_model(model) - - assert '"sqlmesh"."upstream"."foo"' in ctx.models - - def _resolve_source(schema_name: str, table_name: str) -> t.Optional[BaseRelation]: - if schema_name == "upstream" and table_name == "foo": - relation = BaseRelation.create( - database="sqlmesh", schema="upstream", identifier="foo", quote_character="`" - ) - assert relation.render() == "`sqlmesh`.`upstream`.`foo`" - return relation - return None - - jinja_env = environment() - jinja_env.globals["source"] = _resolve_source - - assert ( - transform_str( - input, - jt.rewrite_dbt_source_to_migrated_source( - ctx.models, jinja_env, dialect=ctx.default_dialect - ), - ) - == expected - ) - - -@pytest.mark.parametrize( - "input,expected", - [ - ("select * from {{ this }}", "select * from foo.bar"), - ("{% if foo(this) %}bar{% endif %}", None), - ("select * from {{ this.identifier }}", None), - ], -) -def test_resolve_dbt_this_to_model_name(input: str, expected: t.Optional[str]): - expected = expected or input - assert transform_str(input, jt.resolve_dbt_this_to_model_name("foo.bar")) == expected - - -@pytest.mark.parametrize( - "input,expected", - [ - # sqlmesh_incremental present, is_incremental() block removed - ( - """ - select * from foo where - {% if is_incremental() %}ds > (select max(ds)) from {{ this }}){% endif %} - {% if sqlmesh_incremental is defined %}ds BETWEEN {{ start_ds }} and {{ end_ds }}{% endif %} - """, - """ - select * from foo - where - {% if sqlmesh_incremental is defined %}ds BETWEEN {{ start_ds }} and {{ end_ds }}{% endif %} - """, - ), - # sqlmesh_incremental is NOT present; is_incremental() blocks untouched - ( - """ - select * from foo - where - {% if is_incremental() %}ds > (select max(ds)) from {{ this }}){% endif %} - """, - """ - select * from foo - where - {% if is_incremental() %}ds > (select max(ds)) from {{ this }}){% endif %} - """, - ), - ], -) -def test_deduplicate_incremental_checks(input: str, expected: str) -> None: - assert " ".join(transform_str(input, jt.deduplicate_incremental_checks()).split()) == " ".join( - expected.split() - ) - - -@pytest.mark.parametrize( - "input,expected", - [ - # is_incremental() removed - ( - "select * from foo where {% if is_incremental() %}ds >= (select max(ds) from {{ this }} ){% endif %}", - "select * from foo where ds >= (select max(ds) from {{ this }} )", - ), - # sqlmesh_incremental removed - ( - "select * from foo where {% if sqlmesh_incremental is defined %}ds BETWEEN {{ start_ds }} and {{ end_ds }}{% endif %}", - "select * from foo where ds BETWEEN {{ start_ds }} and {{ end_ds }}", - ), - # else untouched - ( - "select * from foo where {% if is_incremental() %}ds >= (select max(ds) from {{ this }} ){% else %}ds is not null{% endif %}", - "select * from foo where {% if is_incremental() %}ds >= (select max(ds) from {{ this }} ){% else %}ds is not null{% endif %}", - ), - ], -) -def test_unpack_incremental_checks(input: str, expected: str) -> None: - assert " ".join(transform_str(input, jt.unpack_incremental_checks()).split()) == " ".join( - expected.split() - ) - - -@pytest.mark.parametrize( - "input,expected", - [ - ("{{ start_ds }}", "@start_ds"), - ( - "select id, ds from foo where ds between {{ start_ts }} and {{ end_ts }}", - "select id, ds from foo where ds between @start_ts and @end_ts", - ), - ("select {{ some_macro(start_ts) }}", None), - ("{{ start_date }}", "@start_date"), - ("'{{ start_date }}'", "'@start_ds'"), # date inside string literal should remain a string - ], -) -def test_rewrite_sqlmesh_predefined_variables_to_sqlmesh_macro_syntax( - input: str, expected: t.Optional[str] -) -> None: - expected = expected or input - assert ( - transform_str(input, jt.rewrite_sqlmesh_predefined_variables_to_sqlmesh_macro_syntax()) - == expected - ) - - -@pytest.mark.parametrize( - "input,expected,package", - [ - ("{{ var('foo') }}", "{{ var('foo') }}", None), - ("{{ var('foo') }}", "{{ var('foo', __dbt_package='test') }}", "test"), - ( - "{{ var('foo', 'default') }}", - "{{ var('foo', 'default', __dbt_package='test') }}", - "test", - ), - ( - "{% if 'col_name' in var('history_columns') %}bar{% endif %}", - "{% if 'col_name' in var('history_columns', __dbt_package='test') %}bar{% endif %}", - "test", - ), - ], -) -def test_append_dbt_package_kwarg_to_var_calls( - input: str, expected: str, package: t.Optional[str] -) -> None: - assert ( - transform_str(input, jt.append_dbt_package_kwarg_to_var_calls(package_name=package)) - == expected - ) - - -@pytest.mark.parametrize( - "input,expected", - [ - ( - "select * from foo where ds between '@start_dt' and '@end_dt'", - "SELECT * FROM foo WHERE ds BETWEEN @start_dt AND @end_dt", - ), - ( - "select * from foo where bar <> '@unrelated'", - "SELECT * FROM foo WHERE bar <> '@unrelated'", - ), - ], -) -def test_unwrap_macros_in_string_literals(input: str, expected: str) -> None: - assert parse_one(input).transform(jt.unwrap_macros_in_string_literals()).sql() == expected diff --git a/tests/utils/test_jinja.py b/tests/utils/test_jinja.py index 3660adaa95..5eb00aeb3c 100644 --- a/tests/utils/test_jinja.py +++ b/tests/utils/test_jinja.py @@ -9,8 +9,6 @@ MacroReturnVal, call_name, nodes, - extract_macro_references_and_variables, - extract_dbt_adapter_dispatch_targets, ) @@ -177,54 +175,6 @@ def test_macro_registry_trim(): assert not trimmed_registry_for_package_b.root_macros -def test_macro_registry_trim_keeps_dbt_adapter_dispatch(): - registry = JinjaMacroRegistry() - extractor = MacroExtractor() - - registry.add_macros( - extractor.extract( - """ - {% macro foo(col) %} - {{ adapter.dispatch('foo', 'test_package') }} - {% endmacro %} - - {% macro default__foo(col) %} - foo_{{ col }} - {% endmacro %} - - {% macro unrelated() %}foo{% endmacro %} - """, - dialect="duckdb", - ), - package="test_package", - ) - - assert sorted(list(registry.packages["test_package"].keys())) == [ - "default__foo", - "foo", - "unrelated", - ] - assert sorted(str(r) for r in registry.packages["test_package"]["foo"].depends_on) == [ - "adapter.dispatch", - "test_package.default__foo", - "test_package.duckdb__foo", - ] - - query_str = """ - select * from {{ test_package.foo('bar') }} - """ - - references, _ = extract_macro_references_and_variables(query_str, dbt_target_name="test") - references_list = list(references) - assert len(references_list) == 1 - assert str(references_list[0]) == "test_package.foo" - - trimmed_registry = registry.trim(references) - - # duckdb__foo is missing from this list because it's not actually defined as a macro - assert sorted(list(trimmed_registry.packages["test_package"].keys())) == ["default__foo", "foo"] - - def test_macro_return(): macros = "{% macro test_return() %}{{ macro_return([1, 2, 3]) }}{% endmacro %}" @@ -352,31 +302,3 @@ def test_dbt_adapter_macro_scope(): rendered = registry.build_environment().from_string("{{ spark__macro_a() }}").render() assert rendered.strip() == "macro_a" - - -def test_extract_dbt_adapter_dispatch_targets(): - assert extract_dbt_adapter_dispatch_targets(""" - {% macro my_macro(arg1, arg2) -%} - {{ return(adapter.dispatch('my_macro')(arg1, arg2)) }} - {% endmacro %} - """) == [("my_macro", None)] - - assert extract_dbt_adapter_dispatch_targets(""" - {% macro my_macro(arg1, arg2) -%} - {{ return(adapter.dispatch('my_macro', 'foo')(arg1, arg2)) }} - {% endmacro %} - """) == [("my_macro", "foo")] - - assert extract_dbt_adapter_dispatch_targets("""{{ adapter.dispatch('my_macro') }}""") == [ - ("my_macro", None) - ] - - assert extract_dbt_adapter_dispatch_targets(""" - {% macro foo() %} - {{ adapter.dispatch('my_macro') }} - {{ some_other_call() }} - {{ return(adapter.dispatch('other_macro', 'other_package')) }} - {% endmacro %} - """) == [("my_macro", None), ("other_macro", "other_package")] - - assert extract_dbt_adapter_dispatch_targets("no jinja") == [] From 6c20e4d1c327851ffe569704f8726175d857efb2 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Thu, 28 Aug 2025 02:56:15 +0000 Subject: [PATCH 2/2] fix mypy errors for newer pandas --- sqlmesh/core/test/definition.py | 4 +++- tests/integrations/github/cicd/test_integration.py | 8 +++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sqlmesh/core/test/definition.py b/sqlmesh/core/test/definition.py index 8123f52d26..b995310d09 100644 --- a/sqlmesh/core/test/definition.py +++ b/sqlmesh/core/test/definition.py @@ -648,7 +648,9 @@ def _create_df( if partial: columns = referenced_columns - return pd.DataFrame.from_records(rows, columns=columns) + return pd.DataFrame.from_records( + rows, columns=[str(c) for c in columns] if columns else None + ) def _add_missing_columns( self, query: exp.Query, all_columns: t.Optional[t.Collection[str]] = None diff --git a/tests/integrations/github/cicd/test_integration.py b/tests/integrations/github/cicd/test_integration.py index e974ea6fc2..f78419889d 100644 --- a/tests/integrations/github/cicd/test_integration.py +++ b/tests/integrations/github/cicd/test_integration.py @@ -37,11 +37,9 @@ def get_environment_objects(controller: GithubController, environment: str) -> t def get_num_days_loaded(controller: GithubController, environment: str, model: str) -> int: - return int( - controller._context.engine_adapter.fetchdf( - f"SELECT distinct event_date FROM sushi__{environment}.{model}" - ).count() - ) + return controller._context.engine_adapter.fetchdf( + f"SELECT distinct event_date FROM sushi__{environment}.{model}" + ).shape[0] def get_columns(