Skip to content

Commit a1ab547

Browse files
Feat: Add support for pre, post, on_virtual_update statements in config (#4995)
1 parent 812bc27 commit a1ab547

File tree

9 files changed

+379
-1
lines changed

9 files changed

+379
-1
lines changed

docs/concepts/models/python_models.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ For example, pre/post-statements might modify settings or create indexes. Howeve
102102

103103
You can set the `pre_statements` and `post_statements` arguments to a list of SQL strings, SQLGlot expressions, or macro calls to define the model's pre/post-statements.
104104

105+
**Project-level defaults:** You can also define pre/post-statements at the project level using `model_defaults` in your configuration. These will be applied to all models in your project and merged with any model-specific statements. Default statements are executed first, followed by model-specific statements. Learn more about this in the [model configuration reference](../../reference/model_configuration.md#model-defaults).
106+
105107
``` python linenums="1" hl_lines="8-12"
106108
@model(
107109
"db.test_model",
@@ -182,6 +184,8 @@ These can be used, for example, to grant privileges on views of the virtual laye
182184

183185
Similar to pre/post-statements you can set the `on_virtual_update` argument in the `@model` decorator to a list of SQL strings, SQLGlot expressions, or macro calls.
184186

187+
**Project-level defaults:** You can also define on-virtual-update statements at the project level using `model_defaults` in your configuration. These will be applied to all models in your project (including Python models) and merged with any model-specific statements. Default statements are executed first, followed by model-specific statements. Learn more about this in the [model configuration reference](../../reference/model_configuration.md#model-defaults).
188+
185189
``` python linenums="1" hl_lines="8"
186190
@model(
187191
"db.test_model",

docs/concepts/models/seed_models.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ ALTER SESSION SET TIMEZONE = 'PST';
203203

204204
Seed models also support on-virtual-update statements, which are executed after the completion of the [Virtual Update](#virtual-update).
205205

206+
**Project-level defaults:** You can also define on-virtual-update statements at the project level using `model_defaults` in your configuration. These will be applied to all models in your project (including seed models) and merged with any model-specific statements. Default statements are executed first, followed by model-specific statements. Learn more about this in the [model configuration reference](../../reference/model_configuration.md#model-defaults).
207+
206208
These must be enclosed within an `ON_VIRTUAL_UPDATE_BEGIN;` ...; `ON_VIRTUAL_UPDATE_END;` block:
207209

208210
```sql linenums="1" hl_lines="8-13"

docs/concepts/models/sql_models.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ For example, pre/post-statements might modify settings or create a table index.
6767

6868
Pre/post-statements are just standard SQL commands located before/after the model query. They must end with a semi-colon, and the model query must end with a semi-colon if a post-statement is present. The [example above](#example) contains both pre- and post-statements.
6969

70+
**Project-level defaults:** You can also define pre/post-statements at the project level using `model_defaults` in your configuration. These will be applied to all models in your project and merged with any model-specific statements. Default statements are executed first, followed by model-specific statements. Learn more about this in the [model configuration reference](../../reference/model_configuration.md#model-defaults).
71+
7072
!!! warning
7173

7274
Pre/post-statements are evaluated twice: when a model's table is created and when its query logic is evaluated. Executing statements more than once can have unintended side-effects, so you can [conditionally execute](../macros/sqlmesh_macros.md#prepost-statements) them based on SQLMesh's [runtime stage](../macros/macro_variables.md#runtime-variables).
@@ -97,6 +99,8 @@ The optional on-virtual-update statements allow you to execute SQL commands afte
9799

98100
These can be used, for example, to grant privileges on views of the virtual layer.
99101

102+
**Project-level defaults:** You can also define on-virtual-update statements at the project level using `model_defaults` in your configuration. These will be applied to all models in your project and merged with any model-specific statements. Default statements are executed first, followed by model-specific statements. Learn more about this in the [model configuration reference](../../reference/model_configuration.md#model-defaults).
103+
100104
These SQL statements must be enclosed within an `ON_VIRTUAL_UPDATE_BEGIN;` ...; `ON_VIRTUAL_UPDATE_END;` block like this:
101105

102106
```sql linenums="1" hl_lines="10-15"

docs/reference/model_configuration.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,42 @@ You can also use the `@model_kind_name` variable to fine-tune control over `phys
136136
)
137137
```
138138

139+
You can aso define `pre_statements`, `post_statements` and `on_virtual_update` statements at the project level that will be applied to all models. These default statements are merged with any model-specific statements, with default statements executing first, followed by model-specific statements.
140+
141+
=== "YAML"
142+
143+
```yaml linenums="1"
144+
model_defaults:
145+
dialect: duckdb
146+
pre_statements:
147+
- "SET timeout = 300000"
148+
post_statements:
149+
- "@IF(@runtime_stage = 'evaluating', ANALYZE @this_model)"
150+
on_virtual_update:
151+
- "GRANT SELECT ON @this_model TO ROLE analyst_role"
152+
```
153+
154+
=== "Python"
155+
156+
```python linenums="1"
157+
from sqlmesh.core.config import Config, ModelDefaultsConfig
158+
159+
config = Config(
160+
model_defaults=ModelDefaultsConfig(
161+
dialect="duckdb",
162+
pre_statements=[
163+
"SET query_timeout = 300000",
164+
],
165+
post_statements=[
166+
"@IF(@runtime_stage = 'evaluating', ANALYZE @this_model)",
167+
],
168+
on_virtual_update=[
169+
"GRANT SELECT ON @this_model TO ROLE analyst_role",
170+
],
171+
),
172+
)
173+
```
174+
139175

140176
The SQLMesh project-level `model_defaults` key supports the following options, described in the [general model properties](#general-model-properties) table above:
141177

@@ -155,6 +191,9 @@ The SQLMesh project-level `model_defaults` key supports the following options, d
155191
- allow_partials
156192
- enabled
157193
- interval_unit
194+
- pre_statements (described [here](../concepts/models/sql_models.md#pre--and-post-statements))
195+
- post_statements (described [here](../concepts/models/sql_models.md#pre--and-post-statements))
196+
- on_virtual_update (described [here](../concepts/models/sql_models.md#on-virtual-update-statements))
158197

159198

160199
### Model Naming

sqlmesh/core/config/model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import typing as t
44

5+
from sqlglot import exp
56
from sqlmesh.core.dialect import parse_one, extract_func_call
67
from sqlmesh.core.config.base import BaseConfig
78
from sqlmesh.core.model.kind import (
@@ -41,6 +42,9 @@ class ModelDefaultsConfig(BaseConfig):
4142
allow_partials: Whether the models can process partial (incomplete) data intervals.
4243
enabled: Whether the models are enabled.
4344
interval_unit: The temporal granularity of the models data intervals. By default computed from cron.
45+
pre_statements: The list of SQL statements that get executed before a model runs.
46+
post_statements: The list of SQL statements that get executed before a model runs.
47+
on_virtual_update: The list of SQL statements to be executed after the virtual update.
4448
4549
"""
4650

@@ -61,6 +65,9 @@ class ModelDefaultsConfig(BaseConfig):
6165
interval_unit: t.Optional[t.Union[str, IntervalUnit]] = None
6266
enabled: t.Optional[t.Union[str, bool]] = None
6367
formatting: t.Optional[t.Union[str, bool]] = None
68+
pre_statements: t.Optional[t.List[t.Union[str, exp.Expression]]] = None
69+
post_statements: t.Optional[t.List[t.Union[str, exp.Expression]]] = None
70+
on_virtual_update: t.Optional[t.List[t.Union[str, exp.Expression]]] = None
6471

6572
_model_kind_validator = model_kind_validator
6673
_on_destructive_change_validator = on_destructive_change_validator

sqlmesh/core/model/definition.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2472,6 +2472,24 @@ def _create_model(
24722472

24732473
statements: t.List[t.Union[exp.Expression, t.Tuple[exp.Expression, bool]]] = []
24742474

2475+
# Merge default pre_statements with model-specific pre_statements
2476+
if "pre_statements" in defaults:
2477+
kwargs["pre_statements"] = [
2478+
exp.maybe_parse(stmt, dialect=dialect) for stmt in defaults["pre_statements"]
2479+
] + kwargs.get("pre_statements", [])
2480+
2481+
# Merge default post_statements with model-specific post_statements
2482+
if "post_statements" in defaults:
2483+
kwargs["post_statements"] = [
2484+
exp.maybe_parse(stmt, dialect=dialect) for stmt in defaults["post_statements"]
2485+
] + kwargs.get("post_statements", [])
2486+
2487+
# Merge default on_virtual_update with model-specific on_virtual_update
2488+
if "on_virtual_update" in defaults:
2489+
kwargs["on_virtual_update"] = [
2490+
exp.maybe_parse(stmt, dialect=dialect) for stmt in defaults["on_virtual_update"]
2491+
] + kwargs.get("on_virtual_update", [])
2492+
24752493
if "pre_statements" in kwargs:
24762494
statements.extend(kwargs["pre_statements"])
24772495
if "query" in kwargs:

tests/core/test_config.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,65 @@ def test_load_model_defaults_audits(tmp_path):
676676
assert config.model_defaults.audits[1][1]["threshold"].this == "1000"
677677

678678

679+
def test_load_model_defaults_statements(tmp_path):
680+
config_path = tmp_path / "config_model_defaults_statements.yaml"
681+
with open(config_path, "w", encoding="utf-8") as fd:
682+
fd.write(
683+
"""
684+
model_defaults:
685+
dialect: duckdb
686+
pre_statements:
687+
- SET memory_limit = '10GB'
688+
- CREATE TEMP TABLE temp_data AS SELECT 1 as id
689+
post_statements:
690+
- DROP TABLE IF EXISTS temp_data
691+
- ANALYZE @this_model
692+
- SET memory_limit = '5GB'
693+
on_virtual_update:
694+
- UPDATE stats_table SET last_update = CURRENT_TIMESTAMP
695+
"""
696+
)
697+
698+
config = load_config_from_paths(
699+
Config,
700+
project_paths=[config_path],
701+
)
702+
703+
assert config.model_defaults.pre_statements is not None
704+
assert len(config.model_defaults.pre_statements) == 2
705+
assert isinstance(exp.maybe_parse(config.model_defaults.pre_statements[0]), exp.Set)
706+
assert isinstance(exp.maybe_parse(config.model_defaults.pre_statements[1]), exp.Create)
707+
708+
assert config.model_defaults.post_statements is not None
709+
assert len(config.model_defaults.post_statements) == 3
710+
assert isinstance(exp.maybe_parse(config.model_defaults.post_statements[0]), exp.Drop)
711+
assert isinstance(exp.maybe_parse(config.model_defaults.post_statements[1]), exp.Analyze)
712+
assert isinstance(exp.maybe_parse(config.model_defaults.post_statements[2]), exp.Set)
713+
714+
assert config.model_defaults.on_virtual_update is not None
715+
assert len(config.model_defaults.on_virtual_update) == 1
716+
assert isinstance(exp.maybe_parse(config.model_defaults.on_virtual_update[0]), exp.Update)
717+
718+
719+
def test_load_model_defaults_validation_statements(tmp_path):
720+
config_path = tmp_path / "config_model_defaults_statements_wrong.yaml"
721+
with open(config_path, "w", encoding="utf-8") as fd:
722+
fd.write(
723+
"""
724+
model_defaults:
725+
dialect: duckdb
726+
pre_statements:
727+
- 313
728+
"""
729+
)
730+
731+
with pytest.raises(TypeError, match=r"expected str instance, int found"):
732+
config = load_config_from_paths(
733+
Config,
734+
project_paths=[config_path],
735+
)
736+
737+
679738
def test_scheduler_config(tmp_path_factory):
680739
config_path = tmp_path_factory.mktemp("yaml_config") / "config.yaml"
681740
with open(config_path, "w", encoding="utf-8") as fd:

tests/core/test_context.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2731,3 +2731,139 @@ def _get_missing_intervals(name: str) -> t.List[t.Tuple[datetime, datetime]]:
27312731
assert context.engine_adapter.fetchall(
27322732
"select min(start_dt), max(end_dt) from sqlmesh_example__pr_env.unrelated_monthly_model"
27332733
) == [(to_datetime("2020-01-01 00:00:00"), to_datetime("2020-01-31 23:59:59.999999"))]
2734+
2735+
2736+
def test_defaults_pre_post_statements(tmp_path: Path):
2737+
config_path = tmp_path / "config.yaml"
2738+
models_path = tmp_path / "models"
2739+
models_path.mkdir()
2740+
2741+
# Create config with default statements
2742+
config_path.write_text(
2743+
"""
2744+
model_defaults:
2745+
dialect: duckdb
2746+
pre_statements:
2747+
- SET memory_limit = '10GB'
2748+
- SET threads = @var1
2749+
post_statements:
2750+
- ANALYZE @this_model
2751+
variables:
2752+
var1: 4
2753+
"""
2754+
)
2755+
2756+
# Create a model
2757+
model_path = models_path / "test_model.sql"
2758+
model_path.write_text(
2759+
"""
2760+
MODEL (
2761+
name test_model,
2762+
kind FULL
2763+
);
2764+
2765+
SELECT 1 as id, 'test' as status;
2766+
"""
2767+
)
2768+
2769+
ctx = Context(paths=[tmp_path])
2770+
2771+
# Initial plan and apply
2772+
initial_plan = ctx.plan(auto_apply=True, no_prompts=True)
2773+
assert len(initial_plan.new_snapshots) == 1
2774+
2775+
snapshot = list(initial_plan.new_snapshots)[0]
2776+
model = snapshot.model
2777+
2778+
# Verify statements are in the model and python environment has been popuplated
2779+
assert len(model.pre_statements) == 2
2780+
assert len(model.post_statements) == 1
2781+
assert model.python_env[c.SQLMESH_VARS].payload == "{'var1': 4}"
2782+
2783+
# Verify the statements contain the expected SQL
2784+
assert model.pre_statements[0].sql() == "SET memory_limit = '10GB'"
2785+
assert model.render_pre_statements()[0].sql() == "SET \"memory_limit\" = '10GB'"
2786+
assert model.pre_statements[1].sql() == "SET threads = @var1"
2787+
assert model.render_pre_statements()[1].sql() == 'SET "threads" = 4'
2788+
2789+
# Update config to change pre_statement
2790+
config_path.write_text(
2791+
"""
2792+
model_defaults:
2793+
dialect: duckdb
2794+
pre_statements:
2795+
- SET memory_limit = '5GB' # Changed value
2796+
post_statements:
2797+
- ANALYZE @this_model
2798+
"""
2799+
)
2800+
2801+
# Reload context and create new plan
2802+
ctx = Context(paths=[tmp_path])
2803+
updated_plan = ctx.plan(no_prompts=True)
2804+
2805+
# Should detect a change due to different pre_statements
2806+
assert len(updated_plan.directly_modified) == 1
2807+
2808+
# Apply the plan
2809+
ctx.apply(updated_plan)
2810+
2811+
# Reload the models to get the updated version
2812+
ctx.load()
2813+
new_model = ctx.models['"test_model"']
2814+
2815+
# Verify updated statements
2816+
assert len(new_model.pre_statements) == 1
2817+
assert new_model.pre_statements[0].sql() == "SET memory_limit = '5GB'"
2818+
assert new_model.render_pre_statements()[0].sql() == "SET \"memory_limit\" = '5GB'"
2819+
2820+
# Verify the change was detected by the plan
2821+
assert len(updated_plan.directly_modified) == 1
2822+
2823+
2824+
def test_model_defaults_statements_with_on_virtual_update(tmp_path: Path):
2825+
config_path = tmp_path / "config.yaml"
2826+
models_path = tmp_path / "models"
2827+
models_path.mkdir()
2828+
2829+
# Create config with on_virtual_update
2830+
config_path.write_text(
2831+
"""
2832+
model_defaults:
2833+
dialect: duckdb
2834+
on_virtual_update:
2835+
- SELECT 'Model-defailt virtual update' AS message
2836+
"""
2837+
)
2838+
2839+
# Create a model with its own on_virtual_update as wel
2840+
model_path = models_path / "test_model.sql"
2841+
model_path.write_text(
2842+
"""
2843+
MODEL (
2844+
name test_model,
2845+
kind FULL
2846+
);
2847+
2848+
SELECT 1 as id, 'test' as name;
2849+
2850+
ON_VIRTUAL_UPDATE_BEGIN;
2851+
SELECT 'Model-specific update' AS message;
2852+
ON_VIRTUAL_UPDATE_END;
2853+
"""
2854+
)
2855+
2856+
ctx = Context(paths=[tmp_path])
2857+
2858+
# Plan and apply
2859+
plan = ctx.plan(auto_apply=True, no_prompts=True)
2860+
2861+
snapshot = list(plan.new_snapshots)[0]
2862+
model = snapshot.model
2863+
2864+
# Verify both default and model-specific on_virtual_update statements
2865+
assert len(model.on_virtual_update) == 2
2866+
2867+
# Default statements should come first
2868+
assert model.on_virtual_update[0].sql() == "SELECT 'Model-defailt virtual update' AS message"
2869+
assert model.on_virtual_update[1].sql() == "SELECT 'Model-specific update' AS message"

0 commit comments

Comments
 (0)