Skip to content

Commit 526af33

Browse files
committed
Fix bug in table mapping for python models in unit tests
1 parent 98304fb commit 526af33

File tree

2 files changed

+89
-2
lines changed

2 files changed

+89
-2
lines changed

sqlmesh/core/test/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def _model_tables(self) -> t.Dict[str, str]:
4343
# Include upstream dependencies to ensure they can be resolved during test execution
4444
return {
4545
name: self._test._test_fixture_table(name).sql()
46-
for model in self._models.values()
47-
for name in [model.name, *model.depends_on]
46+
for normalized_model_name, model in self._models.items()
47+
for name in [normalized_model_name, *model.depends_on]
4848
}
4949

5050
def with_variables(

tests/core/test_test.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from sqlmesh.core.model import Model, SqlModel, load_sql_based_model, model
3232
from sqlmesh.core.test.definition import ModelTest, PythonModelTest, SqlModelTest
3333
from sqlmesh.core.test.result import ModelTextTestResult
34+
from sqlmesh.core.test.context import TestExecutionContext
3435
from sqlmesh.utils import Verbosity
3536
from sqlmesh.utils.errors import ConfigError, SQLMeshError, TestError
3637
from sqlmesh.utils.yaml import dump as dump_yaml
@@ -2996,3 +2997,89 @@ def test_parameterized_name_self_referential_model():
29962997

29972998
_check_successful_or_raise(test1.run())
29982999
_check_successful_or_raise(test2.run())
3000+
3001+
3002+
def test_parameterized_name_self_referential_python_model():
3003+
variables = {"table_catalog": "gold"}
3004+
3005+
@model(
3006+
name="@{table_catalog}.sushi.foo",
3007+
columns={
3008+
"id": "int",
3009+
},
3010+
depends_on=["@{table_catalog}.sushi.bar"],
3011+
dialect="snowflake",
3012+
)
3013+
def execute(
3014+
context: ExecutionContext,
3015+
**kwargs: t.Any,
3016+
) -> pd.DataFrame:
3017+
current_table = context.resolve_table(f"{context.var('table_catalog')}.sushi.foo")
3018+
current_df = context.fetchdf(f"select id from {current_table}")
3019+
upstream_table = context.resolve_table(f"{context.var('table_catalog')}.sushi.bar")
3020+
upstream_df = context.fetchdf(f"select id from {upstream_table}")
3021+
3022+
return pd.DataFrame([{"ID": upstream_df["ID"].sum() + current_df["ID"].sum()}])
3023+
3024+
@model(
3025+
name="@{table_catalog}.sushi.bar",
3026+
columns={
3027+
"id": "int",
3028+
},
3029+
dialect="snowflake",
3030+
)
3031+
def execute(
3032+
context: ExecutionContext,
3033+
**kwargs: t.Any,
3034+
) -> pd.DataFrame:
3035+
return pd.DataFrame([{"ID": 1}])
3036+
3037+
model_foo = model.get_registry()["@{table_catalog}.sushi.foo"].model(
3038+
module_path=Path("."), path=Path("."), variables=variables
3039+
)
3040+
model_bar = model.get_registry()["@{table_catalog}.sushi.bar"].model(
3041+
module_path=Path("."), path=Path("."), variables=variables
3042+
)
3043+
3044+
assert model_foo.fqn == '"GOLD"."SUSHI"."FOO"'
3045+
assert model_bar.fqn == '"GOLD"."SUSHI"."BAR"'
3046+
3047+
ctx = Context(
3048+
config=Config(model_defaults=ModelDefaultsConfig(dialect="snowflake"), variables=variables)
3049+
)
3050+
ctx.upsert_model(model_foo)
3051+
ctx.upsert_model(model_bar)
3052+
3053+
test = _create_test(
3054+
body=load_yaml(
3055+
"""
3056+
test_foo:
3057+
model: {{ var('table_catalog') }}.sushi.foo
3058+
inputs:
3059+
{{ var('table_catalog') }}.sushi.foo:
3060+
rows:
3061+
- id: 3
3062+
{{ var('table_catalog') }}.sushi.bar:
3063+
rows:
3064+
- id: 5
3065+
outputs:
3066+
query:
3067+
- id: 8
3068+
""",
3069+
variables=variables,
3070+
),
3071+
test_name="test_foo",
3072+
model=model_foo,
3073+
context=ctx,
3074+
)
3075+
3076+
assert isinstance(test, PythonModelTest)
3077+
3078+
assert test.body["model"] == '"GOLD"."SUSHI"."FOO"'
3079+
assert '"GOLD"."SUSHI"."BAR"' in test.body["inputs"]
3080+
3081+
assert isinstance(test.context, TestExecutionContext)
3082+
assert '"GOLD"."SUSHI"."FOO"' in test.context._model_tables
3083+
assert '"GOLD"."SUSHI"."BAR"' in test.context._model_tables
3084+
3085+
_check_successful_or_raise(test.run())

0 commit comments

Comments
 (0)