Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 0 additions & 28 deletions sqlmesh/dbt/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,33 +268,6 @@ def remove_tests_with_invalid_refs(self, context: DbtContext) -> None:
and all(source in context.sources for source in test.dependencies.sources)
]

def fix_circular_test_refs(self, context: DbtContext) -> None:
"""
Checks for direct circular references between two models and moves the test to the downstream
model if found. This addresses the most common circular reference - relationship tests in both
directions. In the future, we may want to increase coverage by checking for indirect circular references.
Args:
context: The dbt context this model resides within.
Returns:
None
"""
for test in self.tests.copy():
for ref in test.dependencies.refs:
if ref == self.name or ref in self.dependencies.refs:
continue
model = context.refs[ref]
if (
self.name in model.dependencies.refs
or self.name in model.tests_ref_source_dependencies.refs
):
logger.info(
f"Moving test '{test.name}' from model '{self.name}' to '{model.name}' to avoid circular reference."
)
model.tests.append(test)
self.tests.remove(test)

@property
def sqlmesh_config_fields(self) -> t.Set[str]:
return {"description", "owner", "stamp", "storage_format"}
Expand All @@ -314,7 +287,6 @@ def sqlmesh_model_kwargs(
) -> t.Dict[str, t.Any]:
"""Get common sqlmesh model parameters"""
self.remove_tests_with_invalid_refs(context)
self.fix_circular_test_refs(context)

dependencies = self.dependencies.copy()
if dependencies.has_dynamic_var_names:
Expand Down
4 changes: 3 additions & 1 deletion sqlmesh/dbt/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,12 @@ def _load_models_and_seeds(self) -> None:
continue

macro_references = _macro_references(self._manifest, node)
tests = (
all_tests = (
self._tests_by_owner[node.name]
+ self._tests_by_owner[f"{node.package_name}.{node.name}"]
)
# Only include non-standalone tests (tests that don't reference other models)
tests = [test for test in all_tests if not test.is_standalone]
node_config = _node_base_config(node)

node_name = node.name
Expand Down
10 changes: 9 additions & 1 deletion sqlmesh/dbt/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,15 @@ def _lowercase_name(cls, v: str) -> str:

@property
def is_standalone(self) -> bool:
return not self.model_name
# A test is standalone if:
# 1. It has no model_name (already standalone), OR
# 2. It references other models besides its own model
if not self.model_name:
return True

# Check if test has references to other models
other_refs = {ref for ref in self.dependencies.refs if ref != self.model_name}
return bool(other_refs)

@property
def sqlmesh_config_fields(self) -> t.Set[str]:
Expand Down
1 change: 0 additions & 1 deletion tests/dbt/cli/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def test_list_select(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Resul
assert result.exit_code == 0
assert not result.exception

assert "main.orders" in result.output
assert "main.customers" in result.output
assert "main.stg_customers" in result.output
assert "main.raw_customers" in result.output
Expand Down
15 changes: 11 additions & 4 deletions tests/dbt/cli/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def test_run_option_mapping(jaffle_shop_duckdb: Path):
operations.context.console = console

plan = operations.run()
standalone_audit_name = "relationships_orders_customer_id__customer_id__ref_customers_"
assert plan.environment.name == "prod"
assert console.no_prompts is True
assert console.no_diff is True
Expand All @@ -149,7 +150,9 @@ def test_run_option_mapping(jaffle_shop_duckdb: Path):
'"jaffle_shop"."main"."orders"',
'"jaffle_shop"."main"."stg_orders"',
}
assert {s.name for s in plan.snapshots} == plan.selected_models_to_backfill
assert {s.name for s in plan.snapshots} == (
plan.selected_models_to_backfill | {standalone_audit_name}
)

plan = operations.run(select=["main.stg_orders+"], exclude=["main.customers"])
assert plan.environment.name == "prod"
Expand All @@ -163,7 +166,9 @@ def test_run_option_mapping(jaffle_shop_duckdb: Path):
'"jaffle_shop"."main"."orders"',
'"jaffle_shop"."main"."stg_orders"',
}
assert {s.name for s in plan.snapshots} == plan.selected_models_to_backfill
assert {s.name for s in plan.snapshots} == (
plan.selected_models_to_backfill | {standalone_audit_name}
)

plan = operations.run(exclude=["main.customers"])
assert plan.environment.name == "prod"
Expand All @@ -175,8 +180,10 @@ def test_run_option_mapping(jaffle_shop_duckdb: Path):
assert plan.skip_backfill is False
assert plan.selected_models_to_backfill == {k for k in operations.context.snapshots} - {
'"jaffle_shop"."main"."customers"'
}
assert {s.name for s in plan.snapshots} == plan.selected_models_to_backfill
} - {standalone_audit_name}
assert {s.name for s in plan.snapshots} == (
plan.selected_models_to_backfill | {standalone_audit_name}
)

plan = operations.run(empty=True)
assert plan.environment.name == "prod"
Expand Down
240 changes: 192 additions & 48 deletions tests/dbt/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,57 +63,201 @@ def _create_empty_project() -> t.Tuple[Path, Path]:
return _create_empty_project


def test_model_test_circular_references() -> None:
upstream_model = ModelConfig(name="upstream")
downstream_model = ModelConfig(name="downstream", dependencies=Dependencies(refs={"upstream"}))
context = DbtContext(_refs={"upstream": upstream_model, "downstream": downstream_model})

# Test and downstream model references
downstream_test = TestConfig(
name="downstream_with_upstream",
sql="",
dependencies=Dependencies(refs={"upstream", "downstream"}),
def test_test_config_is_standalone_behavior() -> None:
"""Test that TestConfig.is_standalone correctly identifies tests with cross-model references"""

# Test with no model_name (should be standalone)
standalone_test = TestConfig(
name="standalone_test",
sql="SELECT 1",
model_name=None,
dependencies=Dependencies(refs={"some_model"}),
)
upstream_test = TestConfig(
name="upstream_with_downstream",
sql="",
dependencies=Dependencies(refs={"upstream", "downstream"}),
assert standalone_test.is_standalone is True

# Test with only self-reference (should not be standalone)
self_ref_test = TestConfig(
name="self_ref_test",
sql="SELECT * FROM {{ this }}",
model_name="my_model",
dependencies=Dependencies(refs={"my_model"}),
)
assert self_ref_test.is_standalone is False

# Test with no references (should not be standalone)
no_ref_test = TestConfig(
name="no_ref_test",
sql="SELECT 1",
model_name="my_model",
dependencies=Dependencies(),
)
assert no_ref_test.is_standalone is False

# Test with references to other models (should be standalone)
cross_ref_test = TestConfig(
name="cross_ref_test",
sql="SELECT * FROM {{ ref('other_model') }}",
model_name="my_model",
dependencies=Dependencies(refs={"my_model", "other_model"}),
)
assert cross_ref_test.is_standalone is True

# Test with only references to other models, no self-reference (should be standalone)
other_only_test = TestConfig(
name="other_only_test",
sql="SELECT * FROM {{ ref('other_model') }}",
model_name="my_model",
dependencies=Dependencies(refs={"other_model"}),
)
assert other_only_test.is_standalone is True


def test_test_to_sqlmesh_creates_correct_audit_type(
dbt_dummy_postgres_config: PostgresConfig,
) -> None:
"""Test that TestConfig.to_sqlmesh creates the correct audit type based on is_standalone"""
from sqlmesh.core.audit.definition import StandaloneAudit, ModelAudit

# Set up models in context
my_model = ModelConfig(
name="my_model", sql="SELECT 1", schema="test_schema", database="test_db", alias="my_model"
)
other_model = ModelConfig(
name="other_model",
sql="SELECT 2",
schema="test_schema",
database="test_db",
alias="other_model",
)
context = DbtContext(
_refs={"my_model": my_model, "other_model": other_model},
_target=dbt_dummy_postgres_config,
)

# Test with only self-reference (should create ModelAudit)
self_ref_test = TestConfig(
name="self_ref_test",
sql="SELECT * FROM {{ this }}",
model_name="my_model",
dependencies=Dependencies(refs={"my_model"}),
)
audit = self_ref_test.to_sqlmesh(context)
assert isinstance(audit, ModelAudit)
assert audit.name == "self_ref_test"

# Test with references to other models (should create StandaloneAudit)
cross_ref_test = TestConfig(
name="cross_ref_test",
sql="SELECT * FROM {{ ref('other_model') }}",
model_name="my_model",
dependencies=Dependencies(refs={"my_model", "other_model"}),
)
audit = cross_ref_test.to_sqlmesh(context)
assert isinstance(audit, StandaloneAudit)
assert audit.name == "cross_ref_test"

# Test with no model_name (should create StandaloneAudit)
standalone_test = TestConfig(
name="standalone_test",
sql="SELECT 1",
model_name=None,
dependencies=Dependencies(),
)
audit = standalone_test.to_sqlmesh(context)
assert isinstance(audit, StandaloneAudit)
assert audit.name == "standalone_test"


@pytest.mark.slow
def test_manifest_filters_standalone_tests_from_models(
tmp_path: Path, create_empty_project
) -> None:
"""Integration test that verifies models only contain non-standalone tests after manifest loading."""
yaml = YAML()
project_dir, model_dir = create_empty_project()

# Create two models
model1_contents = "SELECT 1 as id"
model1_file = model_dir / "model1.sql"
with open(model1_file, "w", encoding="utf-8") as f:
f.write(model1_contents)

model2_contents = "SELECT 2 as id"
model2_file = model_dir / "model2.sql"
with open(model2_file, "w", encoding="utf-8") as f:
f.write(model2_contents)

# Create schema with both standalone and non-standalone tests
schema_yaml = {
"version": 2,
"models": [
{
"name": "model1",
"columns": [
{
"name": "id",
"tests": [
"not_null", # Non-standalone test - only references model1
{
"relationships": { # Standalone test - references model2
"to": "ref('model2')",
"field": "id",
}
},
],
}
],
},
{
"name": "model2",
"columns": [
{"name": "id", "tests": ["not_null"]} # Non-standalone test
],
},
],
}

schema_file = model_dir / "schema.yml"
with open(schema_file, "w", encoding="utf-8") as f:
yaml.dump(schema_yaml, f)

# Load the project through SQLMesh Context
from sqlmesh.core.context import Context

context = Context(paths=project_dir)

# No circular reference
downstream_model.tests = [downstream_test]
downstream_model.fix_circular_test_refs(context)
assert upstream_model.tests == []
assert downstream_model.tests == [downstream_test]

# Upstream model reference in downstream model
downstream_model.tests = []
upstream_model.tests = [upstream_test]
upstream_model.fix_circular_test_refs(context)
assert upstream_model.tests == []
assert downstream_model.tests == [upstream_test]

upstream_model.tests = [upstream_test]
downstream_model.tests = [downstream_test]
upstream_model.fix_circular_test_refs(context)
assert upstream_model.tests == []
assert downstream_model.tests == [downstream_test, upstream_test]

downstream_model.fix_circular_test_refs(context)
assert upstream_model.tests == []
assert downstream_model.tests == [downstream_test, upstream_test]

# Test only references
upstream_model.tests = [upstream_test]
downstream_model.tests = [downstream_test]
downstream_model.dependencies = Dependencies()
upstream_model.fix_circular_test_refs(context)
assert upstream_model.tests == []
assert downstream_model.tests == [downstream_test, upstream_test]

downstream_model.fix_circular_test_refs(context)
assert upstream_model.tests == []
assert downstream_model.tests == [downstream_test, upstream_test]
model1_snapshot = context.snapshots['"local"."main"."model1"']
model2_snapshot = context.snapshots['"local"."main"."model2"']

# Verify model1 only has non-standalone test in its audits
# Should only have "not_null" test, not the "relationships" test
model1_audit_names = [audit[0] for audit in model1_snapshot.model.audits]
assert len(model1_audit_names) == 1
assert model1_audit_names[0] == "not_null_model1_id"

# Verify model2 has its non-standalone test
model2_audit_names = [audit[0] for audit in model2_snapshot.model.audits]
assert len(model2_audit_names) == 1
assert model2_audit_names[0] == "not_null_model2_id"

# Verify the standalone test (relationships) exists as a StandaloneAudit
all_non_standalone_audits = [name for name in context._audits]
assert sorted(all_non_standalone_audits) == [
"not_null_model1_id",
"not_null_model2_id",
]

standalone_audits = [name for name in context._standalone_audits]
assert len(standalone_audits) == 1
assert standalone_audits[0] == "relationships_model1_id__id__ref_model2_"

plan_builder = context.plan_builder()
dag = plan_builder._build_dag()
assert [x.name for x in dag.sorted] == [
'"local"."main"."model1"',
'"local"."main"."model2"',
"relationships_model1_id__id__ref_model2_",
]


@pytest.mark.slow
Expand Down