Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 107 additions & 2 deletions src/clgraph/lineage_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import sqlglot
from sqlglot import exp
from sqlglot.optimizer import qualify_columns

from .metadata_parser import MetadataExtractor
from .models import (
Expand Down Expand Up @@ -463,6 +464,106 @@ def build_parent_map(node: exp.Expression, parent: Optional[exp.Expression] = No
return outermost_nested


def _convert_to_nested_schema(
flat_schema: Dict[str, List[str]],
) -> Dict[str, Dict[str, Dict[str, str]]]:
"""
Convert flat table schema to nested format for sqlglot optimizer.

The sqlglot optimizer.qualify_columns requires a nested schema format:
{
"schema_name": {
"table_name": {
"column_name": "type"
}
}
}

Our flat format is:
{
"schema.table": ["col1", "col2", ...]
}

Args:
flat_schema: Dict mapping "schema.table" to list of column names

Returns:
Nested schema dict suitable for sqlglot optimizer
"""
nested: Dict[str, Dict[str, Dict[str, str]]] = {}

for qualified_table, columns in flat_schema.items():
parts = qualified_table.split(".")

if len(parts) >= 2:
# Has schema prefix: "schema.table" or "catalog.schema.table"
schema_name = parts[-2] # Second to last part
table_name = parts[-1] # Last part
else:
# No schema prefix - use empty string as schema
schema_name = ""
table_name = qualified_table

if schema_name not in nested:
nested[schema_name] = {}

if table_name not in nested[schema_name]:
nested[schema_name][table_name] = {}

for col in columns:
# Use "UNKNOWN" as type since we don't have type info
nested[schema_name][table_name][col] = "UNKNOWN"

return nested


def _qualify_sql_with_schema(
sql_query: str,
external_table_columns: Dict[str, List[str]],
dialect: str,
) -> str:
"""
Qualify unqualified column references in SQL using schema information.

When a SQL query has multiple tables joined and columns are unqualified
(no table prefix), this function uses the schema to determine which table
each column belongs to and adds the appropriate table prefix.

Args:
sql_query: The SQL query to qualify
external_table_columns: Dict mapping table names to column lists
dialect: SQL dialect for parsing

Returns:
The SQL query with qualified column references
"""
if not external_table_columns:
return sql_query

try:
# Parse the SQL
parsed = sqlglot.parse_one(sql_query, read=dialect)

# Convert to nested schema format
nested_schema = _convert_to_nested_schema(external_table_columns)

# Use sqlglot's qualify_columns to add table prefixes
qualified = qualify_columns.qualify_columns(
parsed,
schema=nested_schema,
dialect=dialect,
infer_schema=True,
)

# Return the qualified SQL
return qualified.sql(dialect=dialect)

except Exception:
# If qualification fails, return original SQL
# The lineage builder will handle unqualified columns as before
return sql_query


# ============================================================================
# Part 1: Recursive Lineage Builder
# ============================================================================
Expand All @@ -485,8 +586,12 @@ def __init__(
self.dialect = dialect
self.query_id = query_id

# Parse query structure first
parser = RecursiveQueryParser(sql_query, dialect=dialect)
# Qualify unqualified columns using schema info before parsing
# This ensures columns like "order_date" in a JOIN get the correct table prefix
qualified_sql = _qualify_sql_with_schema(sql_query, self.external_table_columns, dialect)

# Parse query structure using qualified SQL
parser = RecursiveQueryParser(qualified_sql, dialect=dialect)
self.unit_graph = parser.parse()

# Column lineage graph (to be built)
Expand Down
80 changes: 76 additions & 4 deletions src/clgraph/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def build(self, pipeline_or_graph) -> "Pipeline":
# Step 2a: Run single-query lineage
try:
# Extract SELECT statement from DDL/DML if needed
sql_for_lineage = self._extract_select_from_query(query)
sql_for_lineage = self._extract_select_from_query(query, pipeline.dialect)

if sql_for_lineage:
# Collect upstream table schemas from already-processed queries
Expand Down Expand Up @@ -708,23 +708,34 @@ def _is_physical_table_column(

return False

def _extract_select_from_query(self, query: ParsedQuery) -> Optional[str]:
def _extract_select_from_query(
self, query: ParsedQuery, dialect: str = "bigquery"
) -> Optional[str]:
"""
Extract SELECT statement from DDL/DML queries.
Single-query lineage only works on SELECT statements, so we need to extract
the SELECT from CREATE TABLE AS SELECT, INSERT INTO ... SELECT, etc.

Args:
query: The parsed query to extract SELECT from
dialect: SQL dialect for proper SQL serialization (important for functions
like DATE_TRUNC which have different argument orders in different dialects)

Returns:
The SELECT SQL string, or None if no SELECT found
"""
ast = query.ast

# CREATE TABLE/VIEW AS SELECT
if isinstance(ast, exp.Create):
if ast.expression and isinstance(ast.expression, exp.Select):
return ast.expression.sql()
# Use dialect to ensure proper SQL serialization
return ast.expression.sql(dialect=dialect)

# INSERT INTO ... SELECT
elif isinstance(ast, exp.Insert):
if ast.expression and isinstance(ast.expression, exp.Select):
return ast.expression.sql()
return ast.expression.sql(dialect=dialect)

# MERGE INTO statement - pass full SQL to lineage builder
elif isinstance(ast, exp.Merge):
Expand Down Expand Up @@ -2595,6 +2606,67 @@ def to_kestra_flow(
**kwargs,
)

# ========================================================================
# Orchestrator Methods - Mage
# ========================================================================

def to_mage_pipeline(
self,
executor: Callable[[str], None],
pipeline_name: str,
description: Optional[str] = None,
connection_name: str = "clickhouse_default",
) -> Dict[str, Any]:
"""
Generate Mage pipeline files from this pipeline.

Mage is a modern data pipeline tool with a notebook-style UI and
block-based architecture. Each SQL query becomes a block (either
data_loader or transformer).

Args:
executor: Function that executes SQL (for code reference)
pipeline_name: Name for the Mage pipeline
description: Optional pipeline description (auto-generated if not provided)
connection_name: Database connection name in Mage io_config.yaml

Returns:
Dictionary with pipeline file structure:
{
"metadata.yaml": <dict>,
"blocks": {"block_name": <code>, ...}
}

Examples:
# Generate Mage pipeline files
files = pipeline.to_mage_pipeline(
executor=execute_sql,
pipeline_name="enterprise_pipeline",
)

# Write files to Mage project
import yaml
with open("pipelines/enterprise_pipeline/metadata.yaml", "w") as f:
yaml.dump(files["metadata.yaml"], f)
for name, code in files["blocks"].items():
with open(f"pipelines/enterprise_pipeline/{name}.py", "w") as f:
f.write(code)

Note:
- First query (no dependencies) becomes data_loader block
- Subsequent queries become transformer blocks
- Dependencies are managed via upstream_blocks/downstream_blocks
- Requires mage-ai package and ClickHouse connection in io_config.yaml
"""
from .orchestrators import MageOrchestrator

return MageOrchestrator(self).to_pipeline_files(
executor=executor,
pipeline_name=pipeline_name,
description=description,
connection_name=connection_name,
)

# ========================================================================
# Validation Methods
# ========================================================================
Expand Down
Loading
Loading