|
1 | | -import typing as t |
2 | 1 | import pytest |
| 2 | +import typing as t |
| 3 | +from datetime import datetime |
| 4 | +from pathlib import Path |
3 | 5 | from pytest import FixtureRequest |
| 6 | +from pytest_mock import MockerFixture |
| 7 | + |
| 8 | +import sqlmesh.core.dialect as d |
4 | 9 | from sqlglot import exp |
5 | | -from pathlib import Path |
6 | | -from sqlglot.optimizer.qualify_columns import quote_identifiers |
| 10 | +from sqlmesh import Config, ExecutionContext, model |
7 | 11 | from sqlglot.helper import seq_get |
| 12 | +from sqlglot.optimizer.qualify_columns import quote_identifiers |
| 13 | +from sqlmesh.core.config import ModelDefaultsConfig |
8 | 14 | from sqlmesh.core.engine_adapter import SnowflakeEngineAdapter |
9 | 15 | from sqlmesh.core.engine_adapter.shared import DataObject |
10 | | -import sqlmesh.core.dialect as d |
11 | | -from sqlmesh.core.model import SqlModel, load_sql_based_model |
| 16 | +from sqlmesh.core.model import ModelKindName, SqlModel, load_sql_based_model |
12 | 17 | from sqlmesh.core.plan import Plan |
13 | | -from tests.core.engine_adapter.integration import TestContext |
14 | | -from sqlmesh import model, ExecutionContext |
15 | | -from pytest_mock import MockerFixture |
16 | 18 | from sqlmesh.core.snapshot import SnapshotId, SnapshotIdBatch |
17 | 19 | from sqlmesh.core.snapshot.execution_tracker import ( |
18 | 20 | QueryExecutionContext, |
19 | 21 | QueryExecutionTracker, |
20 | 22 | ) |
21 | | -from sqlmesh.core.model import ModelKindName |
22 | | -from datetime import datetime |
23 | | - |
24 | 23 | from tests.core.engine_adapter.integration import ( |
25 | 24 | TestContext, |
26 | 25 | generate_pytest_params, |
@@ -337,3 +336,45 @@ def test_rows_tracker( |
337 | 336 | assert stats is not None |
338 | 337 | assert stats.total_rows_processed is None |
339 | 338 | assert stats.total_bytes_processed is None |
| 339 | + |
| 340 | + |
| 341 | +def test_unit_test(tmp_path: Path, ctx: TestContext): |
| 342 | + models_path = tmp_path / "models" |
| 343 | + tests_path = tmp_path / "tests" |
| 344 | + |
| 345 | + models_path.mkdir() |
| 346 | + tests_path.mkdir() |
| 347 | + |
| 348 | + test_payload = """ |
| 349 | +test_dummy_model: |
| 350 | + model: s.dummy |
| 351 | + inputs: |
| 352 | + s.src_table: |
| 353 | + rows: |
| 354 | + - c: 1 |
| 355 | + outputs: |
| 356 | + query: |
| 357 | + - c: 1 |
| 358 | + """ |
| 359 | + |
| 360 | + (models_path / "dummy_model.sql").write_text(f"MODEL (name s.dummy); SELECT c FROM s.src_table") |
| 361 | + (tests_path / "test_dummy_model.yaml").write_text(test_payload) |
| 362 | + |
| 363 | + def _config_mutator(gateway_name: str, config: Config): |
| 364 | + config.model_defaults = ModelDefaultsConfig(dialect="snowflake") |
| 365 | + test_connection = config.gateways[gateway_name].connection.copy() # type: ignore |
| 366 | + |
| 367 | + # Force the database to lowercase to test that we normalize (if we didn't, the test would fail) |
| 368 | + test_connection.database = test_connection.database.lower() # type: ignore |
| 369 | + config.gateways[gateway_name].test_connection = test_connection |
| 370 | + |
| 371 | + sqlmesh = ctx.create_context(path=tmp_path, config_mutator=_config_mutator) |
| 372 | + |
| 373 | + test_conn = sqlmesh.config.get_test_connection(ctx.gateway) |
| 374 | + assert test_conn.type_ == "snowflake" |
| 375 | + |
| 376 | + catalog = test_conn.get_catalog() |
| 377 | + assert catalog is not None and catalog.islower() |
| 378 | + |
| 379 | + test_results = sqlmesh.test() |
| 380 | + assert not test_results.errors |
0 commit comments