From d8b0496de5305471cbd3afe734b66892e042666e Mon Sep 17 00:00:00 2001 From: Ming-Jer Lee Date: Mon, 2 Feb 2026 03:19:36 -0800 Subject: [PATCH 1/2] feat: Add Mage orchestrator support Add MageOrchestrator class that converts clgraph pipelines to Mage block-based pipelines with metadata.yaml and Python block files. Includes 27 integration tests covering block types, dependencies, configuration, and complex pipeline scenarios. Co-Authored-By: Claude Opus 4.5 --- src/clgraph/orchestrators/__init__.py | 8 + src/clgraph/orchestrators/mage.py | 257 ++++++++++++ tests/test_mage_integration.py | 584 ++++++++++++++++++++++++++ 3 files changed, 849 insertions(+) create mode 100644 src/clgraph/orchestrators/mage.py create mode 100644 tests/test_mage_integration.py diff --git a/src/clgraph/orchestrators/__init__.py b/src/clgraph/orchestrators/__init__.py index b98bc7b..9c93c00 100644 --- a/src/clgraph/orchestrators/__init__.py +++ b/src/clgraph/orchestrators/__init__.py @@ -9,6 +9,7 @@ - Dagster (1.x) - Prefect (2.x and 3.x) - Kestra (YAML-based declarative workflows) +- Mage (notebook-style block-based pipelines) Example: from clgraph import Pipeline @@ -17,6 +18,7 @@ DagsterOrchestrator, PrefectOrchestrator, KestraOrchestrator, + MageOrchestrator, ) pipeline = Pipeline.from_sql_files("queries/", dialect="bigquery") @@ -36,12 +38,17 @@ # Generate Kestra flow YAML kestra = KestraOrchestrator(pipeline) yaml_content = kestra.to_flow(flow_id="my_pipeline", namespace="clgraph") + + # Generate Mage pipeline + mage = MageOrchestrator(pipeline) + files = mage.to_pipeline_files(executor=execute_sql, pipeline_name="my_pipeline") """ from .airflow import AirflowOrchestrator from .base import BaseOrchestrator from .dagster import DagsterOrchestrator from .kestra import KestraOrchestrator +from .mage import MageOrchestrator from .prefect import PrefectOrchestrator __all__ = [ @@ -49,5 +56,6 @@ "AirflowOrchestrator", "DagsterOrchestrator", "KestraOrchestrator", + "MageOrchestrator", "PrefectOrchestrator", ] diff --git a/src/clgraph/orchestrators/mage.py b/src/clgraph/orchestrators/mage.py new file mode 100644 index 0000000..d940bf8 --- /dev/null +++ b/src/clgraph/orchestrators/mage.py @@ -0,0 +1,257 @@ +""" +Mage orchestrator integration for clgraph. + +Converts clgraph pipelines to Mage pipeline files (metadata.yaml and block Python files). +Mage is a modern data pipeline tool with a notebook-style UI and block-based architecture. +""" + +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional + +from .base import BaseOrchestrator + +if TYPE_CHECKING: + pass + + +class MageOrchestrator(BaseOrchestrator): + """ + Converts clgraph pipelines to Mage pipelines. + + Mage uses a block-based architecture where each SQL query becomes a block + (either data_loader or transformer). Dependencies are expressed via + upstream_blocks and downstream_blocks. + + Example: + from clgraph.orchestrators import MageOrchestrator + + orchestrator = MageOrchestrator(pipeline) + files = orchestrator.to_pipeline_files( + executor=execute_sql, + pipeline_name="my_pipeline", + ) + + # Write files to Mage project + import yaml + with open("pipelines/my_pipeline/metadata.yaml", "w") as f: + yaml.dump(files["metadata.yaml"], f) + for name, code in files["blocks"].items(): + with open(f"pipelines/my_pipeline/{name}.py", "w") as f: + f.write(code) + """ + + def to_pipeline_config( + self, + pipeline_name: str, + description: Optional[str] = None, + pipeline_type: str = "python", + **kwargs, + ) -> Dict[str, Any]: + """ + Generate Mage pipeline configuration (metadata.yaml content). + + Args: + pipeline_name: Name for the Mage pipeline + description: Optional pipeline description (auto-generated if not provided) + pipeline_type: Pipeline type (default: "python") + **kwargs: Additional configuration + + Returns: + Dictionary representing metadata.yaml content + """ + table_graph = self.table_graph + + if description is None: + query_count = len(table_graph.queries) + table_count = len(table_graph.tables) + description = ( + f"Pipeline with {query_count} queries operating on " + f"{table_count} tables. Generated by clgraph." + ) + + blocks = [] + for query_id in table_graph.topological_sort(): + query = table_graph.queries[query_id] + block_name = self._sanitize_name(query_id) + + # Determine block type based on dependencies + upstream_blocks = [] + for source_table in query.source_tables: + if source_table in table_graph.tables: + table_node = table_graph.tables[source_table] + if table_node.created_by: + upstream_blocks.append(self._sanitize_name(table_node.created_by)) + + block_config = { + "name": block_name, + "uuid": block_name, + "type": "data_loader" if not upstream_blocks else "transformer", + "upstream_blocks": upstream_blocks, + "downstream_blocks": [], + } + blocks.append(block_config) + + # Set downstream_blocks for each block + block_map = {b["name"]: b for b in blocks} + for block in blocks: + for upstream_name in block["upstream_blocks"]: + if upstream_name in block_map: + block_map[upstream_name]["downstream_blocks"].append(block["name"]) + + config = { + "name": pipeline_name, + "uuid": pipeline_name, + "description": description, + "type": pipeline_type, + "blocks": blocks, + **kwargs, + } + + return config + + def to_blocks( + self, + executor: Callable[[str], None], + connection_name: str = "clickhouse_default", + ) -> Dict[str, str]: + """ + Generate Mage block Python files. + + Args: + executor: Function that executes SQL (for code reference) + connection_name: Name of database connection in Mage io_config.yaml + + Returns: + Dictionary mapping block_name -> block_code + """ + table_graph = self.table_graph + blocks = {} + + for query_id in table_graph.topological_sort(): + query = table_graph.queries[query_id] + block_name = self._sanitize_name(query_id) + + # Determine upstream blocks + upstream_blocks = [] + for source_table in query.source_tables: + if source_table in table_graph.tables: + table_node = table_graph.tables[source_table] + if table_node.created_by: + upstream_blocks.append(self._sanitize_name(table_node.created_by)) + + # Determine block type + block_type = "data_loader" if not upstream_blocks else "transformer" + + # Generate block code + code = self._generate_block_code( + block_name=block_name, + block_type=block_type, + sql=query.sql, + query_id=query_id, + upstream_blocks=upstream_blocks, + connection_name=connection_name, + ) + + blocks[block_name] = code + + return blocks + + def _generate_block_code( + self, + block_name: str, + block_type: str, + sql: str, + query_id: str, + upstream_blocks: List[str], + connection_name: str, + ) -> str: + """Generate Python code for a Mage block.""" + + # Mage requires specific function names for each block type + if block_type == "data_loader": + decorator = "@data_loader" + func_name = "load_data" + imports = "from mage_ai.data_preparation.decorators import data_loader" + else: + decorator = "@transformer" + func_name = "transform" + imports = "from mage_ai.data_preparation.decorators import transformer" + + # Build function signature + if upstream_blocks: + args = ", ".join([f"data_{i}" for i in range(len(upstream_blocks))]) + func_args = f"({args}, *args, **kwargs)" + else: + func_args = "(*args, **kwargs)" + + # Escape triple quotes in SQL if present + escaped_sql = sql.replace('"""', '\\"\\"\\"') + + code = f'''""" +Block: {block_name} +Query ID: {query_id} +Type: {block_type} +Generated by clgraph +""" +{imports} +from mage_ai.io.clickhouse import ClickHouse + + +{decorator} +def {func_name}{func_args}: + """ + Execute SQL query in ClickHouse. + """ + sql = """ +{escaped_sql} +""" + + with ClickHouse.with_config(config_profile="{connection_name}") as loader: + loader.execute(sql) + + return {{"status": "success", "query_id": "{query_id}"}} +''' + return code + + def to_pipeline_files( + self, + executor: Callable[[str], None], + pipeline_name: str, + description: Optional[str] = None, + connection_name: str = "clickhouse_default", + ) -> Dict[str, Any]: + """ + Generate complete Mage pipeline file structure. + + Args: + executor: Function that executes SQL (for code reference) + pipeline_name: Name for the Mage pipeline + description: Optional pipeline description + connection_name: Database connection name in Mage io_config.yaml + + Returns: + Dictionary with file structure: + { + "metadata.yaml": , + "blocks": { + "block1.py": , + "block2.py": , + } + } + """ + config = self.to_pipeline_config( + pipeline_name=pipeline_name, + description=description, + ) + + blocks = self.to_blocks( + executor=executor, + connection_name=connection_name, + ) + + return { + "metadata.yaml": config, + "blocks": blocks, + } + + +__all__ = ["MageOrchestrator"] diff --git a/tests/test_mage_integration.py b/tests/test_mage_integration.py new file mode 100644 index 0000000..d0f38a1 --- /dev/null +++ b/tests/test_mage_integration.py @@ -0,0 +1,584 @@ +""" +Tests for Mage orchestrator integration. + +Tests the to_mage_pipeline() method and MageOrchestrator class. +""" + +from clgraph import Pipeline + + +def mock_executor(sql: str) -> None: + """Mock executor for testing.""" + pass + + +class TestToMagePipelineBasic: + """Basic tests for to_mage_pipeline method.""" + + def test_basic_pipeline_generation(self): + """Test basic Mage pipeline generation.""" + pipeline = Pipeline( + [ + ("staging", "CREATE TABLE staging AS SELECT 1 as id"), + ("analytics", "CREATE TABLE analytics AS SELECT * FROM staging"), + ] + ) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + ) + + assert "metadata.yaml" in result + assert "blocks" in result + + def test_metadata_has_required_fields(self): + """Test metadata.yaml has required fields.""" + pipeline = Pipeline( + [ + ("staging", "CREATE TABLE staging AS SELECT 1 as id"), + ("analytics", "CREATE TABLE analytics AS SELECT * FROM staging"), + ] + ) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + ) + + metadata = result["metadata.yaml"] + assert "name" in metadata + assert "uuid" in metadata + assert "description" in metadata + assert "type" in metadata + assert "blocks" in metadata + + def test_block_count_matches_queries(self): + """Test that block count matches number of queries.""" + pipeline = Pipeline( + [ + ("q1", "CREATE TABLE t1 AS SELECT 1"), + ("q2", "CREATE TABLE t2 AS SELECT * FROM t1"), + ("q3", "CREATE TABLE t3 AS SELECT * FROM t2"), + ] + ) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + ) + + assert len(result["metadata.yaml"]["blocks"]) == 3 + assert len(result["blocks"]) == 3 + + def test_custom_description(self): + """Test pipeline with custom description.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + description="Custom description for testing", + ) + + assert result["metadata.yaml"]["description"] == "Custom description for testing" + + def test_auto_generated_description(self): + """Test auto-generated description.""" + pipeline = Pipeline( + [ + ("q1", "CREATE TABLE t1 AS SELECT 1"), + ("q2", "CREATE TABLE t2 AS SELECT * FROM t1"), + ] + ) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + ) + + description = result["metadata.yaml"]["description"] + assert "2 queries" in description + assert "clgraph" in description + + def test_default_pipeline_type(self): + """Test default pipeline type is python.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + ) + + assert result["metadata.yaml"]["type"] == "python" + + def test_blocks_are_python_code_strings(self): + """Test that block values are Python code strings.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + ) + + for block_code in result["blocks"].values(): + assert isinstance(block_code, str) + assert "def " in block_code + + +class TestMageBlockTypes: + """Test block type assignment.""" + + def test_source_query_is_data_loader(self): + """Test that source query (no upstream) gets data_loader type.""" + pipeline = Pipeline( + [ + ("source", "CREATE TABLE source AS SELECT 1 as id"), + ("derived", "CREATE TABLE derived AS SELECT * FROM source"), + ] + ) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + ) + + blocks = result["metadata.yaml"]["blocks"] + source_block = next(b for b in blocks if b["name"] == "source") + assert source_block["type"] == "data_loader" + + def test_dependent_query_is_transformer(self): + """Test that dependent query gets transformer type.""" + pipeline = Pipeline( + [ + ("source", "CREATE TABLE source AS SELECT 1 as id"), + ("derived", "CREATE TABLE derived AS SELECT * FROM source"), + ] + ) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + ) + + blocks = result["metadata.yaml"]["blocks"] + derived_block = next(b for b in blocks if b["name"] == "derived") + assert derived_block["type"] == "transformer" + + def test_data_loader_decorator_in_code(self): + """Test that data_loader block code contains correct decorator.""" + pipeline = Pipeline( + [ + ("source", "CREATE TABLE source AS SELECT 1 as id"), + ("derived", "CREATE TABLE derived AS SELECT * FROM source"), + ] + ) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + ) + + source_code = result["blocks"]["source"] + assert "@data_loader" in source_code + assert "def load_data" in source_code + + def test_transformer_decorator_in_code(self): + """Test that transformer block code contains correct decorator.""" + pipeline = Pipeline( + [ + ("source", "CREATE TABLE source AS SELECT 1 as id"), + ("derived", "CREATE TABLE derived AS SELECT * FROM source"), + ] + ) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + ) + + derived_code = result["blocks"]["derived"] + assert "@transformer" in derived_code + assert "def transform" in derived_code + + +class TestMageDependencies: + """Test upstream/downstream block wiring.""" + + def test_linear_chain_upstream(self): + """Test linear chain has correct upstream_blocks.""" + pipeline = Pipeline( + [ + ("step1", "CREATE TABLE step1 AS SELECT 1"), + ("step2", "CREATE TABLE step2 AS SELECT * FROM step1"), + ("step3", "CREATE TABLE step3 AS SELECT * FROM step2"), + ] + ) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + ) + + blocks = result["metadata.yaml"]["blocks"] + block_map = {b["name"]: b for b in blocks} + + assert block_map["step1"]["upstream_blocks"] == [] + assert block_map["step2"]["upstream_blocks"] == ["step1"] + assert block_map["step3"]["upstream_blocks"] == ["step2"] + + def test_linear_chain_downstream(self): + """Test linear chain has correct downstream_blocks.""" + pipeline = Pipeline( + [ + ("step1", "CREATE TABLE step1 AS SELECT 1"), + ("step2", "CREATE TABLE step2 AS SELECT * FROM step1"), + ("step3", "CREATE TABLE step3 AS SELECT * FROM step2"), + ] + ) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + ) + + blocks = result["metadata.yaml"]["blocks"] + block_map = {b["name"]: b for b in blocks} + + assert block_map["step1"]["downstream_blocks"] == ["step2"] + assert block_map["step2"]["downstream_blocks"] == ["step3"] + assert block_map["step3"]["downstream_blocks"] == [] + + def test_diamond_pattern_dependencies(self): + """Test diamond pattern wiring for all 4 blocks.""" + pipeline = Pipeline( + [ + ("source", "CREATE TABLE source AS SELECT 1 as id"), + ("left_branch", "CREATE TABLE left_branch AS SELECT * FROM source"), + ("right_branch", "CREATE TABLE right_branch AS SELECT * FROM source"), + ( + "final", + "CREATE TABLE final AS SELECT * FROM left_branch, right_branch", + ), + ] + ) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + ) + + blocks = result["metadata.yaml"]["blocks"] + block_map = {b["name"]: b for b in blocks} + + # Source has no upstream + assert block_map["source"]["upstream_blocks"] == [] + # Branches depend on source + assert block_map["left_branch"]["upstream_blocks"] == ["source"] + assert block_map["right_branch"]["upstream_blocks"] == ["source"] + # Final depends on both branches + assert sorted(block_map["final"]["upstream_blocks"]) == [ + "left_branch", + "right_branch", + ] + + # Source has both branches as downstream + assert sorted(block_map["source"]["downstream_blocks"]) == [ + "left_branch", + "right_branch", + ] + # Final has no downstream + assert block_map["final"]["downstream_blocks"] == [] + + def test_source_has_empty_upstream(self): + """Test that source blocks have empty upstream_blocks.""" + pipeline = Pipeline( + [ + ("source1", "CREATE TABLE source1 AS SELECT 1"), + ("source2", "CREATE TABLE source2 AS SELECT 2"), + ( + "combined", + "CREATE TABLE combined AS SELECT * FROM source1, source2", + ), + ] + ) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + ) + + blocks = result["metadata.yaml"]["blocks"] + block_map = {b["name"]: b for b in blocks} + + assert block_map["source1"]["upstream_blocks"] == [] + assert block_map["source2"]["upstream_blocks"] == [] + + def test_leaf_has_empty_downstream(self): + """Test that leaf blocks have empty downstream_blocks.""" + pipeline = Pipeline( + [ + ("source", "CREATE TABLE source AS SELECT 1"), + ("leaf", "CREATE TABLE leaf AS SELECT * FROM source"), + ] + ) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + ) + + blocks = result["metadata.yaml"]["blocks"] + block_map = {b["name"]: b for b in blocks} + + assert block_map["leaf"]["downstream_blocks"] == [] + + +class TestMageConfiguration: + """Test connection and configuration handling.""" + + def test_default_connection_name_in_code(self): + """Test default connection name clickhouse_default in block code.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + ) + + block_code = list(result["blocks"].values())[0] + assert "clickhouse_default" in block_code + + def test_custom_connection_name_in_code(self): + """Test custom connection name embedded in block code.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + connection_name="my_custom_conn", + ) + + block_code = list(result["blocks"].values())[0] + assert "my_custom_conn" in block_code + + def test_sql_embedded_in_block_code(self): + """Test that SQL is embedded in block code.""" + sql = "CREATE TABLE test_table AS SELECT 42 as value" + pipeline = Pipeline([("q1", sql)]) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + ) + + block_code = list(result["blocks"].values())[0] + assert sql in block_code + + def test_clickhouse_import_in_block_code(self): + """Test that ClickHouse import is present in block code.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="test_pipeline", + ) + + block_code = list(result["blocks"].values())[0] + assert "from mage_ai.io.clickhouse import ClickHouse" in block_code + + +class TestMageOrchestrator: + """Test MageOrchestrator class directly.""" + + def test_orchestrator_initialization(self): + """Test MageOrchestrator initialization.""" + from clgraph.orchestrators import MageOrchestrator + + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + orchestrator = MageOrchestrator(pipeline) + + assert orchestrator.pipeline == pipeline + assert orchestrator.table_graph == pipeline.table_graph + + def test_to_pipeline_config_returns_dict(self): + """Test to_pipeline_config returns dictionary.""" + from clgraph.orchestrators import MageOrchestrator + + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + orchestrator = MageOrchestrator(pipeline) + + result = orchestrator.to_pipeline_config(pipeline_name="test") + + assert isinstance(result, dict) + assert result["name"] == "test" + assert result["uuid"] == "test" + + def test_to_blocks_returns_dict(self): + """Test to_blocks returns dictionary.""" + from clgraph.orchestrators import MageOrchestrator + + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + orchestrator = MageOrchestrator(pipeline) + + result = orchestrator.to_blocks(executor=mock_executor) + + assert isinstance(result, dict) + assert len(result) == 1 + + def test_to_pipeline_files_returns_combined_dict(self): + """Test to_pipeline_files returns combined dictionary.""" + from clgraph.orchestrators import MageOrchestrator + + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + orchestrator = MageOrchestrator(pipeline) + + result = orchestrator.to_pipeline_files( + executor=mock_executor, + pipeline_name="test", + ) + + assert "metadata.yaml" in result + assert "blocks" in result + assert isinstance(result["metadata.yaml"], dict) + assert isinstance(result["blocks"], dict) + + def test_sanitize_name(self): + """Test _sanitize_name works correctly.""" + from clgraph.orchestrators import MageOrchestrator + + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + orchestrator = MageOrchestrator(pipeline) + + assert orchestrator._sanitize_name("my.table") == "my_table" + assert orchestrator._sanitize_name("my-table") == "my_table" + assert orchestrator._sanitize_name("my_table") == "my_table" + + +class TestMageComplexPipeline: + """Test Mage pipeline generation with complex pipelines.""" + + def test_enterprise_like_pipeline(self): + """Test with a pipeline similar to enterprise demo.""" + pipeline = Pipeline( + [ + ( + "raw_sales", + """ + CREATE TABLE raw_sales AS + SELECT + toDate('2024-01-01') + number as date, + number % 100 as product_id, + number % 10 as region_id, + rand() % 1000 as amount + FROM numbers(1000) + """, + ), + ( + "raw_products", + """ + CREATE TABLE raw_products AS + SELECT + number as product_id, + concat('Product ', toString(number)) as product_name, + rand() % 5 as category_id + FROM numbers(100) + """, + ), + ( + "sales_with_products", + """ + CREATE TABLE sales_with_products AS + SELECT + s.date, + s.product_id, + p.product_name, + s.region_id, + s.amount + FROM raw_sales s + JOIN raw_products p ON s.product_id = p.product_id + """, + ), + ( + "daily_summary", + """ + CREATE TABLE daily_summary AS + SELECT + date, + count() as num_sales, + sum(amount) as total_amount + FROM sales_with_products + GROUP BY date + """, + ), + ], + dialect="clickhouse", + ) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="enterprise_pipeline", + ) + + metadata = result["metadata.yaml"] + assert metadata["name"] == "enterprise_pipeline" + assert len(metadata["blocks"]) == 4 + + block_map = {b["name"]: b for b in metadata["blocks"]} + + # Source blocks are data_loaders + assert block_map["raw_sales"]["type"] == "data_loader" + assert block_map["raw_products"]["type"] == "data_loader" + + # Dependent blocks are transformers + assert block_map["sales_with_products"]["type"] == "transformer" + assert block_map["daily_summary"]["type"] == "transformer" + + # sales_with_products depends on both raw tables + assert sorted(block_map["sales_with_products"]["upstream_blocks"]) == [ + "raw_products", + "raw_sales", + ] + + # daily_summary depends on sales_with_products + assert block_map["daily_summary"]["upstream_blocks"] == ["sales_with_products"] + + def test_10_step_linear_chain(self): + """Test 10-step linear chain has correct sequential dependencies.""" + queries = [] + for i in range(10): + if i == 0: + queries.append((f"step_{i}", f"CREATE TABLE step_{i} AS SELECT {i}")) + else: + queries.append( + ( + f"step_{i}", + f"CREATE TABLE step_{i} AS SELECT * FROM step_{i - 1}", + ) + ) + + pipeline = Pipeline(queries) + + result = pipeline.to_mage_pipeline( + executor=mock_executor, + pipeline_name="long_chain", + ) + + metadata = result["metadata.yaml"] + assert len(metadata["blocks"]) == 10 + + block_map = {b["name"]: b for b in metadata["blocks"]} + + # First step is data_loader with no upstream + assert block_map["step_0"]["type"] == "data_loader" + assert block_map["step_0"]["upstream_blocks"] == [] + + # All other steps are transformers with correct upstream + for i in range(1, 10): + assert block_map[f"step_{i}"]["type"] == "transformer" + assert block_map[f"step_{i}"]["upstream_blocks"] == [f"step_{i - 1}"] + + # Check downstream wiring + for i in range(9): + assert block_map[f"step_{i}"]["downstream_blocks"] == [f"step_{i + 1}"] + assert block_map["step_9"]["downstream_blocks"] == [] From f1cfb171d13513d2cea4c09ed37b7f12b535051d Mon Sep 17 00:00:00 2001 From: Ming-Jer Lee Date: Mon, 2 Feb 2026 18:56:26 -0800 Subject: [PATCH 2/2] fix: Address PR #60 review issues for Mage orchestrator - Remove unused `executor` parameter from all Mage methods - Remove empty TYPE_CHECKING block - Make database connector configurable via `db_connector` param (supports clickhouse, postgres, bigquery, snowflake) - Fix mutation in to_pipeline_config() with immutable downstream_map - Validate connection_name to prevent code injection - Add tests for new connectors and input validation --- src/clgraph/orchestrators/mage.py | 103 ++++++++++++++------ src/clgraph/pipeline.py | 9 +- tests/test_mage_integration.py | 150 ++++++++++++++++++++++++------ 3 files changed, 199 insertions(+), 63 deletions(-) diff --git a/src/clgraph/orchestrators/mage.py b/src/clgraph/orchestrators/mage.py index d940bf8..3aac17c 100644 --- a/src/clgraph/orchestrators/mage.py +++ b/src/clgraph/orchestrators/mage.py @@ -5,12 +5,32 @@ Mage is a modern data pipeline tool with a notebook-style UI and block-based architecture. """ -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +import re +from typing import Any, Dict, List, Optional from .base import BaseOrchestrator -if TYPE_CHECKING: - pass +# Mapping of db_connector name to (import_statement, class_name) +DB_CONNECTOR_MAP: Dict[str, tuple] = { + "clickhouse": ( + "from mage_ai.io.clickhouse import ClickHouse", + "ClickHouse", + ), + "postgres": ( + "from mage_ai.io.postgres import Postgres", + "Postgres", + ), + "bigquery": ( + "from mage_ai.io.bigquery import BigQuery", + "BigQuery", + ), + "snowflake": ( + "from mage_ai.io.snowflake import Snowflake", + "Snowflake", + ), +} + +_CONNECTION_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$") class MageOrchestrator(BaseOrchestrator): @@ -26,7 +46,6 @@ class MageOrchestrator(BaseOrchestrator): orchestrator = MageOrchestrator(pipeline) files = orchestrator.to_pipeline_files( - executor=execute_sql, pipeline_name="my_pipeline", ) @@ -68,12 +87,14 @@ def to_pipeline_config( f"{table_count} tables. Generated by clgraph." ) - blocks = [] + # Build a map of block_name -> list of downstream block names + downstream_map: Dict[str, List[str]] = {} + + block_entries = [] for query_id in table_graph.topological_sort(): query = table_graph.queries[query_id] block_name = self._sanitize_name(query_id) - # Determine block type based on dependencies upstream_blocks = [] for source_table in query.source_tables: if source_table in table_graph.tables: @@ -81,21 +102,30 @@ def to_pipeline_config( if table_node.created_by: upstream_blocks.append(self._sanitize_name(table_node.created_by)) - block_config = { - "name": block_name, - "uuid": block_name, - "type": "data_loader" if not upstream_blocks else "transformer", - "upstream_blocks": upstream_blocks, - "downstream_blocks": [], - } - blocks.append(block_config) + # Register this block as downstream of each upstream block + for upstream_name in upstream_blocks: + if upstream_name not in downstream_map: + downstream_map[upstream_name] = [] + downstream_map[upstream_name].append(block_name) - # Set downstream_blocks for each block - block_map = {b["name"]: b for b in blocks} - for block in blocks: - for upstream_name in block["upstream_blocks"]: - if upstream_name in block_map: - block_map[upstream_name]["downstream_blocks"].append(block["name"]) + block_entries.append( + { + "name": block_name, + "upstream_blocks": upstream_blocks, + } + ) + + # Build final block configs with downstream_blocks already set + blocks = [ + { + "name": entry["name"], + "uuid": entry["name"], + "type": "data_loader" if not entry["upstream_blocks"] else "transformer", + "upstream_blocks": entry["upstream_blocks"], + "downstream_blocks": list(downstream_map.get(entry["name"], [])), + } + for entry in block_entries + ] config = { "name": pipeline_name, @@ -110,15 +140,15 @@ def to_pipeline_config( def to_blocks( self, - executor: Callable[[str], None], connection_name: str = "clickhouse_default", + db_connector: str = "clickhouse", ) -> Dict[str, str]: """ Generate Mage block Python files. Args: - executor: Function that executes SQL (for code reference) connection_name: Name of database connection in Mage io_config.yaml + db_connector: Database connector type (clickhouse, postgres, bigquery, snowflake) Returns: Dictionary mapping block_name -> block_code @@ -149,6 +179,7 @@ def to_blocks( query_id=query_id, upstream_blocks=upstream_blocks, connection_name=connection_name, + db_connector=db_connector, ) blocks[block_name] = code @@ -163,9 +194,25 @@ def _generate_block_code( query_id: str, upstream_blocks: List[str], connection_name: str, + db_connector: str = "clickhouse", ) -> str: """Generate Python code for a Mage block.""" + # Validate connection_name to prevent code injection + if not _CONNECTION_NAME_PATTERN.match(connection_name): + raise ValueError( + f"Invalid connection_name '{connection_name}': " + "must contain only alphanumeric characters, underscores, and hyphens" + ) + + # Resolve db connector + if db_connector not in DB_CONNECTOR_MAP: + raise ValueError( + f"Unsupported db_connector '{db_connector}'. " + f"Supported: {', '.join(sorted(DB_CONNECTOR_MAP))}" + ) + connector_import, connector_class = DB_CONNECTOR_MAP[db_connector] + # Mage requires specific function names for each block type if block_type == "data_loader": decorator = "@data_loader" @@ -193,19 +240,19 @@ def _generate_block_code( Generated by clgraph """ {imports} -from mage_ai.io.clickhouse import ClickHouse +{connector_import} {decorator} def {func_name}{func_args}: """ - Execute SQL query in ClickHouse. + Execute SQL query via {connector_class}. """ sql = """ {escaped_sql} """ - with ClickHouse.with_config(config_profile="{connection_name}") as loader: + with {connector_class}.with_config(config_profile="{connection_name}") as loader: loader.execute(sql) return {{"status": "success", "query_id": "{query_id}"}} @@ -214,19 +261,19 @@ def {func_name}{func_args}: def to_pipeline_files( self, - executor: Callable[[str], None], pipeline_name: str, description: Optional[str] = None, connection_name: str = "clickhouse_default", + db_connector: str = "clickhouse", ) -> Dict[str, Any]: """ Generate complete Mage pipeline file structure. Args: - executor: Function that executes SQL (for code reference) pipeline_name: Name for the Mage pipeline description: Optional pipeline description connection_name: Database connection name in Mage io_config.yaml + db_connector: Database connector type (clickhouse, postgres, bigquery, snowflake) Returns: Dictionary with file structure: @@ -244,8 +291,8 @@ def to_pipeline_files( ) blocks = self.to_blocks( - executor=executor, connection_name=connection_name, + db_connector=db_connector, ) return { diff --git a/src/clgraph/pipeline.py b/src/clgraph/pipeline.py index beaf54f..629157c 100644 --- a/src/clgraph/pipeline.py +++ b/src/clgraph/pipeline.py @@ -2612,10 +2612,10 @@ def to_kestra_flow( def to_mage_pipeline( self, - executor: Callable[[str], None], pipeline_name: str, description: Optional[str] = None, connection_name: str = "clickhouse_default", + db_connector: str = "clickhouse", ) -> Dict[str, Any]: """ Generate Mage pipeline files from this pipeline. @@ -2625,10 +2625,10 @@ def to_mage_pipeline( data_loader or transformer). Args: - executor: Function that executes SQL (for code reference) pipeline_name: Name for the Mage pipeline description: Optional pipeline description (auto-generated if not provided) connection_name: Database connection name in Mage io_config.yaml + db_connector: Database connector type (clickhouse, postgres, bigquery, snowflake) Returns: Dictionary with pipeline file structure: @@ -2640,7 +2640,6 @@ def to_mage_pipeline( Examples: # Generate Mage pipeline files files = pipeline.to_mage_pipeline( - executor=execute_sql, pipeline_name="enterprise_pipeline", ) @@ -2656,15 +2655,15 @@ def to_mage_pipeline( - First query (no dependencies) becomes data_loader block - Subsequent queries become transformer blocks - Dependencies are managed via upstream_blocks/downstream_blocks - - Requires mage-ai package and ClickHouse connection in io_config.yaml + - Requires mage-ai package and database connection in io_config.yaml """ from .orchestrators import MageOrchestrator return MageOrchestrator(self).to_pipeline_files( - executor=executor, pipeline_name=pipeline_name, description=description, connection_name=connection_name, + db_connector=db_connector, ) # ======================================================================== diff --git a/tests/test_mage_integration.py b/tests/test_mage_integration.py index d0f38a1..2e6b9fc 100644 --- a/tests/test_mage_integration.py +++ b/tests/test_mage_integration.py @@ -4,12 +4,9 @@ Tests the to_mage_pipeline() method and MageOrchestrator class. """ -from clgraph import Pipeline - +import pytest -def mock_executor(sql: str) -> None: - """Mock executor for testing.""" - pass +from clgraph import Pipeline class TestToMagePipelineBasic: @@ -25,7 +22,6 @@ def test_basic_pipeline_generation(self): ) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", ) @@ -42,7 +38,6 @@ def test_metadata_has_required_fields(self): ) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", ) @@ -64,7 +59,6 @@ def test_block_count_matches_queries(self): ) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", ) @@ -76,7 +70,6 @@ def test_custom_description(self): pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", description="Custom description for testing", ) @@ -93,7 +86,6 @@ def test_auto_generated_description(self): ) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", ) @@ -106,7 +98,6 @@ def test_default_pipeline_type(self): pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", ) @@ -117,7 +108,6 @@ def test_blocks_are_python_code_strings(self): pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", ) @@ -139,7 +129,6 @@ def test_source_query_is_data_loader(self): ) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", ) @@ -157,7 +146,6 @@ def test_dependent_query_is_transformer(self): ) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", ) @@ -175,7 +163,6 @@ def test_data_loader_decorator_in_code(self): ) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", ) @@ -193,7 +180,6 @@ def test_transformer_decorator_in_code(self): ) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", ) @@ -216,7 +202,6 @@ def test_linear_chain_upstream(self): ) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", ) @@ -238,7 +223,6 @@ def test_linear_chain_downstream(self): ) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", ) @@ -264,7 +248,6 @@ def test_diamond_pattern_dependencies(self): ) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", ) @@ -304,7 +287,6 @@ def test_source_has_empty_upstream(self): ) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", ) @@ -324,7 +306,6 @@ def test_leaf_has_empty_downstream(self): ) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", ) @@ -342,7 +323,6 @@ def test_default_connection_name_in_code(self): pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", ) @@ -354,7 +334,6 @@ def test_custom_connection_name_in_code(self): pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", connection_name="my_custom_conn", ) @@ -368,7 +347,6 @@ def test_sql_embedded_in_block_code(self): pipeline = Pipeline([("q1", sql)]) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", ) @@ -376,11 +354,10 @@ def test_sql_embedded_in_block_code(self): assert sql in block_code def test_clickhouse_import_in_block_code(self): - """Test that ClickHouse import is present in block code.""" + """Test that ClickHouse import is present in block code by default.""" pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="test_pipeline", ) @@ -388,6 +365,122 @@ def test_clickhouse_import_in_block_code(self): assert "from mage_ai.io.clickhouse import ClickHouse" in block_code +class TestMageDbConnector: + """Test configurable database connector.""" + + def test_postgres_connector(self): + """Test postgres connector generates correct import and class.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + db_connector="postgres", + connection_name="pg_default", + ) + + block_code = list(result["blocks"].values())[0] + assert "from mage_ai.io.postgres import Postgres" in block_code + assert "Postgres.with_config" in block_code + assert "pg_default" in block_code + + def test_bigquery_connector(self): + """Test bigquery connector generates correct import and class.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + db_connector="bigquery", + connection_name="bq_default", + ) + + block_code = list(result["blocks"].values())[0] + assert "from mage_ai.io.bigquery import BigQuery" in block_code + assert "BigQuery.with_config" in block_code + assert "bq_default" in block_code + + def test_snowflake_connector(self): + """Test snowflake connector generates correct import and class.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + db_connector="snowflake", + connection_name="sf_default", + ) + + block_code = list(result["blocks"].values())[0] + assert "from mage_ai.io.snowflake import Snowflake" in block_code + assert "Snowflake.with_config" in block_code + assert "sf_default" in block_code + + def test_default_connector_is_clickhouse(self): + """Test that default db_connector is clickhouse.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + block_code = list(result["blocks"].values())[0] + assert "ClickHouse.with_config" in block_code + + def test_unsupported_connector_raises_error(self): + """Test that unsupported db_connector raises ValueError.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + with pytest.raises(ValueError, match="Unsupported db_connector"): + pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + db_connector="mysql", + ) + + +class TestMageConnectionNameValidation: + """Test connection_name validation to prevent code injection.""" + + def test_valid_connection_names(self): + """Test that valid connection names are accepted.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + for name in ["clickhouse_default", "my-conn", "conn123", "a_b-c"]: + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + connection_name=name, + ) + block_code = list(result["blocks"].values())[0] + assert name in block_code + + def test_invalid_connection_name_raises_error(self): + """Test that invalid connection_name raises ValueError.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + with pytest.raises(ValueError, match="Invalid connection_name"): + pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + connection_name='"; import os; os.system("rm -rf /")', + ) + + def test_connection_name_with_spaces_raises_error(self): + """Test that connection_name with spaces raises ValueError.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + with pytest.raises(ValueError, match="Invalid connection_name"): + pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + connection_name="my connection", + ) + + def test_connection_name_with_dots_raises_error(self): + """Test that connection_name with dots raises ValueError.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + with pytest.raises(ValueError, match="Invalid connection_name"): + pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + connection_name="my.connection", + ) + + class TestMageOrchestrator: """Test MageOrchestrator class directly.""" @@ -421,7 +514,7 @@ def test_to_blocks_returns_dict(self): pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) orchestrator = MageOrchestrator(pipeline) - result = orchestrator.to_blocks(executor=mock_executor) + result = orchestrator.to_blocks() assert isinstance(result, dict) assert len(result) == 1 @@ -434,7 +527,6 @@ def test_to_pipeline_files_returns_combined_dict(self): orchestrator = MageOrchestrator(pipeline) result = orchestrator.to_pipeline_files( - executor=mock_executor, pipeline_name="test", ) @@ -516,7 +608,6 @@ def test_enterprise_like_pipeline(self): ) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="enterprise_pipeline", ) @@ -560,7 +651,6 @@ def test_10_step_linear_chain(self): pipeline = Pipeline(queries) result = pipeline.to_mage_pipeline( - executor=mock_executor, pipeline_name="long_chain", )