diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 78a391d12f..0339f6506c 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -1677,6 +1677,11 @@ def plan_builder( end_override_per_model=max_interval_end_per_model, console=self.console, user_provided_flags=user_provided_flags, + selected_models={ + dbt_name + for model in model_selector.expand_model_selections(select_models or "*") + if (dbt_name := snapshots[model].node.dbt_name) + }, explain=explain or False, ignore_cron=ignore_cron or False, ) diff --git a/sqlmesh/core/environment.py b/sqlmesh/core/environment.py index 2a0d4f115d..4a1f417468 100644 --- a/sqlmesh/core/environment.py +++ b/sqlmesh/core/environment.py @@ -312,6 +312,7 @@ def execute_environment_statements( start: t.Optional[TimeLike] = None, end: t.Optional[TimeLike] = None, execution_time: t.Optional[TimeLike] = None, + selected_models: t.Optional[t.Set[str]] = None, ) -> None: try: rendered_expressions = [ @@ -327,6 +328,7 @@ def execute_environment_statements( execution_time=execution_time, environment_naming_info=environment_naming_info, engine_adapter=adapter, + selected_models=selected_models, ) ] except Exception as e: diff --git a/sqlmesh/core/node.py b/sqlmesh/core/node.py index ea2264f7fa..b04a59a39f 100644 --- a/sqlmesh/core/node.py +++ b/sqlmesh/core/node.py @@ -199,6 +199,7 @@ class _Node(PydanticModel): interval_unit_: t.Optional[IntervalUnit] = Field(alias="interval_unit", default=None) tags: t.List[str] = [] stamp: t.Optional[str] = None + dbt_name: t.Optional[str] = None # dbt node name _path: t.Optional[Path] = None _data_hash: t.Optional[str] = None _metadata_hash: t.Optional[str] = None diff --git a/sqlmesh/core/plan/builder.py b/sqlmesh/core/plan/builder.py index a48812d16c..a84b3b60dc 100644 --- a/sqlmesh/core/plan/builder.py +++ b/sqlmesh/core/plan/builder.py @@ -129,6 +129,7 @@ def __init__( end_override_per_model: t.Optional[t.Dict[str, datetime]] = None, console: t.Optional[PlanBuilderConsole] = None, user_provided_flags: t.Optional[t.Dict[str, UserProvidedFlags]] = None, + selected_models: t.Optional[t.Set[str]] = None, ): self._context_diff = context_diff self._no_gaps = no_gaps @@ -169,6 +170,7 @@ def __init__( self._console = console or get_console() self._choices: t.Dict[SnapshotId, SnapshotChangeCategory] = {} self._user_provided_flags = user_provided_flags + self._selected_models = selected_models self._explain = explain self._start = start @@ -347,6 +349,7 @@ def build(self) -> Plan: ensure_finalized_snapshots=self._ensure_finalized_snapshots, ignore_cron=self._ignore_cron, user_provided_flags=self._user_provided_flags, + selected_models=self._selected_models, ) self._latest_plan = plan return plan diff --git a/sqlmesh/core/plan/definition.py b/sqlmesh/core/plan/definition.py index 2f3ddb5990..aaf6ec5dc0 100644 --- a/sqlmesh/core/plan/definition.py +++ b/sqlmesh/core/plan/definition.py @@ -70,6 +70,8 @@ class Plan(PydanticModel, frozen=True): execution_time_: t.Optional[TimeLike] = Field(default=None, alias="execution_time") user_provided_flags: t.Optional[t.Dict[str, UserProvidedFlags]] = None + selected_models: t.Optional[t.Set[str]] = None + """Models that have been selected for this plan (used for dbt selected_resources)""" @cached_property def start(self) -> TimeLike: @@ -282,6 +284,7 @@ def to_evaluatable(self) -> EvaluatablePlan: }, environment_statements=self.context_diff.environment_statements, user_provided_flags=self.user_provided_flags, + selected_models=self.selected_models, ) @cached_property @@ -319,6 +322,7 @@ class EvaluatablePlan(PydanticModel): disabled_restatement_models: t.Set[str] environment_statements: t.Optional[t.List[EnvironmentStatements]] = None user_provided_flags: t.Optional[t.Dict[str, UserProvidedFlags]] = None + selected_models: t.Optional[t.Set[str]] = None def is_selected_for_backfill(self, model_fqn: str) -> bool: return self.models_to_backfill is None or model_fqn in self.models_to_backfill diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index 298d18a042..03b0b64016 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -137,6 +137,7 @@ def visit_before_all_stage(self, stage: stages.BeforeAllStage, plan: Evaluatable start=plan.start, end=plan.end, execution_time=plan.execution_time, + selected_models=plan.selected_models, ) def visit_after_all_stage(self, stage: stages.AfterAllStage, plan: EvaluatablePlan) -> None: @@ -150,6 +151,7 @@ def visit_after_all_stage(self, stage: stages.AfterAllStage, plan: EvaluatablePl start=plan.start, end=plan.end, execution_time=plan.execution_time, + selected_models=plan.selected_models, ) def visit_create_snapshot_records_stage( @@ -257,6 +259,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla allow_destructive_snapshots=plan.allow_destructive_models, allow_additive_snapshots=plan.allow_additive_models, selected_snapshot_ids=stage.selected_snapshot_ids, + selected_models=plan.selected_models, ) if errors: raise PlanError("Plan application failed.") diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 44d6b14c10..ec204927d4 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -416,6 +416,7 @@ def run_merged_intervals( start: t.Optional[TimeLike] = None, end: t.Optional[TimeLike] = None, allow_destructive_snapshots: t.Optional[t.Set[str]] = None, + selected_models: t.Optional[t.Set[str]] = None, allow_additive_snapshots: t.Optional[t.Set[str]] = None, selected_snapshot_ids: t.Optional[t.Set[SnapshotId]] = None, run_environment_statements: bool = False, @@ -472,6 +473,7 @@ def run_merged_intervals( start=start, end=end, execution_time=execution_time, + selected_models=selected_models, ) # We only need to create physical tables if the snapshot is not representative or if it @@ -533,6 +535,7 @@ def run_node(node: SchedulingUnit) -> None: allow_destructive_snapshots=allow_destructive_snapshots, allow_additive_snapshots=allow_additive_snapshots, target_table_exists=snapshot.snapshot_id not in snapshots_to_create, + selected_models=selected_models, ) evaluation_duration_ms = now_timestamp() - execution_start_ts @@ -602,6 +605,7 @@ def run_node(node: SchedulingUnit) -> None: start=start, end=end, execution_time=execution_time, + selected_models=selected_models, ) self.state_sync.recycle() @@ -808,6 +812,7 @@ def _run_or_audit( run_environment_statements=run_environment_statements, audit_only=audit_only, auto_restatement_triggers=auto_restatement_triggers, + selected_models={s.node.dbt_name for s in merged_intervals if s.node.dbt_name}, ) return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS diff --git a/sqlmesh/dbt/builtin.py b/sqlmesh/dbt/builtin.py index 0503f1dc92..e284c11797 100644 --- a/sqlmesh/dbt/builtin.py +++ b/sqlmesh/dbt/builtin.py @@ -545,6 +545,7 @@ def create_builtin_globals( "run_query": sql_execution.run_query, "statement": sql_execution.statement, "graph": adapter.graph, + "selected_resources": list(jinja_globals.get("selected_models") or []), } ) diff --git a/sqlmesh/dbt/model.py b/sqlmesh/dbt/model.py index a4ebf93ae5..3d5da1beaa 100644 --- a/sqlmesh/dbt/model.py +++ b/sqlmesh/dbt/model.py @@ -689,6 +689,7 @@ def to_sqlmesh( extract_dependencies_from_query=False, allow_partials=allow_partials, virtual_environment_mode=virtual_environment_mode, + dbt_name=self.node_name, **optional_kwargs, **model_kwargs, ) diff --git a/sqlmesh/dbt/seed.py b/sqlmesh/dbt/seed.py index 38cd635d91..d6ecc768f9 100644 --- a/sqlmesh/dbt/seed.py +++ b/sqlmesh/dbt/seed.py @@ -92,6 +92,7 @@ def to_sqlmesh( audit_definitions=audit_definitions, virtual_environment_mode=virtual_environment_mode, start=self.start or context.sqlmesh_config.model_defaults.start, + dbt_name=self.node_name, **kwargs, ) diff --git a/sqlmesh/migrations/v0097_add_dbt_name_in_node.py b/sqlmesh/migrations/v0097_add_dbt_name_in_node.py new file mode 100644 index 0000000000..f8909e4430 --- /dev/null +++ b/sqlmesh/migrations/v0097_add_dbt_name_in_node.py @@ -0,0 +1,9 @@ +"""Add 'dbt_name' property to node definition.""" + + +def migrate_schemas(state_sync, **kwargs): # type: ignore + pass + + +def migrate_rows(state_sync, **kwargs): # type: ignore + pass diff --git a/tests/dbt/test_model.py b/tests/dbt/test_model.py index bfc18144ef..d3103d3681 100644 --- a/tests/dbt/test_model.py +++ b/tests/dbt/test_model.py @@ -201,7 +201,7 @@ def test_load_microbatch_all_defined( concurrent_batches=true ) }} - + SELECT 1 as cola, '2025-01-01' as ds """ microbatch_model_file = model_dir / "microbatch.sql" @@ -633,3 +633,80 @@ def test_dbt_jinja_macro_undefined_variable_error(create_empty_project): assert "Failed to update model schemas" in error_message assert "Could not render jinja for" in error_message assert "Undefined macro/variable: 'columns' in macro: 'select_columns'" in error_message + + +@pytest.mark.slow +def test_node_name_populated_for_dbt_models(dbt_dummy_postgres_config: PostgresConfig) -> None: + model_config = ModelConfig( + name="test_model", + package_name="test_package", + sql="SELECT 1 as id", + database="test_db", + schema_="test_schema", + alias="test_model", + ) + + context = DbtContext() + context.project_name = "test_project" + context.target = dbt_dummy_postgres_config + + # check after convert to SQLMesh model that node_name is populated correctly + sqlmesh_model = model_config.to_sqlmesh(context) + assert sqlmesh_model.dbt_name == "model.test_package.test_model" + + +@pytest.mark.slow +def test_load_model_dbt_node_name(tmp_path: Path) -> None: + yaml = YAML() + dbt_project_dir = tmp_path / "dbt" + dbt_project_dir.mkdir() + dbt_model_dir = dbt_project_dir / "models" + dbt_model_dir.mkdir() + + model_contents = "SELECT 1 as id, 'test' as name" + model_file = dbt_model_dir / "simple_model.sql" + with open(model_file, "w", encoding="utf-8") as f: + f.write(model_contents) + + dbt_project_config = { + "name": "test_project", + "version": "1.0.0", + "config-version": 2, + "profile": "test", + "model-paths": ["models"], + } + dbt_project_file = dbt_project_dir / "dbt_project.yml" + with open(dbt_project_file, "w", encoding="utf-8") as f: + yaml.dump(dbt_project_config, f) + + sqlmesh_config = { + "model_defaults": { + "start": "2025-01-01", + } + } + sqlmesh_config_file = dbt_project_dir / "sqlmesh.yaml" + with open(sqlmesh_config_file, "w", encoding="utf-8") as f: + yaml.dump(sqlmesh_config, f) + + dbt_data_dir = tmp_path / "dbt_data" + dbt_data_dir.mkdir() + dbt_data_file = dbt_data_dir / "local.db" + dbt_profile_config = { + "test": { + "outputs": {"duckdb": {"type": "duckdb", "path": str(dbt_data_file)}}, + "target": "duckdb", + } + } + db_profile_file = dbt_project_dir / "profiles.yml" + with open(db_profile_file, "w", encoding="utf-8") as f: + yaml.dump(dbt_profile_config, f) + + context = Context(paths=dbt_project_dir) + + # find the model by its sqlmesh fully qualified name + model_fqn = '"local"."main"."simple_model"' + assert model_fqn in context.snapshots + + # Verify that node_name is the equivalent dbt one + model = context.snapshots[model_fqn].model + assert model.dbt_name == "model.test_project.simple_model" diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index 6779e196df..551c6cc16f 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -45,6 +45,7 @@ from sqlmesh.core.state_sync.db.snapshot import _snapshot_to_json from sqlmesh.dbt.builtin import _relation_info_to_relation, Config from sqlmesh.dbt.common import Dependencies +from sqlmesh.dbt.builtin import _relation_info_to_relation from sqlmesh.dbt.column import ( ColumnConfig, column_descriptions_to_sqlmesh, @@ -2375,3 +2376,84 @@ def test_dynamic_var_names_in_macro(sushi_test_project: Project): ) converted_model = model_config.to_sqlmesh(context) assert "dynamic_test_var" in converted_model.jinja_macros.global_objs["vars"] # type: ignore + + +def test_selected_resources_with_selectors(): + sushi_context = Context(paths=["tests/fixtures/dbt/sushi_test"]) + + # A plan with a specific model selection + plan_builder = sushi_context.plan_builder(select_models=["sushi.customers"]) + plan = plan_builder.build() + assert len(plan.selected_models) == 1 + selected_model = list(plan.selected_models)[0] + assert "customers" in selected_model + + # Plan without model selections should include all models + plan_builder = sushi_context.plan_builder() + plan = plan_builder.build() + assert plan.selected_models is not None + assert len(plan.selected_models) > 10 + + # with downstream models should select customers and at least one downstream model + plan_builder = sushi_context.plan_builder(select_models=["sushi.customers+"]) + plan = plan_builder.build() + assert plan.selected_models is not None + assert len(plan.selected_models) >= 2 + assert any("customers" in model for model in plan.selected_models) + + # Test wildcard selection + plan_builder = sushi_context.plan_builder(select_models=["sushi.waiter_*"]) + plan = plan_builder.build() + assert plan.selected_models is not None + assert len(plan.selected_models) >= 4 + assert all("waiter" in model for model in plan.selected_models) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_selected_resources_context_variable( + sushi_test_project: Project, sushi_test_dbt_context: Context +): + context = sushi_test_project.context + + # empty selected resources + direct_access = context.render("{{ selected_resources }}") + assert direct_access == "[]" + + # selected_resources is iterable and count items + test_jinja = """ + {%- set resources = [] -%} + {%- for resource in selected_resources -%} + {%- do resources.append(resource) -%} + {%- endfor -%} + {{ resources | length }} + """ + result = context.render(test_jinja) + assert result.strip() == "0" + + # selected_resources in conditions + test_condition = """ + {%- if selected_resources -%} + has_resources + {%- else -%} + no_resources + {%- endif -%} + """ + result = context.render(test_condition) + assert result.strip() == "no_resources" + + # selected resources in dbt format + selected_resources = [ + "model.jaffle_shop.customers", + "model.jaffle_shop.items", + "model.jaffle_shop.orders", + ] + + # check the jinja macros rendering + result = context.render("{{ selected_resources }}", selected_resources=selected_resources) + assert result == selected_resources.__repr__() + + result = context.render(test_jinja, selected_resources=selected_resources) + assert result.strip() == "3" + + result = context.render(test_condition, selected_resources=selected_resources) + assert result.strip() == "has_resources"