Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions sqlmesh/dbt/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import abstractmethod
from enum import Enum
from pathlib import Path
import logging

from pydantic import Field
from sqlglot.helper import ensure_list
Expand Down Expand Up @@ -38,6 +39,9 @@
BMC = t.TypeVar("BMC", bound="BaseModelConfig")


logger = logging.getLogger(__name__)


class Materialization(str, Enum):
"""DBT model materializations"""

Expand Down Expand Up @@ -261,37 +265,32 @@ def remove_tests_with_invalid_refs(self, context: DbtContext) -> None:
if all(ref in context.refs for ref in test.dependencies.refs)
]

def check_for_circular_test_refs(self, context: DbtContext) -> None:
def fix_circular_test_refs(self, context: DbtContext) -> None:
"""
Checks for direct circular references between two models and raises an exception if found.
This addresses the most common circular reference seen when importing a dbt project -
relationship tests in both directions. In the future, we may want to increase coverage by
checking for indirect circular references.
Checks for direct circular references between two models and moves the test to the downstream
model if found. This addresses the most common circular reference - relationship tests in both
directions. In the future, we may want to increase coverage by checking for indirect circular references.

Args:
context: The dbt context this model resides within.

Returns:
None
"""
for test in self.tests:
for test in self.tests.copy():
for ref in test.dependencies.refs:
model = context.refs[ref]
if ref == self.name or ref in self.dependencies.refs:
continue
elif self.name in model.dependencies.refs:
raise ConfigError(
f"Test '{test.name}' for model '{self.name}' depends on downstream model '{model.name}'."
" Move the test to the downstream model to avoid circular references."
)
elif self.name in model.tests_ref_source_dependencies.refs:
circular_test = next(
test.name for test in model.tests if ref in test.dependencies.refs
)
raise ConfigError(
f"Circular reference detected between tests for models '{self.name}' and '{model.name}':"
f" '{test.name}' ({self.name}), '{circular_test}' ({model.name})."
model = context.refs[ref]
if (
self.name in model.dependencies.refs
or self.name in model.tests_ref_source_dependencies.refs
):
logger.info(
f"Moving test '{test.name}' from model '{self.name}' to '{model.name}' to avoid circular reference."
)
model.tests.append(test)
self.tests.remove(test)

@property
def sqlmesh_config_fields(self) -> t.Set[str]:
Expand All @@ -312,7 +311,7 @@ def sqlmesh_model_kwargs(
) -> t.Dict[str, t.Any]:
"""Get common sqlmesh model parameters"""
self.remove_tests_with_invalid_refs(context)
self.check_for_circular_test_refs(context)
self.fix_circular_test_refs(context)

dependencies = self.dependencies.copy()
if dependencies.has_dynamic_var_names:
Expand Down
37 changes: 26 additions & 11 deletions tests/dbt/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from sqlmesh.dbt.model import ModelConfig
from sqlmesh.dbt.target import PostgresConfig
from sqlmesh.dbt.test import TestConfig
from sqlmesh.utils.errors import ConfigError
from sqlmesh.utils.yaml import YAML

pytestmark = pytest.mark.dbt
Expand All @@ -30,25 +29,41 @@ def test_model_test_circular_references() -> None:
sql="",
dependencies=Dependencies(refs={"upstream", "downstream"}),
)

# No circular reference
downstream_model.tests = [downstream_test]
downstream_model.check_for_circular_test_refs(context)
downstream_model.fix_circular_test_refs(context)
assert upstream_model.tests == []
assert downstream_model.tests == [downstream_test]

# Upstream model reference in downstream model
downstream_model.tests = []
upstream_model.tests = [upstream_test]
with pytest.raises(ConfigError, match="downstream model"):
upstream_model.check_for_circular_test_refs(context)
upstream_model.fix_circular_test_refs(context)
assert upstream_model.tests == []
assert downstream_model.tests == [upstream_test]

upstream_model.tests = [upstream_test]
downstream_model.tests = [downstream_test]
with pytest.raises(ConfigError, match="downstream model"):
upstream_model.check_for_circular_test_refs(context)
downstream_model.check_for_circular_test_refs(context)
upstream_model.fix_circular_test_refs(context)
assert upstream_model.tests == []
assert downstream_model.tests == [downstream_test, upstream_test]

downstream_model.fix_circular_test_refs(context)
assert upstream_model.tests == []
assert downstream_model.tests == [downstream_test, upstream_test]

# Test only references
upstream_model.tests = [upstream_test]
downstream_model.tests = [downstream_test]
downstream_model.dependencies = Dependencies()
with pytest.raises(ConfigError, match="between tests"):
upstream_model.check_for_circular_test_refs(context)
with pytest.raises(ConfigError, match="between tests"):
downstream_model.check_for_circular_test_refs(context)
upstream_model.fix_circular_test_refs(context)
assert upstream_model.tests == []
assert downstream_model.tests == [downstream_test, upstream_test]

downstream_model.fix_circular_test_refs(context)
assert upstream_model.tests == []
assert downstream_model.tests == [downstream_test, upstream_test]


@pytest.mark.slow
Expand Down