Skip to content

Commit ca8cdb3

Browse files
committed
feat: dbt microbatch ref filter support
1 parent e0cd531 commit ca8cdb3

File tree

6 files changed

+180
-1
lines changed

6 files changed

+180
-1
lines changed

sqlmesh/core/renderer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,13 @@ def _resolve_table(table: str | exp.Table) -> str:
214214
dialect=self._dialect, identify=True, comments=False
215215
)
216216

217+
all_refs = list(self._jinja_macro_registry.global_objs.get("sources", {}).values()) + list( # type: ignore
218+
self._jinja_macro_registry.global_objs.get("refs", {}).values() # type: ignore
219+
)
220+
for ref in all_refs:
221+
if ref.event_time_filter:
222+
ref.event_time_filter["start"] = render_kwargs["start_tstz"]
223+
ref.event_time_filter["end"] = render_kwargs["end_tstz"]
217224
jinja_env = self._jinja_macro_registry.build_environment(**jinja_env_kwargs)
218225

219226
expressions = []

sqlmesh/dbt/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,25 @@
22
create_builtin_filters as create_builtin_filters,
33
create_builtin_globals as create_builtin_globals,
44
)
5+
from sqlmesh.dbt.util import DBT_VERSION
6+
7+
8+
if DBT_VERSION >= (1, 9, 0):
9+
from dbt.adapters.base.relation import BaseRelation, EventTimeFilter
10+
11+
def _render_event_time_filtered_inclusive(
12+
self: BaseRelation, event_time_filter: EventTimeFilter
13+
) -> str:
14+
"""
15+
Returns "" if start and end are both None
16+
"""
17+
filter = ""
18+
if event_time_filter.start and event_time_filter.end:
19+
filter = f"{event_time_filter.field_name} BETWEEN '{event_time_filter.start}' and '{event_time_filter.end}'"
20+
elif event_time_filter.start:
21+
filter = f"{event_time_filter.field_name} >= '{event_time_filter.start}'"
22+
elif event_time_filter.end:
23+
filter = f"{event_time_filter.field_name} <= '{event_time_filter.end}'"
24+
return filter
25+
26+
BaseRelation._render_event_time_filtered = _render_event_time_filtered_inclusive # type: ignore

sqlmesh/dbt/basemodel.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from sqlmesh.dbt.relation import Policy, RelationType
3030
from sqlmesh.dbt.test import TestConfig
31+
from sqlmesh.dbt.util import DBT_VERSION
3132
from sqlmesh.utils import AttributeDict
3233
from sqlmesh.utils.errors import ConfigError
3334
from sqlmesh.utils.pydantic import field_validator
@@ -130,6 +131,7 @@ class BaseModelConfig(GeneralConfig):
130131
grants: t.Dict[str, t.List[str]] = {}
131132
columns: t.Dict[str, ColumnConfig] = {}
132133
quoting: t.Dict[str, t.Optional[bool]] = {}
134+
event_time: t.Optional[str] = None
133135

134136
version: t.Optional[int] = None
135137
latest_version: t.Optional[int] = None
@@ -222,13 +224,20 @@ def relation_info(self) -> AttributeDict[str, t.Any]:
222224
else:
223225
relation_type = RelationType.Table
224226

227+
extras = {}
228+
if DBT_VERSION >= (1, 9, 0) and self.event_time:
229+
extras["event_time_filter"] = {
230+
"field_name": self.event_time,
231+
}
232+
225233
return AttributeDict(
226234
{
227235
"database": self.database,
228236
"schema": self.table_schema,
229237
"identifier": self.table_name,
230238
"type": relation_type.value,
231239
"quote_policy": AttributeDict(self.quoting),
240+
**extras,
232241
}
233242
)
234243

sqlmesh/dbt/source.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sqlmesh.dbt.column import ColumnConfig
99
from sqlmesh.dbt.common import GeneralConfig
1010
from sqlmesh.dbt.relation import RelationType
11+
from sqlmesh.dbt.util import DBT_VERSION
1112
from sqlmesh.utils import AttributeDict
1213
from sqlmesh.utils.errors import ConfigError
1314

@@ -46,6 +47,7 @@ class SourceConfig(GeneralConfig):
4647
external: t.Optional[t.Dict[str, t.Any]] = {}
4748
source_meta: t.Optional[t.Dict[str, t.Any]] = {}
4849
columns: t.Dict[str, ColumnConfig] = {}
50+
event_time: t.Optional[str] = None
4951

5052
_canonical_name: t.Optional[str] = None
5153

@@ -94,6 +96,11 @@ def relation_info(self) -> AttributeDict:
9496
if external_location:
9597
extras["external"] = external_location.replace("{name}", self.table_name)
9698

99+
if DBT_VERSION >= (1, 9, 0) and self.event_time:
100+
extras["event_time_filter"] = {
101+
"field_name": self.event_time,
102+
}
103+
97104
return AttributeDict(
98105
{
99106
"database": self.database,

tests/core/test_snapshot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1078,7 +1078,9 @@ def test_fingerprint_jinja_macros_global_objs(model: Model, global_obj_key: str)
10781078
)
10791079
fingerprint = fingerprint_from_node(model, nodes={})
10801080
model = model.copy()
1081-
model.jinja_macros.global_objs[global_obj_key] = AttributeDict({"test": "test"})
1081+
model.jinja_macros.global_objs[global_obj_key] = AttributeDict(
1082+
{"test": AttributeDict({"test": "test"})}
1083+
)
10821084
updated_fingerprint = fingerprint_from_node(model, nodes={})
10831085
assert updated_fingerprint.data_hash != fingerprint.data_hash
10841086
assert updated_fingerprint.metadata_hash == fingerprint.metadata_hash

tests/dbt/test_model.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,3 +298,135 @@ def test_load_microbatch_required_only(
298298
column=exp.to_column("ds", quoted=True), format="%Y-%m-%d"
299299
)
300300
assert model.kind.batch_size is None
301+
302+
303+
@pytest.mark.slow
304+
def test_load_microbatch_with_ref(
305+
tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project
306+
) -> None:
307+
yaml = YAML()
308+
project_dir, model_dir = create_empty_project()
309+
source_schema = {
310+
"version": 2,
311+
"sources": [
312+
{
313+
"name": "my_source",
314+
"tables": [{"name": "my_table", "config": {"event_time": "ds"}}],
315+
}
316+
],
317+
}
318+
source_schema_file = model_dir / "source_schema.yml"
319+
with open(source_schema_file, "w", encoding="utf-8") as f:
320+
yaml.dump(source_schema, f)
321+
# add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it
322+
microbatch_contents = """
323+
{{
324+
config(
325+
materialized='incremental',
326+
incremental_strategy='microbatch',
327+
event_time='ds',
328+
begin='2020-01-01',
329+
batch_size='day'
330+
)
331+
}}
332+
333+
SELECT cola, ds FROM {{ source('my_source', 'my_table') }}
334+
"""
335+
microbatch_model_file = model_dir / "microbatch.sql"
336+
with open(microbatch_model_file, "w", encoding="utf-8") as f:
337+
f.write(microbatch_contents)
338+
339+
microbatch_two_contents = """
340+
{{
341+
config(
342+
materialized='incremental',
343+
incremental_strategy='microbatch',
344+
event_time='ds',
345+
begin='2020-01-05',
346+
batch_size='day'
347+
)
348+
}}
349+
350+
SELECT cola, ds FROM {{ ref('microbatch') }}
351+
"""
352+
microbatch_two_model_file = model_dir / "microbatch_two.sql"
353+
with open(microbatch_two_model_file, "w", encoding="utf-8") as f:
354+
f.write(microbatch_two_contents)
355+
356+
microbatch_snapshot_fqn = '"local"."main"."microbatch"'
357+
microbatch_two_snapshot_fqn = '"local"."main"."microbatch_two"'
358+
context = Context(paths=project_dir)
359+
assert (
360+
context.render(microbatch_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql()
361+
== 'SELECT "cola" AS "cola", "ds" AS "ds" FROM (SELECT * FROM "local"."my_source"."my_table" AS "my_table" WHERE "ds" BETWEEN \'2025-01-01 00:00:00+00:00\' AND \'2025-01-10 23:59:59.999999+00:00\') AS "_q_0"'
362+
)
363+
assert (
364+
context.render(microbatch_two_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql()
365+
== 'SELECT "_q_0"."cola" AS "cola", "_q_0"."ds" AS "ds" FROM (SELECT "microbatch"."cola" AS "cola", "microbatch"."ds" AS "ds" FROM "local"."main"."microbatch" AS "microbatch" WHERE "microbatch"."ds" <= \'2025-01-10 23:59:59.999999+00:00\' AND "microbatch"."ds" >= \'2025-01-01 00:00:00+00:00\') AS "_q_0"'
366+
)
367+
368+
369+
@pytest.mark.slow
370+
def test_load_microbatch_with_ref_no_filter(
371+
tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project
372+
) -> None:
373+
yaml = YAML()
374+
project_dir, model_dir = create_empty_project()
375+
source_schema = {
376+
"version": 2,
377+
"sources": [
378+
{
379+
"name": "my_source",
380+
"tables": [{"name": "my_table", "config": {"event_time": "ds"}}],
381+
}
382+
],
383+
}
384+
source_schema_file = model_dir / "source_schema.yml"
385+
with open(source_schema_file, "w", encoding="utf-8") as f:
386+
yaml.dump(source_schema, f)
387+
# add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it
388+
microbatch_contents = """
389+
{{
390+
config(
391+
materialized='incremental',
392+
incremental_strategy='microbatch',
393+
event_time='ds',
394+
begin='2020-01-01',
395+
batch_size='day'
396+
)
397+
}}
398+
399+
SELECT cola, ds FROM {{ source('my_source', 'my_table').render() }}
400+
"""
401+
microbatch_model_file = model_dir / "microbatch.sql"
402+
with open(microbatch_model_file, "w", encoding="utf-8") as f:
403+
f.write(microbatch_contents)
404+
405+
microbatch_two_contents = """
406+
{{
407+
config(
408+
materialized='incremental',
409+
incremental_strategy='microbatch',
410+
event_time='ds',
411+
begin='2020-01-01',
412+
batch_size='day'
413+
)
414+
}}
415+
416+
SELECT cola, ds FROM {{ ref('microbatch').render() }}
417+
"""
418+
microbatch_two_model_file = model_dir / "microbatch_two.sql"
419+
with open(microbatch_two_model_file, "w", encoding="utf-8") as f:
420+
f.write(microbatch_two_contents)
421+
422+
microbatch_snapshot_fqn = '"local"."main"."microbatch"'
423+
microbatch_two_snapshot_fqn = '"local"."main"."microbatch_two"'
424+
context = Context(paths=project_dir)
425+
assert (
426+
context.render(microbatch_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql()
427+
== 'SELECT "cola" AS "cola", "ds" AS "ds" FROM "local"."my_source"."my_table" AS "my_table"'
428+
)
429+
assert (
430+
context.render(microbatch_two_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql()
431+
== 'SELECT "microbatch"."cola" AS "cola", "microbatch"."ds" AS "ds" FROM "local"."main"."microbatch" AS "microbatch"'
432+
)

0 commit comments

Comments
 (0)