Skip to content

Commit ea860e4

Browse files
authored
Fix: Inference of python model names from the file system (#3844)
1 parent 64af77b commit ea860e4

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

sqlmesh/core/model/decorator.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
render_meta_fields,
2121
)
2222
from sqlmesh.core.model.kind import ModelKindName, _ModelKind
23-
from sqlmesh.utils import registry_decorator
23+
from sqlmesh.utils import registry_decorator, DECORATOR_RETURN_TYPE
2424
from sqlmesh.utils.errors import ConfigError
2525
from sqlmesh.utils.metaprogramming import build_env, serialize_env
2626

@@ -39,6 +39,7 @@ def __init__(self, name: t.Optional[str] = None, is_sql: bool = False, **kwargs:
3939
if not is_sql and "columns" not in kwargs:
4040
raise ConfigError("Python model must define column schema.")
4141

42+
self.name_provided = bool(name)
4243
self.name = name or ""
4344
self.is_sql = is_sql
4445
self.kwargs = kwargs
@@ -76,6 +77,13 @@ def __init__(self, name: t.Optional[str] = None, is_sql: bool = False, **kwargs:
7677
for column_name, column_type in self.kwargs.pop("columns", {}).items()
7778
}
7879

80+
def __call__(
81+
self, func: t.Callable[..., DECORATOR_RETURN_TYPE]
82+
) -> t.Callable[..., DECORATOR_RETURN_TYPE]:
83+
if not self.name_provided:
84+
self.name = get_model_name(Path(inspect.getfile(func)))
85+
return super().__call__(func)
86+
7987
def model(
8088
self,
8189
*,
@@ -97,10 +105,7 @@ def model(
97105
env: t.Dict[str, t.Any] = {}
98106
entrypoint = self.func.__name__
99107

100-
if not self.name and infer_names:
101-
self.name = get_model_name(Path(inspect.getfile(self.func)))
102-
103-
if not self.name:
108+
if not self.name_provided and not infer_names:
104109
raise ConfigError("Python model must have a name.")
105110

106111
kind = self.kwargs.get("kind", None)

tests/core/test_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6024,6 +6024,32 @@ def my_model(context, **kwargs):
60246024
assert isinstance(context.get_model(expected_name), PythonModel)
60256025

60266026

6027+
def test_python_model_name_inference_multiple_models(tmp_path: Path) -> None:
6028+
init_example_project(tmp_path, dialect="duckdb")
6029+
config = Config(
6030+
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
6031+
model_naming=NameInferenceConfig(infer_names=True),
6032+
)
6033+
6034+
path_a = tmp_path / "models/test_schema/test_model_a.py"
6035+
path_b = tmp_path / "models/test_schema/test_model_b.py"
6036+
6037+
model_payload = """from sqlmesh import model
6038+
@model(
6039+
columns={'"COL"': "int"},
6040+
)
6041+
def my_model(context, **kwargs):
6042+
pass"""
6043+
6044+
path_a.parent.mkdir(parents=True, exist_ok=True)
6045+
path_a.write_text(model_payload)
6046+
path_b.write_text(model_payload)
6047+
6048+
context = Context(paths=tmp_path, config=config)
6049+
assert context.get_model("test_schema.test_model_a").name == "test_schema.test_model_a"
6050+
assert context.get_model("test_schema.test_model_b").name == "test_schema.test_model_b"
6051+
6052+
60276053
def test_custom_kind():
60286054
from sqlmesh import CustomMaterialization
60296055

0 commit comments

Comments
 (0)