Skip to content

Commit ff45e03

Browse files
Fix!: Add project in environment statements for consistent multi-repo plans (#4966)
1 parent 107fe25 commit ff45e03

File tree

6 files changed

+96
-6
lines changed

6 files changed

+96
-6
lines changed

sqlmesh/core/context.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,15 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
622622
BUILTIN_RULES.union(project.user_rules), config.linter
623623
)
624624

625+
# Load environment statements from state for projects not in current load
626+
if any(self._projects):
627+
prod = self.state_reader.get_environment(c.PROD)
628+
if prod:
629+
existing_statements = self.state_reader.get_environment_statements(c.PROD)
630+
for stmt in existing_statements:
631+
if stmt.project and stmt.project not in self._projects:
632+
self._environment_statements.append(stmt)
633+
625634
uncached = set()
626635

627636
if any(self._projects):

sqlmesh/core/context_diff.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,9 @@ def has_requirement_changes(self) -> bool:
311311

312312
@property
313313
def has_environment_statements_changes(self) -> bool:
314-
return self.environment_statements != self.previous_environment_statements
314+
return sorted(self.environment_statements, key=lambda s: s.project or "") != sorted(
315+
self.previous_environment_statements, key=lambda s: s.project or ""
316+
)
315317

316318
@property
317319
def has_snapshot_changes(self) -> bool:

sqlmesh/core/environment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ class EnvironmentStatements(PydanticModel):
266266
after_all: t.List[str]
267267
python_env: t.Dict[str, Executable]
268268
jinja_macros: t.Optional[JinjaMacroRegistry] = None
269+
project: t.Optional[str] = None
269270

270271
def render_before_all(
271272
self,

sqlmesh/core/loader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,11 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm
815815
path=self.config_path,
816816
)
817817

818-
return [EnvironmentStatements(**statements, python_env=python_env)]
818+
return [
819+
EnvironmentStatements(
820+
**statements, python_env=python_env, project=self.config.project or None
821+
)
822+
]
819823
return []
820824

821825
def _load_linting_rules(self) -> RuleSet:

sqlmesh/dbt/loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm
277277
],
278278
python_env={},
279279
jinja_macros=jinja_registry,
280+
project=package_name,
280281
)
281282
project_names.add(package_name)
282283

tests/core/test_integration.py

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4944,15 +4944,88 @@ def test_multi(mocker):
49444944
context.apply(plan)
49454945
validate_apply_basics(context, c.PROD, plan.snapshots.values())
49464946

4947-
# Ensure only repo_1's environment statements have executed in this context
4947+
# Ensure that before_all and after_all statements of both repos are there despite planning with repo_1
49484948
environment_statements = context.state_reader.get_environment_statements(c.PROD)
4949-
assert len(environment_statements) == 1
4950-
assert environment_statements[0].before_all == [
4949+
assert len(environment_statements) == 2
4950+
4951+
# Ensure that environment statements have the project field set correctly
4952+
sorted_env_statements = sorted(environment_statements, key=lambda es: es.project)
4953+
assert sorted_env_statements[0].project == "repo_1"
4954+
assert sorted_env_statements[1].project == "repo_2"
4955+
4956+
# Assert before_all and after_all for each project
4957+
assert sorted_env_statements[0].before_all == [
49514958
"CREATE TABLE IF NOT EXISTS before_1 AS select @one()"
49524959
]
4953-
assert environment_statements[0].after_all == [
4960+
assert sorted_env_statements[0].after_all == [
49544961
"CREATE TABLE IF NOT EXISTS after_1 AS select @dup()"
49554962
]
4963+
assert sorted_env_statements[1].before_all == [
4964+
"CREATE TABLE IF NOT EXISTS before_2 AS select @two()"
4965+
]
4966+
assert sorted_env_statements[1].after_all == [
4967+
"CREATE TABLE IF NOT EXISTS after_2 AS select @dup()"
4968+
]
4969+
4970+
4971+
@use_terminal_console
4972+
def test_multi_repo_single_project_environment_statements_update(copy_to_temp_path):
4973+
paths = copy_to_temp_path("examples/multi")
4974+
repo_1_path = f"{paths[0]}/repo_1"
4975+
repo_2_path = f"{paths[0]}/repo_2"
4976+
4977+
context = Context(paths=[repo_1_path, repo_2_path], gateway="memory")
4978+
context._new_state_sync().reset(default_catalog=context.default_catalog)
4979+
4980+
initial_plan = context.plan_builder().build()
4981+
context.apply(initial_plan)
4982+
4983+
# Get initial statements
4984+
initial_statements = context.state_reader.get_environment_statements(c.PROD)
4985+
assert len(initial_statements) == 2
4986+
4987+
# Modify repo_1's config to add a new before_all statement
4988+
repo_1_config_path = f"{repo_1_path}/config.yaml"
4989+
with open(repo_1_config_path, "r") as f:
4990+
config_content = f.read()
4991+
4992+
# Add a new before_all statement to repo_1 only
4993+
modified_config = config_content.replace(
4994+
"CREATE TABLE IF NOT EXISTS before_1 AS select @one()",
4995+
"CREATE TABLE IF NOT EXISTS before_1 AS select @one()\n - CREATE TABLE IF NOT EXISTS before_1_modified AS select 999",
4996+
)
4997+
4998+
with open(repo_1_config_path, "w") as f:
4999+
f.write(modified_config)
5000+
5001+
# Create new context with modified config but only for repo_1
5002+
context_repo_1_only = Context(
5003+
paths=[repo_1_path], state_sync=context.state_sync, gateway="memory"
5004+
)
5005+
5006+
# Plan with only repo_1, this should preserve repo_2's statements from state
5007+
repo_1_plan = context_repo_1_only.plan_builder(environment="dev").build()
5008+
context_repo_1_only.apply(repo_1_plan)
5009+
updated_statements = context_repo_1_only.state_reader.get_environment_statements("dev")
5010+
5011+
# Should still have statements from both projects
5012+
assert len(updated_statements) == 2
5013+
5014+
# Sort by project
5015+
sorted_updated = sorted(updated_statements, key=lambda es: es.project or "")
5016+
5017+
# Verify repo_1 has the new statement
5018+
repo_1_updated = sorted_updated[0]
5019+
assert repo_1_updated.project == "repo_1"
5020+
assert len(repo_1_updated.before_all) == 2
5021+
assert "CREATE TABLE IF NOT EXISTS before_1_modified" in repo_1_updated.before_all[1]
5022+
5023+
# Verify repo_2 statements are preserved from state
5024+
repo_2_preserved = sorted_updated[1]
5025+
assert repo_2_preserved.project == "repo_2"
5026+
assert len(repo_2_preserved.before_all) == 1
5027+
assert "CREATE TABLE IF NOT EXISTS before_2" in repo_2_preserved.before_all[0]
5028+
assert "CREATE TABLE IF NOT EXISTS after_2 AS select @dup()" in repo_2_preserved.after_all[0]
49565029

49575030

49585031
@use_terminal_console

0 commit comments

Comments
 (0)