diff --git a/src/clgraph/orchestrators/__init__.py b/src/clgraph/orchestrators/__init__.py index b98bc7b..9c93c00 100644 --- a/src/clgraph/orchestrators/__init__.py +++ b/src/clgraph/orchestrators/__init__.py @@ -9,6 +9,7 @@ - Dagster (1.x) - Prefect (2.x and 3.x) - Kestra (YAML-based declarative workflows) +- Mage (notebook-style block-based pipelines) Example: from clgraph import Pipeline @@ -17,6 +18,7 @@ DagsterOrchestrator, PrefectOrchestrator, KestraOrchestrator, + MageOrchestrator, ) pipeline = Pipeline.from_sql_files("queries/", dialect="bigquery") @@ -36,12 +38,17 @@ # Generate Kestra flow YAML kestra = KestraOrchestrator(pipeline) yaml_content = kestra.to_flow(flow_id="my_pipeline", namespace="clgraph") + + # Generate Mage pipeline + mage = MageOrchestrator(pipeline) + files = mage.to_pipeline_files(executor=execute_sql, pipeline_name="my_pipeline") """ from .airflow import AirflowOrchestrator from .base import BaseOrchestrator from .dagster import DagsterOrchestrator from .kestra import KestraOrchestrator +from .mage import MageOrchestrator from .prefect import PrefectOrchestrator __all__ = [ @@ -49,5 +56,6 @@ "AirflowOrchestrator", "DagsterOrchestrator", "KestraOrchestrator", + "MageOrchestrator", "PrefectOrchestrator", ] diff --git a/src/clgraph/orchestrators/mage.py b/src/clgraph/orchestrators/mage.py new file mode 100644 index 0000000..3aac17c --- /dev/null +++ b/src/clgraph/orchestrators/mage.py @@ -0,0 +1,304 @@ +""" +Mage orchestrator integration for clgraph. + +Converts clgraph pipelines to Mage pipeline files (metadata.yaml and block Python files). +Mage is a modern data pipeline tool with a notebook-style UI and block-based architecture. +""" + +import re +from typing import Any, Dict, List, Optional + +from .base import BaseOrchestrator + +# Mapping of db_connector name to (import_statement, class_name) +DB_CONNECTOR_MAP: Dict[str, tuple] = { + "clickhouse": ( + "from mage_ai.io.clickhouse import ClickHouse", + "ClickHouse", + ), + "postgres": ( + "from mage_ai.io.postgres import Postgres", + "Postgres", + ), + "bigquery": ( + "from mage_ai.io.bigquery import BigQuery", + "BigQuery", + ), + "snowflake": ( + "from mage_ai.io.snowflake import Snowflake", + "Snowflake", + ), +} + +_CONNECTION_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$") + + +class MageOrchestrator(BaseOrchestrator): + """ + Converts clgraph pipelines to Mage pipelines. + + Mage uses a block-based architecture where each SQL query becomes a block + (either data_loader or transformer). Dependencies are expressed via + upstream_blocks and downstream_blocks. + + Example: + from clgraph.orchestrators import MageOrchestrator + + orchestrator = MageOrchestrator(pipeline) + files = orchestrator.to_pipeline_files( + pipeline_name="my_pipeline", + ) + + # Write files to Mage project + import yaml + with open("pipelines/my_pipeline/metadata.yaml", "w") as f: + yaml.dump(files["metadata.yaml"], f) + for name, code in files["blocks"].items(): + with open(f"pipelines/my_pipeline/{name}.py", "w") as f: + f.write(code) + """ + + def to_pipeline_config( + self, + pipeline_name: str, + description: Optional[str] = None, + pipeline_type: str = "python", + **kwargs, + ) -> Dict[str, Any]: + """ + Generate Mage pipeline configuration (metadata.yaml content). + + Args: + pipeline_name: Name for the Mage pipeline + description: Optional pipeline description (auto-generated if not provided) + pipeline_type: Pipeline type (default: "python") + **kwargs: Additional configuration + + Returns: + Dictionary representing metadata.yaml content + """ + table_graph = self.table_graph + + if description is None: + query_count = len(table_graph.queries) + table_count = len(table_graph.tables) + description = ( + f"Pipeline with {query_count} queries operating on " + f"{table_count} tables. Generated by clgraph." + ) + + # Build a map of block_name -> list of downstream block names + downstream_map: Dict[str, List[str]] = {} + + block_entries = [] + for query_id in table_graph.topological_sort(): + query = table_graph.queries[query_id] + block_name = self._sanitize_name(query_id) + + upstream_blocks = [] + for source_table in query.source_tables: + if source_table in table_graph.tables: + table_node = table_graph.tables[source_table] + if table_node.created_by: + upstream_blocks.append(self._sanitize_name(table_node.created_by)) + + # Register this block as downstream of each upstream block + for upstream_name in upstream_blocks: + if upstream_name not in downstream_map: + downstream_map[upstream_name] = [] + downstream_map[upstream_name].append(block_name) + + block_entries.append( + { + "name": block_name, + "upstream_blocks": upstream_blocks, + } + ) + + # Build final block configs with downstream_blocks already set + blocks = [ + { + "name": entry["name"], + "uuid": entry["name"], + "type": "data_loader" if not entry["upstream_blocks"] else "transformer", + "upstream_blocks": entry["upstream_blocks"], + "downstream_blocks": list(downstream_map.get(entry["name"], [])), + } + for entry in block_entries + ] + + config = { + "name": pipeline_name, + "uuid": pipeline_name, + "description": description, + "type": pipeline_type, + "blocks": blocks, + **kwargs, + } + + return config + + def to_blocks( + self, + connection_name: str = "clickhouse_default", + db_connector: str = "clickhouse", + ) -> Dict[str, str]: + """ + Generate Mage block Python files. + + Args: + connection_name: Name of database connection in Mage io_config.yaml + db_connector: Database connector type (clickhouse, postgres, bigquery, snowflake) + + Returns: + Dictionary mapping block_name -> block_code + """ + table_graph = self.table_graph + blocks = {} + + for query_id in table_graph.topological_sort(): + query = table_graph.queries[query_id] + block_name = self._sanitize_name(query_id) + + # Determine upstream blocks + upstream_blocks = [] + for source_table in query.source_tables: + if source_table in table_graph.tables: + table_node = table_graph.tables[source_table] + if table_node.created_by: + upstream_blocks.append(self._sanitize_name(table_node.created_by)) + + # Determine block type + block_type = "data_loader" if not upstream_blocks else "transformer" + + # Generate block code + code = self._generate_block_code( + block_name=block_name, + block_type=block_type, + sql=query.sql, + query_id=query_id, + upstream_blocks=upstream_blocks, + connection_name=connection_name, + db_connector=db_connector, + ) + + blocks[block_name] = code + + return blocks + + def _generate_block_code( + self, + block_name: str, + block_type: str, + sql: str, + query_id: str, + upstream_blocks: List[str], + connection_name: str, + db_connector: str = "clickhouse", + ) -> str: + """Generate Python code for a Mage block.""" + + # Validate connection_name to prevent code injection + if not _CONNECTION_NAME_PATTERN.match(connection_name): + raise ValueError( + f"Invalid connection_name '{connection_name}': " + "must contain only alphanumeric characters, underscores, and hyphens" + ) + + # Resolve db connector + if db_connector not in DB_CONNECTOR_MAP: + raise ValueError( + f"Unsupported db_connector '{db_connector}'. " + f"Supported: {', '.join(sorted(DB_CONNECTOR_MAP))}" + ) + connector_import, connector_class = DB_CONNECTOR_MAP[db_connector] + + # Mage requires specific function names for each block type + if block_type == "data_loader": + decorator = "@data_loader" + func_name = "load_data" + imports = "from mage_ai.data_preparation.decorators import data_loader" + else: + decorator = "@transformer" + func_name = "transform" + imports = "from mage_ai.data_preparation.decorators import transformer" + + # Build function signature + if upstream_blocks: + args = ", ".join([f"data_{i}" for i in range(len(upstream_blocks))]) + func_args = f"({args}, *args, **kwargs)" + else: + func_args = "(*args, **kwargs)" + + # Escape triple quotes in SQL if present + escaped_sql = sql.replace('"""', '\\"\\"\\"') + + code = f'''""" +Block: {block_name} +Query ID: {query_id} +Type: {block_type} +Generated by clgraph +""" +{imports} +{connector_import} + + +{decorator} +def {func_name}{func_args}: + """ + Execute SQL query via {connector_class}. + """ + sql = """ +{escaped_sql} +""" + + with {connector_class}.with_config(config_profile="{connection_name}") as loader: + loader.execute(sql) + + return {{"status": "success", "query_id": "{query_id}"}} +''' + return code + + def to_pipeline_files( + self, + pipeline_name: str, + description: Optional[str] = None, + connection_name: str = "clickhouse_default", + db_connector: str = "clickhouse", + ) -> Dict[str, Any]: + """ + Generate complete Mage pipeline file structure. + + Args: + pipeline_name: Name for the Mage pipeline + description: Optional pipeline description + connection_name: Database connection name in Mage io_config.yaml + db_connector: Database connector type (clickhouse, postgres, bigquery, snowflake) + + Returns: + Dictionary with file structure: + { + "metadata.yaml": , + "blocks": { + "block1.py": , + "block2.py": , + } + } + """ + config = self.to_pipeline_config( + pipeline_name=pipeline_name, + description=description, + ) + + blocks = self.to_blocks( + connection_name=connection_name, + db_connector=db_connector, + ) + + return { + "metadata.yaml": config, + "blocks": blocks, + } + + +__all__ = ["MageOrchestrator"] diff --git a/src/clgraph/pipeline.py b/src/clgraph/pipeline.py index beaf54f..629157c 100644 --- a/src/clgraph/pipeline.py +++ b/src/clgraph/pipeline.py @@ -2612,10 +2612,10 @@ def to_kestra_flow( def to_mage_pipeline( self, - executor: Callable[[str], None], pipeline_name: str, description: Optional[str] = None, connection_name: str = "clickhouse_default", + db_connector: str = "clickhouse", ) -> Dict[str, Any]: """ Generate Mage pipeline files from this pipeline. @@ -2625,10 +2625,10 @@ def to_mage_pipeline( data_loader or transformer). Args: - executor: Function that executes SQL (for code reference) pipeline_name: Name for the Mage pipeline description: Optional pipeline description (auto-generated if not provided) connection_name: Database connection name in Mage io_config.yaml + db_connector: Database connector type (clickhouse, postgres, bigquery, snowflake) Returns: Dictionary with pipeline file structure: @@ -2640,7 +2640,6 @@ def to_mage_pipeline( Examples: # Generate Mage pipeline files files = pipeline.to_mage_pipeline( - executor=execute_sql, pipeline_name="enterprise_pipeline", ) @@ -2656,15 +2655,15 @@ def to_mage_pipeline( - First query (no dependencies) becomes data_loader block - Subsequent queries become transformer blocks - Dependencies are managed via upstream_blocks/downstream_blocks - - Requires mage-ai package and ClickHouse connection in io_config.yaml + - Requires mage-ai package and database connection in io_config.yaml """ from .orchestrators import MageOrchestrator return MageOrchestrator(self).to_pipeline_files( - executor=executor, pipeline_name=pipeline_name, description=description, connection_name=connection_name, + db_connector=db_connector, ) # ======================================================================== diff --git a/tests/test_mage_integration.py b/tests/test_mage_integration.py new file mode 100644 index 0000000..2e6b9fc --- /dev/null +++ b/tests/test_mage_integration.py @@ -0,0 +1,674 @@ +""" +Tests for Mage orchestrator integration. + +Tests the to_mage_pipeline() method and MageOrchestrator class. +""" + +import pytest + +from clgraph import Pipeline + + +class TestToMagePipelineBasic: + """Basic tests for to_mage_pipeline method.""" + + def test_basic_pipeline_generation(self): + """Test basic Mage pipeline generation.""" + pipeline = Pipeline( + [ + ("staging", "CREATE TABLE staging AS SELECT 1 as id"), + ("analytics", "CREATE TABLE analytics AS SELECT * FROM staging"), + ] + ) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + assert "metadata.yaml" in result + assert "blocks" in result + + def test_metadata_has_required_fields(self): + """Test metadata.yaml has required fields.""" + pipeline = Pipeline( + [ + ("staging", "CREATE TABLE staging AS SELECT 1 as id"), + ("analytics", "CREATE TABLE analytics AS SELECT * FROM staging"), + ] + ) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + metadata = result["metadata.yaml"] + assert "name" in metadata + assert "uuid" in metadata + assert "description" in metadata + assert "type" in metadata + assert "blocks" in metadata + + def test_block_count_matches_queries(self): + """Test that block count matches number of queries.""" + pipeline = Pipeline( + [ + ("q1", "CREATE TABLE t1 AS SELECT 1"), + ("q2", "CREATE TABLE t2 AS SELECT * FROM t1"), + ("q3", "CREATE TABLE t3 AS SELECT * FROM t2"), + ] + ) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + assert len(result["metadata.yaml"]["blocks"]) == 3 + assert len(result["blocks"]) == 3 + + def test_custom_description(self): + """Test pipeline with custom description.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + description="Custom description for testing", + ) + + assert result["metadata.yaml"]["description"] == "Custom description for testing" + + def test_auto_generated_description(self): + """Test auto-generated description.""" + pipeline = Pipeline( + [ + ("q1", "CREATE TABLE t1 AS SELECT 1"), + ("q2", "CREATE TABLE t2 AS SELECT * FROM t1"), + ] + ) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + description = result["metadata.yaml"]["description"] + assert "2 queries" in description + assert "clgraph" in description + + def test_default_pipeline_type(self): + """Test default pipeline type is python.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + assert result["metadata.yaml"]["type"] == "python" + + def test_blocks_are_python_code_strings(self): + """Test that block values are Python code strings.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + for block_code in result["blocks"].values(): + assert isinstance(block_code, str) + assert "def " in block_code + + +class TestMageBlockTypes: + """Test block type assignment.""" + + def test_source_query_is_data_loader(self): + """Test that source query (no upstream) gets data_loader type.""" + pipeline = Pipeline( + [ + ("source", "CREATE TABLE source AS SELECT 1 as id"), + ("derived", "CREATE TABLE derived AS SELECT * FROM source"), + ] + ) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + blocks = result["metadata.yaml"]["blocks"] + source_block = next(b for b in blocks if b["name"] == "source") + assert source_block["type"] == "data_loader" + + def test_dependent_query_is_transformer(self): + """Test that dependent query gets transformer type.""" + pipeline = Pipeline( + [ + ("source", "CREATE TABLE source AS SELECT 1 as id"), + ("derived", "CREATE TABLE derived AS SELECT * FROM source"), + ] + ) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + blocks = result["metadata.yaml"]["blocks"] + derived_block = next(b for b in blocks if b["name"] == "derived") + assert derived_block["type"] == "transformer" + + def test_data_loader_decorator_in_code(self): + """Test that data_loader block code contains correct decorator.""" + pipeline = Pipeline( + [ + ("source", "CREATE TABLE source AS SELECT 1 as id"), + ("derived", "CREATE TABLE derived AS SELECT * FROM source"), + ] + ) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + source_code = result["blocks"]["source"] + assert "@data_loader" in source_code + assert "def load_data" in source_code + + def test_transformer_decorator_in_code(self): + """Test that transformer block code contains correct decorator.""" + pipeline = Pipeline( + [ + ("source", "CREATE TABLE source AS SELECT 1 as id"), + ("derived", "CREATE TABLE derived AS SELECT * FROM source"), + ] + ) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + derived_code = result["blocks"]["derived"] + assert "@transformer" in derived_code + assert "def transform" in derived_code + + +class TestMageDependencies: + """Test upstream/downstream block wiring.""" + + def test_linear_chain_upstream(self): + """Test linear chain has correct upstream_blocks.""" + pipeline = Pipeline( + [ + ("step1", "CREATE TABLE step1 AS SELECT 1"), + ("step2", "CREATE TABLE step2 AS SELECT * FROM step1"), + ("step3", "CREATE TABLE step3 AS SELECT * FROM step2"), + ] + ) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + blocks = result["metadata.yaml"]["blocks"] + block_map = {b["name"]: b for b in blocks} + + assert block_map["step1"]["upstream_blocks"] == [] + assert block_map["step2"]["upstream_blocks"] == ["step1"] + assert block_map["step3"]["upstream_blocks"] == ["step2"] + + def test_linear_chain_downstream(self): + """Test linear chain has correct downstream_blocks.""" + pipeline = Pipeline( + [ + ("step1", "CREATE TABLE step1 AS SELECT 1"), + ("step2", "CREATE TABLE step2 AS SELECT * FROM step1"), + ("step3", "CREATE TABLE step3 AS SELECT * FROM step2"), + ] + ) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + blocks = result["metadata.yaml"]["blocks"] + block_map = {b["name"]: b for b in blocks} + + assert block_map["step1"]["downstream_blocks"] == ["step2"] + assert block_map["step2"]["downstream_blocks"] == ["step3"] + assert block_map["step3"]["downstream_blocks"] == [] + + def test_diamond_pattern_dependencies(self): + """Test diamond pattern wiring for all 4 blocks.""" + pipeline = Pipeline( + [ + ("source", "CREATE TABLE source AS SELECT 1 as id"), + ("left_branch", "CREATE TABLE left_branch AS SELECT * FROM source"), + ("right_branch", "CREATE TABLE right_branch AS SELECT * FROM source"), + ( + "final", + "CREATE TABLE final AS SELECT * FROM left_branch, right_branch", + ), + ] + ) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + blocks = result["metadata.yaml"]["blocks"] + block_map = {b["name"]: b for b in blocks} + + # Source has no upstream + assert block_map["source"]["upstream_blocks"] == [] + # Branches depend on source + assert block_map["left_branch"]["upstream_blocks"] == ["source"] + assert block_map["right_branch"]["upstream_blocks"] == ["source"] + # Final depends on both branches + assert sorted(block_map["final"]["upstream_blocks"]) == [ + "left_branch", + "right_branch", + ] + + # Source has both branches as downstream + assert sorted(block_map["source"]["downstream_blocks"]) == [ + "left_branch", + "right_branch", + ] + # Final has no downstream + assert block_map["final"]["downstream_blocks"] == [] + + def test_source_has_empty_upstream(self): + """Test that source blocks have empty upstream_blocks.""" + pipeline = Pipeline( + [ + ("source1", "CREATE TABLE source1 AS SELECT 1"), + ("source2", "CREATE TABLE source2 AS SELECT 2"), + ( + "combined", + "CREATE TABLE combined AS SELECT * FROM source1, source2", + ), + ] + ) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + blocks = result["metadata.yaml"]["blocks"] + block_map = {b["name"]: b for b in blocks} + + assert block_map["source1"]["upstream_blocks"] == [] + assert block_map["source2"]["upstream_blocks"] == [] + + def test_leaf_has_empty_downstream(self): + """Test that leaf blocks have empty downstream_blocks.""" + pipeline = Pipeline( + [ + ("source", "CREATE TABLE source AS SELECT 1"), + ("leaf", "CREATE TABLE leaf AS SELECT * FROM source"), + ] + ) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + blocks = result["metadata.yaml"]["blocks"] + block_map = {b["name"]: b for b in blocks} + + assert block_map["leaf"]["downstream_blocks"] == [] + + +class TestMageConfiguration: + """Test connection and configuration handling.""" + + def test_default_connection_name_in_code(self): + """Test default connection name clickhouse_default in block code.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + block_code = list(result["blocks"].values())[0] + assert "clickhouse_default" in block_code + + def test_custom_connection_name_in_code(self): + """Test custom connection name embedded in block code.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + connection_name="my_custom_conn", + ) + + block_code = list(result["blocks"].values())[0] + assert "my_custom_conn" in block_code + + def test_sql_embedded_in_block_code(self): + """Test that SQL is embedded in block code.""" + sql = "CREATE TABLE test_table AS SELECT 42 as value" + pipeline = Pipeline([("q1", sql)]) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + block_code = list(result["blocks"].values())[0] + assert sql in block_code + + def test_clickhouse_import_in_block_code(self): + """Test that ClickHouse import is present in block code by default.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + block_code = list(result["blocks"].values())[0] + assert "from mage_ai.io.clickhouse import ClickHouse" in block_code + + +class TestMageDbConnector: + """Test configurable database connector.""" + + def test_postgres_connector(self): + """Test postgres connector generates correct import and class.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + db_connector="postgres", + connection_name="pg_default", + ) + + block_code = list(result["blocks"].values())[0] + assert "from mage_ai.io.postgres import Postgres" in block_code + assert "Postgres.with_config" in block_code + assert "pg_default" in block_code + + def test_bigquery_connector(self): + """Test bigquery connector generates correct import and class.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + db_connector="bigquery", + connection_name="bq_default", + ) + + block_code = list(result["blocks"].values())[0] + assert "from mage_ai.io.bigquery import BigQuery" in block_code + assert "BigQuery.with_config" in block_code + assert "bq_default" in block_code + + def test_snowflake_connector(self): + """Test snowflake connector generates correct import and class.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + db_connector="snowflake", + connection_name="sf_default", + ) + + block_code = list(result["blocks"].values())[0] + assert "from mage_ai.io.snowflake import Snowflake" in block_code + assert "Snowflake.with_config" in block_code + assert "sf_default" in block_code + + def test_default_connector_is_clickhouse(self): + """Test that default db_connector is clickhouse.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + ) + + block_code = list(result["blocks"].values())[0] + assert "ClickHouse.with_config" in block_code + + def test_unsupported_connector_raises_error(self): + """Test that unsupported db_connector raises ValueError.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + with pytest.raises(ValueError, match="Unsupported db_connector"): + pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + db_connector="mysql", + ) + + +class TestMageConnectionNameValidation: + """Test connection_name validation to prevent code injection.""" + + def test_valid_connection_names(self): + """Test that valid connection names are accepted.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + for name in ["clickhouse_default", "my-conn", "conn123", "a_b-c"]: + result = pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + connection_name=name, + ) + block_code = list(result["blocks"].values())[0] + assert name in block_code + + def test_invalid_connection_name_raises_error(self): + """Test that invalid connection_name raises ValueError.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + with pytest.raises(ValueError, match="Invalid connection_name"): + pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + connection_name='"; import os; os.system("rm -rf /")', + ) + + def test_connection_name_with_spaces_raises_error(self): + """Test that connection_name with spaces raises ValueError.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + with pytest.raises(ValueError, match="Invalid connection_name"): + pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + connection_name="my connection", + ) + + def test_connection_name_with_dots_raises_error(self): + """Test that connection_name with dots raises ValueError.""" + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + + with pytest.raises(ValueError, match="Invalid connection_name"): + pipeline.to_mage_pipeline( + pipeline_name="test_pipeline", + connection_name="my.connection", + ) + + +class TestMageOrchestrator: + """Test MageOrchestrator class directly.""" + + def test_orchestrator_initialization(self): + """Test MageOrchestrator initialization.""" + from clgraph.orchestrators import MageOrchestrator + + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + orchestrator = MageOrchestrator(pipeline) + + assert orchestrator.pipeline == pipeline + assert orchestrator.table_graph == pipeline.table_graph + + def test_to_pipeline_config_returns_dict(self): + """Test to_pipeline_config returns dictionary.""" + from clgraph.orchestrators import MageOrchestrator + + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + orchestrator = MageOrchestrator(pipeline) + + result = orchestrator.to_pipeline_config(pipeline_name="test") + + assert isinstance(result, dict) + assert result["name"] == "test" + assert result["uuid"] == "test" + + def test_to_blocks_returns_dict(self): + """Test to_blocks returns dictionary.""" + from clgraph.orchestrators import MageOrchestrator + + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + orchestrator = MageOrchestrator(pipeline) + + result = orchestrator.to_blocks() + + assert isinstance(result, dict) + assert len(result) == 1 + + def test_to_pipeline_files_returns_combined_dict(self): + """Test to_pipeline_files returns combined dictionary.""" + from clgraph.orchestrators import MageOrchestrator + + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + orchestrator = MageOrchestrator(pipeline) + + result = orchestrator.to_pipeline_files( + pipeline_name="test", + ) + + assert "metadata.yaml" in result + assert "blocks" in result + assert isinstance(result["metadata.yaml"], dict) + assert isinstance(result["blocks"], dict) + + def test_sanitize_name(self): + """Test _sanitize_name works correctly.""" + from clgraph.orchestrators import MageOrchestrator + + pipeline = Pipeline([("q1", "CREATE TABLE t1 AS SELECT 1")]) + orchestrator = MageOrchestrator(pipeline) + + assert orchestrator._sanitize_name("my.table") == "my_table" + assert orchestrator._sanitize_name("my-table") == "my_table" + assert orchestrator._sanitize_name("my_table") == "my_table" + + +class TestMageComplexPipeline: + """Test Mage pipeline generation with complex pipelines.""" + + def test_enterprise_like_pipeline(self): + """Test with a pipeline similar to enterprise demo.""" + pipeline = Pipeline( + [ + ( + "raw_sales", + """ + CREATE TABLE raw_sales AS + SELECT + toDate('2024-01-01') + number as date, + number % 100 as product_id, + number % 10 as region_id, + rand() % 1000 as amount + FROM numbers(1000) + """, + ), + ( + "raw_products", + """ + CREATE TABLE raw_products AS + SELECT + number as product_id, + concat('Product ', toString(number)) as product_name, + rand() % 5 as category_id + FROM numbers(100) + """, + ), + ( + "sales_with_products", + """ + CREATE TABLE sales_with_products AS + SELECT + s.date, + s.product_id, + p.product_name, + s.region_id, + s.amount + FROM raw_sales s + JOIN raw_products p ON s.product_id = p.product_id + """, + ), + ( + "daily_summary", + """ + CREATE TABLE daily_summary AS + SELECT + date, + count() as num_sales, + sum(amount) as total_amount + FROM sales_with_products + GROUP BY date + """, + ), + ], + dialect="clickhouse", + ) + + result = pipeline.to_mage_pipeline( + pipeline_name="enterprise_pipeline", + ) + + metadata = result["metadata.yaml"] + assert metadata["name"] == "enterprise_pipeline" + assert len(metadata["blocks"]) == 4 + + block_map = {b["name"]: b for b in metadata["blocks"]} + + # Source blocks are data_loaders + assert block_map["raw_sales"]["type"] == "data_loader" + assert block_map["raw_products"]["type"] == "data_loader" + + # Dependent blocks are transformers + assert block_map["sales_with_products"]["type"] == "transformer" + assert block_map["daily_summary"]["type"] == "transformer" + + # sales_with_products depends on both raw tables + assert sorted(block_map["sales_with_products"]["upstream_blocks"]) == [ + "raw_products", + "raw_sales", + ] + + # daily_summary depends on sales_with_products + assert block_map["daily_summary"]["upstream_blocks"] == ["sales_with_products"] + + def test_10_step_linear_chain(self): + """Test 10-step linear chain has correct sequential dependencies.""" + queries = [] + for i in range(10): + if i == 0: + queries.append((f"step_{i}", f"CREATE TABLE step_{i} AS SELECT {i}")) + else: + queries.append( + ( + f"step_{i}", + f"CREATE TABLE step_{i} AS SELECT * FROM step_{i - 1}", + ) + ) + + pipeline = Pipeline(queries) + + result = pipeline.to_mage_pipeline( + pipeline_name="long_chain", + ) + + metadata = result["metadata.yaml"] + assert len(metadata["blocks"]) == 10 + + block_map = {b["name"]: b for b in metadata["blocks"]} + + # First step is data_loader with no upstream + assert block_map["step_0"]["type"] == "data_loader" + assert block_map["step_0"]["upstream_blocks"] == [] + + # All other steps are transformers with correct upstream + for i in range(1, 10): + assert block_map[f"step_{i}"]["type"] == "transformer" + assert block_map[f"step_{i}"]["upstream_blocks"] == [f"step_{i - 1}"] + + # Check downstream wiring + for i in range(9): + assert block_map[f"step_{i}"]["downstream_blocks"] == [f"step_{i + 1}"] + assert block_map["step_9"]["downstream_blocks"] == []