From 53c477c15c222be22cfac42418cdd8917b10ebde Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Mon, 7 Jul 2025 10:51:10 -0700 Subject: [PATCH] Chore: Remove dead spark integration code --- sqlmesh/engines/commands.py | 184 ----------------------------------- sqlmesh/engines/spark/app.py | 114 ---------------------- 2 files changed, 298 deletions(-) delete mode 100644 sqlmesh/engines/commands.py delete mode 100644 sqlmesh/engines/spark/app.py diff --git a/sqlmesh/engines/commands.py b/sqlmesh/engines/commands.py deleted file mode 100644 index c16e759cbf..0000000000 --- a/sqlmesh/engines/commands.py +++ /dev/null @@ -1,184 +0,0 @@ -import typing as t -from enum import Enum - -from sqlglot import exp -from sqlmesh.core.environment import Environment, EnvironmentNamingInfo -from sqlmesh.core.snapshot import ( - DeployabilityIndex, - Snapshot, - SnapshotEvaluator, - SnapshotId, - SnapshotTableCleanupTask, - SnapshotTableInfo, -) -from sqlmesh.core.state_sync import cleanup_expired_views -from sqlmesh.utils.date import TimeLike -from sqlmesh.utils.errors import AuditError -from sqlmesh.utils.pydantic import PydanticModel - -COMMAND_PAYLOAD_FILE_NAME = "payload.json" - - -class CommandType(str, Enum): - EVALUATE = "evaluate" - PROMOTE = "promote" - DEMOTE = "demote" - CLEANUP = "cleanup" - CREATE_TABLES = "create_tables" - MIGRATE_TABLES = "migrate_tables" - - # This makes it easy to integrate with argparse - def __str__(self) -> str: - return self.value - - -class EvaluateCommandPayload(PydanticModel): - snapshot: Snapshot - parent_snapshots: t.Dict[str, Snapshot] - start: TimeLike - end: TimeLike - execution_time: TimeLike - deployability_index: DeployabilityIndex - batch_index: int - - -class PromoteCommandPayload(PydanticModel): - snapshots: t.List[Snapshot] - environment_naming_info: EnvironmentNamingInfo - deployability_index: DeployabilityIndex - - -class DemoteCommandPayload(PydanticModel): - snapshots: t.List[SnapshotTableInfo] - environment_naming_info: EnvironmentNamingInfo - - -class CleanupCommandPayload(PydanticModel): - environments: t.List[Environment] - tasks: t.List[SnapshotTableCleanupTask] - - -class CreateTablesCommandPayload(PydanticModel): - target_snapshot_ids: t.List[SnapshotId] - snapshots: t.List[Snapshot] - deployability_index: DeployabilityIndex - allow_destructive_snapshots: t.Set[str] - - -class MigrateTablesCommandPayload(PydanticModel): - target_snapshot_ids: t.List[SnapshotId] - snapshots: t.List[Snapshot] - allow_destructive_snapshots: t.Set[str] - - -def evaluate( - evaluator: SnapshotEvaluator, command_payload: t.Union[str, EvaluateCommandPayload] -) -> None: - if isinstance(command_payload, str): - command_payload = EvaluateCommandPayload.parse_raw(command_payload) - - parent_snapshots = command_payload.parent_snapshots - parent_snapshots[command_payload.snapshot.name] = command_payload.snapshot - - wap_id = evaluator.evaluate( - command_payload.snapshot, - start=command_payload.start, - end=command_payload.end, - execution_time=command_payload.execution_time, - snapshots=parent_snapshots, - deployability_index=command_payload.deployability_index, - batch_index=command_payload.batch_index, - ) - audit_results = evaluator.audit( - snapshot=command_payload.snapshot, - start=command_payload.start, - end=command_payload.end, - execution_time=command_payload.execution_time, - snapshots=parent_snapshots, - deployability_index=command_payload.deployability_index, - wap_id=wap_id, - ) - - failed_audit_result = next((r for r in audit_results if r.count and r.blocking), None) - if failed_audit_result: - raise AuditError( - audit_name=failed_audit_result.audit.name, - audit_args=failed_audit_result.audit_args, - model=command_payload.snapshot.model_or_none, - count=t.cast(int, failed_audit_result.count), - query=t.cast(exp.Query, failed_audit_result.query), - adapter_dialect=evaluator.adapter.dialect, - ) - - -def promote( - evaluator: SnapshotEvaluator, command_payload: t.Union[str, PromoteCommandPayload] -) -> None: - if isinstance(command_payload, str): - command_payload = PromoteCommandPayload.parse_raw(command_payload) - evaluator.promote( - command_payload.snapshots, - command_payload.environment_naming_info, - deployability_index=command_payload.deployability_index, - ) - - -def demote( - evaluator: SnapshotEvaluator, command_payload: t.Union[str, DemoteCommandPayload] -) -> None: - if isinstance(command_payload, str): - command_payload = DemoteCommandPayload.parse_raw(command_payload) - evaluator.demote( - command_payload.snapshots, - command_payload.environment_naming_info, - ) - - -def cleanup( - evaluator: SnapshotEvaluator, command_payload: t.Union[str, CleanupCommandPayload] -) -> None: - if isinstance(command_payload, str): - command_payload = CleanupCommandPayload.parse_raw(command_payload) - - cleanup_expired_views(evaluator.adapter, evaluator.adapters, command_payload.environments) - evaluator.cleanup(command_payload.tasks) - - -def create_tables( - evaluator: SnapshotEvaluator, - command_payload: t.Union[str, CreateTablesCommandPayload], -) -> None: - if isinstance(command_payload, str): - command_payload = CreateTablesCommandPayload.parse_raw(command_payload) - - snapshots_by_id = {s.snapshot_id: s for s in command_payload.snapshots} - target_snapshots = [snapshots_by_id[sid] for sid in command_payload.target_snapshot_ids] - evaluator.create( - target_snapshots, - snapshots_by_id, - deployability_index=command_payload.deployability_index, - allow_destructive_snapshots=command_payload.allow_destructive_snapshots, - ) - - -def migrate_tables( - evaluator: SnapshotEvaluator, - command_payload: t.Union[str, MigrateTablesCommandPayload], -) -> None: - if isinstance(command_payload, str): - command_payload = MigrateTablesCommandPayload.parse_raw(command_payload) - snapshots_by_id = {s.snapshot_id: s for s in command_payload.snapshots} - target_snapshots = [snapshots_by_id[sid] for sid in command_payload.target_snapshot_ids] - evaluator.migrate( - target_snapshots, snapshots_by_id, command_payload.allow_destructive_snapshots - ) - - -COMMAND_HANDLERS: t.Dict[CommandType, t.Callable[[SnapshotEvaluator, str], None]] = { - CommandType.EVALUATE: evaluate, - CommandType.PROMOTE: promote, - CommandType.DEMOTE: demote, - CommandType.CLEANUP: cleanup, - CommandType.CREATE_TABLES: create_tables, - CommandType.MIGRATE_TABLES: migrate_tables, -} diff --git a/sqlmesh/engines/spark/app.py b/sqlmesh/engines/spark/app.py deleted file mode 100644 index a8709361fa..0000000000 --- a/sqlmesh/engines/spark/app.py +++ /dev/null @@ -1,114 +0,0 @@ -import argparse -import logging -import os -import tempfile - -from pyspark import SparkFiles -from pyspark.sql import SparkSession - -from sqlmesh.core.engine_adapter import create_engine_adapter -from sqlmesh.core.snapshot import SnapshotEvaluator -from sqlmesh.engines import commands -from sqlmesh.engines.spark.db_api import spark_session as spark_session_db -from sqlmesh.engines.spark.db_api.errors import NotSupportedError -from sqlmesh.utils.errors import SQLMeshError - -logger = logging.getLogger(__name__) - - -def get_or_create_spark_session(dialect: str) -> SparkSession: - if dialect == "databricks": - spark = SparkSession.getActiveSession() - if not spark: - raise SQLMeshError("Could not find an active SparkSession.") - return spark - return ( - SparkSession.builder.config("spark.scheduler.mode", "FAIR") - .enableHiveSupport() - .getOrCreate() - ) - - -def main( - dialect: str, - default_catalog: str, - command_type: commands.CommandType, - ddl_concurrent_tasks: int, - payload_path: str, -) -> None: - if dialect not in ("databricks", "spark"): - raise NotSupportedError( - f"Dialect '{dialect}' not supported. Must be either 'databricks' or 'spark'" - ) - logging.basicConfig( - format="%(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)", - level=logging.INFO, - ) - command_handler = commands.COMMAND_HANDLERS.get(command_type) - if not command_handler: - raise NotSupportedError(f"Command '{command_type.value}' not supported") - - spark = get_or_create_spark_session(dialect) - - evaluator = SnapshotEvaluator( - create_engine_adapter( - lambda: spark_session_db.connection(spark), - dialect, - default_catalog=default_catalog, - multithreaded=ddl_concurrent_tasks > 1, - execute_log_level=logging.INFO, - ), - ddl_concurrent_tasks=ddl_concurrent_tasks, - ) - if dialect == "spark": - with open(SparkFiles.get(payload_path), "r", encoding="utf-8") as payload_fd: - command_payload = payload_fd.read() - else: - from pyspark.dbutils import DBUtils # type: ignore - - dbutils = DBUtils(spark) - with tempfile.TemporaryDirectory() as tmp: - local_payload_path = os.path.join(tmp, commands.COMMAND_PAYLOAD_FILE_NAME) - dbutils.fs.cp(payload_path, f"file://{local_payload_path}") - with open(local_payload_path, "r", encoding="utf-8") as payload_fd: - command_payload = payload_fd.read() - logger.info("Command payload:\n %s", command_payload) - command_handler(evaluator, command_payload) - - evaluator.close() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="SQLMesh Spark Submit App") - parser.add_argument( - "--dialect", - help="The dialect to use when creating the engine adapter.", - ) - parser.add_argument( - "--default_catalog", - help="The default catalog to use when creating the engine adapter.", - ) - parser.add_argument( - "--command_type", - type=commands.CommandType, - choices=list(commands.CommandType), - help="The type of command that is being run", - ) - parser.add_argument( - "--ddl_concurrent_tasks", - type=int, - default=1, - help="The number of ddl concurrent tasks to use. Default to 1.", - ) - parser.add_argument( - "--payload_path", - help="Path to the payload object. Can be a local or remote path.", - ) - args = parser.parse_args() - main( - args.dialect, - args.default_catalog, - args.command_type, - args.ddl_concurrent_tasks, - args.payload_path, - )