Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
1 change: 1 addition & 0 deletions pydough/database_connectors/database_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class DatabaseDialect(Enum):
ANSI = "ansi"
SQLITE = "sqlite"
SNOWFLAKE = "snowflake"
TRINO = "trino"
MYSQL = "mysql"
POSTGRES = "postgres"
BODOSQL = "bodosql"
Expand Down
3 changes: 3 additions & 0 deletions pydough/sqlglot/execute_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sqlglot.dialects import Postgres as PostgresDialect
from sqlglot.dialects import Snowflake as SnowflakeDialect
from sqlglot.dialects import SQLite as SQLiteDialect
from sqlglot.dialects import Trino as TrinoDialect
from sqlglot.dialects.mysql import MySQL
from sqlglot.errors import SqlglotError
from sqlglot.expressions import (
Expand Down Expand Up @@ -486,6 +487,8 @@ def convert_dialect_to_sqlglot(dialect: DatabaseDialect) -> SQLGlotDialect:
# The BodoSQL dialect is essentially a subset of the Snowflake SQL
# dialect without many of the extraneous features.
return SnowflakeDialect()
case DatabaseDialect.TRINO:
return TrinoDialect()
case DatabaseDialect.MYSQL:
return MySQLDialect()
case DatabaseDialect.POSTGRES:
Expand Down
4 changes: 4 additions & 0 deletions pydough/sqlglot/transform_bindings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"PostgresTransformBindings",
"SQLiteTransformBindings",
"SnowflakeTransformBindings",
"TrinoTransformBindings",
"bindings_from_dialect",
]

Expand All @@ -24,6 +25,7 @@
from .postgres_transform_bindings import PostgresTransformBindings
from .sf_transform_bindings import SnowflakeTransformBindings
from .sqlite_transform_bindings import SQLiteTransformBindings
from .trino_transform_bindings import TrinoTransformBindings

if TYPE_CHECKING:
from pydough.sqlglot.sqlglot_relational_visitor import SQLGlotRelationalVisitor
Expand Down Expand Up @@ -53,6 +55,8 @@ def bindings_from_dialect(
return SQLiteTransformBindings(configs, visitor)
case DatabaseDialect.SNOWFLAKE:
return SnowflakeTransformBindings(configs, visitor)
case DatabaseDialect.TRINO:
return TrinoTransformBindings(configs, visitor)
case DatabaseDialect.BODOSQL:
return BodoSQLTransformBindings(configs, visitor)
case DatabaseDialect.MYSQL:
Expand Down
134 changes: 134 additions & 0 deletions pydough/sqlglot/transform_bindings/trino_transform_bindings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Definition of SQLGlot transformation bindings for the Trino dialect.
"""

__all__ = ["TrinoTransformBindings"]


import sqlglot.expressions as sqlglot_expressions
from sqlglot.expressions import Expression as SQLGlotExpression

import pydough.pydough_operators as pydop
from pydough.configs import DayOfWeek
from pydough.types import PyDoughType

from .base_transform_bindings import BaseTransformBindings
from .sqlglot_transform_utils import DateTimeUnit, apply_parens


class TrinoTransformBindings(BaseTransformBindings):
"""
Subclass of BaseTransformBindings for the Trino dialect.
"""

@property
def values_alias_column(self) -> bool:
return False

PYDOP_TO_TRINO_FUNC: dict[pydop.PyDoughExpressionOperator, str] = {
pydop.STARTSWITH: "STARTS_WITH",
pydop.LPAD: "LPAD",
pydop.RPAD: "RPAD",
pydop.SIGN: "SIGN",
pydop.SMALLEST: "LEAST",
pydop.LARGEST: "GREATEST",
pydop.GETPART: "SPLIT_PART",
}
"""
Mapping of PyDough operators to equivalent Trino SQL function names
These are used to generate anonymous function calls in SQLGlot
"""

@property
def dialect_start_of_week(self) -> DayOfWeek:
"""
Which day of the week is considered the start of the week within the
SQL dialect. Individual dialects may override this.
"""
return DayOfWeek.MONDAY

@property
def dialect_dow_mapping(self) -> dict[str, int]:
return {
"Monday": 1,
"Tuesday": 2,
"Wednesday": 3,
"Thursday": 4,
"Friday": 5,
"Saturday": 6,
"Sunday": 7,
}

def convert_call_to_sqlglot(
self,
operator: pydop.PyDoughExpressionOperator,
args: list[SQLGlotExpression],
types: list[PyDoughType],
) -> SQLGlotExpression:
if operator in self.PYDOP_TO_TRINO_FUNC:
return sqlglot_expressions.Anonymous(
this=self.PYDOP_TO_TRINO_FUNC[operator], expressions=args
)

return super().convert_call_to_sqlglot(operator, args, types)

def convert_extract_datetime(
self,
args: list[SQLGlotExpression],
types: list[PyDoughType],
unit: DateTimeUnit,
) -> SQLGlotExpression:
# Update argument type to fit datetime
dt_expr: SQLGlotExpression = self.handle_datetime_base_arg(args[0])
func_expr: SQLGlotExpression
match unit:
case DateTimeUnit.YEAR:
func_expr = sqlglot_expressions.Year(this=dt_expr)
case DateTimeUnit.QUARTER:
func_expr = sqlglot_expressions.Quarter(this=dt_expr)
case DateTimeUnit.MONTH:
func_expr = sqlglot_expressions.Month(this=dt_expr)
case DateTimeUnit.DAY:
func_expr = sqlglot_expressions.Day(this=dt_expr)
case DateTimeUnit.HOUR | DateTimeUnit.MINUTE | DateTimeUnit.SECOND:
func_expr = sqlglot_expressions.Anonymous(
this=unit.value.upper(), expressions=[dt_expr]
)
return func_expr

def apply_datetime_truncation(
self, base: SQLGlotExpression, unit: DateTimeUnit
) -> SQLGlotExpression:
if unit is DateTimeUnit.WEEK:
# 1. Get shifted_weekday (# of days since the start of week)
# 2. Subtract shifted_weekday DAYS from the datetime
# 3. Truncate the result to the nearest day
shifted_weekday: SQLGlotExpression = self.days_from_start_of_week(base)
date_sub: SQLGlotExpression = sqlglot_expressions.DateSub(
this=base,
expression=shifted_weekday,
unit=sqlglot_expressions.Var(this="DAY"),
)
return sqlglot_expressions.DateTrunc(
this=date_sub,
unit=sqlglot_expressions.Var(this="DAY"),
)
else:
# For other units, use the standard SQLGlot truncation
return super().apply_datetime_truncation(base, unit)

def days_from_start_of_week(self, base: SQLGlotExpression) -> SQLGlotExpression:
offset: int = (-self.start_of_week_offset) % 7
dow_expr: SQLGlotExpression = self.dialect_day_of_week(base)
if offset == 1:
return dow_expr
breakpoint()
return sqlglot_expressions.Mod(
this=apply_parens(
sqlglot_expressions.Add(
this=dow_expr,
expression=sqlglot_expressions.Literal.number(offset - 1),
)
),
expression=sqlglot_expressions.Literal.number(7),
)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dev-dependencies = [
"pytest-repeat",
"boto3",
"pydough[snowflake]",
"pydough[trino]",
"pydough[mysql]",
"pydough[postgres]",
"pydough[server]",
Expand All @@ -42,6 +43,7 @@ dev-dependencies = [

[project.optional-dependencies]
snowflake = ["snowflake-connector-python[pandas]==4.1.1"]
trino = ["trino"]
mysql = ["mysql-connector-python==9.5.0"]
postgres = ["psycopg2-binary"]
server = ["fastapi", "httpx", "uvicorn"]
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
markers =
execute: marks tests that do runtime execution (deselect with '-m "not execute"')
snowflake: marks tests that require Snowflake credentials
trino: marks tests that require Trino credentials
mysql: marks tests that require MySQL credentials
postgres: marks tests that require PostgresSQL credentials
server: marks tests that require api mock server
Expand Down
110 changes: 106 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import httpx
import pandas as pd
import pytest
import trino
from botocore.exceptions import ClientError

import pydough
Expand Down Expand Up @@ -180,6 +181,14 @@ def sf_sample_graph_path() -> str:
return f"{os.path.dirname(__file__)}/test_metadata/snowflake_sample_graphs.json"


@pytest.fixture(scope="session")
def trino_graph_path() -> str:
"""
Tuple of the path to the JSON file containing the Trino sample graphs.
"""
return f"{os.path.dirname(__file__)}/test_metadata/trino_graphs.json"


@pytest.fixture(scope="session")
def udf_graph_path() -> str:
"""
Expand Down Expand Up @@ -431,6 +440,7 @@ def sqlite_dialects(request) -> DatabaseDialect:
pytest.param(DatabaseDialect.ANSI, id="ansi"),
pytest.param(DatabaseDialect.SQLITE, id="sqlite"),
pytest.param(DatabaseDialect.SNOWFLAKE, id="snowflake"),
pytest.param(DatabaseDialect.TRINO, id="trino"),
pytest.param(DatabaseDialect.MYSQL, id="mysql"),
pytest.param(DatabaseDialect.POSTGRES, id="postgres"),
]
Expand Down Expand Up @@ -559,6 +569,11 @@ def sqlite_tpch_session(
id="snowflake",
marks=[pytest.mark.snowflake],
),
pytest.param(
"trino",
id="trino",
marks=[pytest.mark.trino],
),
pytest.param(
"mysql",
id="mysql",
Expand All @@ -575,6 +590,7 @@ def all_dialects_tpch_db_context(
request,
get_sample_graph: graph_fetcher,
get_sf_sample_graph: graph_fetcher,
get_trino_graphs: graph_fetcher,
) -> tuple[DatabaseContext, GraphMetadata]:
"""
General fixture providing TPCH database context and graph metadata
Expand All @@ -594,6 +610,9 @@ def all_dialects_tpch_db_context(
sf_conn("SNOWFLAKE_SAMPLE_DATA", "TPCH_SF1"),
get_sf_sample_graph("TPCH"),
)
case "trino":
trino_conn = request.getfixturevalue("trino_conn_db_context")
return trino_conn, get_trino_graphs("TPCH")
case "mysql":
mysql_conn = request.getfixturevalue("mysql_conn_db_context")
return mysql_conn("tpch"), get_sample_graph("TPCH")
Expand Down Expand Up @@ -622,10 +641,11 @@ def impl(name: str) -> GraphMetadata:

@pytest.fixture(scope="session")
def get_dialect_defog_graphs(
defog_graphs,
get_mysql_defog_graphs,
get_sf_defog_graphs,
get_postgres_defog_graphs,
defog_graphs: graph_fetcher,
get_mysql_defog_graphs: graph_fetcher,
get_sf_defog_graphs: graph_fetcher,
get_trino_graphs: graph_fetcher,
get_postgres_defog_graphs: graph_fetcher,
) -> Callable[[DatabaseDialect, str], GraphMetadata]:
"""
Returns the graphs for the defog database based on the dialect
Expand All @@ -638,6 +658,8 @@ def impl(dialect: DatabaseDialect, name: str) -> GraphMetadata:
return get_mysql_defog_graphs(name)
case DatabaseDialect.SNOWFLAKE:
return get_sf_defog_graphs(name)
case DatabaseDialect.TRINO:
return get_trino_graphs(name)
case DatabaseDialect.POSTGRES:
return get_postgres_defog_graphs(name)
case _:
Expand Down Expand Up @@ -1098,6 +1120,86 @@ def container_is_running(name: str) -> bool:
return name in result.stdout.splitlines()


@pytest.fixture(scope="session")
def get_trino_graphs(trino_graph_path: str) -> graph_fetcher:
"""
A function that takes in the name of a graph from the supported Trino graph
names and returns the metadata for that PyDough graph.
"""

@cache
def impl(name: str) -> GraphMetadata:
return pydough.parse_json_metadata_from_file(
file_path=trino_graph_path, graph_name=name
)

return impl


@pytest.fixture(scope="session")
def get_trino_defog_graphs() -> graph_fetcher:
"""
Returns the graphs for the defog database in Trino.
"""

@cache
def impl(name: str) -> GraphMetadata:
path: str = f"{os.path.dirname(__file__)}/test_metadata/trino_defog_graphs.json"
return pydough.parse_json_metadata_from_file(file_path=path, graph_name=name)

return impl


def is_trino_env_set() -> bool:
"""
Check if the Trino environment variables are set.

Returns:
bool: True if all required Trino environment variables are set, False
otherwise.
"""
# TODO: add environment variables for Trino connection
required_envs: list[str] = []
return all(os.getenv(env) for env in required_envs)


@pytest.fixture
def trino_conn_db_context() -> Callable[[str, str], DatabaseContext]:
"""
This fixture is used to connect to the Trino TPCH database using
a connection object.
Return a DatabaseContext for the Trino TPCH database.
"""

def _impl(database_name: str, schema_name: str) -> DatabaseContext:
if not is_trino_env_set():
pytest.skip("Skipping Trino tests: environment variables not set.")

connection: trino.dbapi.Connection = trino.dbapi.connect(
# TODO: use the keyword arguments fetched from environment variables
)

return load_database_context("trino", connection=connection)

return _impl


@pytest.fixture
def trino_params_tpch_db_context() -> DatabaseContext:
"""
This fixture is used to connect to the Trino TPCH database using
parameters instead of a connection object.
Return a DatabaseContext for the Trino TPCH database.
"""
if not is_trino_env_set():
pytest.skip("Skipping Trino tests: environment variables not set.")
# TODO: add keyword arguments fetched from environment variables
return load_database_context(
"trino",
# TODO: use the keyword arguments
)


MYSQL_ENVS = ["MYSQL_USERNAME", "MYSQL_PASSWORD"]
"""
The MySQL environment variables required for connection.
Expand Down
Loading