From 5fdd45913b213451a6a809b7f63ca44c556e723c Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Wed, 10 Dec 2025 11:10:23 +0400 Subject: [PATCH] Update `sqlparser` to 0.60, streamline `just` builds --- Cargo.toml | 2 +- examples/depgraph.py | 10 +++------- justfile | 14 +++++++++----- pyproject.toml | 13 +++++++++++++ sqloxide.pyi | 6 ++++-- sqloxide/__init__.py | 2 +- src/lib.rs | 2 +- src/visitor.rs | 37 +++++++++++++++++++------------------ tests/__init__.py | 0 tests/benchmark.py | 12 ------------ tests/test_sqloxide.py | 4 ++-- 11 files changed, 53 insertions(+), 49 deletions(-) create mode 100644 tests/__init__.py diff --git a/Cargo.toml b/Cargo.toml index 2f88351..38bc7d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,5 +17,5 @@ version = "0.22" features = ["extension-module"] [dependencies.sqlparser] -version = "0.56.0" +version = "0.60.0" features = ["serde", "visitor"] diff --git a/examples/depgraph.py b/examples/depgraph.py index cae9c37..9d58ecc 100755 --- a/examples/depgraph.py +++ b/examples/depgraph.py @@ -4,8 +4,6 @@ """ import argparse -import json -import os from glob import glob from typing import List @@ -16,6 +14,7 @@ parser.add_argument("--path", "-p", type=str, help="The path to process queries for.") parser.add_argument("--dialect", "-d", type=str, help="The dialect to use.") + def get_sql_files(path: str) -> List[str]: return glob(path + "/**/*.sql") @@ -31,7 +30,6 @@ def get_key_recursive(search_dict, field): fields_found = [] for key, value in search_dict.items(): - if key == field: fields_found.append(value) @@ -51,7 +49,6 @@ def get_key_recursive(search_dict, field): def get_tables_in_query(SQL: str, dialect: str) -> List[str]: - res = sqloxide.parse_sql(sql=SQL, dialect=dialect) tables = get_key_recursive(res[0]["Query"], "Table") @@ -64,11 +61,10 @@ def get_tables_in_query(SQL: str, dialect: str) -> List[str]: if __name__ == "__main__": - args = parser.parse_args() files = get_sql_files(args.path) - print(f'Parsing using dialect: {args.dialect}') + print(f"Parsing using dialect: {args.dialect}") result_dict = dict() @@ -87,7 +83,7 @@ def get_tables_in_query(SQL: str, dialect: str) -> List[str]: dot = Digraph(engine="dot") dot.attr(rankdir="LR") dot.attr(splines="ortho") - dot.node_attr['shape'] = 'box' + dot.node_attr["shape"] = "box" for view, tables in result_dict.items(): view = view[:-4] diff --git a/justfile b/justfile index 21f02c5..9311ea0 100644 --- a/justfile +++ b/justfile @@ -1,8 +1,12 @@ -benchmark: build - uvx poetry run pytest tests/benchmark.py +benchmark: + uv sync + uv run maturin develop --release + uv run pytest tests/benchmark.py -test: - uvx poetry run pytest tests/ +test: + uv sync + uv run maturin develop + uv run pytest tests/ build: - uvx poetry build + uv run maturin build --release diff --git a/pyproject.toml b/pyproject.toml index a369aeb..ba631ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,19 @@ classifiers = [ "License :: OSI Approved :: MIT License", ] +[dependency-groups] +dev = [ + # build + "maturin", + # test + "pytest", + "pytest-benchmark", + "pytest-subtests", + # benchmark + "sqlglot", + "sqlparse", +] + [build-system] requires = ["maturin>=1.0,<2.0"] build-backend = "maturin" diff --git a/sqloxide.pyi b/sqloxide.pyi index b0bc4ba..5a1063c 100644 --- a/sqloxide.pyi +++ b/sqloxide.pyi @@ -148,7 +148,7 @@ class Select(TypedDict("Select", {"from": list[TableWithJoins]})): class Insert(TypedDict("Insert", {"or": Any | None})): """ An INSERT statement. - + See https://docs.rs/sqlparser/0.51.0/sqlparser/ast/struct.Insert.html """ @@ -163,7 +163,9 @@ class Insert(TypedDict("Insert", {"or": Any | None})): partitioned: Any | None after_columns: list[Any] table: bool - on: dict[str, Any] | None # e.g. {"OnConflict": {"conflict_target": None, "action": "DoNothing"}}, + on: ( + dict[str, Any] | None + ) # e.g. {"OnConflict": {"conflict_target": None, "action": "DoNothing"}}, returning: Any | None replace_into: bool priority: Any | None diff --git a/sqloxide/__init__.py b/sqloxide/__init__.py index 22e8cc0..c65ad00 100644 --- a/sqloxide/__init__.py +++ b/sqloxide/__init__.py @@ -1 +1 @@ -from .sqloxide import * +from .sqloxide import * # noqa: F403 diff --git a/src/lib.rs b/src/lib.rs index 7002d6d..63e5dd2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,7 +19,7 @@ use visitor::{extract_expressions, extract_relations, mutate_expressions, mutate /// Available `dialects`: https://github.com/sqlparser-rs/sqlparser-rs/blob/main/src/dialect/mod.rs#L189-L206 #[pyfunction] #[pyo3(text_signature = "(sql, dialect)")] -fn parse_sql(py: Python, sql: String, dialect: String) -> PyResult { +fn parse_sql(py: Python, sql: String, dialect: String) -> PyResult> { let chosen_dialect = dialect_from_str(dialect).unwrap_or_else(|| { println!("The dialect you chose was not recognized, falling back to 'generic'"); Box::new(GenericDialect {}) diff --git a/src/visitor.rs b/src/visitor.rs index 97dafa2..4a30f53 100644 --- a/src/visitor.rs +++ b/src/visitor.rs @@ -39,12 +39,12 @@ where #[pyfunction] #[pyo3(text_signature = "(parsed_query)")] -pub fn extract_relations(py: Python, parsed_query: &Bound<'_, PyAny>) -> PyResult { +pub fn extract_relations(py: Python, parsed_query: &Bound<'_, PyAny>) -> PyResult> { let statements = depythonize_query(parsed_query)?; let mut relations = Vec::new(); for statement in statements { - visit_relations(&statement, |relation| { + let _ = visit_relations(&statement, |relation| { relations.push(relation.clone()); ControlFlow::<()>::Continue(()) }); @@ -59,20 +59,21 @@ pub fn mutate_relations(_py: Python, parsed_query: &Bound<'_, PyAny>, func: &Bou let mut statements = depythonize_query(parsed_query)?; for statement in &mut statements { - visit_relations_mut(statement, |table| { + let _ = visit_relations_mut(statement, |table| { for section in &mut table.0 { - let ObjectNamePart::Identifier(ident) = section; - let val = match func.call1((ident.value.clone(),)) { - Ok(val) => val, - Err(e) => { - let msg = e.to_string(); - return ControlFlow::Break(PyValueError::new_err(format!( - "Python object serialization failed.\n\t{msg}" - ))); - } - }; - - ident.value = val.to_string(); + if let ObjectNamePart::Identifier(ident) = section { + let val = match func.call1((ident.value.clone(),)) { + Ok(val) => val, + Err(e) => { + let msg = e.to_string(); + return ControlFlow::Break(PyValueError::new_err(format!( + "Python object serialization failed.\n\t{msg}" + ))); + } + }; + + ident.value = val.to_string(); + } } ControlFlow::Continue(()) }); @@ -90,7 +91,7 @@ pub fn mutate_expressions(py: Python, parsed_query: &Bound<'_, PyAny>, func: &Bo let mut statements: Vec = depythonize_query(parsed_query)?; for statement in &mut statements { - visit_expressions_mut(statement, |expr| { + let _ = visit_expressions_mut(statement, |expr| { let converted_expr = match pythonize::pythonize(py, expr) { Ok(val) => val, Err(e) => { @@ -133,12 +134,12 @@ pub fn mutate_expressions(py: Python, parsed_query: &Bound<'_, PyAny>, func: &Bo #[pyfunction] #[pyo3(text_signature = "(parsed_query)")] -pub fn extract_expressions(py: Python, parsed_query: &Bound<'_, PyAny>) -> PyResult { +pub fn extract_expressions(py: Python, parsed_query: &Bound<'_, PyAny>) -> PyResult> { let statements: Vec = depythonize_query(parsed_query)?; let mut expressions = Vec::new(); for statement in statements { - visit_expressions(&statement, |expr| { + let _ = visit_expressions(&statement, |expr| { expressions.push(expr.clone()); ControlFlow::<()>::Continue(()) }); diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/benchmark.py b/tests/benchmark.py index 3dcac8d..97a35a3 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -1,10 +1,6 @@ -import pytest - from sqloxide import parse_sql import sqlparse import sqlglot -import json -import moz_sql_parser TEST_SQL = """ SELECT employee.first_name, employee.last_name, @@ -24,10 +20,6 @@ def bench_sqlparser(): return sqlparse.parse(TEST_SQL)[0] -def bench_mozsqlparser(): - return json.dumps(moz_sql_parser.parse(TEST_SQL)) - - def bench_sqlglot(): return sqlglot.parse(TEST_SQL, error_level=sqlglot.ErrorLevel.IGNORE) @@ -40,9 +32,5 @@ def test_sqlparser(benchmark): benchmark(bench_sqlparser) -def test_mozsqlparser(benchmark): - benchmark(bench_mozsqlparser) - - def test_sqlglot(benchmark): benchmark(bench_sqlglot) diff --git a/tests/test_sqloxide.py b/tests/test_sqloxide.py index 2014345..e4c64ed 100644 --- a/tests/test_sqloxide.py +++ b/tests/test_sqloxide.py @@ -60,7 +60,7 @@ def func(x): ast = parse_sql(sql=SQL, dialect="ansi") assert mutate_relations(parsed_query=ast, func=func) == [ - 'SELECT employee.first_name, employee.last_name, c.start_time, c.end_time, call_outcome.outcome_text FROM employee INNER JOIN "call2"."call2"."call2" AS c ON c.employee_id = employee.id INNER JOIN call2_outcome ON c.call_outcome_id = call_outcome.id ORDER BY c.start_time ASC' + 'SELECT employee.first_name, employee.last_name, c.start_time, c.end_time, call_outcome.outcome_text FROM employee INNER JOIN "call2"."call2"."call2" c ON c.employee_id = employee.id INNER JOIN call2_outcome ON c.call_outcome_id = call_outcome.id ORDER BY c.start_time ASC' ] @@ -87,7 +87,7 @@ def func(x): ast = parse_sql(sql=SQL, dialect="ansi") result = mutate_expressions(parsed_query=ast, func=func) assert result == [ - 'SELECT EMPLOYEE.FIRST_NAME, EMPLOYEE.LAST_NAME, C.START_TIME, C.END_TIME, CALL_OUTCOME.OUTCOME_TEXT FROM employee INNER JOIN "call"."call"."call" AS c ON C.EMPLOYEE_ID = EMPLOYEE.ID INNER JOIN call_outcome ON C.CALL_OUTCOME_ID = CALL_OUTCOME.ID ORDER BY C.START_TIME ASC' + 'SELECT EMPLOYEE.FIRST_NAME, EMPLOYEE.LAST_NAME, C.START_TIME, C.END_TIME, CALL_OUTCOME.OUTCOME_TEXT FROM employee INNER JOIN "call"."call"."call" c ON C.EMPLOYEE_ID = EMPLOYEE.ID INNER JOIN call_outcome ON C.CALL_OUTCOME_ID = CALL_OUTCOME.ID ORDER BY C.START_TIME ASC' ]