Skip to content

Commit fca603b

Browse files
committed
Chore: Cache results of get_data_objects
1 parent a255e17 commit fca603b

File tree

6 files changed

+404
-17
lines changed

6 files changed

+404
-17
lines changed

sqlmesh/core/engine_adapter/base.py

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def __init__(
161161
self.correlation_id = correlation_id
162162
self._schema_differ_overrides = schema_differ_overrides
163163
self._query_execution_tracker = query_execution_tracker
164+
self._data_object_cache: t.Dict[str, t.Optional[DataObject]] = {}
164165

165166
def with_settings(self, **kwargs: t.Any) -> EngineAdapter:
166167
extra_kwargs = {
@@ -983,6 +984,13 @@ def _create_table(
983984
),
984985
track_rows_processed=track_rows_processed,
985986
)
987+
# Extract table name to clear cache
988+
table_name = (
989+
table_name_or_schema.this
990+
if isinstance(table_name_or_schema, exp.Schema)
991+
else table_name_or_schema
992+
)
993+
self._clear_data_object_cache(table_name)
986994

987995
def _build_create_table_exp(
988996
self,
@@ -1074,6 +1082,7 @@ def clone_table(
10741082
**kwargs,
10751083
)
10761084
)
1085+
self._clear_data_object_cache(target_table_name)
10771086

10781087
def drop_data_object(self, data_object: DataObject, ignore_if_not_exists: bool = True) -> None:
10791088
"""Drops a data object of arbitrary type.
@@ -1139,6 +1148,7 @@ def _drop_object(
11391148
drop_args["cascade"] = cascade
11401149

11411150
self.execute(exp.Drop(this=exp.to_table(name), kind=kind, exists=exists, **drop_args))
1151+
self._clear_data_object_cache(name)
11421152

11431153
def get_alter_operations(
11441154
self,
@@ -1329,6 +1339,8 @@ def create_view(
13291339
quote_identifiers=self.QUOTE_IDENTIFIERS_IN_VIEWS,
13301340
)
13311341

1342+
self._clear_data_object_cache(view_name)
1343+
13321344
# Register table comment with commands if the engine doesn't support doing it in CREATE
13331345
if (
13341346
table_description
@@ -2278,14 +2290,51 @@ def get_data_objects(
22782290
if object_names is not None:
22792291
if not object_names:
22802292
return []
2281-
object_names_list = list(object_names)
2282-
batches = [
2283-
object_names_list[i : i + self.DATA_OBJECT_FILTER_BATCH_SIZE]
2284-
for i in range(0, len(object_names_list), self.DATA_OBJECT_FILTER_BATCH_SIZE)
2285-
]
2286-
return [
2287-
obj for batch in batches for obj in self._get_data_objects(schema_name, set(batch))
2288-
]
2293+
2294+
# Check cache for each object name
2295+
target_schema = to_schema(schema_name)
2296+
cached_objects = []
2297+
missing_names = set()
2298+
2299+
for name in object_names:
2300+
cache_key = _get_data_object_cache_key(
2301+
target_schema.catalog, target_schema.db, name
2302+
)
2303+
if cache_key in self._data_object_cache:
2304+
data_object = self._data_object_cache[cache_key]
2305+
# If the object is none, then the table was previously looked for but not found
2306+
if data_object:
2307+
cached_objects.append(data_object)
2308+
else:
2309+
missing_names.add(name)
2310+
2311+
# Fetch missing objects from database
2312+
if missing_names:
2313+
object_names_list = list(missing_names)
2314+
batches = [
2315+
object_names_list[i : i + self.DATA_OBJECT_FILTER_BATCH_SIZE]
2316+
for i in range(0, len(object_names_list), self.DATA_OBJECT_FILTER_BATCH_SIZE)
2317+
]
2318+
fetched_objects = [
2319+
obj
2320+
for batch in batches
2321+
for obj in self._get_data_objects(schema_name, set(batch))
2322+
]
2323+
2324+
# Cache the fetched objects
2325+
for obj in fetched_objects:
2326+
cache_key = _get_data_object_cache_key(obj.catalog, obj.schema_name, obj.name)
2327+
self._data_object_cache[cache_key] = obj
2328+
2329+
for missing_name in missing_names - {o.name for o in fetched_objects}:
2330+
cache_key = _get_data_object_cache_key(
2331+
target_schema.catalog, target_schema.db, missing_name
2332+
)
2333+
self._data_object_cache[cache_key] = None
2334+
2335+
return cached_objects + fetched_objects
2336+
2337+
return cached_objects
22892338
return self._get_data_objects(schema_name)
22902339

22912340
def fetchone(
@@ -2693,6 +2742,15 @@ def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.An
26932742

26942743
return expression.sql(**sql_gen_kwargs, copy=False) # type: ignore
26952744

2745+
def _clear_data_object_cache(self, table_name: t.Optional[TableName] = None) -> None:
2746+
"""Clears the cache entry for the given table name, or clears the entire cache if table_name is None."""
2747+
if table_name is None:
2748+
self._data_object_cache.clear()
2749+
else:
2750+
table = exp.to_table(table_name)
2751+
cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name)
2752+
self._data_object_cache.pop(cache_key, None)
2753+
26962754
def _get_data_objects(
26972755
self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
26982756
) -> t.List[DataObject]:
@@ -2940,3 +2998,11 @@ def _decoded_str(value: t.Union[str, bytes]) -> str:
29402998
if isinstance(value, bytes):
29412999
return value.decode("utf-8")
29423000
return value
3001+
3002+
3003+
def _get_data_object_cache_key(catalog: t.Optional[str], schema_name: str, object_name: str) -> str:
3004+
"""Returns a cache key for a data object based on its fully qualified name."""
3005+
catalog_part = catalog.lower() if catalog else ""
3006+
schema_part = schema_name.lower() if schema_name else ""
3007+
object_part = object_name.lower()
3008+
return f"{catalog_part}.{schema_part}.{object_part}"

sqlmesh/core/snapshot/evaluator.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,9 @@ def promote(
307307
]
308308
self._create_schemas(gateway_table_pairs=gateway_table_pairs)
309309

310+
# Fetch the view data objects for the promoted snapshots to get them cached
311+
self._get_virtual_data_objects(target_snapshots, environment_naming_info)
312+
310313
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
311314
with self.concurrent_context():
312315
concurrent_apply_to_snapshots(
@@ -425,7 +428,9 @@ def get_snapshots_to_create(
425428
target_snapshots: Target snapshots.
426429
deployability_index: Determines snapshots that are deployable / representative in the context of this creation.
427430
"""
428-
existing_data_objects = self._get_data_objects(target_snapshots, deployability_index)
431+
existing_data_objects = self._get_physical_data_objects(
432+
target_snapshots, deployability_index
433+
)
429434
snapshots_to_create = []
430435
for snapshot in target_snapshots:
431436
if not snapshot.is_model or snapshot.is_symbolic:
@@ -482,7 +487,7 @@ def migrate(
482487
deployability_index: Determines snapshots that are deployable in the context of this evaluation.
483488
"""
484489
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
485-
target_data_objects = self._get_data_objects(target_snapshots, deployability_index)
490+
target_data_objects = self._get_physical_data_objects(target_snapshots, deployability_index)
486491
if not target_data_objects:
487492
return
488493

@@ -1472,7 +1477,7 @@ def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex
14721477
and adapter.table_exists(snapshot.table_name())
14731478
)
14741479

1475-
def _get_data_objects(
1480+
def _get_physical_data_objects(
14761481
self,
14771482
target_snapshots: t.Iterable[Snapshot],
14781483
deployability_index: DeployabilityIndex,
@@ -1488,15 +1493,67 @@ def _get_data_objects(
14881493
A dictionary of snapshot IDs to existing data objects of their physical tables. If the data object
14891494
for a snapshot is not found, it will not be included in the dictionary.
14901495
"""
1496+
return self._get_data_objects(
1497+
target_snapshots,
1498+
lambda s: exp.to_table(
1499+
s.table_name(deployability_index.is_deployable(s)), dialect=s.model.dialect
1500+
),
1501+
)
1502+
1503+
def _get_virtual_data_objects(
1504+
self,
1505+
target_snapshots: t.Iterable[Snapshot],
1506+
environment_naming_info: EnvironmentNamingInfo,
1507+
) -> t.Dict[SnapshotId, DataObject]:
1508+
"""Returns a dictionary of snapshot IDs to existing data objects of their virtual views.
1509+
1510+
Args:
1511+
target_snapshots: Target snapshots.
1512+
environment_naming_info: The environment naming info of the target virtual environment.
1513+
1514+
Returns:
1515+
A dictionary of snapshot IDs to existing data objects of their virtual views. If the data object
1516+
for a snapshot is not found, it will not be included in the dictionary.
1517+
"""
1518+
1519+
def _get_view_name(s: Snapshot) -> exp.Table:
1520+
adapter = (
1521+
self.get_adapter(s.model_gateway)
1522+
if environment_naming_info.gateway_managed
1523+
else self.adapter
1524+
)
1525+
return exp.to_table(
1526+
s.qualified_view_name.for_environment(
1527+
environment_naming_info, dialect=adapter.dialect
1528+
),
1529+
dialect=adapter.dialect,
1530+
)
1531+
1532+
return self._get_data_objects(target_snapshots, _get_view_name)
1533+
1534+
def _get_data_objects(
1535+
self,
1536+
target_snapshots: t.Iterable[Snapshot],
1537+
table_name_callable: t.Callable[[Snapshot], exp.Table],
1538+
) -> t.Dict[SnapshotId, DataObject]:
1539+
"""Returns a dictionary of snapshot IDs to existing data objects.
1540+
1541+
Args:
1542+
target_snapshots: Target snapshots.
1543+
table_name_callable: A function that takes a snapshot and returns the table to look for.
1544+
1545+
Returns:
1546+
A dictionary of snapshot IDs to existing data objects. If the data object for a snapshot is not found,
1547+
it will not be included in the dictionary.
1548+
"""
14911549
tables_by_gateway_and_schema: t.Dict[t.Union[str, None], t.Dict[exp.Table, set[str]]] = (
14921550
defaultdict(lambda: defaultdict(set))
14931551
)
14941552
snapshots_by_table_name: t.Dict[str, Snapshot] = {}
14951553
for snapshot in target_snapshots:
14961554
if not snapshot.is_model or snapshot.is_symbolic:
14971555
continue
1498-
is_deployable = deployability_index.is_deployable(snapshot)
1499-
table = exp.to_table(snapshot.table_name(is_deployable), dialect=snapshot.model.dialect)
1556+
table = table_name_callable(snapshot)
15001557
table_schema = d.schema_(table.db, catalog=table.catalog)
15011558
tables_by_gateway_and_schema[snapshot.model_gateway][table_schema].add(table.name)
15021559
snapshots_by_table_name[table.name] = snapshot

tests/core/engine_adapter/test_athena.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def test_replace_query(adapter: AthenaEngineAdapter, mocker: MockerFixture):
312312
)
313313
mocker.patch.object(adapter, "_get_data_objects", return_value=[])
314314
adapter.cursor.execute.reset_mock()
315+
adapter._clear_data_object_cache()
315316

316317
adapter.s3_warehouse_location = "s3://foo"
317318
adapter.replace_query(

0 commit comments

Comments
 (0)