Skip to content

Commit 503d13f

Browse files
authored
Feat: Allow CustomKind subclasses for custom materializations (#3863)
1 parent d379826 commit 503d13f

File tree

13 files changed

+392
-28
lines changed

13 files changed

+392
-28
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ install-dev:
44
pip3 install -e ".[dev,web,slack,dlt]"
55

66
install-cicd-test:
7-
pip3 install -e ".[dev,web,slack,cicdtest,dlt]"
7+
pip3 install -e ".[dev,web,slack,cicdtest,dlt]" ./examples/custom_materializations
88

99
install-doc:
1010
pip3 install -r ./docs/requirements.txt
@@ -153,7 +153,7 @@ guard-%:
153153
fi
154154

155155
engine-%-install:
156-
pip3 install -e ".[dev,web,slack,${*}]"
156+
pip3 install -e ".[dev,web,slack,${*}]" ./examples/custom_materializations
157157

158158
engine-docker-%-up:
159159
docker compose -f ./tests/core/engine_adapter/integration/docker/compose.${*}.yaml up -d

docs/guides/custom_materializations.md

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,103 @@ class CustomFullMaterialization(CustomMaterialization):
157157
# Example existing materialization for look and feel: https://github.com/TobikoData/sqlmesh/blob/main/sqlmesh/core/snapshot/evaluator.py
158158
```
159159

160+
## Extending `CustomKind`
161+
162+
!!! warning
163+
This is even lower level usage that contains a bunch of extra complexity and relies on knowledge of the SQLMesh internals.
164+
If you dont need this level of complexity, stick with the method described above.
165+
166+
In many cases, the above usage of a custom materialization will suffice.
167+
168+
However, you may still want tighter integration with SQLMesh's internals:
169+
170+
- You may want more control over what is considered a metadata change vs a data change
171+
- You may want to validate custom properties are correct before any database connections are made
172+
- You may want to leverage existing functionality of SQLMesh that relies on specific properties being present
173+
174+
In this case, you can provide a subclass of `CustomKind` for SQLMesh to use instead of `CustomKind` itself.
175+
During project load, SQLMesh will instantiate your *subclass* instead of `CustomKind`.
176+
177+
This allows you to run custom validators at load time rather than having to perform extra validation when `insert()` is invoked on your `CustomMaterialization`.
178+
179+
This approach also allows you set "top-level" properties directly in the `kind (...)` block rather than nesting them under `materialization_properties`.
180+
181+
To extend `CustomKind`, first you define a subclass like so:
182+
183+
```python linenums="1" hl_lines="7"
184+
from sqlmesh import CustomKind
185+
from pydantic import field_validator, ValidationInfo
186+
from sqlmesh.utils.pydantic import list_of_fields_validator
187+
188+
class MyCustomKind(CustomKind):
189+
190+
primary_key: t.List[exp.Expression]
191+
192+
@field_validator("primary_key", mode="before")
193+
@classmethod
194+
def _validate_primary_key(cls, value: t.Any, info: ValidationInfo) -> t.Any:
195+
return list_of_fields_validator(value, info.data)
196+
197+
```
198+
199+
In this example, we define a field called `primary_key` that takes a list of fields. Notice that the field validation is just a simple Pydantic `@field_validator` with the [exact same usage](https://github.com/TobikoData/sqlmesh/blob/ade5f7245950822f3cfe5a68a0c243f91ceca600/sqlmesh/core/model/kind.py#L470) as the standard SQLMesh model kinds.
200+
201+
To use it within a model, we can do something like:
202+
203+
```sql linenums="1" hl_lines="5"
204+
MODEL (
205+
name my_db.my_model,
206+
kind CUSTOM (
207+
materialization 'my_custom_full',
208+
primary_key (col1, col2)
209+
)
210+
);
211+
```
212+
213+
Notice that the `primary_key` field we declared is top-level within the `kind` block instead of being nested under `materialization_properties`.
214+
215+
To indicate to SQLMesh that it should use this subclass, specify it as a generic type parameter on your custom materialization class like so:
216+
217+
```python linenums="1" hl_lines="1 16"
218+
class CustomFullMaterialization(CustomMaterialization[MyCustomKind]):
219+
NAME = "my_custom_full"
220+
221+
def insert(
222+
self,
223+
table_name: str,
224+
query_or_df: QueryOrDF,
225+
model: Model,
226+
is_first_insert: bool,
227+
**kwargs: t.Any,
228+
) -> None:
229+
assert isinstance(model.kind, MyCustomKind)
230+
231+
self.adapter.merge(
232+
...,
233+
unique_key=model.kind.primary_key
234+
)
235+
```
236+
237+
When SQLMesh loads your custom materialization, it will inspect the Python type signature for generic parameters that are subclasses of `CustomKind`. If it finds one, it will instantiate your subclass when building `model.kind` instead of using the default `CustomKind` class.
238+
239+
In this example, this means that:
240+
241+
- Validation for `primary_key` happens at load time instead of evaluation time.
242+
- When your custom materialization is called to load data into tables, `model.kind` will resolve to your custom kind object so you can access the extra properties you defined without first needing to validate them / coerce them to a usable type.
243+
244+
### Data vs Metadata changes
245+
246+
Subclasses of `CustomKind` that add extra properties can also decide if they are data properties (changes may trigger the creation of new snapshots) or metadata properties (changes just update metadata about the model).
247+
248+
They can also decide if they are relevant for text diffing when SQLMesh detects changes to a model.
249+
250+
You can opt in to SQLMesh's change tracking by overriding the following methods:
251+
252+
- If changing the property should change the data fingerprint, add it to [data_hash_values()](https://github.com/TobikoData/sqlmesh/blob/ade5f7245950822f3cfe5a68a0c243f91ceca600/sqlmesh/core/model/kind.py#L858)
253+
- If changing the property should change the metadata fingerprint, add it to [metadata_hash_values()](https://github.com/TobikoData/sqlmesh/blob/ade5f7245950822f3cfe5a68a0c243f91ceca600/sqlmesh/core/model/kind.py#L867)
254+
- If the property should show up in context diffs, add it to [to_expression()](https://github.com/TobikoData/sqlmesh/blob/ade5f7245950822f3cfe5a68a0c243f91ceca600/sqlmesh/core/model/kind.py#L880)
255+
256+
160257
## Sharing custom materializations
161258

162259
### Copying files
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
import typing as t
4+
5+
from sqlmesh import CustomMaterialization, CustomKind, Model
6+
from sqlmesh.utils.pydantic import validate_string
7+
from pydantic import field_validator
8+
9+
if t.TYPE_CHECKING:
10+
from sqlmesh import QueryOrDF
11+
12+
13+
class ExtendedCustomKind(CustomKind):
14+
custom_property: t.Optional[str] = None
15+
16+
@field_validator("custom_property", mode="before")
17+
@classmethod
18+
def _validate_custom_property(cls, v: t.Any) -> str:
19+
return validate_string(v)
20+
21+
22+
class CustomFullWithCustomKindMaterialization(CustomMaterialization[ExtendedCustomKind]):
23+
NAME = "custom_full_with_custom_kind"
24+
25+
def insert(
26+
self,
27+
table_name: str,
28+
query_or_df: QueryOrDF,
29+
model: Model,
30+
is_first_insert: bool,
31+
**kwargs: t.Any,
32+
) -> None:
33+
assert type(model.kind).__name__ == "ExtendedCustomKind"
34+
35+
self._replace_query_for_model(model, table_name, query_or_df)

examples/custom_materializations/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
entry_points={
77
"sqlmesh.materializations": [
88
"custom_full_materialization = custom_materializations.full:CustomFullMaterialization",
9+
"custom_full_with_custom_kind = custom_materializations.custom_kind:CustomFullWithCustomKindMaterialization",
910
],
1011
},
1112
install_requires=[
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
MODEL (
2+
name sushi.latest_order,
3+
kind CUSTOM (
4+
materialization 'custom_full_with_custom_kind',
5+
custom_property 'sushi!!!'
6+
),
7+
cron '@daily'
8+
);
9+
10+
SELECT id, customer_id, start_ts, end_ts, event_date
11+
FROM sushi.orders
12+
ORDER BY event_date DESC LIMIT 1
13+

sqlmesh/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sqlmesh.core.snapshot.evaluator import (
3030
CustomMaterialization as CustomMaterialization,
3131
)
32+
from sqlmesh.core.model.kind import CustomKind as CustomKind
3233
from sqlmesh.utils import (
3334
debug_mode_enabled as debug_mode_enabled,
3435
enable_debug_mode as enable_debug_mode,

sqlmesh/core/model/kind.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -975,6 +975,21 @@ def create_model_kind(v: t.Any, dialect: str, defaults: t.Dict[str, t.Any]) -> M
975975
):
976976
props["on_destructive_change"] = defaults.get("on_destructive_change")
977977

978+
if kind_type == CustomKind:
979+
# load the custom materialization class and check if it uses a custom kind type
980+
from sqlmesh.core.snapshot.evaluator import get_custom_materialization_type
981+
982+
if "materialization" not in props:
983+
raise ConfigError(
984+
"The 'materialization' property is required for models of the CUSTOM kind"
985+
)
986+
987+
actual_kind_type, _ = get_custom_materialization_type(
988+
validate_string(props.get("materialization"))
989+
)
990+
991+
return actual_kind_type(**props)
992+
978993
return kind_type(**props)
979994

980995
name = (v.name if isinstance(v, exp.Expression) else str(v)).upper()

sqlmesh/core/snapshot/evaluator.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
SCDType2ByColumnKind,
5050
SCDType2ByTimeKind,
5151
ViewKind,
52+
CustomKind,
5253
)
5354
from sqlmesh.core.schema_diff import has_drop_alteration, get_dropped_column_names
5455
from sqlmesh.core.snapshot import (
@@ -1130,7 +1131,7 @@ def _evaluation_strategy(snapshot: SnapshotInfoLike, adapter: EngineAdapter) ->
11301131
raise SQLMeshError(
11311132
f"Missing the name of a custom evaluation strategy in model '{snapshot.name}'."
11321133
)
1133-
klass = get_custom_materialization_type(snapshot.custom_materialization)
1134+
_, klass = get_custom_materialization_type(snapshot.custom_materialization)
11341135
return klass(adapter)
11351136
elif snapshot.is_managed:
11361137
klass = EngineManagedStrategy
@@ -1897,7 +1898,10 @@ def _is_materialized_view(self, model: Model) -> bool:
18971898
return isinstance(model.kind, ViewKind) and model.kind.materialized
18981899

18991900

1900-
class CustomMaterialization(MaterializableStrategy):
1901+
C = t.TypeVar("C", bound=CustomKind)
1902+
1903+
1904+
class CustomMaterialization(MaterializableStrategy, t.Generic[C]):
19011905
"""Base class for custom materializations."""
19021906

19031907
def insert(
@@ -1924,14 +1928,36 @@ def insert(
19241928
)
19251929

19261930

1927-
_custom_materialization_type_cache: t.Optional[t.Dict[str, t.Type[CustomMaterialization]]] = None
1931+
_custom_materialization_type_cache: t.Optional[
1932+
t.Dict[str, t.Tuple[t.Type[CustomKind], t.Type[CustomMaterialization]]]
1933+
] = None
1934+
1935+
1936+
def get_custom_materialization_kind_type(st: t.Type[CustomMaterialization]) -> t.Type[CustomKind]:
1937+
# try to read if there is a custom 'kind' type in use by inspecting the type signature
1938+
# eg try to read 'MyCustomKind' from:
1939+
# >>>> class MyCustomMaterialization(CustomMaterialization[MyCustomKind])
1940+
# and fall back to base CustomKind if there is no generic type declared
1941+
if hasattr(st, "__orig_bases__"):
1942+
for base in st.__orig_bases__:
1943+
if hasattr(base, "__origin__") and base.__origin__ == CustomMaterialization:
1944+
for generic_arg in t.get_args(base):
1945+
if not issubclass(generic_arg, CustomKind):
1946+
raise SQLMeshError(
1947+
f"Custom materialization kind '{generic_arg.__name__}' must be a subclass of CustomKind"
1948+
)
1949+
1950+
return generic_arg
19281951

1952+
return CustomKind
19291953

1930-
def get_custom_materialization_type(name: str) -> t.Type[CustomMaterialization]:
1954+
1955+
def get_custom_materialization_type(
1956+
name: str,
1957+
) -> t.Tuple[t.Type[CustomKind], t.Type[CustomMaterialization]]:
19311958
global _custom_materialization_type_cache
19321959

19331960
strategy_key = name.lower()
1934-
19351961
if (
19361962
_custom_materialization_type_cache is None
19371963
or strategy_key not in _custom_materialization_type_cache
@@ -1948,16 +1974,22 @@ def get_custom_materialization_type(name: str) -> t.Type[CustomMaterialization]:
19481974
strategy_types.append(strategy_type)
19491975

19501976
_custom_materialization_type_cache = {
1951-
getattr(strategy_type, "NAME", strategy_type.__name__).lower(): strategy_type
1977+
getattr(strategy_type, "NAME", strategy_type.__name__).lower(): (
1978+
get_custom_materialization_kind_type(strategy_type),
1979+
strategy_type,
1980+
)
19521981
for strategy_type in strategy_types
19531982
}
19541983

19551984
if strategy_key not in _custom_materialization_type_cache:
19561985
raise ConfigError(f"Materialization strategy with name '{name}' was not found.")
19571986

1958-
strategy_type = _custom_materialization_type_cache[strategy_key]
1959-
logger.debug("Resolved custom materialization '%s' to '%s'", name, strategy_type)
1960-
return strategy_type
1987+
strategy_kind_type, strategy_type = _custom_materialization_type_cache[strategy_key]
1988+
logger.debug(
1989+
"Resolved custom materialization '%s' to '%s' (%s)", name, strategy_type, strategy_kind_type
1990+
)
1991+
1992+
return strategy_kind_type, strategy_type
19611993

19621994

19631995
class EngineManagedStrategy(MaterializableStrategy):

tests/core/analytics/test_collector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def test_on_plan_apply(
183183
{
184184
"seq_num": 0,
185185
"event_type": "PLAN_APPLY_START",
186-
"event": f'{{"plan_id": "{plan_id}", "engine_type": "bigquery", "state_sync_type": "mysql", "scheduler_type": "builtin", "is_dev": false, "skip_backfill": false, "no_gaps": false, "forward_only": false, "ensure_finalized_snapshots": false, "has_restatements": false, "directly_modified_count": 18, "indirectly_modified_count": 0, "environment_name_hash": "d6e4a9b6646c62fc48baa6dd6150d1f7"}}',
186+
"event": f'{{"plan_id": "{plan_id}", "engine_type": "bigquery", "state_sync_type": "mysql", "scheduler_type": "builtin", "is_dev": false, "skip_backfill": false, "no_gaps": false, "forward_only": false, "ensure_finalized_snapshots": false, "has_restatements": false, "directly_modified_count": 19, "indirectly_modified_count": 0, "environment_name_hash": "d6e4a9b6646c62fc48baa6dd6150d1f7"}}',
187187
**common_fields,
188188
}
189189
),

tests/core/test_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,7 @@ def test_janitor(sushi_context, mocker: MockerFixture) -> None:
813813
)
814814
# Assert that the views are dropped for each snapshot just once and make sure that the name used is the
815815
# view name with the environment as a suffix
816-
assert adapter_mock.drop_view.call_count == 13
816+
assert adapter_mock.drop_view.call_count == 14
817817
adapter_mock.drop_view.assert_has_calls(
818818
[
819819
call(

0 commit comments

Comments
 (0)