Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 75 additions & 45 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from sqlmesh.core import dialect as d
from sqlmesh.core.audit import Audit, StandaloneAudit
from sqlmesh.core.dialect import schema_
from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, DataObjectType
from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, DataObjectType, DataObject
from sqlmesh.core.macros import RuntimeStage
from sqlmesh.core.model import (
AuditResult,
Expand Down Expand Up @@ -422,50 +422,14 @@ def get_snapshots_to_create(
target_snapshots: Target snapshots.
deployability_index: Determines snapshots that are deployable / representative in the context of this creation.
"""
snapshots_with_table_names = defaultdict(set)
tables_by_gateway_and_schema: t.Dict[t.Union[str, None], t.Dict[exp.Table, set[str]]] = (
defaultdict(lambda: defaultdict(set))
)

existing_data_objects = self._get_data_objects(target_snapshots, deployability_index)
snapshots_to_create = []
for snapshot in target_snapshots:
if not snapshot.is_model or snapshot.is_symbolic:
continue
is_deployable = deployability_index.is_deployable(snapshot)
table = exp.to_table(snapshot.table_name(is_deployable), dialect=snapshot.model.dialect)
snapshots_with_table_names[snapshot].add(table.name)
table_schema = d.schema_(table.db, catalog=table.catalog)
tables_by_gateway_and_schema[snapshot.model_gateway][table_schema].add(table.name)

def _get_data_objects(
schema: exp.Table,
object_names: t.Optional[t.Set[str]] = None,
gateway: t.Optional[str] = None,
) -> t.Set[str]:
logger.info("Listing data objects in schema %s", schema.sql())
objs = self.get_adapter(gateway).get_data_objects(schema, object_names)
return {obj.name for obj in objs}

with self.concurrent_context():
existing_objects: t.Set[str] = set()
# A schema can be shared across multiple engines, so we need to group tables by both gateway and schema
for gateway, tables_by_schema in tables_by_gateway_and_schema.items():
objs_for_gateway = {
obj
for objs in concurrent_apply_to_values(
list(tables_by_schema),
lambda s: _get_data_objects(
schema=s, object_names=tables_by_schema.get(s), gateway=gateway
),
self.ddl_concurrent_tasks,
)
for obj in objs
}
existing_objects.update(objs_for_gateway)

snapshots_to_create = []
for snapshot, table_names in snapshots_with_table_names.items():
missing_tables = table_names - existing_objects
if missing_tables or (snapshot.is_seed and not snapshot.intervals):
if snapshot.snapshot_id not in existing_data_objects or (
snapshot.is_seed and not snapshot.intervals
):
snapshots_to_create.append(snapshot)

return snapshots_to_create
Expand Down Expand Up @@ -514,16 +478,26 @@ def migrate(
allow_additive_snapshots: Set of snapshots that are allowed to have additive schema changes.
deployability_index: Determines snapshots that are deployable in the context of this evaluation.
"""
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
target_data_objects = self._get_data_objects(target_snapshots, deployability_index)
if not target_data_objects:
return

if not snapshots:
snapshots = {s.snapshot_id: s for s in target_snapshots}

allow_destructive_snapshots = allow_destructive_snapshots or set()
allow_additive_snapshots = allow_additive_snapshots or set()
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
snapshots_by_name = {s.name: s for s in snapshots.values()}
snapshots_with_data_objects = [snapshots[s_id] for s_id in target_data_objects]
with self.concurrent_context():
# Only migrate snapshots for which there's an existing data object
concurrent_apply_to_snapshots(
target_snapshots,
snapshots_with_data_objects,
lambda s: self._migrate_snapshot(
s,
snapshots_by_name,
target_data_objects[s.snapshot_id],
allow_destructive_snapshots,
allow_additive_snapshots,
self.get_adapter(s.model_gateway),
Expand Down Expand Up @@ -1074,6 +1048,7 @@ def _migrate_snapshot(
self,
snapshot: Snapshot,
snapshots: t.Dict[str, Snapshot],
target_data_object: t.Optional[DataObject],
allow_destructive_snapshots: t.Set[str],
allow_additive_snapshots: t.Set[str],
adapter: EngineAdapter,
Expand All @@ -1095,7 +1070,6 @@ def _migrate_snapshot(
adapter.transaction(),
adapter.session(snapshot.model.render_session_properties(**render_kwargs)),
):
target_data_object = adapter.get_data_object(target_table_name)
table_exists = target_data_object is not None
if adapter.drop_data_object_on_type_mismatch(
target_data_object, _snapshot_to_data_object_type(snapshot)
Expand Down Expand Up @@ -1447,6 +1421,62 @@ def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex
and not deployability_index.is_deployable(snapshot)
)

def _get_data_objects(
self,
target_snapshots: t.Iterable[Snapshot],
deployability_index: DeployabilityIndex,
) -> t.Dict[SnapshotId, DataObject]:
"""Returns a dictionary of snapshot IDs to existing data objects of their physical tables.
Args:
target_snapshots: Target snapshots.
deployability_index: The deployability index to determine whether to look for a deployable or
a non-deployable physical table.
Returns:
A dictionary of snapshot IDs to existing data objects of their physical tables. If the data object
for a snapshot is not found, it will not be included in the dictionary.
"""
tables_by_gateway_and_schema: t.Dict[t.Union[str, None], t.Dict[exp.Table, set[str]]] = (
defaultdict(lambda: defaultdict(set))
)
snapshots_by_table_name: t.Dict[str, Snapshot] = {}
for snapshot in target_snapshots:
if not snapshot.is_model or snapshot.is_symbolic:
continue
is_deployable = deployability_index.is_deployable(snapshot)
table = exp.to_table(snapshot.table_name(is_deployable), dialect=snapshot.model.dialect)
table_schema = d.schema_(table.db, catalog=table.catalog)
tables_by_gateway_and_schema[snapshot.model_gateway][table_schema].add(table.name)
snapshots_by_table_name[table.name] = snapshot

def _get_data_objects_in_schema(
schema: exp.Table,
object_names: t.Optional[t.Set[str]] = None,
gateway: t.Optional[str] = None,
) -> t.List[DataObject]:
logger.info("Listing data objects in schema %s", schema.sql())
return self.get_adapter(gateway).get_data_objects(schema, object_names)

with self.concurrent_context():
existing_objects: t.List[DataObject] = []
# A schema can be shared across multiple engines, so we need to group tables by both gateway and schema
for gateway, tables_by_schema in tables_by_gateway_and_schema.items():
objs_for_gateway = [
obj
for objs in concurrent_apply_to_values(
list(tables_by_schema),
lambda s: _get_data_objects_in_schema(
schema=s, object_names=tables_by_schema.get(s), gateway=gateway
),
self.ddl_concurrent_tasks,
)
for obj in objs
]
existing_objects.extend(objs_for_gateway)

return {snapshots_by_table_name[obj.name].snapshot_id: obj for obj in existing_objects}


def _evaluation_strategy(snapshot: SnapshotInfoLike, adapter: EngineAdapter) -> EvaluationStrategy:
klass: t.Type
Expand Down
Loading