Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def _statement_renderer(self, expression: exp.Expression) -> ExpressionRenderer:
python_env=self.python_env,
only_execution_time=False,
default_catalog=self.default_catalog,
model_fqn=self.fqn,
model=self,
)
return self._statement_renderer_cache[expression_key]

Expand Down Expand Up @@ -1573,14 +1573,14 @@ def _query_renderer(self) -> QueryRenderer:
self.dialect,
self.macro_definitions,
schema=self.mapping_schema,
model_fqn=self.fqn,
path=self._path,
jinja_macro_registry=self.jinja_macros,
python_env=self.python_env,
only_execution_time=self.kind.only_execution_time,
default_catalog=self.default_catalog,
quote_identifiers=not no_quote_identifiers,
optimize_query=self.optimize_query,
model=self,
)

@property
Expand Down
51 changes: 29 additions & 22 deletions sqlmesh/core/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from sqlglot.dialects.dialect import DialectType

from sqlmesh.core.linter.rule import Rule
from sqlmesh.core.model.definition import _Model
from sqlmesh.core.snapshot import DeployabilityIndex, Snapshot


Expand All @@ -50,9 +51,9 @@ def __init__(
schema: t.Optional[t.Dict[str, t.Any]] = None,
default_catalog: t.Optional[str] = None,
quote_identifiers: bool = True,
model_fqn: t.Optional[str] = None,
normalize_identifiers: bool = True,
optimize_query: t.Optional[bool] = True,
model: t.Optional[_Model] = None,
):
self._expression = expression
self._dialect = dialect
Expand All @@ -66,8 +67,9 @@ def __init__(
self._quote_identifiers = quote_identifiers
self.update_schema({} if schema is None else schema)
self._cache: t.List[t.Optional[exp.Expression]] = []
self._model_fqn = model_fqn
self._model_fqn = model.fqn if model else None
self._optimize_query_flag = optimize_query is not False
self._model = model

def update_schema(self, schema: t.Dict[str, t.Any]) -> None:
self.schema = d.normalize_mapping_schema(schema, dialect=self._dialect)
Expand Down Expand Up @@ -188,30 +190,32 @@ def _resolve_table(table: str | exp.Table) -> str:
}

variables = kwargs.pop("variables", {})
jinja_env_kwargs = {
**{
**render_kwargs,
**_prepare_python_env_for_jinja(macro_evaluator, self._python_env),
**variables,
},
"snapshots": snapshots or {},
"table_mapping": table_mapping,
"deployability_index": deployability_index,
"default_catalog": self._default_catalog,
"runtime_stage": runtime_stage.value,
"resolve_table": _resolve_table,
}
if this_model:
render_kwargs["this_model"] = this_model
jinja_env_kwargs["this_model"] = this_model.sql(
dialect=self._dialect, identify=True, comments=False
)

jinja_env = self._jinja_macro_registry.build_environment(**jinja_env_kwargs)

expressions = [self._expression]
if isinstance(self._expression, d.Jinja):
try:
jinja_env_kwargs = {
**{
**render_kwargs,
**_prepare_python_env_for_jinja(macro_evaluator, self._python_env),
**variables,
},
"snapshots": snapshots or {},
"table_mapping": table_mapping,
"deployability_index": deployability_index,
"default_catalog": self._default_catalog,
"runtime_stage": runtime_stage.value,
"resolve_table": _resolve_table,
"model_instance": self._model,
}

if this_model:
jinja_env_kwargs["this_model"] = this_model.sql(
dialect=self._dialect, identify=True, comments=False
)

jinja_env = self._jinja_macro_registry.build_environment(**jinja_env_kwargs)

expressions = []
rendered_expression = jinja_env.from_string(self._expression.name).render()
logger.debug(
Expand All @@ -229,6 +233,9 @@ def _resolve_table(table: str | exp.Table) -> str:
f"Could not render or parse jinja at '{self._path}'.\n{ex}"
) from ex

if this_model:
render_kwargs["this_model"] = this_model

macro_evaluator.locals.update(render_kwargs)

if variables:
Expand Down
33 changes: 16 additions & 17 deletions sqlmesh/dbt/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DbtConfig,
Dependencies,
GeneralConfig,
RAW_CODE_KEY,
SqlStr,
sql_str_validator,
)
Expand Down Expand Up @@ -167,14 +168,6 @@ def _validate_grants(cls, v: t.Dict[str, str]) -> t.Dict[str, t.List[str]]:
},
}

@property
def sql_no_config(self) -> SqlStr:
return SqlStr("")

@property
def sql_embedded_config(self) -> SqlStr:
return SqlStr("")

@property
def table_schema(self) -> str:
"""
Expand Down Expand Up @@ -375,15 +368,21 @@ def to_sqlmesh(
def _model_jinja_context(
self, context: DbtContext, dependencies: Dependencies
) -> t.Dict[str, t.Any]:
model_node: AttributeDict[str, t.Any] = AttributeDict(
{
k: v
for k, v in context._manifest._manifest.nodes[self.node_name].to_dict().items()
if k in dependencies.model_attrs
}
if context._manifest and self.node_name in context._manifest._manifest.nodes
else {}
)
if context._manifest and self.node_name in context._manifest._manifest.nodes:
attributes = context._manifest._manifest.nodes[self.node_name].to_dict()
if dependencies.model_attrs.all_attrs:
model_node: AttributeDict[str, t.Any] = AttributeDict(attributes)
else:
model_node = AttributeDict(
filter(lambda kv: kv[0] in dependencies.model_attrs.attrs, attributes.items())
)

# We exclude the raw SQL code to reduce the payload size. It's still accessible through
# the JinjaQuery instance stored in the resulting SQLMesh model's `query` field.
model_node.pop(RAW_CODE_KEY, None)
else:
model_node = AttributeDict({})

return {
"this": self.relation_info,
"model": model_node,
Expand Down
11 changes: 11 additions & 0 deletions sqlmesh/dbt/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

from sqlmesh.core.console import get_console
from sqlmesh.core.engine_adapter import EngineAdapter
from sqlmesh.core.model.definition import SqlModel
from sqlmesh.core.snapshot.definition import DeployabilityIndex
from sqlmesh.dbt.adapter import BaseAdapter, ParsetimeAdapter, RuntimeAdapter
from sqlmesh.dbt.common import RAW_CODE_KEY
from sqlmesh.dbt.relation import Policy
from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS
from sqlmesh.dbt.util import DBT_VERSION
Expand Down Expand Up @@ -469,12 +471,21 @@ def create_builtin_globals(
is_incremental &= snapshot_table_exists
else:
is_incremental = False

builtin_globals["is_incremental"] = lambda: is_incremental

builtin_globals["builtins"] = AttributeDict(
{k: builtin_globals.get(k) for k in ("ref", "source", "config", "var")}
)

if (model := jinja_globals.pop("model", None)) is not None:
if isinstance(model_instance := jinja_globals.pop("model_instance", None), SqlModel):
builtin_globals["model"] = AttributeDict(
{**model, RAW_CODE_KEY: model_instance.query.name}
)
else:
builtin_globals["model"] = AttributeDict(model.copy())

if engine_adapter is not None:
builtin_globals["flags"] = Flags(which="run")
adapter: BaseAdapter = RuntimeAdapter(
Expand Down
18 changes: 15 additions & 3 deletions sqlmesh/dbt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,26 @@

import re
import typing as t
from dataclasses import dataclass
from pathlib import Path

from ruamel.yaml.constructor import DuplicateKeyError
from sqlglot.helper import ensure_list

from sqlmesh.dbt.util import DBT_VERSION
from sqlmesh.core.config.base import BaseConfig, UpdateStrategy
from sqlmesh.core.config.common import DBT_PROJECT_FILENAME
from sqlmesh.utils import AttributeDict
from sqlmesh.utils.conversions import ensure_bool, try_str_to_bool
from sqlmesh.utils.errors import ConfigError
from sqlmesh.utils.jinja import MacroReference
from sqlmesh.utils.pydantic import PydanticModel, field_validator
from sqlmesh.utils.yaml import load
from sqlmesh.core.config.common import DBT_PROJECT_FILENAME

T = t.TypeVar("T", bound="GeneralConfig")

PROJECT_FILENAME = DBT_PROJECT_FILENAME
RAW_CODE_KEY = "raw_code" if DBT_VERSION >= (1, 3, 0) else "raw_sql" # type: ignore

JINJA_ONLY = {
"adapter",
Expand Down Expand Up @@ -172,6 +175,12 @@ def sqlmesh_config_fields(self) -> t.Set[str]:
return set()


@dataclass
class ModelAttrs:
attrs: t.Set[str]
all_attrs: bool = False


class Dependencies(PydanticModel):
"""
DBT dependencies for a model, macro, etc.
Expand All @@ -186,7 +195,7 @@ class Dependencies(PydanticModel):
sources: t.Set[str] = set()
refs: t.Set[str] = set()
variables: t.Set[str] = set()
model_attrs: t.Set[str] = set()
model_attrs: ModelAttrs = ModelAttrs(attrs=set())

has_dynamic_var_names: bool = False

Expand All @@ -196,7 +205,10 @@ def union(self, other: Dependencies) -> Dependencies:
sources=self.sources | other.sources,
refs=self.refs | other.refs,
variables=self.variables | other.variables,
model_attrs=self.model_attrs | other.model_attrs,
model_attrs=ModelAttrs(
attrs=self.model_attrs.attrs | other.model_attrs.attrs,
all_attrs=self.model_attrs.all_attrs or other.model_attrs.all_attrs,
),
has_dynamic_var_names=self.has_dynamic_var_names or other.has_dynamic_var_names,
)

Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/dbt/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _to_sqlmesh(config: BMC, context: DbtContext) -> Model:
package_models: t.Dict[str, BaseModelConfig] = {**package.models, **package.seeds}

for model in package_models.values():
if isinstance(model, ModelConfig) and not model.sql_no_config:
if isinstance(model, ModelConfig) and not model.sql.strip():
logger.info(f"Skipping empty model '{model.name}' at path '{model.path}'.")
continue

Expand Down
46 changes: 38 additions & 8 deletions sqlmesh/dbt/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@
from sqlmesh.core import constants as c
from sqlmesh.utils.errors import SQLMeshError
from sqlmesh.core.config import ModelDefaultsConfig
from sqlmesh.dbt.basemodel import Dependencies
from sqlmesh.dbt.builtin import BUILTIN_FILTERS, BUILTIN_GLOBALS, OVERRIDDEN_MACROS
from sqlmesh.dbt.common import Dependencies
from sqlmesh.dbt.model import ModelConfig
from sqlmesh.dbt.package import HookConfig, MacroConfig
from sqlmesh.dbt.seed import SeedConfig
Expand Down Expand Up @@ -354,7 +354,9 @@ def _load_models_and_seeds(self) -> None:
dependencies = Dependencies(
macros=macro_references, refs=_refs(node), sources=_sources(node)
)
dependencies = dependencies.union(self._extra_dependencies(sql, node.package_name))
dependencies = dependencies.union(
self._extra_dependencies(sql, node.package_name, track_all_model_attrs=True)
)
dependencies = dependencies.union(
self._flatten_dependencies_from_macros(dependencies.macros, node.package_name)
)
Expand Down Expand Up @@ -552,17 +554,37 @@ def _flatten_dependencies_from_macros(
dependencies = dependencies.union(macro_dependencies)
return dependencies

def _extra_dependencies(self, target: str, package: str) -> Dependencies:
# We sometimes observe that the manifest doesn't capture all macros, refs, and sources within a macro.
# This behavior has been observed with macros like dbt.current_timestamp(), dbt_utils.slugify(), and source().
# Here we apply our custom extractor to make a best effort to supplement references captured in the manifest.
def _extra_dependencies(
self,
target: str,
package: str,
track_all_model_attrs: bool = False,
) -> Dependencies:
"""
We sometimes observe that the manifest doesn't capture all macros, refs, and sources within a macro.
This behavior has been observed with macros like dbt.current_timestamp(), dbt_utils.slugify(), and source().
Here we apply our custom extractor to make a best effort to supplement references captured in the manifest.
"""
dependencies = Dependencies()

# Whether all `model` attributes (e.g., `model.config`) should be included in the dependencies
all_model_attrs = False

for call_name, node in extract_call_names(target, cache=self._calls):
if call_name[0] == "config":
continue
elif isinstance(node, jinja2.nodes.Getattr):

if (
track_all_model_attrs
and not all_model_attrs
and isinstance(node, jinja2.nodes.Call)
and any(isinstance(a, jinja2.nodes.Name) and a.name == "model" for a in node.args)
):
all_model_attrs = True

if isinstance(node, jinja2.nodes.Getattr):
if call_name[0] == "model":
dependencies.model_attrs.add(call_name[1])
dependencies.model_attrs.attrs.add(call_name[1])
elif call_name[0] == "source":
args = [jinja_call_arg_name(arg) for arg in node.args]
if args and all(arg for arg in args):
Expand Down Expand Up @@ -606,6 +628,14 @@ def _extra_dependencies(self, target: str, package: str) -> Dependencies:
call_name[0], call_name[1], dependencies.macros.append
)

# When `model` is referenced as-is, e.g. it's passed as an argument to a macro call like
# `{{ foo(model) }}`, we can't easily track the attributes that are actually used, because
# it may be aliased and hence tracking actual uses of `model` requires a proper data flow
# analysis. We conservatively deal with this by including all of its supported attributes
# if a standalone reference is found.
if all_model_attrs:
dependencies.model_attrs.all_attrs = True

return dependencies


Expand Down
Loading