Skip to content

Commit ae21dba

Browse files
pr feedback
1 parent 471b459 commit ae21dba

File tree

9 files changed

+123
-100
lines changed

9 files changed

+123
-100
lines changed

sqlmesh/core/context.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@
117117
run_tests,
118118
)
119119
from sqlmesh.core.user import User
120-
from sqlmesh.dbt.builtin import set_selected_resources
121120
from sqlmesh.utils import UniqueKeyDict, Verbosity
122121
from sqlmesh.utils.concurrency import concurrent_apply_to_values
123122
from sqlmesh.utils.dag import DAG
@@ -1583,11 +1582,6 @@ def plan_builder(
15831582
"Selector did not return any models. Please check your model selection and try again."
15841583
)
15851584

1586-
if self._project_type != c.NATIVE:
1587-
set_selected_resources(
1588-
models=model_selector.expand_model_selections(select_models or "*")
1589-
)
1590-
15911585
snapshots = self._snapshots(models_override)
15921586
context_diff = self._context_diff(
15931587
environment or c.PROD,
@@ -1682,6 +1676,7 @@ def plan_builder(
16821676
end_override_per_model=max_interval_end_per_model,
16831677
console=self.console,
16841678
user_provided_flags=user_provided_flags,
1679+
selected_models=model_selector.expand_model_selections(select_models or "*"),
16851680
explain=explain or False,
16861681
ignore_cron=ignore_cron or False,
16871682
)
@@ -2488,9 +2483,6 @@ def _run(
24882483
select_models, no_auto_upstream, snapshots.values()
24892484
)
24902485

2491-
if self._project_type != c.NATIVE:
2492-
set_selected_resources(models=select_models or set([s.name for s in snapshots.keys()]))
2493-
24942486
completion_status = scheduler.run(
24952487
environment,
24962488
start=start,

sqlmesh/core/environment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def execute_environment_statements(
312312
start: t.Optional[TimeLike] = None,
313313
end: t.Optional[TimeLike] = None,
314314
execution_time: t.Optional[TimeLike] = None,
315+
selected_models: t.Optional[t.Set[str]] = None,
315316
) -> None:
316317
try:
317318
rendered_expressions = [
@@ -327,6 +328,7 @@ def execute_environment_statements(
327328
execution_time=execution_time,
328329
environment_naming_info=environment_naming_info,
329330
engine_adapter=adapter,
331+
selected_models=selected_models,
330332
)
331333
]
332334
except Exception as e:

sqlmesh/core/plan/builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def __init__(
129129
end_override_per_model: t.Optional[t.Dict[str, datetime]] = None,
130130
console: t.Optional[PlanBuilderConsole] = None,
131131
user_provided_flags: t.Optional[t.Dict[str, UserProvidedFlags]] = None,
132+
selected_models: t.Optional[t.Set[str]] = None,
132133
):
133134
self._context_diff = context_diff
134135
self._no_gaps = no_gaps
@@ -169,6 +170,7 @@ def __init__(
169170
self._console = console or get_console()
170171
self._choices: t.Dict[SnapshotId, SnapshotChangeCategory] = {}
171172
self._user_provided_flags = user_provided_flags
173+
self._selected_models = selected_models
172174
self._explain = explain
173175

174176
self._start = start
@@ -347,6 +349,7 @@ def build(self) -> Plan:
347349
ensure_finalized_snapshots=self._ensure_finalized_snapshots,
348350
ignore_cron=self._ignore_cron,
349351
user_provided_flags=self._user_provided_flags,
352+
selected_models=self._selected_models,
350353
)
351354
self._latest_plan = plan
352355
return plan

sqlmesh/core/plan/definition.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ class Plan(PydanticModel, frozen=True):
7070
execution_time_: t.Optional[TimeLike] = Field(default=None, alias="execution_time")
7171

7272
user_provided_flags: t.Optional[t.Dict[str, UserProvidedFlags]] = None
73+
selected_models: t.Optional[t.Set[str]] = None
74+
"""Models that have been selected for this plan (used for dbt selected_resouces)"""
7375

7476
@cached_property
7577
def start(self) -> TimeLike:
@@ -282,6 +284,7 @@ def to_evaluatable(self) -> EvaluatablePlan:
282284
},
283285
environment_statements=self.context_diff.environment_statements,
284286
user_provided_flags=self.user_provided_flags,
287+
selected_models=self.selected_models,
285288
)
286289

287290
@cached_property
@@ -319,6 +322,7 @@ class EvaluatablePlan(PydanticModel):
319322
disabled_restatement_models: t.Set[str]
320323
environment_statements: t.Optional[t.List[EnvironmentStatements]] = None
321324
user_provided_flags: t.Optional[t.Dict[str, UserProvidedFlags]] = None
325+
selected_models: t.Optional[t.Set[str]] = None
322326

323327
def is_selected_for_backfill(self, model_fqn: str) -> bool:
324328
return self.models_to_backfill is None or model_fqn in self.models_to_backfill

sqlmesh/core/plan/evaluator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def visit_before_all_stage(self, stage: stages.BeforeAllStage, plan: Evaluatable
137137
start=plan.start,
138138
end=plan.end,
139139
execution_time=plan.execution_time,
140+
selected_models=plan.selected_models,
140141
)
141142

142143
def visit_after_all_stage(self, stage: stages.AfterAllStage, plan: EvaluatablePlan) -> None:
@@ -150,6 +151,7 @@ def visit_after_all_stage(self, stage: stages.AfterAllStage, plan: EvaluatablePl
150151
start=plan.start,
151152
end=plan.end,
152153
execution_time=plan.execution_time,
154+
selected_models=plan.selected_models,
153155
)
154156

155157
def visit_create_snapshot_records_stage(
@@ -257,6 +259,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
257259
allow_destructive_snapshots=plan.allow_destructive_models,
258260
allow_additive_snapshots=plan.allow_additive_models,
259261
selected_snapshot_ids=stage.selected_snapshot_ids,
262+
selected_models=plan.selected_models,
260263
)
261264
if errors:
262265
raise PlanError("Plan application failed.")

sqlmesh/core/scheduler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ def run_merged_intervals(
416416
start: t.Optional[TimeLike] = None,
417417
end: t.Optional[TimeLike] = None,
418418
allow_destructive_snapshots: t.Optional[t.Set[str]] = None,
419+
selected_models: t.Optional[t.Set[str]] = None,
419420
allow_additive_snapshots: t.Optional[t.Set[str]] = None,
420421
selected_snapshot_ids: t.Optional[t.Set[SnapshotId]] = None,
421422
run_environment_statements: bool = False,
@@ -472,6 +473,7 @@ def run_merged_intervals(
472473
start=start,
473474
end=end,
474475
execution_time=execution_time,
476+
selected_models=selected_models,
475477
)
476478

477479
snapshots_to_create = {
@@ -526,6 +528,7 @@ def run_node(node: SchedulingUnit) -> None:
526528
allow_destructive_snapshots=allow_destructive_snapshots,
527529
allow_additive_snapshots=allow_additive_snapshots,
528530
target_table_exists=snapshot.snapshot_id not in snapshots_to_create,
531+
selected_models=selected_models,
529532
)
530533

531534
evaluation_duration_ms = now_timestamp() - execution_start_ts
@@ -595,6 +598,7 @@ def run_node(node: SchedulingUnit) -> None:
595598
start=start,
596599
end=end,
597600
execution_time=execution_time,
601+
selected_models=selected_models,
598602
)
599603

600604
self.state_sync.recycle()
@@ -798,6 +802,7 @@ def _run_or_audit(
798802
run_environment_statements=run_environment_statements,
799803
audit_only=audit_only,
800804
auto_restatement_triggers=auto_restatement_triggers,
805+
selected_models=selected_snapshots or {s.name for s in merged_intervals},
801806
)
802807

803808
return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS

sqlmesh/dbt/adapter.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,10 @@ def graph(self) -> t.Any:
180180
}
181181
)
182182

183+
@property
184+
def selected_resources(self) -> t.List[str]:
185+
return []
186+
183187

184188
class ParsetimeAdapter(BaseAdapter):
185189
def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]:
@@ -280,6 +284,13 @@ def __init__(
280284
def graph(self) -> t.Any:
281285
return self.jinja_globals.get("flat_graph", super().graph)
282286

287+
@property
288+
def selected_resources(self) -> t.List[str]:
289+
selected_models = self.jinja_globals.get("selected_models")
290+
if selected_models:
291+
return [self._dbt_model_id(model) for model in sorted(selected_models)]
292+
return []
293+
283294
def get_relation(
284295
self, database: t.Optional[str], schema: str, identifier: str
285296
) -> t.Optional[BaseRelation]:
@@ -504,3 +515,7 @@ def _normalize(self, input_table: exp.Table) -> exp.Table:
504515
normalized_table.set("db", normalized_table.this)
505516
normalized_table.set("this", None)
506517
return normalized_table
518+
519+
def _dbt_model_id(self, sqlmesh_model_name: str) -> str:
520+
parts = [part.strip('"') for part in sqlmesh_model_name.split(".")]
521+
return f"model.{parts[0]}.{parts[-1]}"

sqlmesh/dbt/builtin.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def create_builtin_globals(
459459
"run_query": sql_execution.run_query,
460460
"statement": sql_execution.statement,
461461
"graph": adapter.graph,
462-
"selected_resources": get_selected_resources(),
462+
"selected_resources": adapter.selected_resources,
463463
}
464464
)
465465

@@ -486,33 +486,3 @@ def _relation_info_to_relation(
486486
}
487487
)
488488
return relation_type.create(**relation_info, quote_policy=quote_policy)
489-
490-
491-
_selected_resources: t.List[str] = []
492-
493-
494-
def set_selected_resources(
495-
models: t.Optional[t.Set[str]] = None,
496-
) -> None:
497-
global _selected_resources
498-
resources = []
499-
500-
if models:
501-
for model in models:
502-
resources.append(dbt_model_id(model))
503-
504-
_selected_resources = sorted(resources)
505-
506-
507-
def dbt_model_id(sqlmesh_model_name: str) -> str:
508-
parts = [part.strip('"') for part in sqlmesh_model_name.split(".")]
509-
return f"model.{parts[0]}.{parts[-1]}"
510-
511-
512-
def get_selected_resources() -> t.List[str]:
513-
return _selected_resources
514-
515-
516-
def clear_selected_resources() -> None:
517-
global _selected_resources
518-
_selected_resources = []

0 commit comments

Comments
 (0)