diff --git a/sqlmesh/dbt/basemodel.py b/sqlmesh/dbt/basemodel.py index fad86f618e..d953a8dedf 100644 --- a/sqlmesh/dbt/basemodel.py +++ b/sqlmesh/dbt/basemodel.py @@ -4,6 +4,7 @@ from abc import abstractmethod from enum import Enum from pathlib import Path +import logging from pydantic import Field from sqlglot.helper import ensure_list @@ -38,6 +39,9 @@ BMC = t.TypeVar("BMC", bound="BaseModelConfig") +logger = logging.getLogger(__name__) + + class Materialization(str, Enum): """DBT model materializations""" @@ -261,12 +265,11 @@ def remove_tests_with_invalid_refs(self, context: DbtContext) -> None: if all(ref in context.refs for ref in test.dependencies.refs) ] - def check_for_circular_test_refs(self, context: DbtContext) -> None: + def fix_circular_test_refs(self, context: DbtContext) -> None: """ - Checks for direct circular references between two models and raises an exception if found. - This addresses the most common circular reference seen when importing a dbt project - - relationship tests in both directions. In the future, we may want to increase coverage by - checking for indirect circular references. + Checks for direct circular references between two models and moves the test to the downstream + model if found. This addresses the most common circular reference - relationship tests in both + directions. In the future, we may want to increase coverage by checking for indirect circular references. Args: context: The dbt context this model resides within. @@ -274,24 +277,20 @@ def check_for_circular_test_refs(self, context: DbtContext) -> None: Returns: None """ - for test in self.tests: + for test in self.tests.copy(): for ref in test.dependencies.refs: - model = context.refs[ref] if ref == self.name or ref in self.dependencies.refs: continue - elif self.name in model.dependencies.refs: - raise ConfigError( - f"Test '{test.name}' for model '{self.name}' depends on downstream model '{model.name}'." - " Move the test to the downstream model to avoid circular references." - ) - elif self.name in model.tests_ref_source_dependencies.refs: - circular_test = next( - test.name for test in model.tests if ref in test.dependencies.refs - ) - raise ConfigError( - f"Circular reference detected between tests for models '{self.name}' and '{model.name}':" - f" '{test.name}' ({self.name}), '{circular_test}' ({model.name})." + model = context.refs[ref] + if ( + self.name in model.dependencies.refs + or self.name in model.tests_ref_source_dependencies.refs + ): + logger.info( + f"Moving test '{test.name}' from model '{self.name}' to '{model.name}' to avoid circular reference." ) + model.tests.append(test) + self.tests.remove(test) @property def sqlmesh_config_fields(self) -> t.Set[str]: @@ -312,7 +311,7 @@ def sqlmesh_model_kwargs( ) -> t.Dict[str, t.Any]: """Get common sqlmesh model parameters""" self.remove_tests_with_invalid_refs(context) - self.check_for_circular_test_refs(context) + self.fix_circular_test_refs(context) dependencies = self.dependencies.copy() if dependencies.has_dynamic_var_names: diff --git a/tests/dbt/test_model.py b/tests/dbt/test_model.py index df9f229900..a0861e7bbd 100644 --- a/tests/dbt/test_model.py +++ b/tests/dbt/test_model.py @@ -8,7 +8,6 @@ from sqlmesh.dbt.model import ModelConfig from sqlmesh.dbt.target import PostgresConfig from sqlmesh.dbt.test import TestConfig -from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.yaml import YAML pytestmark = pytest.mark.dbt @@ -30,25 +29,41 @@ def test_model_test_circular_references() -> None: sql="", dependencies=Dependencies(refs={"upstream", "downstream"}), ) + + # No circular reference downstream_model.tests = [downstream_test] - downstream_model.check_for_circular_test_refs(context) + downstream_model.fix_circular_test_refs(context) + assert upstream_model.tests == [] + assert downstream_model.tests == [downstream_test] + # Upstream model reference in downstream model downstream_model.tests = [] upstream_model.tests = [upstream_test] - with pytest.raises(ConfigError, match="downstream model"): - upstream_model.check_for_circular_test_refs(context) + upstream_model.fix_circular_test_refs(context) + assert upstream_model.tests == [] + assert downstream_model.tests == [upstream_test] + upstream_model.tests = [upstream_test] downstream_model.tests = [downstream_test] - with pytest.raises(ConfigError, match="downstream model"): - upstream_model.check_for_circular_test_refs(context) - downstream_model.check_for_circular_test_refs(context) + upstream_model.fix_circular_test_refs(context) + assert upstream_model.tests == [] + assert downstream_model.tests == [downstream_test, upstream_test] + + downstream_model.fix_circular_test_refs(context) + assert upstream_model.tests == [] + assert downstream_model.tests == [downstream_test, upstream_test] # Test only references + upstream_model.tests = [upstream_test] + downstream_model.tests = [downstream_test] downstream_model.dependencies = Dependencies() - with pytest.raises(ConfigError, match="between tests"): - upstream_model.check_for_circular_test_refs(context) - with pytest.raises(ConfigError, match="between tests"): - downstream_model.check_for_circular_test_refs(context) + upstream_model.fix_circular_test_refs(context) + assert upstream_model.tests == [] + assert downstream_model.tests == [downstream_test, upstream_test] + + downstream_model.fix_circular_test_refs(context) + assert upstream_model.tests == [] + assert downstream_model.tests == [downstream_test, upstream_test] @pytest.mark.slow