From 0105501f043d2a908ff21dc17c69c81a8f2e9d6c Mon Sep 17 00:00:00 2001 From: Ming-Jer Lee Date: Fri, 16 Jan 2026 19:23:10 -0800 Subject: [PATCH] fix: Resolve unqualified columns in JOIN queries using schema info When a SQL query joins multiple tables and columns are unqualified (no table prefix), the lineage builder previously defaulted to the first table, which was often incorrect. For example, in: SELECT DATE_TRUNC(order_date, MONTH) as month FROM analytics.user_metrics JOIN staging.user_orders USING (user_id) The `order_date` column would be incorrectly attributed to `user_metrics` instead of `user_orders`, causing lineage edges to be dropped. This fix: 1. Uses sqlglot's qualify_columns optimizer with upstream table schemas to add correct table prefixes before building lineage 2. Fixes _extract_select_from_query to use dialect when serializing SQL, which was causing DATE_TRUNC arguments to be reordered incorrectly The 3-layer pipeline example now correctly shows: - staging.user_orders.order_date -> reports.monthly_revenue.month - trace_column_backward returns raw.orders.order_date Co-Authored-By: Claude Opus 4.5 --- src/clgraph/lineage_builder.py | 109 +++++- src/clgraph/pipeline.py | 80 ++++- tests/test_unqualified_column_resolution.py | 369 ++++++++++++++++++++ 3 files changed, 552 insertions(+), 6 deletions(-) create mode 100644 tests/test_unqualified_column_resolution.py diff --git a/src/clgraph/lineage_builder.py b/src/clgraph/lineage_builder.py index 72f3253..086f6cc 100644 --- a/src/clgraph/lineage_builder.py +++ b/src/clgraph/lineage_builder.py @@ -9,6 +9,7 @@ import sqlglot from sqlglot import exp +from sqlglot.optimizer import qualify_columns from .metadata_parser import MetadataExtractor from .models import ( @@ -463,6 +464,106 @@ def build_parent_map(node: exp.Expression, parent: Optional[exp.Expression] = No return outermost_nested +def _convert_to_nested_schema( + flat_schema: Dict[str, List[str]], +) -> Dict[str, Dict[str, Dict[str, str]]]: + """ + Convert flat table schema to nested format for sqlglot optimizer. + + The sqlglot optimizer.qualify_columns requires a nested schema format: + { + "schema_name": { + "table_name": { + "column_name": "type" + } + } + } + + Our flat format is: + { + "schema.table": ["col1", "col2", ...] + } + + Args: + flat_schema: Dict mapping "schema.table" to list of column names + + Returns: + Nested schema dict suitable for sqlglot optimizer + """ + nested: Dict[str, Dict[str, Dict[str, str]]] = {} + + for qualified_table, columns in flat_schema.items(): + parts = qualified_table.split(".") + + if len(parts) >= 2: + # Has schema prefix: "schema.table" or "catalog.schema.table" + schema_name = parts[-2] # Second to last part + table_name = parts[-1] # Last part + else: + # No schema prefix - use empty string as schema + schema_name = "" + table_name = qualified_table + + if schema_name not in nested: + nested[schema_name] = {} + + if table_name not in nested[schema_name]: + nested[schema_name][table_name] = {} + + for col in columns: + # Use "UNKNOWN" as type since we don't have type info + nested[schema_name][table_name][col] = "UNKNOWN" + + return nested + + +def _qualify_sql_with_schema( + sql_query: str, + external_table_columns: Dict[str, List[str]], + dialect: str, +) -> str: + """ + Qualify unqualified column references in SQL using schema information. + + When a SQL query has multiple tables joined and columns are unqualified + (no table prefix), this function uses the schema to determine which table + each column belongs to and adds the appropriate table prefix. + + Args: + sql_query: The SQL query to qualify + external_table_columns: Dict mapping table names to column lists + dialect: SQL dialect for parsing + + Returns: + The SQL query with qualified column references + """ + if not external_table_columns: + return sql_query + + try: + # Parse the SQL + parsed = sqlglot.parse_one(sql_query, read=dialect) + + # Convert to nested schema format + nested_schema = _convert_to_nested_schema(external_table_columns) + + # Use sqlglot's qualify_columns to add table prefixes + qualified = qualify_columns.qualify_columns( + parsed, + schema=nested_schema, + dialect=dialect, + infer_schema=True, + ) + + # Return the qualified SQL + return qualified.sql(dialect=dialect) + + except Exception: + # If qualification fails, return original SQL + # The lineage builder will handle unqualified columns as before + return sql_query + + # ============================================================================ # Part 1: Recursive Lineage Builder # ============================================================================ @@ -485,8 +586,12 @@ def __init__( self.dialect = dialect self.query_id = query_id - # Parse query structure first - parser = RecursiveQueryParser(sql_query, dialect=dialect) + # Qualify unqualified columns using schema info before parsing + # This ensures columns like "order_date" in a JOIN get the correct table prefix + qualified_sql = _qualify_sql_with_schema(sql_query, self.external_table_columns, dialect) + + # Parse query structure using qualified SQL + parser = RecursiveQueryParser(qualified_sql, dialect=dialect) self.unit_graph = parser.parse() # Column lineage graph (to be built) diff --git a/src/clgraph/pipeline.py b/src/clgraph/pipeline.py index edc6192..beaf54f 100644 --- a/src/clgraph/pipeline.py +++ b/src/clgraph/pipeline.py @@ -77,7 +77,7 @@ def build(self, pipeline_or_graph) -> "Pipeline": # Step 2a: Run single-query lineage try: # Extract SELECT statement from DDL/DML if needed - sql_for_lineage = self._extract_select_from_query(query) + sql_for_lineage = self._extract_select_from_query(query, pipeline.dialect) if sql_for_lineage: # Collect upstream table schemas from already-processed queries @@ -708,23 +708,34 @@ def _is_physical_table_column( return False - def _extract_select_from_query(self, query: ParsedQuery) -> Optional[str]: + def _extract_select_from_query( + self, query: ParsedQuery, dialect: str = "bigquery" + ) -> Optional[str]: """ Extract SELECT statement from DDL/DML queries. Single-query lineage only works on SELECT statements, so we need to extract the SELECT from CREATE TABLE AS SELECT, INSERT INTO ... SELECT, etc. + + Args: + query: The parsed query to extract SELECT from + dialect: SQL dialect for proper SQL serialization (important for functions + like DATE_TRUNC which have different argument orders in different dialects) + + Returns: + The SELECT SQL string, or None if no SELECT found """ ast = query.ast # CREATE TABLE/VIEW AS SELECT if isinstance(ast, exp.Create): if ast.expression and isinstance(ast.expression, exp.Select): - return ast.expression.sql() + # Use dialect to ensure proper SQL serialization + return ast.expression.sql(dialect=dialect) # INSERT INTO ... SELECT elif isinstance(ast, exp.Insert): if ast.expression and isinstance(ast.expression, exp.Select): - return ast.expression.sql() + return ast.expression.sql(dialect=dialect) # MERGE INTO statement - pass full SQL to lineage builder elif isinstance(ast, exp.Merge): @@ -2595,6 +2606,67 @@ def to_kestra_flow( **kwargs, ) + # ======================================================================== + # Orchestrator Methods - Mage + # ======================================================================== + + def to_mage_pipeline( + self, + executor: Callable[[str], None], + pipeline_name: str, + description: Optional[str] = None, + connection_name: str = "clickhouse_default", + ) -> Dict[str, Any]: + """ + Generate Mage pipeline files from this pipeline. + + Mage is a modern data pipeline tool with a notebook-style UI and + block-based architecture. Each SQL query becomes a block (either + data_loader or transformer). + + Args: + executor: Function that executes SQL (for code reference) + pipeline_name: Name for the Mage pipeline + description: Optional pipeline description (auto-generated if not provided) + connection_name: Database connection name in Mage io_config.yaml + + Returns: + Dictionary with pipeline file structure: + { + "metadata.yaml": , + "blocks": {"block_name": , ...} + } + + Examples: + # Generate Mage pipeline files + files = pipeline.to_mage_pipeline( + executor=execute_sql, + pipeline_name="enterprise_pipeline", + ) + + # Write files to Mage project + import yaml + with open("pipelines/enterprise_pipeline/metadata.yaml", "w") as f: + yaml.dump(files["metadata.yaml"], f) + for name, code in files["blocks"].items(): + with open(f"pipelines/enterprise_pipeline/{name}.py", "w") as f: + f.write(code) + + Note: + - First query (no dependencies) becomes data_loader block + - Subsequent queries become transformer blocks + - Dependencies are managed via upstream_blocks/downstream_blocks + - Requires mage-ai package and ClickHouse connection in io_config.yaml + """ + from .orchestrators import MageOrchestrator + + return MageOrchestrator(self).to_pipeline_files( + executor=executor, + pipeline_name=pipeline_name, + description=description, + connection_name=connection_name, + ) + # ======================================================================== # Validation Methods # ======================================================================== diff --git a/tests/test_unqualified_column_resolution.py b/tests/test_unqualified_column_resolution.py new file mode 100644 index 0000000..a0ff60c --- /dev/null +++ b/tests/test_unqualified_column_resolution.py @@ -0,0 +1,369 @@ +""" +Tests for Unqualified Column Resolution in Multi-Table Queries + +This test suite covers the fix for resolving unqualified column references +when multiple tables are joined. The fix uses sqlglot's qualify_columns +optimizer with schema information to determine which table each column +belongs to. + +Issue: When a query has multiple tables joined and columns are unqualified +(no table prefix), the lineage builder needs to determine which table +each column comes from. Previously, it would default to the first table, +which was often wrong. + +Fix: Use schema information from upstream queries to qualify columns +before building lineage. +""" + +from clgraph import Pipeline +from clgraph.lineage_builder import ( + RecursiveLineageBuilder, + _convert_to_nested_schema, + _qualify_sql_with_schema, +) + +# ============================================================================ +# Part 1: Helper Function Tests +# ============================================================================ + + +class TestSchemaConversion: + """Test the flat-to-nested schema conversion helper""" + + def test_convert_simple_schema(self): + """Test converting a simple flat schema to nested format""" + flat = { + "staging.orders": ["order_id", "user_id", "amount"], + "analytics.users": ["user_id", "name", "email"], + } + + nested = _convert_to_nested_schema(flat) + + assert "staging" in nested + assert "analytics" in nested + assert "orders" in nested["staging"] + assert "users" in nested["analytics"] + assert "order_id" in nested["staging"]["orders"] + assert "user_id" in nested["analytics"]["users"] + + def test_convert_schema_without_prefix(self): + """Test converting schema without schema prefix""" + flat = { + "orders": ["order_id", "user_id"], + } + + nested = _convert_to_nested_schema(flat) + + # Should use empty string as schema + assert "" in nested + assert "orders" in nested[""] + + def test_convert_three_part_name(self): + """Test converting three-part table names (catalog.schema.table)""" + flat = { + "myproject.staging.orders": ["order_id", "amount"], + } + + nested = _convert_to_nested_schema(flat) + + # Should use the last two parts + assert "staging" in nested + assert "orders" in nested["staging"] + + +class TestSqlQualification: + """Test the SQL qualification helper""" + + def test_qualify_single_table(self): + """Test that single-table queries are unchanged""" + sql = "SELECT order_id, amount FROM orders" + schema = {"orders": ["order_id", "amount"]} + + result = _qualify_sql_with_schema(sql, schema, "bigquery") + + # Should still work (columns may be qualified) + assert "order_id" in result + assert "amount" in result + + def test_qualify_multi_table_join(self): + """Test qualification with multiple tables joined""" + sql = """ + SELECT order_date, total_revenue + FROM analytics.user_metrics + JOIN staging.user_orders USING (user_id) + """ + schema = { + "staging.user_orders": ["user_id", "order_date", "amount"], + "analytics.user_metrics": ["user_id", "total_revenue"], + } + + result = _qualify_sql_with_schema(sql, schema, "bigquery") + + # order_date should be qualified with user_orders + assert "user_orders" in result and "order_date" in result + # total_revenue should be qualified with user_metrics + assert "user_metrics" in result and "total_revenue" in result + + def test_qualify_empty_schema(self): + """Test that empty schema returns original SQL""" + sql = "SELECT order_id FROM orders" + + result = _qualify_sql_with_schema(sql, {}, "bigquery") + + assert result == sql + + def test_qualify_date_trunc_bigquery(self): + """Test DATE_TRUNC qualification in BigQuery dialect""" + sql = "SELECT DATE_TRUNC(order_date, MONTH) as month FROM staging.orders" + schema = { + "staging.orders": ["order_id", "order_date"], + } + + result = _qualify_sql_with_schema(sql, schema, "bigquery") + + # Should preserve DATE_TRUNC with proper column + assert "DATE_TRUNC" in result.upper() + assert "order_date" in result.lower() + + +# ============================================================================ +# Part 2: RecursiveLineageBuilder Tests +# ============================================================================ + + +class TestLineageBuilderWithSchema: + """Test RecursiveLineageBuilder with schema-based qualification""" + + def test_lineage_with_qualified_columns(self): + """Test lineage building when columns are already qualified""" + sql = """ + SELECT user_orders.order_date, user_metrics.total_revenue + FROM analytics.user_metrics + JOIN staging.user_orders USING (user_id) + """ + schema = { + "staging.user_orders": ["user_id", "order_date"], + "analytics.user_metrics": ["user_id", "total_revenue"], + } + + builder = RecursiveLineageBuilder(sql, external_table_columns=schema, dialect="bigquery") + lineage = builder.build() + + # Should have both columns as inputs + input_nodes = [n for n in lineage.nodes.values() if n.layer == "input"] + input_names = [n.column_name for n in input_nodes] + + assert "order_date" in input_names + assert "total_revenue" in input_names + + def test_lineage_with_unqualified_columns(self): + """Test lineage building when columns are unqualified""" + sql = """ + SELECT order_date, total_revenue + FROM analytics.user_metrics + JOIN staging.user_orders USING (user_id) + """ + schema = { + "staging.user_orders": ["user_id", "order_date"], + "analytics.user_metrics": ["user_id", "total_revenue"], + } + + builder = RecursiveLineageBuilder(sql, external_table_columns=schema, dialect="bigquery") + lineage = builder.build() + + # Should have both columns as inputs with correct table attribution + input_nodes = [n for n in lineage.nodes.values() if n.layer == "input"] + + # Find order_date node - should be from user_orders + order_date_nodes = [n for n in input_nodes if n.column_name == "order_date"] + assert len(order_date_nodes) == 1 + assert order_date_nodes[0].table_name == "user_orders" + + # Find total_revenue node - should be from user_metrics + revenue_nodes = [n for n in input_nodes if n.column_name == "total_revenue"] + assert len(revenue_nodes) == 1 + assert revenue_nodes[0].table_name == "user_metrics" + + def test_lineage_date_trunc_expression(self): + """Test lineage for DATE_TRUNC expression with unqualified column""" + sql = """ + SELECT DATE_TRUNC(order_date, MONTH) as month + FROM staging.orders + """ + schema = { + "staging.orders": ["order_id", "order_date", "amount"], + } + + builder = RecursiveLineageBuilder(sql, external_table_columns=schema, dialect="bigquery") + lineage = builder.build() + + # Should have edge from order_date to month + edges = list(lineage.edges) + assert len(edges) == 1 + assert edges[0].from_node.column_name == "order_date" + assert edges[0].to_node.column_name == "month" + + +# ============================================================================ +# Part 3: Pipeline Tests (Cross-Query Lineage) +# ============================================================================ + + +class TestPipelineUnqualifiedColumns: + """Test Pipeline with unqualified column resolution""" + + def test_three_layer_pipeline_date_trunc(self): + """Test the 3-layer pipeline example with DATE_TRUNC""" + queries = [ + """CREATE TABLE staging.user_orders AS + SELECT user_id, order_id, amount, order_date + FROM raw.orders + WHERE status = 'completed'""", + """CREATE TABLE analytics.user_metrics AS + SELECT user_id, COUNT(*) as order_count, SUM(amount) as total_revenue + FROM staging.user_orders + GROUP BY user_id""", + """CREATE TABLE reports.monthly_revenue AS + SELECT DATE_TRUNC(order_date, MONTH) as month, SUM(total_revenue) as revenue + FROM analytics.user_metrics + JOIN staging.user_orders USING (user_id) + GROUP BY month""", + ] + + pipeline = Pipeline.from_sql_list(queries, dialect="bigquery") + + # Check that month has lineage to order_date + month_edges = [ + e for e in pipeline.edges if e.to_node.full_name == "reports.monthly_revenue.month" + ] + assert len(month_edges) == 1 + assert month_edges[0].from_node.column_name == "order_date" + assert month_edges[0].from_node.table_name == "staging.user_orders" + + def test_trace_column_backward_through_join(self): + """Test tracing a column backward through a JOIN""" + queries = [ + """CREATE TABLE staging.orders AS + SELECT order_id, customer_id, amount, order_date + FROM raw.orders""", + """CREATE TABLE staging.customers AS + SELECT customer_id, name, email + FROM raw.customers""", + """CREATE TABLE reports.customer_orders AS + SELECT name, order_date, amount + FROM staging.customers + JOIN staging.orders USING (customer_id)""", + ] + + pipeline = Pipeline.from_sql_list(queries, dialect="bigquery") + + # Trace order_date backward + sources = pipeline.trace_column_backward("reports.customer_orders", "order_date") + source_names = [f"{s.table_name}.{s.column_name}" for s in sources] + + assert "raw.orders.order_date" in source_names + + # Trace name backward + sources = pipeline.trace_column_backward("reports.customer_orders", "name") + source_names = [f"{s.table_name}.{s.column_name}" for s in sources] + + assert "raw.customers.name" in source_names + + def test_ambiguous_column_resolved_correctly(self): + """Test that ambiguous columns are resolved to the correct table""" + queries = [ + """CREATE TABLE staging.table_a AS + SELECT id, value_a FROM raw.source_a""", + """CREATE TABLE staging.table_b AS + SELECT id, value_b FROM raw.source_b""", + """CREATE TABLE reports.combined AS + SELECT value_a, value_b + FROM staging.table_a + JOIN staging.table_b USING (id)""", + ] + + pipeline = Pipeline.from_sql_list(queries, dialect="bigquery") + + # value_a should come from table_a + value_a_edges = [ + e for e in pipeline.edges if e.to_node.full_name == "reports.combined.value_a" + ] + assert len(value_a_edges) == 1 + assert value_a_edges[0].from_node.table_name == "staging.table_a" + + # value_b should come from table_b + value_b_edges = [ + e for e in pipeline.edges if e.to_node.full_name == "reports.combined.value_b" + ] + assert len(value_b_edges) == 1 + assert value_b_edges[0].from_node.table_name == "staging.table_b" + + +# ============================================================================ +# Part 4: Edge Cases +# ============================================================================ + + +class TestEdgeCases: + """Test edge cases for unqualified column resolution""" + + def test_column_in_both_tables(self): + """Test when a column name exists in both tables (like 'id')""" + queries = [ + """CREATE TABLE staging.orders AS + SELECT id, amount FROM raw.orders""", + """CREATE TABLE staging.users AS + SELECT id, name FROM raw.users""", + """CREATE TABLE reports.summary AS + SELECT o.id as order_id, u.id as user_id, amount, name + FROM staging.orders o + JOIN staging.users u ON o.id = u.id""", + ] + + pipeline = Pipeline.from_sql_list(queries, dialect="bigquery") + + # amount should come from orders + amount_edges = [ + e for e in pipeline.edges if e.to_node.full_name == "reports.summary.amount" + ] + assert len(amount_edges) == 1 + assert amount_edges[0].from_node.table_name == "staging.orders" + + # name should come from users + name_edges = [e for e in pipeline.edges if e.to_node.full_name == "reports.summary.name"] + assert len(name_edges) == 1 + assert name_edges[0].from_node.table_name == "staging.users" + + def test_aggregate_with_unqualified_column(self): + """Test aggregate functions with unqualified columns""" + queries = [ + """CREATE TABLE staging.orders AS + SELECT user_id, amount FROM raw.orders""", + """CREATE TABLE staging.users AS + SELECT user_id, name FROM raw.users""", + """CREATE TABLE reports.totals AS + SELECT SUM(amount) as total_amount, COUNT(name) as user_count + FROM staging.orders + JOIN staging.users USING (user_id)""", + ] + + pipeline = Pipeline.from_sql_list(queries, dialect="bigquery") + + # total_amount should come from orders.amount + amount_edges = [ + e for e in pipeline.edges if e.to_node.full_name == "reports.totals.total_amount" + ] + assert len(amount_edges) == 1 + assert amount_edges[0].from_node.column_name == "amount" + + def test_no_schema_fallback(self): + """Test fallback behavior when no schema is available""" + sql = "SELECT order_id, amount FROM orders" + + # Without external_table_columns, should still work (default behavior) + builder = RecursiveLineageBuilder(sql, dialect="bigquery") + lineage = builder.build() + + # Should have output columns + output_nodes = [n for n in lineage.nodes.values() if n.layer == "output"] + assert len(output_nodes) == 2