Skip to content

Commit 703a4b5

Browse files
authored
Merge branch 'main' into main
2 parents da1a27a + a2f8a05 commit 703a4b5

File tree

18 files changed

+495
-65
lines changed

18 files changed

+495
-65
lines changed

docs/guides/signals.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ It then divides those into _batches_ (configured with the model's [batch_size](.
2424

2525
Signal checking functions examines a batch of time intervals. The function is always called with a batch of time intervals (DateTimeRanges). It can also optionally be called with key word arguments. It may return `True` if all intervals are ready for evaluation, `False` if no intervals are ready, or the time intervals themselves if only some are ready. A checking function is defined with the `@signal` decorator.
2626

27+
!!! note "One model, multiple signals"
28+
29+
Multiple signals may be specified for a model. SQLMesh categorizes a candidate interval as ready for evaluation if **all** the signal checking functions determine it is ready.
30+
2731
## Defining a signal
2832

2933
To define a signal, create a `signals` directory in your project folder. Define your signal in a file named `__init__.py` in that directory (you can have additional python file names as well).

sqlmesh/core/console.py

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2163,13 +2163,12 @@ def log_test_results(self, result: ModelTextTestResult, target_dialect: str) ->
21632163
self._print("-" * divider_length)
21642164
self._print("Test Failure Summary", style="red")
21652165
self._print("=" * divider_length)
2166-
failures = len(result.failures) + len(result.errors)
2166+
fail_and_error_tests = result.get_fail_and_error_tests()
21672167
self._print(f"{message} \n")
21682168

2169-
self._print(f"Failed tests ({failures}):")
2170-
for test, _ in result.failures + result.errors:
2171-
if isinstance(test, ModelTest):
2172-
self._print(f" • {test.path}::{test.test_name}")
2169+
self._print(f"Failed tests ({len(fail_and_error_tests)}):")
2170+
for test in fail_and_error_tests:
2171+
self._print(f" • {test.path}::{test.test_name}")
21732172
self._print("=" * divider_length, end="\n\n")
21742173

21752174
def _captured_unit_test_results(self, result: ModelTextTestResult) -> str:
@@ -2721,28 +2720,15 @@ def _log_test_details(
27212720
Args:
27222721
result: The unittest test result that contains metrics like num success, fails, ect.
27232722
"""
2724-
27252723
if result.wasSuccessful():
27262724
self._print("\n", end="")
27272725
return
27282726

2729-
errors = result.errors
2730-
failures = result.failures
2731-
skipped = result.skipped
2732-
2733-
infos = []
2734-
if failures:
2735-
infos.append(f"failures={len(failures)}")
2736-
if errors:
2737-
infos.append(f"errors={len(errors)}")
2738-
if skipped:
2739-
infos.append(f"skipped={skipped}")
2740-
27412727
if unittest_char_separator:
27422728
self._print(f"\n{unittest.TextTestResult.separator1}\n\n", end="")
27432729

27442730
for (test_case, failure), test_failure_tables in zip_longest( # type: ignore
2745-
failures, result.failure_tables
2731+
result.failures, result.failure_tables
27462732
):
27472733
self._print(unittest.TextTestResult.separator2)
27482734
self._print(f"FAIL: {test_case}")
@@ -2758,7 +2744,7 @@ def _log_test_details(
27582744
self._print(failure_table)
27592745
self._print("\n", end="")
27602746

2761-
for test_case, error in errors:
2747+
for test_case, error in result.errors:
27622748
self._print(unittest.TextTestResult.separator2)
27632749
self._print(f"ERROR: {test_case}")
27642750
self._print(f"{unittest.TextTestResult.separator2}")
@@ -3080,27 +3066,27 @@ def log_test_results(self, result: ModelTextTestResult, target_dialect: str) ->
30803066
fail_shared_style = {**shared_style, **fail_color}
30813067
header = str(h("span", {"style": fail_shared_style}, "-" * divider_length))
30823068
message = str(h("span", {"style": fail_shared_style}, "Test Failure Summary"))
3069+
fail_and_error_tests = result.get_fail_and_error_tests()
30833070
failed_tests = [
30843071
str(
30853072
h(
30863073
"span",
30873074
{"style": fail_shared_style},
3088-
f"Failed tests ({len(result.failures) + len(result.errors)}):",
3075+
f"Failed tests ({len(fail_and_error_tests)}):",
30893076
)
30903077
)
30913078
]
30923079

3093-
for test, _ in result.failures + result.errors:
3094-
if isinstance(test, ModelTest):
3095-
failed_tests.append(
3096-
str(
3097-
h(
3098-
"span",
3099-
{"style": fail_shared_style},
3100-
f" • {test.model.name}::{test.test_name}",
3101-
)
3080+
for test in fail_and_error_tests:
3081+
failed_tests.append(
3082+
str(
3083+
h(
3084+
"span",
3085+
{"style": fail_shared_style},
3086+
f" • {test.model.name}::{test.test_name}",
31023087
)
31033088
)
3089+
)
31043090
failures = "<br>".join(failed_tests)
31053091
footer = str(h("span", {"style": fail_shared_style}, "=" * divider_length))
31063092
error_output = widgets.Textarea(output, layout={"height": "300px", "width": "100%"})
@@ -3508,10 +3494,10 @@ def log_test_results(self, result: ModelTextTestResult, target_dialect: str) ->
35083494
self._log_test_details(result, unittest_char_separator=False)
35093495
self._print("```\n\n")
35103496

3511-
failures = len(result.failures) + len(result.errors)
3497+
fail_and_error_tests = result.get_fail_and_error_tests()
35123498
self._print(f"**{message}**\n")
3513-
self._print(f"**Failed tests ({failures}):**")
3514-
for test, _ in result.failures + result.errors:
3499+
self._print(f"**Failed tests ({len(fail_and_error_tests)}):**")
3500+
for test in fail_and_error_tests:
35153501
if isinstance(test, ModelTest):
35163502
self._print(f" • `{test.model.name}`::`{test.test_name}`\n\n")
35173503

sqlmesh/core/dialect.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from sqlglot import Dialect, Generator, ParseError, Parser, Tokenizer, TokenType, exp
1414
from sqlglot.dialects.dialect import DialectType
1515
from sqlglot.dialects import DuckDB, Snowflake
16+
import sqlglot.dialects.athena as athena
1617
from sqlglot.helper import seq_get
1718
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1819
from sqlglot.optimizer.qualify_columns import quote_identifiers
@@ -1014,6 +1015,14 @@ def extend_sqlglot() -> None:
10141015
generators = {Generator}
10151016

10161017
for dialect in Dialect.classes.values():
1018+
# Athena picks a different Tokenizer / Parser / Generator depending on the query
1019+
# so this ensures that the extra ones it defines are also extended
1020+
if dialect == athena.Athena:
1021+
tokenizers.add(athena._TrinoTokenizer)
1022+
parsers.add(athena._TrinoParser)
1023+
generators.add(athena._TrinoGenerator)
1024+
generators.add(athena._HiveGenerator)
1025+
10171026
if hasattr(dialect, "Tokenizer"):
10181027
tokenizers.add(dialect.Tokenizer)
10191028
if hasattr(dialect, "Parser"):

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -324,14 +324,26 @@ def create_mapping_schema(
324324
bq_table = self._get_table(table)
325325
columns = create_mapping_schema(bq_table.schema)
326326

327-
if (
328-
include_pseudo_columns
329-
and bq_table.time_partitioning
330-
and not bq_table.time_partitioning.field
331-
):
332-
columns["_PARTITIONTIME"] = exp.DataType.build("TIMESTAMP", dialect="bigquery")
333-
if bq_table.time_partitioning.type_ == "DAY":
334-
columns["_PARTITIONDATE"] = exp.DataType.build("DATE")
327+
if include_pseudo_columns:
328+
if bq_table.time_partitioning and not bq_table.time_partitioning.field:
329+
columns["_PARTITIONTIME"] = exp.DataType.build("TIMESTAMP", dialect="bigquery")
330+
if bq_table.time_partitioning.type_ == "DAY":
331+
columns["_PARTITIONDATE"] = exp.DataType.build("DATE")
332+
if bq_table.table_id.endswith("*"):
333+
columns["_TABLE_SUFFIX"] = exp.DataType.build("STRING", dialect="bigquery")
334+
if (
335+
bq_table.external_data_configuration is not None
336+
and bq_table.external_data_configuration.source_format
337+
in (
338+
"CSV",
339+
"NEWLINE_DELIMITED_JSON",
340+
"AVRO",
341+
"PARQUET",
342+
"ORC",
343+
"DATASTORE_BACKUP",
344+
)
345+
):
346+
columns["_FILE_NAME"] = exp.DataType.build("STRING", dialect="bigquery")
335347

336348
return columns
337349

sqlmesh/core/linter/rule.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import abc
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
from pathlib import Path
66

77
from sqlmesh.core.model import Model
@@ -49,12 +49,21 @@ class TextEdit:
4949
new_text: str
5050

5151

52+
@dataclass(frozen=True)
53+
class CreateFile:
54+
"""Create a new file with the provided text."""
55+
56+
path: Path
57+
text: str
58+
59+
5260
@dataclass(frozen=True)
5361
class Fix:
5462
"""A fix that can be applied to resolve a rule violation."""
5563

5664
title: str
57-
edits: t.List[TextEdit]
65+
edits: t.List[TextEdit] = field(default_factory=list)
66+
create_files: t.List[CreateFile] = field(default_factory=list)
5867

5968

6069
class _Rule(abc.ABCMeta):

sqlmesh/core/linter/rules/builtin.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,15 @@
1414
get_range_of_model_block,
1515
read_range_from_string,
1616
)
17-
from sqlmesh.core.linter.rule import Rule, RuleViolation, Range, Fix, TextEdit, Position
17+
from sqlmesh.core.linter.rule import (
18+
Rule,
19+
RuleViolation,
20+
Range,
21+
Fix,
22+
TextEdit,
23+
Position,
24+
CreateFile,
25+
)
1826
from sqlmesh.core.linter.definition import RuleSet
1927
from sqlmesh.core.model import Model, SqlModel, ExternalModel
2028
from sqlmesh.utils.lineage import extract_references_from_query, ExternalModelReference
@@ -227,7 +235,16 @@ def create_fix(self, model_name: str) -> t.Optional[Fix]:
227235

228236
external_models_path = root / EXTERNAL_MODELS_YAML
229237
if not external_models_path.exists():
230-
return None
238+
return Fix(
239+
title="Add external model file",
240+
edits=[],
241+
create_files=[
242+
CreateFile(
243+
path=external_models_path,
244+
text=f"- name: '{model_name}'\n",
245+
)
246+
],
247+
)
231248

232249
# Figure out the position to insert the new external model at the end of the file, whether
233250
# needs new line or not.

sqlmesh/core/loader.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,8 @@ def _load_materializations(self) -> None:
710710
def _load_signals(self) -> UniqueKeyDict[str, signal]:
711711
"""Loads signals for the built-in scheduler."""
712712

713+
base_signals = signal.get_registry()
714+
713715
signals_max_mtime: t.Optional[float] = None
714716

715717
for path in self._glob_paths(
@@ -729,7 +731,10 @@ def _load_signals(self) -> UniqueKeyDict[str, signal]:
729731

730732
self._signals_max_mtime = signals_max_mtime
731733

732-
return signal.get_registry()
734+
signals = signal.get_registry()
735+
signal.set_registry(base_signals)
736+
737+
return signals
733738

734739
def _load_audits(
735740
self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry

sqlmesh/core/model/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _add_variables_to_python_env(
157157

158158
if blueprint_variables:
159159
blueprint_variables = {
160-
k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v
160+
k.lower(): SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v
161161
for k, v in blueprint_variables.items()
162162
}
163163
python_env[c.SQLMESH_BLUEPRINT_VARS] = Executable.value(

sqlmesh/core/test/definition.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,13 @@ def _to_hashable(x: t.Any) -> t.Any:
317317
#
318318
# This is a bit of a hack, but it's a way to get the best of both worlds.
319319
args: t.List[t.Any] = []
320+
321+
failed_subtest = ""
322+
323+
if subtest := getattr(self, "_subtest", None):
324+
if cte := subtest.params.get("cte"):
325+
failed_subtest = f" (CTE {cte})"
326+
320327
if expected.shape != actual.shape:
321328
_raise_if_unexpected_columns(expected.columns, actual.columns)
322329

@@ -325,13 +332,13 @@ def _to_hashable(x: t.Any) -> t.Any:
325332
missing_rows = _row_difference(expected, actual)
326333
if not missing_rows.empty:
327334
args[0] += f"\n\nMissing rows:\n\n{missing_rows}"
328-
args.append(df_to_table("Missing rows", missing_rows))
335+
args.append(df_to_table(f"Missing rows{failed_subtest}", missing_rows))
329336

330337
unexpected_rows = _row_difference(actual, expected)
331338

332339
if not unexpected_rows.empty:
333340
args[0] += f"\n\nUnexpected rows:\n\n{unexpected_rows}"
334-
args.append(df_to_table("Unexpected rows", unexpected_rows))
341+
args.append(df_to_table(f"Unexpected rows{failed_subtest}", unexpected_rows))
335342

336343
else:
337344
diff = expected.compare(actual).rename(columns={"self": "exp", "other": "act"})
@@ -341,7 +348,8 @@ def _to_hashable(x: t.Any) -> t.Any:
341348
diff.rename(columns={"exp": "Expected", "act": "Actual"}, inplace=True)
342349
if self.verbosity == Verbosity.DEFAULT:
343350
args.extend(
344-
df_to_table("Data mismatch", df) for df in _split_df_by_column_pairs(diff)
351+
df_to_table(f"Data mismatch{failed_subtest}", df)
352+
for df in _split_df_by_column_pairs(diff)
345353
)
346354
else:
347355
from pandas import MultiIndex
@@ -351,7 +359,8 @@ def _to_hashable(x: t.Any) -> t.Any:
351359
col_diff = diff[col]
352360
if not col_diff.empty:
353361
table = df_to_table(
354-
f"[bold red]Column '{col}' mismatch[/bold red]", col_diff
362+
f"[bold red]Column '{col}' mismatch{failed_subtest}[/bold red]",
363+
col_diff,
355364
)
356365
args.append(table)
357366

sqlmesh/core/test/result.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import typing as t
55
import unittest
66

7+
from sqlmesh.core.test.definition import ModelTest
8+
79
if t.TYPE_CHECKING:
810
ErrorType = t.Union[
911
t.Tuple[type[BaseException], BaseException, types.TracebackType],
@@ -42,7 +44,10 @@ def addSubTest(
4244
exctype, value, tb = err
4345
err = (exctype, value, None) # type: ignore
4446

45-
super().addSubTest(test, subtest, err)
47+
if err[0] and issubclass(err[0], test.failureException):
48+
self.addFailure(test, err)
49+
else:
50+
self.addError(test, err)
4651

4752
def _print_char(self, char: str) -> None:
4853
from sqlmesh.core.console import TerminalConsole
@@ -117,4 +122,14 @@ def merge(self, other: ModelTextTestResult) -> None:
117122
skipped_args = other.skipped[0]
118123
self.addSkip(skipped_args[0], skipped_args[1])
119124

120-
self.testsRun += 1
125+
self.testsRun += other.testsRun
126+
127+
def get_fail_and_error_tests(self) -> t.List[ModelTest]:
128+
# If tests contain failed subtests (e.g testing CTE outputs) we don't want
129+
# to report it as different test failures
130+
test_name_to_test = {
131+
test.test_name: test
132+
for test, _ in self.failures + self.errors
133+
if isinstance(test, ModelTest)
134+
}
135+
return list(test_name_to_test.values())

0 commit comments

Comments
 (0)