Skip to content

Commit 96e8e89

Browse files
committed
use PandasCursor with Athena
for %%fetchdf magic only
1 parent 6afa17c commit 96e8e89

File tree

1 file changed

+73
-1
lines changed

1 file changed

+73
-1
lines changed

sqlmesh/magics.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,13 @@ def render(self, context: Context, line: str) -> None:
611611
def fetchdf(self, context: Context, line: str, sql: str) -> None:
612612
"""Fetches a dataframe from sql, optionally storing it in a variable."""
613613
args = parse_argstring(self.fetchdf, line)
614-
df = context.fetchdf(sql)
614+
615+
# Check if we're using Athena and use PandasCursor directly
616+
if hasattr(context.engine_adapter, 'DIALECT') and context.engine_adapter.DIALECT == 'athena':
617+
df = self._fetchdf_athena_pandas_cursor(context, sql)
618+
else:
619+
df = context.fetchdf(sql)
620+
615621
if args.df_var:
616622
self._shell.user_ns[args.df_var] = df
617623
self.display(df)
@@ -1147,6 +1153,72 @@ def destroy(self, context: Context, line: str) -> None:
11471153
"""Removes all project resources, engine-managed objects, state tables and clears the SQLMesh cache."""
11481154
context.destroy()
11491155

1156+
def _fetchdf_athena_pandas_cursor(self, context: Context, sql: str) -> "pd.DataFrame":
1157+
"""Special implementation for Athena using PandasCursor with SQLGlot transpilation"""
1158+
import pandas as pd
1159+
1160+
try:
1161+
from pyathena.pandas.cursor import PandasCursor
1162+
from pyathena import connect
1163+
except ImportError as e:
1164+
raise MagicError(f"PyAthena with pandas support is required: {e}")
1165+
1166+
# Use SQLMesh's transpilation to convert SQL to Athena dialect
1167+
# This handles features like QUALIFY that need transpilation
1168+
try:
1169+
# Parse the SQL string into a SQLGlot expression first
1170+
from sqlmesh.core.dialect import parse
1171+
parsed_expressions = parse(sql, default_dialect=context.config.dialect)
1172+
1173+
# Get the first expression (should be a SELECT statement)
1174+
if parsed_expressions:
1175+
transpiled_sql = context.engine_adapter._to_sql(parsed_expressions[0], quote=False)
1176+
else:
1177+
raise ValueError("No valid SQL expressions found")
1178+
1179+
except Exception as e:
1180+
context.console.log_error(f"SQL transpilation failed: {e}")
1181+
# Fall back to the regular fetchdf method if transpilation fails
1182+
return context.fetchdf(sql)
1183+
1184+
# Get the connection configuration for Athena
1185+
conn_config = context.config.get_connection(context.config.default_connection)
1186+
1187+
# Build connection kwargs using the same logic as SQLMesh
1188+
connection_kwargs = {
1189+
k: v for k, v in conn_config.dict().items()
1190+
if k in conn_config._connection_kwargs_keys and v is not None
1191+
}
1192+
1193+
# Create connection with PandasCursor specifically
1194+
try:
1195+
with connect(
1196+
cursor_class=PandasCursor,
1197+
**connection_kwargs
1198+
) as conn:
1199+
with conn.cursor() as cursor:
1200+
cursor.execute(transpiled_sql)
1201+
1202+
# PyAthena PandasCursor needs to be converted to DataFrame manually
1203+
# It returns data but we need to use pandas.DataFrame constructor
1204+
data = cursor.fetchall()
1205+
1206+
if data:
1207+
# Get column names from cursor description
1208+
columns = [desc[0] for desc in cursor.description] if cursor.description else None
1209+
df = pd.DataFrame(data, columns=columns)
1210+
else:
1211+
# Empty result set
1212+
columns = [desc[0] for desc in cursor.description] if cursor.description else []
1213+
df = pd.DataFrame(columns=columns)
1214+
1215+
return df
1216+
1217+
except Exception as e:
1218+
# Fall back to the regular fetchdf method if PandasCursor fails
1219+
context.console.log_error(f"PandasCursor failed, falling back to standard method: {e}")
1220+
return context.fetchdf(sql)
1221+
11501222

11511223
def register_magics() -> None:
11521224
try:

0 commit comments

Comments
 (0)