Skip to content

Commit e9efeff

Browse files
Fix: Add adapters for commands that are not using the snapshot evaluator (#3531)
1 parent 7f650e1 commit e9efeff

File tree

7 files changed

+115
-20
lines changed

7 files changed

+115
-20
lines changed

sqlmesh/core/context.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -418,12 +418,10 @@ def engine_adapter(self) -> EngineAdapter:
418418
@property
419419
def snapshot_evaluator(self) -> SnapshotEvaluator:
420420
if not self._snapshot_evaluator:
421-
if self._snapshot_gateways:
422-
self._create_engine_adapters(set(self._snapshot_gateways.values()))
423421
self._snapshot_evaluator = SnapshotEvaluator(
424422
{
425423
gateway: adapter.with_log_level(logging.INFO)
426-
for gateway, adapter in self._engine_adapters.items()
424+
for gateway, adapter in self.engine_adapters.items()
427425
},
428426
ddl_concurrent_tasks=self.concurrent_tasks,
429427
selected_gateway=self.selected_gateway,
@@ -1476,6 +1474,7 @@ def table_diff(
14761474
source_alias, target_alias = source, target
14771475

14781476
adapter = self.engine_adapter
1477+
14791478
if model_or_snapshot:
14801479
model = self.get_model(model_or_snapshot, raise_if_missing=True)
14811480
adapter = self._get_engine_adapter(model.gateway)
@@ -1641,6 +1640,7 @@ def create_test(
16411640
test_adapter = self._test_connection_config.create_engine_adapter(
16421641
register_comments_override=False
16431642
)
1643+
16441644
generate_test(
16451645
model=model_to_test,
16461646
input_queries=input_queries,
@@ -2021,21 +2021,19 @@ def _snapshot_gateways(self) -> t.Dict[str, str]:
20212021
if snapshot.is_model and snapshot.model.gateway
20222022
}
20232023

2024-
def _create_engine_adapters(self, gateways: t.Optional[t.Set] = None) -> None:
2025-
"""Create engine adapters for the gateways, when none provided include all defined in the configs."""
2026-
2024+
@cached_property
2025+
def engine_adapters(self) -> t.Dict[str, EngineAdapter]:
2026+
"""Returns all the engine adapters for the gateways defined in the configuration."""
20272027
for gateway_name in self.config.gateways:
2028-
if gateway_name != self.selected_gateway and (
2029-
gateways is None or gateway_name in gateways
2030-
):
2028+
if gateway_name != self.selected_gateway:
20312029
connection = self.config.get_connection(gateway_name)
20322030
adapter = connection.create_engine_adapter()
2033-
self.concurrent_tasks = min(self.concurrent_tasks, connection.concurrent_tasks)
20342031
self._engine_adapters[gateway_name] = adapter
2032+
return self._engine_adapters
20352033

20362034
def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
20372035
if gateway:
2038-
if adapter := self._engine_adapters.get(gateway):
2036+
if adapter := self.engine_adapters.get(gateway):
20392037
return adapter
20402038
raise SQLMeshError(f"Gateway '{gateway}' not found in the available engine adapters.")
20412039
return self.engine_adapter

tests/core/engine_adapter/integration/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ def create_context(
471471
)
472472
if config_mutator:
473473
config_mutator(self.gateway, config)
474+
config.gateways = {self.gateway: config.gateways[self.gateway]}
474475

475476
gateway_config = config.gateways[self.gateway]
476477
if (

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,6 +1328,8 @@ def test_sushi(ctx: TestContext, tmp_path_factory: pytest.TempPathFactory):
13281328
personal_paths=[pathlib.Path("~/.sqlmesh/config.yaml").expanduser()],
13291329
)
13301330

1331+
# To enable parallelism in integration tests
1332+
config.gateways = {ctx.gateway: config.gateways[ctx.gateway]}
13311333
current_gateway_config = config.gateways[ctx.gateway]
13321334
current_gateway_config.state_schema = sushi_state_schema
13331335

@@ -1730,6 +1732,8 @@ def _normalize_snowflake(name: str, prefix_regex: str = "(sqlmesh__)(.*)"):
17301732
if config.model_defaults.dialect != ctx.dialect:
17311733
config.model_defaults = config.model_defaults.copy(update={"dialect": ctx.dialect})
17321734

1735+
# To enable parallelism in integration tests
1736+
config.gateways = {ctx.gateway: config.gateways[ctx.gateway]}
17331737
current_gateway_config = config.gateways[ctx.gateway]
17341738

17351739
if ctx.dialect == "athena":

tests/core/test_config.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -757,10 +757,8 @@ def test_multi_gateway_config(tmp_path, mocker: MockerFixture):
757757
new_callable=mocker.PropertyMock(return_value={"snapshot": "athena"}),
758758
)
759759

760-
ctx._create_engine_adapters()
761-
762760
assert isinstance(ctx._connection_config, RedshiftConnectionConfig)
763-
assert len(ctx._engine_adapters) == 2
764-
assert isinstance(ctx._engine_adapters["athena"], AthenaEngineAdapter)
765-
assert isinstance(ctx._engine_adapters["redshift"], RedshiftEngineAdapter)
761+
assert len(ctx.engine_adapters) == 2
762+
assert isinstance(ctx.engine_adapters["athena"], AthenaEngineAdapter)
763+
assert isinstance(ctx.engine_adapters["redshift"], RedshiftEngineAdapter)
766764
assert ctx.engine_adapter == ctx._get_engine_adapter("redshift")

tests/core/test_context.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,12 @@ def test_gateway_specific_adapters(copy_to_temp_path, mocker):
322322
ctx = Context(paths=path, config="isolated_systems_config", gateway="prod")
323323
assert len(ctx._engine_adapters) == 1
324324
assert ctx.engine_adapter == ctx._engine_adapters["prod"]
325+
325326
with pytest.raises(SQLMeshError):
326-
assert ctx._get_engine_adapter("dev")
327+
assert ctx._get_engine_adapter("non_existing")
328+
329+
# This will create the requested engine adapter
330+
assert ctx._get_engine_adapter("dev") == ctx._engine_adapters["dev"]
327331

328332
ctx = Context(paths=path, config="isolated_systems_config")
329333
assert len(ctx._engine_adapters) == 1
@@ -337,8 +341,7 @@ def test_gateway_specific_adapters(copy_to_temp_path, mocker):
337341

338342
ctx = Context(paths=path, config="isolated_systems_config")
339343

340-
ctx._create_engine_adapters({"test"})
341-
assert len(ctx._engine_adapters) == 2
344+
assert len(ctx.engine_adapters) == 3
342345
assert ctx.engine_adapter == ctx._get_engine_adapter()
343346
assert ctx._get_engine_adapter("test") == ctx._engine_adapters["test"]
344347

tests/core/test_model.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from sqlmesh.core.context import Context, ExecutionContext
2727
from sqlmesh.core.dialect import parse
2828
from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS
29+
from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter
2930
from sqlmesh.core.macros import MacroEvaluator, macro
3031
from sqlmesh.core.model import (
3132
CustomKind,
@@ -6620,3 +6621,43 @@ def test_auto_restatement():
66206621
)
66216622
with pytest.raises(ValueError, match="Invalid cron expression '@invalid'.*"):
66226623
load_sql_based_model(parsed_definition)
6624+
6625+
6626+
def test_gateway_specific_render(assert_exp_eq) -> None:
6627+
gateways = {
6628+
"main": GatewayConfig(connection=DuckDBConnectionConfig()),
6629+
"duckdb": GatewayConfig(connection=DuckDBConnectionConfig()),
6630+
}
6631+
config = Config(
6632+
gateways=gateways,
6633+
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
6634+
default_gateway="main",
6635+
)
6636+
context = Context(config=config)
6637+
assert context.engine_adapter == context._engine_adapters["main"]
6638+
6639+
@model(
6640+
name="dummy_model",
6641+
is_sql=True,
6642+
kind="full",
6643+
gateway="duckdb",
6644+
grain='"x"',
6645+
)
6646+
def dummy_model_entry(evaluator: MacroEvaluator) -> exp.Select:
6647+
return exp.select("x").from_(exp.values([("1", 2)], "_v", ["x"]))
6648+
6649+
dummy_model = model.get_registry()["dummy_model"].model(module_path=Path("."), path=Path("."))
6650+
context.upsert_model(dummy_model)
6651+
assert isinstance(dummy_model, SqlModel)
6652+
assert dummy_model.gateway == "duckdb"
6653+
6654+
assert_exp_eq(
6655+
context.render("dummy_model"),
6656+
"""
6657+
SELECT
6658+
"_v"."x" AS "x",
6659+
FROM (VALUES ('1', 2)) AS "_v"("x")
6660+
""",
6661+
)
6662+
assert isinstance(context._get_engine_adapter("duckdb"), DuckDBEngineAdapter)
6663+
assert len(context._engine_adapters) == 2

tests/core/test_test.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from sqlmesh.core.macros import MacroEvaluator, macro
2626
from sqlmesh.core.model import Model, SqlModel, load_sql_based_model, model
2727
from sqlmesh.core.test.definition import ModelTest, PythonModelTest, SqlModelTest
28-
from sqlmesh.utils.errors import ConfigError, TestError
28+
from sqlmesh.utils.errors import ConfigError, SQLMeshError, TestError
2929
from sqlmesh.utils.yaml import dump as dump_yaml
3030
from sqlmesh.utils.yaml import load as load_yaml
3131

@@ -1989,3 +1989,53 @@ def test_test_generation_with_recursive_ctes(tmp_path: Path) -> None:
19891989
}
19901990

19911991
_check_successful_or_raise(context.test())
1992+
1993+
1994+
def test_test_with_gateway_specific_model(tmp_path: Path, mocker: MockerFixture) -> None:
1995+
init_example_project(tmp_path, dialect="duckdb")
1996+
1997+
config = Config(
1998+
gateways={
1999+
"main": GatewayConfig(connection=DuckDBConnectionConfig()),
2000+
"second": GatewayConfig(connection=DuckDBConnectionConfig()),
2001+
},
2002+
default_gateway="main",
2003+
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
2004+
)
2005+
gw_model_sql_file = tmp_path / "models" / "gw_model.sql"
2006+
2007+
# The model has a gateway specified which isn't the default
2008+
gw_model_sql_file.write_text(
2009+
"MODEL (name sqlmesh_example.gw_model, gateway second); SELECT c FROM sqlmesh_example.input_model;"
2010+
)
2011+
input_model_sql_file = tmp_path / "models" / "input_model.sql"
2012+
input_model_sql_file.write_text(
2013+
"MODEL (name sqlmesh_example.input_model); SELECT c FROM external_table;"
2014+
)
2015+
2016+
context = Context(paths=tmp_path, config=config)
2017+
input_queries = {'"memory"."sqlmesh_example"."input_model"': "SELECT 5 AS c"}
2018+
mocker.patch(
2019+
"sqlmesh.core.engine_adapter.base.EngineAdapter.fetchdf",
2020+
return_value=pd.DataFrame({"c": [5]}),
2021+
)
2022+
2023+
assert context.engine_adapter == context._engine_adapters["main"]
2024+
with pytest.raises(
2025+
SQLMeshError, match=r"Gateway 'wrong' not found in the available engine adapters."
2026+
):
2027+
context._get_engine_adapter("wrong")
2028+
2029+
# Create test should use the gateway specific engine adapter
2030+
context.create_test("sqlmesh_example.gw_model", input_queries=input_queries, overwrite=True)
2031+
assert context._get_engine_adapter("second") == context._engine_adapters["second"]
2032+
assert len(context._engine_adapters) == 2
2033+
2034+
test = load_yaml(context.path / c.TESTS / "test_gw_model.yaml")
2035+
2036+
assert len(test) == 1
2037+
assert "test_gw_model" in test
2038+
assert test["test_gw_model"]["inputs"] == {
2039+
'"memory"."sqlmesh_example"."input_model"': [{"c": 5}]
2040+
}
2041+
assert test["test_gw_model"]["outputs"] == {"query": [{"c": 5}]}

0 commit comments

Comments
 (0)