Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 67 additions & 12 deletions pydough/conversion/column_bubbler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import re

import pydough.pydough_operators as pydop
from pydough.relational import (
Aggregate,
CallExpression,
Expand All @@ -17,6 +18,7 @@
Filter,
Join,
Limit,
LiteralExpression,
Project,
RelationalExpression,
RelationalNode,
Expand Down Expand Up @@ -48,7 +50,63 @@ def name_sort_key(name: str) -> tuple[bool, bool, str]:
)


def generate_cleaner_names(expr: RelationalExpression, current_name: str) -> list[str]:
binop_namings: dict[pydop.PyDoughExpressionOperator, str] = {
pydop.ADD: "plus",
pydop.SUB: "minus",
pydop.MUL: "times",
pydop.DIV: "div",
pydop.MOD: "mod",
pydop.POW: "pow",
pydop.STARTSWITH: "startswith",
pydop.ENDSWITH: "endswith",
pydop.CONTAINS: "contains",
}
"""
TODO
"""


def make_cleaner_name(expr: CallExpression) -> str | None:
"""
TODO
"""
input_names: list[str] = []
arg_name: str | None
for arg in expr.inputs:
if isinstance(arg, ColumnReference):
arg_name = arg.name
# Remove any non-alphanumeric characters to make a cleaner name
# and underscores
arg_name = re.sub(r"[^a-zA-Z0-9_]", "", arg_name)
input_names.append(arg_name)
elif isinstance(arg, CallExpression):
arg_name = make_cleaner_name(arg)
if arg_name is None:
return None
input_names.append(arg_name)
elif isinstance(arg, LiteralExpression):
# For literals, use their value directly in the name if it's
# a simple type
if isinstance(arg.value, (str, int, float, bool)):
arg_name = str(arg.value)
arg_name = re.sub(r"[^a-zA-Z0-9_]", "", arg_name)
input_names.append(arg_name)
else:
return None
else:
return None
cleaner_name: str | None = None
if len(expr.inputs) == 1:
cleaner_name = f"{expr.op.function_name.lower()}_{input_names[0]}"
elif len(expr.inputs) == 2 and expr.op in binop_namings:
cleaner_name = f"{input_names[0]}_{binop_namings[expr.op]}_{input_names[1]}"

if cleaner_name is not None and cleaner_name.isidentifier():
return cleaner_name
return None


def generate_cleaner_names(expr: RelationalExpression, current_name) -> list[str]:
"""
Generates more readable names for an expression based on its, if applicable.
The patterns of name generation are:
Expand All @@ -73,21 +131,18 @@ def generate_cleaner_names(expr: RelationalExpression, current_name: str) -> lis
"""
result: list[str] = []
if isinstance(expr, CallExpression):
if len(expr.inputs) == 1:
input_expr = expr.inputs[0]
if isinstance(input_expr, ColumnReference):
input_name: str = input_expr.name
# Remove any non-alphanumeric characters to make a cleaner name
# and underscores
input_name = re.sub(r"[^a-zA-Z0-9_]", "", input_name)
cleaner_name: str = f"{expr.op.function_name.lower()}_{input_name}"

result.append(cleaner_name)
cleaner_name: str | None = make_cleaner_name(expr)
if cleaner_name is not None:
result.append(cleaner_name)

if len(expr.inputs) == 0 and expr.op.function_name.lower() == "count":
result.append("n_rows")

if not (current_name.startswith("agg") or current_name.startswith("expr")):
if not (
current_name is None
or current_name.startswith("agg")
or current_name.startswith("expr")
):
if re.match(r"^(.*)_[0-9]+$", current_name):
result.append(re.findall(r"^(.*)_[0-9]+$", current_name)[0])
return result
Expand Down
220 changes: 220 additions & 0 deletions pydough/conversion/hybrid_filter_merger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
"""
Logic to merge multiple subtrees in the hybrid tree into one if they are the
same except one of them has more filters than the other and is only used in
a COUNT aggregation, meaning the filter can be implemented by doing a SUM on
the less-filtered subtree where the SUM argument is the additional filters.
"""

import copy
from typing import TYPE_CHECKING

import pydough.pydough_operators as pydop
from pydough.qdag import Literal
from pydough.types import BooleanType, NumericType

from .hybrid_connection import ConnectionType
from .hybrid_expressions import (
HybridChildRefExpr,
HybridExpr,
HybridFunctionExpr,
HybridLiteralExpr,
)
from .hybrid_operations import (
HybridCalculate,
HybridFilter,
HybridLimit,
)
from .hybrid_tree import HybridTree

if TYPE_CHECKING:
from .hybrid_translator import HybridTranslator


class HybridFilterMerger:
"""
TODO
"""

def __init__(self, translator: "HybridTranslator") -> None:
self.translator: HybridTranslator = translator

def merge_filters(self, tree: HybridTree) -> None:
"""
TODO
"""
# Run the main procedure on subtrees with multiple children.
if len(tree.children) > 1:
# Identify which children are only used by a COUNT aggregation that is
# not ONLY_MATCH.
mergeable_children: set[int] = self.identify_mergeable_children(tree)

# TODO ADD COMMENT
child_filters: list[set[HybridExpr]] = [
self.get_final_filters(child.subtree) for child in tree.children
]

# TODO ADD COMMENT
child_isomorphisms: list[set[int]] = self.get_child_isomorphisms(tree)

# TODO ADD COMMENT
filter_dag: list[int | None] = self.make_filter_dag(
mergeable_children, child_filters, child_isomorphisms
)

# TODO ADD COMMENT
replacement_map: dict[HybridExpr, HybridExpr] = {}
for source_idx, target_idx in enumerate(filter_dag):
if target_idx is None:
continue
extra_filters: set[HybridExpr] = (
child_filters[source_idx] - child_filters[target_idx]
)
assert len(extra_filters) > 0
new_cond: HybridExpr
if len(extra_filters) == 1:
new_cond = next(iter(extra_filters))
else:
new_cond = HybridFunctionExpr(
pydop.BAN,
sorted(extra_filters, key=repr),
BooleanType(),
)
numeric_expr: HybridExpr = HybridFunctionExpr(
pydop.IFF,
[
new_cond,
HybridLiteralExpr(Literal(1, NumericType())),
HybridLiteralExpr(Literal(0, NumericType())),
],
NumericType(),
)
sum_expr: HybridFunctionExpr = HybridFunctionExpr(
pydop.SUM,
[numeric_expr],
NumericType(),
)
agg_name: str = self.translator.gen_agg_name(tree.children[target_idx])
tree.children[target_idx].aggs[agg_name] = sum_expr
agg_ref: HybridExpr = HybridChildRefExpr(
agg_name, target_idx, NumericType()
)
old_agg_ref = HybridChildRefExpr(
next(iter(tree.children[source_idx].aggs)),
source_idx,
NumericType(),
)
replacement_map[old_agg_ref] = agg_ref
tree.children[target_idx].max_steps = min(
tree.children[target_idx].max_steps,
tree.children[source_idx].max_steps,
)
tree.children[target_idx].min_steps = min(
tree.children[target_idx].min_steps,
tree.children[source_idx].min_steps,
)

# TODO ADD COMMENT
for operation in tree.pipeline:
operation.replace_expressions(replacement_map)

tree.remove_dead_children(set())

# Run the procedure recursively on the parent tree and the child
# subtrees.
if tree.parent is not None:
self.merge_filters(tree.parent)
for child in tree.children:
self.merge_filters(child.subtree)

def identify_mergeable_children(self, tree: HybridTree) -> set[int]:
"""
TODO
"""
return {
idx
for idx, child in enumerate(tree.children)
if (
child.connection_type == ConnectionType.AGGREGATION
and {repr(v) for v in child.aggs.values()} == {"COUNT()"}
)
}

def get_final_filters(self, tree: HybridTree) -> set[HybridExpr]:
"""
TODO
"""
result: set[HybridExpr] = set()
for operation in reversed(tree.pipeline):
if isinstance(operation, HybridFilter):
if operation.condition.contains_window_functions():
break
result.update(operation.condition.get_conjunction())
elif isinstance(operation, HybridLimit):
break
elif isinstance(operation, HybridCalculate):
if any(
expr.contains_window_functions()
for expr in operation.new_expressions.values()
):
break
return result

def get_child_isomorphisms(self, tree: HybridTree) -> list[set[int]]:
"""
TODO
"""
filter_stripped_forms: list[str] = [
self.get_filter_stripped_form(child.subtree) for child in tree.children
]
result: list[set[int]] = []
for i, form in enumerate(filter_stripped_forms):
alternatives: set[int] = set()
for j, other_form in enumerate(filter_stripped_forms):
if i != j and form == other_form:
alternatives.add(j)
result.append(alternatives)
return result

def get_filter_stripped_form(self, tree: HybridTree) -> str:
"""
TODO
"""
stripped_tree = copy.deepcopy(tree)
for idx, operation in reversed(list(enumerate(stripped_tree.pipeline))):
if isinstance(operation, HybridFilter):
if operation.condition.contains_window_functions():
break
stripped_tree.pipeline.pop(idx)
elif isinstance(operation, HybridLimit):
break
elif isinstance(operation, HybridCalculate):
if any(
expr.contains_window_functions()
for expr in operation.new_expressions.values()
):
break
return repr(stripped_tree)

def make_filter_dag(
self,
mergeable_children: set[int],
child_filters: list[set[HybridExpr]],
child_isomorphisms: list[set[int]],
) -> list[int | None]:
"""
TODO
"""
dag: list[int | None] = [None for _ in range(len(child_filters))]
for idx in mergeable_children:
for other_idx in child_isomorphisms[idx]:
if child_filters[other_idx] < child_filters[idx]:
dag[idx] = other_idx
break
for idx in range(len(dag)):
if dag[idx] is not None:
while True:
target_idx: int | None = dag[idx]
if target_idx is None or dag[target_idx] is None:
break
dag[idx] = dag[target_idx]
return dag
23 changes: 19 additions & 4 deletions pydough/conversion/hybrid_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
HybridSidedRefExpr,
HybridWindowExpr,
)
from .hybrid_filter_merger import HybridFilterMerger
from .hybrid_operations import (
HybridCalculate,
HybridCollectionAccess,
Expand Down Expand Up @@ -1701,6 +1702,19 @@ def run_hybrid_decorrelation(self, hybrid: "HybridTree") -> None:
decorr.find_correlated_children(hybrid)
decorr.decorrelate_hybrid_tree(hybrid)

def run_filter_merging(self, hybrid: "HybridTree") -> None:
"""
Invokes the procedure to merge identical child subtrees in the hybrid
tree if they are identical except for the filters they have, which can
be emulated via a SUM on a predicate. The transformation is done
in-place.

Args:
`hybrid`: The hybrid tree to run filter merging on.
"""
filter_merger: HybridFilterMerger = HybridFilterMerger(self)
filter_merger.merge_filters(hybrid)

def convert_qdag_to_hybrid(self, node: PyDoughCollectionQDAG) -> HybridTree:
"""
Convert a PyDough QDAG node to a hybrid tree, including any necessary
Expand All @@ -1725,10 +1739,11 @@ def convert_qdag_to_hybrid(self, node: PyDoughCollectionQDAG) -> HybridTree:
self.run_correlation_extraction(hybrid)
# 5. Run the de-correlation procedure.
self.run_hybrid_decorrelation(hybrid)
# 6. Run any final rewrites, such as turning MEDIAN into an average
# 5. Run the filter-merging procedure, then re-run ejecting aggregate
# inputs to clean up any new aggregates created by filter merging.
self.run_filter_merging(hybrid)
self.eject_aggregate_inputs(hybrid)
# 7. Run any final rewrites, such as turning MEDIAN into an average
# of the 1-2 median rows, that must happen after de-correlation.
self.run_rewrites(hybrid)
# 7. Remove any dead children in the hybrid tree that are no longer
# being used.
hybrid.remove_dead_children(set())
return hybrid
9 changes: 8 additions & 1 deletion pydough/conversion/relational_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,10 +731,17 @@ def handle_children(
child_output = self.apply_aggregations(
child, child_output, child.subtree.agg_keys
)
# Optimize SEMI to INNER for singular subtrees
join_type = child.connection_type.join_type
if (
child.connection_type == ConnectionType.SEMI
and child.subtree.is_singular()
):
join_type = JoinType.INNER
context = self.join_outputs(
context,
child_output,
child.connection_type.join_type,
join_type,
cardinality,
child.reverse_cardinality,
join_keys,
Expand Down
Loading