Skip to content

Commit 34dc9fd

Browse files
authored
Feat!: bring dbt node information through to SQLMesh (#5412)
1 parent e00e860 commit 34dc9fd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+575
-229
lines changed

.circleci/continue_config.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,11 @@ jobs:
144144
- halt_unless_core
145145
- checkout
146146
- run:
147-
name: Run the migration test
148-
command: ./.circleci/test_migration.sh
147+
name: Run the migration test - sushi
148+
command: ./.circleci/test_migration.sh sushi "--gateway duckdb_persistent"
149+
- run:
150+
name: Run the migration test - sushi_dbt
151+
command: ./.circleci/test_migration.sh sushi_dbt "--config migration_test_config"
149152

150153
ui_style:
151154
docker:

.circleci/test_migration.sh

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
#!/usr/bin/env bash
22
set -ex
33

4-
GATEWAY_NAME="duckdb_persistent"
5-
TMP_DIR=$(mktemp -d)
6-
SUSHI_DIR="$TMP_DIR/sushi"
7-
8-
94
if [[ -z $(git tag --points-at HEAD) ]]; then
105
# If the current commit is not tagged, we need to find the last tag
116
LAST_TAG=$(git describe --tags --abbrev=0)
@@ -14,28 +9,48 @@ else
149
LAST_TAG=$(git tag --sort=-creatordate | head -n 2 | tail -n 1)
1510
fi
1611

12+
if [ "$1" == "" ]; then
13+
echo "Usage: $0 <example name> <sqlmesh opts>"
14+
echo "eg $0 sushi '--gateway duckdb_persistent'"
15+
exit 1
16+
fi
17+
18+
19+
TMP_DIR=$(mktemp -d)
20+
EXAMPLE_NAME="$1"
21+
SQLMESH_OPTS="$2"
22+
EXAMPLE_DIR="./examples/$EXAMPLE_NAME"
23+
TEST_DIR="$TMP_DIR/$EXAMPLE_NAME"
24+
25+
echo "Running migration test for '$EXAMPLE_NAME' in '$TEST_DIR' for example project '$EXAMPLE_DIR' using options '$SQLMESH_OPTS'"
26+
1727
git checkout $LAST_TAG
1828

1929
# Install dependencies from the previous release.
2030
make install-dev
2131

22-
cp -r ./examples/sushi $TMP_DIR
32+
cp -r $EXAMPLE_DIR $TEST_DIR
33+
34+
# this is only needed temporarily until the released tag for $LAST_TAG includes this config
35+
if [ "$EXAMPLE_NAME" == "sushi_dbt" ]; then
36+
echo 'migration_test_config = sqlmesh_config(Path(__file__).parent, dbt_target_name="duckdb")' >> $TEST_DIR/config.py
37+
fi
2338

2439
# Run initial plan
25-
pushd $SUSHI_DIR
40+
pushd $TEST_DIR
2641
rm -rf ./data/*
27-
sqlmesh --gateway $GATEWAY_NAME plan --no-prompts --auto-apply
42+
sqlmesh $SQLMESH_OPTS plan --no-prompts --auto-apply
2843
rm -rf .cache
2944
popd
3045

31-
# Switch back to the starting state of the repository
46+
# Switch back to the starting state of the repository
3247
git checkout -
3348

3449
# Install updated dependencies.
3550
make install-dev
3651

3752
# Migrate and make sure the diff is empty
38-
pushd $SUSHI_DIR
39-
sqlmesh --gateway $GATEWAY_NAME migrate
40-
sqlmesh --gateway $GATEWAY_NAME diff prod
41-
popd
53+
pushd $TEST_DIR
54+
sqlmesh $SQLMESH_OPTS migrate
55+
sqlmesh $SQLMESH_OPTS diff prod
56+
popd

examples/sushi_dbt/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@
55
config = sqlmesh_config(Path(__file__).parent)
66

77
test_config = config
8+
9+
migration_test_config = sqlmesh_config(Path(__file__).parent, dbt_target_name="duckdb")

sqlmesh/cli/project_init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def init_example_project(
298298
dlt_path: t.Optional[str] = None,
299299
schema_name: str = "sqlmesh_example",
300300
cli_mode: InitCliMode = InitCliMode.DEFAULT,
301+
start: t.Optional[str] = None,
301302
) -> Path:
302303
root_path = Path(path)
303304

@@ -336,7 +337,6 @@ def init_example_project(
336337

337338
models: t.Set[t.Tuple[str, str]] = set()
338339
settings = None
339-
start = None
340340
if engine_type and template == ProjectTemplate.DLT:
341341
project_dialect = dialect or DIALECT_TO_TYPE.get(engine_type)
342342
if pipeline and project_dialect:

sqlmesh/core/audit/definition.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
sorted_python_env_payloads,
2020
)
2121
from sqlmesh.core.model.common import make_python_env, single_value_or_tuple, ParsableSql
22-
from sqlmesh.core.node import _Node
22+
from sqlmesh.core.node import _Node, DbtInfoMixin, DbtNodeInfo
2323
from sqlmesh.core.renderer import QueryRenderer
2424
from sqlmesh.utils.date import TimeLike
2525
from sqlmesh.utils.errors import AuditConfigError, SQLMeshError, raise_config_error
@@ -120,7 +120,7 @@ def audit_map_validator(cls: t.Type, v: t.Any, values: t.Any) -> t.Dict[str, t.A
120120
return {}
121121

122122

123-
class ModelAudit(PydanticModel, AuditMixin, frozen=True):
123+
class ModelAudit(PydanticModel, AuditMixin, DbtInfoMixin, frozen=True):
124124
"""
125125
Audit is an assertion made about your tables.
126126
@@ -137,6 +137,7 @@ class ModelAudit(PydanticModel, AuditMixin, frozen=True):
137137
expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions")
138138
jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry()
139139
formatting: t.Optional[bool] = Field(default=None, exclude=True)
140+
dbt_node_info_: t.Optional[DbtNodeInfo] = Field(alias="dbt_node_info", default=None)
140141

141142
_path: t.Optional[Path] = None
142143

@@ -150,6 +151,10 @@ def __str__(self) -> str:
150151
path = f": {self._path.name}" if self._path else ""
151152
return f"{self.__class__.__name__}<{self.name}{path}>"
152153

154+
@property
155+
def dbt_node_info(self) -> t.Optional[DbtNodeInfo]:
156+
return self.dbt_node_info_
157+
153158

154159
class StandaloneAudit(_Node, AuditMixin):
155160
"""
@@ -552,4 +557,5 @@ def _maybe_parse_arg_pair(e: exp.Expression) -> t.Tuple[str, exp.Expression]:
552557
"depends_on_": lambda value: exp.Tuple(expressions=sorted(value)),
553558
"tags": single_value_or_tuple,
554559
"default_catalog": exp.to_identifier,
560+
"dbt_node_info_": lambda value: value.to_expression(),
555561
}

sqlmesh/core/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1697,9 +1697,9 @@ def plan_builder(
16971697
console=self.console,
16981698
user_provided_flags=user_provided_flags,
16991699
selected_models={
1700-
dbt_name
1700+
dbt_unique_id
17011701
for model in model_selector.expand_model_selections(select_models or "*")
1702-
if (dbt_name := snapshots[model].node.dbt_name)
1702+
if (dbt_unique_id := snapshots[model].node.dbt_unique_id)
17031703
},
17041704
explain=explain or False,
17051705
ignore_cron=ignore_cron or False,

sqlmesh/core/model/definition.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,6 +1197,9 @@ def metadata_hash(self) -> str:
11971197
for k, v in sorted(args.items()):
11981198
metadata.append(f"{k}:{gen(v)}")
11991199

1200+
if self.dbt_node_info:
1201+
metadata.append(self.dbt_node_info.json(sort_keys=True))
1202+
12001203
metadata.extend(self._additional_metadata)
12011204

12021205
self._metadata_hash = hash_data(metadata)
@@ -3019,6 +3022,7 @@ def render_expression(
30193022
"formatting": str,
30203023
"optimize_query": str,
30213024
"virtual_environment_mode": lambda value: exp.Literal.string(value.value),
3025+
"dbt_node_info_": lambda value: value.to_expression(),
30223026
}
30233027

30243028

sqlmesh/core/node.py

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,101 @@ def milliseconds(self) -> int:
153153
return self.seconds * 1000
154154

155155

156+
class DbtNodeInfo(PydanticModel):
157+
"""
158+
Represents dbt-specific model information set by the dbt loader and intended to be made available at the Snapshot level
159+
(as opposed to hidden within the individual model jinja macro registries).
160+
161+
This allows for things like injecting implementations of variables / functions into the Jinja context that are compatible with
162+
their dbt equivalents but are backed by the sqlmesh snapshots in any given plan / environment
163+
"""
164+
165+
unique_id: str
166+
"""This is the node/resource name/unique_id that's used as the node key in the dbt manifest.
167+
It's prefixed by the resource type and is exposed in context variables like {{ selected_resources }}.
168+
169+
Examples:
170+
- test.jaffle_shop.unique_stg_orders_order_id.e3b841c71a
171+
- seed.jaffle_shop.raw_payments
172+
- model.jaffle_shop.stg_orders
173+
"""
174+
175+
name: str
176+
"""Name of this object in the dbt global namespace, used by things like {{ ref() }} calls.
177+
178+
Examples:
179+
- unique_stg_orders_order_id
180+
- raw_payments
181+
- stg_orders
182+
"""
183+
184+
fqn: str
185+
"""Used for selectors in --select/--exclude.
186+
Takes the filesystem into account so may be structured differently to :unique_id.
187+
188+
Examples:
189+
- jaffle_shop.staging.unique_stg_orders_order_id
190+
- jaffle_shop.raw_payments
191+
- jaffle_shop.staging.stg_orders
192+
"""
193+
194+
alias: t.Optional[str] = None
195+
"""This is dbt's way of overriding the _physical table_ a model is written to.
196+
197+
It's used in the following situation:
198+
- Say you have two models, "stg_customers" and "customers"
199+
- You want "stg_customers" to be written to the "staging" schema as eg "staging.customers" - NOT "staging.stg_customers"
200+
- But you cant rename the file to "customers" because it will conflict with your other model file "customers"
201+
- Even if you put it in a different folder, eg "staging/customers.sql" - dbt still has a global namespace so it will conflict
202+
when you try to do something like "{{ ref('customers') }}"
203+
- So dbt's solution to this problem is to keep calling it "stg_customers" at the dbt project/model level,
204+
but allow overriding the physical table to "customers" via something like "{{ config(alias='customers', schema='staging') }}"
205+
206+
Note that if :alias is set, it does *not* replace :name at the model level and cannot be used interchangably with :name.
207+
It also does not affect the :fqn or :unique_id. It's just used to override :name when it comes time to generate the physical table name.
208+
"""
209+
210+
@model_validator(mode="after")
211+
def post_init(self) -> Self:
212+
# by default, dbt sets alias to the same as :name
213+
# however, we only want to include :alias if it is actually different / actually providing an override
214+
if self.alias == self.name:
215+
self.alias = None
216+
return self
217+
218+
def to_expression(self) -> exp.Expression:
219+
"""Produce a SQLGlot expression representing this object, for use in things like the model/audit definition renderers"""
220+
return exp.tuple_(
221+
*(
222+
exp.PropertyEQ(this=exp.var(k), expression=exp.Literal.string(v))
223+
for k, v in sorted(self.model_dump(exclude_none=True).items())
224+
)
225+
)
226+
227+
228+
class DbtInfoMixin:
229+
"""This mixin encapsulates properties that only exist for dbt compatibility and are otherwise not required
230+
for native projects"""
231+
232+
@property
233+
def dbt_node_info(self) -> t.Optional[DbtNodeInfo]:
234+
raise NotImplementedError()
235+
236+
@property
237+
def dbt_unique_id(self) -> t.Optional[str]:
238+
"""Used for compatibility with jinja context variables such as {{ selected_resources }}"""
239+
if self.dbt_node_info:
240+
return self.dbt_node_info.unique_id
241+
return None
242+
243+
@property
244+
def dbt_fqn(self) -> t.Optional[str]:
245+
"""Used in the selector engine for compatibility with selectors that select models by dbt fqn"""
246+
if self.dbt_node_info:
247+
return self.dbt_node_info.fqn
248+
return None
249+
250+
156251
# this must be sorted in descending order
157252
INTERVAL_SECONDS = {
158253
IntervalUnit.YEAR: 60 * 60 * 24 * 365,
@@ -165,7 +260,7 @@ def milliseconds(self) -> int:
165260
}
166261

167262

168-
class _Node(PydanticModel):
263+
class _Node(DbtInfoMixin, PydanticModel):
169264
"""
170265
Node is the core abstraction for entity that can be executed within the scheduler.
171266
@@ -199,7 +294,7 @@ class _Node(PydanticModel):
199294
interval_unit_: t.Optional[IntervalUnit] = Field(alias="interval_unit", default=None)
200295
tags: t.List[str] = []
201296
stamp: t.Optional[str] = None
202-
dbt_name: t.Optional[str] = None # dbt node name
297+
dbt_node_info_: t.Optional[DbtNodeInfo] = Field(alias="dbt_node_info", default=None)
203298
_path: t.Optional[Path] = None
204299
_data_hash: t.Optional[str] = None
205300
_metadata_hash: t.Optional[str] = None
@@ -446,6 +541,10 @@ def is_audit(self) -> bool:
446541
"""Return True if this is an audit node"""
447542
return False
448543

544+
@property
545+
def dbt_node_info(self) -> t.Optional[DbtNodeInfo]:
546+
return self.dbt_node_info_
547+
449548

450549
class NodeType(str, Enum):
451550
MODEL = "model"

sqlmesh/core/scheduler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -839,7 +839,9 @@ def _run_or_audit(
839839
run_environment_statements=run_environment_statements,
840840
audit_only=audit_only,
841841
auto_restatement_triggers=auto_restatement_triggers,
842-
selected_models={s.node.dbt_name for s in merged_intervals if s.node.dbt_name},
842+
selected_models={
843+
s.node.dbt_unique_id for s in merged_intervals if s.node.dbt_unique_id
844+
},
843845
)
844846

845847
return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS

sqlmesh/dbt/basemodel.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from sqlmesh.core.config.base import UpdateStrategy
1414
from sqlmesh.core.config.common import VirtualEnvironmentMode
1515
from sqlmesh.core.model import Model
16+
from sqlmesh.core.node import DbtNodeInfo
1617
from sqlmesh.dbt.column import (
1718
ColumnConfig,
1819
column_descriptions_to_sqlmesh,
@@ -120,8 +121,10 @@ class BaseModelConfig(GeneralConfig):
120121
grain: t.Union[str, t.List[str]] = []
121122

122123
# DBT configuration fields
124+
unique_id: str = ""
123125
name: str = ""
124126
package_name: str = ""
127+
fqn: t.List[str] = []
125128
schema_: str = Field("", alias="schema")
126129
database: t.Optional[str] = None
127130
alias: t.Optional[str] = None
@@ -273,12 +276,10 @@ def sqlmesh_config_fields(self) -> t.Set[str]:
273276
return {"description", "owner", "stamp", "storage_format"}
274277

275278
@property
276-
def node_name(self) -> str:
277-
resource_type = getattr(self, "resource_type", "model")
278-
node_name = f"{resource_type}.{self.package_name}.{self.name}"
279-
if self.version:
280-
node_name += f".v{self.version}"
281-
return node_name
279+
def node_info(self) -> DbtNodeInfo:
280+
return DbtNodeInfo(
281+
unique_id=self.unique_id, name=self.name, fqn=".".join(self.fqn), alias=self.alias
282+
)
282283

283284
def sqlmesh_model_kwargs(
284285
self,
@@ -349,8 +350,8 @@ def to_sqlmesh(
349350
def _model_jinja_context(
350351
self, context: DbtContext, dependencies: Dependencies
351352
) -> t.Dict[str, t.Any]:
352-
if context._manifest and self.node_name in context._manifest._manifest.nodes:
353-
attributes = context._manifest._manifest.nodes[self.node_name].to_dict()
353+
if context._manifest and self.unique_id in context._manifest._manifest.nodes:
354+
attributes = context._manifest._manifest.nodes[self.unique_id].to_dict()
354355
if dependencies.model_attrs.all_attrs:
355356
model_node: AttributeDict[str, t.Any] = AttributeDict(attributes)
356357
else:

0 commit comments

Comments
 (0)