diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 0c398af412..31147dec0e 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -622,6 +622,15 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]: BUILTIN_RULES.union(project.user_rules), config.linter ) + # Load environment statements from state for projects not in current load + if any(self._projects): + prod = self.state_reader.get_environment(c.PROD) + if prod: + existing_statements = self.state_reader.get_environment_statements(c.PROD) + for stmt in existing_statements: + if stmt.project and stmt.project not in self._projects: + self._environment_statements.append(stmt) + uncached = set() if any(self._projects): diff --git a/sqlmesh/core/context_diff.py b/sqlmesh/core/context_diff.py index f97edec5da..ff19a3c7c6 100644 --- a/sqlmesh/core/context_diff.py +++ b/sqlmesh/core/context_diff.py @@ -311,7 +311,9 @@ def has_requirement_changes(self) -> bool: @property def has_environment_statements_changes(self) -> bool: - return self.environment_statements != self.previous_environment_statements + return sorted(self.environment_statements, key=lambda s: s.project or "") != sorted( + self.previous_environment_statements, key=lambda s: s.project or "" + ) @property def has_snapshot_changes(self) -> bool: diff --git a/sqlmesh/core/environment.py b/sqlmesh/core/environment.py index 13ca1c5485..2a0d4f115d 100644 --- a/sqlmesh/core/environment.py +++ b/sqlmesh/core/environment.py @@ -266,6 +266,7 @@ class EnvironmentStatements(PydanticModel): after_all: t.List[str] python_env: t.Dict[str, Executable] jinja_macros: t.Optional[JinjaMacroRegistry] = None + project: t.Optional[str] = None def render_before_all( self, diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index 2b40be0230..30c74884c8 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -815,7 +815,11 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm path=self.config_path, ) - return [EnvironmentStatements(**statements, python_env=python_env)] + return [ + EnvironmentStatements( + **statements, python_env=python_env, project=self.config.project or None + ) + ] return [] def _load_linting_rules(self) -> RuleSet: diff --git a/sqlmesh/dbt/loader.py b/sqlmesh/dbt/loader.py index 0f896d5bec..23d34afa31 100644 --- a/sqlmesh/dbt/loader.py +++ b/sqlmesh/dbt/loader.py @@ -277,6 +277,7 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm ], python_env={}, jinja_macros=jinja_registry, + project=package_name, ) project_names.add(package_name) diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 8923c4c75b..d03db7af91 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -4944,15 +4944,88 @@ def test_multi(mocker): context.apply(plan) validate_apply_basics(context, c.PROD, plan.snapshots.values()) - # Ensure only repo_1's environment statements have executed in this context + # Ensure that before_all and after_all statements of both repos are there despite planning with repo_1 environment_statements = context.state_reader.get_environment_statements(c.PROD) - assert len(environment_statements) == 1 - assert environment_statements[0].before_all == [ + assert len(environment_statements) == 2 + + # Ensure that environment statements have the project field set correctly + sorted_env_statements = sorted(environment_statements, key=lambda es: es.project) + assert sorted_env_statements[0].project == "repo_1" + assert sorted_env_statements[1].project == "repo_2" + + # Assert before_all and after_all for each project + assert sorted_env_statements[0].before_all == [ "CREATE TABLE IF NOT EXISTS before_1 AS select @one()" ] - assert environment_statements[0].after_all == [ + assert sorted_env_statements[0].after_all == [ "CREATE TABLE IF NOT EXISTS after_1 AS select @dup()" ] + assert sorted_env_statements[1].before_all == [ + "CREATE TABLE IF NOT EXISTS before_2 AS select @two()" + ] + assert sorted_env_statements[1].after_all == [ + "CREATE TABLE IF NOT EXISTS after_2 AS select @dup()" + ] + + +@use_terminal_console +def test_multi_repo_single_project_environment_statements_update(copy_to_temp_path): + paths = copy_to_temp_path("examples/multi") + repo_1_path = f"{paths[0]}/repo_1" + repo_2_path = f"{paths[0]}/repo_2" + + context = Context(paths=[repo_1_path, repo_2_path], gateway="memory") + context._new_state_sync().reset(default_catalog=context.default_catalog) + + initial_plan = context.plan_builder().build() + context.apply(initial_plan) + + # Get initial statements + initial_statements = context.state_reader.get_environment_statements(c.PROD) + assert len(initial_statements) == 2 + + # Modify repo_1's config to add a new before_all statement + repo_1_config_path = f"{repo_1_path}/config.yaml" + with open(repo_1_config_path, "r") as f: + config_content = f.read() + + # Add a new before_all statement to repo_1 only + modified_config = config_content.replace( + "CREATE TABLE IF NOT EXISTS before_1 AS select @one()", + "CREATE TABLE IF NOT EXISTS before_1 AS select @one()\n - CREATE TABLE IF NOT EXISTS before_1_modified AS select 999", + ) + + with open(repo_1_config_path, "w") as f: + f.write(modified_config) + + # Create new context with modified config but only for repo_1 + context_repo_1_only = Context( + paths=[repo_1_path], state_sync=context.state_sync, gateway="memory" + ) + + # Plan with only repo_1, this should preserve repo_2's statements from state + repo_1_plan = context_repo_1_only.plan_builder(environment="dev").build() + context_repo_1_only.apply(repo_1_plan) + updated_statements = context_repo_1_only.state_reader.get_environment_statements("dev") + + # Should still have statements from both projects + assert len(updated_statements) == 2 + + # Sort by project + sorted_updated = sorted(updated_statements, key=lambda es: es.project or "") + + # Verify repo_1 has the new statement + repo_1_updated = sorted_updated[0] + assert repo_1_updated.project == "repo_1" + assert len(repo_1_updated.before_all) == 2 + assert "CREATE TABLE IF NOT EXISTS before_1_modified" in repo_1_updated.before_all[1] + + # Verify repo_2 statements are preserved from state + repo_2_preserved = sorted_updated[1] + assert repo_2_preserved.project == "repo_2" + assert len(repo_2_preserved.before_all) == 1 + assert "CREATE TABLE IF NOT EXISTS before_2" in repo_2_preserved.before_all[0] + assert "CREATE TABLE IF NOT EXISTS after_2 AS select @dup()" in repo_2_preserved.after_all[0] @use_terminal_console