diff --git a/src/clgraph/__init__.py b/src/clgraph/__init__.py index cb435d8..29c3b07 100644 --- a/src/clgraph/__init__.py +++ b/src/clgraph/__init__.py @@ -18,11 +18,17 @@ from .agent import AgentResult, LineageAgent, QuestionType from .diff import ColumnDiff, PipelineDiff +# Import execution functionality +from .execution import PipelineExecutor + # Import export functionality from .export import CSVExporter, JSONExporter # Import validation models from .models import IssueCategory, IssueSeverity, ValidationIssue + +# Import orchestrator integrations +from .orchestrators import AirflowOrchestrator, DagsterOrchestrator from .parser import ( ColumnEdge, ColumnLineageGraph, @@ -177,4 +183,9 @@ "BASIC_TOOLS", "LLM_TOOLS", "ALL_TOOLS", + # Orchestrator integrations + "AirflowOrchestrator", + "DagsterOrchestrator", + # Execution + "PipelineExecutor", ] diff --git a/src/clgraph/execution.py b/src/clgraph/execution.py new file mode 100644 index 0000000..a897b8a --- /dev/null +++ b/src/clgraph/execution.py @@ -0,0 +1,285 @@ +""" +Pipeline execution module for clgraph. + +Provides synchronous and asynchronous execution of SQL pipelines +with concurrent execution within dependency levels. +""" + +import asyncio +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple + +if TYPE_CHECKING: + from .pipeline import Pipeline + + +class PipelineExecutor: + """ + Executes clgraph pipelines with concurrent execution support. + + Provides both synchronous and asynchronous execution modes, + with automatic parallelization within dependency levels. + + Example: + from clgraph.execution import PipelineExecutor + + executor = PipelineExecutor(pipeline) + result = executor.run(execute_sql, max_workers=4) + + # Or async + result = await executor.async_run(async_execute_sql, max_workers=4) + """ + + def __init__(self, pipeline: "Pipeline") -> None: + """ + Initialize executor with a Pipeline instance. + + Args: + pipeline: The clgraph Pipeline to execute + """ + self.pipeline = pipeline + self.table_graph = pipeline.table_graph + + def get_execution_levels(self) -> List[List[str]]: + """ + Group queries into levels for concurrent execution. + + Level 0: Queries with no dependencies + Level 1: Queries that depend only on Level 0 + Level 2: Queries that depend on Level 0 or 1 + etc. + + Queries in the same level can run concurrently. + + Returns: + List of levels, where each level is a list of query IDs + """ + levels = [] + completed = set() + + while len(completed) < len(self.table_graph.queries): + current_level = [] + + for query_id, query in self.table_graph.queries.items(): + if query_id in completed: + continue + + # Check if all dependencies are completed + dependencies_met = True + for source_table in query.source_tables: + # Find query that creates this table + table_node = self.table_graph.tables.get(source_table) + if table_node and table_node.created_by: + if table_node.created_by not in completed: + dependencies_met = False + break + + if dependencies_met: + current_level.append(query_id) + + if not current_level: + # No progress - circular dependency + raise RuntimeError("Circular dependency detected in pipeline") + + levels.append(current_level) + completed.update(current_level) + + return levels + + def run( + self, + executor: Callable[[str], None], + max_workers: int = 4, + verbose: bool = True, + ) -> Dict[str, Any]: + """ + Execute pipeline synchronously with concurrent execution. + + Args: + executor: Function that executes SQL (takes sql string) + max_workers: Max concurrent workers (default: 4) + verbose: Print progress (default: True) + + Returns: + dict with execution results: { + "completed": list of completed query IDs, + "failed": list of (query_id, error) tuples, + "elapsed_seconds": total execution time, + "total_queries": total number of queries + } + + Example: + def execute_sql(sql: str): + import duckdb + conn = duckdb.connect() + conn.execute(sql) + + result = executor.run(execute_sql, max_workers=4) + print(f"Completed {len(result['completed'])} queries") + """ + if verbose: + print(f"šŸš€ Starting pipeline execution ({len(self.table_graph.queries)} queries)") + print() + + # Track completed queries + completed = set() + failed: List[Tuple[str, str]] = [] + start_time = time.time() + + # Group queries by level for concurrent execution + levels = self.get_execution_levels() + + # Execute level by level + for level_num, level_queries in enumerate(levels, 1): + if verbose: + print(f"šŸ“Š Level {level_num}: {len(level_queries)} queries") + + # Execute queries in this level concurrently + with ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = {} + + for query_id in level_queries: + query = self.table_graph.queries[query_id] + future = pool.submit(executor, query.sql) + futures[future] = query_id + + # Wait for completion + for future in as_completed(futures): + query_id = futures[future] + + try: + future.result() + completed.add(query_id) + + if verbose: + print(f" āœ… {query_id}") + except Exception as e: + failed.append((query_id, str(e))) + + if verbose: + print(f" āŒ {query_id}: {e}") + + if verbose: + print() + + elapsed = time.time() - start_time + + # Summary + if verbose: + print("=" * 60) + print(f"āœ… Pipeline completed in {elapsed:.2f}s") + print(f" Successful: {len(completed)}") + print(f" Failed: {len(failed)}") + if failed: + print("\nāš ļø Failed queries:") + for query_id, error in failed: + print(f" - {query_id}: {error}") + print("=" * 60) + + return { + "completed": list(completed), + "failed": failed, + "elapsed_seconds": elapsed, + "total_queries": len(self.table_graph.queries), + } + + async def async_run( + self, + executor: Callable[[str], Awaitable[None]], + max_workers: int = 4, + verbose: bool = True, + ) -> Dict[str, Any]: + """ + Execute pipeline asynchronously with concurrent execution. + + Args: + executor: Async function that executes SQL (takes sql string) + max_workers: Max concurrent workers (controls semaphore, default: 4) + verbose: Print progress (default: True) + + Returns: + dict with execution results: { + "completed": list of completed query IDs, + "failed": list of (query_id, error) tuples, + "elapsed_seconds": total execution time, + "total_queries": total number of queries + } + + Example: + async def execute_sql(sql: str): + # Your async database connection + await async_conn.execute(sql) + + result = await executor.async_run(execute_sql, max_workers=4) + print(f"Completed {len(result['completed'])} queries") + """ + if verbose: + print(f"šŸš€ Starting async pipeline execution ({len(self.table_graph.queries)} queries)") + print() + + # Track completed queries + completed = set() + failed: List[Tuple[str, str]] = [] + start_time = time.time() + + # Group queries by level for concurrent execution + levels = self.get_execution_levels() + + # Create semaphore to limit concurrency + semaphore = asyncio.Semaphore(max_workers) + + # Execute level by level + for level_num, level_queries in enumerate(levels, 1): + if verbose: + print(f"šŸ“Š Level {level_num}: {len(level_queries)} queries") + + async def execute_with_semaphore(query_id: str, sql: str): + """Execute query with semaphore for concurrency control""" + async with semaphore: + try: + await executor(sql) + completed.add(query_id) + if verbose: + print(f" āœ… {query_id}") + except Exception as e: + failed.append((query_id, str(e))) + if verbose: + print(f" āŒ {query_id}: {e}") + + # Execute queries in this level concurrently + tasks = [] + for query_id in level_queries: + query = self.table_graph.queries[query_id] + task = execute_with_semaphore(query_id, query.sql) + tasks.append(task) + + # Wait for all tasks in this level to complete + await asyncio.gather(*tasks) + + if verbose: + print() + + elapsed = time.time() - start_time + + # Summary + if verbose: + print("=" * 60) + print(f"āœ… Pipeline completed in {elapsed:.2f}s") + print(f" Successful: {len(completed)}") + print(f" Failed: {len(failed)}") + if failed: + print("\nāš ļø Failed queries:") + for query_id, error in failed: + print(f" - {query_id}: {error}") + print("=" * 60) + + return { + "completed": list(completed), + "failed": failed, + "elapsed_seconds": elapsed, + "total_queries": len(self.table_graph.queries), + } + + +__all__ = ["PipelineExecutor"] diff --git a/src/clgraph/orchestrators/__init__.py b/src/clgraph/orchestrators/__init__.py new file mode 100644 index 0000000..edf0311 --- /dev/null +++ b/src/clgraph/orchestrators/__init__.py @@ -0,0 +1,34 @@ +""" +Orchestrator integrations for clgraph. + +This package provides integrations with various workflow orchestrators, +allowing clgraph pipelines to be deployed to production environments. + +Supported orchestrators: +- Airflow (2.x and 3.x) +- Dagster (1.x) + +Example: + from clgraph import Pipeline + from clgraph.orchestrators import AirflowOrchestrator, DagsterOrchestrator + + pipeline = Pipeline.from_sql_files("queries/", dialect="bigquery") + + # Generate Airflow DAG + airflow = AirflowOrchestrator(pipeline) + dag = airflow.to_dag(executor=execute_sql, dag_id="my_pipeline") + + # Generate Dagster assets + dagster = DagsterOrchestrator(pipeline) + assets = dagster.to_assets(executor=execute_sql, group_name="analytics") +""" + +from .airflow import AirflowOrchestrator +from .base import BaseOrchestrator +from .dagster import DagsterOrchestrator + +__all__ = [ + "BaseOrchestrator", + "AirflowOrchestrator", + "DagsterOrchestrator", +] diff --git a/src/clgraph/orchestrators/airflow.py b/src/clgraph/orchestrators/airflow.py new file mode 100644 index 0000000..1025165 --- /dev/null +++ b/src/clgraph/orchestrators/airflow.py @@ -0,0 +1,183 @@ +""" +Airflow orchestrator integration for clgraph. + +Converts clgraph pipelines to Airflow DAGs using the TaskFlow API. +Supports both Airflow 2.x and 3.x. +""" + +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Callable, Optional + +from .base import BaseOrchestrator + +if TYPE_CHECKING: + pass + + +class AirflowOrchestrator(BaseOrchestrator): + """ + Converts clgraph pipelines to Airflow DAGs. + + Uses the TaskFlow API (@dag and @task decorators) which is compatible + across both Airflow 2.x and 3.x versions. + + Example: + from clgraph.orchestrators import AirflowOrchestrator + + orchestrator = AirflowOrchestrator(pipeline) + dag = orchestrator.to_dag( + executor=execute_sql, + dag_id="my_pipeline", + schedule="@daily", + ) + """ + + def to_dag( + self, + executor: Callable[[str], None], + dag_id: str, + schedule: str = "@daily", + start_date: Optional[datetime] = None, + default_args: Optional[dict] = None, + airflow_version: Optional[str] = None, + **dag_kwargs, + ): + """ + Create Airflow DAG from the pipeline using TaskFlow API. + + Supports both Airflow 2.x and 3.x. The TaskFlow API (@dag and @task decorators) + is fully compatible across both versions. + + Args: + executor: Function that executes SQL (takes sql string) + dag_id: Airflow DAG ID + schedule: Schedule interval (default: "@daily") + start_date: DAG start date (default: datetime(2024, 1, 1)) + default_args: Airflow default_args (default: owner='data_team', retries=2) + airflow_version: Optional Airflow version ("2" or "3"). + Auto-detected from installed Airflow if not provided. + **dag_kwargs: Additional DAG parameters (catchup, tags, max_active_runs, + description, max_active_tasks, dagrun_timeout, etc.) + See Airflow DAG documentation for all available parameters. + + Returns: + Airflow DAG instance + + Examples: + # Basic usage (auto-detects Airflow version) + dag = orchestrator.to_dag( + executor=execute_sql, + dag_id="my_pipeline" + ) + + # Explicit version specification (for testing) + dag = orchestrator.to_dag( + executor=execute_sql, + dag_id="my_pipeline", + airflow_version="3" + ) + + # Advanced usage with all DAG parameters + dag = orchestrator.to_dag( + executor=execute_sql, + dag_id="my_pipeline", + schedule="0 0 * * *", # Daily at midnight + description="Customer analytics pipeline", + catchup=False, + max_active_runs=3, + max_active_tasks=10, + tags=["analytics", "daily"], + ) + + Note: + - Airflow 2.x: Fully supported (2.7.0+) + - Airflow 3.x: Fully supported (3.0.0+) + - TaskFlow API is compatible across both versions + """ + try: + import airflow # type: ignore[import-untyped] + from airflow.decorators import dag, task # type: ignore[import-untyped] + except ImportError as e: + raise ImportError( + "Airflow is required for DAG generation. " + "Install it with:\n" + " - Airflow 2.x: pip install 'apache-airflow>=2.7.0,<3.0.0'\n" + " - Airflow 3.x: pip install 'apache-airflow>=3.0.0'" + ) from e + + # Detect Airflow version if not specified + if airflow_version is None: + detected_version = airflow.__version__ + major_version = int(detected_version.split(".")[0]) + airflow_version = str(major_version) + + # Validate version + if airflow_version not in ("2", "3"): + raise ValueError( + f"Unsupported Airflow version: {airflow_version}. Supported versions: 2, 3" + ) + + if start_date is None: + start_date = datetime(2024, 1, 1) + + if default_args is None: + default_args = { + "owner": "data_team", + "retries": 2, + "retry_delay": timedelta(minutes=5), + } + + # Build DAG parameters + dag_params = { + "dag_id": dag_id, + "schedule": schedule, + "start_date": start_date, + "default_args": default_args, + **dag_kwargs, # Allow user to override any parameter + } + + # Set default values only if not provided by user + dag_params.setdefault("catchup", False) + dag_params.setdefault("tags", ["clgraph"]) + + table_graph = self.table_graph + + @dag(**dag_params) + def pipeline_dag(): + """Generated pipeline DAG""" + + # Create task callables for each query + task_callables = {} + + for query_id in table_graph.topological_sort(): + query = table_graph.queries[query_id] + sql_to_execute = query.sql + + # Create task with unique function name using closure + def make_task(qid, sql): + @task(task_id=qid.replace("-", "_")) + def execute_query(): + """Execute SQL query""" + executor(sql) + return f"Completed: {qid}" + + return execute_query + + task_callables[query_id] = make_task(query_id, sql_to_execute) + + # Instantiate all tasks once before wiring dependencies + task_instances = {qid: callable() for qid, callable in task_callables.items()} + + # Set up dependencies based on table lineage + for _table_name, table_node in table_graph.tables.items(): + if table_node.created_by: + upstream_id = table_node.created_by + for downstream_id in table_node.read_by: + if upstream_id in task_instances and downstream_id in task_instances: + # Airflow: downstream >> upstream means upstream runs first + task_instances[upstream_id] >> task_instances[downstream_id] + + return pipeline_dag() + + +__all__ = ["AirflowOrchestrator"] diff --git a/src/clgraph/orchestrators/base.py b/src/clgraph/orchestrators/base.py new file mode 100644 index 0000000..fdfd8b6 --- /dev/null +++ b/src/clgraph/orchestrators/base.py @@ -0,0 +1,106 @@ +""" +Base classes and protocols for orchestrator integrations. + +This module defines the interface that all orchestrator integrations must follow. +""" + +from typing import TYPE_CHECKING, List, Protocol + +if TYPE_CHECKING: + from ..pipeline import Pipeline + + +class OrchestratorProtocol(Protocol): + """Protocol defining the interface for orchestrator integrations.""" + + def __init__(self, pipeline: "Pipeline") -> None: + """Initialize with a Pipeline instance.""" + ... + + +class BaseOrchestrator: + """ + Base class for orchestrator integrations. + + Provides common functionality and enforces the interface for + converting clgraph pipelines to orchestrator-specific formats. + + Subclasses should implement orchestrator-specific methods like: + - to_dag() for Airflow + - to_assets() / to_job() for Dagster + - to_flow() for Prefect + """ + + def __init__(self, pipeline: "Pipeline") -> None: + """ + Initialize orchestrator with a Pipeline instance. + + Args: + pipeline: The clgraph Pipeline to convert + """ + self.pipeline = pipeline + self.table_graph = pipeline.table_graph + + def _get_execution_levels(self) -> List[List[str]]: + """ + Group queries into levels for concurrent execution. + + Level 0: Queries with no dependencies + Level 1: Queries that depend only on Level 0 + Level 2: Queries that depend on Level 0 or 1 + etc. + + Queries in the same level can run concurrently. + + Returns: + List of levels, where each level is a list of query IDs + """ + levels = [] + completed = set() + + while len(completed) < len(self.table_graph.queries): + current_level = [] + + for query_id, query in self.table_graph.queries.items(): + if query_id in completed: + continue + + # Check if all dependencies are completed + dependencies_met = True + for source_table in query.source_tables: + # Find query that creates this table + table_node = self.table_graph.tables.get(source_table) + if table_node and table_node.created_by: + if table_node.created_by not in completed: + dependencies_met = False + break + + if dependencies_met: + current_level.append(query_id) + + if not current_level: + # No progress - circular dependency + raise RuntimeError("Circular dependency detected in pipeline") + + levels.append(current_level) + completed.update(current_level) + + return levels + + def _sanitize_name(self, name: str) -> str: + """ + Sanitize a name for use in orchestrator identifiers. + + Replaces dots and dashes with underscores to ensure compatibility + with orchestrator naming requirements. + + Args: + name: The name to sanitize + + Returns: + Sanitized name safe for use as identifier + """ + return name.replace(".", "_").replace("-", "_") + + +__all__ = ["BaseOrchestrator", "OrchestratorProtocol"] diff --git a/src/clgraph/orchestrators/dagster.py b/src/clgraph/orchestrators/dagster.py new file mode 100644 index 0000000..545a9a1 --- /dev/null +++ b/src/clgraph/orchestrators/dagster.py @@ -0,0 +1,297 @@ +""" +Dagster orchestrator integration for clgraph. + +Converts clgraph pipelines to Dagster assets and jobs. +""" + +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union + +from .base import BaseOrchestrator + +if TYPE_CHECKING: + pass + + +class DagsterOrchestrator(BaseOrchestrator): + """ + Converts clgraph pipelines to Dagster assets and jobs. + + Supports both asset-based (recommended) and job-based approaches. + Assets provide better lineage tracking and observability in Dagster. + + Example: + from clgraph.orchestrators import DagsterOrchestrator + + orchestrator = DagsterOrchestrator(pipeline) + assets = orchestrator.to_assets( + executor=execute_sql, + group_name="analytics", + ) + + # Or for job-based approach + job = orchestrator.to_job( + executor=execute_sql, + job_name="analytics_pipeline", + ) + """ + + def to_assets( + self, + executor: Callable[[str], None], + group_name: Optional[str] = None, + key_prefix: Optional[Union[str, List[str]]] = None, + compute_kind: str = "sql", + **asset_kwargs, + ) -> List: + """ + Create Dagster Assets from the pipeline. + + Converts the pipeline's table dependency graph into Dagster assets + where each target table becomes an asset with proper dependencies. + This is the recommended approach for Dagster as it provides better + lineage tracking and observability. + + Args: + executor: Function that executes SQL (takes sql string) + group_name: Optional asset group name for organization in Dagster UI + key_prefix: Optional prefix for asset keys (e.g., ["warehouse", "analytics"]) + compute_kind: Compute kind tag for assets (default: "sql") + **asset_kwargs: Additional asset parameters (owners, tags, etc.) + + Returns: + List of Dagster Asset definitions + + Examples: + # Basic usage + assets = orchestrator.to_assets( + executor=execute_sql, + group_name="analytics" + ) + + # Create Dagster Definitions + from dagster import Definitions + defs = Definitions(assets=assets) + + # Advanced usage with prefixes and metadata + assets = orchestrator.to_assets( + executor=execute_sql, + group_name="warehouse", + key_prefix=["prod", "analytics"], + compute_kind="clickhouse", + owners=["team:data-eng"], + tags={"domain": "finance"}, + ) + + Note: + - Requires Dagster 1.x: pip install 'dagster>=1.5.0' + - Each target table becomes a Dagster asset + - Dependencies are automatically inferred from table lineage + - Deployment: Drop the definitions.py file in your Dagster workspace + """ + try: + import dagster as dg + except ImportError as e: + raise ImportError( + "Dagster is required for asset generation. " + "Install it with: pip install 'dagster>=1.5.0'" + ) from e + + table_graph = self.table_graph + + assets = [] + asset_key_mapping: Dict[str, Any] = {} # query_id -> AssetKey + + # Process each query that creates a table + for query_id in table_graph.topological_sort(): + query = table_graph.queries[query_id] + target_table = query.destination_table + + if target_table is None: + continue # Skip queries that don't create tables + + # Determine upstream dependencies (source tables created by this pipeline) + upstream_asset_keys = [] + for source_table in query.source_tables: + if source_table in table_graph.tables: + table_node = table_graph.tables[source_table] + # Only add as dep if it's created by another query in this pipeline + if table_node.created_by and table_node.created_by in asset_key_mapping: + upstream_asset_keys.append(asset_key_mapping[table_node.created_by]) + + # Build asset key (sanitize table name for Dagster compatibility) + # Dagster names must match ^[A-Za-z0-9_]+$ + safe_table_name = self._sanitize_name(target_table) + if key_prefix: + if isinstance(key_prefix, str): + prefix_list = [key_prefix] + else: + prefix_list = list(key_prefix) + asset_key = dg.AssetKey([*prefix_list, safe_table_name]) + else: + asset_key = dg.AssetKey(safe_table_name) + + # Store mapping for dependency resolution + asset_key_mapping[query_id] = asset_key + + # Capture SQL in closure + sql_to_execute = query.sql + table_name = target_table + query_identifier = query_id + + # Build asset configuration + asset_config: Dict[str, Any] = { + "key": asset_key, + "compute_kind": compute_kind, + **asset_kwargs, + } + + if group_name: + asset_config["group_name"] = group_name + + if upstream_asset_keys: + asset_config["deps"] = upstream_asset_keys + + # Create asset factory function + def make_asset(qid: str, sql: str, tbl: str, config: Dict[str, Any], exec_fn: Callable): + @dg.asset(**config) + def sql_asset(context: dg.AssetExecutionContext): + """Execute SQL to materialize asset.""" + context.log.info(f"Materializing table: {tbl}") + context.log.info(f"Query ID: {qid}") + context.log.debug(f"SQL: {sql[:500]}...") + exec_fn(sql) + return dg.MaterializeResult( + metadata={ + "query_id": dg.MetadataValue.text(qid), + "table": dg.MetadataValue.text(tbl), + } + ) + + # Rename function for better debugging in Dagster UI + safe_name = tbl.replace(".", "_").replace("-", "_") + sql_asset.__name__ = safe_name + sql_asset.__qualname__ = safe_name + + return sql_asset + + asset = make_asset(query_identifier, sql_to_execute, table_name, asset_config, executor) + assets.append(asset) + + return assets + + def to_job( + self, + executor: Callable[[str], None], + job_name: str, + description: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + **job_kwargs, + ): + """ + Create Dagster Job from the pipeline using ops. + + Converts the pipeline's table dependency graph into a Dagster job + where each SQL query becomes an op with proper dependencies. + + Note: For new pipelines, consider using to_assets() instead, + which provides better lineage tracking and observability in Dagster. + + Args: + executor: Function that executes SQL (takes sql string) + job_name: Name for the Dagster job + description: Optional job description (auto-generated if not provided) + tags: Optional job tags for filtering in Dagster UI + **job_kwargs: Additional job parameters + + Returns: + Dagster Job definition + + Examples: + # Basic usage + job = orchestrator.to_job( + executor=execute_sql, + job_name="analytics_pipeline" + ) + + # Create Dagster Definitions + from dagster import Definitions + defs = Definitions(jobs=[job]) + + # Execute the job locally + result = job.execute_in_process() + + Note: + - Requires Dagster 1.x: pip install 'dagster>=1.5.0' + - Consider using to_assets() for better Dagster integration + - Deployment: Drop the definitions.py file in your Dagster workspace + """ + try: + import dagster as dg + except ImportError as e: + raise ImportError( + "Dagster is required for job generation. " + "Install it with: pip install 'dagster>=1.5.0'" + ) from e + + table_graph = self.table_graph + + # Generate description if not provided + if description is None: + query_count = len(table_graph.queries) + table_count = len(table_graph.tables) + description = ( + f"Pipeline with {query_count} queries operating on {table_count} tables. " + f"Generated by clgraph." + ) + + # Create ops for each query + ops: Dict[str, Any] = {} + op_mapping: Dict[str, str] = {} # query_id -> op_name + + for query_id in table_graph.topological_sort(): + query = table_graph.queries[query_id] + sql_to_execute = query.sql + + # Generate safe op name + op_name = self._sanitize_name(query_id) + op_mapping[query_id] = op_name + + def make_op(qid: str, sql: str, name: str, exec_fn: Callable): + @dg.op(name=name) + def sql_op(context: dg.OpExecutionContext): + """Execute SQL query.""" + context.log.info(f"Executing query: {qid}") + exec_fn(sql) + return qid + + return sql_op + + ops[query_id] = make_op(query_id, sql_to_execute, op_name, executor) + + # Build the job graph + @dg.job(name=job_name, description=description, tags=tags or {}, **job_kwargs) + def pipeline_job(): + """Generated pipeline job.""" + op_results: Dict[str, Any] = {} + + for query_id in table_graph.topological_sort(): + # Find upstream dependencies + query = table_graph.queries[query_id] + upstream_results = [] + + for source_table in query.source_tables: + if source_table in table_graph.tables: + table_node = table_graph.tables[source_table] + if table_node.created_by and table_node.created_by in op_results: + upstream_results.append(op_results[table_node.created_by]) + + # Execute op - dependencies are implicit via the graph structure + # In Dagster, we need to wire dependencies differently + op_results[query_id] = ops[query_id]() + + return op_results + + return pipeline_job + + +__all__ = ["DagsterOrchestrator"] diff --git a/src/clgraph/pipeline.py b/src/clgraph/pipeline.py index a89c5d1..eb1e926 100644 --- a/src/clgraph/pipeline.py +++ b/src/clgraph/pipeline.py @@ -9,11 +9,8 @@ - Airflow DAG generation """ -import asyncio -import time -from concurrent.futures import ThreadPoolExecutor, as_completed -from datetime import datetime, timedelta -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple +from datetime import datetime +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union from sqlglot import exp @@ -2087,37 +2084,9 @@ def _get_execution_levels(self) -> List[List[str]]: Returns: List of levels, where each level is a list of query IDs """ - levels = [] - completed = set() + from .execution import PipelineExecutor - while len(completed) < len(self.table_graph.queries): - current_level = [] - - for query_id, query in self.table_graph.queries.items(): - if query_id in completed: - continue - - # Check if all dependencies are completed - dependencies_met = True - for source_table in query.source_tables: - # Find query that creates this table - table_node = self.table_graph.tables.get(source_table) - if table_node and table_node.created_by: - if table_node.created_by not in completed: - dependencies_met = False - break - - if dependencies_met: - current_level.append(query_id) - - if not current_level: - # No progress - circular dependency - raise RuntimeError("Circular dependency detected in pipeline") - - levels.append(current_level) - completed.update(current_level) - - return levels + return PipelineExecutor(self).get_execution_levels() def to_airflow_dag( self, @@ -2126,12 +2095,14 @@ def to_airflow_dag( schedule: str = "@daily", start_date: Optional[datetime] = None, default_args: Optional[dict] = None, + airflow_version: Optional[str] = None, **dag_kwargs, ): """ Create Airflow DAG from this pipeline using TaskFlow API. - Supports all Airflow DAG parameters via **dag_kwargs for complete flexibility. + Supports both Airflow 2.x and 3.x. The TaskFlow API (@dag and @task decorators) + is fully compatible across both versions. Args: executor: Function that executes SQL (takes sql string) @@ -2139,6 +2110,8 @@ def to_airflow_dag( schedule: Schedule interval (default: "@daily") start_date: DAG start date (default: datetime(2024, 1, 1)) default_args: Airflow default_args (default: owner='data_team', retries=2) + airflow_version: Optional Airflow version ("2" or "3"). + Auto-detected from installed Airflow if not provided. **dag_kwargs: Additional DAG parameters (catchup, tags, max_active_runs, description, max_active_tasks, dagrun_timeout, etc.) See Airflow DAG documentation for all available parameters. @@ -2147,7 +2120,7 @@ def to_airflow_dag( Airflow DAG instance Examples: - # Basic usage + # Basic usage (auto-detects Airflow version) def execute_sql(sql: str): from google.cloud import bigquery client = bigquery.Client() @@ -2158,6 +2131,13 @@ def execute_sql(sql: str): dag_id="my_pipeline" ) + # Explicit version specification (for testing) + dag = pipeline.to_airflow_dag( + executor=execute_sql, + dag_id="my_pipeline", + airflow_version="3" + ) + # Advanced usage with all DAG parameters dag = pipeline.to_airflow_dag( executor=execute_sql, @@ -2168,82 +2148,24 @@ def execute_sql(sql: str): max_active_runs=3, max_active_tasks=10, tags=["analytics", "daily"], - default_view="graph", # Airflow 2.x only - orientation="LR", # Airflow 2.x only ) Note: - Currently supports Airflow 2.x only. Airflow 3.x support is planned. - """ - try: - from airflow.decorators import dag, task # type: ignore[import-untyped] - except ImportError as e: - raise ImportError( - "Airflow is required for DAG generation. " - "Install it with: pip install 'apache-airflow>=2.7.0,<3.0.0'" - ) from e - - if start_date is None: - start_date = datetime(2024, 1, 1) - - if default_args is None: - default_args = { - "owner": "data_team", - "retries": 2, - "retry_delay": timedelta(minutes=5), - } - - # Build DAG parameters - dag_params = { - "dag_id": dag_id, - "schedule": schedule, - "start_date": start_date, - "default_args": default_args, - **dag_kwargs, # Allow user to override any parameter - } - - # Set default values only if not provided by user - dag_params.setdefault("catchup", False) - dag_params.setdefault("tags", ["clgraph"]) - - table_graph = self.table_graph - - @dag(**dag_params) - def pipeline_dag(): - """Generated pipeline DAG""" - - # Create task callables for each query - task_callables = {} - - for query_id in table_graph.topological_sort(): - query = table_graph.queries[query_id] - sql_to_execute = query.sql - - # Create task with unique function name using closure - def make_task(qid, sql): - @task(task_id=qid.replace("-", "_")) - def execute_query(): - """Execute SQL query""" - executor(sql) - return f"Completed: {qid}" - - return execute_query - - task_callables[query_id] = make_task(query_id, sql_to_execute) - - # Instantiate all tasks once before wiring dependencies - task_instances = {qid: callable() for qid, callable in task_callables.items()} - - # Set up dependencies based on table lineage - for _table_name, table_node in table_graph.tables.items(): - if table_node.created_by: - upstream_id = table_node.created_by - for downstream_id in table_node.read_by: - if upstream_id in task_instances and downstream_id in task_instances: - # Airflow: downstream >> upstream means upstream runs first - task_instances[upstream_id] >> task_instances[downstream_id] - - return pipeline_dag() + - Airflow 2.x: Fully supported (2.7.0+) + - Airflow 3.x: Fully supported (3.0.0+) + - TaskFlow API is compatible across both versions + """ + from .orchestrators import AirflowOrchestrator + + return AirflowOrchestrator(self).to_dag( + executor=executor, + dag_id=dag_id, + schedule=schedule, + start_date=start_date, + default_args=default_args, + airflow_version=airflow_version, + **dag_kwargs, + ) def run( self, @@ -2276,71 +2198,13 @@ def execute_sql(sql: str): result = pipeline.run(executor=execute_sql, max_workers=4) print(f"Completed {len(result['completed'])} queries") """ - if verbose: - print(f"šŸš€ Starting pipeline execution ({len(self.table_graph.queries)} queries)") - print() - - # Track completed queries - completed = set() - failed = [] - start_time = time.time() - - # Group queries by level for concurrent execution - levels = self._get_execution_levels() - - # Execute level by level - for level_num, level_queries in enumerate(levels, 1): - if verbose: - print(f"šŸ“Š Level {level_num}: {len(level_queries)} queries") - - # Execute queries in this level concurrently - with ThreadPoolExecutor(max_workers=max_workers) as pool: - futures = {} - - for query_id in level_queries: - query = self.table_graph.queries[query_id] - future = pool.submit(executor, query.sql) - futures[future] = query_id + from .execution import PipelineExecutor - # Wait for completion - for future in as_completed(futures): - query_id = futures[future] - - try: - future.result() - completed.add(query_id) - - if verbose: - print(f" āœ… {query_id}") - except Exception as e: - failed.append((query_id, str(e))) - - if verbose: - print(f" āŒ {query_id}: {e}") - - if verbose: - print() - - elapsed = time.time() - start_time - - # Summary - if verbose: - print("=" * 60) - print(f"āœ… Pipeline completed in {elapsed:.2f}s") - print(f" Successful: {len(completed)}") - print(f" Failed: {len(failed)}") - if failed: - print("\nāš ļø Failed queries:") - for query_id, error in failed: - print(f" - {query_id}: {error}") - print("=" * 60) - - return { - "completed": list(completed), - "failed": failed, - "elapsed_seconds": elapsed, - "total_queries": len(self.table_graph.queries), - } + return PipelineExecutor(self).run( + executor=executor, + max_workers=max_workers, + verbose=verbose, + ) async def async_run( self, @@ -2372,72 +2236,140 @@ async def execute_sql(sql: str): result = await pipeline.async_run(executor=execute_sql, max_workers=4) print(f"Completed {len(result['completed'])} queries") """ - if verbose: - print(f"šŸš€ Starting async pipeline execution ({len(self.table_graph.queries)} queries)") - print() - - # Track completed queries - completed = set() - failed = [] - start_time = time.time() - - # Group queries by level for concurrent execution - levels = self._get_execution_levels() - - # Create semaphore to limit concurrency - semaphore = asyncio.Semaphore(max_workers) - - # Execute level by level - for level_num, level_queries in enumerate(levels, 1): - if verbose: - print(f"šŸ“Š Level {level_num}: {len(level_queries)} queries") - - async def execute_with_semaphore(query_id: str, sql: str): - """Execute query with semaphore for concurrency control""" - async with semaphore: - try: - await executor(sql) - completed.add(query_id) - if verbose: - print(f" āœ… {query_id}") - except Exception as e: - failed.append((query_id, str(e))) - if verbose: - print(f" āŒ {query_id}: {e}") - - # Execute queries in this level concurrently - tasks = [] - for query_id in level_queries: - query = self.table_graph.queries[query_id] - task = execute_with_semaphore(query_id, query.sql) - tasks.append(task) - - # Wait for all tasks in this level to complete - await asyncio.gather(*tasks) - - if verbose: - print() - - elapsed = time.time() - start_time - - # Summary - if verbose: - print("=" * 60) - print(f"āœ… Pipeline completed in {elapsed:.2f}s") - print(f" Successful: {len(completed)}") - print(f" Failed: {len(failed)}") - if failed: - print("\nāš ļø Failed queries:") - for query_id, error in failed: - print(f" - {query_id}: {error}") - print("=" * 60) - - return { - "completed": list(completed), - "failed": failed, - "elapsed_seconds": elapsed, - "total_queries": len(self.table_graph.queries), - } + from .execution import PipelineExecutor + + return await PipelineExecutor(self).async_run( + executor=executor, + max_workers=max_workers, + verbose=verbose, + ) + + # ======================================================================== + # Orchestrator Methods - Dagster + # ======================================================================== + + def to_dagster_assets( + self, + executor: Callable[[str], None], + group_name: Optional[str] = None, + key_prefix: Optional[Union[str, List[str]]] = None, + compute_kind: str = "sql", + **asset_kwargs, + ) -> List: + """ + Create Dagster Assets from this pipeline. + + Converts the pipeline's table dependency graph into Dagster assets + where each target table becomes an asset with proper dependencies. + This is the recommended approach for Dagster as it provides better + lineage tracking and observability. + + Args: + executor: Function that executes SQL (takes sql string) + group_name: Optional asset group name for organization in Dagster UI + key_prefix: Optional prefix for asset keys (e.g., ["warehouse", "analytics"]) + compute_kind: Compute kind tag for assets (default: "sql") + **asset_kwargs: Additional asset parameters (owners, tags, etc.) + + Returns: + List of Dagster Asset definitions + + Examples: + # Basic usage + def execute_sql(sql: str): + from clickhouse_driver import Client + Client('localhost').execute(sql) + + assets = pipeline.to_dagster_assets( + executor=execute_sql, + group_name="analytics" + ) + + # Create Dagster Definitions + from dagster import Definitions + defs = Definitions(assets=assets) + + # Advanced usage with prefixes and metadata + assets = pipeline.to_dagster_assets( + executor=execute_sql, + group_name="warehouse", + key_prefix=["prod", "analytics"], + compute_kind="clickhouse", + owners=["team:data-eng"], + tags={"domain": "finance"}, + ) + + Note: + - Requires Dagster 1.x: pip install 'dagster>=1.5.0' + - Each target table becomes a Dagster asset + - Dependencies are automatically inferred from table lineage + - Deployment: Drop the definitions.py file in your Dagster workspace + """ + from .orchestrators import DagsterOrchestrator + + return DagsterOrchestrator(self).to_assets( + executor=executor, + group_name=group_name, + key_prefix=key_prefix, + compute_kind=compute_kind, + **asset_kwargs, + ) + + def to_dagster_job( + self, + executor: Callable[[str], None], + job_name: str, + description: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + **job_kwargs, + ): + """ + Create Dagster Job from this pipeline using ops. + + Converts the pipeline's table dependency graph into a Dagster job + where each SQL query becomes an op with proper dependencies. + + Note: For new pipelines, consider using to_dagster_assets() instead, + which provides better lineage tracking and observability in Dagster. + + Args: + executor: Function that executes SQL (takes sql string) + job_name: Name for the Dagster job + description: Optional job description (auto-generated if not provided) + tags: Optional job tags for filtering in Dagster UI + **job_kwargs: Additional job parameters + + Returns: + Dagster Job definition + + Examples: + # Basic usage + job = pipeline.to_dagster_job( + executor=execute_sql, + job_name="analytics_pipeline" + ) + + # Create Dagster Definitions + from dagster import Definitions + defs = Definitions(jobs=[job]) + + # Execute the job locally + result = job.execute_in_process() + + Note: + - Requires Dagster 1.x: pip install 'dagster>=1.5.0' + - Consider using to_dagster_assets() for better Dagster integration + - Deployment: Drop the definitions.py file in your Dagster workspace + """ + from .orchestrators import DagsterOrchestrator + + return DagsterOrchestrator(self).to_job( + executor=executor, + job_name=job_name, + description=description, + tags=tags, + **job_kwargs, + ) # ======================================================================== # Validation Methods diff --git a/tests/test_dagster_integration.py b/tests/test_dagster_integration.py new file mode 100644 index 0000000..c5b2c18 --- /dev/null +++ b/tests/test_dagster_integration.py @@ -0,0 +1,348 @@ +""" +Tests for Dagster integration (to_dagster_assets, to_dagster_job). +""" + +import pytest + +from clgraph.pipeline import Pipeline + +# Check if Dagster is available +try: + import dagster # noqa: F401 + + DAGSTER_AVAILABLE = True +except ImportError: + DAGSTER_AVAILABLE = False + + +class TestToDagsterAssetsBasic: + """Basic tests for to_dagster_assets method.""" + + def test_requires_dagster(self): + """Test that to_dagster_assets raises error when Dagster is not installed.""" + queries = [("query1", "CREATE TABLE table1 AS SELECT 1 as id")] + pipeline = Pipeline(queries, dialect="bigquery") + + def mock_executor(sql: str): + pass + + try: + assets = pipeline.to_dagster_assets(executor=mock_executor, group_name="test") + # If we get here, dagster is installed + assert assets is not None + assert len(assets) == 1 + except ImportError as e: + # Expected if dagster not installed + assert "Dagster is required" in str(e) + + @pytest.mark.skipif(not DAGSTER_AVAILABLE, reason="Dagster not installed") + def test_basic_asset_generation(self): + """Test basic asset generation from pipeline.""" + queries = [ + ("staging", "CREATE TABLE staging AS SELECT 1 as id, 'Alice' as name"), + ( + "analytics", + "CREATE TABLE analytics AS SELECT id, name FROM staging WHERE id = 1", + ), + ] + pipeline = Pipeline(queries, dialect="bigquery") + + def mock_executor(sql: str): + pass + + assets = pipeline.to_dagster_assets(executor=mock_executor, group_name="test_group") + + # Should create 2 assets + assert len(assets) == 2 + + # Check asset keys + asset_keys = [asset.key.path[-1] for asset in assets] + assert "staging" in asset_keys + assert "analytics" in asset_keys + + @pytest.mark.skipif(not DAGSTER_AVAILABLE, reason="Dagster not installed") + def test_asset_dependencies(self): + """Test that asset dependencies are correctly wired.""" + queries = [ + ("raw", "CREATE TABLE raw AS SELECT 1 as id"), + ("staging", "CREATE TABLE staging AS SELECT * FROM raw"), + ("analytics", "CREATE TABLE analytics AS SELECT * FROM staging"), + ] + pipeline = Pipeline(queries, dialect="bigquery") + + def mock_executor(sql: str): + pass + + assets = pipeline.to_dagster_assets(executor=mock_executor) + + # Find analytics asset (should depend on staging) + analytics_asset = next(a for a in assets if "analytics" in a.key.path) + + # Check it has dependencies + assert len(analytics_asset.deps) > 0 + + # The dependency should be staging + dep_keys = [str(d.asset_key) for d in analytics_asset.deps] + assert any("staging" in k for k in dep_keys) + + @pytest.mark.skipif(not DAGSTER_AVAILABLE, reason="Dagster not installed") + def test_asset_with_key_prefix(self): + """Test asset key prefix is applied correctly.""" + queries = [("table1", "CREATE TABLE table1 AS SELECT 1 as id")] + pipeline = Pipeline(queries, dialect="bigquery") + + def mock_executor(sql: str): + pass + + # Test with string prefix + assets = pipeline.to_dagster_assets(executor=mock_executor, key_prefix="warehouse") + assert assets[0].key.path == ["warehouse", "table1"] + + # Test with list prefix + assets = pipeline.to_dagster_assets( + executor=mock_executor, key_prefix=["prod", "analytics"] + ) + assert assets[0].key.path == ["prod", "analytics", "table1"] + + @pytest.mark.skipif(not DAGSTER_AVAILABLE, reason="Dagster not installed") + def test_asset_group_name(self): + """Test asset group name is set correctly.""" + queries = [("table1", "CREATE TABLE table1 AS SELECT 1 as id")] + pipeline = Pipeline(queries, dialect="bigquery") + + def mock_executor(sql: str): + pass + + assets = pipeline.to_dagster_assets(executor=mock_executor, group_name="my_group") + + assert assets[0].group_names_by_key[assets[0].key] == "my_group" + + @pytest.mark.skipif(not DAGSTER_AVAILABLE, reason="Dagster not installed") + def test_asset_compute_kind(self): + """Test asset compute kind is set correctly.""" + queries = [("table1", "CREATE TABLE table1 AS SELECT 1 as id")] + pipeline = Pipeline(queries, dialect="bigquery") + + def mock_executor(sql: str): + pass + + assets = pipeline.to_dagster_assets(executor=mock_executor, compute_kind="clickhouse") + + # Check compute_kind in asset's op + assert assets[0].op.tags.get("dagster/compute_kind") == "clickhouse" + + @pytest.mark.skipif(not DAGSTER_AVAILABLE, reason="Dagster not installed") + def test_asset_skips_queries_without_target_table(self): + """Test that queries without target tables are skipped.""" + queries = [ + ("insert_query", "INSERT INTO existing_table SELECT 1"), # No target table + ("create_query", "CREATE TABLE new_table AS SELECT 1 as id"), + ] + pipeline = Pipeline(queries, dialect="bigquery") + + def mock_executor(sql: str): + pass + + assets = pipeline.to_dagster_assets(executor=mock_executor) + + # Should only create 1 asset (for CREATE TABLE) + assert len(assets) == 1 + assert "new_table" in assets[0].key.path + + +class TestToDagsterAssetsExecution: + """Tests for Dagster asset execution.""" + + @pytest.mark.skipif(not DAGSTER_AVAILABLE, reason="Dagster not installed") + def test_asset_execution_calls_executor(self): + """Test that materializing an asset calls the executor.""" + queries = [("table1", "CREATE TABLE table1 AS SELECT 1 as id")] + pipeline = Pipeline(queries, dialect="bigquery") + + executed_sql = [] + + def tracking_executor(sql: str): + executed_sql.append(sql) + + assets = pipeline.to_dagster_assets(executor=tracking_executor, group_name="test") + + # Materialize the asset + from dagster import materialize + + result = materialize(assets) + + assert result.success + assert len(executed_sql) == 1 + assert "table1" in executed_sql[0] + + @pytest.mark.skipif(not DAGSTER_AVAILABLE, reason="Dagster not installed") + def test_asset_execution_order(self): + """Test that assets execute in correct dependency order.""" + queries = [ + ("level1", "CREATE TABLE level1 AS SELECT 1 as id"), + ("level2", "CREATE TABLE level2 AS SELECT * FROM level1"), + ("level3", "CREATE TABLE level3 AS SELECT * FROM level2"), + ] + pipeline = Pipeline(queries, dialect="bigquery") + + execution_order = [] + + def tracking_executor(sql: str): + if "level1" in sql: + execution_order.append("level1") + elif "level2" in sql: + execution_order.append("level2") + elif "level3" in sql: + execution_order.append("level3") + + assets = pipeline.to_dagster_assets(executor=tracking_executor) + + from dagster import materialize + + result = materialize(assets) + + assert result.success + assert execution_order == ["level1", "level2", "level3"] + + +class TestToDagsterJob: + """Tests for to_dagster_job method.""" + + def test_requires_dagster(self): + """Test that to_dagster_job raises error when Dagster is not installed.""" + queries = [("query1", "CREATE TABLE table1 AS SELECT 1 as id")] + pipeline = Pipeline(queries, dialect="bigquery") + + def mock_executor(sql: str): + pass + + try: + job = pipeline.to_dagster_job(executor=mock_executor, job_name="test_job") + # If we get here, dagster is installed + assert job is not None + assert job.name == "test_job" + except ImportError as e: + # Expected if dagster not installed + assert "Dagster is required" in str(e) + + @pytest.mark.skipif(not DAGSTER_AVAILABLE, reason="Dagster not installed") + def test_basic_job_generation(self): + """Test basic job generation from pipeline.""" + queries = [ + ("staging", "CREATE TABLE staging AS SELECT 1 as id"), + ("analytics", "CREATE TABLE analytics AS SELECT * FROM staging"), + ] + pipeline = Pipeline(queries, dialect="bigquery") + + def mock_executor(sql: str): + pass + + job = pipeline.to_dagster_job(executor=mock_executor, job_name="my_pipeline") + + assert job.name == "my_pipeline" + assert job.description is not None + assert "2 queries" in job.description + + @pytest.mark.skipif(not DAGSTER_AVAILABLE, reason="Dagster not installed") + def test_job_with_custom_description(self): + """Test job with custom description.""" + queries = [("table1", "CREATE TABLE table1 AS SELECT 1 as id")] + pipeline = Pipeline(queries, dialect="bigquery") + + def mock_executor(sql: str): + pass + + job = pipeline.to_dagster_job( + executor=mock_executor, + job_name="custom_job", + description="My custom description", + ) + + assert job.description == "My custom description" + + @pytest.mark.skipif(not DAGSTER_AVAILABLE, reason="Dagster not installed") + def test_job_with_tags(self): + """Test job with custom tags.""" + queries = [("table1", "CREATE TABLE table1 AS SELECT 1 as id")] + pipeline = Pipeline(queries, dialect="bigquery") + + def mock_executor(sql: str): + pass + + job = pipeline.to_dagster_job( + executor=mock_executor, + job_name="tagged_job", + tags={"team": "data-eng", "env": "prod"}, + ) + + assert job.tags.get("team") == "data-eng" + assert job.tags.get("env") == "prod" + + @pytest.mark.skipif(not DAGSTER_AVAILABLE, reason="Dagster not installed") + def test_job_execution_in_process(self): + """Test job can be executed in process.""" + queries = [ + ("table1", "CREATE TABLE table1 AS SELECT 1 as id"), + ("table2", "CREATE TABLE table2 AS SELECT * FROM table1"), + ] + pipeline = Pipeline(queries, dialect="bigquery") + + executed_sql = [] + + def tracking_executor(sql: str): + executed_sql.append(sql) + + job = pipeline.to_dagster_job(executor=tracking_executor, job_name="test_execution") + + result = job.execute_in_process() + + assert result.success + assert len(executed_sql) == 2 + + +class TestDagsterDefinitionsIntegration: + """Tests for creating complete Dagster Definitions.""" + + @pytest.mark.skipif(not DAGSTER_AVAILABLE, reason="Dagster not installed") + def test_assets_with_definitions(self): + """Test creating Dagster Definitions from assets.""" + from dagster import Definitions + + queries = [ + ("staging", "CREATE TABLE staging AS SELECT 1 as id"), + ("analytics", "CREATE TABLE analytics AS SELECT * FROM staging"), + ] + pipeline = Pipeline(queries, dialect="bigquery") + + def mock_executor(sql: str): + pass + + assets = pipeline.to_dagster_assets(executor=mock_executor, group_name="demo") + + # Create Definitions (this is what users do in definitions.py) + defs = Definitions(assets=assets) + + assert defs is not None + # Check that assets are registered + all_asset_keys = defs.get_all_asset_keys() + assert len(all_asset_keys) == 2 + + @pytest.mark.skipif(not DAGSTER_AVAILABLE, reason="Dagster not installed") + def test_job_with_definitions(self): + """Test creating Dagster Definitions from job.""" + from dagster import Definitions + + queries = [("table1", "CREATE TABLE table1 AS SELECT 1 as id")] + pipeline = Pipeline(queries, dialect="bigquery") + + def mock_executor(sql: str): + pass + + job = pipeline.to_dagster_job(executor=mock_executor, job_name="demo_job") + + # Create Definitions + defs = Definitions(jobs=[job]) + + assert defs is not None + # Check that job is registered + assert defs.get_job_def("demo_job") is not None