Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 0 additions & 37 deletions sqlmesh/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
"rollback",
"run",
"table_name",
"dbt",
)
SKIP_CONTEXT_COMMANDS = ("init", "ui")

Expand Down Expand Up @@ -1307,39 +1306,3 @@ def state_import(obj: Context, input_file: Path, replace: bool, no_confirm: bool
"""Import a state export file back into the state database"""
confirm = not no_confirm
obj.import_state(input_file=input_file, clear=replace, confirm=confirm)


@cli.group(no_args_is_help=True, hidden=True)
def dbt() -> None:
"""Commands for doing dbt-specific things"""
pass


@dbt.command("convert")
@click.option(
"-i",
"--input-dir",
help="Path to the DBT project",
required=True,
type=click.Path(exists=True, dir_okay=True, file_okay=False, readable=True, path_type=Path),
)
@click.option(
"-o",
"--output-dir",
required=True,
help="Path to write out the converted SQLMesh project",
type=click.Path(exists=False, dir_okay=True, file_okay=False, readable=True, path_type=Path),
)
@click.option("--no-prompts", is_flag=True, help="Disable interactive prompts", default=False)
@click.pass_obj
@error_handler
@cli_analytics
def dbt_convert(obj: Context, input_dir: Path, output_dir: Path, no_prompts: bool) -> None:
"""Convert a DBT project to a SQLMesh project"""
from sqlmesh.dbt.converter.convert import convert_project_files

convert_project_files(
input_dir.absolute(),
output_dir.absolute(),
no_prompts=no_prompts,
)
9 changes: 1 addition & 8 deletions sqlmesh/core/config/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
scheduler_config_validator,
)
from sqlmesh.core.config.ui import UIConfig
from sqlmesh.core.loader import Loader, SqlMeshLoader, MigratedDbtProjectLoader
from sqlmesh.core.loader import Loader, SqlMeshLoader
from sqlmesh.core.notification_target import NotificationTarget
from sqlmesh.core.user import User
from sqlmesh.utils.date import to_timestamp, now
Expand Down Expand Up @@ -227,13 +227,6 @@ def _normalize_and_validate_fields(cls, data: t.Any) -> t.Any:
f"^{k}$": v for k, v in physical_schema_override.items()
}

if (
(variables := data.get("variables", ""))
and isinstance(variables, dict)
and c.MIGRATED_DBT_PROJECT_NAME in variables
):
data["loader"] = MigratedDbtProjectLoader

return data

@model_validator(mode="after")
Expand Down
3 changes: 0 additions & 3 deletions sqlmesh/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@
MAX_MODEL_DEFINITION_SIZE = 10000
"""Maximum number of characters in a model definition"""

MIGRATED_DBT_PROJECT_NAME = "__dbt_project_name__"
MIGRATED_DBT_PACKAGES = "__dbt_packages__"


# The maximum number of fork processes, used for loading projects
# None means default to process pool, 1 means don't fork, :N is number of processes
Expand Down
109 changes: 1 addition & 108 deletions sqlmesh/core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@
from sqlmesh.core.test import ModelTestMetadata, filter_tests_by_patterns
from sqlmesh.utils import UniqueKeyDict, sys_path
from sqlmesh.utils.errors import ConfigError
from sqlmesh.utils.jinja import (
JinjaMacroRegistry,
MacroExtractor,
SQLMESH_DBT_COMPATIBILITY_PACKAGE,
)
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor
from sqlmesh.utils.metaprogramming import import_python_file
from sqlmesh.utils.pydantic import validation_error_message
from sqlmesh.utils.process import create_process_pool_executor
Expand Down Expand Up @@ -561,7 +557,6 @@ def _load_sql_models(
signals: UniqueKeyDict[str, signal],
cache: CacheBase,
gateway: t.Optional[str],
loading_default_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
) -> UniqueKeyDict[str, Model]:
"""Loads the sql models into a Dict"""
models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
Expand Down Expand Up @@ -604,7 +599,6 @@ def _load_sql_models(
signal_definitions=signals,
default_catalog_per_gateway=self.context.default_catalog_per_gateway,
virtual_environment_mode=self.config.virtual_environment_mode,
**loading_default_kwargs or {},
)

with create_process_pool_executor(
Expand Down Expand Up @@ -971,104 +965,3 @@ def _model_cache_entry_id(self, model_path: Path) -> str:
self._loader.context.gateway or self._loader.config.default_gateway_name,
]
)


class MigratedDbtProjectLoader(SqlMeshLoader):
@property
def migrated_dbt_project_name(self) -> str:
return self.config.variables[c.MIGRATED_DBT_PROJECT_NAME]

def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
from sqlmesh.dbt.converter.common import infer_dbt_package_from_path
from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS

# Store a copy of the macro registry
standard_macros = macro.get_registry()

jinja_macros = JinjaMacroRegistry(
create_builtins_module=SQLMESH_DBT_COMPATIBILITY_PACKAGE,
top_level_packages=["dbt", self.migrated_dbt_project_name],
)
extractor = MacroExtractor()

macros_max_mtime: t.Optional[float] = None

for path in self._glob_paths(
self.config_path / c.MACROS,
ignore_patterns=self.config.ignore_patterns,
extension=".py",
):
if import_python_file(path, self.config_path):
self._track_file(path)
macro_file_mtime = self._path_mtimes[path]
macros_max_mtime = (
max(macros_max_mtime, macro_file_mtime)
if macros_max_mtime
else macro_file_mtime
)

for path in self._glob_paths(
self.config_path / c.MACROS,
ignore_patterns=self.config.ignore_patterns,
extension=".sql",
):
self._track_file(path)
macro_file_mtime = self._path_mtimes[path]
macros_max_mtime = (
max(macros_max_mtime, macro_file_mtime) if macros_max_mtime else macro_file_mtime
)

with open(path, "r", encoding="utf-8") as file:
try:
package = infer_dbt_package_from_path(path) or self.migrated_dbt_project_name

jinja_macros.add_macros(
extractor.extract(file.read(), dialect=self.config.model_defaults.dialect),
package=package,
)
except Exception as e:
raise ConfigError(f"Failed to load macro file: {e}", path)

self._macros_max_mtime = macros_max_mtime

macros = macro.get_registry()
macro.set_registry(standard_macros)

connection_config = self.context.connection_config
# this triggers the DBT create_builtins_module to have a `target` property which is required for a bunch of DBT macros to work
if dbt_config_type := TARGET_TYPE_TO_CONFIG_CLASS.get(connection_config.type_):
try:
jinja_macros.add_globals(
{
"target": dbt_config_type.from_sqlmesh(
connection_config,
name=self.config.default_gateway_name,
).attribute_dict()
}
)
except NotImplementedError:
raise ConfigError(f"Unsupported dbt target type: {connection_config.type_}")

return macros, jinja_macros

def _load_sql_models(
self,
macros: MacroRegistry,
jinja_macros: JinjaMacroRegistry,
audits: UniqueKeyDict[str, ModelAudit],
signals: UniqueKeyDict[str, signal],
cache: CacheBase,
gateway: t.Optional[str],
loading_default_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
) -> UniqueKeyDict[str, Model]:
return super()._load_sql_models(
macros=macros,
jinja_macros=jinja_macros,
audits=audits,
signals=signals,
cache=cache,
gateway=gateway,
loading_default_kwargs=dict(
migrated_dbt_project_name=self.migrated_dbt_project_name,
),
)
61 changes: 5 additions & 56 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2061,7 +2061,6 @@ def load_sql_based_model(
variables: t.Optional[t.Dict[str, t.Any]] = None,
infer_names: t.Optional[bool] = False,
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
migrated_dbt_project_name: t.Optional[str] = None,
**kwargs: t.Any,
) -> Model:
"""Load a model from a parsed SQLMesh model SQL file.
Expand Down Expand Up @@ -2239,7 +2238,6 @@ def load_sql_based_model(
query_or_seed_insert,
kind=kind,
time_column_format=time_column_format,
migrated_dbt_project_name=migrated_dbt_project_name,
**common_kwargs,
)

Expand Down Expand Up @@ -2451,7 +2449,6 @@ def _create_model(
signal_definitions: t.Optional[SignalRegistry] = None,
variables: t.Optional[t.Dict[str, t.Any]] = None,
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
migrated_dbt_project_name: t.Optional[str] = None,
**kwargs: t.Any,
) -> Model:
validate_extra_and_required_fields(
Expand Down Expand Up @@ -2531,31 +2528,16 @@ def _create_model(

if jinja_macros:
jinja_macros = (
jinja_macros
if jinja_macros.trimmed
else jinja_macros.trim(jinja_macro_references, package=migrated_dbt_project_name)
jinja_macros if jinja_macros.trimmed else jinja_macros.trim(jinja_macro_references)
)
else:
jinja_macros = JinjaMacroRegistry()

if migrated_dbt_project_name:
# extract {{ var() }} references used in all jinja macro dependencies to check for any variables specific
# to a migrated DBT package and resolve them accordingly
# vars are added into __sqlmesh_vars__ in the Python env so that the native SQLMesh var() function can resolve them
variables = variables or {}

nested_macro_used_variables, flattened_package_variables = (
_extract_migrated_dbt_variable_references(jinja_macros, variables)
for jinja_macro in jinja_macros.root_macros.values():
referenced_variables.update(
extract_macro_references_and_variables(jinja_macro.definition)[1]
)

referenced_variables.update(nested_macro_used_variables)
variables.update(flattened_package_variables)
else:
for jinja_macro in jinja_macros.root_macros.values():
referenced_variables.update(
extract_macro_references_and_variables(jinja_macro.definition)[1]
)

# Merge model-specific audits with default audits
if default_audits := defaults.pop("audits", None):
kwargs["audits"] = default_audits + d.extract_function_calls(kwargs.pop("audits", []))
Expand Down Expand Up @@ -2943,7 +2925,7 @@ def render_expression(
"cron_tz": lambda value: exp.Literal.string(value),
"partitioned_by_": _single_expr_or_tuple,
"clustered_by": _single_expr_or_tuple,
"depends_on_": lambda value: exp.Tuple(expressions=sorted(value)) if value else "()",
"depends_on_": lambda value: exp.Tuple(expressions=sorted(value)),
"pre": _list_of_calls_to_exp,
"post": _list_of_calls_to_exp,
"audits": _list_of_calls_to_exp,
Expand Down Expand Up @@ -3020,37 +3002,4 @@ def clickhouse_partition_func(
)


def _extract_migrated_dbt_variable_references(
jinja_macros: JinjaMacroRegistry, project_variables: t.Dict[str, t.Any]
) -> t.Tuple[t.Set[str], t.Dict[str, t.Any]]:
if not jinja_macros.trimmed:
raise ValueError("Expecting a trimmed JinjaMacroRegistry")

used_variables = set()
# note: JinjaMacroRegistry is trimmed here so "all_macros" should be just be all the macros used by this model
for _, _, jinja_macro in jinja_macros.all_macros:
_, extracted_variable_names = extract_macro_references_and_variables(jinja_macro.definition)
used_variables.update(extracted_variable_names)

flattened = {}
if (dbt_package_variables := project_variables.get(c.MIGRATED_DBT_PACKAGES)) and isinstance(
dbt_package_variables, dict
):
# flatten the nested dict structure from the migrated dbt package variables in the SQLmesh config into __dbt_packages.<package>.<variable>
# to match what extract_macro_references_and_variables() returns. This allows the usage checks in create_python_env() to work
def _flatten(prefix: str, root: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
acc = {}
for k, v in root.items():
key_with_prefix = f"{prefix}.{k}"
if isinstance(v, dict):
acc.update(_flatten(key_with_prefix, v))
else:
acc[key_with_prefix] = v
return acc

flattened = _flatten(c.MIGRATED_DBT_PACKAGES, dbt_package_variables)

return used_variables, flattened


TIME_COL_PARTITION_FUNC = {"clickhouse": clickhouse_partition_func}
5 changes: 2 additions & 3 deletions sqlmesh/core/model/kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import Enum
from typing_extensions import Self

from pydantic import Field, BeforeValidator
from pydantic import Field
from sqlglot import exp
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlglot.optimizer.qualify_columns import quote_identifiers
Expand Down Expand Up @@ -33,7 +33,6 @@
field_validator,
get_dialect,
validate_string,
positive_int_validator,
validate_expression,
)

Expand Down Expand Up @@ -505,7 +504,7 @@ class IncrementalByUniqueKeyKind(_IncrementalBy):
unique_key: SQLGlotListOfFields
when_matched: t.Optional[exp.Whens] = None
merge_filter: t.Optional[exp.Expression] = None
batch_concurrency: t.Annotated[t.Literal[1], BeforeValidator(positive_int_validator)] = 1
batch_concurrency: t.Literal[1] = 1

@field_validator("when_matched", mode="before")
def _when_matched_validator(
Expand Down
1 change: 0 additions & 1 deletion sqlmesh/core/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def _resolve_table(table: str | exp.Table) -> str:
)

render_kwargs = {
"dialect": self._dialect,
**date_dict(
to_datetime(execution_time or c.EPOCH),
start_time,
Expand Down
4 changes: 3 additions & 1 deletion sqlmesh/core/test/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,9 @@ def _create_df(
if partial:
columns = referenced_columns

return pd.DataFrame.from_records(rows, columns=columns)
return pd.DataFrame.from_records(
rows, columns=[str(c) for c in columns] if columns else None
)

def _add_missing_columns(
self, query: exp.Query, all_columns: t.Optional[t.Collection[str]] = None
Expand Down
3 changes: 0 additions & 3 deletions sqlmesh/dbt/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ def __init__(
self.jinja_globals = jinja_globals.copy() if jinja_globals else {}
self.jinja_globals["adapter"] = self
self.project_dialect = project_dialect
self.jinja_globals["dialect"] = (
project_dialect # so the dialect is available in the jinja env created by self.dispatch()
)
self.quote_policy = quote_policy or Policy()

@abc.abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/dbt/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class Var:
def __init__(self, variables: t.Dict[str, t.Any]) -> None:
self.variables = variables

def __call__(self, name: str, default: t.Optional[t.Any] = None, **kwargs: t.Any) -> t.Any:
def __call__(self, name: str, default: t.Optional[t.Any] = None) -> t.Any:
return self.variables.get(name, default)

def has_var(self, name: str) -> bool:
Expand Down
Empty file removed sqlmesh/dbt/converter/__init__.py
Empty file.
Loading