From 5971ef6f7430e4baf9cdd8d4e356b168402eb77c Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 31 Dec 2025 11:00:42 -0800 Subject: [PATCH 01/17] feat: use C -> Rust, bypass protobuf --- build.rs | 170 +++- src/ast/convert.rs | 1892 ++++++++++++++++++++++++++++++++++ src/ast/mod.rs | 31 + src/ast/nodes.rs | 1617 +++++++++++++++++++++++++++++ src/bindings_raw.rs | 12 + src/lib.rs | 4 + src/query.rs | 63 ++ src/raw_parse.rs | 1188 ++++++++++++++++++++++ tests/ast_tests.rs | 374 +++++++ tests/raw_parse_tests.rs | 2070 ++++++++++++++++++++++++++++++++++++++ 10 files changed, 7419 insertions(+), 2 deletions(-) create mode 100644 src/ast/convert.rs create mode 100644 src/ast/mod.rs create mode 100644 src/ast/nodes.rs create mode 100644 src/bindings_raw.rs create mode 100644 src/raw_parse.rs create mode 100644 tests/ast_tests.rs create mode 100644 tests/raw_parse_tests.rs diff --git a/build.rs b/build.rs index a722ad0..0cdaf8c 100644 --- a/build.rs +++ b/build.rs @@ -11,6 +11,7 @@ fn main() -> Result<(), Box> { let out_dir = PathBuf::from(env::var("OUT_DIR")?); let build_path = Path::new(".").join("libpg_query"); let out_header_path = out_dir.join("pg_query").with_extension("h"); + let out_raw_header_path = out_dir.join("pg_query_raw").with_extension("h"); let out_protobuf_path = out_dir.join("protobuf"); let target = env::var("TARGET").unwrap(); @@ -21,7 +22,7 @@ fn main() -> Result<(), Box> { // Copy the relevant source files to the OUT_DIR let source_paths = vec![ build_path.join("pg_query").with_extension("h"), - build_path.join("postgres_deparse").with_extension("h"), + build_path.join("pg_query_raw.h"), build_path.join("Makefile"), build_path.join("src"), build_path.join("protobuf"), @@ -56,13 +57,178 @@ fn main() -> Result<(), Box> { } build.compile("pg_query"); - // Generate bindings for Rust + // Generate bindings for Rust (basic API) bindgen::Builder::default() .header(out_header_path.to_str().ok_or("Invalid header path")?) + // Blocklist raw parse functions that are used via bindings_raw + .blocklist_function("pg_query_parse_raw") + .blocklist_function("pg_query_parse_raw_opts") + .blocklist_function("pg_query_free_raw_parse_result") + .blocklist_type("PgQueryRawParseResult") .generate() .map_err(|_| "Unable to generate bindings")? .write_to_file(out_dir.join("bindings.rs"))?; + // Generate bindings for raw parse tree access (includes PostgreSQL internal types) + let mut raw_builder = bindgen::Builder::default() + .header(out_raw_header_path.to_str().ok_or("Invalid raw header path")?) + .clang_arg(format!("-I{}", out_dir.display())) + .clang_arg(format!("-I{}", out_dir.join("src/postgres/include").display())) + .clang_arg(format!("-I{}", out_dir.join("src/include").display())); + + if target.contains("windows") { + raw_builder = raw_builder.clang_arg(format!("-I{}", out_dir.join("src/postgres/include/port/win32").display())); + if target.contains("msvc") { + raw_builder = raw_builder.clang_arg(format!("-I{}", out_dir.join("src/postgres/include/port/win32_msvc").display())); + } + } + + raw_builder + // Allowlist only the types we need for parse tree traversal + .allowlist_type("List") + .allowlist_type("ListCell") + .allowlist_type("Node") + .allowlist_type("NodeTag") + .allowlist_type("RawStmt") + .allowlist_type("SelectStmt") + .allowlist_type("InsertStmt") + .allowlist_type("UpdateStmt") + .allowlist_type("DeleteStmt") + .allowlist_type("MergeStmt") + .allowlist_type("CreateStmt") + .allowlist_type("AlterTableStmt") + .allowlist_type("DropStmt") + .allowlist_type("TruncateStmt") + .allowlist_type("IndexStmt") + .allowlist_type("ViewStmt") + .allowlist_type("RangeVar") + .allowlist_type("ColumnRef") + .allowlist_type("ResTarget") + .allowlist_type("A_Expr") + .allowlist_type("FuncCall") + .allowlist_type("TypeCast") + .allowlist_type("TypeName") + .allowlist_type("ColumnDef") + .allowlist_type("Constraint") + .allowlist_type("JoinExpr") + .allowlist_type("SortBy") + .allowlist_type("WindowDef") + .allowlist_type("WithClause") + .allowlist_type("CommonTableExpr") + .allowlist_type("IntoClause") + .allowlist_type("OnConflictClause") + .allowlist_type("InferClause") + .allowlist_type("Alias") + .allowlist_type("A_Const") + .allowlist_type("A_Star") + .allowlist_type("A_Indices") + .allowlist_type("A_Indirection") + .allowlist_type("A_ArrayExpr") + .allowlist_type("SubLink") + .allowlist_type("BoolExpr") + .allowlist_type("NullTest") + .allowlist_type("BooleanTest") + .allowlist_type("CaseExpr") + .allowlist_type("CaseWhen") + .allowlist_type("CoalesceExpr") + .allowlist_type("MinMaxExpr") + .allowlist_type("RowExpr") + .allowlist_type("SetToDefault") + .allowlist_type("MultiAssignRef") + .allowlist_type("ParamRef") + .allowlist_type("CollateClause") + .allowlist_type("PartitionSpec") + .allowlist_type("PartitionBoundSpec") + .allowlist_type("PartitionRangeDatum") + .allowlist_type("CTESearchClause") + .allowlist_type("CTECycleClause") + .allowlist_type("RangeSubselect") + .allowlist_type("RangeFunction") + .allowlist_type("DefElem") + .allowlist_type("IndexElem") + .allowlist_type("SortGroupClause") + .allowlist_type("GroupingSet") + .allowlist_type("LockingClause") + .allowlist_type("MergeWhenClause") + .allowlist_type("TransactionStmt") + .allowlist_type("VariableSetStmt") + .allowlist_type("VariableShowStmt") + .allowlist_type("ExplainStmt") + .allowlist_type("CopyStmt") + .allowlist_type("GrantStmt") + .allowlist_type("RoleSpec") + .allowlist_type("FunctionParameter") + .allowlist_type("AlterTableCmd") + .allowlist_type("AccessPriv") + .allowlist_type("ObjectWithArgs") + .allowlist_type("CreateFunctionStmt") + .allowlist_type("CreateSchemaStmt") + .allowlist_type("CreateSeqStmt") + .allowlist_type("CreateTrigStmt") + .allowlist_type("RuleStmt") + .allowlist_type("CreateDomainStmt") + .allowlist_type("CreateTableAsStmt") + .allowlist_type("RefreshMatViewStmt") + .allowlist_type("VacuumStmt") + .allowlist_type("DoStmt") + .allowlist_type("RenameStmt") + .allowlist_type("NotifyStmt") + .allowlist_type("ListenStmt") + .allowlist_type("UnlistenStmt") + .allowlist_type("PrepareStmt") + .allowlist_type("ExecuteStmt") + .allowlist_type("DeallocateStmt") + .allowlist_type("FetchStmt") + .allowlist_type("ClosePortalStmt") + .allowlist_type("String") + .allowlist_type("Integer") + .allowlist_type("Float") + .allowlist_type("Boolean") + .allowlist_type("BitString") + // Allowlist enums + .allowlist_type("SetOperation") + .allowlist_type("LimitOption") + .allowlist_type("A_Expr_Kind") + .allowlist_type("BoolExprType") + .allowlist_type("SubLinkType") + .allowlist_type("NullTestType") + .allowlist_type("BoolTestType") + .allowlist_type("MinMaxOp") + .allowlist_type("JoinType") + .allowlist_type("SortByDir") + .allowlist_type("SortByNulls") + .allowlist_type("CTEMaterialize") + .allowlist_type("OnCommitAction") + .allowlist_type("ObjectType") + .allowlist_type("DropBehavior") + .allowlist_type("OnConflictAction") + .allowlist_type("GroupingSetKind") + .allowlist_type("CmdType") + .allowlist_type("TransactionStmtKind") + .allowlist_type("ConstrType") + .allowlist_type("DefElemAction") + .allowlist_type("RoleSpecType") + .allowlist_type("CoercionForm") + .allowlist_type("VariableSetKind") + .allowlist_type("LockClauseStrength") + .allowlist_type("LockWaitPolicy") + .allowlist_type("ViewCheckOption") + .allowlist_type("DiscardMode") + .allowlist_type("FetchDirection") + .allowlist_type("FunctionParameterMode") + .allowlist_type("AlterTableType") + .allowlist_type("GrantTargetType") + .allowlist_type("OverridingKind") + .allowlist_type("PartitionStrategy") + .allowlist_type("PartitionRangeDatumKind") + // Allowlist raw parse functions + .allowlist_function("pg_query_parse_raw") + .allowlist_function("pg_query_parse_raw_opts") + .allowlist_function("pg_query_free_raw_parse_result") + .generate() + .map_err(|_| "Unable to generate raw bindings")? + .write_to_file(out_dir.join("bindings_raw.rs"))?; + // Only generate protobuf bindings if protoc is available let protoc_exists = Command::new("protoc").arg("--version").status().is_ok(); // If the package is being built by docs.rs, we don't want to regenerate the protobuf bindings diff --git a/src/ast/convert.rs b/src/ast/convert.rs new file mode 100644 index 0000000..65d8e1e --- /dev/null +++ b/src/ast/convert.rs @@ -0,0 +1,1892 @@ +//! Conversion implementations between protobuf types and native AST types. + +use crate::protobuf; +use crate::ast::nodes::*; + +// ============================================================================ +// From protobuf to native AST types +// ============================================================================ + +impl ParseResult { + /// Create a new ParseResult from a protobuf result. + /// This stores the original protobuf for later deparsing. + pub fn from_protobuf(pb: protobuf::ParseResult) -> Self { + let stmts = pb.stmts.iter().map(|s| s.into()).collect(); + ParseResult { + version: pb.version, + stmts, + original_protobuf: pb, + } + } + + /// Get a reference to the original protobuf for deparsing. + pub fn as_protobuf(&self) -> &protobuf::ParseResult { + &self.original_protobuf + } +} + +impl From for ParseResult { + fn from(pb: protobuf::ParseResult) -> Self { + ParseResult::from_protobuf(pb) + } +} + +impl From<&protobuf::ParseResult> for ParseResult { + fn from(pb: &protobuf::ParseResult) -> Self { + ParseResult::from_protobuf(pb.clone()) + } +} + +impl From for RawStmt { + fn from(pb: protobuf::RawStmt) -> Self { + RawStmt { + stmt: pb.stmt.map(|n| (*n).into()).unwrap_or(Node::Null), + stmt_location: pb.stmt_location, + stmt_len: pb.stmt_len, + } + } +} + +impl From<&protobuf::RawStmt> for RawStmt { + fn from(pb: &protobuf::RawStmt) -> Self { + RawStmt { + stmt: pb.stmt.as_ref().map(|n| n.as_ref().into()).unwrap_or(Node::Null), + stmt_location: pb.stmt_location, + stmt_len: pb.stmt_len, + } + } +} + +impl From for Node { + fn from(pb: protobuf::Node) -> Self { + match pb.node { + Some(node) => node.into(), + None => Node::Null, + } + } +} + +impl From<&protobuf::Node> for Node { + fn from(pb: &protobuf::Node) -> Self { + match &pb.node { + Some(node) => node.into(), + None => Node::Null, + } + } +} + +impl From for Node { + fn from(pb: protobuf::node::Node) -> Self { + use protobuf::node::Node as PbNode; + match pb { + // Primitive types (not boxed) + PbNode::Integer(v) => Node::Integer(v.into()), + PbNode::Float(v) => Node::Float(v.into()), + PbNode::Boolean(v) => Node::Boolean(v.into()), + PbNode::String(v) => Node::String(v.into()), + PbNode::BitString(v) => Node::BitString(v.into()), + PbNode::List(v) => Node::List(v.items.into_iter().map(|n| n.into()).collect()), + + // Statement types (boxed in protobuf) + PbNode::SelectStmt(v) => Node::SelectStmt(Box::new((*v).into())), + PbNode::InsertStmt(v) => Node::InsertStmt(Box::new((*v).into())), + PbNode::UpdateStmt(v) => Node::UpdateStmt(Box::new((*v).into())), + PbNode::DeleteStmt(v) => Node::DeleteStmt(Box::new((*v).into())), + PbNode::MergeStmt(v) => Node::MergeStmt(Box::new((*v).into())), + + // DDL statements (not boxed in protobuf) + PbNode::CreateStmt(v) => Node::CreateStmt(Box::new(v.into())), + PbNode::AlterTableStmt(v) => Node::AlterTableStmt(Box::new(v.into())), + PbNode::DropStmt(v) => Node::DropStmt(Box::new(v.into())), + PbNode::TruncateStmt(v) => Node::TruncateStmt(Box::new(v.into())), + PbNode::IndexStmt(v) => Node::IndexStmt(Box::new((*v).into())), + PbNode::CreateSchemaStmt(v) => Node::CreateSchemaStmt(Box::new(v.into())), + PbNode::ViewStmt(v) => Node::ViewStmt(Box::new((*v).into())), + PbNode::CreateFunctionStmt(v) => Node::CreateFunctionStmt(Box::new((*v).into())), + PbNode::AlterFunctionStmt(v) => Node::AlterFunctionStmt(Box::new(v.into())), + PbNode::CreateSeqStmt(v) => Node::CreateSeqStmt(Box::new(v.into())), + PbNode::AlterSeqStmt(v) => Node::AlterSeqStmt(Box::new(v.into())), + PbNode::CreateTrigStmt(v) => Node::CreateTrigStmt(Box::new((*v).into())), + PbNode::RuleStmt(v) => Node::RuleStmt(Box::new((*v).into())), + PbNode::CreateDomainStmt(v) => Node::CreateDomainStmt(Box::new((*v).into())), + PbNode::CreateTableAsStmt(v) => Node::CreateTableAsStmt(Box::new((*v).into())), + PbNode::RefreshMatViewStmt(v) => Node::RefreshMatViewStmt(Box::new(v.into())), + + // Transaction statements (not boxed in protobuf) + PbNode::TransactionStmt(v) => Node::TransactionStmt(Box::new(v.into())), + + // Expression types (mixed boxing) + PbNode::AExpr(v) => Node::AExpr(Box::new((*v).into())), + PbNode::ColumnRef(v) => Node::ColumnRef(Box::new(v.into())), + PbNode::ParamRef(v) => Node::ParamRef(Box::new(v.into())), + PbNode::AConst(v) => Node::AConst(Box::new(v.into())), + PbNode::TypeCast(v) => Node::TypeCast(Box::new((*v).into())), + PbNode::CollateClause(v) => Node::CollateClause(Box::new((*v).into())), + PbNode::FuncCall(v) => Node::FuncCall(Box::new((*v).into())), + PbNode::AStar(_) => Node::AStar(AStar), + PbNode::AIndices(v) => Node::AIndices(Box::new((*v).into())), + PbNode::AIndirection(v) => Node::AIndirection(Box::new((*v).into())), + PbNode::AArrayExpr(v) => Node::AArrayExpr(Box::new(v.into())), + PbNode::SubLink(v) => Node::SubLink(Box::new((*v).into())), + PbNode::BoolExpr(v) => Node::BoolExpr(Box::new((*v).into())), + PbNode::NullTest(v) => Node::NullTest(Box::new((*v).into())), + PbNode::BooleanTest(v) => Node::BooleanTest(Box::new((*v).into())), + PbNode::CaseExpr(v) => Node::CaseExpr(Box::new((*v).into())), + PbNode::CaseWhen(v) => Node::CaseWhen(Box::new((*v).into())), + PbNode::CoalesceExpr(v) => Node::CoalesceExpr(Box::new((*v).into())), + PbNode::MinMaxExpr(v) => Node::MinMaxExpr(Box::new((*v).into())), + PbNode::RowExpr(v) => Node::RowExpr(Box::new((*v).into())), + + // Target/Result types (boxed in protobuf) + PbNode::ResTarget(v) => Node::ResTarget(Box::new((*v).into())), + + // Table/Range types (mixed) + PbNode::RangeVar(v) => Node::RangeVar(Box::new(v.into())), + PbNode::RangeSubselect(v) => Node::RangeSubselect(Box::new((*v).into())), + PbNode::RangeFunction(v) => Node::RangeFunction(Box::new(v.into())), + PbNode::JoinExpr(v) => Node::JoinExpr(Box::new((*v).into())), + + // Clause types (mixed) + PbNode::SortBy(v) => Node::SortBy(Box::new((*v).into())), + PbNode::WindowDef(v) => Node::WindowDef(Box::new((*v).into())), + PbNode::WithClause(v) => Node::WithClause(Box::new(v.into())), + PbNode::CommonTableExpr(v) => Node::CommonTableExpr(Box::new((*v).into())), + PbNode::IntoClause(v) => Node::IntoClause(Box::new((*v).into())), + PbNode::OnConflictClause(v) => Node::OnConflictClause(Box::new((*v).into())), + PbNode::LockingClause(v) => Node::LockingClause(Box::new(v.into())), + PbNode::GroupingSet(v) => Node::GroupingSet(Box::new(v.into())), + PbNode::MergeWhenClause(v) => Node::MergeWhenClause(Box::new((*v).into())), + + // Type-related (mixed) + PbNode::TypeName(v) => Node::TypeName(Box::new(v.into())), + PbNode::ColumnDef(v) => Node::ColumnDef(Box::new((*v).into())), + PbNode::Constraint(v) => Node::Constraint(Box::new((*v).into())), + PbNode::DefElem(v) => Node::DefElem(Box::new((*v).into())), + PbNode::IndexElem(v) => Node::IndexElem(Box::new((*v).into())), + + // Alias and role types (not boxed) + PbNode::Alias(v) => Node::Alias(Box::new(v.into())), + PbNode::RoleSpec(v) => Node::RoleSpec(Box::new(v.into())), + + // Other commonly used types (mixed) + PbNode::SortGroupClause(v) => Node::SortGroupClause(Box::new(v.into())), + PbNode::FunctionParameter(v) => Node::FunctionParameter(Box::new((*v).into())), + PbNode::AlterTableCmd(v) => Node::AlterTableCmd(Box::new((*v).into())), + PbNode::AccessPriv(v) => Node::AccessPriv(Box::new(v.into())), + PbNode::ObjectWithArgs(v) => Node::ObjectWithArgs(Box::new(v.into())), + + // Administrative statements (mixed) + PbNode::VariableSetStmt(v) => Node::VariableSetStmt(Box::new(v.into())), + PbNode::VariableShowStmt(v) => Node::VariableShowStmt(Box::new(v.into())), + PbNode::ExplainStmt(v) => Node::ExplainStmt(Box::new((*v).into())), + PbNode::CopyStmt(v) => Node::CopyStmt(Box::new((*v).into())), + PbNode::GrantStmt(v) => Node::GrantStmt(Box::new(v.into())), + PbNode::GrantRoleStmt(v) => Node::GrantRoleStmt(Box::new(v.into())), + PbNode::LockStmt(v) => Node::LockStmt(Box::new(v.into())), + PbNode::VacuumStmt(v) => Node::VacuumStmt(Box::new(v.into())), + + // Other statements (mixed) + PbNode::DoStmt(v) => Node::DoStmt(Box::new(v.into())), + PbNode::RenameStmt(v) => Node::RenameStmt(Box::new((*v).into())), + PbNode::NotifyStmt(v) => Node::NotifyStmt(Box::new(v.into())), + PbNode::ListenStmt(v) => Node::ListenStmt(Box::new(v.into())), + PbNode::UnlistenStmt(v) => Node::UnlistenStmt(Box::new(v.into())), + PbNode::CheckPointStmt(_) => Node::CheckPointStmt(Box::new(CheckPointStmt)), + PbNode::DiscardStmt(v) => Node::DiscardStmt(Box::new(v.into())), + PbNode::PrepareStmt(v) => Node::PrepareStmt(Box::new((*v).into())), + PbNode::ExecuteStmt(v) => Node::ExecuteStmt(Box::new(v.into())), + PbNode::DeallocateStmt(v) => Node::DeallocateStmt(Box::new(v.into())), + PbNode::ClosePortalStmt(v) => Node::ClosePortalStmt(Box::new(v.into())), + PbNode::FetchStmt(v) => Node::FetchStmt(Box::new(v.into())), + + // Fallback for any unhandled node types + other => Node::Other(protobuf::Node { node: Some(other) }), + } + } +} + +impl From<&protobuf::node::Node> for Node { + fn from(pb: &protobuf::node::Node) -> Self { + pb.clone().into() + } +} + +// Conversions from Box for boxed protobuf fields +impl From> for Node { + fn from(pb: Box) -> Self { + (*pb).into() + } +} + +impl From> for IntoClause { + fn from(pb: Box) -> Self { + (*pb).into() + } +} + +impl From> for OnConflictClause { + fn from(pb: Box) -> Self { + (*pb).into() + } +} + +impl From> for CollateClause { + fn from(pb: Box) -> Self { + (*pb).into() + } +} + +impl From> for SelectStmt { + fn from(pb: Box) -> Self { + (*pb).into() + } +} + +// Primitive type conversions +impl From for Integer { + fn from(pb: protobuf::Integer) -> Self { + Integer { ival: pb.ival } + } +} + +impl From for Float { + fn from(pb: protobuf::Float) -> Self { + Float { fval: pb.fval } + } +} + +impl From for Boolean { + fn from(pb: protobuf::Boolean) -> Self { + Boolean { boolval: pb.boolval } + } +} + +impl From for StringValue { + fn from(pb: protobuf::String) -> Self { + StringValue { sval: pb.sval } + } +} + +impl From for BitString { + fn from(pb: protobuf::BitString) -> Self { + BitString { bsval: pb.bsval } + } +} + +// Statement type conversions +impl From for SelectStmt { + fn from(pb: protobuf::SelectStmt) -> Self { + SelectStmt { + distinct_clause: pb.distinct_clause.into_iter().map(|n| n.into()).collect(), + into_clause: pb.into_clause.map(|v| v.into()), + target_list: pb.target_list.into_iter().map(|n| n.into()).collect(), + from_clause: pb.from_clause.into_iter().map(|n| n.into()).collect(), + where_clause: pb.where_clause.map(|n| n.into()), + group_clause: pb.group_clause.into_iter().map(|n| n.into()).collect(), + group_distinct: pb.group_distinct, + having_clause: pb.having_clause.map(|n| n.into()), + window_clause: pb.window_clause.into_iter().map(|n| n.into()).collect(), + values_lists: pb.values_lists.into_iter().map(|n| n.into()).collect(), + sort_clause: pb.sort_clause.into_iter().map(|n| n.into()).collect(), + limit_offset: pb.limit_offset.map(|n| n.into()), + limit_count: pb.limit_count.map(|n| n.into()), + limit_option: pb.limit_option.into(), + locking_clause: pb.locking_clause.into_iter().map(|n| n.into()).collect(), + with_clause: pb.with_clause.map(|v| v.into()), + op: pb.op.into(), + all: pb.all, + larg: pb.larg.map(|v| Box::new((*v).into())), + rarg: pb.rarg.map(|v| Box::new((*v).into())), + } + } +} + +impl From for InsertStmt { + fn from(pb: protobuf::InsertStmt) -> Self { + InsertStmt { + relation: pb.relation.map(|v| v.into()), + cols: pb.cols.into_iter().map(|n| n.into()).collect(), + select_stmt: pb.select_stmt.map(|n| n.into()), + on_conflict_clause: pb.on_conflict_clause.map(|v| v.into()), + returning_list: pb.returning_list.into_iter().map(|n| n.into()).collect(), + with_clause: pb.with_clause.map(|v| v.into()), + override_: pb.r#override.into(), + } + } +} + +impl From for UpdateStmt { + fn from(pb: protobuf::UpdateStmt) -> Self { + UpdateStmt { + relation: pb.relation.map(|v| v.into()), + target_list: pb.target_list.into_iter().map(|n| n.into()).collect(), + where_clause: pb.where_clause.map(|n| n.into()), + from_clause: pb.from_clause.into_iter().map(|n| n.into()).collect(), + returning_list: pb.returning_list.into_iter().map(|n| n.into()).collect(), + with_clause: pb.with_clause.map(|v| v.into()), + } + } +} + +impl From for DeleteStmt { + fn from(pb: protobuf::DeleteStmt) -> Self { + DeleteStmt { + relation: pb.relation.map(|v| v.into()), + using_clause: pb.using_clause.into_iter().map(|n| n.into()).collect(), + where_clause: pb.where_clause.map(|n| n.into()), + returning_list: pb.returning_list.into_iter().map(|n| n.into()).collect(), + with_clause: pb.with_clause.map(|v| v.into()), + } + } +} + +impl From for MergeStmt { + fn from(pb: protobuf::MergeStmt) -> Self { + MergeStmt { + relation: pb.relation.map(|v| v.into()), + source_relation: pb.source_relation.map(|n| n.into()), + join_condition: pb.join_condition.map(|n| n.into()), + merge_when_clauses: pb.merge_when_clauses.into_iter().map(|n| n.into()).collect(), + with_clause: pb.with_clause.map(|v| v.into()), + } + } +} + +// DDL statement conversions +impl From for CreateStmt { + fn from(pb: protobuf::CreateStmt) -> Self { + CreateStmt { + relation: pb.relation.map(|v| v.into()), + table_elts: pb.table_elts.into_iter().map(|n| n.into()).collect(), + inh_relations: pb.inh_relations.into_iter().map(|n| n.into()).collect(), + partbound: pb.partbound.map(|n| Node::Other(protobuf::Node { node: Some(protobuf::node::Node::PartitionBoundSpec(n)) })), + partspec: pb.partspec.map(|n| Node::Other(protobuf::Node { node: Some(protobuf::node::Node::PartitionSpec(n)) })), + of_typename: pb.of_typename.map(|v| v.into()), + constraints: pb.constraints.into_iter().map(|n| n.into()).collect(), + options: pb.options.into_iter().map(|n| n.into()).collect(), + oncommit: pb.oncommit.into(), + tablespacename: pb.tablespacename, + access_method: pb.access_method, + if_not_exists: pb.if_not_exists, + } + } +} + +impl From for AlterTableStmt { + fn from(pb: protobuf::AlterTableStmt) -> Self { + AlterTableStmt { + relation: pb.relation.map(|v| v.into()), + cmds: pb.cmds.into_iter().map(|n| n.into()).collect(), + objtype: pb.objtype.into(), + missing_ok: pb.missing_ok, + } + } +} + +impl From for DropStmt { + fn from(pb: protobuf::DropStmt) -> Self { + DropStmt { + objects: pb.objects.into_iter().map(|n| n.into()).collect(), + remove_type: pb.remove_type.into(), + behavior: pb.behavior.into(), + missing_ok: pb.missing_ok, + concurrent: pb.concurrent, + } + } +} + +impl From for TruncateStmt { + fn from(pb: protobuf::TruncateStmt) -> Self { + TruncateStmt { + relations: pb.relations.into_iter().map(|n| n.into()).collect(), + restart_seqs: pb.restart_seqs, + behavior: pb.behavior.into(), + } + } +} + +impl From for IndexStmt { + fn from(pb: protobuf::IndexStmt) -> Self { + IndexStmt { + idxname: pb.idxname, + relation: pb.relation.map(|v| v.into()), + access_method: pb.access_method, + table_space: pb.table_space, + index_params: pb.index_params.into_iter().map(|n| n.into()).collect(), + index_including_params: pb.index_including_params.into_iter().map(|n| n.into()).collect(), + options: pb.options.into_iter().map(|n| n.into()).collect(), + where_clause: pb.where_clause.map(|n| n.into()), + exclude_op_names: pb.exclude_op_names.into_iter().map(|n| n.into()).collect(), + idxcomment: pb.idxcomment, + index_oid: pb.index_oid, + old_number: pb.old_number, + old_first_relfilelocator: pb.old_first_relfilelocator_subid, + unique: pb.unique, + nulls_not_distinct: pb.nulls_not_distinct, + primary: pb.primary, + is_constraint: pb.isconstraint, + deferrable: pb.deferrable, + initdeferred: pb.initdeferred, + transformed: pb.transformed, + concurrent: pb.concurrent, + if_not_exists: pb.if_not_exists, + reset_default_tblspc: pb.reset_default_tblspc, + } + } +} + +impl From for CreateSchemaStmt { + fn from(pb: protobuf::CreateSchemaStmt) -> Self { + CreateSchemaStmt { + schemaname: pb.schemaname, + authrole: pb.authrole.map(|v| v.into()), + schema_elts: pb.schema_elts.into_iter().map(|n| n.into()).collect(), + if_not_exists: pb.if_not_exists, + } + } +} + +impl From for ViewStmt { + fn from(pb: protobuf::ViewStmt) -> Self { + ViewStmt { + view: pb.view.map(|v| v.into()), + aliases: pb.aliases.into_iter().map(|n| n.into()).collect(), + query: pb.query.map(|n| n.into()), + replace: pb.replace, + options: pb.options.into_iter().map(|n| n.into()).collect(), + with_check_option: pb.with_check_option.into(), + } + } +} + +impl From for CreateFunctionStmt { + fn from(pb: protobuf::CreateFunctionStmt) -> Self { + CreateFunctionStmt { + is_procedure: pb.is_procedure, + replace: pb.replace, + funcname: pb.funcname.into_iter().map(|n| n.into()).collect(), + parameters: pb.parameters.into_iter().map(|n| n.into()).collect(), + return_type: pb.return_type.map(|v| v.into()), + options: pb.options.into_iter().map(|n| n.into()).collect(), + sql_body: pb.sql_body.map(|n| n.into()), + } + } +} + +impl From for AlterFunctionStmt { + fn from(pb: protobuf::AlterFunctionStmt) -> Self { + AlterFunctionStmt { + objtype: pb.objtype.into(), + func: pb.func.map(|v| v.into()), + actions: pb.actions.into_iter().map(|n| n.into()).collect(), + } + } +} + +impl From for CreateSeqStmt { + fn from(pb: protobuf::CreateSeqStmt) -> Self { + CreateSeqStmt { + sequence: pb.sequence.map(|v| v.into()), + options: pb.options.into_iter().map(|n| n.into()).collect(), + owner_id: pb.owner_id, + for_identity: pb.for_identity, + if_not_exists: pb.if_not_exists, + } + } +} + +impl From for AlterSeqStmt { + fn from(pb: protobuf::AlterSeqStmt) -> Self { + AlterSeqStmt { + sequence: pb.sequence.map(|v| v.into()), + options: pb.options.into_iter().map(|n| n.into()).collect(), + for_identity: pb.for_identity, + missing_ok: pb.missing_ok, + } + } +} + +impl From for CreateTrigStmt { + fn from(pb: protobuf::CreateTrigStmt) -> Self { + CreateTrigStmt { + replace: pb.replace, + isconstraint: pb.isconstraint, + trigname: pb.trigname, + relation: pb.relation.map(|v| v.into()), + funcname: pb.funcname.into_iter().map(|n| n.into()).collect(), + args: pb.args.into_iter().map(|n| n.into()).collect(), + row: pb.row, + timing: pb.timing, + events: pb.events, + columns: pb.columns.into_iter().map(|n| n.into()).collect(), + when_clause: pb.when_clause.map(|n| n.into()), + transition_rels: pb.transition_rels.into_iter().map(|n| n.into()).collect(), + deferrable: pb.deferrable, + initdeferred: pb.initdeferred, + constrrel: pb.constrrel.map(|v| v.into()), + } + } +} + +impl From for RuleStmt { + fn from(pb: protobuf::RuleStmt) -> Self { + RuleStmt { + relation: pb.relation.map(|v| v.into()), + rulename: pb.rulename, + where_clause: pb.where_clause.map(|n| n.into()), + event: pb.event.into(), + instead: pb.instead, + actions: pb.actions.into_iter().map(|n| n.into()).collect(), + replace: pb.replace, + } + } +} + +impl From for CreateDomainStmt { + fn from(pb: protobuf::CreateDomainStmt) -> Self { + CreateDomainStmt { + domainname: pb.domainname.into_iter().map(|n| n.into()).collect(), + type_name: pb.type_name.map(|v| v.into()), + coll_clause: pb.coll_clause.map(|v| v.into()), + constraints: pb.constraints.into_iter().map(|n| n.into()).collect(), + } + } +} + +impl From for CreateTableAsStmt { + fn from(pb: protobuf::CreateTableAsStmt) -> Self { + CreateTableAsStmt { + query: pb.query.map(|n| n.into()), + into: pb.into.map(|v| v.into()), + objtype: pb.objtype.into(), + is_select_into: pb.is_select_into, + if_not_exists: pb.if_not_exists, + } + } +} + +impl From for RefreshMatViewStmt { + fn from(pb: protobuf::RefreshMatViewStmt) -> Self { + RefreshMatViewStmt { + concurrent: pb.concurrent, + skip_data: pb.skip_data, + relation: pb.relation.map(|v| v.into()), + } + } +} + +// Transaction statement +impl From for TransactionStmt { + fn from(pb: protobuf::TransactionStmt) -> Self { + TransactionStmt { + kind: pb.kind.into(), + options: pb.options.into_iter().map(|n| n.into()).collect(), + savepoint_name: pb.savepoint_name, + gid: pb.gid, + chain: pb.chain, + } + } +} + +// Expression type conversions +impl From for AExpr { + fn from(pb: protobuf::AExpr) -> Self { + AExpr { + kind: pb.kind.into(), + name: pb.name.into_iter().map(|n| n.into()).collect(), + lexpr: pb.lexpr.map(|n| n.into()), + rexpr: pb.rexpr.map(|n| n.into()), + location: pb.location, + } + } +} + +impl From for ColumnRef { + fn from(pb: protobuf::ColumnRef) -> Self { + ColumnRef { + fields: pb.fields.into_iter().map(|n| n.into()).collect(), + location: pb.location, + } + } +} + +impl From for ParamRef { + fn from(pb: protobuf::ParamRef) -> Self { + ParamRef { + number: pb.number, + location: pb.location, + } + } +} + +impl From for AConst { + fn from(pb: protobuf::AConst) -> Self { + AConst { + val: pb.val.map(|v| v.into()), + isnull: pb.isnull, + location: pb.location, + } + } +} + +impl From for AConstValue { + fn from(pb: protobuf::a_const::Val) -> Self { + use protobuf::a_const::Val; + match pb { + Val::Ival(v) => AConstValue::Integer(v.into()), + Val::Fval(v) => AConstValue::Float(v.into()), + Val::Boolval(v) => AConstValue::Boolean(v.into()), + Val::Sval(v) => AConstValue::String(v.into()), + Val::Bsval(v) => AConstValue::BitString(v.into()), + } + } +} + +impl From for TypeCast { + fn from(pb: protobuf::TypeCast) -> Self { + TypeCast { + arg: pb.arg.map(|n| n.into()), + type_name: pb.type_name.map(|v| v.into()), + location: pb.location, + } + } +} + +impl From for CollateClause { + fn from(pb: protobuf::CollateClause) -> Self { + CollateClause { + arg: pb.arg.map(|n| n.into()), + collname: pb.collname.into_iter().map(|n| n.into()).collect(), + location: pb.location, + } + } +} + +impl From for FuncCall { + fn from(pb: protobuf::FuncCall) -> Self { + FuncCall { + funcname: pb.funcname.into_iter().map(|n| n.into()).collect(), + args: pb.args.into_iter().map(|n| n.into()).collect(), + agg_order: pb.agg_order.into_iter().map(|n| n.into()).collect(), + agg_filter: pb.agg_filter.map(|n| n.into()), + over: pb.over.map(|v| (*v).into()), + agg_within_group: pb.agg_within_group, + agg_star: pb.agg_star, + agg_distinct: pb.agg_distinct, + func_variadic: pb.func_variadic, + funcformat: pb.funcformat.into(), + location: pb.location, + } + } +} + +impl From for AIndices { + fn from(pb: protobuf::AIndices) -> Self { + AIndices { + is_slice: pb.is_slice, + lidx: pb.lidx.map(|n| n.into()), + uidx: pb.uidx.map(|n| n.into()), + } + } +} + +impl From for AIndirection { + fn from(pb: protobuf::AIndirection) -> Self { + AIndirection { + arg: pb.arg.map(|n| n.into()), + indirection: pb.indirection.into_iter().map(|n| n.into()).collect(), + } + } +} + +impl From for AArrayExpr { + fn from(pb: protobuf::AArrayExpr) -> Self { + AArrayExpr { + elements: pb.elements.into_iter().map(|n| n.into()).collect(), + location: pb.location, + } + } +} + +impl From for SubLink { + fn from(pb: protobuf::SubLink) -> Self { + SubLink { + sub_link_type: pb.sub_link_type.into(), + sub_link_id: pb.sub_link_id, + testexpr: pb.testexpr.map(|n| n.into()), + oper_name: pb.oper_name.into_iter().map(|n| n.into()).collect(), + subselect: pb.subselect.map(|n| n.into()), + location: pb.location, + } + } +} + +impl From for BoolExpr { + fn from(pb: protobuf::BoolExpr) -> Self { + BoolExpr { + boolop: pb.boolop.into(), + args: pb.args.into_iter().map(|n| n.into()).collect(), + location: pb.location, + } + } +} + +impl From for NullTest { + fn from(pb: protobuf::NullTest) -> Self { + NullTest { + arg: pb.arg.map(|n| n.into()), + nulltesttype: pb.nulltesttype.into(), + argisrow: pb.argisrow, + location: pb.location, + } + } +} + +impl From for BooleanTest { + fn from(pb: protobuf::BooleanTest) -> Self { + BooleanTest { + arg: pb.arg.map(|n| n.into()), + booltesttype: pb.booltesttype.into(), + location: pb.location, + } + } +} + +impl From for CaseExpr { + fn from(pb: protobuf::CaseExpr) -> Self { + CaseExpr { + arg: pb.arg.map(|n| n.into()), + args: pb.args.into_iter().map(|n| n.into()).collect(), + defresult: pb.defresult.map(|n| n.into()), + location: pb.location, + } + } +} + +impl From for CaseWhen { + fn from(pb: protobuf::CaseWhen) -> Self { + CaseWhen { + expr: pb.expr.map(|n| n.into()), + result: pb.result.map(|n| n.into()), + location: pb.location, + } + } +} + +impl From for CoalesceExpr { + fn from(pb: protobuf::CoalesceExpr) -> Self { + CoalesceExpr { + args: pb.args.into_iter().map(|n| n.into()).collect(), + location: pb.location, + } + } +} + +impl From for MinMaxExpr { + fn from(pb: protobuf::MinMaxExpr) -> Self { + MinMaxExpr { + op: pb.op.into(), + args: pb.args.into_iter().map(|n| n.into()).collect(), + location: pb.location, + } + } +} + +impl From for RowExpr { + fn from(pb: protobuf::RowExpr) -> Self { + RowExpr { + args: pb.args.into_iter().map(|n| n.into()).collect(), + row_format: pb.row_format.into(), + colnames: pb.colnames.into_iter().map(|n| n.into()).collect(), + location: pb.location, + } + } +} + +// Target/Result type conversions +impl From for ResTarget { + fn from(pb: protobuf::ResTarget) -> Self { + ResTarget { + name: pb.name, + indirection: pb.indirection.into_iter().map(|n| n.into()).collect(), + val: pb.val.map(|n| n.into()), + location: pb.location, + } + } +} + +// Table/Range type conversions +impl From for RangeVar { + fn from(pb: protobuf::RangeVar) -> Self { + RangeVar { + catalogname: pb.catalogname, + schemaname: pb.schemaname, + relname: pb.relname, + inh: pb.inh, + relpersistence: pb.relpersistence, + alias: pb.alias.map(|v| v.into()), + location: pb.location, + } + } +} + +impl From for RangeSubselect { + fn from(pb: protobuf::RangeSubselect) -> Self { + RangeSubselect { + lateral: pb.lateral, + subquery: pb.subquery.map(|n| n.into()), + alias: pb.alias.map(|v| v.into()), + } + } +} + +impl From for RangeFunction { + fn from(pb: protobuf::RangeFunction) -> Self { + RangeFunction { + lateral: pb.lateral, + ordinality: pb.ordinality, + is_rowsfrom: pb.is_rowsfrom, + functions: pb.functions.into_iter().map(|n| n.into()).collect(), + alias: pb.alias.map(|v| v.into()), + coldeflist: pb.coldeflist.into_iter().map(|n| n.into()).collect(), + } + } +} + +impl From for JoinExpr { + fn from(pb: protobuf::JoinExpr) -> Self { + JoinExpr { + jointype: pb.jointype.into(), + is_natural: pb.is_natural, + larg: pb.larg.map(|n| n.into()), + rarg: pb.rarg.map(|n| n.into()), + using_clause: pb.using_clause.into_iter().map(|n| n.into()).collect(), + join_using_alias: pb.join_using_alias.map(|v| v.into()), + quals: pb.quals.map(|n| n.into()), + alias: pb.alias.map(|v| v.into()), + rtindex: pb.rtindex, + } + } +} + +// Clause type conversions +impl From for SortBy { + fn from(pb: protobuf::SortBy) -> Self { + SortBy { + node: pb.node.map(|n| n.into()), + sortby_dir: pb.sortby_dir.into(), + sortby_nulls: pb.sortby_nulls.into(), + use_op: pb.use_op.into_iter().map(|n| n.into()).collect(), + location: pb.location, + } + } +} + +impl From for WindowDef { + fn from(pb: protobuf::WindowDef) -> Self { + WindowDef { + name: pb.name, + refname: pb.refname, + partition_clause: pb.partition_clause.into_iter().map(|n| n.into()).collect(), + order_clause: pb.order_clause.into_iter().map(|n| n.into()).collect(), + frame_options: pb.frame_options, + start_offset: pb.start_offset.map(|n| n.into()), + end_offset: pb.end_offset.map(|n| n.into()), + location: pb.location, + } + } +} + +impl From for WithClause { + fn from(pb: protobuf::WithClause) -> Self { + WithClause { + ctes: pb.ctes.into_iter().map(|n| n.into()).collect(), + recursive: pb.recursive, + location: pb.location, + } + } +} + +impl From for CommonTableExpr { + fn from(pb: protobuf::CommonTableExpr) -> Self { + CommonTableExpr { + ctename: pb.ctename, + aliascolnames: pb.aliascolnames.into_iter().map(|n| n.into()).collect(), + ctematerialized: pb.ctematerialized.into(), + ctequery: pb.ctequery.map(|n| n.into()), + search_clause: pb.search_clause.map(|n| Node::Other(protobuf::Node { node: Some(protobuf::node::Node::CtesearchClause(n)) })), + cycle_clause: pb.cycle_clause.map(|n| Node::Other(protobuf::Node { node: Some(protobuf::node::Node::CtecycleClause(n)) })), + location: pb.location, + cterecursive: pb.cterecursive, + cterefcount: pb.cterefcount, + ctecolnames: pb.ctecolnames.into_iter().map(|n| n.into()).collect(), + ctecoltypes: pb.ctecoltypes.into_iter().map(|n| n.into()).collect(), + ctecoltypmods: pb.ctecoltypmods.into_iter().map(|n| n.into()).collect(), + ctecolcollations: pb.ctecolcollations.into_iter().map(|n| n.into()).collect(), + } + } +} + +impl From for IntoClause { + fn from(pb: protobuf::IntoClause) -> Self { + IntoClause { + rel: pb.rel.map(|v| v.into()), + col_names: pb.col_names.into_iter().map(|n| n.into()).collect(), + access_method: pb.access_method, + options: pb.options.into_iter().map(|n| n.into()).collect(), + on_commit: pb.on_commit.into(), + table_space_name: pb.table_space_name, + view_query: pb.view_query.map(|n| n.into()), + skip_data: pb.skip_data, + } + } +} + +impl From for OnConflictClause { + fn from(pb: protobuf::OnConflictClause) -> Self { + OnConflictClause { + action: pb.action.into(), + infer: pb.infer.map(|n| Node::Other(protobuf::Node { node: Some(protobuf::node::Node::InferClause(n)) })), + target_list: pb.target_list.into_iter().map(|n| n.into()).collect(), + where_clause: pb.where_clause.map(|n| n.into()), + location: pb.location, + } + } +} + +impl From for LockingClause { + fn from(pb: protobuf::LockingClause) -> Self { + LockingClause { + locked_rels: pb.locked_rels.into_iter().map(|n| n.into()).collect(), + strength: pb.strength.into(), + wait_policy: pb.wait_policy.into(), + } + } +} + +impl From for GroupingSet { + fn from(pb: protobuf::GroupingSet) -> Self { + GroupingSet { + kind: pb.kind.into(), + content: pb.content.into_iter().map(|n| n.into()).collect(), + location: pb.location, + } + } +} + +impl From for MergeWhenClause { + fn from(pb: protobuf::MergeWhenClause) -> Self { + MergeWhenClause { + matched: pb.matched, + command_type: pb.command_type.into(), + override_: pb.r#override.into(), + condition: pb.condition.map(|n| n.into()), + target_list: pb.target_list.into_iter().map(|n| n.into()).collect(), + values: pb.values.into_iter().map(|n| n.into()).collect(), + } + } +} + +// Type-related conversions +impl From for TypeName { + fn from(pb: protobuf::TypeName) -> Self { + TypeName { + names: pb.names.into_iter().map(|n| n.into()).collect(), + type_oid: pb.type_oid, + setof: pb.setof, + pct_type: pb.pct_type, + typmods: pb.typmods.into_iter().map(|n| n.into()).collect(), + typemod: pb.typemod, + array_bounds: pb.array_bounds.into_iter().map(|n| n.into()).collect(), + location: pb.location, + } + } +} + +impl From for ColumnDef { + fn from(pb: protobuf::ColumnDef) -> Self { + ColumnDef { + colname: pb.colname, + type_name: pb.type_name.map(|v| v.into()), + compression: pb.compression, + inhcount: pb.inhcount, + is_local: pb.is_local, + is_not_null: pb.is_not_null, + is_from_type: pb.is_from_type, + storage: pb.storage, + storage_name: pb.storage_name, + raw_default: pb.raw_default.map(|n| n.into()), + cooked_default: pb.cooked_default.map(|n| n.into()), + identity: pb.identity, + identity_sequence: pb.identity_sequence.map(|v| v.into()), + generated: pb.generated, + coll_clause: pb.coll_clause.map(|v| v.into()), + coll_oid: pb.coll_oid, + constraints: pb.constraints.into_iter().map(|n| n.into()).collect(), + fdwoptions: pb.fdwoptions.into_iter().map(|n| n.into()).collect(), + location: pb.location, + } + } +} + +impl From for Constraint { + fn from(pb: protobuf::Constraint) -> Self { + Constraint { + contype: pb.contype.into(), + conname: pb.conname, + deferrable: pb.deferrable, + initdeferred: pb.initdeferred, + location: pb.location, + is_no_inherit: pb.is_no_inherit, + raw_expr: pb.raw_expr.map(|n| n.into()), + cooked_expr: pb.cooked_expr, + generated_when: pb.generated_when, + nulls_not_distinct: pb.nulls_not_distinct, + keys: pb.keys.into_iter().map(|n| n.into()).collect(), + including: pb.including.into_iter().map(|n| n.into()).collect(), + exclusions: pb.exclusions.into_iter().map(|n| n.into()).collect(), + options: pb.options.into_iter().map(|n| n.into()).collect(), + indexname: pb.indexname, + indexspace: pb.indexspace, + reset_default_tblspc: pb.reset_default_tblspc, + access_method: pb.access_method, + where_clause: pb.where_clause.map(|n| n.into()), + pktable: pb.pktable.map(|v| v.into()), + fk_attrs: pb.fk_attrs.into_iter().map(|n| n.into()).collect(), + pk_attrs: pb.pk_attrs.into_iter().map(|n| n.into()).collect(), + fk_matchtype: pb.fk_matchtype, + fk_upd_action: pb.fk_upd_action, + fk_del_action: pb.fk_del_action, + fk_del_set_cols: pb.fk_del_set_cols.into_iter().map(|n| n.into()).collect(), + old_conpfeqop: pb.old_conpfeqop.into_iter().map(|n| n.into()).collect(), + old_pktable_oid: pb.old_pktable_oid, + skip_validation: pb.skip_validation, + initially_valid: pb.initially_valid, + } + } +} + +impl From for DefElem { + fn from(pb: protobuf::DefElem) -> Self { + DefElem { + defnamespace: pb.defnamespace, + defname: pb.defname, + arg: pb.arg.map(|n| n.into()), + defaction: pb.defaction.into(), + location: pb.location, + } + } +} + +impl From for IndexElem { + fn from(pb: protobuf::IndexElem) -> Self { + IndexElem { + name: pb.name, + expr: pb.expr.map(|n| n.into()), + indexcolname: pb.indexcolname, + collation: pb.collation.into_iter().map(|n| n.into()).collect(), + opclass: pb.opclass.into_iter().map(|n| n.into()).collect(), + opclassopts: pb.opclassopts.into_iter().map(|n| n.into()).collect(), + ordering: pb.ordering.into(), + nulls_ordering: pb.nulls_ordering.into(), + } + } +} + +// Alias and role type conversions +impl From for Alias { + fn from(pb: protobuf::Alias) -> Self { + Alias { + aliasname: pb.aliasname, + colnames: pb.colnames.into_iter().map(|n| n.into()).collect(), + } + } +} + +impl From for RoleSpec { + fn from(pb: protobuf::RoleSpec) -> Self { + RoleSpec { + roletype: pb.roletype.into(), + rolename: pb.rolename, + location: pb.location, + } + } +} + +// Other type conversions +impl From for SortGroupClause { + fn from(pb: protobuf::SortGroupClause) -> Self { + SortGroupClause { + tle_sort_group_ref: pb.tle_sort_group_ref, + eqop: pb.eqop, + sortop: pb.sortop, + nulls_first: pb.nulls_first, + hashable: pb.hashable, + } + } +} + +impl From for FunctionParameter { + fn from(pb: protobuf::FunctionParameter) -> Self { + FunctionParameter { + name: pb.name, + arg_type: pb.arg_type.map(|v| v.into()), + mode: pb.mode.into(), + defexpr: pb.defexpr.map(|n| n.into()), + } + } +} + +impl From for AlterTableCmd { + fn from(pb: protobuf::AlterTableCmd) -> Self { + AlterTableCmd { + subtype: pb.subtype.into(), + name: pb.name, + num: pb.num as i16, + newowner: pb.newowner.map(|v| v.into()), + def: pb.def.map(|n| n.into()), + behavior: pb.behavior.into(), + missing_ok: pb.missing_ok, + recurse: pb.recurse, + } + } +} + +impl From for AccessPriv { + fn from(pb: protobuf::AccessPriv) -> Self { + AccessPriv { + priv_name: pb.priv_name, + cols: pb.cols.into_iter().map(|n| n.into()).collect(), + } + } +} + +impl From for ObjectWithArgs { + fn from(pb: protobuf::ObjectWithArgs) -> Self { + ObjectWithArgs { + objname: pb.objname.into_iter().map(|n| n.into()).collect(), + objargs: pb.objargs.into_iter().map(|n| n.into()).collect(), + objfuncargs: pb.objfuncargs.into_iter().map(|n| n.into()).collect(), + args_unspecified: pb.args_unspecified, + } + } +} + +// Administrative statement conversions +impl From for VariableSetStmt { + fn from(pb: protobuf::VariableSetStmt) -> Self { + VariableSetStmt { + kind: pb.kind.into(), + name: pb.name, + args: pb.args.into_iter().map(|n| n.into()).collect(), + is_local: pb.is_local, + } + } +} + +impl From for VariableShowStmt { + fn from(pb: protobuf::VariableShowStmt) -> Self { + VariableShowStmt { name: pb.name } + } +} + +impl From for ExplainStmt { + fn from(pb: protobuf::ExplainStmt) -> Self { + ExplainStmt { + query: pb.query.map(|n| n.into()), + options: pb.options.into_iter().map(|n| n.into()).collect(), + } + } +} + +impl From for CopyStmt { + fn from(pb: protobuf::CopyStmt) -> Self { + CopyStmt { + relation: pb.relation.map(|v| v.into()), + query: pb.query.map(|n| n.into()), + attlist: pb.attlist.into_iter().map(|n| n.into()).collect(), + is_from: pb.is_from, + is_program: pb.is_program, + filename: pb.filename, + options: pb.options.into_iter().map(|n| n.into()).collect(), + where_clause: pb.where_clause.map(|n| n.into()), + } + } +} + +impl From for GrantStmt { + fn from(pb: protobuf::GrantStmt) -> Self { + GrantStmt { + is_grant: pb.is_grant, + targtype: pb.targtype.into(), + objtype: pb.objtype.into(), + objects: pb.objects.into_iter().map(|n| n.into()).collect(), + privileges: pb.privileges.into_iter().map(|n| n.into()).collect(), + grantees: pb.grantees.into_iter().map(|n| n.into()).collect(), + grant_option: pb.grant_option, + grantor: pb.grantor.map(|v| v.into()), + behavior: pb.behavior.into(), + } + } +} + +impl From for GrantRoleStmt { + fn from(pb: protobuf::GrantRoleStmt) -> Self { + GrantRoleStmt { + granted_roles: pb.granted_roles.into_iter().map(|n| n.into()).collect(), + grantee_roles: pb.grantee_roles.into_iter().map(|n| n.into()).collect(), + is_grant: pb.is_grant, + opt: pb.opt.into_iter().map(|n| n.into()).collect(), + grantor: pb.grantor.map(|v| v.into()), + behavior: pb.behavior.into(), + } + } +} + +impl From for LockStmt { + fn from(pb: protobuf::LockStmt) -> Self { + LockStmt { + relations: pb.relations.into_iter().map(|n| n.into()).collect(), + mode: pb.mode, + nowait: pb.nowait, + } + } +} + +impl From for VacuumStmt { + fn from(pb: protobuf::VacuumStmt) -> Self { + VacuumStmt { + options: pb.options.into_iter().map(|n| n.into()).collect(), + rels: pb.rels.into_iter().map(|n| n.into()).collect(), + is_vacuumcmd: pb.is_vacuumcmd, + } + } +} + +// Other statement conversions +impl From for DoStmt { + fn from(pb: protobuf::DoStmt) -> Self { + DoStmt { + args: pb.args.into_iter().map(|n| n.into()).collect(), + } + } +} + +impl From for RenameStmt { + fn from(pb: protobuf::RenameStmt) -> Self { + RenameStmt { + rename_type: pb.rename_type.into(), + relation_type: pb.relation_type.into(), + relation: pb.relation.map(|v| v.into()), + object: pb.object.map(|n| n.into()), + subname: pb.subname, + newname: pb.newname, + behavior: pb.behavior.into(), + missing_ok: pb.missing_ok, + } + } +} + +impl From for NotifyStmt { + fn from(pb: protobuf::NotifyStmt) -> Self { + NotifyStmt { + conditionname: pb.conditionname, + payload: pb.payload, + } + } +} + +impl From for ListenStmt { + fn from(pb: protobuf::ListenStmt) -> Self { + ListenStmt { + conditionname: pb.conditionname, + } + } +} + +impl From for UnlistenStmt { + fn from(pb: protobuf::UnlistenStmt) -> Self { + UnlistenStmt { + conditionname: pb.conditionname, + } + } +} + +impl From for DiscardStmt { + fn from(pb: protobuf::DiscardStmt) -> Self { + DiscardStmt { + target: pb.target.into(), + } + } +} + +impl From for PrepareStmt { + fn from(pb: protobuf::PrepareStmt) -> Self { + PrepareStmt { + name: pb.name, + argtypes: pb.argtypes.into_iter().map(|n| n.into()).collect(), + query: pb.query.map(|n| n.into()), + } + } +} + +impl From for ExecuteStmt { + fn from(pb: protobuf::ExecuteStmt) -> Self { + ExecuteStmt { + name: pb.name, + params: pb.params.into_iter().map(|n| n.into()).collect(), + } + } +} + +impl From for DeallocateStmt { + fn from(pb: protobuf::DeallocateStmt) -> Self { + DeallocateStmt { name: pb.name } + } +} + +impl From for ClosePortalStmt { + fn from(pb: protobuf::ClosePortalStmt) -> Self { + ClosePortalStmt { + portalname: pb.portalname, + } + } +} + +impl From for FetchStmt { + fn from(pb: protobuf::FetchStmt) -> Self { + FetchStmt { + direction: pb.direction.into(), + how_many: pb.how_many, + portalname: pb.portalname, + ismove: pb.ismove, + } + } +} + +// ============================================================================ +// Enum conversions +// ============================================================================ + +impl From for SetOperation { + fn from(v: i32) -> Self { + match v { + 1 => SetOperation::None, // SETOP_NONE + 2 => SetOperation::Union, // SETOP_UNION + 3 => SetOperation::Intersect, // SETOP_INTERSECT + 4 => SetOperation::Except, // SETOP_EXCEPT + _ => SetOperation::None, + } + } +} + +impl From for LimitOption { + fn from(v: i32) -> Self { + match v { + 1 => LimitOption::Default, // LIMIT_OPTION_DEFAULT + 2 => LimitOption::Count, // LIMIT_OPTION_COUNT + 3 => LimitOption::WithTies, // LIMIT_OPTION_WITH_TIES + _ => LimitOption::Default, + } + } +} + +impl From for AExprKind { + fn from(v: i32) -> Self { + match v { + 1 => AExprKind::Op, // AEXPR_OP + 2 => AExprKind::OpAny, // AEXPR_OP_ANY + 3 => AExprKind::OpAll, // AEXPR_OP_ALL + 4 => AExprKind::Distinct, // AEXPR_DISTINCT + 5 => AExprKind::NotDistinct, // AEXPR_NOT_DISTINCT + 6 => AExprKind::NullIf, // AEXPR_NULLIF + 7 => AExprKind::In, // AEXPR_IN + 8 => AExprKind::Like, // AEXPR_LIKE + 9 => AExprKind::ILike, // AEXPR_ILIKE + 10 => AExprKind::Similar, // AEXPR_SIMILAR + 11 => AExprKind::Between, // AEXPR_BETWEEN + 12 => AExprKind::NotBetween, // AEXPR_NOT_BETWEEN + 13 => AExprKind::BetweenSym, // AEXPR_BETWEEN_SYM + 14 => AExprKind::NotBetweenSym, // AEXPR_NOT_BETWEEN_SYM + _ => AExprKind::Op, + } + } +} + +impl From for BoolExprType { + fn from(v: i32) -> Self { + match v { + 1 => BoolExprType::And, // AND_EXPR + 2 => BoolExprType::Or, // OR_EXPR + 3 => BoolExprType::Not, // NOT_EXPR + _ => BoolExprType::And, + } + } +} + +impl From for SubLinkType { + fn from(v: i32) -> Self { + match v { + 1 => SubLinkType::Exists, + 2 => SubLinkType::All, + 3 => SubLinkType::Any, + 4 => SubLinkType::RowCompare, + 5 => SubLinkType::Expr, + 6 => SubLinkType::MultiExpr, + 7 => SubLinkType::Array, + 8 => SubLinkType::Cte, + _ => SubLinkType::Exists, + } + } +} + +impl From for NullTestType { + fn from(v: i32) -> Self { + match v { + 1 => NullTestType::IsNull, + 2 => NullTestType::IsNotNull, + _ => NullTestType::IsNull, + } + } +} + +impl From for BoolTestType { + fn from(v: i32) -> Self { + match v { + 1 => BoolTestType::IsTrue, + 2 => BoolTestType::IsNotTrue, + 3 => BoolTestType::IsFalse, + 4 => BoolTestType::IsNotFalse, + 5 => BoolTestType::IsUnknown, + 6 => BoolTestType::IsNotUnknown, + _ => BoolTestType::IsTrue, + } + } +} + +impl From for MinMaxOp { + fn from(v: i32) -> Self { + match v { + 1 => MinMaxOp::Greatest, + 2 => MinMaxOp::Least, + _ => MinMaxOp::Greatest, + } + } +} + +impl From for JoinType { + fn from(v: i32) -> Self { + match v { + 1 => JoinType::Inner, + 2 => JoinType::Left, + 3 => JoinType::Full, + 4 => JoinType::Right, + 5 => JoinType::Semi, + 6 => JoinType::Anti, + 7 => JoinType::RightAnti, + 8 => JoinType::UniqueOuter, + 9 => JoinType::UniqueInner, + _ => JoinType::Inner, + } + } +} + +impl From for SortByDir { + fn from(v: i32) -> Self { + match v { + 1 => SortByDir::Default, + 2 => SortByDir::Asc, + 3 => SortByDir::Desc, + 4 => SortByDir::Using, + _ => SortByDir::Default, + } + } +} + +impl From for SortByNulls { + fn from(v: i32) -> Self { + match v { + 1 => SortByNulls::Default, + 2 => SortByNulls::First, + 3 => SortByNulls::Last, + _ => SortByNulls::Default, + } + } +} + +impl From for CTEMaterialize { + fn from(v: i32) -> Self { + match v { + 1 => CTEMaterialize::Default, + 2 => CTEMaterialize::Always, + 3 => CTEMaterialize::Never, + _ => CTEMaterialize::Default, + } + } +} + +impl From for OnCommitAction { + fn from(v: i32) -> Self { + match v { + 1 => OnCommitAction::Noop, + 2 => OnCommitAction::PreserveRows, + 3 => OnCommitAction::DeleteRows, + 4 => OnCommitAction::Drop, + _ => OnCommitAction::Noop, + } + } +} + +impl From for ObjectType { + fn from(v: i32) -> Self { + // Use direct integer matching + // Values from protobuf ObjectType enum + match v { + 1 => ObjectType::AccessMethod, + 2 => ObjectType::Aggregate, + 11 => ObjectType::Cast, + 12 => ObjectType::Column, + 13 => ObjectType::Collation, + 14 => ObjectType::Conversion, + 15 => ObjectType::Database, + 16 => ObjectType::Default, + 17 => ObjectType::Constraint, + 18 => ObjectType::Domain, + 19 => ObjectType::EventTrigger, + 20 => ObjectType::Extension, + 21 => ObjectType::Fdw, + 22 => ObjectType::ForeignServer, + 23 => ObjectType::ForeignTable, + 24 => ObjectType::Function, + 25 => ObjectType::Index, + 26 => ObjectType::Language, + 27 => ObjectType::LargeObject, + 28 => ObjectType::MatView, + 29 => ObjectType::Operator, + 37 => ObjectType::Policy, + 38 => ObjectType::Procedure, + 39 => ObjectType::Publication, + 44 => ObjectType::Role, + 45 => ObjectType::Routine, + 46 => ObjectType::Rule, + 47 => ObjectType::Schema, + 48 => ObjectType::Sequence, + 49 => ObjectType::Subscription, + 50 => ObjectType::StatisticsObject, + 54 => ObjectType::Table, + 55 => ObjectType::Tablespace, + 57 => ObjectType::Transform, + 58 => ObjectType::Trigger, + 60 => ObjectType::Type, + 62 => ObjectType::View, + _ => ObjectType::Table, + } + } +} + +impl From for DropBehavior { + fn from(v: i32) -> Self { + match v { + 1 => DropBehavior::Restrict, + 2 => DropBehavior::Cascade, + _ => DropBehavior::Restrict, + } + } +} + +impl From for OnConflictAction { + fn from(v: i32) -> Self { + match v { + 1 => OnConflictAction::None, + 2 => OnConflictAction::Nothing, + 3 => OnConflictAction::Update, + _ => OnConflictAction::None, + } + } +} + +impl From for GroupingSetKind { + fn from(v: i32) -> Self { + match v { + 1 => GroupingSetKind::Empty, + 2 => GroupingSetKind::Simple, + 3 => GroupingSetKind::Rollup, + 4 => GroupingSetKind::Cube, + 5 => GroupingSetKind::Sets, + _ => GroupingSetKind::Empty, + } + } +} + +impl From for CmdType { + fn from(v: i32) -> Self { + match v { + 1 => CmdType::Unknown, + 2 => CmdType::Select, + 3 => CmdType::Update, + 4 => CmdType::Insert, + 5 => CmdType::Delete, + 6 => CmdType::Merge, + 7 => CmdType::Utility, + 8 => CmdType::Nothing, + _ => CmdType::Unknown, + } + } +} + +impl From for TransactionStmtKind { + fn from(v: i32) -> Self { + match v { + 1 => TransactionStmtKind::Begin, + 2 => TransactionStmtKind::Start, + 3 => TransactionStmtKind::Commit, + 4 => TransactionStmtKind::Rollback, + 5 => TransactionStmtKind::Savepoint, + 6 => TransactionStmtKind::Release, + 7 => TransactionStmtKind::RollbackTo, + 8 => TransactionStmtKind::Prepare, + 9 => TransactionStmtKind::CommitPrepared, + 10 => TransactionStmtKind::RollbackPrepared, + _ => TransactionStmtKind::Begin, + } + } +} + +impl From for ConstrType { + fn from(v: i32) -> Self { + match v { + 1 => ConstrType::Null, + 2 => ConstrType::NotNull, + 3 => ConstrType::Default, + 4 => ConstrType::Identity, + 5 => ConstrType::Generated, + 6 => ConstrType::Check, + 7 => ConstrType::Primary, + 8 => ConstrType::Unique, + 9 => ConstrType::Exclusion, + 10 => ConstrType::Foreign, + 11 => ConstrType::AttrDeferrable, + 12 => ConstrType::AttrNotDeferrable, + 13 => ConstrType::AttrDeferred, + 14 => ConstrType::AttrImmediate, + _ => ConstrType::Null, + } + } +} + +impl From for DefElemAction { + fn from(v: i32) -> Self { + match v { + 1 => DefElemAction::Unspec, + 2 => DefElemAction::Set, + 3 => DefElemAction::Add, + 4 => DefElemAction::Drop, + _ => DefElemAction::Unspec, + } + } +} + +impl From for RoleSpecType { + fn from(v: i32) -> Self { + match v { + 1 => RoleSpecType::CString, + 2 => RoleSpecType::CurrentRole, + 3 => RoleSpecType::CurrentUser, + 4 => RoleSpecType::SessionUser, + 5 => RoleSpecType::Public, + _ => RoleSpecType::CString, + } + } +} + +impl From for CoercionForm { + fn from(v: i32) -> Self { + match v { + 1 => CoercionForm::ExplicitCall, + 2 => CoercionForm::ExplicitCast, + 3 => CoercionForm::ImplicitCast, + 4 => CoercionForm::SqlSyntax, + _ => CoercionForm::ExplicitCall, + } + } +} + +impl From for VariableSetKind { + fn from(v: i32) -> Self { + match v { + 1 => VariableSetKind::Value, + 2 => VariableSetKind::Default, + 3 => VariableSetKind::Current, + 4 => VariableSetKind::Multi, + 5 => VariableSetKind::Reset, + 6 => VariableSetKind::ResetAll, + _ => VariableSetKind::Value, + } + } +} + +impl From for LockClauseStrength { + fn from(v: i32) -> Self { + match v { + 1 => LockClauseStrength::None, + 2 => LockClauseStrength::ForKeyShare, + 3 => LockClauseStrength::ForShare, + 4 => LockClauseStrength::ForNoKeyUpdate, + 5 => LockClauseStrength::ForUpdate, + _ => LockClauseStrength::None, + } + } +} + +impl From for LockWaitPolicy { + fn from(v: i32) -> Self { + match v { + 1 => LockWaitPolicy::Block, + 2 => LockWaitPolicy::Skip, + 3 => LockWaitPolicy::Error, + _ => LockWaitPolicy::Block, + } + } +} + +impl From for ViewCheckOption { + fn from(v: i32) -> Self { + match v { + 1 => ViewCheckOption::NoCheckOption, + 2 => ViewCheckOption::Local, + 3 => ViewCheckOption::Cascaded, + _ => ViewCheckOption::NoCheckOption, + } + } +} + +impl From for DiscardMode { + fn from(v: i32) -> Self { + match v { + 1 => DiscardMode::All, + 2 => DiscardMode::Plans, + 3 => DiscardMode::Sequences, + 4 => DiscardMode::Temp, + _ => DiscardMode::All, + } + } +} + +impl From for FetchDirection { + fn from(v: i32) -> Self { + match v { + 1 => FetchDirection::Forward, + 2 => FetchDirection::Backward, + 3 => FetchDirection::Absolute, + 4 => FetchDirection::Relative, + _ => FetchDirection::Forward, + } + } +} + +impl From for FunctionParameterMode { + fn from(v: i32) -> Self { + match v { + 105 => FunctionParameterMode::In, // 'i' + 111 => FunctionParameterMode::Out, // 'o' + 98 => FunctionParameterMode::InOut, // 'b' + 118 => FunctionParameterMode::Variadic, // 'v' + 116 => FunctionParameterMode::Table, // 't' + _ => FunctionParameterMode::In, + } + } +} + +impl From for AlterTableType { + fn from(v: i32) -> Self { + // AlterTableType has many variants, use default for simplicity + // The values start at 1 and go up + match v { + 1 => AlterTableType::AddColumn, + 2 => AlterTableType::AddColumnToView, + 3 => AlterTableType::ColumnDefault, + 4 => AlterTableType::CookedColumnDefault, + 5 => AlterTableType::DropNotNull, + 6 => AlterTableType::SetNotNull, + 7 => AlterTableType::DropExpression, + 8 => AlterTableType::CheckNotNull, + 9 => AlterTableType::SetStatistics, + 10 => AlterTableType::SetOptions, + 11 => AlterTableType::ResetOptions, + 12 => AlterTableType::SetStorage, + 13 => AlterTableType::SetCompression, + 14 => AlterTableType::DropColumn, + 15 => AlterTableType::AddIndex, + 16 => AlterTableType::ReAddIndex, + 17 => AlterTableType::AddConstraint, + 18 => AlterTableType::ReAddConstraint, + 19 => AlterTableType::AddIndexConstraint, + 20 => AlterTableType::AlterConstraint, + 21 => AlterTableType::ValidateConstraint, + 22 => AlterTableType::DropConstraint, + 23 => AlterTableType::ClusterOn, + 24 => AlterTableType::DropCluster, + 25 => AlterTableType::SetLogged, + 26 => AlterTableType::SetUnLogged, + 27 => AlterTableType::SetAccessMethod, + 28 => AlterTableType::DropOids, + 29 => AlterTableType::SetTableSpace, + 30 => AlterTableType::SetRelOptions, + 31 => AlterTableType::ResetRelOptions, + 32 => AlterTableType::ReplaceRelOptions, + 33 => AlterTableType::EnableTrig, + 34 => AlterTableType::EnableAlwaysTrig, + 35 => AlterTableType::EnableReplicaTrig, + 36 => AlterTableType::DisableTrig, + 37 => AlterTableType::EnableTrigAll, + 38 => AlterTableType::DisableTrigAll, + 39 => AlterTableType::EnableTrigUser, + 40 => AlterTableType::DisableTrigUser, + 41 => AlterTableType::EnableRule, + 42 => AlterTableType::EnableAlwaysRule, + 43 => AlterTableType::EnableReplicaRule, + 44 => AlterTableType::DisableRule, + 45 => AlterTableType::AddInherit, + 46 => AlterTableType::DropInherit, + 47 => AlterTableType::AddOf, + 48 => AlterTableType::DropOf, + 49 => AlterTableType::ReplicaIdentity, + 50 => AlterTableType::EnableRowSecurity, + 51 => AlterTableType::DisableRowSecurity, + 52 => AlterTableType::ForceRowSecurity, + 53 => AlterTableType::NoForceRowSecurity, + 54 => AlterTableType::GenericOptions, + 55 => AlterTableType::AttachPartition, + 56 => AlterTableType::DetachPartition, + 57 => AlterTableType::DetachPartitionFinalize, + 58 => AlterTableType::AddIdentity, + 59 => AlterTableType::SetIdentity, + 60 => AlterTableType::DropIdentity, + 61 => AlterTableType::ReAddStatistics, + _ => AlterTableType::AddColumn, + } + } +} + +impl From for GrantTargetType { + fn from(v: i32) -> Self { + match v { + 1 => GrantTargetType::Object, + 2 => GrantTargetType::AllInSchema, + 3 => GrantTargetType::Defaults, + _ => GrantTargetType::Object, + } + } +} + +impl From for OverridingKind { + fn from(v: i32) -> Self { + match v { + 1 => OverridingKind::NotSet, + 2 => OverridingKind::UserValue, + 3 => OverridingKind::SystemValue, + _ => OverridingKind::NotSet, + } + } +} + diff --git a/src/ast/mod.rs b/src/ast/mod.rs new file mode 100644 index 0000000..8cc595c --- /dev/null +++ b/src/ast/mod.rs @@ -0,0 +1,31 @@ +//! Native Rust AST types for PostgreSQL parse trees. +//! +//! This module provides ergonomic Rust types that wrap the PostgreSQL parse tree +//! structure. These types make it easier to work with parsed SQL queries without +//! the complexity of deeply nested protobuf Option> wrappers. +//! +//! # Example +//! +//! ```rust +//! use pg_query::ast::Node; +//! +//! let result = pg_query::parse_to_ast("SELECT * FROM users WHERE id = 1").unwrap(); +//! for stmt in &result.stmts { +//! match &stmt.stmt { +//! Node::SelectStmt(select) => { +//! // Access fields more directly +//! for table in &select.from_clause { +//! if let Node::RangeVar(rv) = table { +//! println!("Table: {}", rv.relname); +//! } +//! } +//! } +//! _ => {} +//! } +//! } +//! ``` + +mod nodes; +mod convert; + +pub use nodes::*; diff --git a/src/ast/nodes.rs b/src/ast/nodes.rs new file mode 100644 index 0000000..046d16a --- /dev/null +++ b/src/ast/nodes.rs @@ -0,0 +1,1617 @@ +//! Native Rust AST node types for PostgreSQL parse trees. +//! +//! These types mirror the PostgreSQL parse tree structure but use idiomatic Rust +//! patterns instead of protobuf-style Option> wrappers. + +use crate::protobuf; + +/// Top-level parse result containing all parsed statements. +#[derive(Debug, Clone)] +pub struct ParseResult { + /// PostgreSQL version number (e.g., 160001 for 16.0.1) + pub version: i32, + /// List of parsed statements + pub stmts: Vec, + /// Original protobuf for deparsing (hidden implementation detail) + pub(crate) original_protobuf: protobuf::ParseResult, +} + +// as_protobuf method is defined in convert.rs + +/// A raw statement wrapper with location information. +#[derive(Debug, Clone)] +pub struct RawStmt { + /// The statement node + pub stmt: Node, + /// Character offset in source where statement starts + pub stmt_location: i32, + /// Length of statement in characters (0 means "rest of string") + pub stmt_len: i32, +} + +/// The main AST node enum containing all possible node types. +/// +/// This enum eliminates the need for `Option>` wrappers throughout +/// the AST by using a flat enum with all node types as variants. +#[derive(Debug, Clone)] +pub enum Node { + // Primitive value types + Integer(Integer), + Float(Float), + Boolean(Boolean), + String(StringValue), + BitString(BitString), + Null, + + // List type + List(Vec), + + // Statement types + SelectStmt(Box), + InsertStmt(Box), + UpdateStmt(Box), + DeleteStmt(Box), + MergeStmt(Box), + + // DDL statements + CreateStmt(Box), + AlterTableStmt(Box), + DropStmt(Box), + TruncateStmt(Box), + IndexStmt(Box), + CreateSchemaStmt(Box), + ViewStmt(Box), + CreateFunctionStmt(Box), + AlterFunctionStmt(Box), + CreateSeqStmt(Box), + AlterSeqStmt(Box), + CreateTrigStmt(Box), + RuleStmt(Box), + CreateDomainStmt(Box), + CreateTableAsStmt(Box), + RefreshMatViewStmt(Box), + + // Transaction statement + TransactionStmt(Box), + + // Expression types + AExpr(Box), + ColumnRef(Box), + ParamRef(Box), + AConst(Box), + TypeCast(Box), + CollateClause(Box), + FuncCall(Box), + AStar(AStar), + AIndices(Box), + AIndirection(Box), + AArrayExpr(Box), + SubLink(Box), + BoolExpr(Box), + NullTest(Box), + BooleanTest(Box), + CaseExpr(Box), + CaseWhen(Box), + CoalesceExpr(Box), + MinMaxExpr(Box), + RowExpr(Box), + + // Target/Result types + ResTarget(Box), + + // Table/Range types + RangeVar(Box), + RangeSubselect(Box), + RangeFunction(Box), + JoinExpr(Box), + + // Clause types + SortBy(Box), + WindowDef(Box), + WithClause(Box), + CommonTableExpr(Box), + IntoClause(Box), + OnConflictClause(Box), + LockingClause(Box), + GroupingSet(Box), + MergeWhenClause(Box), + + // Type-related + TypeName(Box), + ColumnDef(Box), + Constraint(Box), + DefElem(Box), + IndexElem(Box), + + // Alias and role types + Alias(Box), + RoleSpec(Box), + + // Other commonly used types + SortGroupClause(Box), + FunctionParameter(Box), + AlterTableCmd(Box), + AccessPriv(Box), + ObjectWithArgs(Box), + + // Administrative statements + VariableSetStmt(Box), + VariableShowStmt(Box), + ExplainStmt(Box), + CopyStmt(Box), + GrantStmt(Box), + GrantRoleStmt(Box), + LockStmt(Box), + VacuumStmt(Box), + + // Other statements + DoStmt(Box), + RenameStmt(Box), + NotifyStmt(Box), + ListenStmt(Box), + UnlistenStmt(Box), + CheckPointStmt(Box), + DiscardStmt(Box), + PrepareStmt(Box), + ExecuteStmt(Box), + DeallocateStmt(Box), + ClosePortalStmt(Box), + FetchStmt(Box), + + // Fallback for unhandled node types - stores the original protobuf + Other(protobuf::Node), +} + +// ============================================================================ +// Primitive value types +// ============================================================================ + +/// Integer value +#[derive(Debug, Clone, Default)] +pub struct Integer { + pub ival: i32, +} + +/// Float value (stored as string) +#[derive(Debug, Clone, Default)] +pub struct Float { + pub fval: String, +} + +/// Boolean value +#[derive(Debug, Clone, Default)] +pub struct Boolean { + pub boolval: bool, +} + +/// String value +#[derive(Debug, Clone, Default)] +pub struct StringValue { + pub sval: String, +} + +/// Bit string value +#[derive(Debug, Clone, Default)] +pub struct BitString { + pub bsval: String, +} + +/// A star (*) in column reference +#[derive(Debug, Clone, Default)] +pub struct AStar; + +// ============================================================================ +// Core statement types +// ============================================================================ + +/// SELECT statement +#[derive(Debug, Clone, Default)] +pub struct SelectStmt { + pub distinct_clause: Vec, + pub into_clause: Option, + pub target_list: Vec, + pub from_clause: Vec, + pub where_clause: Option, + pub group_clause: Vec, + pub group_distinct: bool, + pub having_clause: Option, + pub window_clause: Vec, + pub values_lists: Vec, + pub sort_clause: Vec, + pub limit_offset: Option, + pub limit_count: Option, + pub limit_option: LimitOption, + pub locking_clause: Vec, + pub with_clause: Option, + pub op: SetOperation, + pub all: bool, + pub larg: Option>, + pub rarg: Option>, +} + +/// INSERT statement +#[derive(Debug, Clone, Default)] +pub struct InsertStmt { + pub relation: Option, + pub cols: Vec, + pub select_stmt: Option, + pub on_conflict_clause: Option, + pub returning_list: Vec, + pub with_clause: Option, + pub override_: OverridingKind, +} + +/// UPDATE statement +#[derive(Debug, Clone, Default)] +pub struct UpdateStmt { + pub relation: Option, + pub target_list: Vec, + pub where_clause: Option, + pub from_clause: Vec, + pub returning_list: Vec, + pub with_clause: Option, +} + +/// DELETE statement +#[derive(Debug, Clone, Default)] +pub struct DeleteStmt { + pub relation: Option, + pub using_clause: Vec, + pub where_clause: Option, + pub returning_list: Vec, + pub with_clause: Option, +} + +/// MERGE statement +#[derive(Debug, Clone, Default)] +pub struct MergeStmt { + pub relation: Option, + pub source_relation: Option, + pub join_condition: Option, + pub merge_when_clauses: Vec, + pub with_clause: Option, +} + +// ============================================================================ +// DDL statement types +// ============================================================================ + +/// CREATE TABLE statement +#[derive(Debug, Clone, Default)] +pub struct CreateStmt { + pub relation: Option, + pub table_elts: Vec, + pub inh_relations: Vec, + pub partbound: Option, + pub partspec: Option, + pub of_typename: Option, + pub constraints: Vec, + pub options: Vec, + pub oncommit: OnCommitAction, + pub tablespacename: String, + pub access_method: String, + pub if_not_exists: bool, +} + +/// ALTER TABLE statement +#[derive(Debug, Clone, Default)] +pub struct AlterTableStmt { + pub relation: Option, + pub cmds: Vec, + pub objtype: ObjectType, + pub missing_ok: bool, +} + +/// DROP statement +#[derive(Debug, Clone, Default)] +pub struct DropStmt { + pub objects: Vec, + pub remove_type: ObjectType, + pub behavior: DropBehavior, + pub missing_ok: bool, + pub concurrent: bool, +} + +/// TRUNCATE statement +#[derive(Debug, Clone, Default)] +pub struct TruncateStmt { + pub relations: Vec, + pub restart_seqs: bool, + pub behavior: DropBehavior, +} + +/// CREATE INDEX statement +#[derive(Debug, Clone, Default)] +pub struct IndexStmt { + pub idxname: String, + pub relation: Option, + pub access_method: String, + pub table_space: String, + pub index_params: Vec, + pub index_including_params: Vec, + pub options: Vec, + pub where_clause: Option, + pub exclude_op_names: Vec, + pub idxcomment: String, + pub index_oid: u32, + pub old_number: u32, + pub old_first_relfilelocator: u32, + pub unique: bool, + pub nulls_not_distinct: bool, + pub primary: bool, + pub is_constraint: bool, + pub deferrable: bool, + pub initdeferred: bool, + pub transformed: bool, + pub concurrent: bool, + pub if_not_exists: bool, + pub reset_default_tblspc: bool, +} + +/// CREATE SCHEMA statement +#[derive(Debug, Clone, Default)] +pub struct CreateSchemaStmt { + pub schemaname: String, + pub authrole: Option, + pub schema_elts: Vec, + pub if_not_exists: bool, +} + +/// CREATE VIEW statement +#[derive(Debug, Clone, Default)] +pub struct ViewStmt { + pub view: Option, + pub aliases: Vec, + pub query: Option, + pub replace: bool, + pub options: Vec, + pub with_check_option: ViewCheckOption, +} + +/// CREATE FUNCTION statement +#[derive(Debug, Clone, Default)] +pub struct CreateFunctionStmt { + pub is_procedure: bool, + pub replace: bool, + pub funcname: Vec, + pub parameters: Vec, + pub return_type: Option, + pub options: Vec, + pub sql_body: Option, +} + +/// ALTER FUNCTION statement +#[derive(Debug, Clone, Default)] +pub struct AlterFunctionStmt { + pub objtype: ObjectType, + pub func: Option, + pub actions: Vec, +} + +/// CREATE SEQUENCE statement +#[derive(Debug, Clone, Default)] +pub struct CreateSeqStmt { + pub sequence: Option, + pub options: Vec, + pub owner_id: u32, + pub for_identity: bool, + pub if_not_exists: bool, +} + +/// ALTER SEQUENCE statement +#[derive(Debug, Clone, Default)] +pub struct AlterSeqStmt { + pub sequence: Option, + pub options: Vec, + pub for_identity: bool, + pub missing_ok: bool, +} + +/// CREATE TRIGGER statement +#[derive(Debug, Clone, Default)] +pub struct CreateTrigStmt { + pub replace: bool, + pub isconstraint: bool, + pub trigname: String, + pub relation: Option, + pub funcname: Vec, + pub args: Vec, + pub row: bool, + pub timing: i32, + pub events: i32, + pub columns: Vec, + pub when_clause: Option, + pub transition_rels: Vec, + pub deferrable: bool, + pub initdeferred: bool, + pub constrrel: Option, +} + +/// CREATE RULE statement +#[derive(Debug, Clone, Default)] +pub struct RuleStmt { + pub relation: Option, + pub rulename: String, + pub where_clause: Option, + pub event: CmdType, + pub instead: bool, + pub actions: Vec, + pub replace: bool, +} + +/// CREATE DOMAIN statement +#[derive(Debug, Clone, Default)] +pub struct CreateDomainStmt { + pub domainname: Vec, + pub type_name: Option, + pub coll_clause: Option, + pub constraints: Vec, +} + +/// CREATE TABLE AS statement +#[derive(Debug, Clone, Default)] +pub struct CreateTableAsStmt { + pub query: Option, + pub into: Option, + pub objtype: ObjectType, + pub is_select_into: bool, + pub if_not_exists: bool, +} + +/// REFRESH MATERIALIZED VIEW statement +#[derive(Debug, Clone, Default)] +pub struct RefreshMatViewStmt { + pub concurrent: bool, + pub skip_data: bool, + pub relation: Option, +} + +// ============================================================================ +// Transaction statement +// ============================================================================ + +/// Transaction statement (BEGIN, COMMIT, ROLLBACK, etc.) +#[derive(Debug, Clone, Default)] +pub struct TransactionStmt { + pub kind: TransactionStmtKind, + pub options: Vec, + pub savepoint_name: String, + pub gid: String, + pub chain: bool, +} + +// ============================================================================ +// Expression types +// ============================================================================ + +/// An expression with an operator (e.g., "a + b", "x = 1") +#[derive(Debug, Clone, Default)] +pub struct AExpr { + pub kind: AExprKind, + pub name: Vec, + pub lexpr: Option, + pub rexpr: Option, + pub location: i32, +} + +/// Column reference (e.g., "table.column") +#[derive(Debug, Clone, Default)] +pub struct ColumnRef { + pub fields: Vec, + pub location: i32, +} + +/// Parameter reference ($1, $2, etc.) +#[derive(Debug, Clone, Default)] +pub struct ParamRef { + pub number: i32, + pub location: i32, +} + +/// A constant value +#[derive(Debug, Clone, Default)] +pub struct AConst { + pub val: Option, + pub isnull: bool, + pub location: i32, +} + +/// Value types for AConst +#[derive(Debug, Clone)] +pub enum AConstValue { + Integer(Integer), + Float(Float), + Boolean(Boolean), + String(StringValue), + BitString(BitString), +} + +/// Type cast expression +#[derive(Debug, Clone, Default)] +pub struct TypeCast { + pub arg: Option, + pub type_name: Option, + pub location: i32, +} + +/// COLLATE clause +#[derive(Debug, Clone, Default)] +pub struct CollateClause { + pub arg: Option, + pub collname: Vec, + pub location: i32, +} + +/// Function call +#[derive(Debug, Clone, Default)] +pub struct FuncCall { + pub funcname: Vec, + pub args: Vec, + pub agg_order: Vec, + pub agg_filter: Option, + pub over: Option, + pub agg_within_group: bool, + pub agg_star: bool, + pub agg_distinct: bool, + pub func_variadic: bool, + pub funcformat: CoercionForm, + pub location: i32, +} + +/// Array subscript indices +#[derive(Debug, Clone, Default)] +pub struct AIndices { + pub is_slice: bool, + pub lidx: Option, + pub uidx: Option, +} + +/// Array subscript or field selection +#[derive(Debug, Clone, Default)] +pub struct AIndirection { + pub arg: Option, + pub indirection: Vec, +} + +/// ARRAY[] constructor +#[derive(Debug, Clone, Default)] +pub struct AArrayExpr { + pub elements: Vec, + pub location: i32, +} + +/// Subquery link (subquery in expression context) +#[derive(Debug, Clone, Default)] +pub struct SubLink { + pub sub_link_type: SubLinkType, + pub sub_link_id: i32, + pub testexpr: Option, + pub oper_name: Vec, + pub subselect: Option, + pub location: i32, +} + +/// Boolean expression (AND, OR, NOT) +#[derive(Debug, Clone, Default)] +pub struct BoolExpr { + pub boolop: BoolExprType, + pub args: Vec, + pub location: i32, +} + +/// NULL test expression +#[derive(Debug, Clone, Default)] +pub struct NullTest { + pub arg: Option, + pub nulltesttype: NullTestType, + pub argisrow: bool, + pub location: i32, +} + +/// Boolean test (IS TRUE, IS FALSE, etc.) +#[derive(Debug, Clone, Default)] +pub struct BooleanTest { + pub arg: Option, + pub booltesttype: BoolTestType, + pub location: i32, +} + +/// CASE expression +#[derive(Debug, Clone, Default)] +pub struct CaseExpr { + pub arg: Option, + pub args: Vec, + pub defresult: Option, + pub location: i32, +} + +/// WHEN clause of CASE +#[derive(Debug, Clone, Default)] +pub struct CaseWhen { + pub expr: Option, + pub result: Option, + pub location: i32, +} + +/// COALESCE expression +#[derive(Debug, Clone, Default)] +pub struct CoalesceExpr { + pub args: Vec, + pub location: i32, +} + +/// GREATEST or LEAST expression +#[derive(Debug, Clone, Default)] +pub struct MinMaxExpr { + pub op: MinMaxOp, + pub args: Vec, + pub location: i32, +} + +/// ROW() expression +#[derive(Debug, Clone, Default)] +pub struct RowExpr { + pub args: Vec, + pub row_format: CoercionForm, + pub colnames: Vec, + pub location: i32, +} + +// ============================================================================ +// Target/Result types +// ============================================================================ + +/// Result target (column in SELECT list or assignment target) +#[derive(Debug, Clone, Default)] +pub struct ResTarget { + pub name: String, + pub indirection: Vec, + pub val: Option, + pub location: i32, +} + +// ============================================================================ +// Table/Range types +// ============================================================================ + +/// Table/relation reference +#[derive(Debug, Clone, Default)] +pub struct RangeVar { + pub catalogname: String, + pub schemaname: String, + pub relname: String, + pub inh: bool, + pub relpersistence: String, + pub alias: Option, + pub location: i32, +} + +/// Subquery in FROM clause +#[derive(Debug, Clone, Default)] +pub struct RangeSubselect { + pub lateral: bool, + pub subquery: Option, + pub alias: Option, +} + +/// Function call in FROM clause +#[derive(Debug, Clone, Default)] +pub struct RangeFunction { + pub lateral: bool, + pub ordinality: bool, + pub is_rowsfrom: bool, + pub functions: Vec, + pub alias: Option, + pub coldeflist: Vec, +} + +/// JOIN expression +#[derive(Debug, Clone, Default)] +pub struct JoinExpr { + pub jointype: JoinType, + pub is_natural: bool, + pub larg: Option, + pub rarg: Option, + pub using_clause: Vec, + pub join_using_alias: Option, + pub quals: Option, + pub alias: Option, + pub rtindex: i32, +} + +// ============================================================================ +// Clause types +// ============================================================================ + +/// ORDER BY clause element +#[derive(Debug, Clone, Default)] +pub struct SortBy { + pub node: Option, + pub sortby_dir: SortByDir, + pub sortby_nulls: SortByNulls, + pub use_op: Vec, + pub location: i32, +} + +/// WINDOW definition +#[derive(Debug, Clone, Default)] +pub struct WindowDef { + pub name: String, + pub refname: String, + pub partition_clause: Vec, + pub order_clause: Vec, + pub frame_options: i32, + pub start_offset: Option, + pub end_offset: Option, + pub location: i32, +} + +/// WITH clause +#[derive(Debug, Clone, Default)] +pub struct WithClause { + pub ctes: Vec, + pub recursive: bool, + pub location: i32, +} + +/// Common Table Expression (CTE) +#[derive(Debug, Clone, Default)] +pub struct CommonTableExpr { + pub ctename: String, + pub aliascolnames: Vec, + pub ctematerialized: CTEMaterialize, + pub ctequery: Option, + pub search_clause: Option, + pub cycle_clause: Option, + pub location: i32, + pub cterecursive: bool, + pub cterefcount: i32, + pub ctecolnames: Vec, + pub ctecoltypes: Vec, + pub ctecoltypmods: Vec, + pub ctecolcollations: Vec, +} + +/// INTO clause for SELECT INTO +#[derive(Debug, Clone, Default)] +pub struct IntoClause { + pub rel: Option, + pub col_names: Vec, + pub access_method: String, + pub options: Vec, + pub on_commit: OnCommitAction, + pub table_space_name: String, + pub view_query: Option, + pub skip_data: bool, +} + +/// ON CONFLICT clause for INSERT +#[derive(Debug, Clone, Default)] +pub struct OnConflictClause { + pub action: OnConflictAction, + pub infer: Option, + pub target_list: Vec, + pub where_clause: Option, + pub location: i32, +} + +/// FOR UPDATE/SHARE clause +#[derive(Debug, Clone, Default)] +pub struct LockingClause { + pub locked_rels: Vec, + pub strength: LockClauseStrength, + pub wait_policy: LockWaitPolicy, +} + +/// GROUPING SETS clause element +#[derive(Debug, Clone, Default)] +pub struct GroupingSet { + pub kind: GroupingSetKind, + pub content: Vec, + pub location: i32, +} + +/// MERGE WHEN clause +#[derive(Debug, Clone, Default)] +pub struct MergeWhenClause { + pub matched: bool, + pub command_type: CmdType, + pub override_: OverridingKind, + pub condition: Option, + pub target_list: Vec, + pub values: Vec, +} + +// ============================================================================ +// Type-related +// ============================================================================ + +/// Type name +#[derive(Debug, Clone, Default)] +pub struct TypeName { + pub names: Vec, + pub type_oid: u32, + pub setof: bool, + pub pct_type: bool, + pub typmods: Vec, + pub typemod: i32, + pub array_bounds: Vec, + pub location: i32, +} + +/// Column definition +#[derive(Debug, Clone, Default)] +pub struct ColumnDef { + pub colname: String, + pub type_name: Option, + pub compression: String, + pub inhcount: i32, + pub is_local: bool, + pub is_not_null: bool, + pub is_from_type: bool, + pub storage: String, + pub storage_name: String, + pub raw_default: Option, + pub cooked_default: Option, + pub identity: String, + pub identity_sequence: Option, + pub generated: String, + pub coll_clause: Option, + pub coll_oid: u32, + pub constraints: Vec, + pub fdwoptions: Vec, + pub location: i32, +} + +/// Constraint definition +#[derive(Debug, Clone, Default)] +pub struct Constraint { + pub contype: ConstrType, + pub conname: String, + pub deferrable: bool, + pub initdeferred: bool, + pub location: i32, + pub is_no_inherit: bool, + pub raw_expr: Option, + pub cooked_expr: String, + pub generated_when: String, + pub nulls_not_distinct: bool, + pub keys: Vec, + pub including: Vec, + pub exclusions: Vec, + pub options: Vec, + pub indexname: String, + pub indexspace: String, + pub reset_default_tblspc: bool, + pub access_method: String, + pub where_clause: Option, + pub pktable: Option, + pub fk_attrs: Vec, + pub pk_attrs: Vec, + pub fk_matchtype: String, + pub fk_upd_action: String, + pub fk_del_action: String, + pub fk_del_set_cols: Vec, + pub old_conpfeqop: Vec, + pub old_pktable_oid: u32, + pub skip_validation: bool, + pub initially_valid: bool, +} + +/// Definition element (generic) +#[derive(Debug, Clone, Default)] +pub struct DefElem { + pub defnamespace: String, + pub defname: String, + pub arg: Option, + pub defaction: DefElemAction, + pub location: i32, +} + +/// Index element +#[derive(Debug, Clone, Default)] +pub struct IndexElem { + pub name: String, + pub expr: Option, + pub indexcolname: String, + pub collation: Vec, + pub opclass: Vec, + pub opclassopts: Vec, + pub ordering: SortByDir, + pub nulls_ordering: SortByNulls, +} + +// ============================================================================ +// Alias and role types +// ============================================================================ + +/// Alias +#[derive(Debug, Clone, Default)] +pub struct Alias { + pub aliasname: String, + pub colnames: Vec, +} + +/// Role specification +#[derive(Debug, Clone, Default)] +pub struct RoleSpec { + pub roletype: RoleSpecType, + pub rolename: String, + pub location: i32, +} + +// ============================================================================ +// Other commonly used types +// ============================================================================ + +/// Sort/Group clause +#[derive(Debug, Clone, Default)] +pub struct SortGroupClause { + pub tle_sort_group_ref: u32, + pub eqop: u32, + pub sortop: u32, + pub nulls_first: bool, + pub hashable: bool, +} + +/// Function parameter +#[derive(Debug, Clone, Default)] +pub struct FunctionParameter { + pub name: String, + pub arg_type: Option, + pub mode: FunctionParameterMode, + pub defexpr: Option, +} + +/// ALTER TABLE command +#[derive(Debug, Clone, Default)] +pub struct AlterTableCmd { + pub subtype: AlterTableType, + pub name: String, + pub num: i16, + pub newowner: Option, + pub def: Option, + pub behavior: DropBehavior, + pub missing_ok: bool, + pub recurse: bool, +} + +/// Access privilege +#[derive(Debug, Clone, Default)] +pub struct AccessPriv { + pub priv_name: String, + pub cols: Vec, +} + +/// Object with arguments +#[derive(Debug, Clone, Default)] +pub struct ObjectWithArgs { + pub objname: Vec, + pub objargs: Vec, + pub objfuncargs: Vec, + pub args_unspecified: bool, +} + +// ============================================================================ +// Administrative statements +// ============================================================================ + +/// SET variable statement +#[derive(Debug, Clone, Default)] +pub struct VariableSetStmt { + pub kind: VariableSetKind, + pub name: String, + pub args: Vec, + pub is_local: bool, +} + +/// SHOW variable statement +#[derive(Debug, Clone, Default)] +pub struct VariableShowStmt { + pub name: String, +} + +/// EXPLAIN statement +#[derive(Debug, Clone, Default)] +pub struct ExplainStmt { + pub query: Option, + pub options: Vec, +} + +/// COPY statement +#[derive(Debug, Clone, Default)] +pub struct CopyStmt { + pub relation: Option, + pub query: Option, + pub attlist: Vec, + pub is_from: bool, + pub is_program: bool, + pub filename: String, + pub options: Vec, + pub where_clause: Option, +} + +/// GRANT/REVOKE statement +#[derive(Debug, Clone, Default)] +pub struct GrantStmt { + pub is_grant: bool, + pub targtype: GrantTargetType, + pub objtype: ObjectType, + pub objects: Vec, + pub privileges: Vec, + pub grantees: Vec, + pub grant_option: bool, + pub grantor: Option, + pub behavior: DropBehavior, +} + +/// GRANT/REVOKE role statement +#[derive(Debug, Clone, Default)] +pub struct GrantRoleStmt { + pub granted_roles: Vec, + pub grantee_roles: Vec, + pub is_grant: bool, + pub opt: Vec, + pub grantor: Option, + pub behavior: DropBehavior, +} + +/// LOCK statement +#[derive(Debug, Clone, Default)] +pub struct LockStmt { + pub relations: Vec, + pub mode: i32, + pub nowait: bool, +} + +/// VACUUM/ANALYZE statement +#[derive(Debug, Clone, Default)] +pub struct VacuumStmt { + pub options: Vec, + pub rels: Vec, + pub is_vacuumcmd: bool, +} + +// ============================================================================ +// Other statements +// ============================================================================ + +/// DO statement +#[derive(Debug, Clone, Default)] +pub struct DoStmt { + pub args: Vec, +} + +/// RENAME statement +#[derive(Debug, Clone, Default)] +pub struct RenameStmt { + pub rename_type: ObjectType, + pub relation_type: ObjectType, + pub relation: Option, + pub object: Option, + pub subname: String, + pub newname: String, + pub behavior: DropBehavior, + pub missing_ok: bool, +} + +/// NOTIFY statement +#[derive(Debug, Clone, Default)] +pub struct NotifyStmt { + pub conditionname: String, + pub payload: String, +} + +/// LISTEN statement +#[derive(Debug, Clone, Default)] +pub struct ListenStmt { + pub conditionname: String, +} + +/// UNLISTEN statement +#[derive(Debug, Clone, Default)] +pub struct UnlistenStmt { + pub conditionname: String, +} + +/// CHECKPOINT statement +#[derive(Debug, Clone, Default)] +pub struct CheckPointStmt; + +/// DISCARD statement +#[derive(Debug, Clone, Default)] +pub struct DiscardStmt { + pub target: DiscardMode, +} + +/// PREPARE statement +#[derive(Debug, Clone, Default)] +pub struct PrepareStmt { + pub name: String, + pub argtypes: Vec, + pub query: Option, +} + +/// EXECUTE statement +#[derive(Debug, Clone, Default)] +pub struct ExecuteStmt { + pub name: String, + pub params: Vec, +} + +/// DEALLOCATE statement +#[derive(Debug, Clone, Default)] +pub struct DeallocateStmt { + pub name: String, +} + +/// CLOSE cursor statement +#[derive(Debug, Clone, Default)] +pub struct ClosePortalStmt { + pub portalname: String, +} + +/// FETCH/MOVE statement +#[derive(Debug, Clone, Default)] +pub struct FetchStmt { + pub direction: FetchDirection, + pub how_many: i64, + pub portalname: String, + pub ismove: bool, +} + +// ============================================================================ +// Enums +// ============================================================================ + +/// SET operation type +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum SetOperation { + #[default] + None, + Union, + Intersect, + Except, +} + +/// LIMIT option +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum LimitOption { + #[default] + Default, + Count, + WithTies, +} + +/// A_Expr kind +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum AExprKind { + #[default] + Op, + OpAny, + OpAll, + Distinct, + NotDistinct, + NullIf, + In, + Like, + ILike, + Similar, + Between, + NotBetween, + BetweenSym, + NotBetweenSym, +} + +/// Boolean expression type +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum BoolExprType { + #[default] + And, + Or, + Not, +} + +/// Sublink type +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum SubLinkType { + #[default] + Exists, + All, + Any, + RowCompare, + Expr, + MultiExpr, + Array, + Cte, +} + +/// NULL test type +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum NullTestType { + #[default] + IsNull, + IsNotNull, +} + +/// Boolean test type +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum BoolTestType { + #[default] + IsTrue, + IsNotTrue, + IsFalse, + IsNotFalse, + IsUnknown, + IsNotUnknown, +} + +/// Min/Max operation +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum MinMaxOp { + #[default] + Greatest, + Least, +} + +/// JOIN type +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum JoinType { + #[default] + Inner, + Left, + Full, + Right, + Semi, + Anti, + RightAnti, + UniqueOuter, + UniqueInner, +} + +/// SORT BY direction +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum SortByDir { + #[default] + Default, + Asc, + Desc, + Using, +} + +/// SORT BY nulls +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum SortByNulls { + #[default] + Default, + First, + Last, +} + +/// CTE materialization +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum CTEMaterialize { + #[default] + Default, + Always, + Never, +} + +/// ON COMMIT action +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum OnCommitAction { + #[default] + Noop, + PreserveRows, + DeleteRows, + Drop, +} + +/// Object type for DDL +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum ObjectType { + #[default] + Table, + Index, + Sequence, + View, + MatView, + Type, + Schema, + Function, + Procedure, + Routine, + Aggregate, + Operator, + Language, + Cast, + Trigger, + EventTrigger, + Rule, + Database, + Tablespace, + Role, + Extension, + Fdw, + ForeignServer, + ForeignTable, + Policy, + Publication, + Subscription, + Collation, + Conversion, + Default, + Domain, + Constraint, + Column, + AccessMethod, + LargeObject, + Transform, + StatisticsObject, +} + +/// DROP behavior +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum DropBehavior { + #[default] + Restrict, + Cascade, +} + +/// ON CONFLICT action +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum OnConflictAction { + #[default] + None, + Nothing, + Update, +} + +/// GROUPING SET kind +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum GroupingSetKind { + #[default] + Empty, + Simple, + Rollup, + Cube, + Sets, +} + +/// Command type +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum CmdType { + #[default] + Unknown, + Select, + Update, + Insert, + Delete, + Merge, + Utility, + Nothing, +} + +/// Transaction statement kind +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum TransactionStmtKind { + #[default] + Begin, + Start, + Commit, + Rollback, + Savepoint, + Release, + RollbackTo, + Prepare, + CommitPrepared, + RollbackPrepared, +} + +/// Constraint type +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum ConstrType { + #[default] + Null, + NotNull, + Default, + Identity, + Generated, + Check, + Primary, + Unique, + Exclusion, + Foreign, + AttrDeferrable, + AttrNotDeferrable, + AttrDeferred, + AttrImmediate, +} + +/// DefElem action +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum DefElemAction { + #[default] + Unspec, + Set, + Add, + Drop, +} + +/// Role spec type +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum RoleSpecType { + #[default] + CString, + CurrentRole, + CurrentUser, + SessionUser, + Public, +} + +/// Coercion form +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum CoercionForm { + #[default] + ExplicitCall, + ExplicitCast, + ImplicitCast, + SqlSyntax, +} + +/// Variable SET kind +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum VariableSetKind { + #[default] + Value, + Default, + Current, + Multi, + Reset, + ResetAll, +} + +/// Lock clause strength +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum LockClauseStrength { + #[default] + None, + ForKeyShare, + ForShare, + ForNoKeyUpdate, + ForUpdate, +} + +/// Lock wait policy +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum LockWaitPolicy { + #[default] + Block, + Skip, + Error, +} + +/// View check option +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum ViewCheckOption { + #[default] + NoCheckOption, + Local, + Cascaded, +} + +/// Discard mode +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum DiscardMode { + #[default] + All, + Plans, + Sequences, + Temp, +} + +/// Fetch direction +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum FetchDirection { + #[default] + Forward, + Backward, + Absolute, + Relative, +} + +/// Function parameter mode +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum FunctionParameterMode { + #[default] + In, + Out, + InOut, + Variadic, + Table, +} + +/// ALTER TABLE command type +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum AlterTableType { + #[default] + AddColumn, + AddColumnToView, + ColumnDefault, + CookedColumnDefault, + DropNotNull, + SetNotNull, + DropExpression, + CheckNotNull, + SetStatistics, + SetOptions, + ResetOptions, + SetStorage, + SetCompression, + DropColumn, + AddIndex, + ReAddIndex, + AddConstraint, + ReAddConstraint, + AddIndexConstraint, + AlterConstraint, + ValidateConstraint, + DropConstraint, + ClusterOn, + DropCluster, + SetLogged, + SetUnLogged, + SetAccessMethod, + DropOids, + SetTableSpace, + SetRelOptions, + ResetRelOptions, + ReplaceRelOptions, + EnableTrig, + EnableAlwaysTrig, + EnableReplicaTrig, + DisableTrig, + EnableTrigAll, + DisableTrigAll, + EnableTrigUser, + DisableTrigUser, + EnableRule, + EnableAlwaysRule, + EnableReplicaRule, + DisableRule, + AddInherit, + DropInherit, + AddOf, + DropOf, + ReplicaIdentity, + EnableRowSecurity, + DisableRowSecurity, + ForceRowSecurity, + NoForceRowSecurity, + GenericOptions, + AttachPartition, + DetachPartition, + DetachPartitionFinalize, + AddIdentity, + SetIdentity, + DropIdentity, + ReAddStatistics, +} + +/// GRANT target type +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum GrantTargetType { + #[default] + Object, + AllInSchema, + Defaults, +} + +/// Overriding kind +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum OverridingKind { + #[default] + NotSet, + UserValue, + SystemValue, +} diff --git a/src/bindings_raw.rs b/src/bindings_raw.rs new file mode 100644 index 0000000..0be69dd --- /dev/null +++ b/src/bindings_raw.rs @@ -0,0 +1,12 @@ +//! Raw FFI bindings for PostgreSQL parse tree types. +//! +//! These bindings provide direct access to PostgreSQL's internal node structures, +//! allowing us to convert them to Rust types without going through protobuf serialization. + +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] +#![allow(unused)] +#![allow(clippy::all)] +#![allow(dead_code)] +include!(concat!(env!("OUT_DIR"), "/bindings_raw.rs")); diff --git a/src/lib.rs b/src/lib.rs index 9efe44d..a7c6405 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,7 +38,9 @@ //! ``` //! +pub mod ast; mod bindings; +mod bindings_raw; mod error; mod node_enum; mod node_mut; @@ -48,6 +50,7 @@ mod parse_result; #[rustfmt::skip] pub mod protobuf; mod query; +mod raw_parse; mod summary; mod summary_result; mod truncate; @@ -58,6 +61,7 @@ pub use node_mut::*; pub use node_ref::*; pub use parse_result::*; pub use query::*; +pub use raw_parse::parse_raw; pub use summary::*; pub use summary_result::*; pub use truncate::*; diff --git a/src/query.rs b/src/query.rs index 1f8cd20..eeb665f 100644 --- a/src/query.rs +++ b/src/query.rs @@ -3,6 +3,7 @@ use std::os::raw::c_char; use prost::Message; +use crate::ast; use crate::bindings::*; use crate::error::*; use crate::parse_result::ParseResult; @@ -279,3 +280,65 @@ pub fn split_with_scanner(query: &str) -> Result> { unsafe { pg_query_free_split_result(result) }; split_result } + +/// Parses the given SQL statement into native Rust AST types. +/// +/// This function provides an ergonomic alternative to [`parse`] that returns +/// native Rust types instead of protobuf-generated types. The native types +/// are easier to work with as they don't require unwrapping `Option>` +/// at every level. +/// +/// # Example +/// +/// ```rust +/// use pg_query::ast::{Node, SelectStmt}; +/// +/// let result = pg_query::parse_to_ast("SELECT * FROM users WHERE id = 1").unwrap(); +/// +/// // Direct access to statements without unwrapping +/// for stmt in &result.stmts { +/// if let Node::SelectStmt(select) = &stmt.stmt { +/// // Access fields directly +/// for node in &select.from_clause { +/// if let Node::RangeVar(range_var) = node { +/// println!("Table: {}", range_var.relname); +/// } +/// } +/// } +/// } +/// ``` +pub fn parse_to_ast(statement: &str) -> Result { + let input = CString::new(statement)?; + let result = unsafe { pg_query_parse_protobuf(input.as_ptr()) }; + let parse_result = if !result.error.is_null() { + let message = unsafe { CStr::from_ptr((*result.error).message) }.to_string_lossy().to_string(); + Err(Error::Parse(message)) + } else { + let data = unsafe { std::slice::from_raw_parts(result.parse_tree.data as *const u8, result.parse_tree.len as usize) }; + protobuf::ParseResult::decode(data) + .map_err(Error::Decode) + .map(|pb| ast::ParseResult::from(pb)) + }; + unsafe { pg_query_free_protobuf_parse_result(result) }; + parse_result +} + +/// Converts a native AST parse result back into a SQL string. +/// +/// This function uses the original protobuf stored in the AST to deparse. +/// Note: Any modifications made to the AST fields will NOT be reflected +/// in the deparsed output. This function is primarily useful for round-trip +/// testing and verification. +/// +/// # Example +/// +/// ```rust +/// use pg_query::ast::Node; +/// +/// let result = pg_query::parse_to_ast("SELECT * FROM users").unwrap(); +/// let sql = pg_query::deparse_ast(&result).unwrap(); +/// assert_eq!(sql, "SELECT * FROM users"); +/// ``` +pub fn deparse_ast(parse_result: &ast::ParseResult) -> Result { + deparse(parse_result.as_protobuf()) +} diff --git a/src/raw_parse.rs b/src/raw_parse.rs new file mode 100644 index 0000000..8455db3 --- /dev/null +++ b/src/raw_parse.rs @@ -0,0 +1,1188 @@ +//! Direct parsing that bypasses protobuf serialization/deserialization. +//! +//! This module provides a faster alternative to the protobuf-based parsing by +//! directly reading PostgreSQL's internal parse tree structures and converting +//! them to Rust protobuf types. + +use crate::bindings; +use crate::bindings_raw; +use crate::parse_result::ParseResult; +use crate::protobuf; +use crate::{Error, Result}; +use std::ffi::{CStr, CString}; + +/// Parses a SQL statement directly into protobuf types without going through protobuf serialization. +/// +/// This function is faster than `parse` because it skips the protobuf encode/decode step. +/// The parse tree is read directly from PostgreSQL's internal C structures. +/// +/// # Example +/// +/// ```rust +/// let result = pg_query::parse_raw("SELECT * FROM users").unwrap(); +/// assert_eq!(result.tables(), vec!["users"]); +/// ``` +pub fn parse_raw(statement: &str) -> Result { + let input = CString::new(statement)?; + let result = unsafe { bindings_raw::pg_query_parse_raw(input.as_ptr()) }; + + let parse_result = if !result.error.is_null() { + let message = unsafe { CStr::from_ptr((*result.error).message) } + .to_string_lossy() + .to_string(); + Err(Error::Parse(message)) + } else { + // Convert the C parse tree to protobuf types + let tree = result.tree; + let stmts = unsafe { convert_list_to_raw_stmts(tree) }; + let protobuf = protobuf::ParseResult { + version: bindings::PG_VERSION_NUM as i32, + stmts, + }; + Ok(ParseResult::new(protobuf, String::new())) + }; + + unsafe { bindings_raw::pg_query_free_raw_parse_result(result) }; + parse_result +} + +/// Converts a PostgreSQL List of RawStmt nodes to protobuf RawStmt vector. +unsafe fn convert_list_to_raw_stmts(list: *mut bindings_raw::List) -> Vec { + if list.is_null() { + return Vec::new(); + } + + let list_ref = &*list; + let length = list_ref.length as usize; + let mut stmts = Vec::with_capacity(length); + + for i in 0..length { + let cell = list_ref.elements.add(i); + let node_ptr = (*cell).ptr_value as *mut bindings_raw::Node; + + if !node_ptr.is_null() { + let node_tag = (*node_ptr).type_; + if node_tag == bindings_raw::NodeTag_T_RawStmt { + let raw_stmt = node_ptr as *mut bindings_raw::RawStmt; + stmts.push(convert_raw_stmt(&*raw_stmt)); + } + } + } + + stmts +} + +/// Converts a C RawStmt to a protobuf RawStmt. +unsafe fn convert_raw_stmt(raw_stmt: &bindings_raw::RawStmt) -> protobuf::RawStmt { + protobuf::RawStmt { + stmt: convert_node_boxed(raw_stmt.stmt), + stmt_location: raw_stmt.stmt_location, + stmt_len: raw_stmt.stmt_len, + } +} + +/// Converts a C Node pointer to a boxed protobuf Node (for fields that expect Option>). +unsafe fn convert_node_boxed(node_ptr: *mut bindings_raw::Node) -> Option> { + convert_node(node_ptr).map(Box::new) +} + +/// Converts a C Node pointer to a protobuf Node. +unsafe fn convert_node(node_ptr: *mut bindings_raw::Node) -> Option { + if node_ptr.is_null() { + return None; + } + + let node_tag = (*node_ptr).type_; + let node = match node_tag { + // Types that need Box + bindings_raw::NodeTag_T_SelectStmt => { + let stmt = node_ptr as *mut bindings_raw::SelectStmt; + Some(protobuf::node::Node::SelectStmt(Box::new(convert_select_stmt(&*stmt)))) + } + bindings_raw::NodeTag_T_InsertStmt => { + let stmt = node_ptr as *mut bindings_raw::InsertStmt; + Some(protobuf::node::Node::InsertStmt(Box::new(convert_insert_stmt(&*stmt)))) + } + bindings_raw::NodeTag_T_UpdateStmt => { + let stmt = node_ptr as *mut bindings_raw::UpdateStmt; + Some(protobuf::node::Node::UpdateStmt(Box::new(convert_update_stmt(&*stmt)))) + } + bindings_raw::NodeTag_T_DeleteStmt => { + let stmt = node_ptr as *mut bindings_raw::DeleteStmt; + Some(protobuf::node::Node::DeleteStmt(Box::new(convert_delete_stmt(&*stmt)))) + } + bindings_raw::NodeTag_T_ResTarget => { + let rt = node_ptr as *mut bindings_raw::ResTarget; + Some(protobuf::node::Node::ResTarget(Box::new(convert_res_target(&*rt)))) + } + bindings_raw::NodeTag_T_A_Expr => { + let expr = node_ptr as *mut bindings_raw::A_Expr; + Some(protobuf::node::Node::AExpr(Box::new(convert_a_expr(&*expr)))) + } + bindings_raw::NodeTag_T_A_Const => { + let aconst = node_ptr as *mut bindings_raw::A_Const; + Some(protobuf::node::Node::AConst(convert_a_const(&*aconst))) + } + bindings_raw::NodeTag_T_FuncCall => { + let fc = node_ptr as *mut bindings_raw::FuncCall; + Some(protobuf::node::Node::FuncCall(Box::new(convert_func_call(&*fc)))) + } + bindings_raw::NodeTag_T_TypeCast => { + let tc = node_ptr as *mut bindings_raw::TypeCast; + Some(protobuf::node::Node::TypeCast(Box::new(convert_type_cast(&*tc)))) + } + bindings_raw::NodeTag_T_JoinExpr => { + let je = node_ptr as *mut bindings_raw::JoinExpr; + Some(protobuf::node::Node::JoinExpr(Box::new(convert_join_expr(&*je)))) + } + bindings_raw::NodeTag_T_SortBy => { + let sb = node_ptr as *mut bindings_raw::SortBy; + Some(protobuf::node::Node::SortBy(Box::new(convert_sort_by(&*sb)))) + } + bindings_raw::NodeTag_T_BoolExpr => { + let be = node_ptr as *mut bindings_raw::BoolExpr; + Some(protobuf::node::Node::BoolExpr(Box::new(convert_bool_expr(&*be)))) + } + bindings_raw::NodeTag_T_SubLink => { + let sl = node_ptr as *mut bindings_raw::SubLink; + Some(protobuf::node::Node::SubLink(Box::new(convert_sub_link(&*sl)))) + } + bindings_raw::NodeTag_T_NullTest => { + let nt = node_ptr as *mut bindings_raw::NullTest; + Some(protobuf::node::Node::NullTest(Box::new(convert_null_test(&*nt)))) + } + bindings_raw::NodeTag_T_CaseExpr => { + let ce = node_ptr as *mut bindings_raw::CaseExpr; + Some(protobuf::node::Node::CaseExpr(Box::new(convert_case_expr(&*ce)))) + } + bindings_raw::NodeTag_T_CaseWhen => { + let cw = node_ptr as *mut bindings_raw::CaseWhen; + Some(protobuf::node::Node::CaseWhen(Box::new(convert_case_when(&*cw)))) + } + bindings_raw::NodeTag_T_CoalesceExpr => { + let ce = node_ptr as *mut bindings_raw::CoalesceExpr; + Some(protobuf::node::Node::CoalesceExpr(Box::new(convert_coalesce_expr(&*ce)))) + } + bindings_raw::NodeTag_T_CommonTableExpr => { + let cte = node_ptr as *mut bindings_raw::CommonTableExpr; + Some(protobuf::node::Node::CommonTableExpr(Box::new(convert_common_table_expr(&*cte)))) + } + bindings_raw::NodeTag_T_ColumnDef => { + let cd = node_ptr as *mut bindings_raw::ColumnDef; + Some(protobuf::node::Node::ColumnDef(Box::new(convert_column_def(&*cd)))) + } + bindings_raw::NodeTag_T_Constraint => { + let c = node_ptr as *mut bindings_raw::Constraint; + Some(protobuf::node::Node::Constraint(Box::new(convert_constraint(&*c)))) + } + bindings_raw::NodeTag_T_DropStmt => { + let ds = node_ptr as *mut bindings_raw::DropStmt; + Some(protobuf::node::Node::DropStmt(convert_drop_stmt(&*ds))) + } + bindings_raw::NodeTag_T_IndexStmt => { + let is = node_ptr as *mut bindings_raw::IndexStmt; + Some(protobuf::node::Node::IndexStmt(Box::new(convert_index_stmt(&*is)))) + } + bindings_raw::NodeTag_T_IndexElem => { + let ie = node_ptr as *mut bindings_raw::IndexElem; + Some(protobuf::node::Node::IndexElem(Box::new(convert_index_elem(&*ie)))) + } + bindings_raw::NodeTag_T_DefElem => { + let de = node_ptr as *mut bindings_raw::DefElem; + Some(protobuf::node::Node::DefElem(Box::new(convert_def_elem(&*de)))) + } + bindings_raw::NodeTag_T_WindowDef => { + let wd = node_ptr as *mut bindings_raw::WindowDef; + Some(protobuf::node::Node::WindowDef(Box::new(convert_window_def(&*wd)))) + } + // Types that don't need Box + bindings_raw::NodeTag_T_RangeVar => { + let rv = node_ptr as *mut bindings_raw::RangeVar; + Some(protobuf::node::Node::RangeVar(convert_range_var(&*rv))) + } + bindings_raw::NodeTag_T_ColumnRef => { + let cr = node_ptr as *mut bindings_raw::ColumnRef; + Some(protobuf::node::Node::ColumnRef(convert_column_ref(&*cr))) + } + bindings_raw::NodeTag_T_A_Star => { + Some(protobuf::node::Node::AStar(protobuf::AStar {})) + } + bindings_raw::NodeTag_T_TypeName => { + let tn = node_ptr as *mut bindings_raw::TypeName; + Some(protobuf::node::Node::TypeName(convert_type_name(&*tn))) + } + bindings_raw::NodeTag_T_Alias => { + let alias = node_ptr as *mut bindings_raw::Alias; + Some(protobuf::node::Node::Alias(convert_alias(&*alias))) + } + bindings_raw::NodeTag_T_String => { + let s = node_ptr as *mut bindings_raw::String; + Some(protobuf::node::Node::String(convert_string(&*s))) + } + bindings_raw::NodeTag_T_Integer => { + let i = node_ptr as *mut bindings_raw::Integer; + Some(protobuf::node::Node::Integer(protobuf::Integer { ival: (*i).ival })) + } + bindings_raw::NodeTag_T_Float => { + let f = node_ptr as *mut bindings_raw::Float; + let fval = if (*f).fval.is_null() { + String::new() + } else { + CStr::from_ptr((*f).fval).to_string_lossy().to_string() + }; + Some(protobuf::node::Node::Float(protobuf::Float { fval })) + } + bindings_raw::NodeTag_T_Boolean => { + let b = node_ptr as *mut bindings_raw::Boolean; + Some(protobuf::node::Node::Boolean(protobuf::Boolean { boolval: (*b).boolval })) + } + bindings_raw::NodeTag_T_ParamRef => { + let pr = node_ptr as *mut bindings_raw::ParamRef; + Some(protobuf::node::Node::ParamRef(protobuf::ParamRef { + number: (*pr).number, + location: (*pr).location, + })) + } + bindings_raw::NodeTag_T_WithClause => { + let wc = node_ptr as *mut bindings_raw::WithClause; + Some(protobuf::node::Node::WithClause(convert_with_clause(&*wc))) + } + bindings_raw::NodeTag_T_CreateStmt => { + let cs = node_ptr as *mut bindings_raw::CreateStmt; + Some(protobuf::node::Node::CreateStmt(convert_create_stmt(&*cs))) + } + bindings_raw::NodeTag_T_List => { + let list = node_ptr as *mut bindings_raw::List; + Some(protobuf::node::Node::List(convert_list(&*list))) + } + bindings_raw::NodeTag_T_LockingClause => { + let lc = node_ptr as *mut bindings_raw::LockingClause; + Some(protobuf::node::Node::LockingClause(convert_locking_clause(&*lc))) + } + bindings_raw::NodeTag_T_MinMaxExpr => { + let mme = node_ptr as *mut bindings_raw::MinMaxExpr; + Some(protobuf::node::Node::MinMaxExpr(Box::new(convert_min_max_expr(&*mme)))) + } + bindings_raw::NodeTag_T_GroupingSet => { + let gs = node_ptr as *mut bindings_raw::GroupingSet; + Some(protobuf::node::Node::GroupingSet(convert_grouping_set(&*gs))) + } + bindings_raw::NodeTag_T_RangeSubselect => { + let rs = node_ptr as *mut bindings_raw::RangeSubselect; + Some(protobuf::node::Node::RangeSubselect(Box::new(convert_range_subselect(&*rs)))) + } + bindings_raw::NodeTag_T_A_ArrayExpr => { + let ae = node_ptr as *mut bindings_raw::A_ArrayExpr; + Some(protobuf::node::Node::AArrayExpr(convert_a_array_expr(&*ae))) + } + bindings_raw::NodeTag_T_A_Indirection => { + let ai = node_ptr as *mut bindings_raw::A_Indirection; + Some(protobuf::node::Node::AIndirection(Box::new(convert_a_indirection(&*ai)))) + } + bindings_raw::NodeTag_T_A_Indices => { + let ai = node_ptr as *mut bindings_raw::A_Indices; + Some(protobuf::node::Node::AIndices(Box::new(convert_a_indices(&*ai)))) + } + bindings_raw::NodeTag_T_AlterTableStmt => { + let ats = node_ptr as *mut bindings_raw::AlterTableStmt; + Some(protobuf::node::Node::AlterTableStmt(convert_alter_table_stmt(&*ats))) + } + bindings_raw::NodeTag_T_AlterTableCmd => { + let atc = node_ptr as *mut bindings_raw::AlterTableCmd; + Some(protobuf::node::Node::AlterTableCmd(Box::new(convert_alter_table_cmd(&*atc)))) + } + bindings_raw::NodeTag_T_CopyStmt => { + let cs = node_ptr as *mut bindings_raw::CopyStmt; + Some(protobuf::node::Node::CopyStmt(Box::new(convert_copy_stmt(&*cs)))) + } + bindings_raw::NodeTag_T_TruncateStmt => { + let ts = node_ptr as *mut bindings_raw::TruncateStmt; + Some(protobuf::node::Node::TruncateStmt(convert_truncate_stmt(&*ts))) + } + bindings_raw::NodeTag_T_ViewStmt => { + let vs = node_ptr as *mut bindings_raw::ViewStmt; + Some(protobuf::node::Node::ViewStmt(Box::new(convert_view_stmt(&*vs)))) + } + bindings_raw::NodeTag_T_ExplainStmt => { + let es = node_ptr as *mut bindings_raw::ExplainStmt; + Some(protobuf::node::Node::ExplainStmt(Box::new(convert_explain_stmt(&*es)))) + } + bindings_raw::NodeTag_T_CreateTableAsStmt => { + let ctas = node_ptr as *mut bindings_raw::CreateTableAsStmt; + Some(protobuf::node::Node::CreateTableAsStmt(Box::new(convert_create_table_as_stmt(&*ctas)))) + } + bindings_raw::NodeTag_T_PrepareStmt => { + let ps = node_ptr as *mut bindings_raw::PrepareStmt; + Some(protobuf::node::Node::PrepareStmt(Box::new(convert_prepare_stmt(&*ps)))) + } + bindings_raw::NodeTag_T_ExecuteStmt => { + let es = node_ptr as *mut bindings_raw::ExecuteStmt; + Some(protobuf::node::Node::ExecuteStmt(convert_execute_stmt(&*es))) + } + bindings_raw::NodeTag_T_DeallocateStmt => { + let ds = node_ptr as *mut bindings_raw::DeallocateStmt; + Some(protobuf::node::Node::DeallocateStmt(convert_deallocate_stmt(&*ds))) + } + bindings_raw::NodeTag_T_SetToDefault => { + let std = node_ptr as *mut bindings_raw::SetToDefault; + Some(protobuf::node::Node::SetToDefault(Box::new(convert_set_to_default(&*std)))) + } + bindings_raw::NodeTag_T_MultiAssignRef => { + let mar = node_ptr as *mut bindings_raw::MultiAssignRef; + Some(protobuf::node::Node::MultiAssignRef(Box::new(convert_multi_assign_ref(&*mar)))) + } + bindings_raw::NodeTag_T_RowExpr => { + let re = node_ptr as *mut bindings_raw::RowExpr; + Some(protobuf::node::Node::RowExpr(Box::new(convert_row_expr(&*re)))) + } + _ => { + // For unhandled node types, return None + // In the future, we could add more node types here + None + } + }; + + node.map(|n| protobuf::Node { node: Some(n) }) +} + +/// Converts a PostgreSQL List to a protobuf List of Nodes. +unsafe fn convert_list(list: &bindings_raw::List) -> protobuf::List { + let items = convert_list_to_nodes(list as *const bindings_raw::List as *mut bindings_raw::List); + protobuf::List { items } +} + +/// Converts a PostgreSQL List pointer to a Vec of protobuf Nodes. +unsafe fn convert_list_to_nodes(list: *mut bindings_raw::List) -> Vec { + if list.is_null() { + return Vec::new(); + } + + let list_ref = &*list; + let length = list_ref.length as usize; + let mut nodes = Vec::with_capacity(length); + + for i in 0..length { + let cell = list_ref.elements.add(i); + let node_ptr = (*cell).ptr_value as *mut bindings_raw::Node; + + if let Some(node) = convert_node(node_ptr) { + nodes.push(node); + } + } + + nodes +} + +// ============================================================================ +// Statement Conversions +// ============================================================================ + +unsafe fn convert_select_stmt(stmt: &bindings_raw::SelectStmt) -> protobuf::SelectStmt { + protobuf::SelectStmt { + distinct_clause: convert_list_to_nodes(stmt.distinctClause), + into_clause: convert_into_clause(stmt.intoClause), + target_list: convert_list_to_nodes(stmt.targetList), + from_clause: convert_list_to_nodes(stmt.fromClause), + where_clause: convert_node_boxed(stmt.whereClause), + group_clause: convert_list_to_nodes(stmt.groupClause), + group_distinct: stmt.groupDistinct, + having_clause: convert_node_boxed(stmt.havingClause), + window_clause: convert_list_to_nodes(stmt.windowClause), + values_lists: convert_list_to_nodes(stmt.valuesLists), + sort_clause: convert_list_to_nodes(stmt.sortClause), + limit_offset: convert_node_boxed(stmt.limitOffset), + limit_count: convert_node_boxed(stmt.limitCount), + limit_option: stmt.limitOption as i32 + 1, // Protobuf enums have UNDEFINED=0, so C values need +1 + locking_clause: convert_list_to_nodes(stmt.lockingClause), + with_clause: convert_with_clause_opt(stmt.withClause), + op: stmt.op as i32 + 1, // Protobuf SetOperation has UNDEFINED=0, so C values need +1 + all: stmt.all, + larg: if stmt.larg.is_null() { None } else { Some(Box::new(convert_select_stmt(&*stmt.larg))) }, + rarg: if stmt.rarg.is_null() { None } else { Some(Box::new(convert_select_stmt(&*stmt.rarg))) }, + } +} + +unsafe fn convert_insert_stmt(stmt: &bindings_raw::InsertStmt) -> protobuf::InsertStmt { + protobuf::InsertStmt { + relation: if stmt.relation.is_null() { None } else { Some(convert_range_var(&*stmt.relation)) }, + cols: convert_list_to_nodes(stmt.cols), + select_stmt: convert_node_boxed(stmt.selectStmt), + on_conflict_clause: convert_on_conflict_clause(stmt.onConflictClause), + returning_list: convert_list_to_nodes(stmt.returningList), + with_clause: convert_with_clause_opt(stmt.withClause), + r#override: stmt.override_ as i32 + 1, + } +} + +unsafe fn convert_update_stmt(stmt: &bindings_raw::UpdateStmt) -> protobuf::UpdateStmt { + protobuf::UpdateStmt { + relation: if stmt.relation.is_null() { None } else { Some(convert_range_var(&*stmt.relation)) }, + target_list: convert_list_to_nodes(stmt.targetList), + where_clause: convert_node_boxed(stmt.whereClause), + from_clause: convert_list_to_nodes(stmt.fromClause), + returning_list: convert_list_to_nodes(stmt.returningList), + with_clause: convert_with_clause_opt(stmt.withClause), + } +} + +unsafe fn convert_delete_stmt(stmt: &bindings_raw::DeleteStmt) -> protobuf::DeleteStmt { + protobuf::DeleteStmt { + relation: if stmt.relation.is_null() { None } else { Some(convert_range_var(&*stmt.relation)) }, + using_clause: convert_list_to_nodes(stmt.usingClause), + where_clause: convert_node_boxed(stmt.whereClause), + returning_list: convert_list_to_nodes(stmt.returningList), + with_clause: convert_with_clause_opt(stmt.withClause), + } +} + +unsafe fn convert_create_stmt(stmt: &bindings_raw::CreateStmt) -> protobuf::CreateStmt { + protobuf::CreateStmt { + relation: if stmt.relation.is_null() { None } else { Some(convert_range_var(&*stmt.relation)) }, + table_elts: convert_list_to_nodes(stmt.tableElts), + inh_relations: convert_list_to_nodes(stmt.inhRelations), + partbound: convert_partition_bound_spec_opt(stmt.partbound), + partspec: convert_partition_spec_opt(stmt.partspec), + of_typename: if stmt.ofTypename.is_null() { None } else { Some(convert_type_name(&*stmt.ofTypename)) }, + constraints: convert_list_to_nodes(stmt.constraints), + options: convert_list_to_nodes(stmt.options), + oncommit: stmt.oncommit as i32 + 1, + tablespacename: convert_c_string(stmt.tablespacename), + access_method: convert_c_string(stmt.accessMethod), + if_not_exists: stmt.if_not_exists, + } +} + +unsafe fn convert_drop_stmt(stmt: &bindings_raw::DropStmt) -> protobuf::DropStmt { + protobuf::DropStmt { + objects: convert_list_to_nodes(stmt.objects), + remove_type: stmt.removeType as i32 + 1, + behavior: stmt.behavior as i32 + 1, + missing_ok: stmt.missing_ok, + concurrent: stmt.concurrent, + } +} + +unsafe fn convert_index_stmt(stmt: &bindings_raw::IndexStmt) -> protobuf::IndexStmt { + protobuf::IndexStmt { + idxname: convert_c_string(stmt.idxname), + relation: if stmt.relation.is_null() { None } else { Some(convert_range_var(&*stmt.relation)) }, + access_method: convert_c_string(stmt.accessMethod), + table_space: convert_c_string(stmt.tableSpace), + index_params: convert_list_to_nodes(stmt.indexParams), + index_including_params: convert_list_to_nodes(stmt.indexIncludingParams), + options: convert_list_to_nodes(stmt.options), + where_clause: convert_node_boxed(stmt.whereClause), + exclude_op_names: convert_list_to_nodes(stmt.excludeOpNames), + idxcomment: convert_c_string(stmt.idxcomment), + index_oid: stmt.indexOid, + old_number: stmt.oldNumber, + old_create_subid: stmt.oldCreateSubid, + old_first_relfilelocator_subid: stmt.oldFirstRelfilelocatorSubid, + unique: stmt.unique, + nulls_not_distinct: stmt.nulls_not_distinct, + primary: stmt.primary, + isconstraint: stmt.isconstraint, + deferrable: stmt.deferrable, + initdeferred: stmt.initdeferred, + transformed: stmt.transformed, + concurrent: stmt.concurrent, + if_not_exists: stmt.if_not_exists, + reset_default_tblspc: stmt.reset_default_tblspc, + } +} + +// ============================================================================ +// Expression/Clause Conversions +// ============================================================================ + +unsafe fn convert_range_var(rv: &bindings_raw::RangeVar) -> protobuf::RangeVar { + protobuf::RangeVar { + catalogname: convert_c_string(rv.catalogname), + schemaname: convert_c_string(rv.schemaname), + relname: convert_c_string(rv.relname), + inh: rv.inh, + relpersistence: String::from_utf8_lossy(&[rv.relpersistence as u8]).to_string(), + alias: if rv.alias.is_null() { None } else { Some(convert_alias(&*rv.alias)) }, + location: rv.location, + } +} + +unsafe fn convert_column_ref(cr: &bindings_raw::ColumnRef) -> protobuf::ColumnRef { + protobuf::ColumnRef { + fields: convert_list_to_nodes(cr.fields), + location: cr.location, + } +} + +unsafe fn convert_res_target(rt: &bindings_raw::ResTarget) -> protobuf::ResTarget { + protobuf::ResTarget { + name: convert_c_string(rt.name), + indirection: convert_list_to_nodes(rt.indirection), + val: convert_node_boxed(rt.val), + location: rt.location, + } +} + +unsafe fn convert_a_expr(expr: &bindings_raw::A_Expr) -> protobuf::AExpr { + protobuf::AExpr { + kind: expr.kind as i32 + 1, + name: convert_list_to_nodes(expr.name), + lexpr: convert_node_boxed(expr.lexpr), + rexpr: convert_node_boxed(expr.rexpr), + location: expr.location, + } +} + +unsafe fn convert_a_const(aconst: &bindings_raw::A_Const) -> protobuf::AConst { + let val = if aconst.isnull { + None + } else { + // Check the node tag in the val union to determine the type + let node_tag = aconst.val.node.type_; + match node_tag { + bindings_raw::NodeTag_T_Integer => { + Some(protobuf::a_const::Val::Ival(protobuf::Integer { + ival: aconst.val.ival.ival, + })) + } + bindings_raw::NodeTag_T_Float => { + let fval = if aconst.val.fval.fval.is_null() { + std::string::String::new() + } else { + CStr::from_ptr(aconst.val.fval.fval).to_string_lossy().to_string() + }; + Some(protobuf::a_const::Val::Fval(protobuf::Float { fval })) + } + bindings_raw::NodeTag_T_Boolean => { + Some(protobuf::a_const::Val::Boolval(protobuf::Boolean { + boolval: aconst.val.boolval.boolval, + })) + } + bindings_raw::NodeTag_T_String => { + let sval = if aconst.val.sval.sval.is_null() { + std::string::String::new() + } else { + CStr::from_ptr(aconst.val.sval.sval).to_string_lossy().to_string() + }; + Some(protobuf::a_const::Val::Sval(protobuf::String { sval })) + } + bindings_raw::NodeTag_T_BitString => { + let bsval = if aconst.val.bsval.bsval.is_null() { + std::string::String::new() + } else { + CStr::from_ptr(aconst.val.bsval.bsval).to_string_lossy().to_string() + }; + Some(protobuf::a_const::Val::Bsval(protobuf::BitString { bsval })) + } + _ => None, + } + }; + + protobuf::AConst { + isnull: aconst.isnull, + val, + location: aconst.location, + } +} + +unsafe fn convert_func_call(fc: &bindings_raw::FuncCall) -> protobuf::FuncCall { + protobuf::FuncCall { + funcname: convert_list_to_nodes(fc.funcname), + args: convert_list_to_nodes(fc.args), + agg_order: convert_list_to_nodes(fc.agg_order), + agg_filter: convert_node_boxed(fc.agg_filter), + over: if fc.over.is_null() { None } else { Some(Box::new(convert_window_def(&*fc.over))) }, + agg_within_group: fc.agg_within_group, + agg_star: fc.agg_star, + agg_distinct: fc.agg_distinct, + func_variadic: fc.func_variadic, + funcformat: fc.funcformat as i32 + 1, + location: fc.location, + } +} + +unsafe fn convert_type_cast(tc: &bindings_raw::TypeCast) -> protobuf::TypeCast { + protobuf::TypeCast { + arg: convert_node_boxed(tc.arg), + type_name: if tc.typeName.is_null() { None } else { Some(convert_type_name(&*tc.typeName)) }, + location: tc.location, + } +} + +unsafe fn convert_type_name(tn: &bindings_raw::TypeName) -> protobuf::TypeName { + protobuf::TypeName { + names: convert_list_to_nodes(tn.names), + type_oid: tn.typeOid, + setof: tn.setof, + pct_type: tn.pct_type, + typmods: convert_list_to_nodes(tn.typmods), + typemod: tn.typemod, + array_bounds: convert_list_to_nodes(tn.arrayBounds), + location: tn.location, + } +} + +unsafe fn convert_alias(alias: &bindings_raw::Alias) -> protobuf::Alias { + protobuf::Alias { + aliasname: convert_c_string(alias.aliasname), + colnames: convert_list_to_nodes(alias.colnames), + } +} + +unsafe fn convert_join_expr(je: &bindings_raw::JoinExpr) -> protobuf::JoinExpr { + protobuf::JoinExpr { + jointype: je.jointype as i32 + 1, + is_natural: je.isNatural, + larg: convert_node_boxed(je.larg), + rarg: convert_node_boxed(je.rarg), + using_clause: convert_list_to_nodes(je.usingClause), + join_using_alias: if je.join_using_alias.is_null() { None } else { Some(convert_alias(&*je.join_using_alias)) }, + quals: convert_node_boxed(je.quals), + alias: if je.alias.is_null() { None } else { Some(convert_alias(&*je.alias)) }, + rtindex: je.rtindex, + } +} + +unsafe fn convert_sort_by(sb: &bindings_raw::SortBy) -> protobuf::SortBy { + protobuf::SortBy { + node: convert_node_boxed(sb.node), + sortby_dir: sb.sortby_dir as i32 + 1, + sortby_nulls: sb.sortby_nulls as i32 + 1, + use_op: convert_list_to_nodes(sb.useOp), + location: sb.location, + } +} + +unsafe fn convert_bool_expr(be: &bindings_raw::BoolExpr) -> protobuf::BoolExpr { + protobuf::BoolExpr { + xpr: None, // Xpr is internal + boolop: be.boolop as i32 + 1, + args: convert_list_to_nodes(be.args), + location: be.location, + } +} + +unsafe fn convert_sub_link(sl: &bindings_raw::SubLink) -> protobuf::SubLink { + protobuf::SubLink { + xpr: None, + sub_link_type: sl.subLinkType as i32 + 1, + sub_link_id: sl.subLinkId, + testexpr: convert_node_boxed(sl.testexpr), + oper_name: convert_list_to_nodes(sl.operName), + subselect: convert_node_boxed(sl.subselect), + location: sl.location, + } +} + +unsafe fn convert_null_test(nt: &bindings_raw::NullTest) -> protobuf::NullTest { + protobuf::NullTest { + xpr: None, + arg: convert_node_boxed(nt.arg as *mut bindings_raw::Node), + nulltesttype: nt.nulltesttype as i32 + 1, + argisrow: nt.argisrow, + location: nt.location, + } +} + +unsafe fn convert_case_expr(ce: &bindings_raw::CaseExpr) -> protobuf::CaseExpr { + protobuf::CaseExpr { + xpr: None, + casetype: ce.casetype, + casecollid: ce.casecollid, + arg: convert_node_boxed(ce.arg as *mut bindings_raw::Node), + args: convert_list_to_nodes(ce.args), + defresult: convert_node_boxed(ce.defresult as *mut bindings_raw::Node), + location: ce.location, + } +} + +unsafe fn convert_case_when(cw: &bindings_raw::CaseWhen) -> protobuf::CaseWhen { + protobuf::CaseWhen { + xpr: None, + expr: convert_node_boxed(cw.expr as *mut bindings_raw::Node), + result: convert_node_boxed(cw.result as *mut bindings_raw::Node), + location: cw.location, + } +} + +unsafe fn convert_coalesce_expr(ce: &bindings_raw::CoalesceExpr) -> protobuf::CoalesceExpr { + protobuf::CoalesceExpr { + xpr: None, + coalescetype: ce.coalescetype, + coalescecollid: ce.coalescecollid, + args: convert_list_to_nodes(ce.args), + location: ce.location, + } +} + +unsafe fn convert_with_clause(wc: &bindings_raw::WithClause) -> protobuf::WithClause { + protobuf::WithClause { + ctes: convert_list_to_nodes(wc.ctes), + recursive: wc.recursive, + location: wc.location, + } +} + +unsafe fn convert_with_clause_opt(wc: *mut bindings_raw::WithClause) -> Option { + if wc.is_null() { + None + } else { + Some(convert_with_clause(&*wc)) + } +} + +unsafe fn convert_common_table_expr(cte: &bindings_raw::CommonTableExpr) -> protobuf::CommonTableExpr { + protobuf::CommonTableExpr { + ctename: convert_c_string(cte.ctename), + aliascolnames: convert_list_to_nodes(cte.aliascolnames), + ctematerialized: cte.ctematerialized as i32 + 1, + ctequery: convert_node_boxed(cte.ctequery), + search_clause: convert_cte_search_clause_opt(cte.search_clause), + cycle_clause: convert_cte_cycle_clause_opt(cte.cycle_clause), + location: cte.location, + cterecursive: cte.cterecursive, + cterefcount: cte.cterefcount, + ctecolnames: convert_list_to_nodes(cte.ctecolnames), + ctecoltypes: convert_list_to_nodes(cte.ctecoltypes), + ctecoltypmods: convert_list_to_nodes(cte.ctecoltypmods), + ctecolcollations: convert_list_to_nodes(cte.ctecolcollations), + } +} + +unsafe fn convert_window_def(wd: &bindings_raw::WindowDef) -> protobuf::WindowDef { + protobuf::WindowDef { + name: convert_c_string(wd.name), + refname: convert_c_string(wd.refname), + partition_clause: convert_list_to_nodes(wd.partitionClause), + order_clause: convert_list_to_nodes(wd.orderClause), + frame_options: wd.frameOptions, + start_offset: convert_node_boxed(wd.startOffset), + end_offset: convert_node_boxed(wd.endOffset), + location: wd.location, + } +} + +unsafe fn convert_into_clause(ic: *mut bindings_raw::IntoClause) -> Option> { + if ic.is_null() { + return None; + } + let ic_ref = &*ic; + Some(Box::new(protobuf::IntoClause { + rel: if ic_ref.rel.is_null() { None } else { Some(convert_range_var(&*ic_ref.rel)) }, + col_names: convert_list_to_nodes(ic_ref.colNames), + access_method: convert_c_string(ic_ref.accessMethod), + options: convert_list_to_nodes(ic_ref.options), + on_commit: ic_ref.onCommit as i32 + 1, + table_space_name: convert_c_string(ic_ref.tableSpaceName), + view_query: convert_node_boxed(ic_ref.viewQuery), + skip_data: ic_ref.skipData, + })) +} + +unsafe fn convert_infer_clause(ic: *mut bindings_raw::InferClause) -> Option> { + if ic.is_null() { + return None; + } + let ic_ref = &*ic; + Some(Box::new(protobuf::InferClause { + index_elems: convert_list_to_nodes(ic_ref.indexElems), + where_clause: convert_node_boxed(ic_ref.whereClause), + conname: convert_c_string(ic_ref.conname), + location: ic_ref.location, + })) +} + +unsafe fn convert_on_conflict_clause(oc: *mut bindings_raw::OnConflictClause) -> Option> { + if oc.is_null() { + return None; + } + let oc_ref = &*oc; + Some(Box::new(protobuf::OnConflictClause { + action: oc_ref.action as i32 + 1, + infer: convert_infer_clause(oc_ref.infer), + target_list: convert_list_to_nodes(oc_ref.targetList), + where_clause: convert_node_boxed(oc_ref.whereClause), + location: oc_ref.location, + })) +} + +unsafe fn convert_column_def(cd: &bindings_raw::ColumnDef) -> protobuf::ColumnDef { + protobuf::ColumnDef { + colname: convert_c_string(cd.colname), + type_name: if cd.typeName.is_null() { None } else { Some(convert_type_name(&*cd.typeName)) }, + compression: convert_c_string(cd.compression), + inhcount: cd.inhcount, + is_local: cd.is_local, + is_not_null: cd.is_not_null, + is_from_type: cd.is_from_type, + storage: if cd.storage == 0 { String::new() } else { String::from_utf8_lossy(&[cd.storage as u8]).to_string() }, + storage_name: convert_c_string(cd.storage_name), + raw_default: convert_node_boxed(cd.raw_default), + cooked_default: convert_node_boxed(cd.cooked_default), + identity: if cd.identity == 0 { String::new() } else { String::from_utf8_lossy(&[cd.identity as u8]).to_string() }, + identity_sequence: if cd.identitySequence.is_null() { None } else { Some(convert_range_var(&*cd.identitySequence)) }, + generated: if cd.generated == 0 { String::new() } else { String::from_utf8_lossy(&[cd.generated as u8]).to_string() }, + coll_clause: convert_collate_clause_opt(cd.collClause), + coll_oid: cd.collOid, + constraints: convert_list_to_nodes(cd.constraints), + fdwoptions: convert_list_to_nodes(cd.fdwoptions), + location: cd.location, + } +} + +unsafe fn convert_constraint(c: &bindings_raw::Constraint) -> protobuf::Constraint { + protobuf::Constraint { + contype: c.contype as i32 + 1, + conname: convert_c_string(c.conname), + deferrable: c.deferrable, + initdeferred: c.initdeferred, + location: c.location, + is_no_inherit: c.is_no_inherit, + raw_expr: convert_node_boxed(c.raw_expr), + cooked_expr: convert_c_string(c.cooked_expr), + generated_when: if c.generated_when == 0 { String::new() } else { String::from_utf8_lossy(&[c.generated_when as u8]).to_string() }, + nulls_not_distinct: c.nulls_not_distinct, + keys: convert_list_to_nodes(c.keys), + including: convert_list_to_nodes(c.including), + exclusions: convert_list_to_nodes(c.exclusions), + options: convert_list_to_nodes(c.options), + indexname: convert_c_string(c.indexname), + indexspace: convert_c_string(c.indexspace), + reset_default_tblspc: c.reset_default_tblspc, + access_method: convert_c_string(c.access_method), + where_clause: convert_node_boxed(c.where_clause), + pktable: if c.pktable.is_null() { None } else { Some(convert_range_var(&*c.pktable)) }, + fk_attrs: convert_list_to_nodes(c.fk_attrs), + pk_attrs: convert_list_to_nodes(c.pk_attrs), + fk_matchtype: if c.fk_matchtype == 0 { String::new() } else { String::from_utf8_lossy(&[c.fk_matchtype as u8]).to_string() }, + fk_upd_action: if c.fk_upd_action == 0 { String::new() } else { String::from_utf8_lossy(&[c.fk_upd_action as u8]).to_string() }, + fk_del_action: if c.fk_del_action == 0 { String::new() } else { String::from_utf8_lossy(&[c.fk_del_action as u8]).to_string() }, + fk_del_set_cols: convert_list_to_nodes(c.fk_del_set_cols), + old_conpfeqop: convert_list_to_nodes(c.old_conpfeqop), + old_pktable_oid: c.old_pktable_oid, + skip_validation: c.skip_validation, + initially_valid: c.initially_valid, + } +} + +unsafe fn convert_index_elem(ie: &bindings_raw::IndexElem) -> protobuf::IndexElem { + protobuf::IndexElem { + name: convert_c_string(ie.name), + expr: convert_node_boxed(ie.expr), + indexcolname: convert_c_string(ie.indexcolname), + collation: convert_list_to_nodes(ie.collation), + opclass: convert_list_to_nodes(ie.opclass), + opclassopts: convert_list_to_nodes(ie.opclassopts), + ordering: ie.ordering as i32 + 1, + nulls_ordering: ie.nulls_ordering as i32 + 1, + } +} + +unsafe fn convert_def_elem(de: &bindings_raw::DefElem) -> protobuf::DefElem { + protobuf::DefElem { + defnamespace: convert_c_string(de.defnamespace), + defname: convert_c_string(de.defname), + arg: convert_node_boxed(de.arg), + defaction: de.defaction as i32 + 1, + location: de.location, + } +} + +unsafe fn convert_string(s: &bindings_raw::String) -> protobuf::String { + protobuf::String { + sval: convert_c_string(s.sval), + } +} + +unsafe fn convert_locking_clause(lc: &bindings_raw::LockingClause) -> protobuf::LockingClause { + protobuf::LockingClause { + locked_rels: convert_list_to_nodes(lc.lockedRels), + strength: lc.strength as i32 + 1, + wait_policy: lc.waitPolicy as i32 + 1, + } +} + +unsafe fn convert_min_max_expr(mme: &bindings_raw::MinMaxExpr) -> protobuf::MinMaxExpr { + protobuf::MinMaxExpr { + xpr: None, // Expression type info, not needed for parse tree + minmaxtype: mme.minmaxtype, + minmaxcollid: mme.minmaxcollid, + inputcollid: mme.inputcollid, + op: mme.op as i32 + 1, + args: convert_list_to_nodes(mme.args), + location: mme.location, + } +} + +unsafe fn convert_grouping_set(gs: &bindings_raw::GroupingSet) -> protobuf::GroupingSet { + protobuf::GroupingSet { + kind: gs.kind as i32 + 1, + content: convert_list_to_nodes(gs.content), + location: gs.location, + } +} + +unsafe fn convert_range_subselect(rs: &bindings_raw::RangeSubselect) -> protobuf::RangeSubselect { + protobuf::RangeSubselect { + lateral: rs.lateral, + subquery: convert_node_boxed(rs.subquery), + alias: if rs.alias.is_null() { None } else { Some(convert_alias(&*rs.alias)) }, + } +} + +unsafe fn convert_a_array_expr(ae: &bindings_raw::A_ArrayExpr) -> protobuf::AArrayExpr { + protobuf::AArrayExpr { + elements: convert_list_to_nodes(ae.elements), + location: ae.location, + } +} + +unsafe fn convert_a_indirection(ai: &bindings_raw::A_Indirection) -> protobuf::AIndirection { + protobuf::AIndirection { + arg: convert_node_boxed(ai.arg), + indirection: convert_list_to_nodes(ai.indirection), + } +} + +unsafe fn convert_a_indices(ai: &bindings_raw::A_Indices) -> protobuf::AIndices { + protobuf::AIndices { + is_slice: ai.is_slice, + lidx: convert_node_boxed(ai.lidx), + uidx: convert_node_boxed(ai.uidx), + } +} + +unsafe fn convert_alter_table_stmt(ats: &bindings_raw::AlterTableStmt) -> protobuf::AlterTableStmt { + protobuf::AlterTableStmt { + relation: if ats.relation.is_null() { None } else { Some(convert_range_var(&*ats.relation)) }, + cmds: convert_list_to_nodes(ats.cmds), + objtype: ats.objtype as i32 + 1, + missing_ok: ats.missing_ok, + } +} + +unsafe fn convert_alter_table_cmd(atc: &bindings_raw::AlterTableCmd) -> protobuf::AlterTableCmd { + protobuf::AlterTableCmd { + subtype: atc.subtype as i32 + 1, + name: convert_c_string(atc.name), + num: atc.num as i32, + newowner: if atc.newowner.is_null() { None } else { Some(convert_role_spec(&*atc.newowner)) }, + def: convert_node_boxed(atc.def), + behavior: atc.behavior as i32 + 1, + missing_ok: atc.missing_ok, + recurse: atc.recurse, + } +} + +unsafe fn convert_role_spec(rs: &bindings_raw::RoleSpec) -> protobuf::RoleSpec { + protobuf::RoleSpec { + roletype: rs.roletype as i32 + 1, + rolename: convert_c_string(rs.rolename), + location: rs.location, + } +} + +unsafe fn convert_copy_stmt(cs: &bindings_raw::CopyStmt) -> protobuf::CopyStmt { + protobuf::CopyStmt { + relation: if cs.relation.is_null() { None } else { Some(convert_range_var(&*cs.relation)) }, + query: convert_node_boxed(cs.query), + attlist: convert_list_to_nodes(cs.attlist), + is_from: cs.is_from, + is_program: cs.is_program, + filename: convert_c_string(cs.filename), + options: convert_list_to_nodes(cs.options), + where_clause: convert_node_boxed(cs.whereClause), + } +} + +unsafe fn convert_truncate_stmt(ts: &bindings_raw::TruncateStmt) -> protobuf::TruncateStmt { + protobuf::TruncateStmt { + relations: convert_list_to_nodes(ts.relations), + restart_seqs: ts.restart_seqs, + behavior: ts.behavior as i32 + 1, + } +} + +unsafe fn convert_view_stmt(vs: &bindings_raw::ViewStmt) -> protobuf::ViewStmt { + protobuf::ViewStmt { + view: if vs.view.is_null() { None } else { Some(convert_range_var(&*vs.view)) }, + aliases: convert_list_to_nodes(vs.aliases), + query: convert_node_boxed(vs.query), + replace: vs.replace, + options: convert_list_to_nodes(vs.options), + with_check_option: vs.withCheckOption as i32 + 1, + } +} + +unsafe fn convert_explain_stmt(es: &bindings_raw::ExplainStmt) -> protobuf::ExplainStmt { + protobuf::ExplainStmt { + query: convert_node_boxed(es.query), + options: convert_list_to_nodes(es.options), + } +} + +unsafe fn convert_create_table_as_stmt(ctas: &bindings_raw::CreateTableAsStmt) -> protobuf::CreateTableAsStmt { + protobuf::CreateTableAsStmt { + query: convert_node_boxed(ctas.query), + into: convert_into_clause(ctas.into), + objtype: ctas.objtype as i32 + 1, + is_select_into: ctas.is_select_into, + if_not_exists: ctas.if_not_exists, + } +} + +unsafe fn convert_prepare_stmt(ps: &bindings_raw::PrepareStmt) -> protobuf::PrepareStmt { + protobuf::PrepareStmt { + name: convert_c_string(ps.name), + argtypes: convert_list_to_nodes(ps.argtypes), + query: convert_node_boxed(ps.query), + } +} + +unsafe fn convert_execute_stmt(es: &bindings_raw::ExecuteStmt) -> protobuf::ExecuteStmt { + protobuf::ExecuteStmt { + name: convert_c_string(es.name), + params: convert_list_to_nodes(es.params), + } +} + +unsafe fn convert_deallocate_stmt(ds: &bindings_raw::DeallocateStmt) -> protobuf::DeallocateStmt { + protobuf::DeallocateStmt { + name: convert_c_string(ds.name), + } +} + +unsafe fn convert_set_to_default(std: &bindings_raw::SetToDefault) -> protobuf::SetToDefault { + protobuf::SetToDefault { + xpr: None, // Expression type info, not needed for parse tree + type_id: std.typeId, + type_mod: std.typeMod, + collation: std.collation, + location: std.location, + } +} + +unsafe fn convert_multi_assign_ref(mar: &bindings_raw::MultiAssignRef) -> protobuf::MultiAssignRef { + protobuf::MultiAssignRef { + source: convert_node_boxed(mar.source), + colno: mar.colno, + ncolumns: mar.ncolumns, + } +} + +unsafe fn convert_row_expr(re: &bindings_raw::RowExpr) -> protobuf::RowExpr { + protobuf::RowExpr { + xpr: None, // Expression type info, not needed for parse tree + args: convert_list_to_nodes(re.args), + row_typeid: re.row_typeid, + row_format: re.row_format as i32 + 1, + colnames: convert_list_to_nodes(re.colnames), + location: re.location, + } +} + +unsafe fn convert_collate_clause(cc: &bindings_raw::CollateClause) -> protobuf::CollateClause { + protobuf::CollateClause { + arg: convert_node_boxed(cc.arg), + collname: convert_list_to_nodes(cc.collname), + location: cc.location, + } +} + +unsafe fn convert_collate_clause_opt(cc: *mut bindings_raw::CollateClause) -> Option> { + if cc.is_null() { + None + } else { + Some(Box::new(convert_collate_clause(&*cc))) + } +} + +unsafe fn convert_partition_spec(ps: &bindings_raw::PartitionSpec) -> protobuf::PartitionSpec { + protobuf::PartitionSpec { + strategy: ps.strategy as i32 + 1, + part_params: convert_list_to_nodes(ps.partParams), + location: ps.location, + } +} + +unsafe fn convert_partition_spec_opt(ps: *mut bindings_raw::PartitionSpec) -> Option> { + if ps.is_null() { + None + } else { + Some(Box::new(convert_partition_spec(&*ps))) + } +} + +unsafe fn convert_partition_bound_spec(pbs: &bindings_raw::PartitionBoundSpec) -> protobuf::PartitionBoundSpec { + protobuf::PartitionBoundSpec { + strategy: if pbs.strategy == 0 { String::new() } else { String::from_utf8_lossy(&[pbs.strategy as u8]).to_string() }, + is_default: pbs.is_default, + modulus: pbs.modulus, + remainder: pbs.remainder, + listdatums: convert_list_to_nodes(pbs.listdatums), + lowerdatums: convert_list_to_nodes(pbs.lowerdatums), + upperdatums: convert_list_to_nodes(pbs.upperdatums), + location: pbs.location, + } +} + +unsafe fn convert_partition_bound_spec_opt(pbs: *mut bindings_raw::PartitionBoundSpec) -> Option> { + if pbs.is_null() { + None + } else { + Some(Box::new(convert_partition_bound_spec(&*pbs))) + } +} + +unsafe fn convert_cte_search_clause(csc: &bindings_raw::CTESearchClause) -> protobuf::CtesearchClause { + protobuf::CtesearchClause { + search_col_list: convert_list_to_nodes(csc.search_col_list), + search_breadth_first: csc.search_breadth_first, + search_seq_column: convert_c_string(csc.search_seq_column), + location: csc.location, + } +} + +unsafe fn convert_cte_search_clause_opt(csc: *mut bindings_raw::CTESearchClause) -> Option> { + if csc.is_null() { + None + } else { + Some(Box::new(convert_cte_search_clause(&*csc))) + } +} + +unsafe fn convert_cte_cycle_clause(ccc: &bindings_raw::CTECycleClause) -> protobuf::CtecycleClause { + protobuf::CtecycleClause { + cycle_col_list: convert_list_to_nodes(ccc.cycle_col_list), + cycle_mark_column: convert_c_string(ccc.cycle_mark_column), + cycle_mark_value: convert_node_boxed(ccc.cycle_mark_value), + cycle_mark_default: convert_node_boxed(ccc.cycle_mark_default), + cycle_path_column: convert_c_string(ccc.cycle_path_column), + location: ccc.location, + cycle_mark_type: ccc.cycle_mark_type, + cycle_mark_typmod: ccc.cycle_mark_typmod, + cycle_mark_collation: ccc.cycle_mark_collation, + cycle_mark_neop: ccc.cycle_mark_neop, + } +} + +unsafe fn convert_cte_cycle_clause_opt(ccc: *mut bindings_raw::CTECycleClause) -> Option> { + if ccc.is_null() { + None + } else { + Some(Box::new(convert_cte_cycle_clause(&*ccc))) + } +} + +// ============================================================================ +// Utility Functions +// ============================================================================ + +/// Converts a C string pointer to a Rust String. +unsafe fn convert_c_string(ptr: *const i8) -> std::string::String { + if ptr.is_null() { + std::string::String::new() + } else { + CStr::from_ptr(ptr).to_string_lossy().to_string() + } +} diff --git a/tests/ast_tests.rs b/tests/ast_tests.rs new file mode 100644 index 0000000..f985d04 --- /dev/null +++ b/tests/ast_tests.rs @@ -0,0 +1,374 @@ +#![allow(non_snake_case)] +#![cfg(test)] + +use pg_query::ast::{Node, SelectStmt, InsertStmt, UpdateStmt, DeleteStmt, SetOperation, JoinType}; +use pg_query::{parse_to_ast, deparse_ast}; + +#[macro_use] +mod support; + +/// Test that parse_to_ast successfully parses a simple SELECT query +#[test] +fn it_parses_simple_select_to_ast() { + let result = parse_to_ast("SELECT * FROM users").unwrap(); + assert_eq!(result.stmts.len(), 1); + + if let Node::SelectStmt(select) = &result.stmts[0].stmt { + // Check from_clause contains users table + assert_eq!(select.from_clause.len(), 1); + if let Node::RangeVar(range_var) = &select.from_clause[0] { + assert_eq!(range_var.relname, "users"); + } else { + panic!("Expected RangeVar in from_clause"); + } + + // Check target_list contains * + assert_eq!(select.target_list.len(), 1); + if let Node::ResTarget(res_target) = &select.target_list[0] { + assert!(res_target.val.is_some()); + if let Some(Node::ColumnRef(col_ref)) = &res_target.val { + assert_eq!(col_ref.fields.len(), 1); + assert!(matches!(&col_ref.fields[0], Node::AStar(_))); + } else { + panic!("Expected ColumnRef with AStar"); + } + } else { + panic!("Expected ResTarget in target_list"); + } + } else { + panic!("Expected SelectStmt"); + } +} + +/// Test that parse_to_ast handles errors correctly +#[test] +fn it_handles_parse_errors() { + let result = parse_to_ast("SELECT * FORM users"); + assert!(result.is_err()); +} + +/// Test parsing SELECT with WHERE clause +#[test] +fn it_parses_select_with_where_clause() { + let result = parse_to_ast("SELECT id, name FROM users WHERE id = 1").unwrap(); + assert_eq!(result.stmts.len(), 1); + + if let Node::SelectStmt(select) = &result.stmts[0].stmt { + assert!(select.where_clause.is_some()); + assert_eq!(select.target_list.len(), 2); + assert_eq!(select.from_clause.len(), 1); + } else { + panic!("Expected SelectStmt"); + } +} + +/// Test parsing INSERT statement +#[test] +fn it_parses_insert_to_ast() { + let result = parse_to_ast("INSERT INTO users (name, email) VALUES ('test', 'test@example.com')").unwrap(); + assert_eq!(result.stmts.len(), 1); + + if let Node::InsertStmt(insert) = &result.stmts[0].stmt { + // Check relation + if let Some(rel) = &insert.relation { + assert_eq!(rel.relname, "users"); + } else { + panic!("Expected relation"); + } + + // Check columns + assert_eq!(insert.cols.len(), 2); + } else { + panic!("Expected InsertStmt"); + } +} + +/// Test parsing UPDATE statement +#[test] +fn it_parses_update_to_ast() { + let result = parse_to_ast("UPDATE users SET name = 'bob' WHERE id = 1").unwrap(); + assert_eq!(result.stmts.len(), 1); + + if let Node::UpdateStmt(update) = &result.stmts[0].stmt { + // Check relation + if let Some(rel) = &update.relation { + assert_eq!(rel.relname, "users"); + } else { + panic!("Expected relation"); + } + + // Check target_list (SET clause) + assert_eq!(update.target_list.len(), 1); + + // Check where_clause + assert!(update.where_clause.is_some()); + } else { + panic!("Expected UpdateStmt"); + } +} + +/// Test parsing DELETE statement +#[test] +fn it_parses_delete_to_ast() { + let result = parse_to_ast("DELETE FROM users WHERE id = 1").unwrap(); + assert_eq!(result.stmts.len(), 1); + + if let Node::DeleteStmt(delete) = &result.stmts[0].stmt { + // Check relation + if let Some(rel) = &delete.relation { + assert_eq!(rel.relname, "users"); + } else { + panic!("Expected relation"); + } + + // Check where_clause + assert!(delete.where_clause.is_some()); + } else { + panic!("Expected DeleteStmt"); + } +} + +/// Test parsing SELECT with JOIN +#[test] +fn it_parses_select_with_join() { + let result = parse_to_ast("SELECT u.id, o.total FROM users u INNER JOIN orders o ON u.id = o.user_id").unwrap(); + assert_eq!(result.stmts.len(), 1); + + if let Node::SelectStmt(select) = &result.stmts[0].stmt { + assert_eq!(select.from_clause.len(), 1); + + if let Node::JoinExpr(join) = &select.from_clause[0] { + assert_eq!(join.jointype, JoinType::Inner); + assert!(join.larg.is_some()); + assert!(join.rarg.is_some()); + assert!(join.quals.is_some()); + } else { + panic!("Expected JoinExpr in from_clause"); + } + } else { + panic!("Expected SelectStmt"); + } +} + +/// Test parsing UNION query +#[test] +fn it_parses_union_query() { + let result = parse_to_ast("SELECT id FROM users UNION SELECT id FROM admins").unwrap(); + assert_eq!(result.stmts.len(), 1); + + if let Node::SelectStmt(select) = &result.stmts[0].stmt { + assert_eq!(select.op, SetOperation::Union); + assert!(select.larg.is_some()); + assert!(select.rarg.is_some()); + } else { + panic!("Expected SelectStmt"); + } +} + +/// Test round-trip: parse to AST then deparse back to SQL +#[test] +fn it_roundtrips_simple_select() { + let original = "SELECT * FROM users"; + let ast = parse_to_ast(original).unwrap(); + let deparsed = deparse_ast(&ast).unwrap(); + assert_eq!(deparsed, original); +} + +/// Test round-trip: SELECT with WHERE clause +#[test] +fn it_roundtrips_select_with_where() { + let original = "SELECT id, name FROM users WHERE id = 1"; + let ast = parse_to_ast(original).unwrap(); + let deparsed = deparse_ast(&ast).unwrap(); + assert_eq!(deparsed, original); +} + +/// Test round-trip: INSERT statement +#[test] +fn it_roundtrips_insert() { + let original = "INSERT INTO users (name) VALUES ('test')"; + let ast = parse_to_ast(original).unwrap(); + let deparsed = deparse_ast(&ast).unwrap(); + assert_eq!(deparsed, original); +} + +/// Test round-trip: UPDATE statement +#[test] +fn it_roundtrips_update() { + let original = "UPDATE users SET name = 'bob' WHERE id = 1"; + let ast = parse_to_ast(original).unwrap(); + let deparsed = deparse_ast(&ast).unwrap(); + assert_eq!(deparsed, original); +} + +/// Test round-trip: DELETE statement +#[test] +fn it_roundtrips_delete() { + let original = "DELETE FROM users WHERE id = 1"; + let ast = parse_to_ast(original).unwrap(); + let deparsed = deparse_ast(&ast).unwrap(); + assert_eq!(deparsed, original); +} + +/// Test round-trip: SELECT with JOIN +#[test] +fn it_roundtrips_join() { + let original = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id"; + let ast = parse_to_ast(original).unwrap(); + let deparsed = deparse_ast(&ast).unwrap(); + assert_eq!(deparsed, original); +} + +/// Test round-trip: UNION query +#[test] +fn it_roundtrips_union() { + let original = "SELECT id FROM users UNION SELECT id FROM admins"; + let ast = parse_to_ast(original).unwrap(); + let deparsed = deparse_ast(&ast).unwrap(); + assert_eq!(deparsed, original); +} + +/// Test round-trip: complex SELECT +#[test] +fn it_roundtrips_complex_select() { + let original = "SELECT u.id, u.name, count(*) AS order_count FROM users u LEFT JOIN orders o ON u.id = o.user_id WHERE u.active = true GROUP BY u.id, u.name HAVING count(*) > 0 ORDER BY order_count DESC LIMIT 10"; + let ast = parse_to_ast(original).unwrap(); + let deparsed = deparse_ast(&ast).unwrap(); + assert_eq!(deparsed, original); +} + +/// Test round-trip: WITH clause (CTE) +#[test] +fn it_roundtrips_cte() { + let original = "WITH active_users AS (SELECT * FROM users WHERE active = true) SELECT * FROM active_users"; + let ast = parse_to_ast(original).unwrap(); + let deparsed = deparse_ast(&ast).unwrap(); + assert_eq!(deparsed, original); +} + +/// Test round-trip: CREATE TABLE +#[test] +fn it_roundtrips_create_table() { + let original = "CREATE TABLE test (id integer, name text)"; + let ast = parse_to_ast(original).unwrap(); + let deparsed = deparse_ast(&ast).unwrap(); + // pg_query uses "int" instead of "integer" in its canonical form + assert_eq!(deparsed, "CREATE TABLE test (id int, name text)"); +} + +/// Test round-trip: DROP TABLE +#[test] +fn it_roundtrips_drop_table() { + let original = "DROP TABLE users"; + let ast = parse_to_ast(original).unwrap(); + let deparsed = deparse_ast(&ast).unwrap(); + assert_eq!(deparsed, original); +} + +/// Test round-trip: CREATE INDEX +#[test] +fn it_roundtrips_create_index() { + let original = "CREATE INDEX idx_users_name ON users (name)"; + let ast = parse_to_ast(original).unwrap(); + let deparsed = deparse_ast(&ast).unwrap(); + // pg_query adds explicit "USING btree" in its canonical form + assert_eq!(deparsed, "CREATE INDEX idx_users_name ON users USING btree (name)"); +} + +/// Test that the AST types are ergonomic (no deep Option> unwrapping) +#[test] +fn ast_types_are_ergonomic() { + let result = parse_to_ast("SELECT id FROM users WHERE active = true").unwrap(); + + // Direct pattern matching without .as_ref().unwrap() chains + if let Node::SelectStmt(select) = &result.stmts[0].stmt { + // Direct access to from_clause vector + for table in &select.from_clause { + if let Node::RangeVar(rv) = table { + assert_eq!(rv.relname, "users"); + } + } + + // Direct access to target_list + for target in &select.target_list { + if let Node::ResTarget(rt) = target { + if let Some(Node::ColumnRef(cr)) = &rt.val { + // Can access fields directly + assert!(!cr.fields.is_empty()); + } + } + } + } +} + +/// Test parsing multiple statements +#[test] +fn it_parses_multiple_statements() { + let result = parse_to_ast("SELECT 1; SELECT 2; SELECT 3").unwrap(); + assert_eq!(result.stmts.len(), 3); + + for stmt in &result.stmts { + assert!(matches!(&stmt.stmt, Node::SelectStmt(_))); + } +} + +/// Test parsing empty query (comment only) +#[test] +fn it_parses_empty_query() { + let result = parse_to_ast("-- just a comment").unwrap(); + assert_eq!(result.stmts.len(), 0); +} + +/// Test round-trip: subquery in SELECT +#[test] +fn it_roundtrips_subquery() { + let original = "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)"; + let ast = parse_to_ast(original).unwrap(); + let deparsed = deparse_ast(&ast).unwrap(); + assert_eq!(deparsed, original); +} + +/// Test round-trip: aggregate functions +#[test] +fn it_roundtrips_aggregates() { + let original = "SELECT count(*), sum(amount), avg(price) FROM orders"; + let ast = parse_to_ast(original).unwrap(); + let deparsed = deparse_ast(&ast).unwrap(); + assert_eq!(deparsed, original); +} + +/// Test round-trip: CASE expression +#[test] +fn it_roundtrips_case_expression() { + let original = "SELECT CASE WHEN x > 0 THEN 'positive' ELSE 'non-positive' END FROM t"; + let ast = parse_to_ast(original).unwrap(); + let deparsed = deparse_ast(&ast).unwrap(); + assert_eq!(deparsed, original); +} + +/// Test round-trip: INSERT with RETURNING +#[test] +fn it_roundtrips_insert_returning() { + let original = "INSERT INTO users (name) VALUES ('test') RETURNING id"; + let ast = parse_to_ast(original).unwrap(); + let deparsed = deparse_ast(&ast).unwrap(); + assert_eq!(deparsed, original); +} + +/// Test round-trip: UPDATE with FROM +#[test] +fn it_roundtrips_update_from() { + let original = "UPDATE users SET name = o.name FROM other_users o WHERE users.id = o.id"; + let ast = parse_to_ast(original).unwrap(); + let deparsed = deparse_ast(&ast).unwrap(); + assert_eq!(deparsed, original); +} + +/// Test round-trip: DELETE with USING +#[test] +fn it_roundtrips_delete_using() { + let original = "DELETE FROM users USING orders WHERE users.id = orders.user_id"; + let ast = parse_to_ast(original).unwrap(); + let deparsed = deparse_ast(&ast).unwrap(); + assert_eq!(deparsed, original); +} diff --git a/tests/raw_parse_tests.rs b/tests/raw_parse_tests.rs new file mode 100644 index 0000000..12afc8d --- /dev/null +++ b/tests/raw_parse_tests.rs @@ -0,0 +1,2070 @@ +#![allow(non_snake_case)] +#![cfg(test)] + +use pg_query::{parse, parse_raw, Error}; +use pg_query::protobuf::{a_const, node, ParseResult as ProtobufParseResult}; + +#[macro_use] +mod support; + +// ============================================================================ +// Helper functions +// ============================================================================ + +/// Helper to extract AConst from a SELECT statement's first target +fn get_first_const(result: &ProtobufParseResult) -> Option<&pg_query::protobuf::AConst> { + let stmt = result.stmts.first()?; + let raw_stmt = stmt.stmt.as_ref()?; + let node = raw_stmt.node.as_ref()?; + + if let node::Node::SelectStmt(select) = node { + let target = select.target_list.first()?; + if let Some(node::Node::ResTarget(res_target)) = target.node.as_ref() { + if let Some(val_node) = res_target.val.as_ref() { + if let Some(node::Node::AConst(aconst)) = val_node.node.as_ref() { + return Some(aconst); + } + } + } + } + None +} + +// ============================================================================ +// Basic parsing tests +// ============================================================================ + +/// Test that parse_raw successfully parses a simple SELECT query +#[test] +fn it_parses_simple_select() { + let query = "SELECT 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf.stmts.len(), 1); + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test that parse_raw handles syntax errors +#[test] +fn it_handles_parse_errors() { + let query = "SELECT * FORM users"; + let raw_error = parse_raw(query).err().unwrap(); + let proto_error = parse(query).err().unwrap(); + + assert!(matches!(raw_error, Error::Parse(_))); + assert!(matches!(proto_error, Error::Parse(_))); +} + +/// Test that parse_raw and parse produce equivalent results for simple SELECT +#[test] +fn it_matches_parse_for_simple_select() { + let query = "SELECT 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test that parse_raw and parse produce equivalent results for SELECT with table +#[test] +fn it_matches_parse_for_select_from_table() { + let query = "SELECT * FROM users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify tables are extracted correctly + let mut raw_tables = raw_result.tables(); + let mut proto_tables = proto_result.tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["users"]); +} + +/// Test that parse_raw handles empty queries (comments only) +#[test] +fn it_handles_empty_queries() { + let query = "-- just a comment"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf.stmts.len(), 0); + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test that parse_raw parses multiple statements +#[test] +fn it_parses_multiple_statements() { + let query = "SELECT 1; SELECT 2; SELECT 3"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf.stmts.len(), 3); + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// DML statement tests +// ============================================================================ + +/// Test parsing INSERT statement +#[test] +fn it_parses_insert() { + let query = "INSERT INTO users (name) VALUES ('test')"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify the INSERT target table + let mut raw_tables = raw_result.dml_tables(); + let mut proto_tables = proto_result.dml_tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["users"]); +} + +/// Test parsing UPDATE statement +#[test] +fn it_parses_update() { + let query = "UPDATE users SET name = 'bob' WHERE id = 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify the UPDATE target table + let mut raw_tables = raw_result.dml_tables(); + let mut proto_tables = proto_result.dml_tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["users"]); +} + +/// Test parsing DELETE statement +#[test] +fn it_parses_delete() { + let query = "DELETE FROM users WHERE id = 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify the DELETE target table + let mut raw_tables = raw_result.dml_tables(); + let mut proto_tables = proto_result.dml_tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["users"]); +} + +// ============================================================================ +// DDL statement tests +// ============================================================================ + +/// Test parsing CREATE TABLE +#[test] +fn it_parses_create_table() { + let query = "CREATE TABLE test (id int, name text)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify statement types match + assert_eq!(raw_result.statement_types(), proto_result.statement_types()); + assert_eq!(raw_result.statement_types(), vec!["CreateStmt"]); +} + +/// Test parsing DROP TABLE +#[test] +fn it_parses_drop_table() { + let query = "DROP TABLE users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify DDL tables match + let mut raw_tables = raw_result.ddl_tables(); + let mut proto_tables = proto_result.ddl_tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["users"]); +} + +/// Test parsing CREATE INDEX +#[test] +fn it_parses_create_index() { + let query = "CREATE INDEX idx_users_name ON users (name)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify statement types match + assert_eq!(raw_result.statement_types(), proto_result.statement_types()); + assert_eq!(raw_result.statement_types(), vec!["IndexStmt"]); +} + +// ============================================================================ +// JOIN and complex SELECT tests +// ============================================================================ + +/// Test parsing SELECT with JOIN +#[test] +fn it_parses_join() { + let query = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify tables are extracted correctly + let mut raw_tables = raw_result.tables(); + let mut proto_tables = proto_result.tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["orders", "users"]); +} + +/// Test parsing UNION query +#[test] +fn it_parses_union() { + let query = "SELECT id FROM users UNION SELECT id FROM admins"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify tables from both sides of UNION + let mut raw_tables = raw_result.tables(); + let mut proto_tables = proto_result.tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["admins", "users"]); +} + +/// Test parsing WITH clause (CTE) +#[test] +fn it_parses_cte() { + let query = "WITH active_users AS (SELECT * FROM users WHERE active = true) SELECT * FROM active_users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify CTE names match + assert_eq!(raw_result.cte_names, proto_result.cte_names); + assert!(raw_result.cte_names.contains(&"active_users".to_string())); + + // Verify tables (should only include actual tables, not CTEs) + let mut raw_tables = raw_result.tables(); + let mut proto_tables = proto_result.tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["users"]); +} + +/// Test parsing subquery in SELECT +#[test] +fn it_parses_subquery() { + let query = "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify all tables are found + let mut raw_tables = raw_result.tables(); + let mut proto_tables = proto_result.tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["orders", "users"]); +} + +/// Test parsing aggregate functions +#[test] +fn it_parses_aggregates() { + let query = "SELECT count(*), sum(amount), avg(price) FROM orders"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify functions are extracted correctly + let mut raw_funcs = raw_result.functions(); + let mut proto_funcs = proto_result.functions(); + raw_funcs.sort(); + proto_funcs.sort(); + assert_eq!(raw_funcs, proto_funcs); + assert!(raw_funcs.contains(&"count".to_string())); + assert!(raw_funcs.contains(&"sum".to_string())); + assert!(raw_funcs.contains(&"avg".to_string())); +} + +/// Test parsing CASE expression +#[test] +fn it_parses_case_expression() { + let query = "SELECT CASE WHEN x > 0 THEN 'positive' ELSE 'non-positive' END FROM t"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify table is found + let raw_tables = raw_result.tables(); + let proto_tables = proto_result.tables(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["t"]); +} + +/// Test parsing complex SELECT with multiple clauses +#[test] +fn it_parses_complex_select() { + let query = "SELECT u.id, u.name, count(*) AS order_count FROM users u LEFT JOIN orders o ON u.id = o.user_id WHERE u.active = true GROUP BY u.id, u.name HAVING count(*) > 0 ORDER BY order_count DESC LIMIT 10"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify tables + let mut raw_tables = raw_result.tables(); + let mut proto_tables = proto_result.tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["orders", "users"]); + + // Verify functions + let mut raw_funcs = raw_result.functions(); + let mut proto_funcs = proto_result.functions(); + raw_funcs.sort(); + proto_funcs.sort(); + assert_eq!(raw_funcs, proto_funcs); + assert!(raw_funcs.contains(&"count".to_string())); +} + +// ============================================================================ +// INSERT variations +// ============================================================================ + +/// Test parsing INSERT with ON CONFLICT +#[test] +fn it_parses_insert_on_conflict() { + let query = "INSERT INTO users (id, name) VALUES (1, 'test') ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify DML tables + let raw_tables = raw_result.dml_tables(); + let proto_tables = proto_result.dml_tables(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["users"]); +} + +/// Test parsing INSERT with RETURNING +#[test] +fn it_parses_insert_returning() { + let query = "INSERT INTO users (name) VALUES ('test') RETURNING id"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Literal value tests +// ============================================================================ + +/// Test parsing float with leading dot +#[test] +fn it_parses_floats_with_leading_dot() { + let query = "SELECT .1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify the float value + let raw_const = get_first_const(&raw_result.protobuf).expect("should have const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("should have const"); + assert_eq!(raw_const, proto_const); +} + +/// Test parsing bit string in hex notation +#[test] +fn it_parses_bit_strings_hex() { + let query = "SELECT X'EFFF'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify the bit string value + let raw_const = get_first_const(&raw_result.protobuf).expect("should have const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("should have const"); + assert_eq!(raw_const, proto_const); +} + +/// Test parsing real-world query with multiple joins +#[test] +fn it_parses_real_world_query() { + let query = " + SELECT memory_total_bytes, memory_free_bytes, memory_pagecache_bytes, + (memory_swap_total_bytes - memory_swap_free_bytes) AS swap + FROM snapshots s JOIN system_snapshots ON (snapshot_id = s.id) + WHERE s.database_id = 1 AND s.collected_at BETWEEN '2021-01-01' AND '2021-12-31' + ORDER BY collected_at"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify tables + let mut raw_tables = raw_result.tables(); + let mut proto_tables = proto_result.tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["snapshots", "system_snapshots"]); +} + +// ============================================================================ +// A_Const value extraction tests +// ============================================================================ + +/// Test that parse_raw extracts integer values correctly and matches parse +#[test] +fn it_extracts_integer_const() { + let query = "SELECT 42"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); + + assert_eq!(raw_const, proto_const); + assert!(!raw_const.isnull); + match &raw_const.val { + Some(a_const::Val::Ival(int_val)) => { + assert_eq!(int_val.ival, 42); + } + other => panic!("Expected Ival, got {:?}", other), + } +} + +/// Test that parse_raw extracts negative integer values correctly +#[test] +fn it_extracts_negative_integer_const() { + let query = "SELECT -123"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test that parse_raw extracts string values correctly and matches parse +#[test] +fn it_extracts_string_const() { + let query = "SELECT 'hello world'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); + + assert_eq!(raw_const, proto_const); + assert!(!raw_const.isnull); + match &raw_const.val { + Some(a_const::Val::Sval(str_val)) => { + assert_eq!(str_val.sval, "hello world"); + } + other => panic!("Expected Sval, got {:?}", other), + } +} + +/// Test that parse_raw extracts float values correctly and matches parse +#[test] +fn it_extracts_float_const() { + let query = "SELECT 3.14159"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); + + assert_eq!(raw_const, proto_const); + assert!(!raw_const.isnull); + match &raw_const.val { + Some(a_const::Val::Fval(float_val)) => { + assert_eq!(float_val.fval, "3.14159"); + } + other => panic!("Expected Fval, got {:?}", other), + } +} + +/// Test that parse_raw extracts boolean TRUE correctly and matches parse +#[test] +fn it_extracts_boolean_true_const() { + let query = "SELECT TRUE"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); + + assert_eq!(raw_const, proto_const); + assert!(!raw_const.isnull); + match &raw_const.val { + Some(a_const::Val::Boolval(bool_val)) => { + assert!(bool_val.boolval); + } + other => panic!("Expected Boolval(true), got {:?}", other), + } +} + +/// Test that parse_raw extracts boolean FALSE correctly and matches parse +#[test] +fn it_extracts_boolean_false_const() { + let query = "SELECT FALSE"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); + + assert_eq!(raw_const, proto_const); + assert!(!raw_const.isnull); + match &raw_const.val { + Some(a_const::Val::Boolval(bool_val)) => { + assert!(!bool_val.boolval); + } + other => panic!("Expected Boolval(false), got {:?}", other), + } +} + +/// Test that parse_raw extracts NULL correctly and matches parse +#[test] +fn it_extracts_null_const() { + let query = "SELECT NULL"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); + + assert_eq!(raw_const, proto_const); + assert!(raw_const.isnull); + assert!(raw_const.val.is_none()); +} + +/// Test that parse_raw extracts bit string values correctly and matches parse +#[test] +fn it_extracts_bit_string_const() { + let query = "SELECT B'1010'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); + + assert_eq!(raw_const, proto_const); + assert!(!raw_const.isnull); + match &raw_const.val { + Some(a_const::Val::Bsval(bit_val)) => { + assert_eq!(bit_val.bsval, "b1010"); + } + other => panic!("Expected Bsval, got {:?}", other), + } +} + +/// Test that parse_raw extracts hex bit string correctly and matches parse +#[test] +fn it_extracts_hex_bit_string_const() { + let query = "SELECT X'FF'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); + + assert_eq!(raw_const, proto_const); + assert!(!raw_const.isnull); + match &raw_const.val { + Some(a_const::Val::Bsval(bit_val)) => { + assert_eq!(bit_val.bsval, "xFF"); + } + other => panic!("Expected Bsval, got {:?}", other), + } +} + +// ============================================================================ +// ParseResult method equivalence tests +// ============================================================================ + +/// Test that tables() returns the same results for both parsers +#[test] +fn it_returns_tables_like_parse() { + let query = "SELECT * FROM users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Both should have the same tables + let mut raw_tables = raw_result.tables(); + let mut proto_tables = proto_result.tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["users"]); +} + +/// Test that functions() returns the same results for both parsers +#[test] +fn it_returns_functions_like_parse() { + let query = "SELECT count(*), sum(amount) FROM orders"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Both should have the same functions + let mut raw_funcs = raw_result.functions(); + let mut proto_funcs = proto_result.functions(); + raw_funcs.sort(); + proto_funcs.sort(); + assert_eq!(raw_funcs, proto_funcs); + assert_eq!(raw_funcs, vec!["count", "sum"]); +} + +/// Test that statement_types() returns the same results for both parsers +#[test] +fn it_returns_statement_types_like_parse() { + let query = "SELECT 1; INSERT INTO t VALUES (1); UPDATE t SET x = 1; DELETE FROM t"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + assert_eq!(raw_result.statement_types(), proto_result.statement_types()); + assert_eq!(raw_result.statement_types(), vec!["SelectStmt", "InsertStmt", "UpdateStmt", "DeleteStmt"]); +} + +// ============================================================================ +// Advanced JOIN tests +// ============================================================================ + +/// Test LEFT JOIN +#[test] +fn it_parses_left_join() { + let query = "SELECT * FROM users u LEFT JOIN orders o ON u.id = o.user_id"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let mut raw_tables = raw_result.tables(); + raw_tables.sort(); + assert_eq!(raw_tables, vec!["orders", "users"]); +} + +/// Test RIGHT JOIN +#[test] +fn it_parses_right_join() { + let query = "SELECT * FROM users u RIGHT JOIN orders o ON u.id = o.user_id"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test FULL OUTER JOIN +#[test] +fn it_parses_full_outer_join() { + let query = "SELECT * FROM users u FULL OUTER JOIN orders o ON u.id = o.user_id"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CROSS JOIN +#[test] +fn it_parses_cross_join() { + let query = "SELECT * FROM users CROSS JOIN products"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let mut raw_tables = raw_result.tables(); + raw_tables.sort(); + assert_eq!(raw_tables, vec!["products", "users"]); +} + +/// Test NATURAL JOIN +#[test] +fn it_parses_natural_join() { + let query = "SELECT * FROM users NATURAL JOIN user_profiles"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test multiple JOINs +#[test] +fn it_parses_multiple_joins() { + let query = "SELECT u.name, o.id, p.name FROM users u + JOIN orders o ON u.id = o.user_id + JOIN order_items oi ON o.id = oi.order_id + JOIN products p ON oi.product_id = p.id"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let mut raw_tables = raw_result.tables(); + raw_tables.sort(); + assert_eq!(raw_tables, vec!["order_items", "orders", "products", "users"]); +} + +/// Test JOIN with USING clause +#[test] +fn it_parses_join_using() { + let query = "SELECT * FROM users u JOIN user_profiles p USING (user_id)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test LATERAL JOIN +#[test] +fn it_parses_lateral_join() { + let query = "SELECT * FROM users u, LATERAL (SELECT * FROM orders o WHERE o.user_id = u.id LIMIT 3) AS recent_orders"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let mut raw_tables = raw_result.tables(); + raw_tables.sort(); + assert_eq!(raw_tables, vec!["orders", "users"]); +} + +// ============================================================================ +// Advanced subquery tests +// ============================================================================ + +/// Test correlated subquery +#[test] +fn it_parses_correlated_subquery() { + let query = "SELECT * FROM users u WHERE EXISTS (SELECT 1 FROM orders o WHERE o.user_id = u.id)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let mut raw_tables = raw_result.tables(); + raw_tables.sort(); + assert_eq!(raw_tables, vec!["orders", "users"]); +} + +/// Test NOT EXISTS subquery +#[test] +fn it_parses_not_exists_subquery() { + let query = "SELECT * FROM users u WHERE NOT EXISTS (SELECT 1 FROM banned b WHERE b.user_id = u.id)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test scalar subquery in SELECT +#[test] +fn it_parses_scalar_subquery() { + let query = "SELECT u.name, (SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id) AS order_count FROM users u"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test subquery in FROM clause +#[test] +fn it_parses_derived_table() { + let query = "SELECT * FROM (SELECT id, name FROM users WHERE active = true) AS active_users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ANY/SOME subquery +#[test] +fn it_parses_any_subquery() { + let query = "SELECT * FROM products WHERE price > ANY (SELECT avg_price FROM categories)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ALL subquery +#[test] +fn it_parses_all_subquery() { + let query = "SELECT * FROM products WHERE price > ALL (SELECT price FROM discounted_products)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Window function tests +// ============================================================================ + +/// Test basic window function +#[test] +fn it_parses_window_function() { + let query = "SELECT name, salary, ROW_NUMBER() OVER (ORDER BY salary DESC) AS rank FROM employees"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test window function with PARTITION BY +#[test] +fn it_parses_window_function_partition() { + let query = "SELECT department, name, salary, RANK() OVER (PARTITION BY department ORDER BY salary DESC) AS dept_rank FROM employees"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test window function with frame clause +#[test] +fn it_parses_window_function_frame() { + let query = "SELECT date, amount, SUM(amount) OVER (ORDER BY date ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS moving_sum FROM transactions"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test named window +#[test] +fn it_parses_named_window() { + let query = "SELECT name, salary, SUM(salary) OVER w, AVG(salary) OVER w FROM employees WINDOW w AS (PARTITION BY department ORDER BY salary)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test LAG and LEAD functions +#[test] +fn it_parses_lag_lead() { + let query = "SELECT date, price, LAG(price, 1) OVER (ORDER BY date) AS prev_price, LEAD(price, 1) OVER (ORDER BY date) AS next_price FROM stock_prices"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// CTE variations +// ============================================================================ + +/// Test multiple CTEs +#[test] +fn it_parses_multiple_ctes() { + let query = "WITH + active_users AS (SELECT * FROM users WHERE active = true), + premium_users AS (SELECT * FROM active_users WHERE plan = 'premium') + SELECT * FROM premium_users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert!(raw_result.cte_names.contains(&"active_users".to_string())); + assert!(raw_result.cte_names.contains(&"premium_users".to_string())); +} + +/// Test recursive CTE +#[test] +fn it_parses_recursive_cte() { + let query = "WITH RECURSIVE subordinates AS ( + SELECT id, name, manager_id FROM employees WHERE id = 1 + UNION ALL + SELECT e.id, e.name, e.manager_id FROM employees e INNER JOIN subordinates s ON e.manager_id = s.id + ) SELECT * FROM subordinates"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CTE with column list +#[test] +fn it_parses_cte_with_columns() { + let query = "WITH regional_sales(region, total) AS (SELECT region, SUM(amount) FROM orders GROUP BY region) SELECT * FROM regional_sales"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CTE with MATERIALIZED +#[test] +fn it_parses_cte_materialized() { + let query = "WITH t AS MATERIALIZED (SELECT * FROM large_table WHERE x > 100) SELECT * FROM t"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Set operations +// ============================================================================ + +/// Test INTERSECT +#[test] +fn it_parses_intersect() { + let query = "SELECT id FROM users INTERSECT SELECT user_id FROM orders"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test EXCEPT +#[test] +fn it_parses_except() { + let query = "SELECT id FROM users EXCEPT SELECT user_id FROM banned_users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UNION ALL +#[test] +fn it_parses_union_all() { + let query = "SELECT name FROM users UNION ALL SELECT name FROM admins"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test compound set operations +#[test] +fn it_parses_compound_set_operations() { + let query = "(SELECT id FROM a UNION SELECT id FROM b) INTERSECT SELECT id FROM c"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// GROUP BY variations +// ============================================================================ + +/// Test GROUP BY ROLLUP +#[test] +fn it_parses_group_by_rollup() { + let query = "SELECT region, product, SUM(sales) FROM sales_data GROUP BY ROLLUP(region, product)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test GROUP BY CUBE +#[test] +fn it_parses_group_by_cube() { + let query = "SELECT region, product, SUM(sales) FROM sales_data GROUP BY CUBE(region, product)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test GROUP BY GROUPING SETS +#[test] +fn it_parses_grouping_sets() { + let query = "SELECT region, product, SUM(sales) FROM sales_data GROUP BY GROUPING SETS ((region), (product), ())"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// DISTINCT and ORDER BY variations +// ============================================================================ + +/// Test DISTINCT ON +#[test] +fn it_parses_distinct_on() { + let query = "SELECT DISTINCT ON (user_id) * FROM orders ORDER BY user_id, created_at DESC"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ORDER BY with NULLS FIRST/LAST +#[test] +fn it_parses_order_by_nulls() { + let query = "SELECT * FROM users ORDER BY last_login DESC NULLS LAST"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test FETCH FIRST +#[test] +fn it_parses_fetch_first() { + let query = "SELECT * FROM users ORDER BY id FETCH FIRST 10 ROWS ONLY"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test OFFSET with FETCH +#[test] +fn it_parses_offset_fetch() { + let query = "SELECT * FROM users ORDER BY id OFFSET 20 ROWS FETCH NEXT 10 ROWS ONLY"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Locking clauses +// ============================================================================ + +/// Test FOR UPDATE +#[test] +fn it_parses_for_update() { + let query = "SELECT * FROM users WHERE id = 1 FOR UPDATE"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test FOR SHARE +#[test] +fn it_parses_for_share() { + let query = "SELECT * FROM users WHERE id = 1 FOR SHARE"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test FOR UPDATE NOWAIT +#[test] +fn it_parses_for_update_nowait() { + let query = "SELECT * FROM users WHERE id = 1 FOR UPDATE NOWAIT"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test FOR UPDATE SKIP LOCKED +#[test] +fn it_parses_for_update_skip_locked() { + let query = "SELECT * FROM jobs WHERE status = 'pending' LIMIT 1 FOR UPDATE SKIP LOCKED"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Expression tests +// ============================================================================ + +/// Test COALESCE +#[test] +fn it_parses_coalesce() { + let query = "SELECT COALESCE(nickname, name, 'Unknown') FROM users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test NULLIF +#[test] +fn it_parses_nullif() { + let query = "SELECT NULLIF(status, 'deleted') FROM records"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test GREATEST and LEAST +#[test] +fn it_parses_greatest_least() { + let query = "SELECT GREATEST(a, b, c), LEAST(x, y, z) FROM t"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test IS NULL and IS NOT NULL +#[test] +fn it_parses_null_tests() { + let query = "SELECT * FROM users WHERE deleted_at IS NULL AND email IS NOT NULL"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test IS DISTINCT FROM +#[test] +fn it_parses_is_distinct_from() { + let query = "SELECT * FROM t WHERE a IS DISTINCT FROM b"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test BETWEEN +#[test] +fn it_parses_between() { + let query = "SELECT * FROM events WHERE created_at BETWEEN '2023-01-01' AND '2023-12-31'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test LIKE and ILIKE +#[test] +fn it_parses_like_ilike() { + let query = "SELECT * FROM users WHERE name LIKE 'John%' OR email ILIKE '%@EXAMPLE.COM'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test SIMILAR TO +#[test] +fn it_parses_similar_to() { + let query = "SELECT * FROM products WHERE name SIMILAR TO '%(phone|tablet)%'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test complex boolean expressions +#[test] +fn it_parses_complex_boolean() { + let query = "SELECT * FROM users WHERE (active = true AND verified = true) OR (role = 'admin' AND NOT suspended)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Type cast tests +// ============================================================================ + +/// Test PostgreSQL-style type cast +#[test] +fn it_parses_pg_type_cast() { + let query = "SELECT '123'::integer, '2023-01-01'::date, 'true'::boolean"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test SQL-style CAST +#[test] +fn it_parses_sql_cast() { + let query = "SELECT CAST('123' AS integer), CAST(created_at AS date) FROM t"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test array type cast +#[test] +fn it_parses_array_cast() { + let query = "SELECT ARRAY[1, 2, 3]::text[]"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Array and JSON tests +// ============================================================================ + +/// Test array constructor +#[test] +fn it_parses_array_constructor() { + let query = "SELECT ARRAY[1, 2, 3], ARRAY['a', 'b', 'c']"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test array subscript +#[test] +fn it_parses_array_subscript() { + let query = "SELECT tags[1], matrix[1][2] FROM t"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test array slice +#[test] +fn it_parses_array_slice() { + let query = "SELECT arr[2:4], arr[:3], arr[2:] FROM t"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test unnest +#[test] +fn it_parses_unnest() { + let query = "SELECT unnest(ARRAY[1, 2, 3])"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test JSON operators +#[test] +fn it_parses_json_operators() { + let query = "SELECT data->'name', data->>'email', data#>'{address,city}' FROM users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test JSONB containment +#[test] +fn it_parses_jsonb_containment() { + let query = "SELECT * FROM products WHERE metadata @> '{\"featured\": true}'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// DDL statements +// ============================================================================ + +/// Test CREATE TABLE with constraints +#[test] +fn it_parses_create_table_with_constraints() { + let query = "CREATE TABLE orders ( + id SERIAL PRIMARY KEY, + user_id INTEGER NOT NULL REFERENCES users(id), + amount DECIMAL(10, 2) CHECK (amount > 0), + status TEXT DEFAULT 'pending', + created_at TIMESTAMP DEFAULT NOW(), + UNIQUE (user_id, created_at) + )"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE TABLE AS +#[test] +fn it_parses_create_table_as() { + let query = "CREATE TABLE active_users AS SELECT * FROM users WHERE active = true"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE VIEW +#[test] +fn it_parses_create_view() { + let query = "CREATE VIEW active_users AS SELECT id, name FROM users WHERE active = true"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE MATERIALIZED VIEW +#[test] +fn it_parses_create_materialized_view() { + let query = "CREATE MATERIALIZED VIEW monthly_sales AS SELECT date_trunc('month', created_at) AS month, SUM(amount) FROM orders GROUP BY 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ALTER TABLE ADD COLUMN +#[test] +fn it_parses_alter_table_add_column() { + let query = "ALTER TABLE users ADD COLUMN email TEXT NOT NULL"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ALTER TABLE DROP COLUMN +#[test] +fn it_parses_alter_table_drop_column() { + let query = "ALTER TABLE users DROP COLUMN deprecated_field"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ALTER TABLE ADD CONSTRAINT +#[test] +fn it_parses_alter_table_add_constraint() { + let query = "ALTER TABLE orders ADD CONSTRAINT fk_user FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE INDEX with expression +#[test] +fn it_parses_create_index_expression() { + let query = "CREATE INDEX idx_lower_email ON users (lower(email))"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE UNIQUE INDEX with WHERE +#[test] +fn it_parses_partial_unique_index() { + let query = "CREATE UNIQUE INDEX idx_active_email ON users (email) WHERE active = true"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE INDEX CONCURRENTLY +#[test] +fn it_parses_create_index_concurrently() { + let query = "CREATE INDEX CONCURRENTLY idx_name ON users (name)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test TRUNCATE +#[test] +fn it_parses_truncate() { + let query = "TRUNCATE TABLE logs, audit_logs RESTART IDENTITY CASCADE"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Transaction and utility statements +// ============================================================================ + +/// Test EXPLAIN +#[test] +fn it_parses_explain() { + let query = "EXPLAIN SELECT * FROM users WHERE id = 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test EXPLAIN ANALYZE +#[test] +fn it_parses_explain_analyze() { + let query = "EXPLAIN (ANALYZE, BUFFERS, FORMAT JSON) SELECT * FROM users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test COPY +#[test] +fn it_parses_copy() { + let query = "COPY users (id, name, email) FROM STDIN WITH (FORMAT csv, HEADER true)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test PREPARE +#[test] +fn it_parses_prepare() { + let query = "PREPARE user_by_id (int) AS SELECT * FROM users WHERE id = $1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test EXECUTE +#[test] +fn it_parses_execute() { + let query = "EXECUTE user_by_id(42)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DEALLOCATE +#[test] +fn it_parses_deallocate() { + let query = "DEALLOCATE user_by_id"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Parameter placeholder tests +// ============================================================================ + +/// Test positional parameters +#[test] +fn it_parses_positional_params() { + let query = "SELECT * FROM users WHERE id = $1 AND status = $2"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test parameters in INSERT +#[test] +fn it_parses_params_in_insert() { + let query = "INSERT INTO users (name, email) VALUES ($1, $2) RETURNING id"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Complex real-world queries +// ============================================================================ + +/// Test analytics query with window functions +#[test] +fn it_parses_analytics_query() { + let query = " + SELECT + date_trunc('day', created_at) AS day, + COUNT(*) AS daily_orders, + SUM(amount) AS daily_revenue, + AVG(amount) OVER (ORDER BY date_trunc('day', created_at) ROWS BETWEEN 6 PRECEDING AND CURRENT ROW) AS weekly_avg + FROM orders + WHERE created_at >= NOW() - INTERVAL '30 days' + GROUP BY date_trunc('day', created_at) + ORDER BY day"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test hierarchical query with recursive CTE +#[test] +fn it_parses_hierarchy_query() { + let query = " + WITH RECURSIVE category_tree AS ( + SELECT id, name, parent_id, 0 AS level, ARRAY[id] AS path + FROM categories + WHERE parent_id IS NULL + UNION ALL + SELECT c.id, c.name, c.parent_id, ct.level + 1, ct.path || c.id + FROM categories c + JOIN category_tree ct ON c.parent_id = ct.id + ) + SELECT * FROM category_tree ORDER BY path"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test complex report query +#[test] +fn it_parses_complex_report_query() { + let query = " + WITH monthly_data AS ( + SELECT + date_trunc('month', o.created_at) AS month, + u.region, + p.category, + SUM(oi.quantity * oi.unit_price) AS revenue, + COUNT(DISTINCT o.id) AS order_count, + COUNT(DISTINCT o.user_id) AS customer_count + FROM orders o + JOIN users u ON o.user_id = u.id + JOIN order_items oi ON o.id = oi.order_id + JOIN products p ON oi.product_id = p.id + WHERE o.created_at >= '2023-01-01' AND o.status = 'completed' + GROUP BY 1, 2, 3 + ) + SELECT + month, + region, + category, + revenue, + order_count, + customer_count, + revenue / NULLIF(order_count, 0) AS avg_order_value, + SUM(revenue) OVER (PARTITION BY region ORDER BY month) AS cumulative_revenue + FROM monthly_data + ORDER BY month DESC, region, revenue DESC"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test query with multiple subqueries and CTEs +#[test] +fn it_parses_mixed_subqueries_and_ctes() { + let query = " + WITH high_value_customers AS ( + SELECT user_id FROM orders GROUP BY user_id HAVING SUM(amount) > 1000 + ) + SELECT u.*, + (SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id) AS total_orders, + (SELECT MAX(created_at) FROM orders o WHERE o.user_id = u.id) AS last_order + FROM users u + WHERE u.id IN (SELECT user_id FROM high_value_customers) + AND EXISTS (SELECT 1 FROM orders o WHERE o.user_id = u.id AND o.created_at > NOW() - INTERVAL '90 days') + ORDER BY (SELECT SUM(amount) FROM orders o WHERE o.user_id = u.id) DESC + LIMIT 100"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Complex INSERT tests +// ============================================================================ + +/// Test INSERT with multiple tuples +#[test] +fn it_parses_insert_multiple_rows() { + let query = "INSERT INTO users (name, email, age) VALUES ('Alice', 'alice@example.com', 25), ('Bob', 'bob@example.com', 30), ('Charlie', 'charlie@example.com', 35)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT ... SELECT +#[test] +fn it_parses_insert_select() { + let query = "INSERT INTO archived_users (id, name, email) SELECT id, name, email FROM users WHERE deleted_at IS NOT NULL"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT ... SELECT with complex query +#[test] +fn it_parses_insert_select_complex() { + let query = "INSERT INTO monthly_stats (month, user_count, order_count, total_revenue) + SELECT date_trunc('month', created_at) AS month, + COUNT(DISTINCT user_id), + COUNT(*), + SUM(amount) + FROM orders + WHERE created_at >= '2023-01-01' + GROUP BY date_trunc('month', created_at)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT with CTE +#[test] +fn it_parses_insert_with_cte() { + let query = "WITH new_data AS ( + SELECT name, email FROM temp_imports WHERE valid = true + ) + INSERT INTO users (name, email) SELECT name, email FROM new_data"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT with DEFAULT values +#[test] +fn it_parses_insert_default_values() { + let query = "INSERT INTO users (name, created_at) VALUES ('test', DEFAULT)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT with ON CONFLICT DO NOTHING +#[test] +fn it_parses_insert_on_conflict_do_nothing() { + let query = "INSERT INTO users (id, name) VALUES (1, 'test') ON CONFLICT (id) DO NOTHING"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT with ON CONFLICT with WHERE clause +#[test] +fn it_parses_insert_on_conflict_with_where() { + let query = "INSERT INTO users (id, name, updated_at) VALUES (1, 'test', NOW()) + ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, updated_at = EXCLUDED.updated_at + WHERE users.updated_at < EXCLUDED.updated_at"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT with multiple columns in ON CONFLICT +#[test] +fn it_parses_insert_on_conflict_multiple_columns() { + let query = "INSERT INTO user_settings (user_id, key, value) VALUES (1, 'theme', 'dark') + ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT with RETURNING multiple columns +#[test] +fn it_parses_insert_returning_multiple() { + let query = "INSERT INTO users (name, email) VALUES ('test', 'test@example.com') RETURNING id, created_at, name"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT with subquery in VALUES +#[test] +fn it_parses_insert_with_subquery_value() { + let query = "INSERT INTO orders (user_id, total) VALUES ((SELECT id FROM users WHERE email = 'test@example.com'), 100.00)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT with OVERRIDING +#[test] +fn it_parses_insert_overriding() { + let query = "INSERT INTO users (id, name) OVERRIDING SYSTEM VALUE VALUES (1, 'test')"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Complex UPDATE tests +// ============================================================================ + +/// Test UPDATE with multiple columns +#[test] +fn it_parses_update_multiple_columns() { + let query = "UPDATE users SET name = 'new_name', email = 'new@example.com', updated_at = NOW() WHERE id = 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE with subquery in SET +#[test] +fn it_parses_update_with_subquery_set() { + let query = "UPDATE orders SET total = (SELECT SUM(price * quantity) FROM order_items WHERE order_id = orders.id) WHERE id = 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE with FROM clause (PostgreSQL-specific JOIN update) +#[test] +fn it_parses_update_from() { + let query = "UPDATE orders o SET status = 'shipped', shipped_at = NOW() + FROM shipments s + WHERE o.id = s.order_id AND s.status = 'delivered'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE with FROM and multiple tables +#[test] +fn it_parses_update_from_multiple_tables() { + let query = "UPDATE products p SET price = p.price * (1 + d.percentage / 100) + FROM discounts d + JOIN categories c ON d.category_id = c.id + WHERE p.category_id = c.id AND d.active = true"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE with CTE +#[test] +fn it_parses_update_with_cte() { + let query = "WITH inactive_users AS ( + SELECT id FROM users WHERE last_login < NOW() - INTERVAL '1 year' + ) + UPDATE users SET status = 'inactive' WHERE id IN (SELECT id FROM inactive_users)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE with RETURNING +#[test] +fn it_parses_update_returning() { + let query = "UPDATE users SET name = 'updated' WHERE id = 1 RETURNING id, name, updated_at"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE with complex WHERE clause +#[test] +fn it_parses_update_complex_where() { + let query = "UPDATE orders SET status = 'cancelled' + WHERE created_at < NOW() - INTERVAL '30 days' + AND status = 'pending' + AND NOT EXISTS (SELECT 1 FROM payments WHERE payments.order_id = orders.id)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE with row value comparison +#[test] +fn it_parses_update_row_comparison() { + let query = "UPDATE users SET (name, email) = ('new_name', 'new@example.com') WHERE id = 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE with CASE expression +#[test] +fn it_parses_update_with_case() { + let query = "UPDATE products SET price = CASE + WHEN category = 'electronics' THEN price * 0.9 + WHEN category = 'clothing' THEN price * 0.8 + ELSE price * 0.95 + END + WHERE sale_active = true"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE with array operations +#[test] +fn it_parses_update_array() { + let query = "UPDATE users SET tags = array_append(tags, 'verified') WHERE id = 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Complex DELETE tests +// ============================================================================ + +/// Test DELETE with subquery in WHERE +#[test] +fn it_parses_delete_with_subquery() { + let query = "DELETE FROM orders WHERE user_id IN (SELECT id FROM users WHERE status = 'deleted')"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DELETE with USING clause (PostgreSQL-specific JOIN delete) +#[test] +fn it_parses_delete_using() { + let query = "DELETE FROM order_items oi USING orders o + WHERE oi.order_id = o.id AND o.status = 'cancelled'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DELETE with USING and multiple tables +#[test] +fn it_parses_delete_using_multiple_tables() { + let query = "DELETE FROM notifications n + USING users u, user_settings s + WHERE n.user_id = u.id + AND u.id = s.user_id + AND s.key = 'email_notifications' + AND s.value = 'false'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DELETE with CTE +#[test] +fn it_parses_delete_with_cte() { + let query = "WITH old_orders AS ( + SELECT id FROM orders WHERE created_at < NOW() - INTERVAL '5 years' + ) + DELETE FROM order_items WHERE order_id IN (SELECT id FROM old_orders)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DELETE with RETURNING +#[test] +fn it_parses_delete_returning() { + let query = "DELETE FROM users WHERE id = 1 RETURNING id, name, email"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DELETE with EXISTS +#[test] +fn it_parses_delete_with_exists() { + let query = "DELETE FROM products p + WHERE NOT EXISTS (SELECT 1 FROM order_items oi WHERE oi.product_id = p.id) + AND p.created_at < NOW() - INTERVAL '1 year'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DELETE with complex boolean conditions +#[test] +fn it_parses_delete_complex_conditions() { + let query = "DELETE FROM logs + WHERE (level = 'debug' AND created_at < NOW() - INTERVAL '7 days') + OR (level = 'info' AND created_at < NOW() - INTERVAL '30 days') + OR (level IN ('warning', 'error') AND created_at < NOW() - INTERVAL '90 days')"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DELETE with LIMIT (PostgreSQL extension) +#[test] +fn it_parses_delete_only() { + let query = "DELETE FROM ONLY parent_table WHERE id = 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Combined DML with CTEs +// ============================================================================ + +/// Test data modification CTE (INSERT in CTE) +#[test] +fn it_parses_insert_cte_returning() { + let query = "WITH inserted AS ( + INSERT INTO users (name, email) VALUES ('test', 'test@example.com') RETURNING id, name + ) + SELECT * FROM inserted"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE in CTE with final SELECT +#[test] +fn it_parses_update_cte_returning() { + let query = "WITH updated AS ( + UPDATE users SET last_login = NOW() WHERE id = 1 RETURNING id, name, last_login + ) + SELECT * FROM updated"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DELETE in CTE with final SELECT +#[test] +fn it_parses_delete_cte_returning() { + let query = "WITH deleted AS ( + DELETE FROM expired_sessions WHERE expires_at < NOW() RETURNING user_id + ) + SELECT COUNT(*) FROM deleted"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test chained CTEs with multiple DML operations +#[test] +fn it_parses_chained_dml_ctes() { + let query = "WITH + to_archive AS ( + SELECT id FROM users WHERE last_login < NOW() - INTERVAL '2 years' + ), + archived AS ( + INSERT INTO archived_users SELECT * FROM users WHERE id IN (SELECT id FROM to_archive) RETURNING id + ), + deleted AS ( + DELETE FROM users WHERE id IN (SELECT id FROM archived) RETURNING id + ) + SELECT COUNT(*) as archived_count FROM deleted"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} From 9a8e2ca916b81ab7204b60c2a237a3707b501f1e Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 31 Dec 2025 11:11:36 -0800 Subject: [PATCH 02/17] save --- build.rs | 1 + src/raw_parse.rs | 68 +++++++-- tests/raw_parse_tests.rs | 317 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 374 insertions(+), 12 deletions(-) diff --git a/build.rs b/build.rs index 0cdaf8c..f2269e3 100644 --- a/build.rs +++ b/build.rs @@ -140,6 +140,7 @@ fn main() -> Result<(), Box> { .allowlist_type("PartitionSpec") .allowlist_type("PartitionBoundSpec") .allowlist_type("PartitionRangeDatum") + .allowlist_type("PartitionElem") .allowlist_type("CTESearchClause") .allowlist_type("CTECycleClause") .allowlist_type("RangeSubselect") diff --git a/src/raw_parse.rs b/src/raw_parse.rs index 8455db3..5d8d3c1 100644 --- a/src/raw_parse.rs +++ b/src/raw_parse.rs @@ -335,6 +335,14 @@ unsafe fn convert_node(node_ptr: *mut bindings_raw::Node) -> Option { + let pe = node_ptr as *mut bindings_raw::PartitionElem; + Some(protobuf::node::Node::PartitionElem(Box::new(convert_partition_elem(&*pe)))) + } + bindings_raw::NodeTag_T_PartitionRangeDatum => { + let prd = node_ptr as *mut bindings_raw::PartitionRangeDatum; + Some(protobuf::node::Node::PartitionRangeDatum(Box::new(convert_partition_range_datum(&*prd)))) + } _ => { // For unhandled node types, return None // In the future, we could add more node types here @@ -1098,18 +1106,27 @@ unsafe fn convert_collate_clause_opt(cc: *mut bindings_raw::CollateClause) -> Op } unsafe fn convert_partition_spec(ps: &bindings_raw::PartitionSpec) -> protobuf::PartitionSpec { + // Map from C char values to protobuf enum values + // C: 'l'=108, 'r'=114, 'h'=104 + // Protobuf: LIST=1, RANGE=2, HASH=3 + let strategy = match ps.strategy as u8 as char { + 'l' => 1, // LIST + 'r' => 2, // RANGE + 'h' => 3, // HASH + _ => 0, // UNDEFINED + }; protobuf::PartitionSpec { - strategy: ps.strategy as i32 + 1, + strategy, part_params: convert_list_to_nodes(ps.partParams), location: ps.location, } } -unsafe fn convert_partition_spec_opt(ps: *mut bindings_raw::PartitionSpec) -> Option> { +unsafe fn convert_partition_spec_opt(ps: *mut bindings_raw::PartitionSpec) -> Option { if ps.is_null() { None } else { - Some(Box::new(convert_partition_spec(&*ps))) + Some(convert_partition_spec(&*ps)) } } @@ -1126,16 +1143,43 @@ unsafe fn convert_partition_bound_spec(pbs: &bindings_raw::PartitionBoundSpec) - } } -unsafe fn convert_partition_bound_spec_opt(pbs: *mut bindings_raw::PartitionBoundSpec) -> Option> { +unsafe fn convert_partition_bound_spec_opt(pbs: *mut bindings_raw::PartitionBoundSpec) -> Option { if pbs.is_null() { None } else { - Some(Box::new(convert_partition_bound_spec(&*pbs))) + Some(convert_partition_bound_spec(&*pbs)) + } +} + +unsafe fn convert_partition_elem(pe: &bindings_raw::PartitionElem) -> protobuf::PartitionElem { + protobuf::PartitionElem { + name: convert_c_string(pe.name), + expr: convert_node_boxed(pe.expr), + collation: convert_list_to_nodes(pe.collation), + opclass: convert_list_to_nodes(pe.opclass), + location: pe.location, + } +} + +unsafe fn convert_partition_range_datum(prd: &bindings_raw::PartitionRangeDatum) -> protobuf::PartitionRangeDatum { + // Map from C enum to protobuf enum + // C: PARTITION_RANGE_DATUM_MINVALUE=-1, PARTITION_RANGE_DATUM_VALUE=0, PARTITION_RANGE_DATUM_MAXVALUE=1 + // Protobuf: UNDEFINED=0, MINVALUE=1, VALUE=2, MAXVALUE=3 + let kind = match prd.kind { + bindings_raw::PartitionRangeDatumKind_PARTITION_RANGE_DATUM_MINVALUE => 1, + bindings_raw::PartitionRangeDatumKind_PARTITION_RANGE_DATUM_VALUE => 2, + bindings_raw::PartitionRangeDatumKind_PARTITION_RANGE_DATUM_MAXVALUE => 3, + _ => 0, + }; + protobuf::PartitionRangeDatum { + kind, + value: convert_node_boxed(prd.value), + location: prd.location, } } -unsafe fn convert_cte_search_clause(csc: &bindings_raw::CTESearchClause) -> protobuf::CtesearchClause { - protobuf::CtesearchClause { +unsafe fn convert_cte_search_clause(csc: &bindings_raw::CTESearchClause) -> protobuf::CteSearchClause { + protobuf::CteSearchClause { search_col_list: convert_list_to_nodes(csc.search_col_list), search_breadth_first: csc.search_breadth_first, search_seq_column: convert_c_string(csc.search_seq_column), @@ -1143,16 +1187,16 @@ unsafe fn convert_cte_search_clause(csc: &bindings_raw::CTESearchClause) -> prot } } -unsafe fn convert_cte_search_clause_opt(csc: *mut bindings_raw::CTESearchClause) -> Option> { +unsafe fn convert_cte_search_clause_opt(csc: *mut bindings_raw::CTESearchClause) -> Option { if csc.is_null() { None } else { - Some(Box::new(convert_cte_search_clause(&*csc))) + Some(convert_cte_search_clause(&*csc)) } } -unsafe fn convert_cte_cycle_clause(ccc: &bindings_raw::CTECycleClause) -> protobuf::CtecycleClause { - protobuf::CtecycleClause { +unsafe fn convert_cte_cycle_clause(ccc: &bindings_raw::CTECycleClause) -> protobuf::CteCycleClause { + protobuf::CteCycleClause { cycle_col_list: convert_list_to_nodes(ccc.cycle_col_list), cycle_mark_column: convert_c_string(ccc.cycle_mark_column), cycle_mark_value: convert_node_boxed(ccc.cycle_mark_value), @@ -1166,7 +1210,7 @@ unsafe fn convert_cte_cycle_clause(ccc: &bindings_raw::CTECycleClause) -> protob } } -unsafe fn convert_cte_cycle_clause_opt(ccc: *mut bindings_raw::CTECycleClause) -> Option> { +unsafe fn convert_cte_cycle_clause_opt(ccc: *mut bindings_raw::CTECycleClause) -> Option> { if ccc.is_null() { None } else { diff --git a/tests/raw_parse_tests.rs b/tests/raw_parse_tests.rs index 12afc8d..c60eb99 100644 --- a/tests/raw_parse_tests.rs +++ b/tests/raw_parse_tests.rs @@ -2068,3 +2068,320 @@ fn it_parses_chained_dml_ctes() { assert_eq!(raw_result.protobuf, proto_result.protobuf); } + +// ============================================================================ +// Tests for previously stubbed fields +// ============================================================================ + +/// Test column with COLLATE clause +#[test] +fn it_parses_column_with_collate() { + let query = "CREATE TABLE test_collate ( + name TEXT COLLATE \"C\", + description VARCHAR(255) COLLATE \"en_US.UTF-8\" + )"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test partitioned table with PARTITION BY RANGE +#[test] +fn it_parses_partition_by_range() { + let query = "CREATE TABLE measurements ( + id SERIAL, + logdate DATE NOT NULL, + peaktemp INT + ) PARTITION BY RANGE (logdate)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test partitioned table with PARTITION BY LIST +#[test] +fn it_parses_partition_by_list() { + let query = "CREATE TABLE orders ( + id SERIAL, + region TEXT NOT NULL, + order_date DATE + ) PARTITION BY LIST (region)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test partitioned table with PARTITION BY HASH +#[test] +fn it_parses_partition_by_hash() { + let query = "CREATE TABLE users_partitioned ( + id SERIAL, + username TEXT + ) PARTITION BY HASH (id)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test partition with FOR VALUES (range) +#[test] +fn it_parses_partition_for_values_range() { + let query = "CREATE TABLE measurements_2023 PARTITION OF measurements + FOR VALUES FROM ('2023-01-01') TO ('2024-01-01')"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test partition with FOR VALUES (list) +#[test] +fn it_parses_partition_for_values_list() { + let query = "CREATE TABLE orders_west PARTITION OF orders + FOR VALUES IN ('west', 'northwest', 'southwest')"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test partition with FOR VALUES (hash) +#[test] +fn it_parses_partition_for_values_hash() { + let query = "CREATE TABLE users_part_0 PARTITION OF users_partitioned + FOR VALUES WITH (MODULUS 4, REMAINDER 0)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test partition with DEFAULT +#[test] +fn it_parses_partition_default() { + let query = "CREATE TABLE orders_other PARTITION OF orders DEFAULT"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test recursive CTE with SEARCH BREADTH FIRST +#[test] +fn it_parses_cte_search_breadth_first() { + let query = "WITH RECURSIVE search_tree(id, parent_id, data, depth) AS ( + SELECT id, parent_id, data, 0 FROM tree WHERE parent_id IS NULL + UNION ALL + SELECT t.id, t.parent_id, t.data, st.depth + 1 + FROM tree t, search_tree st WHERE t.parent_id = st.id + ) SEARCH BREADTH FIRST BY id SET ordercol + SELECT * FROM search_tree ORDER BY ordercol"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test recursive CTE with SEARCH DEPTH FIRST +#[test] +fn it_parses_cte_search_depth_first() { + let query = "WITH RECURSIVE search_tree(id, parent_id, data) AS ( + SELECT id, parent_id, data FROM tree WHERE parent_id IS NULL + UNION ALL + SELECT t.id, t.parent_id, t.data + FROM tree t, search_tree st WHERE t.parent_id = st.id + ) SEARCH DEPTH FIRST BY id SET ordercol + SELECT * FROM search_tree ORDER BY ordercol"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test recursive CTE with CYCLE detection +#[test] +fn it_parses_cte_cycle() { + let query = "WITH RECURSIVE search_graph(id, link, data, depth) AS ( + SELECT g.id, g.link, g.data, 0 FROM graph g + UNION ALL + SELECT g.id, g.link, g.data, sg.depth + 1 + FROM graph g, search_graph sg WHERE g.id = sg.link + ) CYCLE id SET is_cycle USING path + SELECT * FROM search_graph"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test recursive CTE with both SEARCH and CYCLE +#[test] +fn it_parses_cte_search_and_cycle() { + let query = "WITH RECURSIVE search_graph(id, link, data, depth) AS ( + SELECT g.id, g.link, g.data, 0 FROM graph g WHERE id = 1 + UNION ALL + SELECT g.id, g.link, g.data, sg.depth + 1 + FROM graph g, search_graph sg WHERE g.id = sg.link + ) SEARCH DEPTH FIRST BY id SET ordercol + CYCLE id SET is_cycle USING path + SELECT * FROM search_graph"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Benchmark +// ============================================================================ + +/// Benchmark comparing parse_raw vs parse performance +#[test] +fn benchmark_parse_raw_vs_parse() { + use std::time::{Duration, Instant}; + + // Complex query with multiple features: CTEs, JOINs, subqueries, window functions, etc. + let query = r#" + WITH RECURSIVE + category_tree AS ( + SELECT id, name, parent_id, 0 AS depth + FROM categories + WHERE parent_id IS NULL + UNION ALL + SELECT c.id, c.name, c.parent_id, ct.depth + 1 + FROM categories c + INNER JOIN category_tree ct ON c.parent_id = ct.id + WHERE ct.depth < 10 + ), + recent_orders AS ( + SELECT + o.id, + o.user_id, + o.total_amount, + o.created_at, + ROW_NUMBER() OVER (PARTITION BY o.user_id ORDER BY o.created_at DESC) as rn + FROM orders o + WHERE o.created_at > NOW() - INTERVAL '30 days' + AND o.status IN ('completed', 'shipped', 'delivered') + ) + SELECT + u.id AS user_id, + u.email, + u.first_name || ' ' || u.last_name AS full_name, + COALESCE(ua.city, 'Unknown') AS city, + COUNT(DISTINCT ro.id) AS order_count, + SUM(ro.total_amount) AS total_spent, + AVG(ro.total_amount) AS avg_order_value, + MAX(ro.created_at) AS last_order_date, + CASE + WHEN SUM(ro.total_amount) > 10000 THEN 'platinum' + WHEN SUM(ro.total_amount) > 5000 THEN 'gold' + WHEN SUM(ro.total_amount) > 1000 THEN 'silver' + ELSE 'bronze' + END AS customer_tier, + ( + SELECT COUNT(*) + FROM user_reviews ur + WHERE ur.user_id = u.id AND ur.rating >= 4 + ) AS positive_reviews, + ARRAY_AGG(DISTINCT ct.name ORDER BY ct.name) FILTER (WHERE ct.depth = 1) AS top_categories + FROM users u + LEFT JOIN user_addresses ua ON ua.user_id = u.id AND ua.is_primary = true + LEFT JOIN recent_orders ro ON ro.user_id = u.id AND ro.rn <= 5 + LEFT JOIN order_items oi ON oi.order_id = ro.id + LEFT JOIN products p ON p.id = oi.product_id + LEFT JOIN category_tree ct ON ct.id = p.category_id + WHERE u.is_active = true + AND u.created_at < NOW() - INTERVAL '7 days' + AND EXISTS ( + SELECT 1 FROM user_logins ul + WHERE ul.user_id = u.id + AND ul.logged_in_at > NOW() - INTERVAL '90 days' + ) + GROUP BY u.id, u.email, u.first_name, u.last_name, ua.city + HAVING COUNT(DISTINCT ro.id) > 0 + ORDER BY total_spent DESC NULLS LAST, u.created_at ASC + LIMIT 100 + OFFSET 0 + FOR UPDATE OF u SKIP LOCKED + "#; + + // Verify both produce the same result first + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Warm up + for _ in 0..100 { + let _ = parse_raw(query).unwrap(); + let _ = parse(query).unwrap(); + } + + // Target ~2 seconds per benchmark (4 seconds total) + let target_duration = Duration::from_millis(2000); + + // Benchmark parse_raw + let mut raw_iterations = 0u64; + let raw_start = Instant::now(); + while raw_start.elapsed() < target_duration { + for _ in 0..100 { + let _ = parse_raw(query).unwrap(); + raw_iterations += 1; + } + } + let raw_elapsed = raw_start.elapsed(); + let raw_ns_per_iter = raw_elapsed.as_nanos() as f64 / raw_iterations as f64; + + // Benchmark parse (protobuf) + let mut proto_iterations = 0u64; + let proto_start = Instant::now(); + while proto_start.elapsed() < target_duration { + for _ in 0..100 { + let _ = parse(query).unwrap(); + proto_iterations += 1; + } + } + let proto_elapsed = proto_start.elapsed(); + let proto_ns_per_iter = proto_elapsed.as_nanos() as f64 / proto_iterations as f64; + + // Calculate speedup and time saved + let speedup = proto_ns_per_iter / raw_ns_per_iter; + let time_saved_ns = proto_ns_per_iter - raw_ns_per_iter; + let time_saved_us = time_saved_ns / 1000.0; + + // Calculate throughput (queries per second) + let raw_qps = 1_000_000_000.0 / raw_ns_per_iter; + let proto_qps = 1_000_000_000.0 / proto_ns_per_iter; + + println!("\n"); + println!("============================================================"); + println!(" parse_raw vs parse Benchmark "); + println!("============================================================"); + println!("Query: {} chars (CTEs + JOINs + subqueries + window functions)", query.len()); + println!(); + println!("┌─────────────────────────────────────────────────────────┐"); + println!("│ RESULTS │"); + println!("├─────────────────────────────────────────────────────────┤"); + println!("│ parse_raw (direct C struct reading): │"); + println!("│ Iterations: {:>10} │", raw_iterations); + println!("│ Total time: {:>10.2?} │", raw_elapsed); + println!("│ Per iteration: {:>10.2} μs │", raw_ns_per_iter / 1000.0); + println!("│ Throughput: {:>10.0} queries/sec │", raw_qps); + println!("├─────────────────────────────────────────────────────────┤"); + println!("│ parse (protobuf serialization): │"); + println!("│ Iterations: {:>10} │", proto_iterations); + println!("│ Total time: {:>10.2?} │", proto_elapsed); + println!("│ Per iteration: {:>10.2} μs │", proto_ns_per_iter / 1000.0); + println!("│ Throughput: {:>10.0} queries/sec │", proto_qps); + println!("├─────────────────────────────────────────────────────────┤"); + println!("│ COMPARISON │"); + println!("│ Speedup: {:>10.2}x faster │", speedup); + println!("│ Time saved: {:>10.2} μs per parse │", time_saved_us); + println!("│ Extra queries: {:>10.0} more queries/sec │", raw_qps - proto_qps); + println!("└─────────────────────────────────────────────────────────┘"); + println!(); +} From 76298957e01047437b285779305f964d861dbf92 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 31 Dec 2025 11:15:05 -0800 Subject: [PATCH 03/17] Use our remote --- libpg_query | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libpg_query b/libpg_query index 03e2f43..0946937 160000 --- a/libpg_query +++ b/libpg_query @@ -1 +1 @@ -Subproject commit 03e2f436c999a1d22dbce439573e8cfabced5720 +Subproject commit 09469376d81131912d61374709b8331c85831837 From 14d2b718b87d11b2e6e113375e8d20c29771b3cd Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 31 Dec 2025 11:28:11 -0800 Subject: [PATCH 04/17] rebase --- build.rs | 1 + src/ast/convert.rs | 14 +++++++++++++- src/ast/nodes.rs | 13 ++++++++++++- src/raw_parse.rs | 3 +++ 4 files changed, 29 insertions(+), 2 deletions(-) diff --git a/build.rs b/build.rs index f2269e3..27658a1 100644 --- a/build.rs +++ b/build.rs @@ -23,6 +23,7 @@ fn main() -> Result<(), Box> { let source_paths = vec![ build_path.join("pg_query").with_extension("h"), build_path.join("pg_query_raw.h"), + build_path.join("postgres_deparse").with_extension("h"), build_path.join("Makefile"), build_path.join("src"), build_path.join("protobuf"), diff --git a/src/ast/convert.rs b/src/ast/convert.rs index 65d8e1e..410e19f 100644 --- a/src/ast/convert.rs +++ b/src/ast/convert.rs @@ -977,7 +977,7 @@ impl From for GroupingSet { impl From for MergeWhenClause { fn from(pb: protobuf::MergeWhenClause) -> Self { MergeWhenClause { - matched: pb.matched, + match_kind: pb.match_kind.into(), command_type: pb.command_type.into(), override_: pb.r#override.into(), condition: pb.condition.map(|n| n.into()), @@ -1041,6 +1041,7 @@ impl From for Constraint { raw_expr: pb.raw_expr.map(|n| n.into()), cooked_expr: pb.cooked_expr, generated_when: pb.generated_when, + inhcount: pb.inhcount, nulls_not_distinct: pb.nulls_not_distinct, keys: pb.keys.into_iter().map(|n| n.into()).collect(), including: pb.including.into_iter().map(|n| n.into()).collect(), @@ -1634,6 +1635,17 @@ impl From for CmdType { } } +impl From for MergeMatchKind { + fn from(v: i32) -> Self { + match v { + 1 => MergeMatchKind::Matched, + 2 => MergeMatchKind::NotMatchedBySource, + 3 => MergeMatchKind::NotMatchedByTarget, + _ => MergeMatchKind::Undefined, + } + } +} + impl From for TransactionStmtKind { fn from(v: i32) -> Self { match v { diff --git a/src/ast/nodes.rs b/src/ast/nodes.rs index 046d16a..92f4726 100644 --- a/src/ast/nodes.rs +++ b/src/ast/nodes.rs @@ -814,7 +814,7 @@ pub struct GroupingSet { /// MERGE WHEN clause #[derive(Debug, Clone, Default)] pub struct MergeWhenClause { - pub matched: bool, + pub match_kind: MergeMatchKind, pub command_type: CmdType, pub override_: OverridingKind, pub condition: Option, @@ -875,6 +875,7 @@ pub struct Constraint { pub raw_expr: Option, pub cooked_expr: String, pub generated_when: String, + pub inhcount: i32, pub nulls_not_distinct: bool, pub keys: Vec, pub including: Vec, @@ -1392,6 +1393,16 @@ pub enum CmdType { Nothing, } +/// MERGE match kind +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum MergeMatchKind { + #[default] + Undefined, + Matched, + NotMatchedBySource, + NotMatchedByTarget, +} + /// Transaction statement kind #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum TransactionStmtKind { diff --git a/src/raw_parse.rs b/src/raw_parse.rs index 5d8d3c1..0ad61ae 100644 --- a/src/raw_parse.rs +++ b/src/raw_parse.rs @@ -849,6 +849,7 @@ unsafe fn convert_constraint(c: &bindings_raw::Constraint) -> protobuf::Constrai raw_expr: convert_node_boxed(c.raw_expr), cooked_expr: convert_c_string(c.cooked_expr), generated_when: if c.generated_when == 0 { String::new() } else { String::from_utf8_lossy(&[c.generated_when as u8]).to_string() }, + inhcount: c.inhcount, nulls_not_distinct: c.nulls_not_distinct, keys: convert_list_to_nodes(c.keys), including: convert_list_to_nodes(c.including), @@ -1057,6 +1058,8 @@ unsafe fn convert_execute_stmt(es: &bindings_raw::ExecuteStmt) -> protobuf::Exec unsafe fn convert_deallocate_stmt(ds: &bindings_raw::DeallocateStmt) -> protobuf::DeallocateStmt { protobuf::DeallocateStmt { name: convert_c_string(ds.name), + isall: ds.isall, + location: ds.location, } } From 1719b5e6aa6757999c4abd23a05e96f83fe1e781 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 31 Dec 2025 11:34:12 -0800 Subject: [PATCH 05/17] remove dead code --- src/ast/convert.rs | 1904 -------------------------------------- src/ast/mod.rs | 31 - src/ast/nodes.rs | 1628 -------------------------------- src/lib.rs | 1 - src/query.rs | 63 -- src/raw_parse.rs | 148 +-- tests/ast_tests.rs | 374 -------- tests/raw_parse_tests.rs | 5 +- 8 files changed, 30 insertions(+), 4124 deletions(-) delete mode 100644 src/ast/convert.rs delete mode 100644 src/ast/mod.rs delete mode 100644 src/ast/nodes.rs delete mode 100644 tests/ast_tests.rs diff --git a/src/ast/convert.rs b/src/ast/convert.rs deleted file mode 100644 index 410e19f..0000000 --- a/src/ast/convert.rs +++ /dev/null @@ -1,1904 +0,0 @@ -//! Conversion implementations between protobuf types and native AST types. - -use crate::protobuf; -use crate::ast::nodes::*; - -// ============================================================================ -// From protobuf to native AST types -// ============================================================================ - -impl ParseResult { - /// Create a new ParseResult from a protobuf result. - /// This stores the original protobuf for later deparsing. - pub fn from_protobuf(pb: protobuf::ParseResult) -> Self { - let stmts = pb.stmts.iter().map(|s| s.into()).collect(); - ParseResult { - version: pb.version, - stmts, - original_protobuf: pb, - } - } - - /// Get a reference to the original protobuf for deparsing. - pub fn as_protobuf(&self) -> &protobuf::ParseResult { - &self.original_protobuf - } -} - -impl From for ParseResult { - fn from(pb: protobuf::ParseResult) -> Self { - ParseResult::from_protobuf(pb) - } -} - -impl From<&protobuf::ParseResult> for ParseResult { - fn from(pb: &protobuf::ParseResult) -> Self { - ParseResult::from_protobuf(pb.clone()) - } -} - -impl From for RawStmt { - fn from(pb: protobuf::RawStmt) -> Self { - RawStmt { - stmt: pb.stmt.map(|n| (*n).into()).unwrap_or(Node::Null), - stmt_location: pb.stmt_location, - stmt_len: pb.stmt_len, - } - } -} - -impl From<&protobuf::RawStmt> for RawStmt { - fn from(pb: &protobuf::RawStmt) -> Self { - RawStmt { - stmt: pb.stmt.as_ref().map(|n| n.as_ref().into()).unwrap_or(Node::Null), - stmt_location: pb.stmt_location, - stmt_len: pb.stmt_len, - } - } -} - -impl From for Node { - fn from(pb: protobuf::Node) -> Self { - match pb.node { - Some(node) => node.into(), - None => Node::Null, - } - } -} - -impl From<&protobuf::Node> for Node { - fn from(pb: &protobuf::Node) -> Self { - match &pb.node { - Some(node) => node.into(), - None => Node::Null, - } - } -} - -impl From for Node { - fn from(pb: protobuf::node::Node) -> Self { - use protobuf::node::Node as PbNode; - match pb { - // Primitive types (not boxed) - PbNode::Integer(v) => Node::Integer(v.into()), - PbNode::Float(v) => Node::Float(v.into()), - PbNode::Boolean(v) => Node::Boolean(v.into()), - PbNode::String(v) => Node::String(v.into()), - PbNode::BitString(v) => Node::BitString(v.into()), - PbNode::List(v) => Node::List(v.items.into_iter().map(|n| n.into()).collect()), - - // Statement types (boxed in protobuf) - PbNode::SelectStmt(v) => Node::SelectStmt(Box::new((*v).into())), - PbNode::InsertStmt(v) => Node::InsertStmt(Box::new((*v).into())), - PbNode::UpdateStmt(v) => Node::UpdateStmt(Box::new((*v).into())), - PbNode::DeleteStmt(v) => Node::DeleteStmt(Box::new((*v).into())), - PbNode::MergeStmt(v) => Node::MergeStmt(Box::new((*v).into())), - - // DDL statements (not boxed in protobuf) - PbNode::CreateStmt(v) => Node::CreateStmt(Box::new(v.into())), - PbNode::AlterTableStmt(v) => Node::AlterTableStmt(Box::new(v.into())), - PbNode::DropStmt(v) => Node::DropStmt(Box::new(v.into())), - PbNode::TruncateStmt(v) => Node::TruncateStmt(Box::new(v.into())), - PbNode::IndexStmt(v) => Node::IndexStmt(Box::new((*v).into())), - PbNode::CreateSchemaStmt(v) => Node::CreateSchemaStmt(Box::new(v.into())), - PbNode::ViewStmt(v) => Node::ViewStmt(Box::new((*v).into())), - PbNode::CreateFunctionStmt(v) => Node::CreateFunctionStmt(Box::new((*v).into())), - PbNode::AlterFunctionStmt(v) => Node::AlterFunctionStmt(Box::new(v.into())), - PbNode::CreateSeqStmt(v) => Node::CreateSeqStmt(Box::new(v.into())), - PbNode::AlterSeqStmt(v) => Node::AlterSeqStmt(Box::new(v.into())), - PbNode::CreateTrigStmt(v) => Node::CreateTrigStmt(Box::new((*v).into())), - PbNode::RuleStmt(v) => Node::RuleStmt(Box::new((*v).into())), - PbNode::CreateDomainStmt(v) => Node::CreateDomainStmt(Box::new((*v).into())), - PbNode::CreateTableAsStmt(v) => Node::CreateTableAsStmt(Box::new((*v).into())), - PbNode::RefreshMatViewStmt(v) => Node::RefreshMatViewStmt(Box::new(v.into())), - - // Transaction statements (not boxed in protobuf) - PbNode::TransactionStmt(v) => Node::TransactionStmt(Box::new(v.into())), - - // Expression types (mixed boxing) - PbNode::AExpr(v) => Node::AExpr(Box::new((*v).into())), - PbNode::ColumnRef(v) => Node::ColumnRef(Box::new(v.into())), - PbNode::ParamRef(v) => Node::ParamRef(Box::new(v.into())), - PbNode::AConst(v) => Node::AConst(Box::new(v.into())), - PbNode::TypeCast(v) => Node::TypeCast(Box::new((*v).into())), - PbNode::CollateClause(v) => Node::CollateClause(Box::new((*v).into())), - PbNode::FuncCall(v) => Node::FuncCall(Box::new((*v).into())), - PbNode::AStar(_) => Node::AStar(AStar), - PbNode::AIndices(v) => Node::AIndices(Box::new((*v).into())), - PbNode::AIndirection(v) => Node::AIndirection(Box::new((*v).into())), - PbNode::AArrayExpr(v) => Node::AArrayExpr(Box::new(v.into())), - PbNode::SubLink(v) => Node::SubLink(Box::new((*v).into())), - PbNode::BoolExpr(v) => Node::BoolExpr(Box::new((*v).into())), - PbNode::NullTest(v) => Node::NullTest(Box::new((*v).into())), - PbNode::BooleanTest(v) => Node::BooleanTest(Box::new((*v).into())), - PbNode::CaseExpr(v) => Node::CaseExpr(Box::new((*v).into())), - PbNode::CaseWhen(v) => Node::CaseWhen(Box::new((*v).into())), - PbNode::CoalesceExpr(v) => Node::CoalesceExpr(Box::new((*v).into())), - PbNode::MinMaxExpr(v) => Node::MinMaxExpr(Box::new((*v).into())), - PbNode::RowExpr(v) => Node::RowExpr(Box::new((*v).into())), - - // Target/Result types (boxed in protobuf) - PbNode::ResTarget(v) => Node::ResTarget(Box::new((*v).into())), - - // Table/Range types (mixed) - PbNode::RangeVar(v) => Node::RangeVar(Box::new(v.into())), - PbNode::RangeSubselect(v) => Node::RangeSubselect(Box::new((*v).into())), - PbNode::RangeFunction(v) => Node::RangeFunction(Box::new(v.into())), - PbNode::JoinExpr(v) => Node::JoinExpr(Box::new((*v).into())), - - // Clause types (mixed) - PbNode::SortBy(v) => Node::SortBy(Box::new((*v).into())), - PbNode::WindowDef(v) => Node::WindowDef(Box::new((*v).into())), - PbNode::WithClause(v) => Node::WithClause(Box::new(v.into())), - PbNode::CommonTableExpr(v) => Node::CommonTableExpr(Box::new((*v).into())), - PbNode::IntoClause(v) => Node::IntoClause(Box::new((*v).into())), - PbNode::OnConflictClause(v) => Node::OnConflictClause(Box::new((*v).into())), - PbNode::LockingClause(v) => Node::LockingClause(Box::new(v.into())), - PbNode::GroupingSet(v) => Node::GroupingSet(Box::new(v.into())), - PbNode::MergeWhenClause(v) => Node::MergeWhenClause(Box::new((*v).into())), - - // Type-related (mixed) - PbNode::TypeName(v) => Node::TypeName(Box::new(v.into())), - PbNode::ColumnDef(v) => Node::ColumnDef(Box::new((*v).into())), - PbNode::Constraint(v) => Node::Constraint(Box::new((*v).into())), - PbNode::DefElem(v) => Node::DefElem(Box::new((*v).into())), - PbNode::IndexElem(v) => Node::IndexElem(Box::new((*v).into())), - - // Alias and role types (not boxed) - PbNode::Alias(v) => Node::Alias(Box::new(v.into())), - PbNode::RoleSpec(v) => Node::RoleSpec(Box::new(v.into())), - - // Other commonly used types (mixed) - PbNode::SortGroupClause(v) => Node::SortGroupClause(Box::new(v.into())), - PbNode::FunctionParameter(v) => Node::FunctionParameter(Box::new((*v).into())), - PbNode::AlterTableCmd(v) => Node::AlterTableCmd(Box::new((*v).into())), - PbNode::AccessPriv(v) => Node::AccessPriv(Box::new(v.into())), - PbNode::ObjectWithArgs(v) => Node::ObjectWithArgs(Box::new(v.into())), - - // Administrative statements (mixed) - PbNode::VariableSetStmt(v) => Node::VariableSetStmt(Box::new(v.into())), - PbNode::VariableShowStmt(v) => Node::VariableShowStmt(Box::new(v.into())), - PbNode::ExplainStmt(v) => Node::ExplainStmt(Box::new((*v).into())), - PbNode::CopyStmt(v) => Node::CopyStmt(Box::new((*v).into())), - PbNode::GrantStmt(v) => Node::GrantStmt(Box::new(v.into())), - PbNode::GrantRoleStmt(v) => Node::GrantRoleStmt(Box::new(v.into())), - PbNode::LockStmt(v) => Node::LockStmt(Box::new(v.into())), - PbNode::VacuumStmt(v) => Node::VacuumStmt(Box::new(v.into())), - - // Other statements (mixed) - PbNode::DoStmt(v) => Node::DoStmt(Box::new(v.into())), - PbNode::RenameStmt(v) => Node::RenameStmt(Box::new((*v).into())), - PbNode::NotifyStmt(v) => Node::NotifyStmt(Box::new(v.into())), - PbNode::ListenStmt(v) => Node::ListenStmt(Box::new(v.into())), - PbNode::UnlistenStmt(v) => Node::UnlistenStmt(Box::new(v.into())), - PbNode::CheckPointStmt(_) => Node::CheckPointStmt(Box::new(CheckPointStmt)), - PbNode::DiscardStmt(v) => Node::DiscardStmt(Box::new(v.into())), - PbNode::PrepareStmt(v) => Node::PrepareStmt(Box::new((*v).into())), - PbNode::ExecuteStmt(v) => Node::ExecuteStmt(Box::new(v.into())), - PbNode::DeallocateStmt(v) => Node::DeallocateStmt(Box::new(v.into())), - PbNode::ClosePortalStmt(v) => Node::ClosePortalStmt(Box::new(v.into())), - PbNode::FetchStmt(v) => Node::FetchStmt(Box::new(v.into())), - - // Fallback for any unhandled node types - other => Node::Other(protobuf::Node { node: Some(other) }), - } - } -} - -impl From<&protobuf::node::Node> for Node { - fn from(pb: &protobuf::node::Node) -> Self { - pb.clone().into() - } -} - -// Conversions from Box for boxed protobuf fields -impl From> for Node { - fn from(pb: Box) -> Self { - (*pb).into() - } -} - -impl From> for IntoClause { - fn from(pb: Box) -> Self { - (*pb).into() - } -} - -impl From> for OnConflictClause { - fn from(pb: Box) -> Self { - (*pb).into() - } -} - -impl From> for CollateClause { - fn from(pb: Box) -> Self { - (*pb).into() - } -} - -impl From> for SelectStmt { - fn from(pb: Box) -> Self { - (*pb).into() - } -} - -// Primitive type conversions -impl From for Integer { - fn from(pb: protobuf::Integer) -> Self { - Integer { ival: pb.ival } - } -} - -impl From for Float { - fn from(pb: protobuf::Float) -> Self { - Float { fval: pb.fval } - } -} - -impl From for Boolean { - fn from(pb: protobuf::Boolean) -> Self { - Boolean { boolval: pb.boolval } - } -} - -impl From for StringValue { - fn from(pb: protobuf::String) -> Self { - StringValue { sval: pb.sval } - } -} - -impl From for BitString { - fn from(pb: protobuf::BitString) -> Self { - BitString { bsval: pb.bsval } - } -} - -// Statement type conversions -impl From for SelectStmt { - fn from(pb: protobuf::SelectStmt) -> Self { - SelectStmt { - distinct_clause: pb.distinct_clause.into_iter().map(|n| n.into()).collect(), - into_clause: pb.into_clause.map(|v| v.into()), - target_list: pb.target_list.into_iter().map(|n| n.into()).collect(), - from_clause: pb.from_clause.into_iter().map(|n| n.into()).collect(), - where_clause: pb.where_clause.map(|n| n.into()), - group_clause: pb.group_clause.into_iter().map(|n| n.into()).collect(), - group_distinct: pb.group_distinct, - having_clause: pb.having_clause.map(|n| n.into()), - window_clause: pb.window_clause.into_iter().map(|n| n.into()).collect(), - values_lists: pb.values_lists.into_iter().map(|n| n.into()).collect(), - sort_clause: pb.sort_clause.into_iter().map(|n| n.into()).collect(), - limit_offset: pb.limit_offset.map(|n| n.into()), - limit_count: pb.limit_count.map(|n| n.into()), - limit_option: pb.limit_option.into(), - locking_clause: pb.locking_clause.into_iter().map(|n| n.into()).collect(), - with_clause: pb.with_clause.map(|v| v.into()), - op: pb.op.into(), - all: pb.all, - larg: pb.larg.map(|v| Box::new((*v).into())), - rarg: pb.rarg.map(|v| Box::new((*v).into())), - } - } -} - -impl From for InsertStmt { - fn from(pb: protobuf::InsertStmt) -> Self { - InsertStmt { - relation: pb.relation.map(|v| v.into()), - cols: pb.cols.into_iter().map(|n| n.into()).collect(), - select_stmt: pb.select_stmt.map(|n| n.into()), - on_conflict_clause: pb.on_conflict_clause.map(|v| v.into()), - returning_list: pb.returning_list.into_iter().map(|n| n.into()).collect(), - with_clause: pb.with_clause.map(|v| v.into()), - override_: pb.r#override.into(), - } - } -} - -impl From for UpdateStmt { - fn from(pb: protobuf::UpdateStmt) -> Self { - UpdateStmt { - relation: pb.relation.map(|v| v.into()), - target_list: pb.target_list.into_iter().map(|n| n.into()).collect(), - where_clause: pb.where_clause.map(|n| n.into()), - from_clause: pb.from_clause.into_iter().map(|n| n.into()).collect(), - returning_list: pb.returning_list.into_iter().map(|n| n.into()).collect(), - with_clause: pb.with_clause.map(|v| v.into()), - } - } -} - -impl From for DeleteStmt { - fn from(pb: protobuf::DeleteStmt) -> Self { - DeleteStmt { - relation: pb.relation.map(|v| v.into()), - using_clause: pb.using_clause.into_iter().map(|n| n.into()).collect(), - where_clause: pb.where_clause.map(|n| n.into()), - returning_list: pb.returning_list.into_iter().map(|n| n.into()).collect(), - with_clause: pb.with_clause.map(|v| v.into()), - } - } -} - -impl From for MergeStmt { - fn from(pb: protobuf::MergeStmt) -> Self { - MergeStmt { - relation: pb.relation.map(|v| v.into()), - source_relation: pb.source_relation.map(|n| n.into()), - join_condition: pb.join_condition.map(|n| n.into()), - merge_when_clauses: pb.merge_when_clauses.into_iter().map(|n| n.into()).collect(), - with_clause: pb.with_clause.map(|v| v.into()), - } - } -} - -// DDL statement conversions -impl From for CreateStmt { - fn from(pb: protobuf::CreateStmt) -> Self { - CreateStmt { - relation: pb.relation.map(|v| v.into()), - table_elts: pb.table_elts.into_iter().map(|n| n.into()).collect(), - inh_relations: pb.inh_relations.into_iter().map(|n| n.into()).collect(), - partbound: pb.partbound.map(|n| Node::Other(protobuf::Node { node: Some(protobuf::node::Node::PartitionBoundSpec(n)) })), - partspec: pb.partspec.map(|n| Node::Other(protobuf::Node { node: Some(protobuf::node::Node::PartitionSpec(n)) })), - of_typename: pb.of_typename.map(|v| v.into()), - constraints: pb.constraints.into_iter().map(|n| n.into()).collect(), - options: pb.options.into_iter().map(|n| n.into()).collect(), - oncommit: pb.oncommit.into(), - tablespacename: pb.tablespacename, - access_method: pb.access_method, - if_not_exists: pb.if_not_exists, - } - } -} - -impl From for AlterTableStmt { - fn from(pb: protobuf::AlterTableStmt) -> Self { - AlterTableStmt { - relation: pb.relation.map(|v| v.into()), - cmds: pb.cmds.into_iter().map(|n| n.into()).collect(), - objtype: pb.objtype.into(), - missing_ok: pb.missing_ok, - } - } -} - -impl From for DropStmt { - fn from(pb: protobuf::DropStmt) -> Self { - DropStmt { - objects: pb.objects.into_iter().map(|n| n.into()).collect(), - remove_type: pb.remove_type.into(), - behavior: pb.behavior.into(), - missing_ok: pb.missing_ok, - concurrent: pb.concurrent, - } - } -} - -impl From for TruncateStmt { - fn from(pb: protobuf::TruncateStmt) -> Self { - TruncateStmt { - relations: pb.relations.into_iter().map(|n| n.into()).collect(), - restart_seqs: pb.restart_seqs, - behavior: pb.behavior.into(), - } - } -} - -impl From for IndexStmt { - fn from(pb: protobuf::IndexStmt) -> Self { - IndexStmt { - idxname: pb.idxname, - relation: pb.relation.map(|v| v.into()), - access_method: pb.access_method, - table_space: pb.table_space, - index_params: pb.index_params.into_iter().map(|n| n.into()).collect(), - index_including_params: pb.index_including_params.into_iter().map(|n| n.into()).collect(), - options: pb.options.into_iter().map(|n| n.into()).collect(), - where_clause: pb.where_clause.map(|n| n.into()), - exclude_op_names: pb.exclude_op_names.into_iter().map(|n| n.into()).collect(), - idxcomment: pb.idxcomment, - index_oid: pb.index_oid, - old_number: pb.old_number, - old_first_relfilelocator: pb.old_first_relfilelocator_subid, - unique: pb.unique, - nulls_not_distinct: pb.nulls_not_distinct, - primary: pb.primary, - is_constraint: pb.isconstraint, - deferrable: pb.deferrable, - initdeferred: pb.initdeferred, - transformed: pb.transformed, - concurrent: pb.concurrent, - if_not_exists: pb.if_not_exists, - reset_default_tblspc: pb.reset_default_tblspc, - } - } -} - -impl From for CreateSchemaStmt { - fn from(pb: protobuf::CreateSchemaStmt) -> Self { - CreateSchemaStmt { - schemaname: pb.schemaname, - authrole: pb.authrole.map(|v| v.into()), - schema_elts: pb.schema_elts.into_iter().map(|n| n.into()).collect(), - if_not_exists: pb.if_not_exists, - } - } -} - -impl From for ViewStmt { - fn from(pb: protobuf::ViewStmt) -> Self { - ViewStmt { - view: pb.view.map(|v| v.into()), - aliases: pb.aliases.into_iter().map(|n| n.into()).collect(), - query: pb.query.map(|n| n.into()), - replace: pb.replace, - options: pb.options.into_iter().map(|n| n.into()).collect(), - with_check_option: pb.with_check_option.into(), - } - } -} - -impl From for CreateFunctionStmt { - fn from(pb: protobuf::CreateFunctionStmt) -> Self { - CreateFunctionStmt { - is_procedure: pb.is_procedure, - replace: pb.replace, - funcname: pb.funcname.into_iter().map(|n| n.into()).collect(), - parameters: pb.parameters.into_iter().map(|n| n.into()).collect(), - return_type: pb.return_type.map(|v| v.into()), - options: pb.options.into_iter().map(|n| n.into()).collect(), - sql_body: pb.sql_body.map(|n| n.into()), - } - } -} - -impl From for AlterFunctionStmt { - fn from(pb: protobuf::AlterFunctionStmt) -> Self { - AlterFunctionStmt { - objtype: pb.objtype.into(), - func: pb.func.map(|v| v.into()), - actions: pb.actions.into_iter().map(|n| n.into()).collect(), - } - } -} - -impl From for CreateSeqStmt { - fn from(pb: protobuf::CreateSeqStmt) -> Self { - CreateSeqStmt { - sequence: pb.sequence.map(|v| v.into()), - options: pb.options.into_iter().map(|n| n.into()).collect(), - owner_id: pb.owner_id, - for_identity: pb.for_identity, - if_not_exists: pb.if_not_exists, - } - } -} - -impl From for AlterSeqStmt { - fn from(pb: protobuf::AlterSeqStmt) -> Self { - AlterSeqStmt { - sequence: pb.sequence.map(|v| v.into()), - options: pb.options.into_iter().map(|n| n.into()).collect(), - for_identity: pb.for_identity, - missing_ok: pb.missing_ok, - } - } -} - -impl From for CreateTrigStmt { - fn from(pb: protobuf::CreateTrigStmt) -> Self { - CreateTrigStmt { - replace: pb.replace, - isconstraint: pb.isconstraint, - trigname: pb.trigname, - relation: pb.relation.map(|v| v.into()), - funcname: pb.funcname.into_iter().map(|n| n.into()).collect(), - args: pb.args.into_iter().map(|n| n.into()).collect(), - row: pb.row, - timing: pb.timing, - events: pb.events, - columns: pb.columns.into_iter().map(|n| n.into()).collect(), - when_clause: pb.when_clause.map(|n| n.into()), - transition_rels: pb.transition_rels.into_iter().map(|n| n.into()).collect(), - deferrable: pb.deferrable, - initdeferred: pb.initdeferred, - constrrel: pb.constrrel.map(|v| v.into()), - } - } -} - -impl From for RuleStmt { - fn from(pb: protobuf::RuleStmt) -> Self { - RuleStmt { - relation: pb.relation.map(|v| v.into()), - rulename: pb.rulename, - where_clause: pb.where_clause.map(|n| n.into()), - event: pb.event.into(), - instead: pb.instead, - actions: pb.actions.into_iter().map(|n| n.into()).collect(), - replace: pb.replace, - } - } -} - -impl From for CreateDomainStmt { - fn from(pb: protobuf::CreateDomainStmt) -> Self { - CreateDomainStmt { - domainname: pb.domainname.into_iter().map(|n| n.into()).collect(), - type_name: pb.type_name.map(|v| v.into()), - coll_clause: pb.coll_clause.map(|v| v.into()), - constraints: pb.constraints.into_iter().map(|n| n.into()).collect(), - } - } -} - -impl From for CreateTableAsStmt { - fn from(pb: protobuf::CreateTableAsStmt) -> Self { - CreateTableAsStmt { - query: pb.query.map(|n| n.into()), - into: pb.into.map(|v| v.into()), - objtype: pb.objtype.into(), - is_select_into: pb.is_select_into, - if_not_exists: pb.if_not_exists, - } - } -} - -impl From for RefreshMatViewStmt { - fn from(pb: protobuf::RefreshMatViewStmt) -> Self { - RefreshMatViewStmt { - concurrent: pb.concurrent, - skip_data: pb.skip_data, - relation: pb.relation.map(|v| v.into()), - } - } -} - -// Transaction statement -impl From for TransactionStmt { - fn from(pb: protobuf::TransactionStmt) -> Self { - TransactionStmt { - kind: pb.kind.into(), - options: pb.options.into_iter().map(|n| n.into()).collect(), - savepoint_name: pb.savepoint_name, - gid: pb.gid, - chain: pb.chain, - } - } -} - -// Expression type conversions -impl From for AExpr { - fn from(pb: protobuf::AExpr) -> Self { - AExpr { - kind: pb.kind.into(), - name: pb.name.into_iter().map(|n| n.into()).collect(), - lexpr: pb.lexpr.map(|n| n.into()), - rexpr: pb.rexpr.map(|n| n.into()), - location: pb.location, - } - } -} - -impl From for ColumnRef { - fn from(pb: protobuf::ColumnRef) -> Self { - ColumnRef { - fields: pb.fields.into_iter().map(|n| n.into()).collect(), - location: pb.location, - } - } -} - -impl From for ParamRef { - fn from(pb: protobuf::ParamRef) -> Self { - ParamRef { - number: pb.number, - location: pb.location, - } - } -} - -impl From for AConst { - fn from(pb: protobuf::AConst) -> Self { - AConst { - val: pb.val.map(|v| v.into()), - isnull: pb.isnull, - location: pb.location, - } - } -} - -impl From for AConstValue { - fn from(pb: protobuf::a_const::Val) -> Self { - use protobuf::a_const::Val; - match pb { - Val::Ival(v) => AConstValue::Integer(v.into()), - Val::Fval(v) => AConstValue::Float(v.into()), - Val::Boolval(v) => AConstValue::Boolean(v.into()), - Val::Sval(v) => AConstValue::String(v.into()), - Val::Bsval(v) => AConstValue::BitString(v.into()), - } - } -} - -impl From for TypeCast { - fn from(pb: protobuf::TypeCast) -> Self { - TypeCast { - arg: pb.arg.map(|n| n.into()), - type_name: pb.type_name.map(|v| v.into()), - location: pb.location, - } - } -} - -impl From for CollateClause { - fn from(pb: protobuf::CollateClause) -> Self { - CollateClause { - arg: pb.arg.map(|n| n.into()), - collname: pb.collname.into_iter().map(|n| n.into()).collect(), - location: pb.location, - } - } -} - -impl From for FuncCall { - fn from(pb: protobuf::FuncCall) -> Self { - FuncCall { - funcname: pb.funcname.into_iter().map(|n| n.into()).collect(), - args: pb.args.into_iter().map(|n| n.into()).collect(), - agg_order: pb.agg_order.into_iter().map(|n| n.into()).collect(), - agg_filter: pb.agg_filter.map(|n| n.into()), - over: pb.over.map(|v| (*v).into()), - agg_within_group: pb.agg_within_group, - agg_star: pb.agg_star, - agg_distinct: pb.agg_distinct, - func_variadic: pb.func_variadic, - funcformat: pb.funcformat.into(), - location: pb.location, - } - } -} - -impl From for AIndices { - fn from(pb: protobuf::AIndices) -> Self { - AIndices { - is_slice: pb.is_slice, - lidx: pb.lidx.map(|n| n.into()), - uidx: pb.uidx.map(|n| n.into()), - } - } -} - -impl From for AIndirection { - fn from(pb: protobuf::AIndirection) -> Self { - AIndirection { - arg: pb.arg.map(|n| n.into()), - indirection: pb.indirection.into_iter().map(|n| n.into()).collect(), - } - } -} - -impl From for AArrayExpr { - fn from(pb: protobuf::AArrayExpr) -> Self { - AArrayExpr { - elements: pb.elements.into_iter().map(|n| n.into()).collect(), - location: pb.location, - } - } -} - -impl From for SubLink { - fn from(pb: protobuf::SubLink) -> Self { - SubLink { - sub_link_type: pb.sub_link_type.into(), - sub_link_id: pb.sub_link_id, - testexpr: pb.testexpr.map(|n| n.into()), - oper_name: pb.oper_name.into_iter().map(|n| n.into()).collect(), - subselect: pb.subselect.map(|n| n.into()), - location: pb.location, - } - } -} - -impl From for BoolExpr { - fn from(pb: protobuf::BoolExpr) -> Self { - BoolExpr { - boolop: pb.boolop.into(), - args: pb.args.into_iter().map(|n| n.into()).collect(), - location: pb.location, - } - } -} - -impl From for NullTest { - fn from(pb: protobuf::NullTest) -> Self { - NullTest { - arg: pb.arg.map(|n| n.into()), - nulltesttype: pb.nulltesttype.into(), - argisrow: pb.argisrow, - location: pb.location, - } - } -} - -impl From for BooleanTest { - fn from(pb: protobuf::BooleanTest) -> Self { - BooleanTest { - arg: pb.arg.map(|n| n.into()), - booltesttype: pb.booltesttype.into(), - location: pb.location, - } - } -} - -impl From for CaseExpr { - fn from(pb: protobuf::CaseExpr) -> Self { - CaseExpr { - arg: pb.arg.map(|n| n.into()), - args: pb.args.into_iter().map(|n| n.into()).collect(), - defresult: pb.defresult.map(|n| n.into()), - location: pb.location, - } - } -} - -impl From for CaseWhen { - fn from(pb: protobuf::CaseWhen) -> Self { - CaseWhen { - expr: pb.expr.map(|n| n.into()), - result: pb.result.map(|n| n.into()), - location: pb.location, - } - } -} - -impl From for CoalesceExpr { - fn from(pb: protobuf::CoalesceExpr) -> Self { - CoalesceExpr { - args: pb.args.into_iter().map(|n| n.into()).collect(), - location: pb.location, - } - } -} - -impl From for MinMaxExpr { - fn from(pb: protobuf::MinMaxExpr) -> Self { - MinMaxExpr { - op: pb.op.into(), - args: pb.args.into_iter().map(|n| n.into()).collect(), - location: pb.location, - } - } -} - -impl From for RowExpr { - fn from(pb: protobuf::RowExpr) -> Self { - RowExpr { - args: pb.args.into_iter().map(|n| n.into()).collect(), - row_format: pb.row_format.into(), - colnames: pb.colnames.into_iter().map(|n| n.into()).collect(), - location: pb.location, - } - } -} - -// Target/Result type conversions -impl From for ResTarget { - fn from(pb: protobuf::ResTarget) -> Self { - ResTarget { - name: pb.name, - indirection: pb.indirection.into_iter().map(|n| n.into()).collect(), - val: pb.val.map(|n| n.into()), - location: pb.location, - } - } -} - -// Table/Range type conversions -impl From for RangeVar { - fn from(pb: protobuf::RangeVar) -> Self { - RangeVar { - catalogname: pb.catalogname, - schemaname: pb.schemaname, - relname: pb.relname, - inh: pb.inh, - relpersistence: pb.relpersistence, - alias: pb.alias.map(|v| v.into()), - location: pb.location, - } - } -} - -impl From for RangeSubselect { - fn from(pb: protobuf::RangeSubselect) -> Self { - RangeSubselect { - lateral: pb.lateral, - subquery: pb.subquery.map(|n| n.into()), - alias: pb.alias.map(|v| v.into()), - } - } -} - -impl From for RangeFunction { - fn from(pb: protobuf::RangeFunction) -> Self { - RangeFunction { - lateral: pb.lateral, - ordinality: pb.ordinality, - is_rowsfrom: pb.is_rowsfrom, - functions: pb.functions.into_iter().map(|n| n.into()).collect(), - alias: pb.alias.map(|v| v.into()), - coldeflist: pb.coldeflist.into_iter().map(|n| n.into()).collect(), - } - } -} - -impl From for JoinExpr { - fn from(pb: protobuf::JoinExpr) -> Self { - JoinExpr { - jointype: pb.jointype.into(), - is_natural: pb.is_natural, - larg: pb.larg.map(|n| n.into()), - rarg: pb.rarg.map(|n| n.into()), - using_clause: pb.using_clause.into_iter().map(|n| n.into()).collect(), - join_using_alias: pb.join_using_alias.map(|v| v.into()), - quals: pb.quals.map(|n| n.into()), - alias: pb.alias.map(|v| v.into()), - rtindex: pb.rtindex, - } - } -} - -// Clause type conversions -impl From for SortBy { - fn from(pb: protobuf::SortBy) -> Self { - SortBy { - node: pb.node.map(|n| n.into()), - sortby_dir: pb.sortby_dir.into(), - sortby_nulls: pb.sortby_nulls.into(), - use_op: pb.use_op.into_iter().map(|n| n.into()).collect(), - location: pb.location, - } - } -} - -impl From for WindowDef { - fn from(pb: protobuf::WindowDef) -> Self { - WindowDef { - name: pb.name, - refname: pb.refname, - partition_clause: pb.partition_clause.into_iter().map(|n| n.into()).collect(), - order_clause: pb.order_clause.into_iter().map(|n| n.into()).collect(), - frame_options: pb.frame_options, - start_offset: pb.start_offset.map(|n| n.into()), - end_offset: pb.end_offset.map(|n| n.into()), - location: pb.location, - } - } -} - -impl From for WithClause { - fn from(pb: protobuf::WithClause) -> Self { - WithClause { - ctes: pb.ctes.into_iter().map(|n| n.into()).collect(), - recursive: pb.recursive, - location: pb.location, - } - } -} - -impl From for CommonTableExpr { - fn from(pb: protobuf::CommonTableExpr) -> Self { - CommonTableExpr { - ctename: pb.ctename, - aliascolnames: pb.aliascolnames.into_iter().map(|n| n.into()).collect(), - ctematerialized: pb.ctematerialized.into(), - ctequery: pb.ctequery.map(|n| n.into()), - search_clause: pb.search_clause.map(|n| Node::Other(protobuf::Node { node: Some(protobuf::node::Node::CtesearchClause(n)) })), - cycle_clause: pb.cycle_clause.map(|n| Node::Other(protobuf::Node { node: Some(protobuf::node::Node::CtecycleClause(n)) })), - location: pb.location, - cterecursive: pb.cterecursive, - cterefcount: pb.cterefcount, - ctecolnames: pb.ctecolnames.into_iter().map(|n| n.into()).collect(), - ctecoltypes: pb.ctecoltypes.into_iter().map(|n| n.into()).collect(), - ctecoltypmods: pb.ctecoltypmods.into_iter().map(|n| n.into()).collect(), - ctecolcollations: pb.ctecolcollations.into_iter().map(|n| n.into()).collect(), - } - } -} - -impl From for IntoClause { - fn from(pb: protobuf::IntoClause) -> Self { - IntoClause { - rel: pb.rel.map(|v| v.into()), - col_names: pb.col_names.into_iter().map(|n| n.into()).collect(), - access_method: pb.access_method, - options: pb.options.into_iter().map(|n| n.into()).collect(), - on_commit: pb.on_commit.into(), - table_space_name: pb.table_space_name, - view_query: pb.view_query.map(|n| n.into()), - skip_data: pb.skip_data, - } - } -} - -impl From for OnConflictClause { - fn from(pb: protobuf::OnConflictClause) -> Self { - OnConflictClause { - action: pb.action.into(), - infer: pb.infer.map(|n| Node::Other(protobuf::Node { node: Some(protobuf::node::Node::InferClause(n)) })), - target_list: pb.target_list.into_iter().map(|n| n.into()).collect(), - where_clause: pb.where_clause.map(|n| n.into()), - location: pb.location, - } - } -} - -impl From for LockingClause { - fn from(pb: protobuf::LockingClause) -> Self { - LockingClause { - locked_rels: pb.locked_rels.into_iter().map(|n| n.into()).collect(), - strength: pb.strength.into(), - wait_policy: pb.wait_policy.into(), - } - } -} - -impl From for GroupingSet { - fn from(pb: protobuf::GroupingSet) -> Self { - GroupingSet { - kind: pb.kind.into(), - content: pb.content.into_iter().map(|n| n.into()).collect(), - location: pb.location, - } - } -} - -impl From for MergeWhenClause { - fn from(pb: protobuf::MergeWhenClause) -> Self { - MergeWhenClause { - match_kind: pb.match_kind.into(), - command_type: pb.command_type.into(), - override_: pb.r#override.into(), - condition: pb.condition.map(|n| n.into()), - target_list: pb.target_list.into_iter().map(|n| n.into()).collect(), - values: pb.values.into_iter().map(|n| n.into()).collect(), - } - } -} - -// Type-related conversions -impl From for TypeName { - fn from(pb: protobuf::TypeName) -> Self { - TypeName { - names: pb.names.into_iter().map(|n| n.into()).collect(), - type_oid: pb.type_oid, - setof: pb.setof, - pct_type: pb.pct_type, - typmods: pb.typmods.into_iter().map(|n| n.into()).collect(), - typemod: pb.typemod, - array_bounds: pb.array_bounds.into_iter().map(|n| n.into()).collect(), - location: pb.location, - } - } -} - -impl From for ColumnDef { - fn from(pb: protobuf::ColumnDef) -> Self { - ColumnDef { - colname: pb.colname, - type_name: pb.type_name.map(|v| v.into()), - compression: pb.compression, - inhcount: pb.inhcount, - is_local: pb.is_local, - is_not_null: pb.is_not_null, - is_from_type: pb.is_from_type, - storage: pb.storage, - storage_name: pb.storage_name, - raw_default: pb.raw_default.map(|n| n.into()), - cooked_default: pb.cooked_default.map(|n| n.into()), - identity: pb.identity, - identity_sequence: pb.identity_sequence.map(|v| v.into()), - generated: pb.generated, - coll_clause: pb.coll_clause.map(|v| v.into()), - coll_oid: pb.coll_oid, - constraints: pb.constraints.into_iter().map(|n| n.into()).collect(), - fdwoptions: pb.fdwoptions.into_iter().map(|n| n.into()).collect(), - location: pb.location, - } - } -} - -impl From for Constraint { - fn from(pb: protobuf::Constraint) -> Self { - Constraint { - contype: pb.contype.into(), - conname: pb.conname, - deferrable: pb.deferrable, - initdeferred: pb.initdeferred, - location: pb.location, - is_no_inherit: pb.is_no_inherit, - raw_expr: pb.raw_expr.map(|n| n.into()), - cooked_expr: pb.cooked_expr, - generated_when: pb.generated_when, - inhcount: pb.inhcount, - nulls_not_distinct: pb.nulls_not_distinct, - keys: pb.keys.into_iter().map(|n| n.into()).collect(), - including: pb.including.into_iter().map(|n| n.into()).collect(), - exclusions: pb.exclusions.into_iter().map(|n| n.into()).collect(), - options: pb.options.into_iter().map(|n| n.into()).collect(), - indexname: pb.indexname, - indexspace: pb.indexspace, - reset_default_tblspc: pb.reset_default_tblspc, - access_method: pb.access_method, - where_clause: pb.where_clause.map(|n| n.into()), - pktable: pb.pktable.map(|v| v.into()), - fk_attrs: pb.fk_attrs.into_iter().map(|n| n.into()).collect(), - pk_attrs: pb.pk_attrs.into_iter().map(|n| n.into()).collect(), - fk_matchtype: pb.fk_matchtype, - fk_upd_action: pb.fk_upd_action, - fk_del_action: pb.fk_del_action, - fk_del_set_cols: pb.fk_del_set_cols.into_iter().map(|n| n.into()).collect(), - old_conpfeqop: pb.old_conpfeqop.into_iter().map(|n| n.into()).collect(), - old_pktable_oid: pb.old_pktable_oid, - skip_validation: pb.skip_validation, - initially_valid: pb.initially_valid, - } - } -} - -impl From for DefElem { - fn from(pb: protobuf::DefElem) -> Self { - DefElem { - defnamespace: pb.defnamespace, - defname: pb.defname, - arg: pb.arg.map(|n| n.into()), - defaction: pb.defaction.into(), - location: pb.location, - } - } -} - -impl From for IndexElem { - fn from(pb: protobuf::IndexElem) -> Self { - IndexElem { - name: pb.name, - expr: pb.expr.map(|n| n.into()), - indexcolname: pb.indexcolname, - collation: pb.collation.into_iter().map(|n| n.into()).collect(), - opclass: pb.opclass.into_iter().map(|n| n.into()).collect(), - opclassopts: pb.opclassopts.into_iter().map(|n| n.into()).collect(), - ordering: pb.ordering.into(), - nulls_ordering: pb.nulls_ordering.into(), - } - } -} - -// Alias and role type conversions -impl From for Alias { - fn from(pb: protobuf::Alias) -> Self { - Alias { - aliasname: pb.aliasname, - colnames: pb.colnames.into_iter().map(|n| n.into()).collect(), - } - } -} - -impl From for RoleSpec { - fn from(pb: protobuf::RoleSpec) -> Self { - RoleSpec { - roletype: pb.roletype.into(), - rolename: pb.rolename, - location: pb.location, - } - } -} - -// Other type conversions -impl From for SortGroupClause { - fn from(pb: protobuf::SortGroupClause) -> Self { - SortGroupClause { - tle_sort_group_ref: pb.tle_sort_group_ref, - eqop: pb.eqop, - sortop: pb.sortop, - nulls_first: pb.nulls_first, - hashable: pb.hashable, - } - } -} - -impl From for FunctionParameter { - fn from(pb: protobuf::FunctionParameter) -> Self { - FunctionParameter { - name: pb.name, - arg_type: pb.arg_type.map(|v| v.into()), - mode: pb.mode.into(), - defexpr: pb.defexpr.map(|n| n.into()), - } - } -} - -impl From for AlterTableCmd { - fn from(pb: protobuf::AlterTableCmd) -> Self { - AlterTableCmd { - subtype: pb.subtype.into(), - name: pb.name, - num: pb.num as i16, - newowner: pb.newowner.map(|v| v.into()), - def: pb.def.map(|n| n.into()), - behavior: pb.behavior.into(), - missing_ok: pb.missing_ok, - recurse: pb.recurse, - } - } -} - -impl From for AccessPriv { - fn from(pb: protobuf::AccessPriv) -> Self { - AccessPriv { - priv_name: pb.priv_name, - cols: pb.cols.into_iter().map(|n| n.into()).collect(), - } - } -} - -impl From for ObjectWithArgs { - fn from(pb: protobuf::ObjectWithArgs) -> Self { - ObjectWithArgs { - objname: pb.objname.into_iter().map(|n| n.into()).collect(), - objargs: pb.objargs.into_iter().map(|n| n.into()).collect(), - objfuncargs: pb.objfuncargs.into_iter().map(|n| n.into()).collect(), - args_unspecified: pb.args_unspecified, - } - } -} - -// Administrative statement conversions -impl From for VariableSetStmt { - fn from(pb: protobuf::VariableSetStmt) -> Self { - VariableSetStmt { - kind: pb.kind.into(), - name: pb.name, - args: pb.args.into_iter().map(|n| n.into()).collect(), - is_local: pb.is_local, - } - } -} - -impl From for VariableShowStmt { - fn from(pb: protobuf::VariableShowStmt) -> Self { - VariableShowStmt { name: pb.name } - } -} - -impl From for ExplainStmt { - fn from(pb: protobuf::ExplainStmt) -> Self { - ExplainStmt { - query: pb.query.map(|n| n.into()), - options: pb.options.into_iter().map(|n| n.into()).collect(), - } - } -} - -impl From for CopyStmt { - fn from(pb: protobuf::CopyStmt) -> Self { - CopyStmt { - relation: pb.relation.map(|v| v.into()), - query: pb.query.map(|n| n.into()), - attlist: pb.attlist.into_iter().map(|n| n.into()).collect(), - is_from: pb.is_from, - is_program: pb.is_program, - filename: pb.filename, - options: pb.options.into_iter().map(|n| n.into()).collect(), - where_clause: pb.where_clause.map(|n| n.into()), - } - } -} - -impl From for GrantStmt { - fn from(pb: protobuf::GrantStmt) -> Self { - GrantStmt { - is_grant: pb.is_grant, - targtype: pb.targtype.into(), - objtype: pb.objtype.into(), - objects: pb.objects.into_iter().map(|n| n.into()).collect(), - privileges: pb.privileges.into_iter().map(|n| n.into()).collect(), - grantees: pb.grantees.into_iter().map(|n| n.into()).collect(), - grant_option: pb.grant_option, - grantor: pb.grantor.map(|v| v.into()), - behavior: pb.behavior.into(), - } - } -} - -impl From for GrantRoleStmt { - fn from(pb: protobuf::GrantRoleStmt) -> Self { - GrantRoleStmt { - granted_roles: pb.granted_roles.into_iter().map(|n| n.into()).collect(), - grantee_roles: pb.grantee_roles.into_iter().map(|n| n.into()).collect(), - is_grant: pb.is_grant, - opt: pb.opt.into_iter().map(|n| n.into()).collect(), - grantor: pb.grantor.map(|v| v.into()), - behavior: pb.behavior.into(), - } - } -} - -impl From for LockStmt { - fn from(pb: protobuf::LockStmt) -> Self { - LockStmt { - relations: pb.relations.into_iter().map(|n| n.into()).collect(), - mode: pb.mode, - nowait: pb.nowait, - } - } -} - -impl From for VacuumStmt { - fn from(pb: protobuf::VacuumStmt) -> Self { - VacuumStmt { - options: pb.options.into_iter().map(|n| n.into()).collect(), - rels: pb.rels.into_iter().map(|n| n.into()).collect(), - is_vacuumcmd: pb.is_vacuumcmd, - } - } -} - -// Other statement conversions -impl From for DoStmt { - fn from(pb: protobuf::DoStmt) -> Self { - DoStmt { - args: pb.args.into_iter().map(|n| n.into()).collect(), - } - } -} - -impl From for RenameStmt { - fn from(pb: protobuf::RenameStmt) -> Self { - RenameStmt { - rename_type: pb.rename_type.into(), - relation_type: pb.relation_type.into(), - relation: pb.relation.map(|v| v.into()), - object: pb.object.map(|n| n.into()), - subname: pb.subname, - newname: pb.newname, - behavior: pb.behavior.into(), - missing_ok: pb.missing_ok, - } - } -} - -impl From for NotifyStmt { - fn from(pb: protobuf::NotifyStmt) -> Self { - NotifyStmt { - conditionname: pb.conditionname, - payload: pb.payload, - } - } -} - -impl From for ListenStmt { - fn from(pb: protobuf::ListenStmt) -> Self { - ListenStmt { - conditionname: pb.conditionname, - } - } -} - -impl From for UnlistenStmt { - fn from(pb: protobuf::UnlistenStmt) -> Self { - UnlistenStmt { - conditionname: pb.conditionname, - } - } -} - -impl From for DiscardStmt { - fn from(pb: protobuf::DiscardStmt) -> Self { - DiscardStmt { - target: pb.target.into(), - } - } -} - -impl From for PrepareStmt { - fn from(pb: protobuf::PrepareStmt) -> Self { - PrepareStmt { - name: pb.name, - argtypes: pb.argtypes.into_iter().map(|n| n.into()).collect(), - query: pb.query.map(|n| n.into()), - } - } -} - -impl From for ExecuteStmt { - fn from(pb: protobuf::ExecuteStmt) -> Self { - ExecuteStmt { - name: pb.name, - params: pb.params.into_iter().map(|n| n.into()).collect(), - } - } -} - -impl From for DeallocateStmt { - fn from(pb: protobuf::DeallocateStmt) -> Self { - DeallocateStmt { name: pb.name } - } -} - -impl From for ClosePortalStmt { - fn from(pb: protobuf::ClosePortalStmt) -> Self { - ClosePortalStmt { - portalname: pb.portalname, - } - } -} - -impl From for FetchStmt { - fn from(pb: protobuf::FetchStmt) -> Self { - FetchStmt { - direction: pb.direction.into(), - how_many: pb.how_many, - portalname: pb.portalname, - ismove: pb.ismove, - } - } -} - -// ============================================================================ -// Enum conversions -// ============================================================================ - -impl From for SetOperation { - fn from(v: i32) -> Self { - match v { - 1 => SetOperation::None, // SETOP_NONE - 2 => SetOperation::Union, // SETOP_UNION - 3 => SetOperation::Intersect, // SETOP_INTERSECT - 4 => SetOperation::Except, // SETOP_EXCEPT - _ => SetOperation::None, - } - } -} - -impl From for LimitOption { - fn from(v: i32) -> Self { - match v { - 1 => LimitOption::Default, // LIMIT_OPTION_DEFAULT - 2 => LimitOption::Count, // LIMIT_OPTION_COUNT - 3 => LimitOption::WithTies, // LIMIT_OPTION_WITH_TIES - _ => LimitOption::Default, - } - } -} - -impl From for AExprKind { - fn from(v: i32) -> Self { - match v { - 1 => AExprKind::Op, // AEXPR_OP - 2 => AExprKind::OpAny, // AEXPR_OP_ANY - 3 => AExprKind::OpAll, // AEXPR_OP_ALL - 4 => AExprKind::Distinct, // AEXPR_DISTINCT - 5 => AExprKind::NotDistinct, // AEXPR_NOT_DISTINCT - 6 => AExprKind::NullIf, // AEXPR_NULLIF - 7 => AExprKind::In, // AEXPR_IN - 8 => AExprKind::Like, // AEXPR_LIKE - 9 => AExprKind::ILike, // AEXPR_ILIKE - 10 => AExprKind::Similar, // AEXPR_SIMILAR - 11 => AExprKind::Between, // AEXPR_BETWEEN - 12 => AExprKind::NotBetween, // AEXPR_NOT_BETWEEN - 13 => AExprKind::BetweenSym, // AEXPR_BETWEEN_SYM - 14 => AExprKind::NotBetweenSym, // AEXPR_NOT_BETWEEN_SYM - _ => AExprKind::Op, - } - } -} - -impl From for BoolExprType { - fn from(v: i32) -> Self { - match v { - 1 => BoolExprType::And, // AND_EXPR - 2 => BoolExprType::Or, // OR_EXPR - 3 => BoolExprType::Not, // NOT_EXPR - _ => BoolExprType::And, - } - } -} - -impl From for SubLinkType { - fn from(v: i32) -> Self { - match v { - 1 => SubLinkType::Exists, - 2 => SubLinkType::All, - 3 => SubLinkType::Any, - 4 => SubLinkType::RowCompare, - 5 => SubLinkType::Expr, - 6 => SubLinkType::MultiExpr, - 7 => SubLinkType::Array, - 8 => SubLinkType::Cte, - _ => SubLinkType::Exists, - } - } -} - -impl From for NullTestType { - fn from(v: i32) -> Self { - match v { - 1 => NullTestType::IsNull, - 2 => NullTestType::IsNotNull, - _ => NullTestType::IsNull, - } - } -} - -impl From for BoolTestType { - fn from(v: i32) -> Self { - match v { - 1 => BoolTestType::IsTrue, - 2 => BoolTestType::IsNotTrue, - 3 => BoolTestType::IsFalse, - 4 => BoolTestType::IsNotFalse, - 5 => BoolTestType::IsUnknown, - 6 => BoolTestType::IsNotUnknown, - _ => BoolTestType::IsTrue, - } - } -} - -impl From for MinMaxOp { - fn from(v: i32) -> Self { - match v { - 1 => MinMaxOp::Greatest, - 2 => MinMaxOp::Least, - _ => MinMaxOp::Greatest, - } - } -} - -impl From for JoinType { - fn from(v: i32) -> Self { - match v { - 1 => JoinType::Inner, - 2 => JoinType::Left, - 3 => JoinType::Full, - 4 => JoinType::Right, - 5 => JoinType::Semi, - 6 => JoinType::Anti, - 7 => JoinType::RightAnti, - 8 => JoinType::UniqueOuter, - 9 => JoinType::UniqueInner, - _ => JoinType::Inner, - } - } -} - -impl From for SortByDir { - fn from(v: i32) -> Self { - match v { - 1 => SortByDir::Default, - 2 => SortByDir::Asc, - 3 => SortByDir::Desc, - 4 => SortByDir::Using, - _ => SortByDir::Default, - } - } -} - -impl From for SortByNulls { - fn from(v: i32) -> Self { - match v { - 1 => SortByNulls::Default, - 2 => SortByNulls::First, - 3 => SortByNulls::Last, - _ => SortByNulls::Default, - } - } -} - -impl From for CTEMaterialize { - fn from(v: i32) -> Self { - match v { - 1 => CTEMaterialize::Default, - 2 => CTEMaterialize::Always, - 3 => CTEMaterialize::Never, - _ => CTEMaterialize::Default, - } - } -} - -impl From for OnCommitAction { - fn from(v: i32) -> Self { - match v { - 1 => OnCommitAction::Noop, - 2 => OnCommitAction::PreserveRows, - 3 => OnCommitAction::DeleteRows, - 4 => OnCommitAction::Drop, - _ => OnCommitAction::Noop, - } - } -} - -impl From for ObjectType { - fn from(v: i32) -> Self { - // Use direct integer matching - // Values from protobuf ObjectType enum - match v { - 1 => ObjectType::AccessMethod, - 2 => ObjectType::Aggregate, - 11 => ObjectType::Cast, - 12 => ObjectType::Column, - 13 => ObjectType::Collation, - 14 => ObjectType::Conversion, - 15 => ObjectType::Database, - 16 => ObjectType::Default, - 17 => ObjectType::Constraint, - 18 => ObjectType::Domain, - 19 => ObjectType::EventTrigger, - 20 => ObjectType::Extension, - 21 => ObjectType::Fdw, - 22 => ObjectType::ForeignServer, - 23 => ObjectType::ForeignTable, - 24 => ObjectType::Function, - 25 => ObjectType::Index, - 26 => ObjectType::Language, - 27 => ObjectType::LargeObject, - 28 => ObjectType::MatView, - 29 => ObjectType::Operator, - 37 => ObjectType::Policy, - 38 => ObjectType::Procedure, - 39 => ObjectType::Publication, - 44 => ObjectType::Role, - 45 => ObjectType::Routine, - 46 => ObjectType::Rule, - 47 => ObjectType::Schema, - 48 => ObjectType::Sequence, - 49 => ObjectType::Subscription, - 50 => ObjectType::StatisticsObject, - 54 => ObjectType::Table, - 55 => ObjectType::Tablespace, - 57 => ObjectType::Transform, - 58 => ObjectType::Trigger, - 60 => ObjectType::Type, - 62 => ObjectType::View, - _ => ObjectType::Table, - } - } -} - -impl From for DropBehavior { - fn from(v: i32) -> Self { - match v { - 1 => DropBehavior::Restrict, - 2 => DropBehavior::Cascade, - _ => DropBehavior::Restrict, - } - } -} - -impl From for OnConflictAction { - fn from(v: i32) -> Self { - match v { - 1 => OnConflictAction::None, - 2 => OnConflictAction::Nothing, - 3 => OnConflictAction::Update, - _ => OnConflictAction::None, - } - } -} - -impl From for GroupingSetKind { - fn from(v: i32) -> Self { - match v { - 1 => GroupingSetKind::Empty, - 2 => GroupingSetKind::Simple, - 3 => GroupingSetKind::Rollup, - 4 => GroupingSetKind::Cube, - 5 => GroupingSetKind::Sets, - _ => GroupingSetKind::Empty, - } - } -} - -impl From for CmdType { - fn from(v: i32) -> Self { - match v { - 1 => CmdType::Unknown, - 2 => CmdType::Select, - 3 => CmdType::Update, - 4 => CmdType::Insert, - 5 => CmdType::Delete, - 6 => CmdType::Merge, - 7 => CmdType::Utility, - 8 => CmdType::Nothing, - _ => CmdType::Unknown, - } - } -} - -impl From for MergeMatchKind { - fn from(v: i32) -> Self { - match v { - 1 => MergeMatchKind::Matched, - 2 => MergeMatchKind::NotMatchedBySource, - 3 => MergeMatchKind::NotMatchedByTarget, - _ => MergeMatchKind::Undefined, - } - } -} - -impl From for TransactionStmtKind { - fn from(v: i32) -> Self { - match v { - 1 => TransactionStmtKind::Begin, - 2 => TransactionStmtKind::Start, - 3 => TransactionStmtKind::Commit, - 4 => TransactionStmtKind::Rollback, - 5 => TransactionStmtKind::Savepoint, - 6 => TransactionStmtKind::Release, - 7 => TransactionStmtKind::RollbackTo, - 8 => TransactionStmtKind::Prepare, - 9 => TransactionStmtKind::CommitPrepared, - 10 => TransactionStmtKind::RollbackPrepared, - _ => TransactionStmtKind::Begin, - } - } -} - -impl From for ConstrType { - fn from(v: i32) -> Self { - match v { - 1 => ConstrType::Null, - 2 => ConstrType::NotNull, - 3 => ConstrType::Default, - 4 => ConstrType::Identity, - 5 => ConstrType::Generated, - 6 => ConstrType::Check, - 7 => ConstrType::Primary, - 8 => ConstrType::Unique, - 9 => ConstrType::Exclusion, - 10 => ConstrType::Foreign, - 11 => ConstrType::AttrDeferrable, - 12 => ConstrType::AttrNotDeferrable, - 13 => ConstrType::AttrDeferred, - 14 => ConstrType::AttrImmediate, - _ => ConstrType::Null, - } - } -} - -impl From for DefElemAction { - fn from(v: i32) -> Self { - match v { - 1 => DefElemAction::Unspec, - 2 => DefElemAction::Set, - 3 => DefElemAction::Add, - 4 => DefElemAction::Drop, - _ => DefElemAction::Unspec, - } - } -} - -impl From for RoleSpecType { - fn from(v: i32) -> Self { - match v { - 1 => RoleSpecType::CString, - 2 => RoleSpecType::CurrentRole, - 3 => RoleSpecType::CurrentUser, - 4 => RoleSpecType::SessionUser, - 5 => RoleSpecType::Public, - _ => RoleSpecType::CString, - } - } -} - -impl From for CoercionForm { - fn from(v: i32) -> Self { - match v { - 1 => CoercionForm::ExplicitCall, - 2 => CoercionForm::ExplicitCast, - 3 => CoercionForm::ImplicitCast, - 4 => CoercionForm::SqlSyntax, - _ => CoercionForm::ExplicitCall, - } - } -} - -impl From for VariableSetKind { - fn from(v: i32) -> Self { - match v { - 1 => VariableSetKind::Value, - 2 => VariableSetKind::Default, - 3 => VariableSetKind::Current, - 4 => VariableSetKind::Multi, - 5 => VariableSetKind::Reset, - 6 => VariableSetKind::ResetAll, - _ => VariableSetKind::Value, - } - } -} - -impl From for LockClauseStrength { - fn from(v: i32) -> Self { - match v { - 1 => LockClauseStrength::None, - 2 => LockClauseStrength::ForKeyShare, - 3 => LockClauseStrength::ForShare, - 4 => LockClauseStrength::ForNoKeyUpdate, - 5 => LockClauseStrength::ForUpdate, - _ => LockClauseStrength::None, - } - } -} - -impl From for LockWaitPolicy { - fn from(v: i32) -> Self { - match v { - 1 => LockWaitPolicy::Block, - 2 => LockWaitPolicy::Skip, - 3 => LockWaitPolicy::Error, - _ => LockWaitPolicy::Block, - } - } -} - -impl From for ViewCheckOption { - fn from(v: i32) -> Self { - match v { - 1 => ViewCheckOption::NoCheckOption, - 2 => ViewCheckOption::Local, - 3 => ViewCheckOption::Cascaded, - _ => ViewCheckOption::NoCheckOption, - } - } -} - -impl From for DiscardMode { - fn from(v: i32) -> Self { - match v { - 1 => DiscardMode::All, - 2 => DiscardMode::Plans, - 3 => DiscardMode::Sequences, - 4 => DiscardMode::Temp, - _ => DiscardMode::All, - } - } -} - -impl From for FetchDirection { - fn from(v: i32) -> Self { - match v { - 1 => FetchDirection::Forward, - 2 => FetchDirection::Backward, - 3 => FetchDirection::Absolute, - 4 => FetchDirection::Relative, - _ => FetchDirection::Forward, - } - } -} - -impl From for FunctionParameterMode { - fn from(v: i32) -> Self { - match v { - 105 => FunctionParameterMode::In, // 'i' - 111 => FunctionParameterMode::Out, // 'o' - 98 => FunctionParameterMode::InOut, // 'b' - 118 => FunctionParameterMode::Variadic, // 'v' - 116 => FunctionParameterMode::Table, // 't' - _ => FunctionParameterMode::In, - } - } -} - -impl From for AlterTableType { - fn from(v: i32) -> Self { - // AlterTableType has many variants, use default for simplicity - // The values start at 1 and go up - match v { - 1 => AlterTableType::AddColumn, - 2 => AlterTableType::AddColumnToView, - 3 => AlterTableType::ColumnDefault, - 4 => AlterTableType::CookedColumnDefault, - 5 => AlterTableType::DropNotNull, - 6 => AlterTableType::SetNotNull, - 7 => AlterTableType::DropExpression, - 8 => AlterTableType::CheckNotNull, - 9 => AlterTableType::SetStatistics, - 10 => AlterTableType::SetOptions, - 11 => AlterTableType::ResetOptions, - 12 => AlterTableType::SetStorage, - 13 => AlterTableType::SetCompression, - 14 => AlterTableType::DropColumn, - 15 => AlterTableType::AddIndex, - 16 => AlterTableType::ReAddIndex, - 17 => AlterTableType::AddConstraint, - 18 => AlterTableType::ReAddConstraint, - 19 => AlterTableType::AddIndexConstraint, - 20 => AlterTableType::AlterConstraint, - 21 => AlterTableType::ValidateConstraint, - 22 => AlterTableType::DropConstraint, - 23 => AlterTableType::ClusterOn, - 24 => AlterTableType::DropCluster, - 25 => AlterTableType::SetLogged, - 26 => AlterTableType::SetUnLogged, - 27 => AlterTableType::SetAccessMethod, - 28 => AlterTableType::DropOids, - 29 => AlterTableType::SetTableSpace, - 30 => AlterTableType::SetRelOptions, - 31 => AlterTableType::ResetRelOptions, - 32 => AlterTableType::ReplaceRelOptions, - 33 => AlterTableType::EnableTrig, - 34 => AlterTableType::EnableAlwaysTrig, - 35 => AlterTableType::EnableReplicaTrig, - 36 => AlterTableType::DisableTrig, - 37 => AlterTableType::EnableTrigAll, - 38 => AlterTableType::DisableTrigAll, - 39 => AlterTableType::EnableTrigUser, - 40 => AlterTableType::DisableTrigUser, - 41 => AlterTableType::EnableRule, - 42 => AlterTableType::EnableAlwaysRule, - 43 => AlterTableType::EnableReplicaRule, - 44 => AlterTableType::DisableRule, - 45 => AlterTableType::AddInherit, - 46 => AlterTableType::DropInherit, - 47 => AlterTableType::AddOf, - 48 => AlterTableType::DropOf, - 49 => AlterTableType::ReplicaIdentity, - 50 => AlterTableType::EnableRowSecurity, - 51 => AlterTableType::DisableRowSecurity, - 52 => AlterTableType::ForceRowSecurity, - 53 => AlterTableType::NoForceRowSecurity, - 54 => AlterTableType::GenericOptions, - 55 => AlterTableType::AttachPartition, - 56 => AlterTableType::DetachPartition, - 57 => AlterTableType::DetachPartitionFinalize, - 58 => AlterTableType::AddIdentity, - 59 => AlterTableType::SetIdentity, - 60 => AlterTableType::DropIdentity, - 61 => AlterTableType::ReAddStatistics, - _ => AlterTableType::AddColumn, - } - } -} - -impl From for GrantTargetType { - fn from(v: i32) -> Self { - match v { - 1 => GrantTargetType::Object, - 2 => GrantTargetType::AllInSchema, - 3 => GrantTargetType::Defaults, - _ => GrantTargetType::Object, - } - } -} - -impl From for OverridingKind { - fn from(v: i32) -> Self { - match v { - 1 => OverridingKind::NotSet, - 2 => OverridingKind::UserValue, - 3 => OverridingKind::SystemValue, - _ => OverridingKind::NotSet, - } - } -} - diff --git a/src/ast/mod.rs b/src/ast/mod.rs deleted file mode 100644 index 8cc595c..0000000 --- a/src/ast/mod.rs +++ /dev/null @@ -1,31 +0,0 @@ -//! Native Rust AST types for PostgreSQL parse trees. -//! -//! This module provides ergonomic Rust types that wrap the PostgreSQL parse tree -//! structure. These types make it easier to work with parsed SQL queries without -//! the complexity of deeply nested protobuf Option> wrappers. -//! -//! # Example -//! -//! ```rust -//! use pg_query::ast::Node; -//! -//! let result = pg_query::parse_to_ast("SELECT * FROM users WHERE id = 1").unwrap(); -//! for stmt in &result.stmts { -//! match &stmt.stmt { -//! Node::SelectStmt(select) => { -//! // Access fields more directly -//! for table in &select.from_clause { -//! if let Node::RangeVar(rv) = table { -//! println!("Table: {}", rv.relname); -//! } -//! } -//! } -//! _ => {} -//! } -//! } -//! ``` - -mod nodes; -mod convert; - -pub use nodes::*; diff --git a/src/ast/nodes.rs b/src/ast/nodes.rs deleted file mode 100644 index 92f4726..0000000 --- a/src/ast/nodes.rs +++ /dev/null @@ -1,1628 +0,0 @@ -//! Native Rust AST node types for PostgreSQL parse trees. -//! -//! These types mirror the PostgreSQL parse tree structure but use idiomatic Rust -//! patterns instead of protobuf-style Option> wrappers. - -use crate::protobuf; - -/// Top-level parse result containing all parsed statements. -#[derive(Debug, Clone)] -pub struct ParseResult { - /// PostgreSQL version number (e.g., 160001 for 16.0.1) - pub version: i32, - /// List of parsed statements - pub stmts: Vec, - /// Original protobuf for deparsing (hidden implementation detail) - pub(crate) original_protobuf: protobuf::ParseResult, -} - -// as_protobuf method is defined in convert.rs - -/// A raw statement wrapper with location information. -#[derive(Debug, Clone)] -pub struct RawStmt { - /// The statement node - pub stmt: Node, - /// Character offset in source where statement starts - pub stmt_location: i32, - /// Length of statement in characters (0 means "rest of string") - pub stmt_len: i32, -} - -/// The main AST node enum containing all possible node types. -/// -/// This enum eliminates the need for `Option>` wrappers throughout -/// the AST by using a flat enum with all node types as variants. -#[derive(Debug, Clone)] -pub enum Node { - // Primitive value types - Integer(Integer), - Float(Float), - Boolean(Boolean), - String(StringValue), - BitString(BitString), - Null, - - // List type - List(Vec), - - // Statement types - SelectStmt(Box), - InsertStmt(Box), - UpdateStmt(Box), - DeleteStmt(Box), - MergeStmt(Box), - - // DDL statements - CreateStmt(Box), - AlterTableStmt(Box), - DropStmt(Box), - TruncateStmt(Box), - IndexStmt(Box), - CreateSchemaStmt(Box), - ViewStmt(Box), - CreateFunctionStmt(Box), - AlterFunctionStmt(Box), - CreateSeqStmt(Box), - AlterSeqStmt(Box), - CreateTrigStmt(Box), - RuleStmt(Box), - CreateDomainStmt(Box), - CreateTableAsStmt(Box), - RefreshMatViewStmt(Box), - - // Transaction statement - TransactionStmt(Box), - - // Expression types - AExpr(Box), - ColumnRef(Box), - ParamRef(Box), - AConst(Box), - TypeCast(Box), - CollateClause(Box), - FuncCall(Box), - AStar(AStar), - AIndices(Box), - AIndirection(Box), - AArrayExpr(Box), - SubLink(Box), - BoolExpr(Box), - NullTest(Box), - BooleanTest(Box), - CaseExpr(Box), - CaseWhen(Box), - CoalesceExpr(Box), - MinMaxExpr(Box), - RowExpr(Box), - - // Target/Result types - ResTarget(Box), - - // Table/Range types - RangeVar(Box), - RangeSubselect(Box), - RangeFunction(Box), - JoinExpr(Box), - - // Clause types - SortBy(Box), - WindowDef(Box), - WithClause(Box), - CommonTableExpr(Box), - IntoClause(Box), - OnConflictClause(Box), - LockingClause(Box), - GroupingSet(Box), - MergeWhenClause(Box), - - // Type-related - TypeName(Box), - ColumnDef(Box), - Constraint(Box), - DefElem(Box), - IndexElem(Box), - - // Alias and role types - Alias(Box), - RoleSpec(Box), - - // Other commonly used types - SortGroupClause(Box), - FunctionParameter(Box), - AlterTableCmd(Box), - AccessPriv(Box), - ObjectWithArgs(Box), - - // Administrative statements - VariableSetStmt(Box), - VariableShowStmt(Box), - ExplainStmt(Box), - CopyStmt(Box), - GrantStmt(Box), - GrantRoleStmt(Box), - LockStmt(Box), - VacuumStmt(Box), - - // Other statements - DoStmt(Box), - RenameStmt(Box), - NotifyStmt(Box), - ListenStmt(Box), - UnlistenStmt(Box), - CheckPointStmt(Box), - DiscardStmt(Box), - PrepareStmt(Box), - ExecuteStmt(Box), - DeallocateStmt(Box), - ClosePortalStmt(Box), - FetchStmt(Box), - - // Fallback for unhandled node types - stores the original protobuf - Other(protobuf::Node), -} - -// ============================================================================ -// Primitive value types -// ============================================================================ - -/// Integer value -#[derive(Debug, Clone, Default)] -pub struct Integer { - pub ival: i32, -} - -/// Float value (stored as string) -#[derive(Debug, Clone, Default)] -pub struct Float { - pub fval: String, -} - -/// Boolean value -#[derive(Debug, Clone, Default)] -pub struct Boolean { - pub boolval: bool, -} - -/// String value -#[derive(Debug, Clone, Default)] -pub struct StringValue { - pub sval: String, -} - -/// Bit string value -#[derive(Debug, Clone, Default)] -pub struct BitString { - pub bsval: String, -} - -/// A star (*) in column reference -#[derive(Debug, Clone, Default)] -pub struct AStar; - -// ============================================================================ -// Core statement types -// ============================================================================ - -/// SELECT statement -#[derive(Debug, Clone, Default)] -pub struct SelectStmt { - pub distinct_clause: Vec, - pub into_clause: Option, - pub target_list: Vec, - pub from_clause: Vec, - pub where_clause: Option, - pub group_clause: Vec, - pub group_distinct: bool, - pub having_clause: Option, - pub window_clause: Vec, - pub values_lists: Vec, - pub sort_clause: Vec, - pub limit_offset: Option, - pub limit_count: Option, - pub limit_option: LimitOption, - pub locking_clause: Vec, - pub with_clause: Option, - pub op: SetOperation, - pub all: bool, - pub larg: Option>, - pub rarg: Option>, -} - -/// INSERT statement -#[derive(Debug, Clone, Default)] -pub struct InsertStmt { - pub relation: Option, - pub cols: Vec, - pub select_stmt: Option, - pub on_conflict_clause: Option, - pub returning_list: Vec, - pub with_clause: Option, - pub override_: OverridingKind, -} - -/// UPDATE statement -#[derive(Debug, Clone, Default)] -pub struct UpdateStmt { - pub relation: Option, - pub target_list: Vec, - pub where_clause: Option, - pub from_clause: Vec, - pub returning_list: Vec, - pub with_clause: Option, -} - -/// DELETE statement -#[derive(Debug, Clone, Default)] -pub struct DeleteStmt { - pub relation: Option, - pub using_clause: Vec, - pub where_clause: Option, - pub returning_list: Vec, - pub with_clause: Option, -} - -/// MERGE statement -#[derive(Debug, Clone, Default)] -pub struct MergeStmt { - pub relation: Option, - pub source_relation: Option, - pub join_condition: Option, - pub merge_when_clauses: Vec, - pub with_clause: Option, -} - -// ============================================================================ -// DDL statement types -// ============================================================================ - -/// CREATE TABLE statement -#[derive(Debug, Clone, Default)] -pub struct CreateStmt { - pub relation: Option, - pub table_elts: Vec, - pub inh_relations: Vec, - pub partbound: Option, - pub partspec: Option, - pub of_typename: Option, - pub constraints: Vec, - pub options: Vec, - pub oncommit: OnCommitAction, - pub tablespacename: String, - pub access_method: String, - pub if_not_exists: bool, -} - -/// ALTER TABLE statement -#[derive(Debug, Clone, Default)] -pub struct AlterTableStmt { - pub relation: Option, - pub cmds: Vec, - pub objtype: ObjectType, - pub missing_ok: bool, -} - -/// DROP statement -#[derive(Debug, Clone, Default)] -pub struct DropStmt { - pub objects: Vec, - pub remove_type: ObjectType, - pub behavior: DropBehavior, - pub missing_ok: bool, - pub concurrent: bool, -} - -/// TRUNCATE statement -#[derive(Debug, Clone, Default)] -pub struct TruncateStmt { - pub relations: Vec, - pub restart_seqs: bool, - pub behavior: DropBehavior, -} - -/// CREATE INDEX statement -#[derive(Debug, Clone, Default)] -pub struct IndexStmt { - pub idxname: String, - pub relation: Option, - pub access_method: String, - pub table_space: String, - pub index_params: Vec, - pub index_including_params: Vec, - pub options: Vec, - pub where_clause: Option, - pub exclude_op_names: Vec, - pub idxcomment: String, - pub index_oid: u32, - pub old_number: u32, - pub old_first_relfilelocator: u32, - pub unique: bool, - pub nulls_not_distinct: bool, - pub primary: bool, - pub is_constraint: bool, - pub deferrable: bool, - pub initdeferred: bool, - pub transformed: bool, - pub concurrent: bool, - pub if_not_exists: bool, - pub reset_default_tblspc: bool, -} - -/// CREATE SCHEMA statement -#[derive(Debug, Clone, Default)] -pub struct CreateSchemaStmt { - pub schemaname: String, - pub authrole: Option, - pub schema_elts: Vec, - pub if_not_exists: bool, -} - -/// CREATE VIEW statement -#[derive(Debug, Clone, Default)] -pub struct ViewStmt { - pub view: Option, - pub aliases: Vec, - pub query: Option, - pub replace: bool, - pub options: Vec, - pub with_check_option: ViewCheckOption, -} - -/// CREATE FUNCTION statement -#[derive(Debug, Clone, Default)] -pub struct CreateFunctionStmt { - pub is_procedure: bool, - pub replace: bool, - pub funcname: Vec, - pub parameters: Vec, - pub return_type: Option, - pub options: Vec, - pub sql_body: Option, -} - -/// ALTER FUNCTION statement -#[derive(Debug, Clone, Default)] -pub struct AlterFunctionStmt { - pub objtype: ObjectType, - pub func: Option, - pub actions: Vec, -} - -/// CREATE SEQUENCE statement -#[derive(Debug, Clone, Default)] -pub struct CreateSeqStmt { - pub sequence: Option, - pub options: Vec, - pub owner_id: u32, - pub for_identity: bool, - pub if_not_exists: bool, -} - -/// ALTER SEQUENCE statement -#[derive(Debug, Clone, Default)] -pub struct AlterSeqStmt { - pub sequence: Option, - pub options: Vec, - pub for_identity: bool, - pub missing_ok: bool, -} - -/// CREATE TRIGGER statement -#[derive(Debug, Clone, Default)] -pub struct CreateTrigStmt { - pub replace: bool, - pub isconstraint: bool, - pub trigname: String, - pub relation: Option, - pub funcname: Vec, - pub args: Vec, - pub row: bool, - pub timing: i32, - pub events: i32, - pub columns: Vec, - pub when_clause: Option, - pub transition_rels: Vec, - pub deferrable: bool, - pub initdeferred: bool, - pub constrrel: Option, -} - -/// CREATE RULE statement -#[derive(Debug, Clone, Default)] -pub struct RuleStmt { - pub relation: Option, - pub rulename: String, - pub where_clause: Option, - pub event: CmdType, - pub instead: bool, - pub actions: Vec, - pub replace: bool, -} - -/// CREATE DOMAIN statement -#[derive(Debug, Clone, Default)] -pub struct CreateDomainStmt { - pub domainname: Vec, - pub type_name: Option, - pub coll_clause: Option, - pub constraints: Vec, -} - -/// CREATE TABLE AS statement -#[derive(Debug, Clone, Default)] -pub struct CreateTableAsStmt { - pub query: Option, - pub into: Option, - pub objtype: ObjectType, - pub is_select_into: bool, - pub if_not_exists: bool, -} - -/// REFRESH MATERIALIZED VIEW statement -#[derive(Debug, Clone, Default)] -pub struct RefreshMatViewStmt { - pub concurrent: bool, - pub skip_data: bool, - pub relation: Option, -} - -// ============================================================================ -// Transaction statement -// ============================================================================ - -/// Transaction statement (BEGIN, COMMIT, ROLLBACK, etc.) -#[derive(Debug, Clone, Default)] -pub struct TransactionStmt { - pub kind: TransactionStmtKind, - pub options: Vec, - pub savepoint_name: String, - pub gid: String, - pub chain: bool, -} - -// ============================================================================ -// Expression types -// ============================================================================ - -/// An expression with an operator (e.g., "a + b", "x = 1") -#[derive(Debug, Clone, Default)] -pub struct AExpr { - pub kind: AExprKind, - pub name: Vec, - pub lexpr: Option, - pub rexpr: Option, - pub location: i32, -} - -/// Column reference (e.g., "table.column") -#[derive(Debug, Clone, Default)] -pub struct ColumnRef { - pub fields: Vec, - pub location: i32, -} - -/// Parameter reference ($1, $2, etc.) -#[derive(Debug, Clone, Default)] -pub struct ParamRef { - pub number: i32, - pub location: i32, -} - -/// A constant value -#[derive(Debug, Clone, Default)] -pub struct AConst { - pub val: Option, - pub isnull: bool, - pub location: i32, -} - -/// Value types for AConst -#[derive(Debug, Clone)] -pub enum AConstValue { - Integer(Integer), - Float(Float), - Boolean(Boolean), - String(StringValue), - BitString(BitString), -} - -/// Type cast expression -#[derive(Debug, Clone, Default)] -pub struct TypeCast { - pub arg: Option, - pub type_name: Option, - pub location: i32, -} - -/// COLLATE clause -#[derive(Debug, Clone, Default)] -pub struct CollateClause { - pub arg: Option, - pub collname: Vec, - pub location: i32, -} - -/// Function call -#[derive(Debug, Clone, Default)] -pub struct FuncCall { - pub funcname: Vec, - pub args: Vec, - pub agg_order: Vec, - pub agg_filter: Option, - pub over: Option, - pub agg_within_group: bool, - pub agg_star: bool, - pub agg_distinct: bool, - pub func_variadic: bool, - pub funcformat: CoercionForm, - pub location: i32, -} - -/// Array subscript indices -#[derive(Debug, Clone, Default)] -pub struct AIndices { - pub is_slice: bool, - pub lidx: Option, - pub uidx: Option, -} - -/// Array subscript or field selection -#[derive(Debug, Clone, Default)] -pub struct AIndirection { - pub arg: Option, - pub indirection: Vec, -} - -/// ARRAY[] constructor -#[derive(Debug, Clone, Default)] -pub struct AArrayExpr { - pub elements: Vec, - pub location: i32, -} - -/// Subquery link (subquery in expression context) -#[derive(Debug, Clone, Default)] -pub struct SubLink { - pub sub_link_type: SubLinkType, - pub sub_link_id: i32, - pub testexpr: Option, - pub oper_name: Vec, - pub subselect: Option, - pub location: i32, -} - -/// Boolean expression (AND, OR, NOT) -#[derive(Debug, Clone, Default)] -pub struct BoolExpr { - pub boolop: BoolExprType, - pub args: Vec, - pub location: i32, -} - -/// NULL test expression -#[derive(Debug, Clone, Default)] -pub struct NullTest { - pub arg: Option, - pub nulltesttype: NullTestType, - pub argisrow: bool, - pub location: i32, -} - -/// Boolean test (IS TRUE, IS FALSE, etc.) -#[derive(Debug, Clone, Default)] -pub struct BooleanTest { - pub arg: Option, - pub booltesttype: BoolTestType, - pub location: i32, -} - -/// CASE expression -#[derive(Debug, Clone, Default)] -pub struct CaseExpr { - pub arg: Option, - pub args: Vec, - pub defresult: Option, - pub location: i32, -} - -/// WHEN clause of CASE -#[derive(Debug, Clone, Default)] -pub struct CaseWhen { - pub expr: Option, - pub result: Option, - pub location: i32, -} - -/// COALESCE expression -#[derive(Debug, Clone, Default)] -pub struct CoalesceExpr { - pub args: Vec, - pub location: i32, -} - -/// GREATEST or LEAST expression -#[derive(Debug, Clone, Default)] -pub struct MinMaxExpr { - pub op: MinMaxOp, - pub args: Vec, - pub location: i32, -} - -/// ROW() expression -#[derive(Debug, Clone, Default)] -pub struct RowExpr { - pub args: Vec, - pub row_format: CoercionForm, - pub colnames: Vec, - pub location: i32, -} - -// ============================================================================ -// Target/Result types -// ============================================================================ - -/// Result target (column in SELECT list or assignment target) -#[derive(Debug, Clone, Default)] -pub struct ResTarget { - pub name: String, - pub indirection: Vec, - pub val: Option, - pub location: i32, -} - -// ============================================================================ -// Table/Range types -// ============================================================================ - -/// Table/relation reference -#[derive(Debug, Clone, Default)] -pub struct RangeVar { - pub catalogname: String, - pub schemaname: String, - pub relname: String, - pub inh: bool, - pub relpersistence: String, - pub alias: Option, - pub location: i32, -} - -/// Subquery in FROM clause -#[derive(Debug, Clone, Default)] -pub struct RangeSubselect { - pub lateral: bool, - pub subquery: Option, - pub alias: Option, -} - -/// Function call in FROM clause -#[derive(Debug, Clone, Default)] -pub struct RangeFunction { - pub lateral: bool, - pub ordinality: bool, - pub is_rowsfrom: bool, - pub functions: Vec, - pub alias: Option, - pub coldeflist: Vec, -} - -/// JOIN expression -#[derive(Debug, Clone, Default)] -pub struct JoinExpr { - pub jointype: JoinType, - pub is_natural: bool, - pub larg: Option, - pub rarg: Option, - pub using_clause: Vec, - pub join_using_alias: Option, - pub quals: Option, - pub alias: Option, - pub rtindex: i32, -} - -// ============================================================================ -// Clause types -// ============================================================================ - -/// ORDER BY clause element -#[derive(Debug, Clone, Default)] -pub struct SortBy { - pub node: Option, - pub sortby_dir: SortByDir, - pub sortby_nulls: SortByNulls, - pub use_op: Vec, - pub location: i32, -} - -/// WINDOW definition -#[derive(Debug, Clone, Default)] -pub struct WindowDef { - pub name: String, - pub refname: String, - pub partition_clause: Vec, - pub order_clause: Vec, - pub frame_options: i32, - pub start_offset: Option, - pub end_offset: Option, - pub location: i32, -} - -/// WITH clause -#[derive(Debug, Clone, Default)] -pub struct WithClause { - pub ctes: Vec, - pub recursive: bool, - pub location: i32, -} - -/// Common Table Expression (CTE) -#[derive(Debug, Clone, Default)] -pub struct CommonTableExpr { - pub ctename: String, - pub aliascolnames: Vec, - pub ctematerialized: CTEMaterialize, - pub ctequery: Option, - pub search_clause: Option, - pub cycle_clause: Option, - pub location: i32, - pub cterecursive: bool, - pub cterefcount: i32, - pub ctecolnames: Vec, - pub ctecoltypes: Vec, - pub ctecoltypmods: Vec, - pub ctecolcollations: Vec, -} - -/// INTO clause for SELECT INTO -#[derive(Debug, Clone, Default)] -pub struct IntoClause { - pub rel: Option, - pub col_names: Vec, - pub access_method: String, - pub options: Vec, - pub on_commit: OnCommitAction, - pub table_space_name: String, - pub view_query: Option, - pub skip_data: bool, -} - -/// ON CONFLICT clause for INSERT -#[derive(Debug, Clone, Default)] -pub struct OnConflictClause { - pub action: OnConflictAction, - pub infer: Option, - pub target_list: Vec, - pub where_clause: Option, - pub location: i32, -} - -/// FOR UPDATE/SHARE clause -#[derive(Debug, Clone, Default)] -pub struct LockingClause { - pub locked_rels: Vec, - pub strength: LockClauseStrength, - pub wait_policy: LockWaitPolicy, -} - -/// GROUPING SETS clause element -#[derive(Debug, Clone, Default)] -pub struct GroupingSet { - pub kind: GroupingSetKind, - pub content: Vec, - pub location: i32, -} - -/// MERGE WHEN clause -#[derive(Debug, Clone, Default)] -pub struct MergeWhenClause { - pub match_kind: MergeMatchKind, - pub command_type: CmdType, - pub override_: OverridingKind, - pub condition: Option, - pub target_list: Vec, - pub values: Vec, -} - -// ============================================================================ -// Type-related -// ============================================================================ - -/// Type name -#[derive(Debug, Clone, Default)] -pub struct TypeName { - pub names: Vec, - pub type_oid: u32, - pub setof: bool, - pub pct_type: bool, - pub typmods: Vec, - pub typemod: i32, - pub array_bounds: Vec, - pub location: i32, -} - -/// Column definition -#[derive(Debug, Clone, Default)] -pub struct ColumnDef { - pub colname: String, - pub type_name: Option, - pub compression: String, - pub inhcount: i32, - pub is_local: bool, - pub is_not_null: bool, - pub is_from_type: bool, - pub storage: String, - pub storage_name: String, - pub raw_default: Option, - pub cooked_default: Option, - pub identity: String, - pub identity_sequence: Option, - pub generated: String, - pub coll_clause: Option, - pub coll_oid: u32, - pub constraints: Vec, - pub fdwoptions: Vec, - pub location: i32, -} - -/// Constraint definition -#[derive(Debug, Clone, Default)] -pub struct Constraint { - pub contype: ConstrType, - pub conname: String, - pub deferrable: bool, - pub initdeferred: bool, - pub location: i32, - pub is_no_inherit: bool, - pub raw_expr: Option, - pub cooked_expr: String, - pub generated_when: String, - pub inhcount: i32, - pub nulls_not_distinct: bool, - pub keys: Vec, - pub including: Vec, - pub exclusions: Vec, - pub options: Vec, - pub indexname: String, - pub indexspace: String, - pub reset_default_tblspc: bool, - pub access_method: String, - pub where_clause: Option, - pub pktable: Option, - pub fk_attrs: Vec, - pub pk_attrs: Vec, - pub fk_matchtype: String, - pub fk_upd_action: String, - pub fk_del_action: String, - pub fk_del_set_cols: Vec, - pub old_conpfeqop: Vec, - pub old_pktable_oid: u32, - pub skip_validation: bool, - pub initially_valid: bool, -} - -/// Definition element (generic) -#[derive(Debug, Clone, Default)] -pub struct DefElem { - pub defnamespace: String, - pub defname: String, - pub arg: Option, - pub defaction: DefElemAction, - pub location: i32, -} - -/// Index element -#[derive(Debug, Clone, Default)] -pub struct IndexElem { - pub name: String, - pub expr: Option, - pub indexcolname: String, - pub collation: Vec, - pub opclass: Vec, - pub opclassopts: Vec, - pub ordering: SortByDir, - pub nulls_ordering: SortByNulls, -} - -// ============================================================================ -// Alias and role types -// ============================================================================ - -/// Alias -#[derive(Debug, Clone, Default)] -pub struct Alias { - pub aliasname: String, - pub colnames: Vec, -} - -/// Role specification -#[derive(Debug, Clone, Default)] -pub struct RoleSpec { - pub roletype: RoleSpecType, - pub rolename: String, - pub location: i32, -} - -// ============================================================================ -// Other commonly used types -// ============================================================================ - -/// Sort/Group clause -#[derive(Debug, Clone, Default)] -pub struct SortGroupClause { - pub tle_sort_group_ref: u32, - pub eqop: u32, - pub sortop: u32, - pub nulls_first: bool, - pub hashable: bool, -} - -/// Function parameter -#[derive(Debug, Clone, Default)] -pub struct FunctionParameter { - pub name: String, - pub arg_type: Option, - pub mode: FunctionParameterMode, - pub defexpr: Option, -} - -/// ALTER TABLE command -#[derive(Debug, Clone, Default)] -pub struct AlterTableCmd { - pub subtype: AlterTableType, - pub name: String, - pub num: i16, - pub newowner: Option, - pub def: Option, - pub behavior: DropBehavior, - pub missing_ok: bool, - pub recurse: bool, -} - -/// Access privilege -#[derive(Debug, Clone, Default)] -pub struct AccessPriv { - pub priv_name: String, - pub cols: Vec, -} - -/// Object with arguments -#[derive(Debug, Clone, Default)] -pub struct ObjectWithArgs { - pub objname: Vec, - pub objargs: Vec, - pub objfuncargs: Vec, - pub args_unspecified: bool, -} - -// ============================================================================ -// Administrative statements -// ============================================================================ - -/// SET variable statement -#[derive(Debug, Clone, Default)] -pub struct VariableSetStmt { - pub kind: VariableSetKind, - pub name: String, - pub args: Vec, - pub is_local: bool, -} - -/// SHOW variable statement -#[derive(Debug, Clone, Default)] -pub struct VariableShowStmt { - pub name: String, -} - -/// EXPLAIN statement -#[derive(Debug, Clone, Default)] -pub struct ExplainStmt { - pub query: Option, - pub options: Vec, -} - -/// COPY statement -#[derive(Debug, Clone, Default)] -pub struct CopyStmt { - pub relation: Option, - pub query: Option, - pub attlist: Vec, - pub is_from: bool, - pub is_program: bool, - pub filename: String, - pub options: Vec, - pub where_clause: Option, -} - -/// GRANT/REVOKE statement -#[derive(Debug, Clone, Default)] -pub struct GrantStmt { - pub is_grant: bool, - pub targtype: GrantTargetType, - pub objtype: ObjectType, - pub objects: Vec, - pub privileges: Vec, - pub grantees: Vec, - pub grant_option: bool, - pub grantor: Option, - pub behavior: DropBehavior, -} - -/// GRANT/REVOKE role statement -#[derive(Debug, Clone, Default)] -pub struct GrantRoleStmt { - pub granted_roles: Vec, - pub grantee_roles: Vec, - pub is_grant: bool, - pub opt: Vec, - pub grantor: Option, - pub behavior: DropBehavior, -} - -/// LOCK statement -#[derive(Debug, Clone, Default)] -pub struct LockStmt { - pub relations: Vec, - pub mode: i32, - pub nowait: bool, -} - -/// VACUUM/ANALYZE statement -#[derive(Debug, Clone, Default)] -pub struct VacuumStmt { - pub options: Vec, - pub rels: Vec, - pub is_vacuumcmd: bool, -} - -// ============================================================================ -// Other statements -// ============================================================================ - -/// DO statement -#[derive(Debug, Clone, Default)] -pub struct DoStmt { - pub args: Vec, -} - -/// RENAME statement -#[derive(Debug, Clone, Default)] -pub struct RenameStmt { - pub rename_type: ObjectType, - pub relation_type: ObjectType, - pub relation: Option, - pub object: Option, - pub subname: String, - pub newname: String, - pub behavior: DropBehavior, - pub missing_ok: bool, -} - -/// NOTIFY statement -#[derive(Debug, Clone, Default)] -pub struct NotifyStmt { - pub conditionname: String, - pub payload: String, -} - -/// LISTEN statement -#[derive(Debug, Clone, Default)] -pub struct ListenStmt { - pub conditionname: String, -} - -/// UNLISTEN statement -#[derive(Debug, Clone, Default)] -pub struct UnlistenStmt { - pub conditionname: String, -} - -/// CHECKPOINT statement -#[derive(Debug, Clone, Default)] -pub struct CheckPointStmt; - -/// DISCARD statement -#[derive(Debug, Clone, Default)] -pub struct DiscardStmt { - pub target: DiscardMode, -} - -/// PREPARE statement -#[derive(Debug, Clone, Default)] -pub struct PrepareStmt { - pub name: String, - pub argtypes: Vec, - pub query: Option, -} - -/// EXECUTE statement -#[derive(Debug, Clone, Default)] -pub struct ExecuteStmt { - pub name: String, - pub params: Vec, -} - -/// DEALLOCATE statement -#[derive(Debug, Clone, Default)] -pub struct DeallocateStmt { - pub name: String, -} - -/// CLOSE cursor statement -#[derive(Debug, Clone, Default)] -pub struct ClosePortalStmt { - pub portalname: String, -} - -/// FETCH/MOVE statement -#[derive(Debug, Clone, Default)] -pub struct FetchStmt { - pub direction: FetchDirection, - pub how_many: i64, - pub portalname: String, - pub ismove: bool, -} - -// ============================================================================ -// Enums -// ============================================================================ - -/// SET operation type -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum SetOperation { - #[default] - None, - Union, - Intersect, - Except, -} - -/// LIMIT option -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum LimitOption { - #[default] - Default, - Count, - WithTies, -} - -/// A_Expr kind -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum AExprKind { - #[default] - Op, - OpAny, - OpAll, - Distinct, - NotDistinct, - NullIf, - In, - Like, - ILike, - Similar, - Between, - NotBetween, - BetweenSym, - NotBetweenSym, -} - -/// Boolean expression type -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum BoolExprType { - #[default] - And, - Or, - Not, -} - -/// Sublink type -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum SubLinkType { - #[default] - Exists, - All, - Any, - RowCompare, - Expr, - MultiExpr, - Array, - Cte, -} - -/// NULL test type -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum NullTestType { - #[default] - IsNull, - IsNotNull, -} - -/// Boolean test type -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum BoolTestType { - #[default] - IsTrue, - IsNotTrue, - IsFalse, - IsNotFalse, - IsUnknown, - IsNotUnknown, -} - -/// Min/Max operation -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum MinMaxOp { - #[default] - Greatest, - Least, -} - -/// JOIN type -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum JoinType { - #[default] - Inner, - Left, - Full, - Right, - Semi, - Anti, - RightAnti, - UniqueOuter, - UniqueInner, -} - -/// SORT BY direction -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum SortByDir { - #[default] - Default, - Asc, - Desc, - Using, -} - -/// SORT BY nulls -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum SortByNulls { - #[default] - Default, - First, - Last, -} - -/// CTE materialization -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum CTEMaterialize { - #[default] - Default, - Always, - Never, -} - -/// ON COMMIT action -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum OnCommitAction { - #[default] - Noop, - PreserveRows, - DeleteRows, - Drop, -} - -/// Object type for DDL -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum ObjectType { - #[default] - Table, - Index, - Sequence, - View, - MatView, - Type, - Schema, - Function, - Procedure, - Routine, - Aggregate, - Operator, - Language, - Cast, - Trigger, - EventTrigger, - Rule, - Database, - Tablespace, - Role, - Extension, - Fdw, - ForeignServer, - ForeignTable, - Policy, - Publication, - Subscription, - Collation, - Conversion, - Default, - Domain, - Constraint, - Column, - AccessMethod, - LargeObject, - Transform, - StatisticsObject, -} - -/// DROP behavior -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum DropBehavior { - #[default] - Restrict, - Cascade, -} - -/// ON CONFLICT action -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum OnConflictAction { - #[default] - None, - Nothing, - Update, -} - -/// GROUPING SET kind -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum GroupingSetKind { - #[default] - Empty, - Simple, - Rollup, - Cube, - Sets, -} - -/// Command type -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum CmdType { - #[default] - Unknown, - Select, - Update, - Insert, - Delete, - Merge, - Utility, - Nothing, -} - -/// MERGE match kind -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum MergeMatchKind { - #[default] - Undefined, - Matched, - NotMatchedBySource, - NotMatchedByTarget, -} - -/// Transaction statement kind -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum TransactionStmtKind { - #[default] - Begin, - Start, - Commit, - Rollback, - Savepoint, - Release, - RollbackTo, - Prepare, - CommitPrepared, - RollbackPrepared, -} - -/// Constraint type -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum ConstrType { - #[default] - Null, - NotNull, - Default, - Identity, - Generated, - Check, - Primary, - Unique, - Exclusion, - Foreign, - AttrDeferrable, - AttrNotDeferrable, - AttrDeferred, - AttrImmediate, -} - -/// DefElem action -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum DefElemAction { - #[default] - Unspec, - Set, - Add, - Drop, -} - -/// Role spec type -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum RoleSpecType { - #[default] - CString, - CurrentRole, - CurrentUser, - SessionUser, - Public, -} - -/// Coercion form -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum CoercionForm { - #[default] - ExplicitCall, - ExplicitCast, - ImplicitCast, - SqlSyntax, -} - -/// Variable SET kind -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum VariableSetKind { - #[default] - Value, - Default, - Current, - Multi, - Reset, - ResetAll, -} - -/// Lock clause strength -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum LockClauseStrength { - #[default] - None, - ForKeyShare, - ForShare, - ForNoKeyUpdate, - ForUpdate, -} - -/// Lock wait policy -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum LockWaitPolicy { - #[default] - Block, - Skip, - Error, -} - -/// View check option -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum ViewCheckOption { - #[default] - NoCheckOption, - Local, - Cascaded, -} - -/// Discard mode -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum DiscardMode { - #[default] - All, - Plans, - Sequences, - Temp, -} - -/// Fetch direction -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum FetchDirection { - #[default] - Forward, - Backward, - Absolute, - Relative, -} - -/// Function parameter mode -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum FunctionParameterMode { - #[default] - In, - Out, - InOut, - Variadic, - Table, -} - -/// ALTER TABLE command type -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum AlterTableType { - #[default] - AddColumn, - AddColumnToView, - ColumnDefault, - CookedColumnDefault, - DropNotNull, - SetNotNull, - DropExpression, - CheckNotNull, - SetStatistics, - SetOptions, - ResetOptions, - SetStorage, - SetCompression, - DropColumn, - AddIndex, - ReAddIndex, - AddConstraint, - ReAddConstraint, - AddIndexConstraint, - AlterConstraint, - ValidateConstraint, - DropConstraint, - ClusterOn, - DropCluster, - SetLogged, - SetUnLogged, - SetAccessMethod, - DropOids, - SetTableSpace, - SetRelOptions, - ResetRelOptions, - ReplaceRelOptions, - EnableTrig, - EnableAlwaysTrig, - EnableReplicaTrig, - DisableTrig, - EnableTrigAll, - DisableTrigAll, - EnableTrigUser, - DisableTrigUser, - EnableRule, - EnableAlwaysRule, - EnableReplicaRule, - DisableRule, - AddInherit, - DropInherit, - AddOf, - DropOf, - ReplicaIdentity, - EnableRowSecurity, - DisableRowSecurity, - ForceRowSecurity, - NoForceRowSecurity, - GenericOptions, - AttachPartition, - DetachPartition, - DetachPartitionFinalize, - AddIdentity, - SetIdentity, - DropIdentity, - ReAddStatistics, -} - -/// GRANT target type -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum GrantTargetType { - #[default] - Object, - AllInSchema, - Defaults, -} - -/// Overriding kind -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum OverridingKind { - #[default] - NotSet, - UserValue, - SystemValue, -} diff --git a/src/lib.rs b/src/lib.rs index a7c6405..f4e1ff6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,7 +38,6 @@ //! ``` //! -pub mod ast; mod bindings; mod bindings_raw; mod error; diff --git a/src/query.rs b/src/query.rs index eeb665f..1f8cd20 100644 --- a/src/query.rs +++ b/src/query.rs @@ -3,7 +3,6 @@ use std::os::raw::c_char; use prost::Message; -use crate::ast; use crate::bindings::*; use crate::error::*; use crate::parse_result::ParseResult; @@ -280,65 +279,3 @@ pub fn split_with_scanner(query: &str) -> Result> { unsafe { pg_query_free_split_result(result) }; split_result } - -/// Parses the given SQL statement into native Rust AST types. -/// -/// This function provides an ergonomic alternative to [`parse`] that returns -/// native Rust types instead of protobuf-generated types. The native types -/// are easier to work with as they don't require unwrapping `Option>` -/// at every level. -/// -/// # Example -/// -/// ```rust -/// use pg_query::ast::{Node, SelectStmt}; -/// -/// let result = pg_query::parse_to_ast("SELECT * FROM users WHERE id = 1").unwrap(); -/// -/// // Direct access to statements without unwrapping -/// for stmt in &result.stmts { -/// if let Node::SelectStmt(select) = &stmt.stmt { -/// // Access fields directly -/// for node in &select.from_clause { -/// if let Node::RangeVar(range_var) = node { -/// println!("Table: {}", range_var.relname); -/// } -/// } -/// } -/// } -/// ``` -pub fn parse_to_ast(statement: &str) -> Result { - let input = CString::new(statement)?; - let result = unsafe { pg_query_parse_protobuf(input.as_ptr()) }; - let parse_result = if !result.error.is_null() { - let message = unsafe { CStr::from_ptr((*result.error).message) }.to_string_lossy().to_string(); - Err(Error::Parse(message)) - } else { - let data = unsafe { std::slice::from_raw_parts(result.parse_tree.data as *const u8, result.parse_tree.len as usize) }; - protobuf::ParseResult::decode(data) - .map_err(Error::Decode) - .map(|pb| ast::ParseResult::from(pb)) - }; - unsafe { pg_query_free_protobuf_parse_result(result) }; - parse_result -} - -/// Converts a native AST parse result back into a SQL string. -/// -/// This function uses the original protobuf stored in the AST to deparse. -/// Note: Any modifications made to the AST fields will NOT be reflected -/// in the deparsed output. This function is primarily useful for round-trip -/// testing and verification. -/// -/// # Example -/// -/// ```rust -/// use pg_query::ast::Node; -/// -/// let result = pg_query::parse_to_ast("SELECT * FROM users").unwrap(); -/// let sql = pg_query::deparse_ast(&result).unwrap(); -/// assert_eq!(sql, "SELECT * FROM users"); -/// ``` -pub fn deparse_ast(parse_result: &ast::ParseResult) -> Result { - deparse(parse_result.as_protobuf()) -} diff --git a/src/raw_parse.rs b/src/raw_parse.rs index 0ad61ae..c2a02ef 100644 --- a/src/raw_parse.rs +++ b/src/raw_parse.rs @@ -27,18 +27,13 @@ pub fn parse_raw(statement: &str) -> Result { let result = unsafe { bindings_raw::pg_query_parse_raw(input.as_ptr()) }; let parse_result = if !result.error.is_null() { - let message = unsafe { CStr::from_ptr((*result.error).message) } - .to_string_lossy() - .to_string(); + let message = unsafe { CStr::from_ptr((*result.error).message) }.to_string_lossy().to_string(); Err(Error::Parse(message)) } else { // Convert the C parse tree to protobuf types let tree = result.tree; let stmts = unsafe { convert_list_to_raw_stmts(tree) }; - let protobuf = protobuf::ParseResult { - version: bindings::PG_VERSION_NUM as i32, - stmts, - }; + let protobuf = protobuf::ParseResult { version: bindings::PG_VERSION_NUM as i32, stmts }; Ok(ParseResult::new(protobuf, String::new())) }; @@ -74,11 +69,7 @@ unsafe fn convert_list_to_raw_stmts(list: *mut bindings_raw::List) -> Vec protobuf::RawStmt { - protobuf::RawStmt { - stmt: convert_node_boxed(raw_stmt.stmt), - stmt_location: raw_stmt.stmt_location, - stmt_len: raw_stmt.stmt_len, - } + protobuf::RawStmt { stmt: convert_node_boxed(raw_stmt.stmt), stmt_location: raw_stmt.stmt_location, stmt_len: raw_stmt.stmt_len } } /// Converts a C Node pointer to a boxed protobuf Node (for fields that expect Option>). @@ -204,9 +195,7 @@ unsafe fn convert_node(node_ptr: *mut bindings_raw::Node) -> Option { - Some(protobuf::node::Node::AStar(protobuf::AStar {})) - } + bindings_raw::NodeTag_T_A_Star => Some(protobuf::node::Node::AStar(protobuf::AStar {})), bindings_raw::NodeTag_T_TypeName => { let tn = node_ptr as *mut bindings_raw::TypeName; Some(protobuf::node::Node::TypeName(convert_type_name(&*tn))) @@ -225,11 +214,7 @@ unsafe fn convert_node(node_ptr: *mut bindings_raw::Node) -> Option { let f = node_ptr as *mut bindings_raw::Float; - let fval = if (*f).fval.is_null() { - String::new() - } else { - CStr::from_ptr((*f).fval).to_string_lossy().to_string() - }; + let fval = if (*f).fval.is_null() { String::new() } else { CStr::from_ptr((*f).fval).to_string_lossy().to_string() }; Some(protobuf::node::Node::Float(protobuf::Float { fval })) } bindings_raw::NodeTag_T_Boolean => { @@ -238,10 +223,7 @@ unsafe fn convert_node(node_ptr: *mut bindings_raw::Node) -> Option { let pr = node_ptr as *mut bindings_raw::ParamRef; - Some(protobuf::node::Node::ParamRef(protobuf::ParamRef { - number: (*pr).number, - location: (*pr).location, - })) + Some(protobuf::node::Node::ParamRef(protobuf::ParamRef { number: (*pr).number, location: (*pr).location })) } bindings_raw::NodeTag_T_WithClause => { let wc = node_ptr as *mut bindings_raw::WithClause; @@ -516,10 +498,7 @@ unsafe fn convert_range_var(rv: &bindings_raw::RangeVar) -> protobuf::RangeVar { } unsafe fn convert_column_ref(cr: &bindings_raw::ColumnRef) -> protobuf::ColumnRef { - protobuf::ColumnRef { - fields: convert_list_to_nodes(cr.fields), - location: cr.location, - } + protobuf::ColumnRef { fields: convert_list_to_nodes(cr.fields), location: cr.location } } unsafe fn convert_res_target(rt: &bindings_raw::ResTarget) -> protobuf::ResTarget { @@ -548,11 +527,7 @@ unsafe fn convert_a_const(aconst: &bindings_raw::A_Const) -> protobuf::AConst { // Check the node tag in the val union to determine the type let node_tag = aconst.val.node.type_; match node_tag { - bindings_raw::NodeTag_T_Integer => { - Some(protobuf::a_const::Val::Ival(protobuf::Integer { - ival: aconst.val.ival.ival, - })) - } + bindings_raw::NodeTag_T_Integer => Some(protobuf::a_const::Val::Ival(protobuf::Integer { ival: aconst.val.ival.ival })), bindings_raw::NodeTag_T_Float => { let fval = if aconst.val.fval.fval.is_null() { std::string::String::new() @@ -561,11 +536,7 @@ unsafe fn convert_a_const(aconst: &bindings_raw::A_Const) -> protobuf::AConst { }; Some(protobuf::a_const::Val::Fval(protobuf::Float { fval })) } - bindings_raw::NodeTag_T_Boolean => { - Some(protobuf::a_const::Val::Boolval(protobuf::Boolean { - boolval: aconst.val.boolval.boolval, - })) - } + bindings_raw::NodeTag_T_Boolean => Some(protobuf::a_const::Val::Boolval(protobuf::Boolean { boolval: aconst.val.boolval.boolval })), bindings_raw::NodeTag_T_String => { let sval = if aconst.val.sval.sval.is_null() { std::string::String::new() @@ -586,11 +557,7 @@ unsafe fn convert_a_const(aconst: &bindings_raw::A_Const) -> protobuf::AConst { } }; - protobuf::AConst { - isnull: aconst.isnull, - val, - location: aconst.location, - } + protobuf::AConst { isnull: aconst.isnull, val, location: aconst.location } } unsafe fn convert_func_call(fc: &bindings_raw::FuncCall) -> protobuf::FuncCall { @@ -631,10 +598,7 @@ unsafe fn convert_type_name(tn: &bindings_raw::TypeName) -> protobuf::TypeName { } unsafe fn convert_alias(alias: &bindings_raw::Alias) -> protobuf::Alias { - protobuf::Alias { - aliasname: convert_c_string(alias.aliasname), - colnames: convert_list_to_nodes(alias.colnames), - } + protobuf::Alias { aliasname: convert_c_string(alias.aliasname), colnames: convert_list_to_nodes(alias.colnames) } } unsafe fn convert_join_expr(je: &bindings_raw::JoinExpr) -> protobuf::JoinExpr { @@ -724,11 +688,7 @@ unsafe fn convert_coalesce_expr(ce: &bindings_raw::CoalesceExpr) -> protobuf::Co } unsafe fn convert_with_clause(wc: &bindings_raw::WithClause) -> protobuf::WithClause { - protobuf::WithClause { - ctes: convert_list_to_nodes(wc.ctes), - recursive: wc.recursive, - location: wc.location, - } + protobuf::WithClause { ctes: convert_list_to_nodes(wc.ctes), recursive: wc.recursive, location: wc.location } } unsafe fn convert_with_clause_opt(wc: *mut bindings_raw::WithClause) -> Option { @@ -898,9 +858,7 @@ unsafe fn convert_def_elem(de: &bindings_raw::DefElem) -> protobuf::DefElem { } unsafe fn convert_string(s: &bindings_raw::String) -> protobuf::String { - protobuf::String { - sval: convert_c_string(s.sval), - } + protobuf::String { sval: convert_c_string(s.sval) } } unsafe fn convert_locking_clause(lc: &bindings_raw::LockingClause) -> protobuf::LockingClause { @@ -924,11 +882,7 @@ unsafe fn convert_min_max_expr(mme: &bindings_raw::MinMaxExpr) -> protobuf::MinM } unsafe fn convert_grouping_set(gs: &bindings_raw::GroupingSet) -> protobuf::GroupingSet { - protobuf::GroupingSet { - kind: gs.kind as i32 + 1, - content: convert_list_to_nodes(gs.content), - location: gs.location, - } + protobuf::GroupingSet { kind: gs.kind as i32 + 1, content: convert_list_to_nodes(gs.content), location: gs.location } } unsafe fn convert_range_subselect(rs: &bindings_raw::RangeSubselect) -> protobuf::RangeSubselect { @@ -940,25 +894,15 @@ unsafe fn convert_range_subselect(rs: &bindings_raw::RangeSubselect) -> protobuf } unsafe fn convert_a_array_expr(ae: &bindings_raw::A_ArrayExpr) -> protobuf::AArrayExpr { - protobuf::AArrayExpr { - elements: convert_list_to_nodes(ae.elements), - location: ae.location, - } + protobuf::AArrayExpr { elements: convert_list_to_nodes(ae.elements), location: ae.location } } unsafe fn convert_a_indirection(ai: &bindings_raw::A_Indirection) -> protobuf::AIndirection { - protobuf::AIndirection { - arg: convert_node_boxed(ai.arg), - indirection: convert_list_to_nodes(ai.indirection), - } + protobuf::AIndirection { arg: convert_node_boxed(ai.arg), indirection: convert_list_to_nodes(ai.indirection) } } unsafe fn convert_a_indices(ai: &bindings_raw::A_Indices) -> protobuf::AIndices { - protobuf::AIndices { - is_slice: ai.is_slice, - lidx: convert_node_boxed(ai.lidx), - uidx: convert_node_boxed(ai.uidx), - } + protobuf::AIndices { is_slice: ai.is_slice, lidx: convert_node_boxed(ai.lidx), uidx: convert_node_boxed(ai.uidx) } } unsafe fn convert_alter_table_stmt(ats: &bindings_raw::AlterTableStmt) -> protobuf::AlterTableStmt { @@ -984,11 +928,7 @@ unsafe fn convert_alter_table_cmd(atc: &bindings_raw::AlterTableCmd) -> protobuf } unsafe fn convert_role_spec(rs: &bindings_raw::RoleSpec) -> protobuf::RoleSpec { - protobuf::RoleSpec { - roletype: rs.roletype as i32 + 1, - rolename: convert_c_string(rs.rolename), - location: rs.location, - } + protobuf::RoleSpec { roletype: rs.roletype as i32 + 1, rolename: convert_c_string(rs.rolename), location: rs.location } } unsafe fn convert_copy_stmt(cs: &bindings_raw::CopyStmt) -> protobuf::CopyStmt { @@ -1005,11 +945,7 @@ unsafe fn convert_copy_stmt(cs: &bindings_raw::CopyStmt) -> protobuf::CopyStmt { } unsafe fn convert_truncate_stmt(ts: &bindings_raw::TruncateStmt) -> protobuf::TruncateStmt { - protobuf::TruncateStmt { - relations: convert_list_to_nodes(ts.relations), - restart_seqs: ts.restart_seqs, - behavior: ts.behavior as i32 + 1, - } + protobuf::TruncateStmt { relations: convert_list_to_nodes(ts.relations), restart_seqs: ts.restart_seqs, behavior: ts.behavior as i32 + 1 } } unsafe fn convert_view_stmt(vs: &bindings_raw::ViewStmt) -> protobuf::ViewStmt { @@ -1024,10 +960,7 @@ unsafe fn convert_view_stmt(vs: &bindings_raw::ViewStmt) -> protobuf::ViewStmt { } unsafe fn convert_explain_stmt(es: &bindings_raw::ExplainStmt) -> protobuf::ExplainStmt { - protobuf::ExplainStmt { - query: convert_node_boxed(es.query), - options: convert_list_to_nodes(es.options), - } + protobuf::ExplainStmt { query: convert_node_boxed(es.query), options: convert_list_to_nodes(es.options) } } unsafe fn convert_create_table_as_stmt(ctas: &bindings_raw::CreateTableAsStmt) -> protobuf::CreateTableAsStmt { @@ -1041,26 +974,15 @@ unsafe fn convert_create_table_as_stmt(ctas: &bindings_raw::CreateTableAsStmt) - } unsafe fn convert_prepare_stmt(ps: &bindings_raw::PrepareStmt) -> protobuf::PrepareStmt { - protobuf::PrepareStmt { - name: convert_c_string(ps.name), - argtypes: convert_list_to_nodes(ps.argtypes), - query: convert_node_boxed(ps.query), - } + protobuf::PrepareStmt { name: convert_c_string(ps.name), argtypes: convert_list_to_nodes(ps.argtypes), query: convert_node_boxed(ps.query) } } unsafe fn convert_execute_stmt(es: &bindings_raw::ExecuteStmt) -> protobuf::ExecuteStmt { - protobuf::ExecuteStmt { - name: convert_c_string(es.name), - params: convert_list_to_nodes(es.params), - } + protobuf::ExecuteStmt { name: convert_c_string(es.name), params: convert_list_to_nodes(es.params) } } unsafe fn convert_deallocate_stmt(ds: &bindings_raw::DeallocateStmt) -> protobuf::DeallocateStmt { - protobuf::DeallocateStmt { - name: convert_c_string(ds.name), - isall: ds.isall, - location: ds.location, - } + protobuf::DeallocateStmt { name: convert_c_string(ds.name), isall: ds.isall, location: ds.location } } unsafe fn convert_set_to_default(std: &bindings_raw::SetToDefault) -> protobuf::SetToDefault { @@ -1074,11 +996,7 @@ unsafe fn convert_set_to_default(std: &bindings_raw::SetToDefault) -> protobuf:: } unsafe fn convert_multi_assign_ref(mar: &bindings_raw::MultiAssignRef) -> protobuf::MultiAssignRef { - protobuf::MultiAssignRef { - source: convert_node_boxed(mar.source), - colno: mar.colno, - ncolumns: mar.ncolumns, - } + protobuf::MultiAssignRef { source: convert_node_boxed(mar.source), colno: mar.colno, ncolumns: mar.ncolumns } } unsafe fn convert_row_expr(re: &bindings_raw::RowExpr) -> protobuf::RowExpr { @@ -1093,11 +1011,7 @@ unsafe fn convert_row_expr(re: &bindings_raw::RowExpr) -> protobuf::RowExpr { } unsafe fn convert_collate_clause(cc: &bindings_raw::CollateClause) -> protobuf::CollateClause { - protobuf::CollateClause { - arg: convert_node_boxed(cc.arg), - collname: convert_list_to_nodes(cc.collname), - location: cc.location, - } + protobuf::CollateClause { arg: convert_node_boxed(cc.arg), collname: convert_list_to_nodes(cc.collname), location: cc.location } } unsafe fn convert_collate_clause_opt(cc: *mut bindings_raw::CollateClause) -> Option> { @@ -1118,11 +1032,7 @@ unsafe fn convert_partition_spec(ps: &bindings_raw::PartitionSpec) -> protobuf:: 'h' => 3, // HASH _ => 0, // UNDEFINED }; - protobuf::PartitionSpec { - strategy, - part_params: convert_list_to_nodes(ps.partParams), - location: ps.location, - } + protobuf::PartitionSpec { strategy, part_params: convert_list_to_nodes(ps.partParams), location: ps.location } } unsafe fn convert_partition_spec_opt(ps: *mut bindings_raw::PartitionSpec) -> Option { @@ -1174,11 +1084,7 @@ unsafe fn convert_partition_range_datum(prd: &bindings_raw::PartitionRangeDatum) bindings_raw::PartitionRangeDatumKind_PARTITION_RANGE_DATUM_MAXVALUE => 3, _ => 0, }; - protobuf::PartitionRangeDatum { - kind, - value: convert_node_boxed(prd.value), - location: prd.location, - } + protobuf::PartitionRangeDatum { kind, value: convert_node_boxed(prd.value), location: prd.location } } unsafe fn convert_cte_search_clause(csc: &bindings_raw::CTESearchClause) -> protobuf::CteSearchClause { diff --git a/tests/ast_tests.rs b/tests/ast_tests.rs deleted file mode 100644 index f985d04..0000000 --- a/tests/ast_tests.rs +++ /dev/null @@ -1,374 +0,0 @@ -#![allow(non_snake_case)] -#![cfg(test)] - -use pg_query::ast::{Node, SelectStmt, InsertStmt, UpdateStmt, DeleteStmt, SetOperation, JoinType}; -use pg_query::{parse_to_ast, deparse_ast}; - -#[macro_use] -mod support; - -/// Test that parse_to_ast successfully parses a simple SELECT query -#[test] -fn it_parses_simple_select_to_ast() { - let result = parse_to_ast("SELECT * FROM users").unwrap(); - assert_eq!(result.stmts.len(), 1); - - if let Node::SelectStmt(select) = &result.stmts[0].stmt { - // Check from_clause contains users table - assert_eq!(select.from_clause.len(), 1); - if let Node::RangeVar(range_var) = &select.from_clause[0] { - assert_eq!(range_var.relname, "users"); - } else { - panic!("Expected RangeVar in from_clause"); - } - - // Check target_list contains * - assert_eq!(select.target_list.len(), 1); - if let Node::ResTarget(res_target) = &select.target_list[0] { - assert!(res_target.val.is_some()); - if let Some(Node::ColumnRef(col_ref)) = &res_target.val { - assert_eq!(col_ref.fields.len(), 1); - assert!(matches!(&col_ref.fields[0], Node::AStar(_))); - } else { - panic!("Expected ColumnRef with AStar"); - } - } else { - panic!("Expected ResTarget in target_list"); - } - } else { - panic!("Expected SelectStmt"); - } -} - -/// Test that parse_to_ast handles errors correctly -#[test] -fn it_handles_parse_errors() { - let result = parse_to_ast("SELECT * FORM users"); - assert!(result.is_err()); -} - -/// Test parsing SELECT with WHERE clause -#[test] -fn it_parses_select_with_where_clause() { - let result = parse_to_ast("SELECT id, name FROM users WHERE id = 1").unwrap(); - assert_eq!(result.stmts.len(), 1); - - if let Node::SelectStmt(select) = &result.stmts[0].stmt { - assert!(select.where_clause.is_some()); - assert_eq!(select.target_list.len(), 2); - assert_eq!(select.from_clause.len(), 1); - } else { - panic!("Expected SelectStmt"); - } -} - -/// Test parsing INSERT statement -#[test] -fn it_parses_insert_to_ast() { - let result = parse_to_ast("INSERT INTO users (name, email) VALUES ('test', 'test@example.com')").unwrap(); - assert_eq!(result.stmts.len(), 1); - - if let Node::InsertStmt(insert) = &result.stmts[0].stmt { - // Check relation - if let Some(rel) = &insert.relation { - assert_eq!(rel.relname, "users"); - } else { - panic!("Expected relation"); - } - - // Check columns - assert_eq!(insert.cols.len(), 2); - } else { - panic!("Expected InsertStmt"); - } -} - -/// Test parsing UPDATE statement -#[test] -fn it_parses_update_to_ast() { - let result = parse_to_ast("UPDATE users SET name = 'bob' WHERE id = 1").unwrap(); - assert_eq!(result.stmts.len(), 1); - - if let Node::UpdateStmt(update) = &result.stmts[0].stmt { - // Check relation - if let Some(rel) = &update.relation { - assert_eq!(rel.relname, "users"); - } else { - panic!("Expected relation"); - } - - // Check target_list (SET clause) - assert_eq!(update.target_list.len(), 1); - - // Check where_clause - assert!(update.where_clause.is_some()); - } else { - panic!("Expected UpdateStmt"); - } -} - -/// Test parsing DELETE statement -#[test] -fn it_parses_delete_to_ast() { - let result = parse_to_ast("DELETE FROM users WHERE id = 1").unwrap(); - assert_eq!(result.stmts.len(), 1); - - if let Node::DeleteStmt(delete) = &result.stmts[0].stmt { - // Check relation - if let Some(rel) = &delete.relation { - assert_eq!(rel.relname, "users"); - } else { - panic!("Expected relation"); - } - - // Check where_clause - assert!(delete.where_clause.is_some()); - } else { - panic!("Expected DeleteStmt"); - } -} - -/// Test parsing SELECT with JOIN -#[test] -fn it_parses_select_with_join() { - let result = parse_to_ast("SELECT u.id, o.total FROM users u INNER JOIN orders o ON u.id = o.user_id").unwrap(); - assert_eq!(result.stmts.len(), 1); - - if let Node::SelectStmt(select) = &result.stmts[0].stmt { - assert_eq!(select.from_clause.len(), 1); - - if let Node::JoinExpr(join) = &select.from_clause[0] { - assert_eq!(join.jointype, JoinType::Inner); - assert!(join.larg.is_some()); - assert!(join.rarg.is_some()); - assert!(join.quals.is_some()); - } else { - panic!("Expected JoinExpr in from_clause"); - } - } else { - panic!("Expected SelectStmt"); - } -} - -/// Test parsing UNION query -#[test] -fn it_parses_union_query() { - let result = parse_to_ast("SELECT id FROM users UNION SELECT id FROM admins").unwrap(); - assert_eq!(result.stmts.len(), 1); - - if let Node::SelectStmt(select) = &result.stmts[0].stmt { - assert_eq!(select.op, SetOperation::Union); - assert!(select.larg.is_some()); - assert!(select.rarg.is_some()); - } else { - panic!("Expected SelectStmt"); - } -} - -/// Test round-trip: parse to AST then deparse back to SQL -#[test] -fn it_roundtrips_simple_select() { - let original = "SELECT * FROM users"; - let ast = parse_to_ast(original).unwrap(); - let deparsed = deparse_ast(&ast).unwrap(); - assert_eq!(deparsed, original); -} - -/// Test round-trip: SELECT with WHERE clause -#[test] -fn it_roundtrips_select_with_where() { - let original = "SELECT id, name FROM users WHERE id = 1"; - let ast = parse_to_ast(original).unwrap(); - let deparsed = deparse_ast(&ast).unwrap(); - assert_eq!(deparsed, original); -} - -/// Test round-trip: INSERT statement -#[test] -fn it_roundtrips_insert() { - let original = "INSERT INTO users (name) VALUES ('test')"; - let ast = parse_to_ast(original).unwrap(); - let deparsed = deparse_ast(&ast).unwrap(); - assert_eq!(deparsed, original); -} - -/// Test round-trip: UPDATE statement -#[test] -fn it_roundtrips_update() { - let original = "UPDATE users SET name = 'bob' WHERE id = 1"; - let ast = parse_to_ast(original).unwrap(); - let deparsed = deparse_ast(&ast).unwrap(); - assert_eq!(deparsed, original); -} - -/// Test round-trip: DELETE statement -#[test] -fn it_roundtrips_delete() { - let original = "DELETE FROM users WHERE id = 1"; - let ast = parse_to_ast(original).unwrap(); - let deparsed = deparse_ast(&ast).unwrap(); - assert_eq!(deparsed, original); -} - -/// Test round-trip: SELECT with JOIN -#[test] -fn it_roundtrips_join() { - let original = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id"; - let ast = parse_to_ast(original).unwrap(); - let deparsed = deparse_ast(&ast).unwrap(); - assert_eq!(deparsed, original); -} - -/// Test round-trip: UNION query -#[test] -fn it_roundtrips_union() { - let original = "SELECT id FROM users UNION SELECT id FROM admins"; - let ast = parse_to_ast(original).unwrap(); - let deparsed = deparse_ast(&ast).unwrap(); - assert_eq!(deparsed, original); -} - -/// Test round-trip: complex SELECT -#[test] -fn it_roundtrips_complex_select() { - let original = "SELECT u.id, u.name, count(*) AS order_count FROM users u LEFT JOIN orders o ON u.id = o.user_id WHERE u.active = true GROUP BY u.id, u.name HAVING count(*) > 0 ORDER BY order_count DESC LIMIT 10"; - let ast = parse_to_ast(original).unwrap(); - let deparsed = deparse_ast(&ast).unwrap(); - assert_eq!(deparsed, original); -} - -/// Test round-trip: WITH clause (CTE) -#[test] -fn it_roundtrips_cte() { - let original = "WITH active_users AS (SELECT * FROM users WHERE active = true) SELECT * FROM active_users"; - let ast = parse_to_ast(original).unwrap(); - let deparsed = deparse_ast(&ast).unwrap(); - assert_eq!(deparsed, original); -} - -/// Test round-trip: CREATE TABLE -#[test] -fn it_roundtrips_create_table() { - let original = "CREATE TABLE test (id integer, name text)"; - let ast = parse_to_ast(original).unwrap(); - let deparsed = deparse_ast(&ast).unwrap(); - // pg_query uses "int" instead of "integer" in its canonical form - assert_eq!(deparsed, "CREATE TABLE test (id int, name text)"); -} - -/// Test round-trip: DROP TABLE -#[test] -fn it_roundtrips_drop_table() { - let original = "DROP TABLE users"; - let ast = parse_to_ast(original).unwrap(); - let deparsed = deparse_ast(&ast).unwrap(); - assert_eq!(deparsed, original); -} - -/// Test round-trip: CREATE INDEX -#[test] -fn it_roundtrips_create_index() { - let original = "CREATE INDEX idx_users_name ON users (name)"; - let ast = parse_to_ast(original).unwrap(); - let deparsed = deparse_ast(&ast).unwrap(); - // pg_query adds explicit "USING btree" in its canonical form - assert_eq!(deparsed, "CREATE INDEX idx_users_name ON users USING btree (name)"); -} - -/// Test that the AST types are ergonomic (no deep Option> unwrapping) -#[test] -fn ast_types_are_ergonomic() { - let result = parse_to_ast("SELECT id FROM users WHERE active = true").unwrap(); - - // Direct pattern matching without .as_ref().unwrap() chains - if let Node::SelectStmt(select) = &result.stmts[0].stmt { - // Direct access to from_clause vector - for table in &select.from_clause { - if let Node::RangeVar(rv) = table { - assert_eq!(rv.relname, "users"); - } - } - - // Direct access to target_list - for target in &select.target_list { - if let Node::ResTarget(rt) = target { - if let Some(Node::ColumnRef(cr)) = &rt.val { - // Can access fields directly - assert!(!cr.fields.is_empty()); - } - } - } - } -} - -/// Test parsing multiple statements -#[test] -fn it_parses_multiple_statements() { - let result = parse_to_ast("SELECT 1; SELECT 2; SELECT 3").unwrap(); - assert_eq!(result.stmts.len(), 3); - - for stmt in &result.stmts { - assert!(matches!(&stmt.stmt, Node::SelectStmt(_))); - } -} - -/// Test parsing empty query (comment only) -#[test] -fn it_parses_empty_query() { - let result = parse_to_ast("-- just a comment").unwrap(); - assert_eq!(result.stmts.len(), 0); -} - -/// Test round-trip: subquery in SELECT -#[test] -fn it_roundtrips_subquery() { - let original = "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)"; - let ast = parse_to_ast(original).unwrap(); - let deparsed = deparse_ast(&ast).unwrap(); - assert_eq!(deparsed, original); -} - -/// Test round-trip: aggregate functions -#[test] -fn it_roundtrips_aggregates() { - let original = "SELECT count(*), sum(amount), avg(price) FROM orders"; - let ast = parse_to_ast(original).unwrap(); - let deparsed = deparse_ast(&ast).unwrap(); - assert_eq!(deparsed, original); -} - -/// Test round-trip: CASE expression -#[test] -fn it_roundtrips_case_expression() { - let original = "SELECT CASE WHEN x > 0 THEN 'positive' ELSE 'non-positive' END FROM t"; - let ast = parse_to_ast(original).unwrap(); - let deparsed = deparse_ast(&ast).unwrap(); - assert_eq!(deparsed, original); -} - -/// Test round-trip: INSERT with RETURNING -#[test] -fn it_roundtrips_insert_returning() { - let original = "INSERT INTO users (name) VALUES ('test') RETURNING id"; - let ast = parse_to_ast(original).unwrap(); - let deparsed = deparse_ast(&ast).unwrap(); - assert_eq!(deparsed, original); -} - -/// Test round-trip: UPDATE with FROM -#[test] -fn it_roundtrips_update_from() { - let original = "UPDATE users SET name = o.name FROM other_users o WHERE users.id = o.id"; - let ast = parse_to_ast(original).unwrap(); - let deparsed = deparse_ast(&ast).unwrap(); - assert_eq!(deparsed, original); -} - -/// Test round-trip: DELETE with USING -#[test] -fn it_roundtrips_delete_using() { - let original = "DELETE FROM users USING orders WHERE users.id = orders.user_id"; - let ast = parse_to_ast(original).unwrap(); - let deparsed = deparse_ast(&ast).unwrap(); - assert_eq!(deparsed, original); -} diff --git a/tests/raw_parse_tests.rs b/tests/raw_parse_tests.rs index c60eb99..f6d2427 100644 --- a/tests/raw_parse_tests.rs +++ b/tests/raw_parse_tests.rs @@ -1,8 +1,8 @@ #![allow(non_snake_case)] #![cfg(test)] -use pg_query::{parse, parse_raw, Error}; use pg_query::protobuf::{a_const, node, ParseResult as ProtobufParseResult}; +use pg_query::{parse, parse_raw, Error}; #[macro_use] mod support; @@ -934,7 +934,8 @@ fn it_parses_named_window() { /// Test LAG and LEAD functions #[test] fn it_parses_lag_lead() { - let query = "SELECT date, price, LAG(price, 1) OVER (ORDER BY date) AS prev_price, LEAD(price, 1) OVER (ORDER BY date) AS next_price FROM stock_prices"; + let query = + "SELECT date, price, LAG(price, 1) OVER (ORDER BY date) AS prev_price, LEAD(price, 1) OVER (ORDER BY date) AS next_price FROM stock_prices"; let raw_result = parse_raw(query).unwrap(); let proto_result = parse(query).unwrap(); From 11815b1d928657d96c74b2f2d950a59db36ec994 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 31 Dec 2025 12:07:03 -0800 Subject: [PATCH 06/17] Export version --- src/lib.rs | 3 +++ tests/raw_parse_tests.rs | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index f4e1ff6..01bfa6b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -67,6 +67,9 @@ pub use truncate::*; pub use protobuf::Node; +/// PostgreSQL version number (e.g., 170007 for PostgreSQL 17.0.7) +pub use bindings::PG_VERSION_NUM; + // From Postgres source: src/include/storage/lockdefs.h #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] diff --git a/tests/raw_parse_tests.rs b/tests/raw_parse_tests.rs index f6d2427..f8f300c 100644 --- a/tests/raw_parse_tests.rs +++ b/tests/raw_parse_tests.rs @@ -4,6 +4,25 @@ use pg_query::protobuf::{a_const, node, ParseResult as ProtobufParseResult}; use pg_query::{parse, parse_raw, Error}; +/// Test that parse_raw results can be deparsed back to SQL +#[test] +fn it_deparses_parse_raw_result() { + let query = "SELECT * FROM users"; + let result = parse_raw(query).unwrap(); + + // Print version info for debugging + eprintln!("parse_raw protobuf version: {}", result.protobuf.version); + + // Compare with regular parse + let regular_result = parse(query).unwrap(); + eprintln!("parse protobuf version: {}", regular_result.protobuf.version); + + assert_eq!(result.protobuf.version, regular_result.protobuf.version, "Version mismatch between parse_raw and parse"); + + let deparsed = result.deparse().unwrap(); + assert_eq!(deparsed, query); +} + #[macro_use] mod support; From 9554ab369030a416eb921ac474118f011834db3d Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 31 Dec 2025 12:49:22 -0800 Subject: [PATCH 07/17] more nodes --- build.rs | 13 ++ src/raw_parse.rs | 376 ++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 386 insertions(+), 3 deletions(-) diff --git a/build.rs b/build.rs index 27658a1..c12507f 100644 --- a/build.rs +++ b/build.rs @@ -172,11 +172,24 @@ fn main() -> Result<(), Box> { .allowlist_type("CreateTableAsStmt") .allowlist_type("RefreshMatViewStmt") .allowlist_type("VacuumStmt") + .allowlist_type("VacuumRelation") + .allowlist_type("LockStmt") + .allowlist_type("AlterOwnerStmt") + .allowlist_type("AlterSeqStmt") + .allowlist_type("CreateEnumStmt") .allowlist_type("DoStmt") .allowlist_type("RenameStmt") .allowlist_type("NotifyStmt") .allowlist_type("ListenStmt") .allowlist_type("UnlistenStmt") + .allowlist_type("DiscardStmt") + .allowlist_type("CoerceToDomain") + .allowlist_type("CompositeTypeStmt") + .allowlist_type("CreateExtensionStmt") + .allowlist_type("CreatePublicationStmt") + .allowlist_type("AlterPublicationStmt") + .allowlist_type("CreateSubscriptionStmt") + .allowlist_type("AlterSubscriptionStmt") .allowlist_type("PrepareStmt") .allowlist_type("ExecuteStmt") .allowlist_type("DeallocateStmt") diff --git a/src/raw_parse.rs b/src/raw_parse.rs index c2a02ef..8509fdb 100644 --- a/src/raw_parse.rs +++ b/src/raw_parse.rs @@ -325,6 +325,126 @@ unsafe fn convert_node(node_ptr: *mut bindings_raw::Node) -> Option { + let ts = node_ptr as *mut bindings_raw::TransactionStmt; + Some(protobuf::node::Node::TransactionStmt(convert_transaction_stmt(&*ts))) + } + bindings_raw::NodeTag_T_VacuumStmt => { + let vs = node_ptr as *mut bindings_raw::VacuumStmt; + Some(protobuf::node::Node::VacuumStmt(convert_vacuum_stmt(&*vs))) + } + bindings_raw::NodeTag_T_VacuumRelation => { + let vr = node_ptr as *mut bindings_raw::VacuumRelation; + Some(protobuf::node::Node::VacuumRelation(convert_vacuum_relation(&*vr))) + } + bindings_raw::NodeTag_T_VariableSetStmt => { + let vss = node_ptr as *mut bindings_raw::VariableSetStmt; + Some(protobuf::node::Node::VariableSetStmt(convert_variable_set_stmt(&*vss))) + } + bindings_raw::NodeTag_T_VariableShowStmt => { + let vss = node_ptr as *mut bindings_raw::VariableShowStmt; + Some(protobuf::node::Node::VariableShowStmt(convert_variable_show_stmt(&*vss))) + } + bindings_raw::NodeTag_T_CreateSeqStmt => { + let css = node_ptr as *mut bindings_raw::CreateSeqStmt; + Some(protobuf::node::Node::CreateSeqStmt(convert_create_seq_stmt(&*css))) + } + bindings_raw::NodeTag_T_DoStmt => { + let ds = node_ptr as *mut bindings_raw::DoStmt; + Some(protobuf::node::Node::DoStmt(convert_do_stmt(&*ds))) + } + bindings_raw::NodeTag_T_LockStmt => { + let ls = node_ptr as *mut bindings_raw::LockStmt; + Some(protobuf::node::Node::LockStmt(convert_lock_stmt(&*ls))) + } + bindings_raw::NodeTag_T_CreateSchemaStmt => { + let css = node_ptr as *mut bindings_raw::CreateSchemaStmt; + Some(protobuf::node::Node::CreateSchemaStmt(convert_create_schema_stmt(&*css))) + } + bindings_raw::NodeTag_T_RenameStmt => { + let rs = node_ptr as *mut bindings_raw::RenameStmt; + Some(protobuf::node::Node::RenameStmt(Box::new(convert_rename_stmt(&*rs)))) + } + bindings_raw::NodeTag_T_CreateFunctionStmt => { + let cfs = node_ptr as *mut bindings_raw::CreateFunctionStmt; + Some(protobuf::node::Node::CreateFunctionStmt(Box::new(convert_create_function_stmt(&*cfs)))) + } + bindings_raw::NodeTag_T_AlterOwnerStmt => { + let aos = node_ptr as *mut bindings_raw::AlterOwnerStmt; + Some(protobuf::node::Node::AlterOwnerStmt(Box::new(convert_alter_owner_stmt(&*aos)))) + } + bindings_raw::NodeTag_T_AlterSeqStmt => { + let ass = node_ptr as *mut bindings_raw::AlterSeqStmt; + Some(protobuf::node::Node::AlterSeqStmt(convert_alter_seq_stmt(&*ass))) + } + bindings_raw::NodeTag_T_CreateEnumStmt => { + let ces = node_ptr as *mut bindings_raw::CreateEnumStmt; + Some(protobuf::node::Node::CreateEnumStmt(convert_create_enum_stmt(&*ces))) + } + bindings_raw::NodeTag_T_ObjectWithArgs => { + let owa = node_ptr as *mut bindings_raw::ObjectWithArgs; + Some(protobuf::node::Node::ObjectWithArgs(convert_object_with_args(&*owa))) + } + bindings_raw::NodeTag_T_FunctionParameter => { + let fp = node_ptr as *mut bindings_raw::FunctionParameter; + Some(protobuf::node::Node::FunctionParameter(Box::new(convert_function_parameter(&*fp)))) + } + bindings_raw::NodeTag_T_NotifyStmt => { + let ns = node_ptr as *mut bindings_raw::NotifyStmt; + Some(protobuf::node::Node::NotifyStmt(convert_notify_stmt(&*ns))) + } + bindings_raw::NodeTag_T_ListenStmt => { + let ls = node_ptr as *mut bindings_raw::ListenStmt; + Some(protobuf::node::Node::ListenStmt(convert_listen_stmt(&*ls))) + } + bindings_raw::NodeTag_T_UnlistenStmt => { + let us = node_ptr as *mut bindings_raw::UnlistenStmt; + Some(protobuf::node::Node::UnlistenStmt(convert_unlisten_stmt(&*us))) + } + bindings_raw::NodeTag_T_DiscardStmt => { + let ds = node_ptr as *mut bindings_raw::DiscardStmt; + Some(protobuf::node::Node::DiscardStmt(convert_discard_stmt(&*ds))) + } + bindings_raw::NodeTag_T_CollateClause => { + let cc = node_ptr as *mut bindings_raw::CollateClause; + Some(protobuf::node::Node::CollateClause(Box::new(convert_collate_clause(&*cc)))) + } + bindings_raw::NodeTag_T_CoerceToDomain => { + let ctd = node_ptr as *mut bindings_raw::CoerceToDomain; + Some(protobuf::node::Node::CoerceToDomain(Box::new(convert_coerce_to_domain(&*ctd)))) + } + bindings_raw::NodeTag_T_CompositeTypeStmt => { + let cts = node_ptr as *mut bindings_raw::CompositeTypeStmt; + Some(protobuf::node::Node::CompositeTypeStmt(convert_composite_type_stmt(&*cts))) + } + bindings_raw::NodeTag_T_CreateDomainStmt => { + let cds = node_ptr as *mut bindings_raw::CreateDomainStmt; + Some(protobuf::node::Node::CreateDomainStmt(Box::new(convert_create_domain_stmt(&*cds)))) + } + bindings_raw::NodeTag_T_CreateExtensionStmt => { + let ces = node_ptr as *mut bindings_raw::CreateExtensionStmt; + Some(protobuf::node::Node::CreateExtensionStmt(convert_create_extension_stmt(&*ces))) + } + bindings_raw::NodeTag_T_CreatePublicationStmt => { + let cps = node_ptr as *mut bindings_raw::CreatePublicationStmt; + Some(protobuf::node::Node::CreatePublicationStmt(convert_create_publication_stmt(&*cps))) + } + bindings_raw::NodeTag_T_AlterPublicationStmt => { + let aps = node_ptr as *mut bindings_raw::AlterPublicationStmt; + Some(protobuf::node::Node::AlterPublicationStmt(convert_alter_publication_stmt(&*aps))) + } + bindings_raw::NodeTag_T_CreateSubscriptionStmt => { + let css = node_ptr as *mut bindings_raw::CreateSubscriptionStmt; + Some(protobuf::node::Node::CreateSubscriptionStmt(convert_create_subscription_stmt(&*css))) + } + bindings_raw::NodeTag_T_AlterSubscriptionStmt => { + let ass = node_ptr as *mut bindings_raw::AlterSubscriptionStmt; + Some(protobuf::node::Node::AlterSubscriptionStmt(convert_alter_subscription_stmt(&*ass))) + } + bindings_raw::NodeTag_T_CreateTrigStmt => { + let cts = node_ptr as *mut bindings_raw::CreateTrigStmt; + Some(protobuf::node::Node::CreateTrigStmt(Box::new(convert_create_trig_stmt(&*cts)))) + } _ => { // For unhandled node types, return None // In the future, we could add more node types here @@ -342,6 +462,8 @@ unsafe fn convert_list(list: &bindings_raw::List) -> protobuf::List { } /// Converts a PostgreSQL List pointer to a Vec of protobuf Nodes. +/// Note: Preserves placeholder nodes (Node { node: None }) for cases like DISTINCT +/// where the list must retain its structure even if content is not recognized. unsafe fn convert_list_to_nodes(list: *mut bindings_raw::List) -> Vec { if list.is_null() { return Vec::new(); @@ -355,9 +477,11 @@ unsafe fn convert_list_to_nodes(list: *mut bindings_raw::List) -> Vec protobuf::TransactionStmt { + protobuf::TransactionStmt { + kind: ts.kind as i32 + 1, // Protobuf enums have UNDEFINED=0 + options: convert_list_to_nodes(ts.options), + savepoint_name: convert_c_string(ts.savepoint_name), + gid: convert_c_string(ts.gid), + chain: ts.chain, + location: ts.location, + } +} + +unsafe fn convert_vacuum_stmt(vs: &bindings_raw::VacuumStmt) -> protobuf::VacuumStmt { + protobuf::VacuumStmt { options: convert_list_to_nodes(vs.options), rels: convert_list_to_nodes(vs.rels), is_vacuumcmd: vs.is_vacuumcmd } +} + +unsafe fn convert_vacuum_relation(vr: &bindings_raw::VacuumRelation) -> protobuf::VacuumRelation { + protobuf::VacuumRelation { + relation: if vr.relation.is_null() { None } else { Some(convert_range_var(&*vr.relation)) }, + oid: vr.oid, + va_cols: convert_list_to_nodes(vr.va_cols), + } +} + +unsafe fn convert_variable_set_stmt(vss: &bindings_raw::VariableSetStmt) -> protobuf::VariableSetStmt { + protobuf::VariableSetStmt { + kind: vss.kind as i32 + 1, // Protobuf enums have UNDEFINED=0 + name: convert_c_string(vss.name), + args: convert_list_to_nodes(vss.args), + is_local: vss.is_local, + } +} + +unsafe fn convert_variable_show_stmt(vss: &bindings_raw::VariableShowStmt) -> protobuf::VariableShowStmt { + protobuf::VariableShowStmt { name: convert_c_string(vss.name) } +} + +unsafe fn convert_create_seq_stmt(css: &bindings_raw::CreateSeqStmt) -> protobuf::CreateSeqStmt { + protobuf::CreateSeqStmt { + sequence: if css.sequence.is_null() { None } else { Some(convert_range_var(&*css.sequence)) }, + options: convert_list_to_nodes(css.options), + owner_id: css.ownerId, + for_identity: css.for_identity, + if_not_exists: css.if_not_exists, + } +} + +unsafe fn convert_do_stmt(ds: &bindings_raw::DoStmt) -> protobuf::DoStmt { + protobuf::DoStmt { args: convert_list_to_nodes(ds.args) } +} + +unsafe fn convert_lock_stmt(ls: &bindings_raw::LockStmt) -> protobuf::LockStmt { + protobuf::LockStmt { relations: convert_list_to_nodes(ls.relations), mode: ls.mode, nowait: ls.nowait } +} + +unsafe fn convert_create_schema_stmt(css: &bindings_raw::CreateSchemaStmt) -> protobuf::CreateSchemaStmt { + protobuf::CreateSchemaStmt { + schemaname: convert_c_string(css.schemaname), + authrole: if css.authrole.is_null() { None } else { Some(convert_role_spec(&*css.authrole)) }, + schema_elts: convert_list_to_nodes(css.schemaElts), + if_not_exists: css.if_not_exists, + } +} + +unsafe fn convert_rename_stmt(rs: &bindings_raw::RenameStmt) -> protobuf::RenameStmt { + protobuf::RenameStmt { + rename_type: rs.renameType as i32 + 1, // Protobuf ObjectType has UNDEFINED=0 + relation_type: rs.relationType as i32 + 1, + relation: if rs.relation.is_null() { None } else { Some(convert_range_var(&*rs.relation)) }, + object: convert_node_boxed(rs.object), + subname: convert_c_string(rs.subname), + newname: convert_c_string(rs.newname), + behavior: rs.behavior as i32 + 1, + missing_ok: rs.missing_ok, + } +} + +unsafe fn convert_create_function_stmt(cfs: &bindings_raw::CreateFunctionStmt) -> protobuf::CreateFunctionStmt { + protobuf::CreateFunctionStmt { + is_procedure: cfs.is_procedure, + replace: cfs.replace, + funcname: convert_list_to_nodes(cfs.funcname), + parameters: convert_list_to_nodes(cfs.parameters), + return_type: if cfs.returnType.is_null() { None } else { Some(convert_type_name(&*cfs.returnType)) }, + options: convert_list_to_nodes(cfs.options), + sql_body: convert_node_boxed(cfs.sql_body), + } +} + +unsafe fn convert_alter_owner_stmt(aos: &bindings_raw::AlterOwnerStmt) -> protobuf::AlterOwnerStmt { + protobuf::AlterOwnerStmt { + object_type: aos.objectType as i32 + 1, // Protobuf ObjectType has UNDEFINED=0 + relation: if aos.relation.is_null() { None } else { Some(convert_range_var(&*aos.relation)) }, + object: convert_node_boxed(aos.object), + newowner: if aos.newowner.is_null() { None } else { Some(convert_role_spec(&*aos.newowner)) }, + } +} + +unsafe fn convert_alter_seq_stmt(ass: &bindings_raw::AlterSeqStmt) -> protobuf::AlterSeqStmt { + protobuf::AlterSeqStmt { + sequence: if ass.sequence.is_null() { None } else { Some(convert_range_var(&*ass.sequence)) }, + options: convert_list_to_nodes(ass.options), + for_identity: ass.for_identity, + missing_ok: ass.missing_ok, + } +} + +unsafe fn convert_create_enum_stmt(ces: &bindings_raw::CreateEnumStmt) -> protobuf::CreateEnumStmt { + protobuf::CreateEnumStmt { type_name: convert_list_to_nodes(ces.typeName), vals: convert_list_to_nodes(ces.vals) } +} + +unsafe fn convert_object_with_args(owa: &bindings_raw::ObjectWithArgs) -> protobuf::ObjectWithArgs { + protobuf::ObjectWithArgs { + objname: convert_list_to_nodes(owa.objname), + objargs: convert_list_to_nodes(owa.objargs), + objfuncargs: convert_list_to_nodes(owa.objfuncargs), + args_unspecified: owa.args_unspecified, + } +} + +unsafe fn convert_function_parameter(fp: &bindings_raw::FunctionParameter) -> protobuf::FunctionParameter { + protobuf::FunctionParameter { + name: convert_c_string(fp.name), + arg_type: if fp.argType.is_null() { None } else { Some(convert_type_name(&*fp.argType)) }, + mode: fp.mode as i32 + 1, // Protobuf FunctionParameterMode has UNDEFINED=0 + defexpr: convert_node_boxed(fp.defexpr), + } +} + +unsafe fn convert_notify_stmt(ns: &bindings_raw::NotifyStmt) -> protobuf::NotifyStmt { + protobuf::NotifyStmt { conditionname: convert_c_string(ns.conditionname), payload: convert_c_string(ns.payload) } +} + +unsafe fn convert_listen_stmt(ls: &bindings_raw::ListenStmt) -> protobuf::ListenStmt { + protobuf::ListenStmt { conditionname: convert_c_string(ls.conditionname) } +} + +unsafe fn convert_unlisten_stmt(us: &bindings_raw::UnlistenStmt) -> protobuf::UnlistenStmt { + protobuf::UnlistenStmt { conditionname: convert_c_string(us.conditionname) } +} + +unsafe fn convert_discard_stmt(ds: &bindings_raw::DiscardStmt) -> protobuf::DiscardStmt { + protobuf::DiscardStmt { + target: ds.target as i32 + 1, // DiscardMode enum + } +} + +unsafe fn convert_coerce_to_domain(ctd: &bindings_raw::CoerceToDomain) -> protobuf::CoerceToDomain { + // xpr is an embedded Expr, convert it as a node pointer + let xpr_ptr = &ctd.xpr as *const bindings_raw::Expr as *mut bindings_raw::Node; + protobuf::CoerceToDomain { + xpr: convert_node_boxed(xpr_ptr), + arg: convert_node_boxed(ctd.arg as *mut bindings_raw::Node), + resulttype: ctd.resulttype, + resulttypmod: ctd.resulttypmod, + resultcollid: ctd.resultcollid, + coercionformat: ctd.coercionformat as i32 + 1, + location: ctd.location, + } +} + +unsafe fn convert_composite_type_stmt(cts: &bindings_raw::CompositeTypeStmt) -> protobuf::CompositeTypeStmt { + protobuf::CompositeTypeStmt { + typevar: if cts.typevar.is_null() { None } else { Some(convert_range_var(&*cts.typevar)) }, + coldeflist: convert_list_to_nodes(cts.coldeflist), + } +} + +unsafe fn convert_create_domain_stmt(cds: &bindings_raw::CreateDomainStmt) -> protobuf::CreateDomainStmt { + protobuf::CreateDomainStmt { + domainname: convert_list_to_nodes(cds.domainname), + type_name: if cds.typeName.is_null() { None } else { Some(convert_type_name(&*cds.typeName)) }, + coll_clause: convert_collate_clause_opt(cds.collClause), + constraints: convert_list_to_nodes(cds.constraints), + } +} + +unsafe fn convert_create_extension_stmt(ces: &bindings_raw::CreateExtensionStmt) -> protobuf::CreateExtensionStmt { + protobuf::CreateExtensionStmt { + extname: convert_c_string(ces.extname), + if_not_exists: ces.if_not_exists, + options: convert_list_to_nodes(ces.options), + } +} + +unsafe fn convert_create_publication_stmt(cps: &bindings_raw::CreatePublicationStmt) -> protobuf::CreatePublicationStmt { + protobuf::CreatePublicationStmt { + pubname: convert_c_string(cps.pubname), + options: convert_list_to_nodes(cps.options), + pubobjects: convert_list_to_nodes(cps.pubobjects), + for_all_tables: cps.for_all_tables, + } +} + +unsafe fn convert_alter_publication_stmt(aps: &bindings_raw::AlterPublicationStmt) -> protobuf::AlterPublicationStmt { + protobuf::AlterPublicationStmt { + pubname: convert_c_string(aps.pubname), + options: convert_list_to_nodes(aps.options), + pubobjects: convert_list_to_nodes(aps.pubobjects), + for_all_tables: aps.for_all_tables, + action: aps.action as i32 + 1, + } +} + +unsafe fn convert_create_subscription_stmt(css: &bindings_raw::CreateSubscriptionStmt) -> protobuf::CreateSubscriptionStmt { + protobuf::CreateSubscriptionStmt { + subname: convert_c_string(css.subname), + conninfo: convert_c_string(css.conninfo), + publication: convert_list_to_nodes(css.publication), + options: convert_list_to_nodes(css.options), + } +} + +unsafe fn convert_alter_subscription_stmt(ass: &bindings_raw::AlterSubscriptionStmt) -> protobuf::AlterSubscriptionStmt { + protobuf::AlterSubscriptionStmt { + kind: ass.kind as i32 + 1, + subname: convert_c_string(ass.subname), + conninfo: convert_c_string(ass.conninfo), + publication: convert_list_to_nodes(ass.publication), + options: convert_list_to_nodes(ass.options), + } +} + +unsafe fn convert_create_trig_stmt(cts: &bindings_raw::CreateTrigStmt) -> protobuf::CreateTrigStmt { + protobuf::CreateTrigStmt { + replace: cts.replace, + isconstraint: cts.isconstraint, + trigname: convert_c_string(cts.trigname), + relation: if cts.relation.is_null() { None } else { Some(convert_range_var(&*cts.relation)) }, + funcname: convert_list_to_nodes(cts.funcname), + args: convert_list_to_nodes(cts.args), + row: cts.row, + timing: cts.timing as i32, + events: cts.events as i32, + columns: convert_list_to_nodes(cts.columns), + when_clause: convert_node_boxed(cts.whenClause), + transition_rels: convert_list_to_nodes(cts.transitionRels), + deferrable: cts.deferrable, + initdeferred: cts.initdeferred, + constrrel: if cts.constrrel.is_null() { None } else { Some(convert_range_var(&*cts.constrrel)) }, + } +} + // ============================================================================ // Utility Functions // ============================================================================ From 4e2e92a687d30a1c8e9f93333d6a0835bff1b7ec Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 31 Dec 2025 13:14:53 -0800 Subject: [PATCH 08/17] test refactor --- build.rs | 3 + src/raw_parse.rs | 22 + tests/raw_parse/basic.rs | 137 ++ tests/raw_parse/ddl.rs | 423 ++++++ tests/raw_parse/dml.rs | 493 +++++++ tests/raw_parse/expressions.rs | 470 +++++++ tests/raw_parse/mod.rs | 43 + tests/raw_parse/select.rs | 855 ++++++++++++ tests/raw_parse/statements.rs | 450 +++++++ tests/raw_parse_tests.rs | 2284 +------------------------------- 10 files changed, 2912 insertions(+), 2268 deletions(-) create mode 100644 tests/raw_parse/basic.rs create mode 100644 tests/raw_parse/ddl.rs create mode 100644 tests/raw_parse/dml.rs create mode 100644 tests/raw_parse/expressions.rs create mode 100644 tests/raw_parse/mod.rs create mode 100644 tests/raw_parse/select.rs create mode 100644 tests/raw_parse/statements.rs diff --git a/build.rs b/build.rs index c12507f..2accb5d 100644 --- a/build.rs +++ b/build.rs @@ -188,6 +188,9 @@ fn main() -> Result<(), Box> { .allowlist_type("CreateExtensionStmt") .allowlist_type("CreatePublicationStmt") .allowlist_type("AlterPublicationStmt") + .allowlist_type("PublicationObjSpec") + .allowlist_type("PublicationTable") + .allowlist_type("PublicationObjSpecType") .allowlist_type("CreateSubscriptionStmt") .allowlist_type("AlterSubscriptionStmt") .allowlist_type("PrepareStmt") diff --git a/src/raw_parse.rs b/src/raw_parse.rs index 8509fdb..4fbf13e 100644 --- a/src/raw_parse.rs +++ b/src/raw_parse.rs @@ -445,6 +445,14 @@ unsafe fn convert_node(node_ptr: *mut bindings_raw::Node) -> Option { + let pos = node_ptr as *mut bindings_raw::PublicationObjSpec; + Some(protobuf::node::Node::PublicationObjSpec(Box::new(convert_publication_obj_spec(&*pos)))) + } + bindings_raw::NodeTag_T_PublicationTable => { + let pt = node_ptr as *mut bindings_raw::PublicationTable; + Some(protobuf::node::Node::PublicationTable(Box::new(convert_publication_table(&*pt)))) + } _ => { // For unhandled node types, return None // In the future, we could add more node types here @@ -1477,6 +1485,20 @@ unsafe fn convert_alter_subscription_stmt(ass: &bindings_raw::AlterSubscriptionS } } +unsafe fn convert_publication_obj_spec(pos: &bindings_raw::PublicationObjSpec) -> protobuf::PublicationObjSpec { + let pubtable = if pos.pubtable.is_null() { None } else { Some(Box::new(convert_publication_table(&*pos.pubtable))) }; + protobuf::PublicationObjSpec { pubobjtype: pos.pubobjtype as i32 + 1, name: convert_c_string(pos.name), pubtable, location: pos.location } +} + +unsafe fn convert_publication_table(pt: &bindings_raw::PublicationTable) -> protobuf::PublicationTable { + let relation = if pt.relation.is_null() { None } else { Some(convert_range_var(&*pt.relation)) }; + protobuf::PublicationTable { + relation, + where_clause: convert_node_boxed(pt.whereClause as *mut bindings_raw::Node), + columns: convert_list_to_nodes(pt.columns), + } +} + unsafe fn convert_create_trig_stmt(cts: &bindings_raw::CreateTrigStmt) -> protobuf::CreateTrigStmt { protobuf::CreateTrigStmt { replace: cts.replace, diff --git a/tests/raw_parse/basic.rs b/tests/raw_parse/basic.rs new file mode 100644 index 0000000..94f42dd --- /dev/null +++ b/tests/raw_parse/basic.rs @@ -0,0 +1,137 @@ +//! Basic parsing tests for parse_raw. +//! +//! These tests verify fundamental parsing behavior including: +//! - Simple SELECT queries +//! - Error handling +//! - Multiple statements +//! - Empty queries + +use super::*; + +/// Test that parse_raw results can be deparsed back to SQL +#[test] +fn it_deparses_parse_raw_result() { + let query = "SELECT * FROM users"; + let result = parse_raw(query).unwrap(); + + let deparsed = result.deparse().unwrap(); + assert_eq!(deparsed, query); +} + +/// Test that parse_raw successfully parses a simple SELECT query +#[test] +fn it_parses_simple_select() { + let query = "SELECT 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf.stmts.len(), 1); + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test that parse_raw handles syntax errors +#[test] +fn it_handles_parse_errors() { + let query = "SELECT * FORM users"; + let raw_error = parse_raw(query).err().unwrap(); + let proto_error = parse(query).err().unwrap(); + + assert!(matches!(raw_error, Error::Parse(_))); + assert!(matches!(proto_error, Error::Parse(_))); +} + +/// Test that parse_raw and parse produce equivalent results for simple SELECT +#[test] +fn it_matches_parse_for_simple_select() { + let query = "SELECT 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test that parse_raw and parse produce equivalent results for SELECT with table +#[test] +fn it_matches_parse_for_select_from_table() { + let query = "SELECT * FROM users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let mut raw_tables = raw_result.tables(); + let mut proto_tables = proto_result.tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["users"]); +} + +/// Test that parse_raw handles empty queries (comments only) +#[test] +fn it_handles_empty_queries() { + let query = "-- just a comment"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf.stmts.len(), 0); + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test that parse_raw parses multiple statements +#[test] +fn it_parses_multiple_statements() { + let query = "SELECT 1; SELECT 2; SELECT 3"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf.stmts.len(), 3); + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test that tables() returns the same results for both parsers +#[test] +fn it_returns_tables_like_parse() { + let query = "SELECT u.*, o.* FROM users u JOIN orders o ON u.id = o.user_id WHERE o.status = 'active'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let mut raw_tables = raw_result.tables(); + let mut proto_tables = proto_result.tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["orders", "users"]); +} + +/// Test that functions() returns the same results for both parsers +#[test] +fn it_returns_functions_like_parse() { + let query = "SELECT count(*), sum(amount) FROM orders"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let mut raw_funcs = raw_result.functions(); + let mut proto_funcs = proto_result.functions(); + raw_funcs.sort(); + proto_funcs.sort(); + assert_eq!(raw_funcs, proto_funcs); + assert_eq!(raw_funcs, vec!["count", "sum"]); +} + +/// Test that statement_types() returns the same results for both parsers +#[test] +fn it_returns_statement_types_like_parse() { + let query = "SELECT 1; INSERT INTO t VALUES (1); UPDATE t SET x = 1; DELETE FROM t"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + assert_eq!(raw_result.statement_types(), proto_result.statement_types()); + assert_eq!(raw_result.statement_types(), vec!["SelectStmt", "InsertStmt", "UpdateStmt", "DeleteStmt"]); +} diff --git a/tests/raw_parse/ddl.rs b/tests/raw_parse/ddl.rs new file mode 100644 index 0000000..3694181 --- /dev/null +++ b/tests/raw_parse/ddl.rs @@ -0,0 +1,423 @@ +//! DDL statement tests (CREATE, ALTER, DROP, etc.). +//! +//! These tests verify parse_raw correctly handles data definition language statements. + +use super::*; + +// ============================================================================ +// Basic DDL tests +// ============================================================================ + +/// Test parsing CREATE TABLE +#[test] +fn it_parses_create_table() { + let query = "CREATE TABLE test (id int, name text)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert_eq!(raw_result.statement_types(), proto_result.statement_types()); + assert_eq!(raw_result.statement_types(), vec!["CreateStmt"]); +} + +/// Test parsing DROP TABLE +#[test] +fn it_parses_drop_table() { + let query = "DROP TABLE users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let mut raw_tables = raw_result.ddl_tables(); + let mut proto_tables = proto_result.ddl_tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["users"]); +} + +/// Test parsing CREATE INDEX +#[test] +fn it_parses_create_index() { + let query = "CREATE INDEX idx_users_name ON users (name)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert_eq!(raw_result.statement_types(), proto_result.statement_types()); + assert_eq!(raw_result.statement_types(), vec!["IndexStmt"]); +} + +/// Test CREATE TABLE with constraints +#[test] +fn it_parses_create_table_with_constraints() { + let query = "CREATE TABLE orders ( + id SERIAL PRIMARY KEY, + user_id INTEGER NOT NULL REFERENCES users(id), + amount DECIMAL(10, 2) CHECK (amount > 0), + status TEXT DEFAULT 'pending', + created_at TIMESTAMP DEFAULT NOW(), + UNIQUE (user_id, created_at) + )"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE TABLE AS +#[test] +fn it_parses_create_table_as() { + let query = "CREATE TABLE active_users AS SELECT * FROM users WHERE active = true"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE VIEW +#[test] +fn it_parses_create_view() { + let query = "CREATE VIEW active_users AS SELECT id, name FROM users WHERE active = true"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE MATERIALIZED VIEW +#[test] +fn it_parses_create_materialized_view() { + let query = "CREATE MATERIALIZED VIEW monthly_sales AS SELECT date_trunc('month', created_at) AS month, SUM(amount) FROM orders GROUP BY 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// ALTER TABLE tests +// ============================================================================ + +/// Test ALTER TABLE ADD COLUMN +#[test] +fn it_parses_alter_table_add_column() { + let query = "ALTER TABLE users ADD COLUMN email TEXT NOT NULL"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ALTER TABLE DROP COLUMN +#[test] +fn it_parses_alter_table_drop_column() { + let query = "ALTER TABLE users DROP COLUMN deprecated_field"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ALTER TABLE ADD CONSTRAINT +#[test] +fn it_parses_alter_table_add_constraint() { + let query = "ALTER TABLE orders ADD CONSTRAINT fk_user FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ALTER TABLE RENAME +#[test] +fn it_parses_alter_table_rename() { + let query = "ALTER TABLE users RENAME TO customers"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ALTER TABLE RENAME COLUMN +#[test] +fn it_parses_alter_table_rename_column() { + let query = "ALTER TABLE users RENAME COLUMN name TO full_name"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ALTER TABLE OWNER +#[test] +fn it_parses_alter_owner() { + let query = "ALTER TABLE users OWNER TO postgres"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// INDEX tests +// ============================================================================ + +/// Test CREATE INDEX with expression +#[test] +fn it_parses_create_index_expression() { + let query = "CREATE INDEX idx_lower_email ON users (lower(email))"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE UNIQUE INDEX with WHERE +#[test] +fn it_parses_partial_unique_index() { + let query = "CREATE UNIQUE INDEX idx_active_email ON users (email) WHERE active = true"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE INDEX CONCURRENTLY +#[test] +fn it_parses_create_index_concurrently() { + let query = "CREATE INDEX CONCURRENTLY idx_name ON users (name)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// TRUNCATE test +// ============================================================================ + +/// Test TRUNCATE +#[test] +fn it_parses_truncate() { + let query = "TRUNCATE TABLE logs, audit_logs RESTART IDENTITY CASCADE"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Sequence tests +// ============================================================================ + +/// Test CREATE SEQUENCE +#[test] +fn it_parses_create_sequence() { + let query = "CREATE SEQUENCE my_seq"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE SEQUENCE with options +#[test] +fn it_parses_create_sequence_with_options() { + let query = "CREATE SEQUENCE my_seq START WITH 100 INCREMENT BY 10 MINVALUE 1 MAXVALUE 1000 CYCLE"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE SEQUENCE IF NOT EXISTS +#[test] +fn it_parses_create_sequence_if_not_exists() { + let query = "CREATE SEQUENCE IF NOT EXISTS my_seq"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ALTER SEQUENCE +#[test] +fn it_parses_alter_sequence() { + let query = "ALTER SEQUENCE my_seq RESTART WITH 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Domain tests +// ============================================================================ + +/// Test CREATE DOMAIN +#[test] +fn it_parses_create_domain() { + let query = "CREATE DOMAIN positive_int AS INTEGER CHECK (VALUE > 0)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE DOMAIN with NOT NULL +#[test] +fn it_parses_create_domain_not_null() { + let query = "CREATE DOMAIN non_empty_text AS TEXT NOT NULL CHECK (VALUE <> '')"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE DOMAIN with DEFAULT +#[test] +fn it_parses_create_domain_default() { + let query = "CREATE DOMAIN my_text AS TEXT DEFAULT 'unknown'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Type tests +// ============================================================================ + +/// Test CREATE TYPE AS composite +#[test] +fn it_parses_create_composite_type() { + let query = "CREATE TYPE address AS (street TEXT, city TEXT, zip TEXT)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE TYPE AS ENUM +#[test] +fn it_parses_create_enum_type() { + let query = "CREATE TYPE status AS ENUM ('pending', 'approved', 'rejected')"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Extension tests +// ============================================================================ + +/// Test CREATE EXTENSION +#[test] +fn it_parses_create_extension() { + let query = "CREATE EXTENSION IF NOT EXISTS pg_stat_statements"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE EXTENSION with schema +#[test] +fn it_parses_create_extension_with_schema() { + let query = "CREATE EXTENSION hstore WITH SCHEMA public"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Publication and Subscription tests +// ============================================================================ + +/// Test CREATE PUBLICATION +#[test] +fn it_parses_create_publication() { + let query = "CREATE PUBLICATION my_pub FOR ALL TABLES"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE PUBLICATION for specific tables +#[test] +fn it_parses_create_publication_for_tables() { + let query = "CREATE PUBLICATION my_pub FOR TABLE users, orders"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ALTER PUBLICATION +#[test] +fn it_parses_alter_publication() { + let query = "ALTER PUBLICATION my_pub ADD TABLE products"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE SUBSCRIPTION +#[test] +fn it_parses_create_subscription() { + let query = "CREATE SUBSCRIPTION my_sub CONNECTION 'host=localhost dbname=mydb' PUBLICATION my_pub"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ALTER SUBSCRIPTION +#[test] +fn it_parses_alter_subscription() { + let query = "ALTER SUBSCRIPTION my_sub DISABLE"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Trigger tests +// ============================================================================ + +/// Test CREATE TRIGGER +#[test] +fn it_parses_create_trigger() { + let query = "CREATE TRIGGER my_trigger BEFORE INSERT ON users FOR EACH ROW EXECUTE FUNCTION my_func()"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE TRIGGER AFTER UPDATE +#[test] +fn it_parses_create_trigger_after_update() { + let query = "CREATE TRIGGER audit_trigger AFTER UPDATE ON users FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) EXECUTE FUNCTION audit_log()"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CREATE CONSTRAINT TRIGGER +#[test] +fn it_parses_create_constraint_trigger() { + let query = "CREATE CONSTRAINT TRIGGER check_balance AFTER INSERT OR UPDATE ON accounts DEFERRABLE INITIALLY DEFERRED FOR EACH ROW EXECUTE FUNCTION check_balance()"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} diff --git a/tests/raw_parse/dml.rs b/tests/raw_parse/dml.rs new file mode 100644 index 0000000..d30eadc --- /dev/null +++ b/tests/raw_parse/dml.rs @@ -0,0 +1,493 @@ +//! DML statement tests (INSERT, UPDATE, DELETE). +//! +//! These tests verify parse_raw correctly handles data manipulation language statements. + +use super::*; + +// ============================================================================ +// Basic DML tests +// ============================================================================ + +/// Test parsing INSERT statement +#[test] +fn it_parses_insert() { + let query = "INSERT INTO users (name) VALUES ('test')"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let mut raw_tables = raw_result.dml_tables(); + let mut proto_tables = proto_result.dml_tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["users"]); +} + +/// Test parsing UPDATE statement +#[test] +fn it_parses_update() { + let query = "UPDATE users SET name = 'bob' WHERE id = 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let mut raw_tables = raw_result.dml_tables(); + let mut proto_tables = proto_result.dml_tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["users"]); +} + +/// Test parsing DELETE statement +#[test] +fn it_parses_delete() { + let query = "DELETE FROM users WHERE id = 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let mut raw_tables = raw_result.dml_tables(); + let mut proto_tables = proto_result.dml_tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["users"]); +} + +// ============================================================================ +// INSERT variations +// ============================================================================ + +/// Test parsing INSERT with ON CONFLICT +#[test] +fn it_parses_insert_on_conflict() { + let query = "INSERT INTO users (id, name) VALUES (1, 'test') ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let raw_tables = raw_result.dml_tables(); + let proto_tables = proto_result.dml_tables(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["users"]); +} + +/// Test parsing INSERT with RETURNING +#[test] +fn it_parses_insert_returning() { + let query = "INSERT INTO users (name) VALUES ('test') RETURNING id"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT with multiple tuples +#[test] +fn it_parses_insert_multiple_rows() { + let query = "INSERT INTO users (name, email, age) VALUES ('Alice', 'alice@example.com', 25), ('Bob', 'bob@example.com', 30), ('Charlie', 'charlie@example.com', 35)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT ... SELECT +#[test] +fn it_parses_insert_select() { + let query = "INSERT INTO archived_users (id, name, email) SELECT id, name, email FROM users WHERE deleted_at IS NOT NULL"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT ... SELECT with complex query +#[test] +fn it_parses_insert_select_complex() { + let query = "INSERT INTO monthly_stats (month, user_count, order_count, total_revenue) + SELECT date_trunc('month', created_at) AS month, + COUNT(DISTINCT user_id), + COUNT(*), + SUM(amount) + FROM orders + WHERE created_at >= '2023-01-01' + GROUP BY date_trunc('month', created_at)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT with CTE +#[test] +fn it_parses_insert_with_cte() { + let query = "WITH new_data AS ( + SELECT name, email FROM temp_imports WHERE valid = true + ) + INSERT INTO users (name, email) SELECT name, email FROM new_data"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT with DEFAULT values +#[test] +fn it_parses_insert_default_values() { + let query = "INSERT INTO users (name, created_at) VALUES ('test', DEFAULT)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT with ON CONFLICT DO NOTHING +#[test] +fn it_parses_insert_on_conflict_do_nothing() { + let query = "INSERT INTO users (id, name) VALUES (1, 'test') ON CONFLICT (id) DO NOTHING"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT with ON CONFLICT with WHERE clause +#[test] +fn it_parses_insert_on_conflict_with_where() { + let query = "INSERT INTO users (id, name, updated_at) VALUES (1, 'test', NOW()) + ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, updated_at = EXCLUDED.updated_at + WHERE users.updated_at < EXCLUDED.updated_at"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT with multiple columns in ON CONFLICT +#[test] +fn it_parses_insert_on_conflict_multiple_columns() { + let query = "INSERT INTO user_settings (user_id, key, value) VALUES (1, 'theme', 'dark') + ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT with RETURNING multiple columns +#[test] +fn it_parses_insert_returning_multiple() { + let query = "INSERT INTO users (name, email) VALUES ('test', 'test@example.com') RETURNING id, created_at, name"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT with subquery in VALUES +#[test] +fn it_parses_insert_with_subquery_value() { + let query = "INSERT INTO orders (user_id, total) VALUES ((SELECT id FROM users WHERE email = 'test@example.com'), 100.00)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test INSERT with OVERRIDING +#[test] +fn it_parses_insert_overriding() { + let query = "INSERT INTO users (id, name) OVERRIDING SYSTEM VALUE VALUES (1, 'test')"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Complex UPDATE tests +// ============================================================================ + +/// Test UPDATE with multiple columns +#[test] +fn it_parses_update_multiple_columns() { + let query = "UPDATE users SET name = 'new_name', email = 'new@example.com', updated_at = NOW() WHERE id = 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE with subquery in SET +#[test] +fn it_parses_update_with_subquery_set() { + let query = "UPDATE orders SET total = (SELECT SUM(price * quantity) FROM order_items WHERE order_id = orders.id) WHERE id = 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE with FROM clause (PostgreSQL-specific JOIN update) +#[test] +fn it_parses_update_from() { + let query = "UPDATE orders o SET status = 'shipped', shipped_at = NOW() + FROM shipments s + WHERE o.id = s.order_id AND s.status = 'delivered'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE with FROM and multiple tables +#[test] +fn it_parses_update_from_multiple_tables() { + let query = "UPDATE products p SET price = p.price * (1 + d.percentage / 100) + FROM discounts d + JOIN categories c ON d.category_id = c.id + WHERE p.category_id = c.id AND d.active = true"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE with CTE +#[test] +fn it_parses_update_with_cte() { + let query = "WITH inactive_users AS ( + SELECT id FROM users WHERE last_login < NOW() - INTERVAL '1 year' + ) + UPDATE users SET status = 'inactive' WHERE id IN (SELECT id FROM inactive_users)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE with RETURNING +#[test] +fn it_parses_update_returning() { + let query = "UPDATE users SET name = 'updated' WHERE id = 1 RETURNING id, name, updated_at"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE with complex WHERE clause +#[test] +fn it_parses_update_complex_where() { + let query = "UPDATE orders SET status = 'cancelled' + WHERE created_at < NOW() - INTERVAL '30 days' + AND status = 'pending' + AND NOT EXISTS (SELECT 1 FROM payments WHERE payments.order_id = orders.id)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE with row value comparison +#[test] +fn it_parses_update_row_comparison() { + let query = "UPDATE users SET (name, email) = ('new_name', 'new@example.com') WHERE id = 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE with CASE expression +#[test] +fn it_parses_update_with_case() { + let query = "UPDATE products SET price = CASE + WHEN category = 'electronics' THEN price * 0.9 + WHEN category = 'clothing' THEN price * 0.8 + ELSE price * 0.95 + END + WHERE sale_active = true"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE with array operations +#[test] +fn it_parses_update_array() { + let query = "UPDATE users SET tags = array_append(tags, 'verified') WHERE id = 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Complex DELETE tests +// ============================================================================ + +/// Test DELETE with subquery in WHERE +#[test] +fn it_parses_delete_with_subquery() { + let query = "DELETE FROM orders WHERE user_id IN (SELECT id FROM users WHERE status = 'deleted')"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DELETE with USING clause (PostgreSQL-specific JOIN delete) +#[test] +fn it_parses_delete_using() { + let query = "DELETE FROM order_items oi USING orders o + WHERE oi.order_id = o.id AND o.status = 'cancelled'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DELETE with USING and multiple tables +#[test] +fn it_parses_delete_using_multiple_tables() { + let query = "DELETE FROM notifications n + USING users u, user_settings s + WHERE n.user_id = u.id + AND u.id = s.user_id + AND s.key = 'email_notifications' + AND s.value = 'false'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DELETE with CTE +#[test] +fn it_parses_delete_with_cte() { + let query = "WITH old_orders AS ( + SELECT id FROM orders WHERE created_at < NOW() - INTERVAL '5 years' + ) + DELETE FROM order_items WHERE order_id IN (SELECT id FROM old_orders)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DELETE with RETURNING +#[test] +fn it_parses_delete_returning() { + let query = "DELETE FROM users WHERE id = 1 RETURNING id, name, email"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DELETE with EXISTS +#[test] +fn it_parses_delete_with_exists() { + let query = "DELETE FROM products p + WHERE NOT EXISTS (SELECT 1 FROM order_items oi WHERE oi.product_id = p.id) + AND p.created_at < NOW() - INTERVAL '1 year'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DELETE with complex boolean conditions +#[test] +fn it_parses_delete_complex_conditions() { + let query = "DELETE FROM logs + WHERE (level = 'debug' AND created_at < NOW() - INTERVAL '7 days') + OR (level = 'info' AND created_at < NOW() - INTERVAL '30 days') + OR (level IN ('warning', 'error') AND created_at < NOW() - INTERVAL '90 days')"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DELETE with ONLY +#[test] +fn it_parses_delete_only() { + let query = "DELETE FROM ONLY parent_table WHERE id = 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +// ============================================================================ +// Combined DML with CTEs +// ============================================================================ + +/// Test data modification CTE (INSERT in CTE) +#[test] +fn it_parses_insert_cte_returning() { + let query = "WITH inserted AS ( + INSERT INTO users (name, email) VALUES ('test', 'test@example.com') RETURNING id, name + ) + SELECT * FROM inserted"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UPDATE in CTE with final SELECT +#[test] +fn it_parses_update_cte_returning() { + let query = "WITH updated AS ( + UPDATE users SET last_login = NOW() WHERE id = 1 RETURNING id, name, last_login + ) + SELECT * FROM updated"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DELETE in CTE with final SELECT +#[test] +fn it_parses_delete_cte_returning() { + let query = "WITH deleted AS ( + DELETE FROM expired_sessions WHERE expires_at < NOW() RETURNING user_id + ) + SELECT COUNT(*) FROM deleted"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test chained CTEs with multiple DML operations +#[test] +fn it_parses_chained_dml_ctes() { + let query = "WITH + to_archive AS ( + SELECT id FROM users WHERE last_login < NOW() - INTERVAL '2 years' + ), + archived AS ( + INSERT INTO archived_users SELECT * FROM users WHERE id IN (SELECT id FROM to_archive) RETURNING id + ), + deleted AS ( + DELETE FROM users WHERE id IN (SELECT id FROM archived) RETURNING id + ) + SELECT COUNT(*) as archived_count FROM deleted"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} diff --git a/tests/raw_parse/expressions.rs b/tests/raw_parse/expressions.rs new file mode 100644 index 0000000..c108bbf --- /dev/null +++ b/tests/raw_parse/expressions.rs @@ -0,0 +1,470 @@ +//! Expression tests: literals, type casts, arrays, JSON, operators. +//! +//! These tests verify parse_raw correctly handles various expressions. + +use super::*; + +// ============================================================================ +// Literal value tests +// ============================================================================ + +/// Test parsing float with leading dot +#[test] +fn it_parses_floats_with_leading_dot() { + let query = "SELECT .1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify the float value + let raw_const = get_first_const(&raw_result.protobuf).expect("should have const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("should have const"); + assert_eq!(raw_const, proto_const); +} + +/// Test parsing bit string in hex notation +#[test] +fn it_parses_bit_strings_hex() { + let query = "SELECT X'EFFF'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify the bit string value + let raw_const = get_first_const(&raw_result.protobuf).expect("should have const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("should have const"); + assert_eq!(raw_const, proto_const); +} + +/// Test parsing real-world query with multiple joins +#[test] +fn it_parses_real_world_query() { + let query = " + SELECT memory_total_bytes, memory_free_bytes, memory_pagecache_bytes, + (memory_swap_total_bytes - memory_swap_free_bytes) AS swap + FROM snapshots s JOIN system_snapshots ON (snapshot_id = s.id) + WHERE s.database_id = 1 AND s.collected_at BETWEEN '2021-01-01' AND '2021-12-31' + ORDER BY collected_at"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify tables + let mut raw_tables = raw_result.tables(); + let mut proto_tables = proto_result.tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["snapshots", "system_snapshots"]); +} +// ============================================================================ +// A_Const value extraction tests +// ============================================================================ + +/// Test that parse_raw extracts integer values correctly and matches parse +#[test] +fn it_extracts_integer_const() { + let query = "SELECT 42"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); + + assert_eq!(raw_const, proto_const); + assert!(!raw_const.isnull); + match &raw_const.val { + Some(a_const::Val::Ival(int_val)) => { + assert_eq!(int_val.ival, 42); + } + other => panic!("Expected Ival, got {:?}", other), + } +} + +/// Test that parse_raw extracts negative integer values correctly +#[test] +fn it_extracts_negative_integer_const() { + let query = "SELECT -123"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test that parse_raw extracts string values correctly and matches parse +#[test] +fn it_extracts_string_const() { + let query = "SELECT 'hello world'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); + + assert_eq!(raw_const, proto_const); + assert!(!raw_const.isnull); + match &raw_const.val { + Some(a_const::Val::Sval(str_val)) => { + assert_eq!(str_val.sval, "hello world"); + } + other => panic!("Expected Sval, got {:?}", other), + } +} + +/// Test that parse_raw extracts float values correctly and matches parse +#[test] +fn it_extracts_float_const() { + let query = "SELECT 3.14159"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); + + assert_eq!(raw_const, proto_const); + assert!(!raw_const.isnull); + match &raw_const.val { + Some(a_const::Val::Fval(float_val)) => { + assert_eq!(float_val.fval, "3.14159"); + } + other => panic!("Expected Fval, got {:?}", other), + } +} + +/// Test that parse_raw extracts boolean TRUE correctly and matches parse +#[test] +fn it_extracts_boolean_true_const() { + let query = "SELECT TRUE"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); + + assert_eq!(raw_const, proto_const); + assert!(!raw_const.isnull); + match &raw_const.val { + Some(a_const::Val::Boolval(bool_val)) => { + assert!(bool_val.boolval); + } + other => panic!("Expected Boolval(true), got {:?}", other), + } +} + +/// Test that parse_raw extracts boolean FALSE correctly and matches parse +#[test] +fn it_extracts_boolean_false_const() { + let query = "SELECT FALSE"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); + + assert_eq!(raw_const, proto_const); + assert!(!raw_const.isnull); + match &raw_const.val { + Some(a_const::Val::Boolval(bool_val)) => { + assert!(!bool_val.boolval); + } + other => panic!("Expected Boolval(false), got {:?}", other), + } +} + +/// Test that parse_raw extracts NULL correctly and matches parse +#[test] +fn it_extracts_null_const() { + let query = "SELECT NULL"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); + + assert_eq!(raw_const, proto_const); + assert!(raw_const.isnull); + assert!(raw_const.val.is_none()); +} + +/// Test that parse_raw extracts bit string values correctly and matches parse +#[test] +fn it_extracts_bit_string_const() { + let query = "SELECT B'1010'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); + + assert_eq!(raw_const, proto_const); + assert!(!raw_const.isnull); + match &raw_const.val { + Some(a_const::Val::Bsval(bit_val)) => { + assert_eq!(bit_val.bsval, "b1010"); + } + other => panic!("Expected Bsval, got {:?}", other), + } +} + +/// Test that parse_raw extracts hex bit string correctly and matches parse +#[test] +fn it_extracts_hex_bit_string_const() { + let query = "SELECT X'FF'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); + let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); + + assert_eq!(raw_const, proto_const); + assert!(!raw_const.isnull); + match &raw_const.val { + Some(a_const::Val::Bsval(bit_val)) => { + assert_eq!(bit_val.bsval, "xFF"); + } + other => panic!("Expected Bsval, got {:?}", other), + } +} +// ============================================================================ +// Expression tests +// ============================================================================ + +/// Test COALESCE +#[test] +fn it_parses_coalesce() { + let query = "SELECT COALESCE(nickname, name, 'Unknown') FROM users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test NULLIF +#[test] +fn it_parses_nullif() { + let query = "SELECT NULLIF(status, 'deleted') FROM records"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test GREATEST and LEAST +#[test] +fn it_parses_greatest_least() { + let query = "SELECT GREATEST(a, b, c), LEAST(x, y, z) FROM t"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test IS NULL and IS NOT NULL +#[test] +fn it_parses_null_tests() { + let query = "SELECT * FROM users WHERE deleted_at IS NULL AND email IS NOT NULL"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test IS DISTINCT FROM +#[test] +fn it_parses_is_distinct_from() { + let query = "SELECT * FROM t WHERE a IS DISTINCT FROM b"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test BETWEEN +#[test] +fn it_parses_between() { + let query = "SELECT * FROM events WHERE created_at BETWEEN '2023-01-01' AND '2023-12-31'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test LIKE and ILIKE +#[test] +fn it_parses_like_ilike() { + let query = "SELECT * FROM users WHERE name LIKE 'John%' OR email ILIKE '%@EXAMPLE.COM'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test SIMILAR TO +#[test] +fn it_parses_similar_to() { + let query = "SELECT * FROM products WHERE name SIMILAR TO '%(phone|tablet)%'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test complex boolean expressions +#[test] +fn it_parses_complex_boolean() { + let query = "SELECT * FROM users WHERE (active = true AND verified = true) OR (role = 'admin' AND NOT suspended)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} +// ============================================================================ +// Type cast tests +// ============================================================================ + +/// Test PostgreSQL-style type cast +#[test] +fn it_parses_pg_type_cast() { + let query = "SELECT '123'::integer, '2023-01-01'::date, 'true'::boolean"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test SQL-style CAST +#[test] +fn it_parses_sql_cast() { + let query = "SELECT CAST('123' AS integer), CAST(created_at AS date) FROM t"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test array type cast +#[test] +fn it_parses_array_cast() { + let query = "SELECT ARRAY[1, 2, 3]::text[]"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} +// ============================================================================ +// Array and JSON tests +// ============================================================================ + +/// Test array constructor +#[test] +fn it_parses_array_constructor() { + let query = "SELECT ARRAY[1, 2, 3], ARRAY['a', 'b', 'c']"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test array subscript +#[test] +fn it_parses_array_subscript() { + let query = "SELECT tags[1], matrix[1][2] FROM t"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test array slice +#[test] +fn it_parses_array_slice() { + let query = "SELECT arr[2:4], arr[:3], arr[2:] FROM t"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test unnest +#[test] +fn it_parses_unnest() { + let query = "SELECT unnest(ARRAY[1, 2, 3])"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test JSON operators +#[test] +fn it_parses_json_operators() { + let query = "SELECT data->'name', data->>'email', data#>'{address,city}' FROM users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test JSONB containment +#[test] +fn it_parses_jsonb_containment() { + let query = "SELECT * FROM products WHERE metadata @> '{\"featured\": true}'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} +// ============================================================================ +// Parameter placeholder tests +// ============================================================================ + +/// Test positional parameters +#[test] +fn it_parses_positional_params() { + let query = "SELECT * FROM users WHERE id = $1 AND status = $2"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test parameters in INSERT +#[test] +fn it_parses_params_in_insert() { + let query = "INSERT INTO users (name, email) VALUES ($1, $2) RETURNING id"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} diff --git a/tests/raw_parse/mod.rs b/tests/raw_parse/mod.rs new file mode 100644 index 0000000..45fd98c --- /dev/null +++ b/tests/raw_parse/mod.rs @@ -0,0 +1,43 @@ +//! Raw parse tests split into multiple modules for maintainability. +//! +//! This module contains tests that verify parse_raw produces equivalent +//! results to parse (protobuf-based parsing). + +pub use pg_query::protobuf::{a_const, node, ParseResult as ProtobufParseResult}; +pub use pg_query::{parse, parse_raw, Error}; + +/// Helper to extract AConst from a SELECT statement's first target +pub fn get_first_const(result: &ProtobufParseResult) -> Option<&pg_query::protobuf::AConst> { + let stmt = result.stmts.first()?; + let raw_stmt = stmt.stmt.as_ref()?; + let node = raw_stmt.node.as_ref()?; + + if let node::Node::SelectStmt(select) = node { + let target = select.target_list.first()?; + if let Some(node::Node::ResTarget(res_target)) = target.node.as_ref() { + if let Some(val_node) = res_target.val.as_ref() { + if let Some(node::Node::AConst(aconst)) = val_node.node.as_ref() { + return Some(aconst); + } + } + } + } + None +} + +/// Helper macro for simple parse comparison tests +#[macro_export] +macro_rules! parse_test { + ($query:expr) => {{ + let raw_result = parse_raw($query).unwrap(); + let proto_result = parse($query).unwrap(); + assert_eq!(raw_result.protobuf, proto_result.protobuf); + }}; +} + +pub mod basic; +pub mod ddl; +pub mod dml; +pub mod expressions; +pub mod select; +pub mod statements; diff --git a/tests/raw_parse/select.rs b/tests/raw_parse/select.rs new file mode 100644 index 0000000..75ae1e8 --- /dev/null +++ b/tests/raw_parse/select.rs @@ -0,0 +1,855 @@ +//! Complex SELECT tests: JOINs, subqueries, CTEs, window functions, set operations. +//! +//! These tests verify parse_raw correctly handles complex SELECT statements. + +use super::*; + +// ============================================================================ +// JOIN and complex SELECT tests +// ============================================================================ + +/// Test parsing SELECT with JOIN +#[test] +fn it_parses_join() { + let query = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify tables are extracted correctly + let mut raw_tables = raw_result.tables(); + let mut proto_tables = proto_result.tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["orders", "users"]); +} + +/// Test parsing UNION query +#[test] +fn it_parses_union() { + let query = "SELECT id FROM users UNION SELECT id FROM admins"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify tables from both sides of UNION + let mut raw_tables = raw_result.tables(); + let mut proto_tables = proto_result.tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["admins", "users"]); +} + +/// Test parsing WITH clause (CTE) +#[test] +fn it_parses_cte() { + let query = "WITH active_users AS (SELECT * FROM users WHERE active = true) SELECT * FROM active_users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify CTE names match + assert_eq!(raw_result.cte_names, proto_result.cte_names); + assert!(raw_result.cte_names.contains(&"active_users".to_string())); + + // Verify tables (should only include actual tables, not CTEs) + let mut raw_tables = raw_result.tables(); + let mut proto_tables = proto_result.tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["users"]); +} + +/// Test parsing subquery in SELECT +#[test] +fn it_parses_subquery() { + let query = "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify all tables are found + let mut raw_tables = raw_result.tables(); + let mut proto_tables = proto_result.tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["orders", "users"]); +} + +/// Test parsing aggregate functions +#[test] +fn it_parses_aggregates() { + let query = "SELECT count(*), sum(amount), avg(price) FROM orders"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify functions are extracted correctly + let mut raw_funcs = raw_result.functions(); + let mut proto_funcs = proto_result.functions(); + raw_funcs.sort(); + proto_funcs.sort(); + assert_eq!(raw_funcs, proto_funcs); + assert!(raw_funcs.contains(&"count".to_string())); + assert!(raw_funcs.contains(&"sum".to_string())); + assert!(raw_funcs.contains(&"avg".to_string())); +} + +/// Test parsing CASE expression +#[test] +fn it_parses_case_expression() { + let query = "SELECT CASE WHEN x > 0 THEN 'positive' ELSE 'non-positive' END FROM t"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify table is found + let raw_tables = raw_result.tables(); + let proto_tables = proto_result.tables(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["t"]); +} + +/// Test parsing complex SELECT with multiple clauses +#[test] +fn it_parses_complex_select() { + let query = "SELECT u.id, u.name, count(*) AS order_count FROM users u LEFT JOIN orders o ON u.id = o.user_id WHERE u.active = true GROUP BY u.id, u.name HAVING count(*) > 0 ORDER BY order_count DESC LIMIT 10"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + // Full structural equality check + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + // Verify tables + let mut raw_tables = raw_result.tables(); + let mut proto_tables = proto_result.tables(); + raw_tables.sort(); + proto_tables.sort(); + assert_eq!(raw_tables, proto_tables); + assert_eq!(raw_tables, vec!["orders", "users"]); + + // Verify functions + let mut raw_funcs = raw_result.functions(); + let mut proto_funcs = proto_result.functions(); + raw_funcs.sort(); + proto_funcs.sort(); + assert_eq!(raw_funcs, proto_funcs); + assert!(raw_funcs.contains(&"count".to_string())); +} + +// ============================================================================ +// Advanced JOIN tests +// ============================================================================ + +/// Test LEFT JOIN +#[test] +fn it_parses_left_join() { + let query = "SELECT * FROM users u LEFT JOIN orders o ON u.id = o.user_id"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let mut raw_tables = raw_result.tables(); + raw_tables.sort(); + assert_eq!(raw_tables, vec!["orders", "users"]); +} + +/// Test RIGHT JOIN +#[test] +fn it_parses_right_join() { + let query = "SELECT * FROM users u RIGHT JOIN orders o ON u.id = o.user_id"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test FULL OUTER JOIN +#[test] +fn it_parses_full_outer_join() { + let query = "SELECT * FROM users u FULL OUTER JOIN orders o ON u.id = o.user_id"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CROSS JOIN +#[test] +fn it_parses_cross_join() { + let query = "SELECT * FROM users CROSS JOIN products"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let mut raw_tables = raw_result.tables(); + raw_tables.sort(); + assert_eq!(raw_tables, vec!["products", "users"]); +} + +/// Test NATURAL JOIN +#[test] +fn it_parses_natural_join() { + let query = "SELECT * FROM users NATURAL JOIN user_profiles"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test multiple JOINs +#[test] +fn it_parses_multiple_joins() { + let query = "SELECT u.name, o.id, p.name FROM users u + JOIN orders o ON u.id = o.user_id + JOIN order_items oi ON o.id = oi.order_id + JOIN products p ON oi.product_id = p.id"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let mut raw_tables = raw_result.tables(); + raw_tables.sort(); + assert_eq!(raw_tables, vec!["order_items", "orders", "products", "users"]); +} + +/// Test JOIN with USING clause +#[test] +fn it_parses_join_using() { + let query = "SELECT * FROM users u JOIN user_profiles p USING (user_id)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test LATERAL JOIN +#[test] +fn it_parses_lateral_join() { + let query = "SELECT * FROM users u, LATERAL (SELECT * FROM orders o WHERE o.user_id = u.id LIMIT 3) AS recent_orders"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let mut raw_tables = raw_result.tables(); + raw_tables.sort(); + assert_eq!(raw_tables, vec!["orders", "users"]); +} +// ============================================================================ +// Advanced subquery tests +// ============================================================================ + +/// Test correlated subquery +#[test] +fn it_parses_correlated_subquery() { + let query = "SELECT * FROM users u WHERE EXISTS (SELECT 1 FROM orders o WHERE o.user_id = u.id)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + + let mut raw_tables = raw_result.tables(); + raw_tables.sort(); + assert_eq!(raw_tables, vec!["orders", "users"]); +} + +/// Test NOT EXISTS subquery +#[test] +fn it_parses_not_exists_subquery() { + let query = "SELECT * FROM users u WHERE NOT EXISTS (SELECT 1 FROM banned b WHERE b.user_id = u.id)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test scalar subquery in SELECT +#[test] +fn it_parses_scalar_subquery() { + let query = "SELECT u.name, (SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id) AS order_count FROM users u"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test subquery in FROM clause +#[test] +fn it_parses_derived_table() { + let query = "SELECT * FROM (SELECT id, name FROM users WHERE active = true) AS active_users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ANY/SOME subquery +#[test] +fn it_parses_any_subquery() { + let query = "SELECT * FROM products WHERE price > ANY (SELECT avg_price FROM categories)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ALL subquery +#[test] +fn it_parses_all_subquery() { + let query = "SELECT * FROM products WHERE price > ALL (SELECT price FROM discounted_products)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} +// ============================================================================ +// Window function tests +// ============================================================================ + +/// Test basic window function +#[test] +fn it_parses_window_function() { + let query = "SELECT name, salary, ROW_NUMBER() OVER (ORDER BY salary DESC) AS rank FROM employees"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test window function with PARTITION BY +#[test] +fn it_parses_window_function_partition() { + let query = "SELECT department, name, salary, RANK() OVER (PARTITION BY department ORDER BY salary DESC) AS dept_rank FROM employees"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test window function with frame clause +#[test] +fn it_parses_window_function_frame() { + let query = "SELECT date, amount, SUM(amount) OVER (ORDER BY date ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS moving_sum FROM transactions"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test named window +#[test] +fn it_parses_named_window() { + let query = "SELECT name, salary, SUM(salary) OVER w, AVG(salary) OVER w FROM employees WINDOW w AS (PARTITION BY department ORDER BY salary)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test LAG and LEAD functions +#[test] +fn it_parses_lag_lead() { + let query = + "SELECT date, price, LAG(price, 1) OVER (ORDER BY date) AS prev_price, LEAD(price, 1) OVER (ORDER BY date) AS next_price FROM stock_prices"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} +// ============================================================================ +// CTE variations +// ============================================================================ + +/// Test multiple CTEs +#[test] +fn it_parses_multiple_ctes() { + let query = "WITH + active_users AS (SELECT * FROM users WHERE active = true), + premium_users AS (SELECT * FROM active_users WHERE plan = 'premium') + SELECT * FROM premium_users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert!(raw_result.cte_names.contains(&"active_users".to_string())); + assert!(raw_result.cte_names.contains(&"premium_users".to_string())); +} + +/// Test recursive CTE +#[test] +fn it_parses_recursive_cte() { + let query = "WITH RECURSIVE subordinates AS ( + SELECT id, name, manager_id FROM employees WHERE id = 1 + UNION ALL + SELECT e.id, e.name, e.manager_id FROM employees e INNER JOIN subordinates s ON e.manager_id = s.id + ) SELECT * FROM subordinates"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CTE with column list +#[test] +fn it_parses_cte_with_columns() { + let query = "WITH regional_sales(region, total) AS (SELECT region, SUM(amount) FROM orders GROUP BY region) SELECT * FROM regional_sales"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test CTE with MATERIALIZED +#[test] +fn it_parses_cte_materialized() { + let query = "WITH t AS MATERIALIZED (SELECT * FROM large_table WHERE x > 100) SELECT * FROM t"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} +// ============================================================================ +// Set operations +// ============================================================================ + +/// Test INTERSECT +#[test] +fn it_parses_intersect() { + let query = "SELECT id FROM users INTERSECT SELECT user_id FROM orders"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test EXCEPT +#[test] +fn it_parses_except() { + let query = "SELECT id FROM users EXCEPT SELECT user_id FROM banned_users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UNION ALL +#[test] +fn it_parses_union_all() { + let query = "SELECT name FROM users UNION ALL SELECT name FROM admins"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test compound set operations +#[test] +fn it_parses_compound_set_operations() { + let query = "(SELECT id FROM a UNION SELECT id FROM b) INTERSECT SELECT id FROM c"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} +// ============================================================================ +// GROUP BY variations +// ============================================================================ + +/// Test GROUP BY ROLLUP +#[test] +fn it_parses_group_by_rollup() { + let query = "SELECT region, product, SUM(sales) FROM sales_data GROUP BY ROLLUP(region, product)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test GROUP BY CUBE +#[test] +fn it_parses_group_by_cube() { + let query = "SELECT region, product, SUM(sales) FROM sales_data GROUP BY CUBE(region, product)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test GROUP BY GROUPING SETS +#[test] +fn it_parses_grouping_sets() { + let query = "SELECT region, product, SUM(sales) FROM sales_data GROUP BY GROUPING SETS ((region), (product), ())"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} +// ============================================================================ +// DISTINCT and ORDER BY variations +// ============================================================================ + +/// Test DISTINCT ON +#[test] +fn it_parses_distinct_on() { + let query = "SELECT DISTINCT ON (user_id) * FROM orders ORDER BY user_id, created_at DESC"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ORDER BY with NULLS FIRST/LAST +#[test] +fn it_parses_order_by_nulls() { + let query = "SELECT * FROM users ORDER BY last_login DESC NULLS LAST"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test FETCH FIRST +#[test] +fn it_parses_fetch_first() { + let query = "SELECT * FROM users ORDER BY id FETCH FIRST 10 ROWS ONLY"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test OFFSET with FETCH +#[test] +fn it_parses_offset_fetch() { + let query = "SELECT * FROM users ORDER BY id OFFSET 20 ROWS FETCH NEXT 10 ROWS ONLY"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} +// ============================================================================ +// Locking clauses +// ============================================================================ + +/// Test FOR UPDATE +#[test] +fn it_parses_for_update() { + let query = "SELECT * FROM users WHERE id = 1 FOR UPDATE"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test FOR SHARE +#[test] +fn it_parses_for_share() { + let query = "SELECT * FROM users WHERE id = 1 FOR SHARE"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test FOR UPDATE NOWAIT +#[test] +fn it_parses_for_update_nowait() { + let query = "SELECT * FROM users WHERE id = 1 FOR UPDATE NOWAIT"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test FOR UPDATE SKIP LOCKED +#[test] +fn it_parses_for_update_skip_locked() { + let query = "SELECT * FROM jobs WHERE status = 'pending' LIMIT 1 FOR UPDATE SKIP LOCKED"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} +// ============================================================================ +// Complex real-world queries +// ============================================================================ + +/// Test analytics query with window functions +#[test] +fn it_parses_analytics_query() { + let query = " + SELECT + date_trunc('day', created_at) AS day, + COUNT(*) AS daily_orders, + SUM(amount) AS daily_revenue, + AVG(amount) OVER (ORDER BY date_trunc('day', created_at) ROWS BETWEEN 6 PRECEDING AND CURRENT ROW) AS weekly_avg + FROM orders + WHERE created_at >= NOW() - INTERVAL '30 days' + GROUP BY date_trunc('day', created_at) + ORDER BY day"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test hierarchical query with recursive CTE +#[test] +fn it_parses_hierarchy_query() { + let query = " + WITH RECURSIVE category_tree AS ( + SELECT id, name, parent_id, 0 AS level, ARRAY[id] AS path + FROM categories + WHERE parent_id IS NULL + UNION ALL + SELECT c.id, c.name, c.parent_id, ct.level + 1, ct.path || c.id + FROM categories c + JOIN category_tree ct ON c.parent_id = ct.id + ) + SELECT * FROM category_tree ORDER BY path"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test complex report query +#[test] +fn it_parses_complex_report_query() { + let query = " + WITH monthly_data AS ( + SELECT + date_trunc('month', o.created_at) AS month, + u.region, + p.category, + SUM(oi.quantity * oi.unit_price) AS revenue, + COUNT(DISTINCT o.id) AS order_count, + COUNT(DISTINCT o.user_id) AS customer_count + FROM orders o + JOIN users u ON o.user_id = u.id + JOIN order_items oi ON o.id = oi.order_id + JOIN products p ON oi.product_id = p.id + WHERE o.created_at >= '2023-01-01' AND o.status = 'completed' + GROUP BY 1, 2, 3 + ) + SELECT + month, + region, + category, + revenue, + order_count, + customer_count, + revenue / NULLIF(order_count, 0) AS avg_order_value, + SUM(revenue) OVER (PARTITION BY region ORDER BY month) AS cumulative_revenue + FROM monthly_data + ORDER BY month DESC, region, revenue DESC"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test query with multiple subqueries and CTEs +#[test] +fn it_parses_mixed_subqueries_and_ctes() { + let query = " + WITH high_value_customers AS ( + SELECT user_id FROM orders GROUP BY user_id HAVING SUM(amount) > 1000 + ) + SELECT u.*, + (SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id) AS total_orders, + (SELECT MAX(created_at) FROM orders o WHERE o.user_id = u.id) AS last_order + FROM users u + WHERE u.id IN (SELECT user_id FROM high_value_customers) + AND EXISTS (SELECT 1 FROM orders o WHERE o.user_id = u.id AND o.created_at > NOW() - INTERVAL '90 days') + ORDER BY (SELECT SUM(amount) FROM orders o WHERE o.user_id = u.id) DESC + LIMIT 100"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} +// Tests for previously stubbed fields +// ============================================================================ + +/// Test column with COLLATE clause +#[test] +fn it_parses_column_with_collate() { + let query = "CREATE TABLE test_collate ( + name TEXT COLLATE \"C\", + description VARCHAR(255) COLLATE \"en_US.UTF-8\" + )"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test partitioned table with PARTITION BY RANGE +#[test] +fn it_parses_partition_by_range() { + let query = "CREATE TABLE measurements ( + id SERIAL, + logdate DATE NOT NULL, + peaktemp INT + ) PARTITION BY RANGE (logdate)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test partitioned table with PARTITION BY LIST +#[test] +fn it_parses_partition_by_list() { + let query = "CREATE TABLE orders ( + id SERIAL, + region TEXT NOT NULL, + order_date DATE + ) PARTITION BY LIST (region)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test partitioned table with PARTITION BY HASH +#[test] +fn it_parses_partition_by_hash() { + let query = "CREATE TABLE users_partitioned ( + id SERIAL, + username TEXT + ) PARTITION BY HASH (id)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test partition with FOR VALUES (range) +#[test] +fn it_parses_partition_for_values_range() { + let query = "CREATE TABLE measurements_2023 PARTITION OF measurements + FOR VALUES FROM ('2023-01-01') TO ('2024-01-01')"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test partition with FOR VALUES (list) +#[test] +fn it_parses_partition_for_values_list() { + let query = "CREATE TABLE orders_west PARTITION OF orders + FOR VALUES IN ('west', 'northwest', 'southwest')"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test partition with FOR VALUES (hash) +#[test] +fn it_parses_partition_for_values_hash() { + let query = "CREATE TABLE users_part_0 PARTITION OF users_partitioned + FOR VALUES WITH (MODULUS 4, REMAINDER 0)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test partition with DEFAULT +#[test] +fn it_parses_partition_default() { + let query = "CREATE TABLE orders_other PARTITION OF orders DEFAULT"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test recursive CTE with SEARCH BREADTH FIRST +#[test] +fn it_parses_cte_search_breadth_first() { + let query = "WITH RECURSIVE search_tree(id, parent_id, data, depth) AS ( + SELECT id, parent_id, data, 0 FROM tree WHERE parent_id IS NULL + UNION ALL + SELECT t.id, t.parent_id, t.data, st.depth + 1 + FROM tree t, search_tree st WHERE t.parent_id = st.id + ) SEARCH BREADTH FIRST BY id SET ordercol + SELECT * FROM search_tree ORDER BY ordercol"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test recursive CTE with SEARCH DEPTH FIRST +#[test] +fn it_parses_cte_search_depth_first() { + let query = "WITH RECURSIVE search_tree(id, parent_id, data) AS ( + SELECT id, parent_id, data FROM tree WHERE parent_id IS NULL + UNION ALL + SELECT t.id, t.parent_id, t.data + FROM tree t, search_tree st WHERE t.parent_id = st.id + ) SEARCH DEPTH FIRST BY id SET ordercol + SELECT * FROM search_tree ORDER BY ordercol"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test recursive CTE with CYCLE detection +#[test] +fn it_parses_cte_cycle() { + let query = "WITH RECURSIVE search_graph(id, link, data, depth) AS ( + SELECT g.id, g.link, g.data, 0 FROM graph g + UNION ALL + SELECT g.id, g.link, g.data, sg.depth + 1 + FROM graph g, search_graph sg WHERE g.id = sg.link + ) CYCLE id SET is_cycle USING path + SELECT * FROM search_graph"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test recursive CTE with both SEARCH and CYCLE +#[test] +fn it_parses_cte_search_and_cycle() { + let query = "WITH RECURSIVE search_graph(id, link, data, depth) AS ( + SELECT g.id, g.link, g.data, 0 FROM graph g WHERE id = 1 + UNION ALL + SELECT g.id, g.link, g.data, sg.depth + 1 + FROM graph g, search_graph sg WHERE g.id = sg.link + ) SEARCH DEPTH FIRST BY id SET ordercol + CYCLE id SET is_cycle USING path + SELECT * FROM search_graph"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} diff --git a/tests/raw_parse/statements.rs b/tests/raw_parse/statements.rs new file mode 100644 index 0000000..a46a549 --- /dev/null +++ b/tests/raw_parse/statements.rs @@ -0,0 +1,450 @@ +//! Utility statement tests: transactions, VACUUM, SET/SHOW, LOCK, DO, LISTEN, etc. +//! +//! These tests verify parse_raw correctly handles utility statements. + +use super::*; + +// ============================================================================ +// Transaction and utility statements +// ============================================================================ + +/// Test EXPLAIN +#[test] +fn it_parses_explain() { + let query = "EXPLAIN SELECT * FROM users WHERE id = 1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test EXPLAIN ANALYZE +#[test] +fn it_parses_explain_analyze() { + let query = "EXPLAIN (ANALYZE, BUFFERS, FORMAT JSON) SELECT * FROM users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test COPY +#[test] +fn it_parses_copy() { + let query = "COPY users (id, name, email) FROM STDIN WITH (FORMAT csv, HEADER true)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test PREPARE +#[test] +fn it_parses_prepare() { + let query = "PREPARE user_by_id (int) AS SELECT * FROM users WHERE id = $1"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test EXECUTE +#[test] +fn it_parses_execute() { + let query = "EXECUTE user_by_id(42)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DEALLOCATE +#[test] +fn it_parses_deallocate() { + let query = "DEALLOCATE user_by_id"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} +// ============================================================================ +// Transaction statement tests +// ============================================================================ + +/// Test BEGIN transaction +#[test] +fn it_parses_begin() { + let query = "BEGIN"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test BEGIN with options +#[test] +fn it_parses_begin_with_options() { + let query = "BEGIN ISOLATION LEVEL SERIALIZABLE READ ONLY"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test COMMIT transaction +#[test] +fn it_parses_commit() { + let query = "COMMIT"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ROLLBACK transaction +#[test] +fn it_parses_rollback() { + let query = "ROLLBACK"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test START TRANSACTION +#[test] +fn it_parses_start_transaction() { + let query = "START TRANSACTION"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test SAVEPOINT +#[test] +fn it_parses_savepoint() { + let query = "SAVEPOINT my_savepoint"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ROLLBACK TO SAVEPOINT +#[test] +fn it_parses_rollback_to_savepoint() { + let query = "ROLLBACK TO SAVEPOINT my_savepoint"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test RELEASE SAVEPOINT +#[test] +fn it_parses_release_savepoint() { + let query = "RELEASE SAVEPOINT my_savepoint"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} +// ============================================================================ +// VACUUM and ANALYZE statement tests +// ============================================================================ + +/// Test VACUUM +#[test] +fn it_parses_vacuum() { + let query = "VACUUM"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test VACUUM with table +#[test] +fn it_parses_vacuum_table() { + let query = "VACUUM users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test VACUUM ANALYZE +#[test] +fn it_parses_vacuum_analyze() { + let query = "VACUUM ANALYZE users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test VACUUM FULL +#[test] +fn it_parses_vacuum_full() { + let query = "VACUUM FULL users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ANALYZE +#[test] +fn it_parses_analyze() { + let query = "ANALYZE"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ANALYZE with table +#[test] +fn it_parses_analyze_table() { + let query = "ANALYZE users"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test ANALYZE with column list +#[test] +fn it_parses_analyze_columns() { + let query = "ANALYZE users (id, name)"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} +// ============================================================================ +// SET and SHOW statement tests +// ============================================================================ + +/// Test SET statement +#[test] +fn it_parses_set() { + let query = "SET search_path TO public"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test SET with equals +#[test] +fn it_parses_set_equals() { + let query = "SET statement_timeout = 5000"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test SET LOCAL +#[test] +fn it_parses_set_local() { + let query = "SET LOCAL search_path TO myschema"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test SET SESSION +#[test] +fn it_parses_set_session() { + let query = "SET SESSION timezone = 'UTC'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test RESET +#[test] +fn it_parses_reset() { + let query = "RESET search_path"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test RESET ALL +#[test] +fn it_parses_reset_all() { + let query = "RESET ALL"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test SHOW statement +#[test] +fn it_parses_show() { + let query = "SHOW search_path"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test SHOW ALL +#[test] +fn it_parses_show_all() { + let query = "SHOW ALL"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} +// ============================================================================ +// LISTEN, NOTIFY, UNLISTEN statement tests +// ============================================================================ + +/// Test LISTEN statement +#[test] +fn it_parses_listen() { + let query = "LISTEN my_channel"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test NOTIFY statement +#[test] +fn it_parses_notify() { + let query = "NOTIFY my_channel"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test NOTIFY with payload +#[test] +fn it_parses_notify_with_payload() { + let query = "NOTIFY my_channel, 'hello world'"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UNLISTEN statement +#[test] +fn it_parses_unlisten() { + let query = "UNLISTEN my_channel"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test UNLISTEN * +#[test] +fn it_parses_unlisten_all() { + let query = "UNLISTEN *"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} +// ============================================================================ +// DISCARD statement tests +// ============================================================================ + +/// Test DISCARD ALL +#[test] +fn it_parses_discard_all() { + let query = "DISCARD ALL"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DISCARD PLANS +#[test] +fn it_parses_discard_plans() { + let query = "DISCARD PLANS"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DISCARD SEQUENCES +#[test] +fn it_parses_discard_sequences() { + let query = "DISCARD SEQUENCES"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DISCARD TEMP +#[test] +fn it_parses_discard_temp() { + let query = "DISCARD TEMP"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} +// ============================================================================ +// LOCK statement tests +// ============================================================================ + +/// Test LOCK TABLE +#[test] +fn it_parses_lock_table() { + let query = "LOCK TABLE users IN ACCESS EXCLUSIVE MODE"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test LOCK multiple tables +#[test] +fn it_parses_lock_multiple_tables() { + let query = "LOCK TABLE users, orders IN SHARE MODE"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} +// ============================================================================ +// DO statement tests +// ============================================================================ + +/// Test DO statement +#[test] +fn it_parses_do_statement() { + let query = "DO $$ BEGIN RAISE NOTICE 'Hello'; END $$"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} + +/// Test DO statement with language +#[test] +fn it_parses_do_with_language() { + let query = "DO LANGUAGE plpgsql $$ BEGIN NULL; END $$"; + let raw_result = parse_raw(query).unwrap(); + let proto_result = parse(query).unwrap(); + + assert_eq!(raw_result.protobuf, proto_result.protobuf); +} diff --git a/tests/raw_parse_tests.rs b/tests/raw_parse_tests.rs index f8f300c..f19220b 100644 --- a/tests/raw_parse_tests.rs +++ b/tests/raw_parse_tests.rs @@ -1,2269 +1,23 @@ -#![allow(non_snake_case)] -#![cfg(test)] - -use pg_query::protobuf::{a_const, node, ParseResult as ProtobufParseResult}; -use pg_query::{parse, parse_raw, Error}; - -/// Test that parse_raw results can be deparsed back to SQL -#[test] -fn it_deparses_parse_raw_result() { - let query = "SELECT * FROM users"; - let result = parse_raw(query).unwrap(); - - // Print version info for debugging - eprintln!("parse_raw protobuf version: {}", result.protobuf.version); - - // Compare with regular parse - let regular_result = parse(query).unwrap(); - eprintln!("parse protobuf version: {}", regular_result.protobuf.version); - - assert_eq!(result.protobuf.version, regular_result.protobuf.version, "Version mismatch between parse_raw and parse"); - - let deparsed = result.deparse().unwrap(); - assert_eq!(deparsed, query); -} - -#[macro_use] -mod support; - -// ============================================================================ -// Helper functions -// ============================================================================ - -/// Helper to extract AConst from a SELECT statement's first target -fn get_first_const(result: &ProtobufParseResult) -> Option<&pg_query::protobuf::AConst> { - let stmt = result.stmts.first()?; - let raw_stmt = stmt.stmt.as_ref()?; - let node = raw_stmt.node.as_ref()?; - - if let node::Node::SelectStmt(select) = node { - let target = select.target_list.first()?; - if let Some(node::Node::ResTarget(res_target)) = target.node.as_ref() { - if let Some(val_node) = res_target.val.as_ref() { - if let Some(node::Node::AConst(aconst)) = val_node.node.as_ref() { - return Some(aconst); - } - } - } - } - None -} - -// ============================================================================ -// Basic parsing tests -// ============================================================================ - -/// Test that parse_raw successfully parses a simple SELECT query -#[test] -fn it_parses_simple_select() { - let query = "SELECT 1"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf.stmts.len(), 1); - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test that parse_raw handles syntax errors -#[test] -fn it_handles_parse_errors() { - let query = "SELECT * FORM users"; - let raw_error = parse_raw(query).err().unwrap(); - let proto_error = parse(query).err().unwrap(); - - assert!(matches!(raw_error, Error::Parse(_))); - assert!(matches!(proto_error, Error::Parse(_))); -} - -/// Test that parse_raw and parse produce equivalent results for simple SELECT -#[test] -fn it_matches_parse_for_simple_select() { - let query = "SELECT 1"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test that parse_raw and parse produce equivalent results for SELECT with table -#[test] -fn it_matches_parse_for_select_from_table() { - let query = "SELECT * FROM users"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Verify tables are extracted correctly - let mut raw_tables = raw_result.tables(); - let mut proto_tables = proto_result.tables(); - raw_tables.sort(); - proto_tables.sort(); - assert_eq!(raw_tables, proto_tables); - assert_eq!(raw_tables, vec!["users"]); -} - -/// Test that parse_raw handles empty queries (comments only) -#[test] -fn it_handles_empty_queries() { - let query = "-- just a comment"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf.stmts.len(), 0); - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test that parse_raw parses multiple statements -#[test] -fn it_parses_multiple_statements() { - let query = "SELECT 1; SELECT 2; SELECT 3"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf.stmts.len(), 3); - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// DML statement tests -// ============================================================================ - -/// Test parsing INSERT statement -#[test] -fn it_parses_insert() { - let query = "INSERT INTO users (name) VALUES ('test')"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Verify the INSERT target table - let mut raw_tables = raw_result.dml_tables(); - let mut proto_tables = proto_result.dml_tables(); - raw_tables.sort(); - proto_tables.sort(); - assert_eq!(raw_tables, proto_tables); - assert_eq!(raw_tables, vec!["users"]); -} - -/// Test parsing UPDATE statement -#[test] -fn it_parses_update() { - let query = "UPDATE users SET name = 'bob' WHERE id = 1"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Verify the UPDATE target table - let mut raw_tables = raw_result.dml_tables(); - let mut proto_tables = proto_result.dml_tables(); - raw_tables.sort(); - proto_tables.sort(); - assert_eq!(raw_tables, proto_tables); - assert_eq!(raw_tables, vec!["users"]); -} - -/// Test parsing DELETE statement -#[test] -fn it_parses_delete() { - let query = "DELETE FROM users WHERE id = 1"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Verify the DELETE target table - let mut raw_tables = raw_result.dml_tables(); - let mut proto_tables = proto_result.dml_tables(); - raw_tables.sort(); - proto_tables.sort(); - assert_eq!(raw_tables, proto_tables); - assert_eq!(raw_tables, vec!["users"]); -} - -// ============================================================================ -// DDL statement tests -// ============================================================================ - -/// Test parsing CREATE TABLE -#[test] -fn it_parses_create_table() { - let query = "CREATE TABLE test (id int, name text)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Verify statement types match - assert_eq!(raw_result.statement_types(), proto_result.statement_types()); - assert_eq!(raw_result.statement_types(), vec!["CreateStmt"]); -} - -/// Test parsing DROP TABLE -#[test] -fn it_parses_drop_table() { - let query = "DROP TABLE users"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Verify DDL tables match - let mut raw_tables = raw_result.ddl_tables(); - let mut proto_tables = proto_result.ddl_tables(); - raw_tables.sort(); - proto_tables.sort(); - assert_eq!(raw_tables, proto_tables); - assert_eq!(raw_tables, vec!["users"]); -} - -/// Test parsing CREATE INDEX -#[test] -fn it_parses_create_index() { - let query = "CREATE INDEX idx_users_name ON users (name)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Verify statement types match - assert_eq!(raw_result.statement_types(), proto_result.statement_types()); - assert_eq!(raw_result.statement_types(), vec!["IndexStmt"]); -} - -// ============================================================================ -// JOIN and complex SELECT tests -// ============================================================================ - -/// Test parsing SELECT with JOIN -#[test] -fn it_parses_join() { - let query = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Verify tables are extracted correctly - let mut raw_tables = raw_result.tables(); - let mut proto_tables = proto_result.tables(); - raw_tables.sort(); - proto_tables.sort(); - assert_eq!(raw_tables, proto_tables); - assert_eq!(raw_tables, vec!["orders", "users"]); -} - -/// Test parsing UNION query -#[test] -fn it_parses_union() { - let query = "SELECT id FROM users UNION SELECT id FROM admins"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Verify tables from both sides of UNION - let mut raw_tables = raw_result.tables(); - let mut proto_tables = proto_result.tables(); - raw_tables.sort(); - proto_tables.sort(); - assert_eq!(raw_tables, proto_tables); - assert_eq!(raw_tables, vec!["admins", "users"]); -} - -/// Test parsing WITH clause (CTE) -#[test] -fn it_parses_cte() { - let query = "WITH active_users AS (SELECT * FROM users WHERE active = true) SELECT * FROM active_users"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Verify CTE names match - assert_eq!(raw_result.cte_names, proto_result.cte_names); - assert!(raw_result.cte_names.contains(&"active_users".to_string())); - - // Verify tables (should only include actual tables, not CTEs) - let mut raw_tables = raw_result.tables(); - let mut proto_tables = proto_result.tables(); - raw_tables.sort(); - proto_tables.sort(); - assert_eq!(raw_tables, proto_tables); - assert_eq!(raw_tables, vec!["users"]); -} - -/// Test parsing subquery in SELECT -#[test] -fn it_parses_subquery() { - let query = "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Verify all tables are found - let mut raw_tables = raw_result.tables(); - let mut proto_tables = proto_result.tables(); - raw_tables.sort(); - proto_tables.sort(); - assert_eq!(raw_tables, proto_tables); - assert_eq!(raw_tables, vec!["orders", "users"]); -} - -/// Test parsing aggregate functions -#[test] -fn it_parses_aggregates() { - let query = "SELECT count(*), sum(amount), avg(price) FROM orders"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Verify functions are extracted correctly - let mut raw_funcs = raw_result.functions(); - let mut proto_funcs = proto_result.functions(); - raw_funcs.sort(); - proto_funcs.sort(); - assert_eq!(raw_funcs, proto_funcs); - assert!(raw_funcs.contains(&"count".to_string())); - assert!(raw_funcs.contains(&"sum".to_string())); - assert!(raw_funcs.contains(&"avg".to_string())); -} - -/// Test parsing CASE expression -#[test] -fn it_parses_case_expression() { - let query = "SELECT CASE WHEN x > 0 THEN 'positive' ELSE 'non-positive' END FROM t"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Verify table is found - let raw_tables = raw_result.tables(); - let proto_tables = proto_result.tables(); - assert_eq!(raw_tables, proto_tables); - assert_eq!(raw_tables, vec!["t"]); -} - -/// Test parsing complex SELECT with multiple clauses -#[test] -fn it_parses_complex_select() { - let query = "SELECT u.id, u.name, count(*) AS order_count FROM users u LEFT JOIN orders o ON u.id = o.user_id WHERE u.active = true GROUP BY u.id, u.name HAVING count(*) > 0 ORDER BY order_count DESC LIMIT 10"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Verify tables - let mut raw_tables = raw_result.tables(); - let mut proto_tables = proto_result.tables(); - raw_tables.sort(); - proto_tables.sort(); - assert_eq!(raw_tables, proto_tables); - assert_eq!(raw_tables, vec!["orders", "users"]); - - // Verify functions - let mut raw_funcs = raw_result.functions(); - let mut proto_funcs = proto_result.functions(); - raw_funcs.sort(); - proto_funcs.sort(); - assert_eq!(raw_funcs, proto_funcs); - assert!(raw_funcs.contains(&"count".to_string())); -} - -// ============================================================================ -// INSERT variations -// ============================================================================ - -/// Test parsing INSERT with ON CONFLICT -#[test] -fn it_parses_insert_on_conflict() { - let query = "INSERT INTO users (id, name) VALUES (1, 'test') ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Verify DML tables - let raw_tables = raw_result.dml_tables(); - let proto_tables = proto_result.dml_tables(); - assert_eq!(raw_tables, proto_tables); - assert_eq!(raw_tables, vec!["users"]); -} - -/// Test parsing INSERT with RETURNING -#[test] -fn it_parses_insert_returning() { - let query = "INSERT INTO users (name) VALUES ('test') RETURNING id"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// Literal value tests -// ============================================================================ - -/// Test parsing float with leading dot -#[test] -fn it_parses_floats_with_leading_dot() { - let query = "SELECT .1"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Verify the float value - let raw_const = get_first_const(&raw_result.protobuf).expect("should have const"); - let proto_const = get_first_const(&proto_result.protobuf).expect("should have const"); - assert_eq!(raw_const, proto_const); -} - -/// Test parsing bit string in hex notation -#[test] -fn it_parses_bit_strings_hex() { - let query = "SELECT X'EFFF'"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Verify the bit string value - let raw_const = get_first_const(&raw_result.protobuf).expect("should have const"); - let proto_const = get_first_const(&proto_result.protobuf).expect("should have const"); - assert_eq!(raw_const, proto_const); -} - -/// Test parsing real-world query with multiple joins -#[test] -fn it_parses_real_world_query() { - let query = " - SELECT memory_total_bytes, memory_free_bytes, memory_pagecache_bytes, - (memory_swap_total_bytes - memory_swap_free_bytes) AS swap - FROM snapshots s JOIN system_snapshots ON (snapshot_id = s.id) - WHERE s.database_id = 1 AND s.collected_at BETWEEN '2021-01-01' AND '2021-12-31' - ORDER BY collected_at"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Verify tables - let mut raw_tables = raw_result.tables(); - let mut proto_tables = proto_result.tables(); - raw_tables.sort(); - proto_tables.sort(); - assert_eq!(raw_tables, proto_tables); - assert_eq!(raw_tables, vec!["snapshots", "system_snapshots"]); -} - -// ============================================================================ -// A_Const value extraction tests -// ============================================================================ - -/// Test that parse_raw extracts integer values correctly and matches parse -#[test] -fn it_extracts_integer_const() { - let query = "SELECT 42"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); - let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); - - assert_eq!(raw_const, proto_const); - assert!(!raw_const.isnull); - match &raw_const.val { - Some(a_const::Val::Ival(int_val)) => { - assert_eq!(int_val.ival, 42); - } - other => panic!("Expected Ival, got {:?}", other), - } -} - -/// Test that parse_raw extracts negative integer values correctly -#[test] -fn it_extracts_negative_integer_const() { - let query = "SELECT -123"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test that parse_raw extracts string values correctly and matches parse -#[test] -fn it_extracts_string_const() { - let query = "SELECT 'hello world'"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); - let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); - - assert_eq!(raw_const, proto_const); - assert!(!raw_const.isnull); - match &raw_const.val { - Some(a_const::Val::Sval(str_val)) => { - assert_eq!(str_val.sval, "hello world"); - } - other => panic!("Expected Sval, got {:?}", other), - } -} - -/// Test that parse_raw extracts float values correctly and matches parse -#[test] -fn it_extracts_float_const() { - let query = "SELECT 3.14159"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); - let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); - - assert_eq!(raw_const, proto_const); - assert!(!raw_const.isnull); - match &raw_const.val { - Some(a_const::Val::Fval(float_val)) => { - assert_eq!(float_val.fval, "3.14159"); - } - other => panic!("Expected Fval, got {:?}", other), - } -} - -/// Test that parse_raw extracts boolean TRUE correctly and matches parse -#[test] -fn it_extracts_boolean_true_const() { - let query = "SELECT TRUE"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); - let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); - - assert_eq!(raw_const, proto_const); - assert!(!raw_const.isnull); - match &raw_const.val { - Some(a_const::Val::Boolval(bool_val)) => { - assert!(bool_val.boolval); - } - other => panic!("Expected Boolval(true), got {:?}", other), - } -} - -/// Test that parse_raw extracts boolean FALSE correctly and matches parse -#[test] -fn it_extracts_boolean_false_const() { - let query = "SELECT FALSE"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); - let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); - - assert_eq!(raw_const, proto_const); - assert!(!raw_const.isnull); - match &raw_const.val { - Some(a_const::Val::Boolval(bool_val)) => { - assert!(!bool_val.boolval); - } - other => panic!("Expected Boolval(false), got {:?}", other), - } -} - -/// Test that parse_raw extracts NULL correctly and matches parse -#[test] -fn it_extracts_null_const() { - let query = "SELECT NULL"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); - let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); - - assert_eq!(raw_const, proto_const); - assert!(raw_const.isnull); - assert!(raw_const.val.is_none()); -} - -/// Test that parse_raw extracts bit string values correctly and matches parse -#[test] -fn it_extracts_bit_string_const() { - let query = "SELECT B'1010'"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); - let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); - - assert_eq!(raw_const, proto_const); - assert!(!raw_const.isnull); - match &raw_const.val { - Some(a_const::Val::Bsval(bit_val)) => { - assert_eq!(bit_val.bsval, "b1010"); - } - other => panic!("Expected Bsval, got {:?}", other), - } -} - -/// Test that parse_raw extracts hex bit string correctly and matches parse -#[test] -fn it_extracts_hex_bit_string_const() { - let query = "SELECT X'FF'"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - let raw_const = get_first_const(&raw_result.protobuf).expect("Should have A_Const"); - let proto_const = get_first_const(&proto_result.protobuf).expect("Should have A_Const"); - - assert_eq!(raw_const, proto_const); - assert!(!raw_const.isnull); - match &raw_const.val { - Some(a_const::Val::Bsval(bit_val)) => { - assert_eq!(bit_val.bsval, "xFF"); - } - other => panic!("Expected Bsval, got {:?}", other), - } -} - -// ============================================================================ -// ParseResult method equivalence tests -// ============================================================================ - -/// Test that tables() returns the same results for both parsers -#[test] -fn it_returns_tables_like_parse() { - let query = "SELECT * FROM users"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Both should have the same tables - let mut raw_tables = raw_result.tables(); - let mut proto_tables = proto_result.tables(); - raw_tables.sort(); - proto_tables.sort(); - assert_eq!(raw_tables, proto_tables); - assert_eq!(raw_tables, vec!["users"]); -} - -/// Test that functions() returns the same results for both parsers -#[test] -fn it_returns_functions_like_parse() { - let query = "SELECT count(*), sum(amount) FROM orders"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - // Both should have the same functions - let mut raw_funcs = raw_result.functions(); - let mut proto_funcs = proto_result.functions(); - raw_funcs.sort(); - proto_funcs.sort(); - assert_eq!(raw_funcs, proto_funcs); - assert_eq!(raw_funcs, vec!["count", "sum"]); -} - -/// Test that statement_types() returns the same results for both parsers -#[test] -fn it_returns_statement_types_like_parse() { - let query = "SELECT 1; INSERT INTO t VALUES (1); UPDATE t SET x = 1; DELETE FROM t"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - // Full structural equality check - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - assert_eq!(raw_result.statement_types(), proto_result.statement_types()); - assert_eq!(raw_result.statement_types(), vec!["SelectStmt", "InsertStmt", "UpdateStmt", "DeleteStmt"]); -} - -// ============================================================================ -// Advanced JOIN tests -// ============================================================================ - -/// Test LEFT JOIN -#[test] -fn it_parses_left_join() { - let query = "SELECT * FROM users u LEFT JOIN orders o ON u.id = o.user_id"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - let mut raw_tables = raw_result.tables(); - raw_tables.sort(); - assert_eq!(raw_tables, vec!["orders", "users"]); -} - -/// Test RIGHT JOIN -#[test] -fn it_parses_right_join() { - let query = "SELECT * FROM users u RIGHT JOIN orders o ON u.id = o.user_id"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test FULL OUTER JOIN -#[test] -fn it_parses_full_outer_join() { - let query = "SELECT * FROM users u FULL OUTER JOIN orders o ON u.id = o.user_id"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test CROSS JOIN -#[test] -fn it_parses_cross_join() { - let query = "SELECT * FROM users CROSS JOIN products"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - let mut raw_tables = raw_result.tables(); - raw_tables.sort(); - assert_eq!(raw_tables, vec!["products", "users"]); -} - -/// Test NATURAL JOIN -#[test] -fn it_parses_natural_join() { - let query = "SELECT * FROM users NATURAL JOIN user_profiles"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test multiple JOINs -#[test] -fn it_parses_multiple_joins() { - let query = "SELECT u.name, o.id, p.name FROM users u - JOIN orders o ON u.id = o.user_id - JOIN order_items oi ON o.id = oi.order_id - JOIN products p ON oi.product_id = p.id"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - let mut raw_tables = raw_result.tables(); - raw_tables.sort(); - assert_eq!(raw_tables, vec!["order_items", "orders", "products", "users"]); -} - -/// Test JOIN with USING clause -#[test] -fn it_parses_join_using() { - let query = "SELECT * FROM users u JOIN user_profiles p USING (user_id)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test LATERAL JOIN -#[test] -fn it_parses_lateral_join() { - let query = "SELECT * FROM users u, LATERAL (SELECT * FROM orders o WHERE o.user_id = u.id LIMIT 3) AS recent_orders"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - let mut raw_tables = raw_result.tables(); - raw_tables.sort(); - assert_eq!(raw_tables, vec!["orders", "users"]); -} - -// ============================================================================ -// Advanced subquery tests -// ============================================================================ - -/// Test correlated subquery -#[test] -fn it_parses_correlated_subquery() { - let query = "SELECT * FROM users u WHERE EXISTS (SELECT 1 FROM orders o WHERE o.user_id = u.id)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); - - let mut raw_tables = raw_result.tables(); - raw_tables.sort(); - assert_eq!(raw_tables, vec!["orders", "users"]); -} - -/// Test NOT EXISTS subquery -#[test] -fn it_parses_not_exists_subquery() { - let query = "SELECT * FROM users u WHERE NOT EXISTS (SELECT 1 FROM banned b WHERE b.user_id = u.id)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test scalar subquery in SELECT -#[test] -fn it_parses_scalar_subquery() { - let query = "SELECT u.name, (SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id) AS order_count FROM users u"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test subquery in FROM clause -#[test] -fn it_parses_derived_table() { - let query = "SELECT * FROM (SELECT id, name FROM users WHERE active = true) AS active_users"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test ANY/SOME subquery -#[test] -fn it_parses_any_subquery() { - let query = "SELECT * FROM products WHERE price > ANY (SELECT avg_price FROM categories)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test ALL subquery -#[test] -fn it_parses_all_subquery() { - let query = "SELECT * FROM products WHERE price > ALL (SELECT price FROM discounted_products)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// Window function tests -// ============================================================================ - -/// Test basic window function -#[test] -fn it_parses_window_function() { - let query = "SELECT name, salary, ROW_NUMBER() OVER (ORDER BY salary DESC) AS rank FROM employees"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test window function with PARTITION BY -#[test] -fn it_parses_window_function_partition() { - let query = "SELECT department, name, salary, RANK() OVER (PARTITION BY department ORDER BY salary DESC) AS dept_rank FROM employees"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test window function with frame clause -#[test] -fn it_parses_window_function_frame() { - let query = "SELECT date, amount, SUM(amount) OVER (ORDER BY date ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS moving_sum FROM transactions"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test named window -#[test] -fn it_parses_named_window() { - let query = "SELECT name, salary, SUM(salary) OVER w, AVG(salary) OVER w FROM employees WINDOW w AS (PARTITION BY department ORDER BY salary)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test LAG and LEAD functions -#[test] -fn it_parses_lag_lead() { - let query = - "SELECT date, price, LAG(price, 1) OVER (ORDER BY date) AS prev_price, LEAD(price, 1) OVER (ORDER BY date) AS next_price FROM stock_prices"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// CTE variations -// ============================================================================ - -/// Test multiple CTEs -#[test] -fn it_parses_multiple_ctes() { - let query = "WITH - active_users AS (SELECT * FROM users WHERE active = true), - premium_users AS (SELECT * FROM active_users WHERE plan = 'premium') - SELECT * FROM premium_users"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); - assert!(raw_result.cte_names.contains(&"active_users".to_string())); - assert!(raw_result.cte_names.contains(&"premium_users".to_string())); -} - -/// Test recursive CTE -#[test] -fn it_parses_recursive_cte() { - let query = "WITH RECURSIVE subordinates AS ( - SELECT id, name, manager_id FROM employees WHERE id = 1 - UNION ALL - SELECT e.id, e.name, e.manager_id FROM employees e INNER JOIN subordinates s ON e.manager_id = s.id - ) SELECT * FROM subordinates"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test CTE with column list -#[test] -fn it_parses_cte_with_columns() { - let query = "WITH regional_sales(region, total) AS (SELECT region, SUM(amount) FROM orders GROUP BY region) SELECT * FROM regional_sales"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test CTE with MATERIALIZED -#[test] -fn it_parses_cte_materialized() { - let query = "WITH t AS MATERIALIZED (SELECT * FROM large_table WHERE x > 100) SELECT * FROM t"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// Set operations -// ============================================================================ - -/// Test INTERSECT -#[test] -fn it_parses_intersect() { - let query = "SELECT id FROM users INTERSECT SELECT user_id FROM orders"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test EXCEPT -#[test] -fn it_parses_except() { - let query = "SELECT id FROM users EXCEPT SELECT user_id FROM banned_users"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test UNION ALL -#[test] -fn it_parses_union_all() { - let query = "SELECT name FROM users UNION ALL SELECT name FROM admins"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test compound set operations -#[test] -fn it_parses_compound_set_operations() { - let query = "(SELECT id FROM a UNION SELECT id FROM b) INTERSECT SELECT id FROM c"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// GROUP BY variations -// ============================================================================ - -/// Test GROUP BY ROLLUP -#[test] -fn it_parses_group_by_rollup() { - let query = "SELECT region, product, SUM(sales) FROM sales_data GROUP BY ROLLUP(region, product)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test GROUP BY CUBE -#[test] -fn it_parses_group_by_cube() { - let query = "SELECT region, product, SUM(sales) FROM sales_data GROUP BY CUBE(region, product)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test GROUP BY GROUPING SETS -#[test] -fn it_parses_grouping_sets() { - let query = "SELECT region, product, SUM(sales) FROM sales_data GROUP BY GROUPING SETS ((region), (product), ())"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// DISTINCT and ORDER BY variations -// ============================================================================ - -/// Test DISTINCT ON -#[test] -fn it_parses_distinct_on() { - let query = "SELECT DISTINCT ON (user_id) * FROM orders ORDER BY user_id, created_at DESC"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test ORDER BY with NULLS FIRST/LAST -#[test] -fn it_parses_order_by_nulls() { - let query = "SELECT * FROM users ORDER BY last_login DESC NULLS LAST"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test FETCH FIRST -#[test] -fn it_parses_fetch_first() { - let query = "SELECT * FROM users ORDER BY id FETCH FIRST 10 ROWS ONLY"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test OFFSET with FETCH -#[test] -fn it_parses_offset_fetch() { - let query = "SELECT * FROM users ORDER BY id OFFSET 20 ROWS FETCH NEXT 10 ROWS ONLY"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// Locking clauses -// ============================================================================ - -/// Test FOR UPDATE -#[test] -fn it_parses_for_update() { - let query = "SELECT * FROM users WHERE id = 1 FOR UPDATE"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test FOR SHARE -#[test] -fn it_parses_for_share() { - let query = "SELECT * FROM users WHERE id = 1 FOR SHARE"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test FOR UPDATE NOWAIT -#[test] -fn it_parses_for_update_nowait() { - let query = "SELECT * FROM users WHERE id = 1 FOR UPDATE NOWAIT"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test FOR UPDATE SKIP LOCKED -#[test] -fn it_parses_for_update_skip_locked() { - let query = "SELECT * FROM jobs WHERE status = 'pending' LIMIT 1 FOR UPDATE SKIP LOCKED"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// Expression tests -// ============================================================================ - -/// Test COALESCE -#[test] -fn it_parses_coalesce() { - let query = "SELECT COALESCE(nickname, name, 'Unknown') FROM users"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test NULLIF -#[test] -fn it_parses_nullif() { - let query = "SELECT NULLIF(status, 'deleted') FROM records"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test GREATEST and LEAST -#[test] -fn it_parses_greatest_least() { - let query = "SELECT GREATEST(a, b, c), LEAST(x, y, z) FROM t"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test IS NULL and IS NOT NULL -#[test] -fn it_parses_null_tests() { - let query = "SELECT * FROM users WHERE deleted_at IS NULL AND email IS NOT NULL"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test IS DISTINCT FROM -#[test] -fn it_parses_is_distinct_from() { - let query = "SELECT * FROM t WHERE a IS DISTINCT FROM b"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test BETWEEN -#[test] -fn it_parses_between() { - let query = "SELECT * FROM events WHERE created_at BETWEEN '2023-01-01' AND '2023-12-31'"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test LIKE and ILIKE -#[test] -fn it_parses_like_ilike() { - let query = "SELECT * FROM users WHERE name LIKE 'John%' OR email ILIKE '%@EXAMPLE.COM'"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test SIMILAR TO -#[test] -fn it_parses_similar_to() { - let query = "SELECT * FROM products WHERE name SIMILAR TO '%(phone|tablet)%'"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test complex boolean expressions -#[test] -fn it_parses_complex_boolean() { - let query = "SELECT * FROM users WHERE (active = true AND verified = true) OR (role = 'admin' AND NOT suspended)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// Type cast tests -// ============================================================================ - -/// Test PostgreSQL-style type cast -#[test] -fn it_parses_pg_type_cast() { - let query = "SELECT '123'::integer, '2023-01-01'::date, 'true'::boolean"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test SQL-style CAST -#[test] -fn it_parses_sql_cast() { - let query = "SELECT CAST('123' AS integer), CAST(created_at AS date) FROM t"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test array type cast -#[test] -fn it_parses_array_cast() { - let query = "SELECT ARRAY[1, 2, 3]::text[]"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// Array and JSON tests -// ============================================================================ - -/// Test array constructor -#[test] -fn it_parses_array_constructor() { - let query = "SELECT ARRAY[1, 2, 3], ARRAY['a', 'b', 'c']"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test array subscript -#[test] -fn it_parses_array_subscript() { - let query = "SELECT tags[1], matrix[1][2] FROM t"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test array slice -#[test] -fn it_parses_array_slice() { - let query = "SELECT arr[2:4], arr[:3], arr[2:] FROM t"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test unnest -#[test] -fn it_parses_unnest() { - let query = "SELECT unnest(ARRAY[1, 2, 3])"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test JSON operators -#[test] -fn it_parses_json_operators() { - let query = "SELECT data->'name', data->>'email', data#>'{address,city}' FROM users"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test JSONB containment -#[test] -fn it_parses_jsonb_containment() { - let query = "SELECT * FROM products WHERE metadata @> '{\"featured\": true}'"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// DDL statements -// ============================================================================ - -/// Test CREATE TABLE with constraints -#[test] -fn it_parses_create_table_with_constraints() { - let query = "CREATE TABLE orders ( - id SERIAL PRIMARY KEY, - user_id INTEGER NOT NULL REFERENCES users(id), - amount DECIMAL(10, 2) CHECK (amount > 0), - status TEXT DEFAULT 'pending', - created_at TIMESTAMP DEFAULT NOW(), - UNIQUE (user_id, created_at) - )"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test CREATE TABLE AS -#[test] -fn it_parses_create_table_as() { - let query = "CREATE TABLE active_users AS SELECT * FROM users WHERE active = true"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test CREATE VIEW -#[test] -fn it_parses_create_view() { - let query = "CREATE VIEW active_users AS SELECT id, name FROM users WHERE active = true"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test CREATE MATERIALIZED VIEW -#[test] -fn it_parses_create_materialized_view() { - let query = "CREATE MATERIALIZED VIEW monthly_sales AS SELECT date_trunc('month', created_at) AS month, SUM(amount) FROM orders GROUP BY 1"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test ALTER TABLE ADD COLUMN -#[test] -fn it_parses_alter_table_add_column() { - let query = "ALTER TABLE users ADD COLUMN email TEXT NOT NULL"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test ALTER TABLE DROP COLUMN -#[test] -fn it_parses_alter_table_drop_column() { - let query = "ALTER TABLE users DROP COLUMN deprecated_field"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test ALTER TABLE ADD CONSTRAINT -#[test] -fn it_parses_alter_table_add_constraint() { - let query = "ALTER TABLE orders ADD CONSTRAINT fk_user FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test CREATE INDEX with expression -#[test] -fn it_parses_create_index_expression() { - let query = "CREATE INDEX idx_lower_email ON users (lower(email))"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); +//! Tests for parse_raw functionality. +//! +//! These tests verify that parse_raw produces equivalent results to parse. +//! Tests are split into modules for maintainability. - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test CREATE UNIQUE INDEX with WHERE -#[test] -fn it_parses_partial_unique_index() { - let query = "CREATE UNIQUE INDEX idx_active_email ON users (email) WHERE active = true"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test CREATE INDEX CONCURRENTLY -#[test] -fn it_parses_create_index_concurrently() { - let query = "CREATE INDEX CONCURRENTLY idx_name ON users (name)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test TRUNCATE -#[test] -fn it_parses_truncate() { - let query = "TRUNCATE TABLE logs, audit_logs RESTART IDENTITY CASCADE"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// Transaction and utility statements -// ============================================================================ - -/// Test EXPLAIN -#[test] -fn it_parses_explain() { - let query = "EXPLAIN SELECT * FROM users WHERE id = 1"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test EXPLAIN ANALYZE -#[test] -fn it_parses_explain_analyze() { - let query = "EXPLAIN (ANALYZE, BUFFERS, FORMAT JSON) SELECT * FROM users"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test COPY -#[test] -fn it_parses_copy() { - let query = "COPY users (id, name, email) FROM STDIN WITH (FORMAT csv, HEADER true)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test PREPARE -#[test] -fn it_parses_prepare() { - let query = "PREPARE user_by_id (int) AS SELECT * FROM users WHERE id = $1"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test EXECUTE -#[test] -fn it_parses_execute() { - let query = "EXECUTE user_by_id(42)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test DEALLOCATE -#[test] -fn it_parses_deallocate() { - let query = "DEALLOCATE user_by_id"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// Parameter placeholder tests -// ============================================================================ - -/// Test positional parameters -#[test] -fn it_parses_positional_params() { - let query = "SELECT * FROM users WHERE id = $1 AND status = $2"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test parameters in INSERT -#[test] -fn it_parses_params_in_insert() { - let query = "INSERT INTO users (name, email) VALUES ($1, $2) RETURNING id"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// Complex real-world queries -// ============================================================================ - -/// Test analytics query with window functions -#[test] -fn it_parses_analytics_query() { - let query = " - SELECT - date_trunc('day', created_at) AS day, - COUNT(*) AS daily_orders, - SUM(amount) AS daily_revenue, - AVG(amount) OVER (ORDER BY date_trunc('day', created_at) ROWS BETWEEN 6 PRECEDING AND CURRENT ROW) AS weekly_avg - FROM orders - WHERE created_at >= NOW() - INTERVAL '30 days' - GROUP BY date_trunc('day', created_at) - ORDER BY day"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test hierarchical query with recursive CTE -#[test] -fn it_parses_hierarchy_query() { - let query = " - WITH RECURSIVE category_tree AS ( - SELECT id, name, parent_id, 0 AS level, ARRAY[id] AS path - FROM categories - WHERE parent_id IS NULL - UNION ALL - SELECT c.id, c.name, c.parent_id, ct.level + 1, ct.path || c.id - FROM categories c - JOIN category_tree ct ON c.parent_id = ct.id - ) - SELECT * FROM category_tree ORDER BY path"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test complex report query -#[test] -fn it_parses_complex_report_query() { - let query = " - WITH monthly_data AS ( - SELECT - date_trunc('month', o.created_at) AS month, - u.region, - p.category, - SUM(oi.quantity * oi.unit_price) AS revenue, - COUNT(DISTINCT o.id) AS order_count, - COUNT(DISTINCT o.user_id) AS customer_count - FROM orders o - JOIN users u ON o.user_id = u.id - JOIN order_items oi ON o.id = oi.order_id - JOIN products p ON oi.product_id = p.id - WHERE o.created_at >= '2023-01-01' AND o.status = 'completed' - GROUP BY 1, 2, 3 - ) - SELECT - month, - region, - category, - revenue, - order_count, - customer_count, - revenue / NULLIF(order_count, 0) AS avg_order_value, - SUM(revenue) OVER (PARTITION BY region ORDER BY month) AS cumulative_revenue - FROM monthly_data - ORDER BY month DESC, region, revenue DESC"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test query with multiple subqueries and CTEs -#[test] -fn it_parses_mixed_subqueries_and_ctes() { - let query = " - WITH high_value_customers AS ( - SELECT user_id FROM orders GROUP BY user_id HAVING SUM(amount) > 1000 - ) - SELECT u.*, - (SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id) AS total_orders, - (SELECT MAX(created_at) FROM orders o WHERE o.user_id = u.id) AS last_order - FROM users u - WHERE u.id IN (SELECT user_id FROM high_value_customers) - AND EXISTS (SELECT 1 FROM orders o WHERE o.user_id = u.id AND o.created_at > NOW() - INTERVAL '90 days') - ORDER BY (SELECT SUM(amount) FROM orders o WHERE o.user_id = u.id) DESC - LIMIT 100"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// Complex INSERT tests -// ============================================================================ - -/// Test INSERT with multiple tuples -#[test] -fn it_parses_insert_multiple_rows() { - let query = "INSERT INTO users (name, email, age) VALUES ('Alice', 'alice@example.com', 25), ('Bob', 'bob@example.com', 30), ('Charlie', 'charlie@example.com', 35)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test INSERT ... SELECT -#[test] -fn it_parses_insert_select() { - let query = "INSERT INTO archived_users (id, name, email) SELECT id, name, email FROM users WHERE deleted_at IS NOT NULL"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test INSERT ... SELECT with complex query -#[test] -fn it_parses_insert_select_complex() { - let query = "INSERT INTO monthly_stats (month, user_count, order_count, total_revenue) - SELECT date_trunc('month', created_at) AS month, - COUNT(DISTINCT user_id), - COUNT(*), - SUM(amount) - FROM orders - WHERE created_at >= '2023-01-01' - GROUP BY date_trunc('month', created_at)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test INSERT with CTE -#[test] -fn it_parses_insert_with_cte() { - let query = "WITH new_data AS ( - SELECT name, email FROM temp_imports WHERE valid = true - ) - INSERT INTO users (name, email) SELECT name, email FROM new_data"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test INSERT with DEFAULT values -#[test] -fn it_parses_insert_default_values() { - let query = "INSERT INTO users (name, created_at) VALUES ('test', DEFAULT)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test INSERT with ON CONFLICT DO NOTHING -#[test] -fn it_parses_insert_on_conflict_do_nothing() { - let query = "INSERT INTO users (id, name) VALUES (1, 'test') ON CONFLICT (id) DO NOTHING"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test INSERT with ON CONFLICT with WHERE clause -#[test] -fn it_parses_insert_on_conflict_with_where() { - let query = "INSERT INTO users (id, name, updated_at) VALUES (1, 'test', NOW()) - ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, updated_at = EXCLUDED.updated_at - WHERE users.updated_at < EXCLUDED.updated_at"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test INSERT with multiple columns in ON CONFLICT -#[test] -fn it_parses_insert_on_conflict_multiple_columns() { - let query = "INSERT INTO user_settings (user_id, key, value) VALUES (1, 'theme', 'dark') - ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test INSERT with RETURNING multiple columns -#[test] -fn it_parses_insert_returning_multiple() { - let query = "INSERT INTO users (name, email) VALUES ('test', 'test@example.com') RETURNING id, created_at, name"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test INSERT with subquery in VALUES -#[test] -fn it_parses_insert_with_subquery_value() { - let query = "INSERT INTO orders (user_id, total) VALUES ((SELECT id FROM users WHERE email = 'test@example.com'), 100.00)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test INSERT with OVERRIDING -#[test] -fn it_parses_insert_overriding() { - let query = "INSERT INTO users (id, name) OVERRIDING SYSTEM VALUE VALUES (1, 'test')"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// Complex UPDATE tests -// ============================================================================ - -/// Test UPDATE with multiple columns -#[test] -fn it_parses_update_multiple_columns() { - let query = "UPDATE users SET name = 'new_name', email = 'new@example.com', updated_at = NOW() WHERE id = 1"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test UPDATE with subquery in SET -#[test] -fn it_parses_update_with_subquery_set() { - let query = "UPDATE orders SET total = (SELECT SUM(price * quantity) FROM order_items WHERE order_id = orders.id) WHERE id = 1"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test UPDATE with FROM clause (PostgreSQL-specific JOIN update) -#[test] -fn it_parses_update_from() { - let query = "UPDATE orders o SET status = 'shipped', shipped_at = NOW() - FROM shipments s - WHERE o.id = s.order_id AND s.status = 'delivered'"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test UPDATE with FROM and multiple tables -#[test] -fn it_parses_update_from_multiple_tables() { - let query = "UPDATE products p SET price = p.price * (1 + d.percentage / 100) - FROM discounts d - JOIN categories c ON d.category_id = c.id - WHERE p.category_id = c.id AND d.active = true"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test UPDATE with CTE -#[test] -fn it_parses_update_with_cte() { - let query = "WITH inactive_users AS ( - SELECT id FROM users WHERE last_login < NOW() - INTERVAL '1 year' - ) - UPDATE users SET status = 'inactive' WHERE id IN (SELECT id FROM inactive_users)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test UPDATE with RETURNING -#[test] -fn it_parses_update_returning() { - let query = "UPDATE users SET name = 'updated' WHERE id = 1 RETURNING id, name, updated_at"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test UPDATE with complex WHERE clause -#[test] -fn it_parses_update_complex_where() { - let query = "UPDATE orders SET status = 'cancelled' - WHERE created_at < NOW() - INTERVAL '30 days' - AND status = 'pending' - AND NOT EXISTS (SELECT 1 FROM payments WHERE payments.order_id = orders.id)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test UPDATE with row value comparison -#[test] -fn it_parses_update_row_comparison() { - let query = "UPDATE users SET (name, email) = ('new_name', 'new@example.com') WHERE id = 1"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test UPDATE with CASE expression -#[test] -fn it_parses_update_with_case() { - let query = "UPDATE products SET price = CASE - WHEN category = 'electronics' THEN price * 0.9 - WHEN category = 'clothing' THEN price * 0.8 - ELSE price * 0.95 - END - WHERE sale_active = true"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test UPDATE with array operations -#[test] -fn it_parses_update_array() { - let query = "UPDATE users SET tags = array_append(tags, 'verified') WHERE id = 1"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// Complex DELETE tests -// ============================================================================ - -/// Test DELETE with subquery in WHERE -#[test] -fn it_parses_delete_with_subquery() { - let query = "DELETE FROM orders WHERE user_id IN (SELECT id FROM users WHERE status = 'deleted')"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test DELETE with USING clause (PostgreSQL-specific JOIN delete) -#[test] -fn it_parses_delete_using() { - let query = "DELETE FROM order_items oi USING orders o - WHERE oi.order_id = o.id AND o.status = 'cancelled'"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test DELETE with USING and multiple tables -#[test] -fn it_parses_delete_using_multiple_tables() { - let query = "DELETE FROM notifications n - USING users u, user_settings s - WHERE n.user_id = u.id - AND u.id = s.user_id - AND s.key = 'email_notifications' - AND s.value = 'false'"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test DELETE with CTE -#[test] -fn it_parses_delete_with_cte() { - let query = "WITH old_orders AS ( - SELECT id FROM orders WHERE created_at < NOW() - INTERVAL '5 years' - ) - DELETE FROM order_items WHERE order_id IN (SELECT id FROM old_orders)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test DELETE with RETURNING -#[test] -fn it_parses_delete_returning() { - let query = "DELETE FROM users WHERE id = 1 RETURNING id, name, email"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test DELETE with EXISTS -#[test] -fn it_parses_delete_with_exists() { - let query = "DELETE FROM products p - WHERE NOT EXISTS (SELECT 1 FROM order_items oi WHERE oi.product_id = p.id) - AND p.created_at < NOW() - INTERVAL '1 year'"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test DELETE with complex boolean conditions -#[test] -fn it_parses_delete_complex_conditions() { - let query = "DELETE FROM logs - WHERE (level = 'debug' AND created_at < NOW() - INTERVAL '7 days') - OR (level = 'info' AND created_at < NOW() - INTERVAL '30 days') - OR (level IN ('warning', 'error') AND created_at < NOW() - INTERVAL '90 days')"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test DELETE with LIMIT (PostgreSQL extension) -#[test] -fn it_parses_delete_only() { - let query = "DELETE FROM ONLY parent_table WHERE id = 1"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// Combined DML with CTEs -// ============================================================================ - -/// Test data modification CTE (INSERT in CTE) -#[test] -fn it_parses_insert_cte_returning() { - let query = "WITH inserted AS ( - INSERT INTO users (name, email) VALUES ('test', 'test@example.com') RETURNING id, name - ) - SELECT * FROM inserted"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test UPDATE in CTE with final SELECT -#[test] -fn it_parses_update_cte_returning() { - let query = "WITH updated AS ( - UPDATE users SET last_login = NOW() WHERE id = 1 RETURNING id, name, last_login - ) - SELECT * FROM updated"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test DELETE in CTE with final SELECT -#[test] -fn it_parses_delete_cte_returning() { - let query = "WITH deleted AS ( - DELETE FROM expired_sessions WHERE expires_at < NOW() RETURNING user_id - ) - SELECT COUNT(*) FROM deleted"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test chained CTEs with multiple DML operations -#[test] -fn it_parses_chained_dml_ctes() { - let query = "WITH - to_archive AS ( - SELECT id FROM users WHERE last_login < NOW() - INTERVAL '2 years' - ), - archived AS ( - INSERT INTO archived_users SELECT * FROM users WHERE id IN (SELECT id FROM to_archive) RETURNING id - ), - deleted AS ( - DELETE FROM users WHERE id IN (SELECT id FROM archived) RETURNING id - ) - SELECT COUNT(*) as archived_count FROM deleted"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -// ============================================================================ -// Tests for previously stubbed fields -// ============================================================================ - -/// Test column with COLLATE clause -#[test] -fn it_parses_column_with_collate() { - let query = "CREATE TABLE test_collate ( - name TEXT COLLATE \"C\", - description VARCHAR(255) COLLATE \"en_US.UTF-8\" - )"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test partitioned table with PARTITION BY RANGE -#[test] -fn it_parses_partition_by_range() { - let query = "CREATE TABLE measurements ( - id SERIAL, - logdate DATE NOT NULL, - peaktemp INT - ) PARTITION BY RANGE (logdate)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test partitioned table with PARTITION BY LIST -#[test] -fn it_parses_partition_by_list() { - let query = "CREATE TABLE orders ( - id SERIAL, - region TEXT NOT NULL, - order_date DATE - ) PARTITION BY LIST (region)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test partitioned table with PARTITION BY HASH -#[test] -fn it_parses_partition_by_hash() { - let query = "CREATE TABLE users_partitioned ( - id SERIAL, - username TEXT - ) PARTITION BY HASH (id)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test partition with FOR VALUES (range) -#[test] -fn it_parses_partition_for_values_range() { - let query = "CREATE TABLE measurements_2023 PARTITION OF measurements - FOR VALUES FROM ('2023-01-01') TO ('2024-01-01')"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test partition with FOR VALUES (list) -#[test] -fn it_parses_partition_for_values_list() { - let query = "CREATE TABLE orders_west PARTITION OF orders - FOR VALUES IN ('west', 'northwest', 'southwest')"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test partition with FOR VALUES (hash) -#[test] -fn it_parses_partition_for_values_hash() { - let query = "CREATE TABLE users_part_0 PARTITION OF users_partitioned - FOR VALUES WITH (MODULUS 4, REMAINDER 0)"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test partition with DEFAULT -#[test] -fn it_parses_partition_default() { - let query = "CREATE TABLE orders_other PARTITION OF orders DEFAULT"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test recursive CTE with SEARCH BREADTH FIRST -#[test] -fn it_parses_cte_search_breadth_first() { - let query = "WITH RECURSIVE search_tree(id, parent_id, data, depth) AS ( - SELECT id, parent_id, data, 0 FROM tree WHERE parent_id IS NULL - UNION ALL - SELECT t.id, t.parent_id, t.data, st.depth + 1 - FROM tree t, search_tree st WHERE t.parent_id = st.id - ) SEARCH BREADTH FIRST BY id SET ordercol - SELECT * FROM search_tree ORDER BY ordercol"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test recursive CTE with SEARCH DEPTH FIRST -#[test] -fn it_parses_cte_search_depth_first() { - let query = "WITH RECURSIVE search_tree(id, parent_id, data) AS ( - SELECT id, parent_id, data FROM tree WHERE parent_id IS NULL - UNION ALL - SELECT t.id, t.parent_id, t.data - FROM tree t, search_tree st WHERE t.parent_id = st.id - ) SEARCH DEPTH FIRST BY id SET ordercol - SELECT * FROM search_tree ORDER BY ordercol"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} - -/// Test recursive CTE with CYCLE detection -#[test] -fn it_parses_cte_cycle() { - let query = "WITH RECURSIVE search_graph(id, link, data, depth) AS ( - SELECT g.id, g.link, g.data, 0 FROM graph g - UNION ALL - SELECT g.id, g.link, g.data, sg.depth + 1 - FROM graph g, search_graph sg WHERE g.id = sg.link - ) CYCLE id SET is_cycle USING path - SELECT * FROM search_graph"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} +#![allow(non_snake_case)] +#![cfg(test)] -/// Test recursive CTE with both SEARCH and CYCLE -#[test] -fn it_parses_cte_search_and_cycle() { - let query = "WITH RECURSIVE search_graph(id, link, data, depth) AS ( - SELECT g.id, g.link, g.data, 0 FROM graph g WHERE id = 1 - UNION ALL - SELECT g.id, g.link, g.data, sg.depth + 1 - FROM graph g, search_graph sg WHERE g.id = sg.link - ) SEARCH DEPTH FIRST BY id SET ordercol - CYCLE id SET is_cycle USING path - SELECT * FROM search_graph"; - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); +#[macro_use] +mod support; - assert_eq!(raw_result.protobuf, proto_result.protobuf); -} +mod raw_parse; -// ============================================================================ -// Benchmark -// ============================================================================ +// Re-export the benchmark test at the top level +use pg_query::{parse, parse_raw}; +use std::time::{Duration, Instant}; /// Benchmark comparing parse_raw vs parse performance #[test] fn benchmark_parse_raw_vs_parse() { - use std::time::{Duration, Instant}; - // Complex query with multiple features: CTEs, JOINs, subqueries, window functions, etc. let query = r#" WITH RECURSIVE @@ -2327,22 +81,16 @@ fn benchmark_parse_raw_vs_parse() { ORDER BY total_spent DESC NULLS LAST, u.created_at ASC LIMIT 100 OFFSET 0 - FOR UPDATE OF u SKIP LOCKED - "#; - - // Verify both produce the same result first - let raw_result = parse_raw(query).unwrap(); - let proto_result = parse(query).unwrap(); - assert_eq!(raw_result.protobuf, proto_result.protobuf); + FOR UPDATE OF u SKIP LOCKED"#; // Warm up - for _ in 0..100 { + for _ in 0..10 { let _ = parse_raw(query).unwrap(); let _ = parse(query).unwrap(); } - // Target ~2 seconds per benchmark (4 seconds total) - let target_duration = Duration::from_millis(2000); + // Run for a fixed duration to get stable measurements + let target_duration = Duration::from_secs(2); // Benchmark parse_raw let mut raw_iterations = 0u64; From 08efbca401b345595588089b822ec7e7b21e8f6e Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 31 Dec 2025 14:06:55 -0800 Subject: [PATCH 09/17] Add deparse_raw --- build.rs | 15 + libpg_query | 2 +- src/lib.rs | 2 + src/raw_deparse.rs | 781 ++++++++++++++++++++++++++++++++++++++ tests/raw_parse/basic.rs | 77 ++++ tests/raw_parse/dml.rs | 7 + tests/raw_parse/mod.rs | 18 +- tests/raw_parse/select.rs | 7 + tests/raw_parse_tests.rs | 144 ++++++- 9 files changed, 1049 insertions(+), 4 deletions(-) create mode 100644 src/raw_deparse.rs diff --git a/build.rs b/build.rs index 2accb5d..0b2e498 100644 --- a/build.rs +++ b/build.rs @@ -66,6 +66,9 @@ fn main() -> Result<(), Box> { .blocklist_function("pg_query_parse_raw_opts") .blocklist_function("pg_query_free_raw_parse_result") .blocklist_type("PgQueryRawParseResult") + // Blocklist raw deparse functions that use types from bindings_raw + .blocklist_function("pg_query_deparse_raw") + .blocklist_function("pg_query_deparse_raw_opts") .generate() .map_err(|_| "Unable to generate bindings")? .write_to_file(out_dir.join("bindings.rs"))?; @@ -243,6 +246,18 @@ fn main() -> Result<(), Box> { .allowlist_function("pg_query_parse_raw") .allowlist_function("pg_query_parse_raw_opts") .allowlist_function("pg_query_free_raw_parse_result") + // Allowlist raw deparse functions + .allowlist_function("pg_query_deparse_raw") + .allowlist_function("pg_query_deparse_raw_opts") + .allowlist_function("pg_query_free_deparse_result") + // Node building helpers for deparse_raw + .allowlist_function("pg_query_deparse_enter_context") + .allowlist_function("pg_query_deparse_exit_context") + .allowlist_function("pg_query_alloc_node") + .allowlist_function("pg_query_pstrdup") + .allowlist_function("pg_query_list_make1") + .allowlist_function("pg_query_list_append") + .allowlist_function("pg_query_deparse_nodes") .generate() .map_err(|_| "Unable to generate raw bindings")? .write_to_file(out_dir.join("bindings_raw.rs"))?; diff --git a/libpg_query b/libpg_query index 0946937..db02663 160000 --- a/libpg_query +++ b/libpg_query @@ -1 +1 @@ -Subproject commit 09469376d81131912d61374709b8331c85831837 +Subproject commit db02663b0be81a8499ee8c0cc87081effb13d54e diff --git a/src/lib.rs b/src/lib.rs index 01bfa6b..135f914 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,6 +49,7 @@ mod parse_result; #[rustfmt::skip] pub mod protobuf; mod query; +mod raw_deparse; mod raw_parse; mod summary; mod summary_result; @@ -60,6 +61,7 @@ pub use node_mut::*; pub use node_ref::*; pub use parse_result::*; pub use query::*; +pub use raw_deparse::deparse_raw; pub use raw_parse::parse_raw; pub use summary::*; pub use summary_result::*; diff --git a/src/raw_deparse.rs b/src/raw_deparse.rs new file mode 100644 index 0000000..09053cf --- /dev/null +++ b/src/raw_deparse.rs @@ -0,0 +1,781 @@ +//! Direct deparsing that bypasses protobuf serialization. +//! +//! This module converts Rust protobuf types directly to PostgreSQL's internal +//! C parse tree structures, then deparses them to SQL without going through +//! protobuf serialization. + +use crate::bindings_raw; +use crate::protobuf; +use crate::{Error, Result}; +use std::ffi::CStr; +use std::os::raw::c_char; + +/// Deparses a protobuf ParseResult directly to SQL without protobuf serialization. +/// +/// This function is faster than `deparse` because it skips the protobuf encode/decode step. +/// The protobuf types are converted directly to PostgreSQL's internal C structures. +/// +/// # Example +/// +/// ```rust +/// let result = pg_query::parse("SELECT * FROM users").unwrap(); +/// let sql = pg_query::deparse_raw(&result.protobuf).unwrap(); +/// assert_eq!(sql, "SELECT * FROM users"); +/// ``` +pub fn deparse_raw(protobuf: &protobuf::ParseResult) -> Result { + unsafe { + // Enter PostgreSQL memory context - this must stay active for the entire operation + let ctx = bindings_raw::pg_query_deparse_enter_context(); + + // Build C nodes from protobuf types (uses palloc which requires active context) + let stmts = write_stmts(&protobuf.stmts); + + // Deparse the nodes to SQL (also requires active context) + let result = bindings_raw::pg_query_deparse_nodes(stmts); + + // Exit memory context - this frees all palloc'd memory + bindings_raw::pg_query_deparse_exit_context(ctx); + + // Handle result (result.query is strdup'd, so it survives context exit) + if !result.error.is_null() { + let message = CStr::from_ptr((*result.error).message).to_string_lossy().to_string(); + bindings_raw::pg_query_free_deparse_result(result); + return Err(Error::Parse(message)); + } + + let query = CStr::from_ptr(result.query).to_string_lossy().to_string(); + bindings_raw::pg_query_free_deparse_result(result); + Ok(query) + } +} + +/// Allocates a C node of the given type. +unsafe fn alloc_node(tag: u32) -> *mut T { + bindings_raw::pg_query_alloc_node(std::mem::size_of::(), tag as i32) as *mut T +} + +/// Converts a protobuf enum value to a C enum value. +/// Protobuf enums have an extra "Undefined = 0" value, so we subtract 1. +/// If the value is 0 (Undefined), we return 0 (treating it as the first C enum value). +fn proto_enum_to_c(value: i32) -> u32 { + if value <= 0 { + 0 + } else { + (value - 1) as u32 + } +} + +/// Duplicates a string into PostgreSQL memory context. +unsafe fn pstrdup(s: &str) -> *mut c_char { + if s.is_empty() { + return std::ptr::null_mut(); + } + let cstr = std::ffi::CString::new(s).unwrap(); + bindings_raw::pg_query_pstrdup(cstr.as_ptr()) +} + +/// Writes a list of RawStmt to a C List. +fn write_stmts(stmts: &[protobuf::RawStmt]) -> *mut std::ffi::c_void { + if stmts.is_empty() { + return std::ptr::null_mut(); + } + + let mut list: *mut std::ffi::c_void = std::ptr::null_mut(); + + for stmt in stmts { + let raw_stmt = write_raw_stmt(stmt); + if list.is_null() { + list = unsafe { bindings_raw::pg_query_list_make1(raw_stmt as *mut std::ffi::c_void) }; + } else { + list = unsafe { bindings_raw::pg_query_list_append(list, raw_stmt as *mut std::ffi::c_void) }; + } + } + + list +} + +/// Writes a protobuf RawStmt to a C RawStmt. +fn write_raw_stmt(stmt: &protobuf::RawStmt) -> *mut bindings_raw::RawStmt { + unsafe { + let raw_stmt = alloc_node::(bindings_raw::NodeTag_T_RawStmt); + (*raw_stmt).stmt_location = stmt.stmt_location; + (*raw_stmt).stmt_len = stmt.stmt_len; + (*raw_stmt).stmt = write_node_boxed(&stmt.stmt); + raw_stmt + } +} + +/// Writes an Option> to a C Node pointer. +fn write_node_boxed(node: &Option>) -> *mut bindings_raw::Node { + match node { + Some(n) => write_node(n), + None => std::ptr::null_mut(), + } +} + +/// Writes a protobuf Node to a C Node. +fn write_node(node: &protobuf::Node) -> *mut bindings_raw::Node { + match &node.node { + Some(n) => write_node_inner(n), + None => std::ptr::null_mut(), + } +} + +/// Writes a protobuf node::Node enum to a C Node. +fn write_node_inner(node: &protobuf::node::Node) -> *mut bindings_raw::Node { + unsafe { + match node { + protobuf::node::Node::SelectStmt(stmt) => write_select_stmt(stmt) as *mut bindings_raw::Node, + protobuf::node::Node::InsertStmt(stmt) => write_insert_stmt(stmt) as *mut bindings_raw::Node, + protobuf::node::Node::UpdateStmt(stmt) => write_update_stmt(stmt) as *mut bindings_raw::Node, + protobuf::node::Node::DeleteStmt(stmt) => write_delete_stmt(stmt) as *mut bindings_raw::Node, + protobuf::node::Node::RangeVar(rv) => write_range_var(rv) as *mut bindings_raw::Node, + protobuf::node::Node::Alias(alias) => write_alias(alias) as *mut bindings_raw::Node, + protobuf::node::Node::ResTarget(rt) => write_res_target(rt) as *mut bindings_raw::Node, + protobuf::node::Node::ColumnRef(cr) => write_column_ref(cr) as *mut bindings_raw::Node, + protobuf::node::Node::AConst(ac) => write_a_const(ac) as *mut bindings_raw::Node, + protobuf::node::Node::AExpr(expr) => write_a_expr(expr) as *mut bindings_raw::Node, + protobuf::node::Node::FuncCall(fc) => write_func_call(fc) as *mut bindings_raw::Node, + protobuf::node::Node::String(s) => write_string(s) as *mut bindings_raw::Node, + protobuf::node::Node::Integer(i) => write_integer(i) as *mut bindings_raw::Node, + protobuf::node::Node::Float(f) => write_float(f) as *mut bindings_raw::Node, + protobuf::node::Node::Boolean(b) => write_boolean(b) as *mut bindings_raw::Node, + protobuf::node::Node::List(l) => write_list(l) as *mut bindings_raw::Node, + protobuf::node::Node::AStar(_) => write_a_star() as *mut bindings_raw::Node, + protobuf::node::Node::JoinExpr(je) => write_join_expr(je) as *mut bindings_raw::Node, + protobuf::node::Node::SortBy(sb) => write_sort_by(sb) as *mut bindings_raw::Node, + protobuf::node::Node::TypeCast(tc) => write_type_cast(tc) as *mut bindings_raw::Node, + protobuf::node::Node::TypeName(tn) => write_type_name(tn) as *mut bindings_raw::Node, + protobuf::node::Node::ParamRef(pr) => write_param_ref(pr) as *mut bindings_raw::Node, + protobuf::node::Node::NullTest(nt) => write_null_test(nt) as *mut bindings_raw::Node, + protobuf::node::Node::BoolExpr(be) => write_bool_expr(be) as *mut bindings_raw::Node, + protobuf::node::Node::SubLink(sl) => write_sub_link(sl) as *mut bindings_raw::Node, + protobuf::node::Node::RangeSubselect(rs) => write_range_subselect(rs) as *mut bindings_raw::Node, + protobuf::node::Node::CommonTableExpr(cte) => write_common_table_expr(cte) as *mut bindings_raw::Node, + protobuf::node::Node::WithClause(wc) => write_with_clause(wc) as *mut bindings_raw::Node, + protobuf::node::Node::GroupingSet(gs) => write_grouping_set(gs) as *mut bindings_raw::Node, + protobuf::node::Node::WindowDef(wd) => write_window_def(wd) as *mut bindings_raw::Node, + protobuf::node::Node::CoalesceExpr(ce) => write_coalesce_expr(ce) as *mut bindings_raw::Node, + protobuf::node::Node::CaseExpr(ce) => write_case_expr(ce) as *mut bindings_raw::Node, + protobuf::node::Node::CaseWhen(cw) => write_case_when(cw) as *mut bindings_raw::Node, + protobuf::node::Node::SetToDefault(_) => write_set_to_default() as *mut bindings_raw::Node, + protobuf::node::Node::LockingClause(lc) => write_locking_clause(lc) as *mut bindings_raw::Node, + protobuf::node::Node::RangeFunction(rf) => write_range_function(rf) as *mut bindings_raw::Node, + protobuf::node::Node::BitString(bs) => write_bit_string(bs) as *mut bindings_raw::Node, + protobuf::node::Node::IndexElem(ie) => write_index_elem(ie) as *mut bindings_raw::Node, + // TODO: Add remaining node types as needed + _ => { + // For unimplemented nodes, return null and let the deparser handle it + std::ptr::null_mut() + } + } + } +} + +/// Writes a list of protobuf Nodes to a C List. +fn write_node_list(nodes: &[protobuf::Node]) -> *mut bindings_raw::List { + if nodes.is_empty() { + return std::ptr::null_mut(); + } + + let mut list: *mut std::ffi::c_void = std::ptr::null_mut(); + + for node in nodes { + let c_node = write_node(node); + if !c_node.is_null() { + if list.is_null() { + list = unsafe { bindings_raw::pg_query_list_make1(c_node as *mut std::ffi::c_void) }; + } else { + list = unsafe { bindings_raw::pg_query_list_append(list, c_node as *mut std::ffi::c_void) }; + } + } + } + + list as *mut bindings_raw::List +} + +// ============================================================================= +// Individual node type writers +// ============================================================================= + +unsafe fn write_select_stmt(stmt: &protobuf::SelectStmt) -> *mut bindings_raw::SelectStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_SelectStmt); + (*node).distinctClause = write_node_list(&stmt.distinct_clause); + (*node).intoClause = write_into_clause_opt(&stmt.into_clause); + (*node).targetList = write_node_list(&stmt.target_list); + (*node).fromClause = write_node_list(&stmt.from_clause); + (*node).whereClause = write_node_boxed(&stmt.where_clause); + (*node).groupClause = write_node_list(&stmt.group_clause); + (*node).groupDistinct = stmt.group_distinct; + (*node).havingClause = write_node_boxed(&stmt.having_clause); + (*node).windowClause = write_node_list(&stmt.window_clause); + (*node).valuesLists = write_values_lists(&stmt.values_lists); + (*node).sortClause = write_node_list(&stmt.sort_clause); + (*node).limitOffset = write_node_boxed(&stmt.limit_offset); + (*node).limitCount = write_node_boxed(&stmt.limit_count); + (*node).limitOption = proto_enum_to_c(stmt.limit_option); + (*node).lockingClause = write_node_list(&stmt.locking_clause); + (*node).withClause = write_with_clause_ref(&stmt.with_clause); + (*node).op = proto_enum_to_c(stmt.op); + (*node).all = stmt.all; + (*node).larg = write_select_stmt_opt(&stmt.larg); + (*node).rarg = write_select_stmt_opt(&stmt.rarg); + node +} + +unsafe fn write_select_stmt_opt(stmt: &Option>) -> *mut bindings_raw::SelectStmt { + match stmt { + Some(s) => write_select_stmt(s), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_into_clause_opt(ic: &Option>) -> *mut bindings_raw::IntoClause { + match ic { + Some(into) => write_into_clause(into), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_into_clause(ic: &protobuf::IntoClause) -> *mut bindings_raw::IntoClause { + let node = alloc_node::(bindings_raw::NodeTag_T_IntoClause); + (*node).rel = write_range_var_ref(&ic.rel); + (*node).colNames = write_node_list(&ic.col_names); + (*node).accessMethod = pstrdup(&ic.access_method); + (*node).options = write_node_list(&ic.options); + (*node).onCommit = proto_enum_to_c(ic.on_commit); + (*node).tableSpaceName = pstrdup(&ic.table_space_name); + (*node).viewQuery = write_node_boxed(&ic.view_query); + (*node).skipData = ic.skip_data; + node +} + +unsafe fn write_insert_stmt(stmt: &protobuf::InsertStmt) -> *mut bindings_raw::InsertStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_InsertStmt); + (*node).relation = write_range_var_ref(&stmt.relation); + (*node).cols = write_node_list(&stmt.cols); + (*node).selectStmt = write_node_boxed(&stmt.select_stmt); + (*node).onConflictClause = write_on_conflict_clause_opt(&stmt.on_conflict_clause); + (*node).returningList = write_node_list(&stmt.returning_list); + (*node).withClause = write_with_clause_ref(&stmt.with_clause); + (*node).override_ = proto_enum_to_c(stmt.r#override); + node +} + +unsafe fn write_on_conflict_clause_opt(oc: &Option>) -> *mut bindings_raw::OnConflictClause { + match oc { + Some(clause) => write_on_conflict_clause(clause), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_on_conflict_clause(oc: &protobuf::OnConflictClause) -> *mut bindings_raw::OnConflictClause { + let node = alloc_node::(bindings_raw::NodeTag_T_OnConflictClause); + (*node).action = proto_enum_to_c(oc.action); + (*node).infer = write_infer_clause_opt(&oc.infer); + (*node).targetList = write_node_list(&oc.target_list); + (*node).whereClause = write_node_boxed(&oc.where_clause); + (*node).location = oc.location; + node +} + +unsafe fn write_infer_clause_opt(ic: &Option>) -> *mut bindings_raw::InferClause { + match ic { + Some(infer) => { + let node = alloc_node::(bindings_raw::NodeTag_T_InferClause); + (*node).indexElems = write_node_list(&infer.index_elems); + (*node).whereClause = write_node_boxed(&infer.where_clause); + (*node).conname = pstrdup(&infer.conname); + (*node).location = infer.location; + node + } + None => std::ptr::null_mut(), + } +} + +unsafe fn write_update_stmt(stmt: &protobuf::UpdateStmt) -> *mut bindings_raw::UpdateStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_UpdateStmt); + (*node).relation = write_range_var_ref(&stmt.relation); + (*node).targetList = write_node_list(&stmt.target_list); + (*node).whereClause = write_node_boxed(&stmt.where_clause); + (*node).fromClause = write_node_list(&stmt.from_clause); + (*node).returningList = write_node_list(&stmt.returning_list); + (*node).withClause = write_with_clause_ref(&stmt.with_clause); + node +} + +unsafe fn write_delete_stmt(stmt: &protobuf::DeleteStmt) -> *mut bindings_raw::DeleteStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_DeleteStmt); + (*node).relation = write_range_var_ref(&stmt.relation); + (*node).usingClause = write_node_list(&stmt.using_clause); + (*node).whereClause = write_node_boxed(&stmt.where_clause); + (*node).returningList = write_node_list(&stmt.returning_list); + (*node).withClause = write_with_clause_ref(&stmt.with_clause); + node +} + +unsafe fn write_range_var(rv: &protobuf::RangeVar) -> *mut bindings_raw::RangeVar { + let node = alloc_node::(bindings_raw::NodeTag_T_RangeVar); + (*node).catalogname = pstrdup(&rv.catalogname); + (*node).schemaname = pstrdup(&rv.schemaname); + (*node).relname = pstrdup(&rv.relname); + (*node).inh = rv.inh; + (*node).relpersistence = if rv.relpersistence.is_empty() { 'p' as i8 } else { rv.relpersistence.chars().next().unwrap() as i8 }; + (*node).alias = write_alias_ref(&rv.alias); + (*node).location = rv.location; + node +} + +unsafe fn write_range_var_opt(rv: &Option>) -> *mut bindings_raw::RangeVar { + match rv { + Some(r) => write_range_var(r), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_range_var_ref(rv: &Option) -> *mut bindings_raw::RangeVar { + match rv { + Some(r) => write_range_var(r), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_alias(alias: &protobuf::Alias) -> *mut bindings_raw::Alias { + let node = alloc_node::(bindings_raw::NodeTag_T_Alias); + (*node).aliasname = pstrdup(&alias.aliasname); + (*node).colnames = write_node_list(&alias.colnames); + node +} + +unsafe fn write_alias_opt(alias: &Option>) -> *mut bindings_raw::Alias { + match alias { + Some(a) => write_alias(a), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_alias_ref(alias: &Option) -> *mut bindings_raw::Alias { + match alias { + Some(a) => write_alias(a), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_res_target(rt: &protobuf::ResTarget) -> *mut bindings_raw::ResTarget { + let node = alloc_node::(bindings_raw::NodeTag_T_ResTarget); + (*node).name = pstrdup(&rt.name); + (*node).indirection = write_node_list(&rt.indirection); + (*node).val = write_node_boxed(&rt.val); + (*node).location = rt.location; + node +} + +unsafe fn write_column_ref(cr: &protobuf::ColumnRef) -> *mut bindings_raw::ColumnRef { + let node = alloc_node::(bindings_raw::NodeTag_T_ColumnRef); + (*node).fields = write_node_list(&cr.fields); + (*node).location = cr.location; + node +} + +unsafe fn write_a_const(ac: &protobuf::AConst) -> *mut bindings_raw::A_Const { + let node = alloc_node::(bindings_raw::NodeTag_T_A_Const); + (*node).location = ac.location; + (*node).isnull = ac.isnull; + + if let Some(val) = &ac.val { + match val { + protobuf::a_const::Val::Ival(i) => { + (*node).val.ival.type_ = bindings_raw::NodeTag_T_Integer; + (*node).val.ival.ival = i.ival; + } + protobuf::a_const::Val::Fval(f) => { + (*node).val.fval.type_ = bindings_raw::NodeTag_T_Float; + (*node).val.fval.fval = pstrdup(&f.fval); + } + protobuf::a_const::Val::Boolval(b) => { + (*node).val.boolval.type_ = bindings_raw::NodeTag_T_Boolean; + (*node).val.boolval.boolval = b.boolval; + } + protobuf::a_const::Val::Sval(s) => { + (*node).val.sval.type_ = bindings_raw::NodeTag_T_String; + (*node).val.sval.sval = pstrdup(&s.sval); + } + protobuf::a_const::Val::Bsval(bs) => { + (*node).val.bsval.type_ = bindings_raw::NodeTag_T_BitString; + (*node).val.bsval.bsval = pstrdup(&bs.bsval); + } + } + } + node +} + +unsafe fn write_a_expr(expr: &protobuf::AExpr) -> *mut bindings_raw::A_Expr { + let node = alloc_node::(bindings_raw::NodeTag_T_A_Expr); + (*node).kind = proto_enum_to_c(expr.kind); + (*node).name = write_node_list(&expr.name); + (*node).lexpr = write_node_boxed(&expr.lexpr); + (*node).rexpr = write_node_boxed(&expr.rexpr); + (*node).location = expr.location; + node +} + +unsafe fn write_func_call(fc: &protobuf::FuncCall) -> *mut bindings_raw::FuncCall { + let node = alloc_node::(bindings_raw::NodeTag_T_FuncCall); + (*node).funcname = write_node_list(&fc.funcname); + (*node).args = write_node_list(&fc.args); + (*node).agg_order = write_node_list(&fc.agg_order); + (*node).agg_filter = write_node_boxed(&fc.agg_filter); + (*node).over = write_window_def_opt(&fc.over); + (*node).agg_within_group = fc.agg_within_group; + (*node).agg_star = fc.agg_star; + (*node).agg_distinct = fc.agg_distinct; + (*node).func_variadic = fc.func_variadic; + (*node).funcformat = proto_enum_to_c(fc.funcformat); + (*node).location = fc.location; + node +} + +unsafe fn write_window_def_opt(wd: &Option>) -> *mut bindings_raw::WindowDef { + match wd { + Some(w) => write_window_def(w), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_window_def(wd: &protobuf::WindowDef) -> *mut bindings_raw::WindowDef { + let node = alloc_node::(bindings_raw::NodeTag_T_WindowDef); + (*node).name = pstrdup(&wd.name); + (*node).refname = pstrdup(&wd.refname); + (*node).partitionClause = write_node_list(&wd.partition_clause); + (*node).orderClause = write_node_list(&wd.order_clause); + (*node).frameOptions = wd.frame_options; + (*node).startOffset = write_node_boxed(&wd.start_offset); + (*node).endOffset = write_node_boxed(&wd.end_offset); + (*node).location = wd.location; + node +} + +unsafe fn write_string(s: &protobuf::String) -> *mut bindings_raw::String { + let node = alloc_node::(bindings_raw::NodeTag_T_String); + (*node).sval = pstrdup(&s.sval); + node +} + +unsafe fn write_integer(i: &protobuf::Integer) -> *mut bindings_raw::Integer { + let node = alloc_node::(bindings_raw::NodeTag_T_Integer); + (*node).ival = i.ival; + node +} + +unsafe fn write_float(f: &protobuf::Float) -> *mut bindings_raw::Float { + let node = alloc_node::(bindings_raw::NodeTag_T_Float); + (*node).fval = pstrdup(&f.fval); + node +} + +unsafe fn write_boolean(b: &protobuf::Boolean) -> *mut bindings_raw::Boolean { + let node = alloc_node::(bindings_raw::NodeTag_T_Boolean); + (*node).boolval = b.boolval; + node +} + +unsafe fn write_bit_string(bs: &protobuf::BitString) -> *mut bindings_raw::BitString { + let node = alloc_node::(bindings_raw::NodeTag_T_BitString); + (*node).bsval = pstrdup(&bs.bsval); + node +} + +unsafe fn write_null() -> *mut bindings_raw::Node { + // A_Const with isnull=true represents NULL + let node = alloc_node::(bindings_raw::NodeTag_T_A_Const); + (*node).isnull = true; + (*node).location = -1; + node as *mut bindings_raw::Node +} + +unsafe fn write_list(l: &protobuf::List) -> *mut bindings_raw::List { + write_node_list(&l.items) +} + +unsafe fn write_a_star() -> *mut bindings_raw::A_Star { + alloc_node::(bindings_raw::NodeTag_T_A_Star) +} + +unsafe fn write_join_expr(je: &protobuf::JoinExpr) -> *mut bindings_raw::JoinExpr { + let node = alloc_node::(bindings_raw::NodeTag_T_JoinExpr); + (*node).jointype = proto_enum_to_c(je.jointype); + (*node).isNatural = je.is_natural; + (*node).larg = write_node_boxed(&je.larg); + (*node).rarg = write_node_boxed(&je.rarg); + (*node).usingClause = write_node_list(&je.using_clause); + (*node).join_using_alias = write_alias_ref(&je.join_using_alias); + (*node).quals = write_node_boxed(&je.quals); + (*node).alias = write_alias_ref(&je.alias); + (*node).rtindex = je.rtindex; + node +} + +unsafe fn write_sort_by(sb: &protobuf::SortBy) -> *mut bindings_raw::SortBy { + let node = alloc_node::(bindings_raw::NodeTag_T_SortBy); + (*node).node = write_node_boxed(&sb.node); + (*node).sortby_dir = proto_enum_to_c(sb.sortby_dir); + (*node).sortby_nulls = proto_enum_to_c(sb.sortby_nulls); + (*node).useOp = write_node_list(&sb.use_op); + (*node).location = sb.location; + node +} + +unsafe fn write_type_cast(tc: &protobuf::TypeCast) -> *mut bindings_raw::TypeCast { + let node = alloc_node::(bindings_raw::NodeTag_T_TypeCast); + (*node).arg = write_node_boxed(&tc.arg); + (*node).typeName = write_type_name_ref(&tc.type_name); + (*node).location = tc.location; + node +} + +unsafe fn write_type_name_opt(tn: &Option>) -> *mut bindings_raw::TypeName { + match tn { + Some(t) => write_type_name(t), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_type_name_ref(tn: &Option) -> *mut bindings_raw::TypeName { + match tn { + Some(t) => write_type_name(t), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_type_name(tn: &protobuf::TypeName) -> *mut bindings_raw::TypeName { + let node = alloc_node::(bindings_raw::NodeTag_T_TypeName); + (*node).names = write_node_list(&tn.names); + (*node).typeOid = tn.type_oid; + (*node).setof = tn.setof; + (*node).pct_type = tn.pct_type; + (*node).typmods = write_node_list(&tn.typmods); + (*node).typemod = tn.typemod; + (*node).arrayBounds = write_node_list(&tn.array_bounds); + (*node).location = tn.location; + node +} + +unsafe fn write_param_ref(pr: &protobuf::ParamRef) -> *mut bindings_raw::ParamRef { + let node = alloc_node::(bindings_raw::NodeTag_T_ParamRef); + (*node).number = pr.number; + (*node).location = pr.location; + node +} + +unsafe fn write_null_test(nt: &protobuf::NullTest) -> *mut bindings_raw::NullTest { + let node = alloc_node::(bindings_raw::NodeTag_T_NullTest); + (*node).arg = write_node_boxed(&nt.arg) as *mut bindings_raw::Expr; + (*node).nulltesttype = proto_enum_to_c(nt.nulltesttype); + (*node).argisrow = nt.argisrow; + (*node).location = nt.location; + node +} + +unsafe fn write_bool_expr(be: &protobuf::BoolExpr) -> *mut bindings_raw::BoolExpr { + let node = alloc_node::(bindings_raw::NodeTag_T_BoolExpr); + (*node).boolop = proto_enum_to_c(be.boolop); + (*node).args = write_node_list(&be.args); + (*node).location = be.location; + node +} + +unsafe fn write_sub_link(sl: &protobuf::SubLink) -> *mut bindings_raw::SubLink { + let node = alloc_node::(bindings_raw::NodeTag_T_SubLink); + (*node).subLinkType = proto_enum_to_c(sl.sub_link_type); + (*node).subLinkId = sl.sub_link_id; + (*node).testexpr = write_node_boxed(&sl.testexpr); + (*node).operName = write_node_list(&sl.oper_name); + (*node).subselect = write_node_boxed(&sl.subselect); + (*node).location = sl.location; + node +} + +unsafe fn write_range_subselect(rs: &protobuf::RangeSubselect) -> *mut bindings_raw::RangeSubselect { + let node = alloc_node::(bindings_raw::NodeTag_T_RangeSubselect); + (*node).lateral = rs.lateral; + (*node).subquery = write_node_boxed(&rs.subquery); + (*node).alias = write_alias_ref(&rs.alias); + node +} + +unsafe fn write_common_table_expr(cte: &protobuf::CommonTableExpr) -> *mut bindings_raw::CommonTableExpr { + let node = alloc_node::(bindings_raw::NodeTag_T_CommonTableExpr); + (*node).ctename = pstrdup(&cte.ctename); + (*node).aliascolnames = write_node_list(&cte.aliascolnames); + (*node).ctematerialized = proto_enum_to_c(cte.ctematerialized); + (*node).ctequery = write_node_boxed(&cte.ctequery); + (*node).search_clause = write_cte_search_clause_opt(&cte.search_clause); + (*node).cycle_clause = write_cte_cycle_clause_opt(&cte.cycle_clause); + (*node).location = cte.location; + (*node).cterecursive = cte.cterecursive; + (*node).cterefcount = cte.cterefcount; + (*node).ctecolnames = write_node_list(&cte.ctecolnames); + // ctecoltypmods is a list of integers, handle separately if needed + node +} + +unsafe fn write_cte_search_clause_opt(sc: &Option) -> *mut bindings_raw::CTESearchClause { + match sc { + Some(s) => { + let node = alloc_node::(bindings_raw::NodeTag_T_CTESearchClause); + (*node).search_col_list = write_node_list(&s.search_col_list); + (*node).search_breadth_first = s.search_breadth_first; + (*node).search_seq_column = pstrdup(&s.search_seq_column); + (*node).location = s.location; + node + } + None => std::ptr::null_mut(), + } +} + +unsafe fn write_cte_cycle_clause_opt(cc: &Option>) -> *mut bindings_raw::CTECycleClause { + match cc { + Some(c) => { + let node = alloc_node::(bindings_raw::NodeTag_T_CTECycleClause); + (*node).cycle_col_list = write_node_list(&c.cycle_col_list); + (*node).cycle_mark_column = pstrdup(&c.cycle_mark_column); + (*node).cycle_mark_value = write_node_boxed(&c.cycle_mark_value); + (*node).cycle_mark_default = write_node_boxed(&c.cycle_mark_default); + (*node).cycle_path_column = pstrdup(&c.cycle_path_column); + (*node).location = c.location; + node + } + None => std::ptr::null_mut(), + } +} + +unsafe fn write_with_clause(wc: &protobuf::WithClause) -> *mut bindings_raw::WithClause { + let node = alloc_node::(bindings_raw::NodeTag_T_WithClause); + (*node).ctes = write_node_list(&wc.ctes); + (*node).recursive = wc.recursive; + (*node).location = wc.location; + node +} + +unsafe fn write_with_clause_opt(wc: &Option>) -> *mut bindings_raw::WithClause { + match wc { + Some(w) => write_with_clause(w), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_with_clause_ref(wc: &Option) -> *mut bindings_raw::WithClause { + match wc { + Some(w) => write_with_clause(w), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_grouping_set(gs: &protobuf::GroupingSet) -> *mut bindings_raw::GroupingSet { + let node = alloc_node::(bindings_raw::NodeTag_T_GroupingSet); + (*node).kind = proto_enum_to_c(gs.kind); + (*node).content = write_node_list(&gs.content); + (*node).location = gs.location; + node +} + +unsafe fn write_coalesce_expr(ce: &protobuf::CoalesceExpr) -> *mut bindings_raw::CoalesceExpr { + let node = alloc_node::(bindings_raw::NodeTag_T_CoalesceExpr); + (*node).coalescetype = ce.coalescetype; + (*node).coalescecollid = ce.coalescecollid; + (*node).args = write_node_list(&ce.args); + (*node).location = ce.location; + node +} + +unsafe fn write_case_expr(ce: &protobuf::CaseExpr) -> *mut bindings_raw::CaseExpr { + let node = alloc_node::(bindings_raw::NodeTag_T_CaseExpr); + (*node).casetype = ce.casetype; + (*node).casecollid = ce.casecollid; + (*node).arg = write_node_boxed(&ce.arg) as *mut bindings_raw::Expr; + (*node).args = write_node_list(&ce.args); + (*node).defresult = write_node_boxed(&ce.defresult) as *mut bindings_raw::Expr; + (*node).location = ce.location; + node +} + +unsafe fn write_case_when(cw: &protobuf::CaseWhen) -> *mut bindings_raw::CaseWhen { + let node = alloc_node::(bindings_raw::NodeTag_T_CaseWhen); + (*node).expr = write_node_boxed(&cw.expr) as *mut bindings_raw::Expr; + (*node).result = write_node_boxed(&cw.result) as *mut bindings_raw::Expr; + (*node).location = cw.location; + node +} + +unsafe fn write_set_to_default() -> *mut bindings_raw::SetToDefault { + let node = alloc_node::(bindings_raw::NodeTag_T_SetToDefault); + (*node).location = -1; + node +} + +unsafe fn write_locking_clause(lc: &protobuf::LockingClause) -> *mut bindings_raw::LockingClause { + let node = alloc_node::(bindings_raw::NodeTag_T_LockingClause); + (*node).lockedRels = write_node_list(&lc.locked_rels); + (*node).strength = proto_enum_to_c(lc.strength); + (*node).waitPolicy = proto_enum_to_c(lc.wait_policy); + node +} + +unsafe fn write_range_function(rf: &protobuf::RangeFunction) -> *mut bindings_raw::RangeFunction { + let node = alloc_node::(bindings_raw::NodeTag_T_RangeFunction); + (*node).lateral = rf.lateral; + (*node).ordinality = rf.ordinality; + (*node).is_rowsfrom = rf.is_rowsfrom; + (*node).functions = write_node_list(&rf.functions); + (*node).alias = write_alias_ref(&rf.alias); + (*node).coldeflist = write_node_list(&rf.coldeflist); + node +} + +unsafe fn write_index_elem(ie: &protobuf::IndexElem) -> *mut bindings_raw::IndexElem { + let node = alloc_node::(bindings_raw::NodeTag_T_IndexElem); + (*node).name = pstrdup(&ie.name); + (*node).expr = write_node_boxed(&ie.expr); + (*node).indexcolname = pstrdup(&ie.indexcolname); + (*node).collation = write_node_list(&ie.collation); + (*node).opclass = write_node_list(&ie.opclass); + (*node).opclassopts = write_node_list(&ie.opclassopts); + (*node).ordering = proto_enum_to_c(ie.ordering); + (*node).nulls_ordering = proto_enum_to_c(ie.nulls_ordering); + node +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_deparse_raw_empty() { + let result = protobuf::ParseResult { version: 170007, stmts: vec![] }; + let sql = deparse_raw(&result).unwrap(); + assert_eq!(sql, ""); + } +} + +/// Writes values lists (list of lists) for INSERT ... VALUES +unsafe fn write_values_lists(values: &[protobuf::Node]) -> *mut bindings_raw::List { + if values.is_empty() { + return std::ptr::null_mut(); + } + + let mut outer_list: *mut std::ffi::c_void = std::ptr::null_mut(); + + for value_node in values { + // Each value_node should be a List node containing the values for one row + if let Some(protobuf::node::Node::List(inner_list)) = &value_node.node { + let c_inner_list = write_node_list(&inner_list.items); + if outer_list.is_null() { + outer_list = bindings_raw::pg_query_list_make1(c_inner_list as *mut std::ffi::c_void); + } else { + outer_list = bindings_raw::pg_query_list_append(outer_list, c_inner_list as *mut std::ffi::c_void); + } + } + } + + outer_list as *mut bindings_raw::List +} diff --git a/tests/raw_parse/basic.rs b/tests/raw_parse/basic.rs index 94f42dd..76c52dc 100644 --- a/tests/raw_parse/basic.rs +++ b/tests/raw_parse/basic.rs @@ -27,6 +27,8 @@ fn it_parses_simple_select() { assert_eq!(raw_result.protobuf.stmts.len(), 1); assert_eq!(raw_result.protobuf, proto_result.protobuf); + // Verify deparse produces original query + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap(), query); } /// Test that parse_raw handles syntax errors @@ -48,6 +50,7 @@ fn it_matches_parse_for_simple_select() { let proto_result = parse(query).unwrap(); assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap(), query); } /// Test that parse_raw and parse produce equivalent results for SELECT with table @@ -58,6 +61,7 @@ fn it_matches_parse_for_select_from_table() { let proto_result = parse(query).unwrap(); assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap(), query); let mut raw_tables = raw_result.tables(); let mut proto_tables = proto_result.tables(); @@ -76,6 +80,8 @@ fn it_handles_empty_queries() { assert_eq!(raw_result.protobuf.stmts.len(), 0); assert_eq!(raw_result.protobuf, proto_result.protobuf); + // Empty queries deparse to empty string (comments are stripped) + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap(), ""); } /// Test that parse_raw parses multiple statements @@ -87,6 +93,7 @@ fn it_parses_multiple_statements() { assert_eq!(raw_result.protobuf.stmts.len(), 3); assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap(), query); } /// Test that tables() returns the same results for both parsers @@ -97,6 +104,7 @@ fn it_returns_tables_like_parse() { let proto_result = parse(query).unwrap(); assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap(), query); let mut raw_tables = raw_result.tables(); let mut proto_tables = proto_result.tables(); @@ -114,6 +122,7 @@ fn it_returns_functions_like_parse() { let proto_result = parse(query).unwrap(); assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap(), query); let mut raw_funcs = raw_result.functions(); let mut proto_funcs = proto_result.functions(); @@ -131,7 +140,75 @@ fn it_returns_statement_types_like_parse() { let proto_result = parse(query).unwrap(); assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap(), query); assert_eq!(raw_result.statement_types(), proto_result.statement_types()); assert_eq!(raw_result.statement_types(), vec!["SelectStmt", "InsertStmt", "UpdateStmt", "DeleteStmt"]); } + +// ============================================================================ +// deparse_raw tests +// ============================================================================ + +/// Test that deparse_raw successfully roundtrips a simple SELECT +#[test] +fn it_deparse_raw_simple_select() { + let query = "SELECT 1"; + let result = pg_query::parse(query).unwrap(); + let deparsed = pg_query::deparse_raw(&result.protobuf).unwrap(); + assert_eq!(deparsed, query); +} + +/// Test that deparse_raw successfully roundtrips SELECT FROM table +#[test] +fn it_deparse_raw_select_from_table() { + let query = "SELECT * FROM users"; + let result = pg_query::parse(query).unwrap(); + let deparsed = pg_query::deparse_raw(&result.protobuf).unwrap(); + assert_eq!(deparsed, query); +} + +/// Test that deparse_raw handles complex queries +#[test] +fn it_deparse_raw_complex_select() { + let query = "SELECT u.id, u.name FROM users u WHERE u.active = true ORDER BY u.name"; + let result = pg_query::parse(query).unwrap(); + let deparsed = pg_query::deparse_raw(&result.protobuf).unwrap(); + assert_eq!(deparsed, query); +} + +/// Test that deparse_raw handles INSERT statements +#[test] +fn it_deparse_raw_insert() { + let query = "INSERT INTO users (name, email) VALUES ('John', 'john@example.com')"; + let result = pg_query::parse(query).unwrap(); + let deparsed = pg_query::deparse_raw(&result.protobuf).unwrap(); + assert_eq!(deparsed, query); +} + +/// Test that deparse_raw handles UPDATE statements +#[test] +fn it_deparse_raw_update() { + let query = "UPDATE users SET name = 'Jane' WHERE id = 1"; + let result = pg_query::parse(query).unwrap(); + let deparsed = pg_query::deparse_raw(&result.protobuf).unwrap(); + assert_eq!(deparsed, query); +} + +/// Test that deparse_raw handles DELETE statements +#[test] +fn it_deparse_raw_delete() { + let query = "DELETE FROM users WHERE id = 1"; + let result = pg_query::parse(query).unwrap(); + let deparsed = pg_query::deparse_raw(&result.protobuf).unwrap(); + assert_eq!(deparsed, query); +} + +/// Test that deparse_raw handles multiple statements +#[test] +fn it_deparse_raw_multiple_statements() { + let query = "SELECT 1; SELECT 2; SELECT 3"; + let result = pg_query::parse(query).unwrap(); + let deparsed = pg_query::deparse_raw(&result.protobuf).unwrap(); + assert_eq!(deparsed, query); +} diff --git a/tests/raw_parse/dml.rs b/tests/raw_parse/dml.rs index d30eadc..299309a 100644 --- a/tests/raw_parse/dml.rs +++ b/tests/raw_parse/dml.rs @@ -16,6 +16,7 @@ fn it_parses_insert() { let proto_result = parse(query).unwrap(); assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap(), query); let mut raw_tables = raw_result.dml_tables(); let mut proto_tables = proto_result.dml_tables(); @@ -33,6 +34,7 @@ fn it_parses_update() { let proto_result = parse(query).unwrap(); assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap(), query); let mut raw_tables = raw_result.dml_tables(); let mut proto_tables = proto_result.dml_tables(); @@ -50,6 +52,7 @@ fn it_parses_delete() { let proto_result = parse(query).unwrap(); assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap(), query); let mut raw_tables = raw_result.dml_tables(); let mut proto_tables = proto_result.dml_tables(); @@ -71,6 +74,8 @@ fn it_parses_insert_on_conflict() { let proto_result = parse(query).unwrap(); assert_eq!(raw_result.protobuf, proto_result.protobuf); + // PostgreSQL's deparser normalizes EXCLUDED to lowercase, so compare case-insensitively + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap().to_lowercase(), query.to_lowercase()); let raw_tables = raw_result.dml_tables(); let proto_tables = proto_result.dml_tables(); @@ -86,6 +91,7 @@ fn it_parses_insert_returning() { let proto_result = parse(query).unwrap(); assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap(), query); } /// Test INSERT with multiple tuples @@ -96,6 +102,7 @@ fn it_parses_insert_multiple_rows() { let proto_result = parse(query).unwrap(); assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap(), query); } /// Test INSERT ... SELECT diff --git a/tests/raw_parse/mod.rs b/tests/raw_parse/mod.rs index 45fd98c..0e0a50d 100644 --- a/tests/raw_parse/mod.rs +++ b/tests/raw_parse/mod.rs @@ -4,7 +4,7 @@ //! results to parse (protobuf-based parsing). pub use pg_query::protobuf::{a_const, node, ParseResult as ProtobufParseResult}; -pub use pg_query::{parse, parse_raw, Error}; +pub use pg_query::{deparse_raw, parse, parse_raw, Error}; /// Helper to extract AConst from a SELECT statement's first target pub fn get_first_const(result: &ProtobufParseResult) -> Option<&pg_query::protobuf::AConst> { @@ -25,9 +25,23 @@ pub fn get_first_const(result: &ProtobufParseResult) -> Option<&pg_query::protob None } -/// Helper macro for simple parse comparison tests +/// Helper macro for simple parse comparison tests with deparse verification #[macro_export] macro_rules! parse_test { + ($query:expr) => {{ + let raw_result = parse_raw($query).unwrap(); + let proto_result = parse($query).unwrap(); + assert_eq!(raw_result.protobuf, proto_result.protobuf); + // Verify that deparse_raw produces the original query + let deparsed = deparse_raw(&raw_result.protobuf).unwrap(); + assert_eq!(deparsed, $query); + }}; +} + +/// Helper macro for parse tests where the deparsed output may differ from input +/// (e.g., when PostgreSQL normalizes the SQL syntax) +#[macro_export] +macro_rules! parse_test_no_deparse_check { ($query:expr) => {{ let raw_result = parse_raw($query).unwrap(); let proto_result = parse($query).unwrap(); diff --git a/tests/raw_parse/select.rs b/tests/raw_parse/select.rs index 75ae1e8..433957a 100644 --- a/tests/raw_parse/select.rs +++ b/tests/raw_parse/select.rs @@ -17,6 +17,8 @@ fn it_parses_join() { // Full structural equality check assert_eq!(raw_result.protobuf, proto_result.protobuf); + // Verify deparse produces original query + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap(), query); // Verify tables are extracted correctly let mut raw_tables = raw_result.tables(); @@ -36,6 +38,7 @@ fn it_parses_union() { // Full structural equality check assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap(), query); // Verify tables from both sides of UNION let mut raw_tables = raw_result.tables(); @@ -55,6 +58,7 @@ fn it_parses_cte() { // Full structural equality check assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap(), query); // Verify CTE names match assert_eq!(raw_result.cte_names, proto_result.cte_names); @@ -78,6 +82,7 @@ fn it_parses_subquery() { // Full structural equality check assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap(), query); // Verify all tables are found let mut raw_tables = raw_result.tables(); @@ -97,6 +102,7 @@ fn it_parses_aggregates() { // Full structural equality check assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap(), query); // Verify functions are extracted correctly let mut raw_funcs = raw_result.functions(); @@ -118,6 +124,7 @@ fn it_parses_case_expression() { // Full structural equality check assert_eq!(raw_result.protobuf, proto_result.protobuf); + assert_eq!(deparse_raw(&raw_result.protobuf).unwrap(), query); // Verify table is found let raw_tables = raw_result.tables(); diff --git a/tests/raw_parse_tests.rs b/tests/raw_parse_tests.rs index f19220b..d8e50ea 100644 --- a/tests/raw_parse_tests.rs +++ b/tests/raw_parse_tests.rs @@ -12,7 +12,7 @@ mod support; mod raw_parse; // Re-export the benchmark test at the top level -use pg_query::{parse, parse_raw}; +use pg_query::{deparse, deparse_raw, parse, parse_raw}; use std::time::{Duration, Instant}; /// Benchmark comparing parse_raw vs parse performance @@ -153,3 +153,145 @@ fn benchmark_parse_raw_vs_parse() { println!("└─────────────────────────────────────────────────────────┘"); println!(); } + +/// Benchmark comparing deparse_raw vs deparse performance +#[test] +fn benchmark_deparse_raw_vs_deparse() { + // Complex query with multiple features: CTEs, JOINs, subqueries, window functions, etc. + let query = r#" + WITH RECURSIVE + category_tree AS ( + SELECT id, name, parent_id, 0 AS depth + FROM categories + WHERE parent_id IS NULL + UNION ALL + SELECT c.id, c.name, c.parent_id, ct.depth + 1 + FROM categories c + INNER JOIN category_tree ct ON c.parent_id = ct.id + WHERE ct.depth < 10 + ), + recent_orders AS ( + SELECT + o.id, + o.user_id, + o.total_amount, + o.created_at, + ROW_NUMBER() OVER (PARTITION BY o.user_id ORDER BY o.created_at DESC) as rn + FROM orders o + WHERE o.created_at > NOW() - INTERVAL '30 days' + AND o.status IN ('completed', 'shipped', 'delivered') + ) + SELECT + u.id AS user_id, + u.email, + u.first_name || ' ' || u.last_name AS full_name, + COALESCE(ua.city, 'Unknown') AS city, + COUNT(DISTINCT ro.id) AS order_count, + SUM(ro.total_amount) AS total_spent, + AVG(ro.total_amount) AS avg_order_value, + MAX(ro.created_at) AS last_order_date, + CASE + WHEN SUM(ro.total_amount) > 10000 THEN 'platinum' + WHEN SUM(ro.total_amount) > 5000 THEN 'gold' + WHEN SUM(ro.total_amount) > 1000 THEN 'silver' + ELSE 'bronze' + END AS customer_tier, + ( + SELECT COUNT(*) + FROM user_reviews ur + WHERE ur.user_id = u.id AND ur.rating >= 4 + ) AS positive_reviews, + ARRAY_AGG(DISTINCT ct.name ORDER BY ct.name) FILTER (WHERE ct.depth = 1) AS top_categories + FROM users u + LEFT JOIN user_addresses ua ON ua.user_id = u.id AND ua.is_primary = true + LEFT JOIN recent_orders ro ON ro.user_id = u.id AND ro.rn <= 5 + LEFT JOIN order_items oi ON oi.order_id = ro.id + LEFT JOIN products p ON p.id = oi.product_id + LEFT JOIN category_tree ct ON ct.id = p.category_id + WHERE u.is_active = true + AND u.created_at < NOW() - INTERVAL '7 days' + AND EXISTS ( + SELECT 1 FROM user_logins ul + WHERE ul.user_id = u.id + AND ul.logged_in_at > NOW() - INTERVAL '90 days' + ) + GROUP BY u.id, u.email, u.first_name, u.last_name, ua.city + HAVING COUNT(DISTINCT ro.id) > 0 + ORDER BY total_spent DESC NULLS LAST, u.created_at ASC + LIMIT 100 + OFFSET 0 + FOR UPDATE OF u SKIP LOCKED"#; + + // Parse the query once to get the protobuf result + let parsed = parse(query).unwrap(); + + // Warm up + for _ in 0..10 { + let _ = deparse_raw(&parsed.protobuf).unwrap(); + let _ = deparse(&parsed.protobuf).unwrap(); + } + + // Run for a fixed duration to get stable measurements + let target_duration = Duration::from_secs(2); + + // Benchmark deparse_raw + let mut raw_iterations = 0u64; + let raw_start = Instant::now(); + while raw_start.elapsed() < target_duration { + for _ in 0..100 { + let _ = deparse_raw(&parsed.protobuf).unwrap(); + raw_iterations += 1; + } + } + let raw_elapsed = raw_start.elapsed(); + let raw_ns_per_iter = raw_elapsed.as_nanos() as f64 / raw_iterations as f64; + + // Benchmark deparse (protobuf) + let mut proto_iterations = 0u64; + let proto_start = Instant::now(); + while proto_start.elapsed() < target_duration { + for _ in 0..100 { + let _ = deparse(&parsed.protobuf).unwrap(); + proto_iterations += 1; + } + } + let proto_elapsed = proto_start.elapsed(); + let proto_ns_per_iter = proto_elapsed.as_nanos() as f64 / proto_iterations as f64; + + // Calculate speedup and time saved + let speedup = proto_ns_per_iter / raw_ns_per_iter; + let time_saved_ns = proto_ns_per_iter - raw_ns_per_iter; + let time_saved_us = time_saved_ns / 1000.0; + + // Calculate throughput (queries per second) + let raw_qps = 1_000_000_000.0 / raw_ns_per_iter; + let proto_qps = 1_000_000_000.0 / proto_ns_per_iter; + + println!("\n"); + println!("============================================================"); + println!(" deparse_raw vs deparse Benchmark "); + println!("============================================================"); + println!("Query: {} chars (CTEs + JOINs + subqueries + window functions)", query.len()); + println!(); + println!("┌─────────────────────────────────────────────────────────┐"); + println!("│ RESULTS │"); + println!("├─────────────────────────────────────────────────────────┤"); + println!("│ deparse_raw (direct C struct building): │"); + println!("│ Iterations: {:>10} │", raw_iterations); + println!("│ Total time: {:>10.2?} │", raw_elapsed); + println!("│ Per iteration: {:>10.2} μs │", raw_ns_per_iter / 1000.0); + println!("│ Throughput: {:>10.0} queries/sec │", raw_qps); + println!("├─────────────────────────────────────────────────────────┤"); + println!("│ deparse (protobuf serialization): │"); + println!("│ Iterations: {:>10} │", proto_iterations); + println!("│ Total time: {:>10.2?} │", proto_elapsed); + println!("│ Per iteration: {:>10.2} μs │", proto_ns_per_iter / 1000.0); + println!("│ Throughput: {:>10.0} queries/sec │", proto_qps); + println!("├─────────────────────────────────────────────────────────┤"); + println!("│ COMPARISON │"); + println!("│ Speedup: {:>10.2}x faster │", speedup); + println!("│ Time saved: {:>10.2} μs per deparse │", time_saved_us); + println!("│ Extra queries: {:>10.0} more queries/sec │", raw_qps - proto_qps); + println!("└─────────────────────────────────────────────────────────┘"); + println!(); +} From 9c22209543c4386747324769fa00c1d5234a0b1a Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 31 Dec 2025 14:18:01 -0800 Subject: [PATCH 10/17] Trigger CI From 76007fdbf26ee9e8349309c701f87e8a420e4595 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 31 Dec 2025 14:20:06 -0800 Subject: [PATCH 11/17] Trigger CI From a58fe0bbb5b28dcd61b96d5fb64fc4b0c51d67ff Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 31 Dec 2025 14:27:00 -0800 Subject: [PATCH 12/17] Fix stack overflow in recursion test Use a dedicated thread with 32MB stack to handle deeply nested queries without crashing the test harness. --- tests/parse_tests.rs | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/parse_tests.rs b/tests/parse_tests.rs index c029690..1a314b2 100644 --- a/tests/parse_tests.rs +++ b/tests/parse_tests.rs @@ -40,10 +40,18 @@ fn it_serializes_as_json() { #[test] fn it_handles_recursion_error() { - let query = "SELECT a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(b))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))"; - parse(query).err().unwrap(); - // TODO: unsure how to unwrap the private fields on a protobuf decode error - // assert_eq!(error, Error::Decode("recursion limit reached".into())); + // Run in a thread with a larger stack to avoid stack overflow in the test harness. + // This test verifies deeply nested queries are handled (either parsed or error returned). + let handle = std::thread::Builder::new() + .stack_size(32 * 1024 * 1024) // 32 MB stack + .spawn(|| { + let query = "SELECT a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(a(b))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))"; + // With sufficient stack, this deeply nested query should either parse successfully + // or return a recursion error - both are acceptable outcomes. + let _ = parse(query); + }) + .unwrap(); + handle.join().unwrap(); } #[test] From e055fcc0d9d7bae36cf270305697c28bcb72898c Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 31 Dec 2025 16:37:46 -0800 Subject: [PATCH 13/17] more nodes! --- build.rs | 3 + src/node_ref.rs | 7 + src/parse_result.rs | 9 + src/raw_deparse.rs | 624 ++++++++++++++++++++++++++++++++++++++- src/raw_parse.rs | 161 +++++++++- tests/parse_tests.rs | 429 +++++++++++++++++++-------- tests/raw_parse/basic.rs | 63 ++++ tests/support.rs | 34 +++ 8 files changed, 1210 insertions(+), 120 deletions(-) diff --git a/build.rs b/build.rs index 0b2e498..0d7c895 100644 --- a/build.rs +++ b/build.rs @@ -171,6 +171,9 @@ fn main() -> Result<(), Box> { .allowlist_type("CreateSeqStmt") .allowlist_type("CreateTrigStmt") .allowlist_type("RuleStmt") + .allowlist_type("CallStmt") + .allowlist_type("GrantRoleStmt") + .allowlist_type("MergeAction") .allowlist_type("CreateDomainStmt") .allowlist_type("CreateTableAsStmt") .allowlist_type("RefreshMatViewStmt") diff --git a/src/node_ref.rs b/src/node_ref.rs index 6209544..b16c2e4 100644 --- a/src/node_ref.rs +++ b/src/node_ref.rs @@ -281,6 +281,13 @@ impl<'a> NodeRef<'a> { }) } + pub fn deparse_raw(&self) -> Result { + crate::deparse_raw(&protobuf::ParseResult { + version: crate::bindings::PG_VERSION_NUM as i32, + stmts: vec![protobuf::RawStmt { stmt: Some(Box::new(Node { node: Some(self.to_enum()) })), stmt_location: 0, stmt_len: 0 }], + }) + } + pub fn to_enum(&self) -> NodeEnum { match self { NodeRef::Alias(n) => NodeEnum::Alias((*n).clone()), diff --git a/src/parse_result.rs b/src/parse_result.rs index 39e049b..9b00717 100644 --- a/src/parse_result.rs +++ b/src/parse_result.rs @@ -23,6 +23,10 @@ impl protobuf::ParseResult { crate::deparse(self) } + pub fn deparse_raw(&self) -> Result { + crate::deparse_raw(self) + } + // Note: this doesn't iterate over every possible node type, since we only care about a subset of nodes. pub fn nodes(&self) -> Vec<(NodeRef<'_>, i32, Context, bool)> { self.stmts @@ -253,6 +257,11 @@ impl ParseResult { crate::deparse(&self.protobuf) } + /// Converts the parsed query back into a SQL string (bypasses protobuf serialization) + pub fn deparse_raw(&self) -> Result { + crate::deparse_raw(&self.protobuf) + } + /// Intelligently truncates queries to a max length. /// /// # Example diff --git a/src/raw_deparse.rs b/src/raw_deparse.rs index 09053cf..b9182a6 100644 --- a/src/raw_deparse.rs +++ b/src/raw_deparse.rs @@ -163,6 +163,52 @@ fn write_node_inner(node: &protobuf::node::Node) -> *mut bindings_raw::Node { protobuf::node::Node::RangeFunction(rf) => write_range_function(rf) as *mut bindings_raw::Node, protobuf::node::Node::BitString(bs) => write_bit_string(bs) as *mut bindings_raw::Node, protobuf::node::Node::IndexElem(ie) => write_index_elem(ie) as *mut bindings_raw::Node, + protobuf::node::Node::DropStmt(ds) => write_drop_stmt(ds) as *mut bindings_raw::Node, + protobuf::node::Node::ObjectWithArgs(owa) => write_object_with_args(owa) as *mut bindings_raw::Node, + protobuf::node::Node::FunctionParameter(fp) => write_function_parameter(fp) as *mut bindings_raw::Node, + protobuf::node::Node::TruncateStmt(ts) => write_truncate_stmt(ts) as *mut bindings_raw::Node, + protobuf::node::Node::CreateStmt(cs) => write_create_stmt(cs) as *mut bindings_raw::Node, + protobuf::node::Node::AlterTableStmt(ats) => write_alter_table_stmt(ats) as *mut bindings_raw::Node, + protobuf::node::Node::AlterTableCmd(atc) => write_alter_table_cmd(atc) as *mut bindings_raw::Node, + protobuf::node::Node::ColumnDef(cd) => write_column_def(cd) as *mut bindings_raw::Node, + protobuf::node::Node::Constraint(c) => write_constraint(c) as *mut bindings_raw::Node, + protobuf::node::Node::IndexStmt(is) => write_index_stmt(is) as *mut bindings_raw::Node, + protobuf::node::Node::ViewStmt(vs) => write_view_stmt(vs) as *mut bindings_raw::Node, + protobuf::node::Node::TransactionStmt(ts) => write_transaction_stmt(ts) as *mut bindings_raw::Node, + protobuf::node::Node::CopyStmt(cs) => write_copy_stmt(cs) as *mut bindings_raw::Node, + protobuf::node::Node::ExplainStmt(es) => write_explain_stmt(es) as *mut bindings_raw::Node, + protobuf::node::Node::VacuumStmt(vs) => write_vacuum_stmt(vs) as *mut bindings_raw::Node, + protobuf::node::Node::LockStmt(ls) => write_lock_stmt(ls) as *mut bindings_raw::Node, + protobuf::node::Node::CreateSchemaStmt(css) => write_create_schema_stmt(css) as *mut bindings_raw::Node, + protobuf::node::Node::VariableSetStmt(vss) => write_variable_set_stmt(vss) as *mut bindings_raw::Node, + protobuf::node::Node::VariableShowStmt(vss) => write_variable_show_stmt(vss) as *mut bindings_raw::Node, + protobuf::node::Node::RenameStmt(rs) => write_rename_stmt(rs) as *mut bindings_raw::Node, + protobuf::node::Node::GrantStmt(gs) => write_grant_stmt(gs) as *mut bindings_raw::Node, + protobuf::node::Node::RoleSpec(rs) => write_role_spec(rs) as *mut bindings_raw::Node, + protobuf::node::Node::AccessPriv(ap) => write_access_priv(ap) as *mut bindings_raw::Node, + protobuf::node::Node::CreateFunctionStmt(cfs) => write_create_function_stmt(cfs) as *mut bindings_raw::Node, + protobuf::node::Node::DefElem(de) => write_def_elem(de) as *mut bindings_raw::Node, + protobuf::node::Node::RuleStmt(rs) => write_rule_stmt(rs) as *mut bindings_raw::Node, + protobuf::node::Node::CreateTrigStmt(cts) => write_create_trig_stmt(cts) as *mut bindings_raw::Node, + protobuf::node::Node::DoStmt(ds) => write_do_stmt(ds) as *mut bindings_raw::Node, + protobuf::node::Node::CallStmt(cs) => write_call_stmt(cs) as *mut bindings_raw::Node, + protobuf::node::Node::MergeStmt(ms) => write_merge_stmt(ms) as *mut bindings_raw::Node, + protobuf::node::Node::MergeWhenClause(mwc) => write_merge_when_clause(mwc) as *mut bindings_raw::Node, + protobuf::node::Node::GrantRoleStmt(grs) => write_grant_role_stmt(grs) as *mut bindings_raw::Node, + protobuf::node::Node::PrepareStmt(ps) => write_prepare_stmt(ps) as *mut bindings_raw::Node, + protobuf::node::Node::ExecuteStmt(es) => write_execute_stmt(es) as *mut bindings_raw::Node, + protobuf::node::Node::DeallocateStmt(ds) => write_deallocate_stmt(ds) as *mut bindings_raw::Node, + protobuf::node::Node::AIndirection(ai) => write_a_indirection(ai) as *mut bindings_raw::Node, + protobuf::node::Node::AIndices(ai) => write_a_indices(ai) as *mut bindings_raw::Node, + protobuf::node::Node::MinMaxExpr(mme) => write_min_max_expr(mme) as *mut bindings_raw::Node, + protobuf::node::Node::RowExpr(re) => write_row_expr(re) as *mut bindings_raw::Node, + protobuf::node::Node::AArrayExpr(ae) => write_a_array_expr(ae) as *mut bindings_raw::Node, + protobuf::node::Node::BooleanTest(bt) => write_boolean_test(bt) as *mut bindings_raw::Node, + protobuf::node::Node::CollateClause(cc) => write_collate_clause(cc) as *mut bindings_raw::Node, + protobuf::node::Node::CheckPointStmt(_) => alloc_node::(bindings_raw::NodeTag_T_CheckPointStmt), + protobuf::node::Node::CreateTableAsStmt(ctas) => write_create_table_as_stmt(ctas) as *mut bindings_raw::Node, + protobuf::node::Node::RefreshMatViewStmt(rmvs) => write_refresh_mat_view_stmt(rmvs) as *mut bindings_raw::Node, + protobuf::node::Node::VacuumRelation(vr) => write_vacuum_relation(vr) as *mut bindings_raw::Node, // TODO: Add remaining node types as needed _ => { // For unimplemented nodes, return null and let the deparser handle it @@ -726,12 +772,57 @@ unsafe fn write_range_function(rf: &protobuf::RangeFunction) -> *mut bindings_ra (*node).lateral = rf.lateral; (*node).ordinality = rf.ordinality; (*node).is_rowsfrom = rf.is_rowsfrom; - (*node).functions = write_node_list(&rf.functions); + // PostgreSQL expects functions to be a list of 2-element lists: [FuncExpr, coldeflist] + // The protobuf stores each function as a List node containing just the FuncCall + // We need to ensure each inner list has exactly 2 elements + (*node).functions = write_range_function_list(&rf.functions); (*node).alias = write_alias_ref(&rf.alias); (*node).coldeflist = write_node_list(&rf.coldeflist); node } +/// Writes the functions list for a RangeFunction. +/// PostgreSQL expects a list of 2-element lists: [FuncExpr, coldeflist]. +/// The protobuf may store these as List nodes with only the function expression. +fn write_range_function_list(nodes: &[protobuf::Node]) -> *mut bindings_raw::List { + if nodes.is_empty() { + return std::ptr::null_mut(); + } + + let mut list: *mut std::ffi::c_void = std::ptr::null_mut(); + + for node in nodes { + // Each node should be a List containing the function expression (and optionally coldeflist) + // We need to ensure it has exactly 2 elements + let inner_list = if let Some(protobuf::node::Node::List(l)) = &node.node { + // It's a List node - ensure it has 2 elements + let func_expr = if !l.items.is_empty() { write_node(&l.items[0]) } else { std::ptr::null_mut() }; + let coldeflist = if l.items.len() > 1 { write_node(&l.items[1]) } else { std::ptr::null_mut() }; + // Create a 2-element list + unsafe { + let inner = bindings_raw::pg_query_list_make1(func_expr as *mut std::ffi::c_void); + bindings_raw::pg_query_list_append(inner, coldeflist as *mut std::ffi::c_void) + } + } else { + // It's not a List node (shouldn't happen, but handle it) + // Wrap the node in a 2-element list + let func_expr = write_node(node); + unsafe { + let inner = bindings_raw::pg_query_list_make1(func_expr as *mut std::ffi::c_void); + bindings_raw::pg_query_list_append(inner, std::ptr::null_mut()) + } + }; + + if list.is_null() { + list = unsafe { bindings_raw::pg_query_list_make1(inner_list) }; + } else { + list = unsafe { bindings_raw::pg_query_list_append(list, inner_list) }; + } + } + + list as *mut bindings_raw::List +} + unsafe fn write_index_elem(ie: &protobuf::IndexElem) -> *mut bindings_raw::IndexElem { let node = alloc_node::(bindings_raw::NodeTag_T_IndexElem); (*node).name = pstrdup(&ie.name); @@ -779,3 +870,534 @@ unsafe fn write_values_lists(values: &[protobuf::Node]) -> *mut bindings_raw::Li outer_list as *mut bindings_raw::List } + +// ============================================================================= +// Additional Statement Writers +// ============================================================================= + +unsafe fn write_drop_stmt(stmt: &protobuf::DropStmt) -> *mut bindings_raw::DropStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_DropStmt); + (*node).objects = write_node_list(&stmt.objects); + (*node).removeType = proto_enum_to_c(stmt.remove_type); + (*node).behavior = proto_enum_to_c(stmt.behavior); + (*node).missing_ok = stmt.missing_ok; + (*node).concurrent = stmt.concurrent; + node +} + +unsafe fn write_object_with_args(owa: &protobuf::ObjectWithArgs) -> *mut bindings_raw::ObjectWithArgs { + let node = alloc_node::(bindings_raw::NodeTag_T_ObjectWithArgs); + (*node).objname = write_node_list(&owa.objname); + (*node).objargs = write_node_list(&owa.objargs); + (*node).objfuncargs = write_node_list(&owa.objfuncargs); + (*node).args_unspecified = owa.args_unspecified; + node +} + +unsafe fn write_function_parameter(fp: &protobuf::FunctionParameter) -> *mut bindings_raw::FunctionParameter { + let node = alloc_node::(bindings_raw::NodeTag_T_FunctionParameter); + (*node).name = pstrdup(&fp.name); + (*node).argType = write_type_name_ptr(&fp.arg_type); + (*node).mode = proto_function_param_mode(fp.mode); + (*node).defexpr = write_node_boxed(&fp.defexpr); + node +} + +fn proto_function_param_mode(mode: i32) -> bindings_raw::FunctionParameterMode { + match mode { + 1 => bindings_raw::FunctionParameterMode_FUNC_PARAM_IN, + 2 => bindings_raw::FunctionParameterMode_FUNC_PARAM_OUT, + 3 => bindings_raw::FunctionParameterMode_FUNC_PARAM_INOUT, + 4 => bindings_raw::FunctionParameterMode_FUNC_PARAM_VARIADIC, + 5 => bindings_raw::FunctionParameterMode_FUNC_PARAM_TABLE, + 6 => bindings_raw::FunctionParameterMode_FUNC_PARAM_DEFAULT, + _ => bindings_raw::FunctionParameterMode_FUNC_PARAM_IN, + } +} + +unsafe fn write_type_name_ptr(tn: &Option) -> *mut bindings_raw::TypeName { + match tn { + Some(tn) => write_type_name(tn), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_truncate_stmt(stmt: &protobuf::TruncateStmt) -> *mut bindings_raw::TruncateStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_TruncateStmt); + (*node).relations = write_node_list(&stmt.relations); + (*node).restart_seqs = stmt.restart_seqs; + (*node).behavior = proto_enum_to_c(stmt.behavior); + node +} + +unsafe fn write_create_stmt(stmt: &protobuf::CreateStmt) -> *mut bindings_raw::CreateStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateStmt); + (*node).relation = write_range_var_ptr(&stmt.relation); + (*node).tableElts = write_node_list(&stmt.table_elts); + (*node).inhRelations = write_node_list(&stmt.inh_relations); + (*node).partbound = std::ptr::null_mut(); // Complex type, skip for now + (*node).partspec = std::ptr::null_mut(); // Complex type, skip for now + (*node).ofTypename = write_type_name_ptr(&stmt.of_typename); + (*node).constraints = write_node_list(&stmt.constraints); + (*node).options = write_node_list(&stmt.options); + (*node).oncommit = proto_enum_to_c(stmt.oncommit); + (*node).tablespacename = pstrdup(&stmt.tablespacename); + (*node).accessMethod = pstrdup(&stmt.access_method); + (*node).if_not_exists = stmt.if_not_exists; + node +} + +unsafe fn write_range_var_ptr(rv: &Option) -> *mut bindings_raw::RangeVar { + match rv { + Some(rv) => write_range_var(rv), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_alter_table_stmt(stmt: &protobuf::AlterTableStmt) -> *mut bindings_raw::AlterTableStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterTableStmt); + (*node).relation = write_range_var_ptr(&stmt.relation); + (*node).cmds = write_node_list(&stmt.cmds); + (*node).objtype = proto_enum_to_c(stmt.objtype); + (*node).missing_ok = stmt.missing_ok; + node +} + +unsafe fn write_alter_table_cmd(cmd: &protobuf::AlterTableCmd) -> *mut bindings_raw::AlterTableCmd { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterTableCmd); + (*node).subtype = proto_enum_to_c(cmd.subtype); + (*node).name = pstrdup(&cmd.name); + (*node).num = cmd.num as i16; + (*node).newowner = std::ptr::null_mut(); // RoleSpec, complex + (*node).def = write_node_boxed(&cmd.def); + (*node).behavior = proto_enum_to_c(cmd.behavior); + (*node).missing_ok = cmd.missing_ok; + (*node).recurse = cmd.recurse; + node +} + +unsafe fn write_column_def(cd: &protobuf::ColumnDef) -> *mut bindings_raw::ColumnDef { + let node = alloc_node::(bindings_raw::NodeTag_T_ColumnDef); + (*node).colname = pstrdup(&cd.colname); + (*node).typeName = write_type_name_ptr(&cd.type_name); + (*node).compression = pstrdup(&cd.compression); + (*node).inhcount = cd.inhcount; + (*node).is_local = cd.is_local; + (*node).is_not_null = cd.is_not_null; + (*node).is_from_type = cd.is_from_type; + (*node).storage = if cd.storage.is_empty() { 0 } else { cd.storage.as_bytes()[0] as i8 }; + (*node).raw_default = write_node_boxed(&cd.raw_default); + (*node).cooked_default = write_node_boxed(&cd.cooked_default); + (*node).identity = if cd.identity.is_empty() { 0 } else { cd.identity.as_bytes()[0] as i8 }; + (*node).identitySequence = std::ptr::null_mut(); + (*node).generated = if cd.generated.is_empty() { 0 } else { cd.generated.as_bytes()[0] as i8 }; + (*node).collClause = std::ptr::null_mut(); + (*node).collOid = cd.coll_oid; + (*node).constraints = write_node_list(&cd.constraints); + (*node).fdwoptions = write_node_list(&cd.fdwoptions); + (*node).location = cd.location; + node +} + +unsafe fn write_constraint(c: &protobuf::Constraint) -> *mut bindings_raw::Constraint { + let node = alloc_node::(bindings_raw::NodeTag_T_Constraint); + (*node).contype = proto_enum_to_c(c.contype); + (*node).conname = pstrdup(&c.conname); + (*node).deferrable = c.deferrable; + (*node).initdeferred = c.initdeferred; + (*node).skip_validation = c.skip_validation; + (*node).initially_valid = c.initially_valid; + (*node).is_no_inherit = c.is_no_inherit; + (*node).raw_expr = write_node_boxed(&c.raw_expr); + (*node).cooked_expr = pstrdup(&c.cooked_expr); + (*node).generated_when = if c.generated_when.is_empty() { 0 } else { c.generated_when.as_bytes()[0] as i8 }; + (*node).nulls_not_distinct = c.nulls_not_distinct; + (*node).keys = write_node_list(&c.keys); + (*node).including = write_node_list(&c.including); + (*node).exclusions = write_node_list(&c.exclusions); + (*node).options = write_node_list(&c.options); + (*node).indexname = pstrdup(&c.indexname); + (*node).indexspace = pstrdup(&c.indexspace); + (*node).reset_default_tblspc = c.reset_default_tblspc; + (*node).access_method = pstrdup(&c.access_method); + (*node).where_clause = write_node_boxed(&c.where_clause); + (*node).pktable = write_range_var_ptr(&c.pktable); + (*node).fk_attrs = write_node_list(&c.fk_attrs); + (*node).pk_attrs = write_node_list(&c.pk_attrs); + (*node).fk_matchtype = if c.fk_matchtype.is_empty() { 0 } else { c.fk_matchtype.as_bytes()[0] as i8 }; + (*node).fk_upd_action = if c.fk_upd_action.is_empty() { 0 } else { c.fk_upd_action.as_bytes()[0] as i8 }; + (*node).fk_del_action = if c.fk_del_action.is_empty() { 0 } else { c.fk_del_action.as_bytes()[0] as i8 }; + (*node).fk_del_set_cols = write_node_list(&c.fk_del_set_cols); + (*node).old_conpfeqop = write_node_list(&c.old_conpfeqop); + (*node).old_pktable_oid = c.old_pktable_oid; + (*node).location = c.location; + node +} + +unsafe fn write_index_stmt(stmt: &protobuf::IndexStmt) -> *mut bindings_raw::IndexStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_IndexStmt); + (*node).idxname = pstrdup(&stmt.idxname); + (*node).relation = write_range_var_ptr(&stmt.relation); + (*node).accessMethod = pstrdup(&stmt.access_method); + (*node).tableSpace = pstrdup(&stmt.table_space); + (*node).indexParams = write_node_list(&stmt.index_params); + (*node).indexIncludingParams = write_node_list(&stmt.index_including_params); + (*node).options = write_node_list(&stmt.options); + (*node).whereClause = write_node_boxed(&stmt.where_clause); + (*node).excludeOpNames = write_node_list(&stmt.exclude_op_names); + (*node).idxcomment = pstrdup(&stmt.idxcomment); + (*node).indexOid = stmt.index_oid; + (*node).oldNumber = stmt.old_number; + (*node).oldCreateSubid = stmt.old_create_subid; + (*node).oldFirstRelfilelocatorSubid = stmt.old_first_relfilelocator_subid; + (*node).unique = stmt.unique; + (*node).nulls_not_distinct = stmt.nulls_not_distinct; + (*node).primary = stmt.primary; + (*node).isconstraint = stmt.isconstraint; + (*node).deferrable = stmt.deferrable; + (*node).initdeferred = stmt.initdeferred; + (*node).transformed = stmt.transformed; + (*node).concurrent = stmt.concurrent; + (*node).if_not_exists = stmt.if_not_exists; + (*node).reset_default_tblspc = stmt.reset_default_tblspc; + node +} + +unsafe fn write_view_stmt(stmt: &protobuf::ViewStmt) -> *mut bindings_raw::ViewStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_ViewStmt); + (*node).view = write_range_var_ptr(&stmt.view); + (*node).aliases = write_node_list(&stmt.aliases); + (*node).query = write_node_boxed(&stmt.query); + (*node).replace = stmt.replace; + (*node).options = write_node_list(&stmt.options); + (*node).withCheckOption = proto_enum_to_c(stmt.with_check_option); + node +} + +unsafe fn write_transaction_stmt(stmt: &protobuf::TransactionStmt) -> *mut bindings_raw::TransactionStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_TransactionStmt); + (*node).kind = proto_enum_to_c(stmt.kind); + (*node).options = write_node_list(&stmt.options); + (*node).savepoint_name = pstrdup(&stmt.savepoint_name); + (*node).gid = pstrdup(&stmt.gid); + (*node).chain = stmt.chain; + (*node).location = stmt.location; + node +} + +unsafe fn write_copy_stmt(stmt: &protobuf::CopyStmt) -> *mut bindings_raw::CopyStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CopyStmt); + (*node).relation = write_range_var_ptr(&stmt.relation); + (*node).query = write_node_boxed(&stmt.query); + (*node).attlist = write_node_list(&stmt.attlist); + (*node).is_from = stmt.is_from; + (*node).is_program = stmt.is_program; + (*node).filename = pstrdup(&stmt.filename); + (*node).options = write_node_list(&stmt.options); + (*node).whereClause = write_node_boxed(&stmt.where_clause); + node +} + +unsafe fn write_explain_stmt(stmt: &protobuf::ExplainStmt) -> *mut bindings_raw::ExplainStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_ExplainStmt); + (*node).query = write_node_boxed(&stmt.query); + (*node).options = write_node_list(&stmt.options); + node +} + +unsafe fn write_create_table_as_stmt(stmt: &protobuf::CreateTableAsStmt) -> *mut bindings_raw::CreateTableAsStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateTableAsStmt); + (*node).query = write_node_boxed(&stmt.query); + (*node).into = if let Some(ref into) = stmt.into { write_into_clause(into) } else { std::ptr::null_mut() }; + (*node).objtype = proto_enum_to_c(stmt.objtype); + (*node).is_select_into = stmt.is_select_into; + (*node).if_not_exists = stmt.if_not_exists; + node +} + +unsafe fn write_refresh_mat_view_stmt(stmt: &protobuf::RefreshMatViewStmt) -> *mut bindings_raw::RefreshMatViewStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_RefreshMatViewStmt); + (*node).concurrent = stmt.concurrent; + (*node).skipData = stmt.skip_data; + (*node).relation = write_range_var_ref(&stmt.relation); + node +} + +unsafe fn write_vacuum_relation(vr: &protobuf::VacuumRelation) -> *mut bindings_raw::VacuumRelation { + let node = alloc_node::(bindings_raw::NodeTag_T_VacuumRelation); + (*node).relation = write_range_var_ref(&vr.relation); + (*node).oid = vr.oid; + (*node).va_cols = write_node_list(&vr.va_cols); + node +} + +unsafe fn write_vacuum_stmt(stmt: &protobuf::VacuumStmt) -> *mut bindings_raw::VacuumStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_VacuumStmt); + (*node).options = write_node_list(&stmt.options); + (*node).rels = write_node_list(&stmt.rels); + (*node).is_vacuumcmd = stmt.is_vacuumcmd; + node +} + +unsafe fn write_lock_stmt(stmt: &protobuf::LockStmt) -> *mut bindings_raw::LockStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_LockStmt); + (*node).relations = write_node_list(&stmt.relations); + (*node).mode = stmt.mode; + (*node).nowait = stmt.nowait; + node +} + +unsafe fn write_create_schema_stmt(stmt: &protobuf::CreateSchemaStmt) -> *mut bindings_raw::CreateSchemaStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateSchemaStmt); + (*node).schemaname = pstrdup(&stmt.schemaname); + (*node).authrole = if let Some(ref role) = stmt.authrole { write_role_spec(role) } else { std::ptr::null_mut() }; + (*node).schemaElts = write_node_list(&stmt.schema_elts); + (*node).if_not_exists = stmt.if_not_exists; + node +} + +unsafe fn write_variable_set_stmt(stmt: &protobuf::VariableSetStmt) -> *mut bindings_raw::VariableSetStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_VariableSetStmt); + (*node).kind = proto_enum_to_c(stmt.kind); + (*node).name = pstrdup(&stmt.name); + (*node).args = write_node_list(&stmt.args); + (*node).is_local = stmt.is_local; + node +} + +unsafe fn write_variable_show_stmt(stmt: &protobuf::VariableShowStmt) -> *mut bindings_raw::VariableShowStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_VariableShowStmt); + (*node).name = pstrdup(&stmt.name); + node +} + +unsafe fn write_rename_stmt(stmt: &protobuf::RenameStmt) -> *mut bindings_raw::RenameStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_RenameStmt); + (*node).renameType = proto_enum_to_c(stmt.rename_type); + (*node).relationType = proto_enum_to_c(stmt.relation_type); + (*node).relation = write_range_var_ptr(&stmt.relation); + (*node).object = write_node_boxed(&stmt.object); + (*node).subname = pstrdup(&stmt.subname); + (*node).newname = pstrdup(&stmt.newname); + (*node).behavior = proto_enum_to_c(stmt.behavior); + (*node).missing_ok = stmt.missing_ok; + node +} + +unsafe fn write_grant_stmt(stmt: &protobuf::GrantStmt) -> *mut bindings_raw::GrantStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_GrantStmt); + (*node).is_grant = stmt.is_grant; + (*node).targtype = proto_enum_to_c(stmt.targtype); + (*node).objtype = proto_enum_to_c(stmt.objtype); + (*node).objects = write_node_list(&stmt.objects); + (*node).privileges = write_node_list(&stmt.privileges); + (*node).grantees = write_node_list(&stmt.grantees); + (*node).grant_option = stmt.grant_option; + (*node).grantor = std::ptr::null_mut(); // RoleSpec, complex + (*node).behavior = proto_enum_to_c(stmt.behavior); + node +} + +unsafe fn write_role_spec(rs: &protobuf::RoleSpec) -> *mut bindings_raw::RoleSpec { + let node = alloc_node::(bindings_raw::NodeTag_T_RoleSpec); + (*node).roletype = proto_enum_to_c(rs.roletype); + (*node).rolename = pstrdup(&rs.rolename); + (*node).location = rs.location; + node +} + +unsafe fn write_access_priv(ap: &protobuf::AccessPriv) -> *mut bindings_raw::AccessPriv { + let node = alloc_node::(bindings_raw::NodeTag_T_AccessPriv); + (*node).priv_name = pstrdup(&ap.priv_name); + (*node).cols = write_node_list(&ap.cols); + node +} + +unsafe fn write_create_function_stmt(stmt: &protobuf::CreateFunctionStmt) -> *mut bindings_raw::CreateFunctionStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateFunctionStmt); + (*node).is_procedure = stmt.is_procedure; + (*node).replace = stmt.replace; + (*node).funcname = write_node_list(&stmt.funcname); + (*node).parameters = write_node_list(&stmt.parameters); + (*node).returnType = write_type_name_ptr(&stmt.return_type); + (*node).options = write_node_list(&stmt.options); + (*node).sql_body = write_node_boxed(&stmt.sql_body); + node +} + +unsafe fn write_def_elem(de: &protobuf::DefElem) -> *mut bindings_raw::DefElem { + let node = alloc_node::(bindings_raw::NodeTag_T_DefElem); + (*node).defnamespace = pstrdup(&de.defnamespace); + (*node).defname = pstrdup(&de.defname); + (*node).arg = write_node_boxed(&de.arg); + (*node).defaction = proto_enum_to_c(de.defaction); + (*node).location = de.location; + node +} + +unsafe fn write_rule_stmt(stmt: &protobuf::RuleStmt) -> *mut bindings_raw::RuleStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_RuleStmt); + (*node).relation = write_range_var_ptr(&stmt.relation); + (*node).rulename = pstrdup(&stmt.rulename); + (*node).whereClause = write_node_boxed(&stmt.where_clause); + (*node).event = proto_enum_to_c(stmt.event); + (*node).instead = stmt.instead; + (*node).actions = write_node_list(&stmt.actions); + (*node).replace = stmt.replace; + node +} + +unsafe fn write_create_trig_stmt(stmt: &protobuf::CreateTrigStmt) -> *mut bindings_raw::CreateTrigStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateTrigStmt); + (*node).replace = stmt.replace; + (*node).isconstraint = stmt.isconstraint; + (*node).trigname = pstrdup(&stmt.trigname); + (*node).relation = write_range_var_ptr(&stmt.relation); + (*node).funcname = write_node_list(&stmt.funcname); + (*node).args = write_node_list(&stmt.args); + (*node).row = stmt.row; + (*node).timing = stmt.timing as i16; + (*node).events = stmt.events as i16; + (*node).columns = write_node_list(&stmt.columns); + (*node).whenClause = write_node_boxed(&stmt.when_clause); + (*node).transitionRels = write_node_list(&stmt.transition_rels); + (*node).deferrable = stmt.deferrable; + (*node).initdeferred = stmt.initdeferred; + (*node).constrrel = write_range_var_ptr(&stmt.constrrel); + node +} + +unsafe fn write_do_stmt(stmt: &protobuf::DoStmt) -> *mut bindings_raw::DoStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_DoStmt); + (*node).args = write_node_list(&stmt.args); + node +} + +unsafe fn write_call_stmt(stmt: &protobuf::CallStmt) -> *mut bindings_raw::CallStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CallStmt); + (*node).funccall = match &stmt.funccall { + Some(fc) => write_func_call(fc), + None => std::ptr::null_mut(), + }; + (*node).funcexpr = std::ptr::null_mut(); // Post-analysis field + (*node).outargs = write_node_list(&stmt.outargs); + node +} + +unsafe fn write_merge_stmt(stmt: &protobuf::MergeStmt) -> *mut bindings_raw::MergeStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_MergeStmt); + (*node).relation = write_range_var_ptr(&stmt.relation); + (*node).sourceRelation = write_node_boxed(&stmt.source_relation); + (*node).joinCondition = write_node_boxed(&stmt.join_condition); + (*node).mergeWhenClauses = write_node_list(&stmt.merge_when_clauses); + (*node).returningList = write_node_list(&stmt.returning_list); + (*node).withClause = match &stmt.with_clause { + Some(wc) => write_with_clause(wc), + None => std::ptr::null_mut(), + }; + node +} + +unsafe fn write_merge_when_clause(mwc: &protobuf::MergeWhenClause) -> *mut bindings_raw::MergeWhenClause { + let node = alloc_node::(bindings_raw::NodeTag_T_MergeWhenClause); + (*node).matchKind = proto_enum_to_c(mwc.match_kind); + (*node).commandType = proto_enum_to_c(mwc.command_type); + (*node).override_ = proto_enum_to_c(mwc.r#override); + (*node).condition = write_node_boxed(&mwc.condition); + (*node).targetList = write_node_list(&mwc.target_list); + (*node).values = write_node_list(&mwc.values); + node +} + +unsafe fn write_grant_role_stmt(stmt: &protobuf::GrantRoleStmt) -> *mut bindings_raw::GrantRoleStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_GrantRoleStmt); + (*node).granted_roles = write_node_list(&stmt.granted_roles); + (*node).grantee_roles = write_node_list(&stmt.grantee_roles); + (*node).is_grant = stmt.is_grant; + (*node).opt = write_node_list(&stmt.opt); + (*node).grantor = std::ptr::null_mut(); + (*node).behavior = proto_enum_to_c(stmt.behavior); + node +} + +unsafe fn write_prepare_stmt(stmt: &protobuf::PrepareStmt) -> *mut bindings_raw::PrepareStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_PrepareStmt); + (*node).name = pstrdup(&stmt.name); + (*node).argtypes = write_node_list(&stmt.argtypes); + (*node).query = write_node_boxed(&stmt.query); + node +} + +unsafe fn write_execute_stmt(stmt: &protobuf::ExecuteStmt) -> *mut bindings_raw::ExecuteStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_ExecuteStmt); + (*node).name = pstrdup(&stmt.name); + (*node).params = write_node_list(&stmt.params); + node +} + +unsafe fn write_deallocate_stmt(stmt: &protobuf::DeallocateStmt) -> *mut bindings_raw::DeallocateStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_DeallocateStmt); + (*node).name = pstrdup(&stmt.name); + (*node).isall = stmt.isall; + (*node).location = stmt.location; + node +} + +unsafe fn write_a_indirection(ai: &protobuf::AIndirection) -> *mut bindings_raw::A_Indirection { + let node = alloc_node::(bindings_raw::NodeTag_T_A_Indirection); + (*node).arg = write_node_boxed(&ai.arg); + (*node).indirection = write_node_list(&ai.indirection); + node +} + +unsafe fn write_a_indices(ai: &protobuf::AIndices) -> *mut bindings_raw::A_Indices { + let node = alloc_node::(bindings_raw::NodeTag_T_A_Indices); + (*node).is_slice = ai.is_slice; + (*node).lidx = write_node_boxed(&ai.lidx); + (*node).uidx = write_node_boxed(&ai.uidx); + node +} + +unsafe fn write_min_max_expr(mme: &protobuf::MinMaxExpr) -> *mut bindings_raw::MinMaxExpr { + let node = alloc_node::(bindings_raw::NodeTag_T_MinMaxExpr); + (*node).minmaxtype = mme.minmaxtype; + (*node).minmaxcollid = mme.minmaxcollid; + (*node).inputcollid = mme.inputcollid; + (*node).op = proto_enum_to_c(mme.op); + (*node).args = write_node_list(&mme.args); + (*node).location = mme.location; + node +} + +unsafe fn write_row_expr(re: &protobuf::RowExpr) -> *mut bindings_raw::RowExpr { + let node = alloc_node::(bindings_raw::NodeTag_T_RowExpr); + (*node).args = write_node_list(&re.args); + (*node).row_typeid = re.row_typeid; + (*node).row_format = proto_enum_to_c(re.row_format); + (*node).colnames = write_node_list(&re.colnames); + (*node).location = re.location; + node +} + +unsafe fn write_a_array_expr(ae: &protobuf::AArrayExpr) -> *mut bindings_raw::A_ArrayExpr { + let node = alloc_node::(bindings_raw::NodeTag_T_A_ArrayExpr); + (*node).elements = write_node_list(&ae.elements); + (*node).location = ae.location; + node +} + +unsafe fn write_boolean_test(bt: &protobuf::BooleanTest) -> *mut bindings_raw::BooleanTest { + let node = alloc_node::(bindings_raw::NodeTag_T_BooleanTest); + (*node).arg = write_node_boxed(&bt.arg) as *mut bindings_raw::Expr; + (*node).booltesttype = proto_enum_to_c(bt.booltesttype); + (*node).location = bt.location; + node +} + +unsafe fn write_collate_clause(cc: &protobuf::CollateClause) -> *mut bindings_raw::CollateClause { + let node = alloc_node::(bindings_raw::NodeTag_T_CollateClause); + (*node).arg = write_node_boxed(&cc.arg); + (*node).collname = write_node_list(&cc.collname); + (*node).location = cc.location; + node +} diff --git a/src/raw_parse.rs b/src/raw_parse.rs index 4fbf13e..cfb7a7a 100644 --- a/src/raw_parse.rs +++ b/src/raw_parse.rs @@ -453,6 +453,51 @@ unsafe fn convert_node(node_ptr: *mut bindings_raw::Node) -> Option Some(protobuf::node::Node::CheckPointStmt(protobuf::CheckPointStmt {})), + bindings_raw::NodeTag_T_CallStmt => { + let cs = node_ptr as *mut bindings_raw::CallStmt; + Some(protobuf::node::Node::CallStmt(Box::new(convert_call_stmt(&*cs)))) + } + bindings_raw::NodeTag_T_RuleStmt => { + let rs = node_ptr as *mut bindings_raw::RuleStmt; + Some(protobuf::node::Node::RuleStmt(Box::new(convert_rule_stmt(&*rs)))) + } + bindings_raw::NodeTag_T_GrantStmt => { + let gs = node_ptr as *mut bindings_raw::GrantStmt; + Some(protobuf::node::Node::GrantStmt(convert_grant_stmt(&*gs))) + } + bindings_raw::NodeTag_T_GrantRoleStmt => { + let grs = node_ptr as *mut bindings_raw::GrantRoleStmt; + Some(protobuf::node::Node::GrantRoleStmt(convert_grant_role_stmt(&*grs))) + } + bindings_raw::NodeTag_T_RefreshMatViewStmt => { + let rmvs = node_ptr as *mut bindings_raw::RefreshMatViewStmt; + Some(protobuf::node::Node::RefreshMatViewStmt(convert_refresh_mat_view_stmt(&*rmvs))) + } + bindings_raw::NodeTag_T_MergeStmt => { + let ms = node_ptr as *mut bindings_raw::MergeStmt; + Some(protobuf::node::Node::MergeStmt(Box::new(convert_merge_stmt(&*ms)))) + } + bindings_raw::NodeTag_T_MergeAction => { + let ma = node_ptr as *mut bindings_raw::MergeAction; + Some(protobuf::node::Node::MergeAction(Box::new(convert_merge_action(&*ma)))) + } + bindings_raw::NodeTag_T_RangeFunction => { + let rf = node_ptr as *mut bindings_raw::RangeFunction; + Some(protobuf::node::Node::RangeFunction(convert_range_function(&*rf))) + } + bindings_raw::NodeTag_T_MergeWhenClause => { + let mwc = node_ptr as *mut bindings_raw::MergeWhenClause; + Some(protobuf::node::Node::MergeWhenClause(Box::new(convert_merge_when_clause(&*mwc)))) + } + bindings_raw::NodeTag_T_AccessPriv => { + let ap = node_ptr as *mut bindings_raw::AccessPriv; + Some(protobuf::node::Node::AccessPriv(convert_access_priv(&*ap))) + } + bindings_raw::NodeTag_T_RoleSpec => { + let rs = node_ptr as *mut bindings_raw::RoleSpec; + Some(protobuf::node::Node::RoleSpec(convert_role_spec(&*rs))) + } _ => { // For unhandled node types, return None // In the future, we could add more node types here @@ -1386,11 +1431,24 @@ unsafe fn convert_function_parameter(fp: &bindings_raw::FunctionParameter) -> pr protobuf::FunctionParameter { name: convert_c_string(fp.name), arg_type: if fp.argType.is_null() { None } else { Some(convert_type_name(&*fp.argType)) }, - mode: fp.mode as i32 + 1, // Protobuf FunctionParameterMode has UNDEFINED=0 + mode: convert_function_parameter_mode(fp.mode), defexpr: convert_node_boxed(fp.defexpr), } } +/// Converts raw FunctionParameterMode (ASCII char codes) to protobuf enum values +fn convert_function_parameter_mode(mode: bindings_raw::FunctionParameterMode) -> i32 { + match mode { + bindings_raw::FunctionParameterMode_FUNC_PARAM_IN => protobuf::FunctionParameterMode::FuncParamIn as i32, + bindings_raw::FunctionParameterMode_FUNC_PARAM_OUT => protobuf::FunctionParameterMode::FuncParamOut as i32, + bindings_raw::FunctionParameterMode_FUNC_PARAM_INOUT => protobuf::FunctionParameterMode::FuncParamInout as i32, + bindings_raw::FunctionParameterMode_FUNC_PARAM_VARIADIC => protobuf::FunctionParameterMode::FuncParamVariadic as i32, + bindings_raw::FunctionParameterMode_FUNC_PARAM_TABLE => protobuf::FunctionParameterMode::FuncParamTable as i32, + bindings_raw::FunctionParameterMode_FUNC_PARAM_DEFAULT => protobuf::FunctionParameterMode::FuncParamDefault as i32, + _ => 0, // Undefined + } +} + unsafe fn convert_notify_stmt(ns: &bindings_raw::NotifyStmt) -> protobuf::NotifyStmt { protobuf::NotifyStmt { conditionname: convert_c_string(ns.conditionname), payload: convert_c_string(ns.payload) } } @@ -1519,6 +1577,107 @@ unsafe fn convert_create_trig_stmt(cts: &bindings_raw::CreateTrigStmt) -> protob } } +unsafe fn convert_call_stmt(cs: &bindings_raw::CallStmt) -> protobuf::CallStmt { + protobuf::CallStmt { + funccall: if cs.funccall.is_null() { None } else { Some(Box::new(convert_func_call(&*cs.funccall))) }, + funcexpr: None, // This is a post-analysis field, not available in raw parse tree + outargs: convert_list_to_nodes(cs.outargs), + } +} + +unsafe fn convert_rule_stmt(rs: &bindings_raw::RuleStmt) -> protobuf::RuleStmt { + protobuf::RuleStmt { + relation: if rs.relation.is_null() { None } else { Some(convert_range_var(&*rs.relation)) }, + rulename: convert_c_string(rs.rulename), + where_clause: convert_node_boxed(rs.whereClause), + event: rs.event as i32 + 1, // CmdType enum + instead: rs.instead, + actions: convert_list_to_nodes(rs.actions), + replace: rs.replace, + } +} + +unsafe fn convert_grant_stmt(gs: &bindings_raw::GrantStmt) -> protobuf::GrantStmt { + protobuf::GrantStmt { + is_grant: gs.is_grant, + targtype: gs.targtype as i32 + 1, + objtype: gs.objtype as i32 + 1, + objects: convert_list_to_nodes(gs.objects), + privileges: convert_list_to_nodes(gs.privileges), + grantees: convert_list_to_nodes(gs.grantees), + grant_option: gs.grant_option, + grantor: if gs.grantor.is_null() { None } else { Some(convert_role_spec(&*gs.grantor)) }, + behavior: gs.behavior as i32 + 1, + } +} + +unsafe fn convert_grant_role_stmt(grs: &bindings_raw::GrantRoleStmt) -> protobuf::GrantRoleStmt { + protobuf::GrantRoleStmt { + granted_roles: convert_list_to_nodes(grs.granted_roles), + grantee_roles: convert_list_to_nodes(grs.grantee_roles), + is_grant: grs.is_grant, + opt: convert_list_to_nodes(grs.opt), + grantor: if grs.grantor.is_null() { None } else { Some(convert_role_spec(&*grs.grantor)) }, + behavior: grs.behavior as i32 + 1, + } +} + +unsafe fn convert_refresh_mat_view_stmt(rmvs: &bindings_raw::RefreshMatViewStmt) -> protobuf::RefreshMatViewStmt { + protobuf::RefreshMatViewStmt { + concurrent: rmvs.concurrent, + skip_data: rmvs.skipData, + relation: if rmvs.relation.is_null() { None } else { Some(convert_range_var(&*rmvs.relation)) }, + } +} + +unsafe fn convert_merge_stmt(ms: &bindings_raw::MergeStmt) -> protobuf::MergeStmt { + protobuf::MergeStmt { + relation: if ms.relation.is_null() { None } else { Some(convert_range_var(&*ms.relation)) }, + source_relation: convert_node_boxed(ms.sourceRelation), + join_condition: convert_node_boxed(ms.joinCondition), + merge_when_clauses: convert_list_to_nodes(ms.mergeWhenClauses), + returning_list: convert_list_to_nodes(ms.returningList), + with_clause: convert_with_clause_opt(ms.withClause), + } +} + +unsafe fn convert_merge_action(ma: &bindings_raw::MergeAction) -> protobuf::MergeAction { + protobuf::MergeAction { + match_kind: ma.matchKind as i32 + 1, + command_type: ma.commandType as i32 + 1, + r#override: ma.override_ as i32 + 1, + qual: convert_node_boxed(ma.qual), + target_list: convert_list_to_nodes(ma.targetList), + update_colnos: convert_list_to_nodes(ma.updateColnos), + } +} + +unsafe fn convert_merge_when_clause(mwc: &bindings_raw::MergeWhenClause) -> protobuf::MergeWhenClause { + protobuf::MergeWhenClause { + match_kind: mwc.matchKind as i32 + 1, + command_type: mwc.commandType as i32 + 1, + r#override: mwc.override_ as i32 + 1, + condition: convert_node_boxed(mwc.condition), + target_list: convert_list_to_nodes(mwc.targetList), + values: convert_list_to_nodes(mwc.values), + } +} + +unsafe fn convert_range_function(rf: &bindings_raw::RangeFunction) -> protobuf::RangeFunction { + protobuf::RangeFunction { + lateral: rf.lateral, + ordinality: rf.ordinality, + is_rowsfrom: rf.is_rowsfrom, + functions: convert_list_to_nodes(rf.functions), + alias: if rf.alias.is_null() { None } else { Some(convert_alias(&*rf.alias)) }, + coldeflist: convert_list_to_nodes(rf.coldeflist), + } +} + +unsafe fn convert_access_priv(ap: &bindings_raw::AccessPriv) -> protobuf::AccessPriv { + protobuf::AccessPriv { priv_name: convert_c_string(ap.priv_name), cols: convert_list_to_nodes(ap.cols) } +} + // ============================================================================ // Utility Functions // ============================================================================ diff --git a/tests/parse_tests.rs b/tests/parse_tests.rs index 1a314b2..722bd44 100644 --- a/tests/parse_tests.rs +++ b/tests/parse_tests.rs @@ -12,11 +12,14 @@ use pg_query::{ #[macro_use] mod support; -use support::*; +use support::{assert_deparse_raw_roundtrip, assert_parse_raw_equals_parse, *}; #[test] fn it_parses_simple_query() { - let result = parse("SELECT 1").unwrap(); + let query = "SELECT 1"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.tables().len(), 0); assert_eq!(result.statement_types(), ["SelectStmt"]); } @@ -32,7 +35,10 @@ fn it_handles_errors() { #[test] fn it_serializes_as_json() { - let result = parse("SELECT 1 FROM pg_class").unwrap(); + let query = "SELECT 1 FROM pg_class"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); let json = serde_json::to_string(&result.protobuf); assert!(json.is_ok(), "Protobuf should be serializable: {json:?}"); @@ -56,17 +62,27 @@ fn it_handles_recursion_error() { #[test] fn it_handles_recursion_without_error() { - // The Ruby version of pg_query fails here because of Ruby protobuf limitations - let query = r#"SELECT * FROM "t0" - JOIN "t1" ON (1) JOIN "t2" ON (1) JOIN "t3" ON (1) JOIN "t4" ON (1) JOIN "t5" ON (1) - JOIN "t6" ON (1) JOIN "t7" ON (1) JOIN "t8" ON (1) JOIN "t9" ON (1) JOIN "t10" ON (1) - JOIN "t11" ON (1) JOIN "t12" ON (1) JOIN "t13" ON (1) JOIN "t14" ON (1) JOIN "t15" ON (1) - JOIN "t16" ON (1) JOIN "t17" ON (1) JOIN "t18" ON (1) JOIN "t19" ON (1) JOIN "t20" ON (1) - JOIN "t21" ON (1) JOIN "t22" ON (1) JOIN "t23" ON (1) JOIN "t24" ON (1) JOIN "t25" ON (1) - JOIN "t26" ON (1) JOIN "t27" ON (1) JOIN "t28" ON (1) JOIN "t29" ON (1)"#; - let result = parse(query).unwrap(); - assert_eq!(result.tables().len(), 30); - assert_eq!(result.statement_types(), ["SelectStmt"]); + // Run in a thread with a larger stack to avoid stack overflow in parse_raw + // when processing deeply nested JoinExpr nodes. + let handle = std::thread::Builder::new() + .stack_size(32 * 1024 * 1024) // 32 MB stack + .spawn(|| { + // The Ruby version of pg_query fails here because of Ruby protobuf limitations + let query = r#"SELECT * FROM "t0" + JOIN "t1" ON (1) JOIN "t2" ON (1) JOIN "t3" ON (1) JOIN "t4" ON (1) JOIN "t5" ON (1) + JOIN "t6" ON (1) JOIN "t7" ON (1) JOIN "t8" ON (1) JOIN "t9" ON (1) JOIN "t10" ON (1) + JOIN "t11" ON (1) JOIN "t12" ON (1) JOIN "t13" ON (1) JOIN "t14" ON (1) JOIN "t15" ON (1) + JOIN "t16" ON (1) JOIN "t17" ON (1) JOIN "t18" ON (1) JOIN "t19" ON (1) JOIN "t20" ON (1) + JOIN "t21" ON (1) JOIN "t22" ON (1) JOIN "t23" ON (1) JOIN "t24" ON (1) JOIN "t25" ON (1) + JOIN "t26" ON (1) JOIN "t27" ON (1) JOIN "t28" ON (1) JOIN "t29" ON (1)"#; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); + assert_eq!(result.tables().len(), 30); + assert_eq!(result.statement_types(), ["SelectStmt"]); + }) + .unwrap(); + handle.join().unwrap(); } #[test] @@ -77,6 +93,9 @@ fn it_parses_real_queries() { FROM snapshots s JOIN system_snapshots ON (snapshot_id = s.id) WHERE s.database_id = $0 AND s.collected_at BETWEEN $0 AND $0 ORDER BY collected_at"; + assert_parse_raw_equals_parse(query); + // Note: Skip deparse_raw roundtrip because $0 is deparsed as ? by libpg_query, + // which is not valid PostgreSQL syntax for reparsing ($0 is non-standard anyway) let result = parse(query).unwrap(); let tables: Vec = sorted(result.tables()).collect(); let select_tables: Vec = sorted(result.select_tables()).collect(); @@ -87,7 +106,10 @@ fn it_parses_real_queries() { #[test] fn it_parses_empty_queries() { - let result = parse("-- nothing").unwrap(); + let query = "-- nothing"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.protobuf.nodes().len(), 0); assert_eq!(result.tables().len(), 0); assert_eq!(result.warnings.len(), 0); @@ -96,7 +118,10 @@ fn it_parses_empty_queries() { #[test] fn it_parses_floats_with_leading_dot() { - let result = parse("SELECT .1").unwrap(); + let query = "SELECT .1"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); let select = cast!(result.protobuf.nodes()[0].0, NodeRef::SelectStmt); let target = cast!(select.target_list[0].node.as_ref().unwrap(), NodeEnum::ResTarget); let a_const = cast!(target.val.as_ref().unwrap().node.as_ref().unwrap(), NodeEnum::AConst); @@ -107,7 +132,10 @@ fn it_parses_floats_with_leading_dot() { #[test] fn it_parses_bit_strings_hex_notation() { - let result = parse("SELECT X'EFFF'").unwrap(); + let query = "SELECT X'EFFF'"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); let select = cast!(result.protobuf.nodes()[0].0, NodeRef::SelectStmt); let target = cast!(select.target_list[0].node.as_ref().unwrap(), NodeEnum::ResTarget); let a_const = cast!(target.val.as_ref().unwrap().node.as_ref().unwrap(), NodeEnum::AConst); @@ -118,7 +146,10 @@ fn it_parses_bit_strings_hex_notation() { #[test] fn it_parses_ALTER_TABLE() { - let result = parse("ALTER TABLE test ADD PRIMARY KEY (gid)").unwrap(); + let query = "ALTER TABLE test ADD PRIMARY KEY (gid)"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["test"]); assert_eq!(result.ddl_tables(), ["test"]); @@ -192,7 +223,10 @@ fn it_parses_ALTER_TABLE() { #[test] fn it_parses_SET() { - let result = parse("SET statement_timeout=1").unwrap(); + let query = "SET statement_timeout=1"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables().len(), 0); assert_eq!(result.ddl_tables().len(), 0); @@ -206,7 +240,10 @@ fn it_parses_SET() { #[test] fn it_parses_SHOW() { - let result = parse("SHOW work_mem").unwrap(); + let query = "SHOW work_mem"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables().len(), 0); assert_eq!(result.statement_types(), ["VariableShowStmt"]); @@ -216,7 +253,10 @@ fn it_parses_SHOW() { #[test] fn it_parses_COPY() { - let result = parse("COPY test (id) TO stdout").unwrap(); + let query = "COPY test (id) TO stdout"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["test"]); assert_eq!(result.statement_types(), ["CopyStmt"]); @@ -258,7 +298,10 @@ fn it_parses_COPY() { #[test] fn it_parses_DROP_TABLE() { - let result = parse("drop table abc.test123 cascade").unwrap(); + let query = "drop table abc.test123 cascade"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["abc.test123"]); assert_eq!(result.ddl_tables(), ["abc.test123"]); @@ -266,7 +309,9 @@ fn it_parses_DROP_TABLE() { let drop = cast!(result.protobuf.nodes()[0].0, NodeRef::DropStmt); assert_eq!(protobuf::DropBehavior::from_i32(drop.behavior), Some(protobuf::DropBehavior::DropCascade)); - let result = parse("drop table abc.test123, test").unwrap(); + let query2 = "drop table abc.test123, test"; + assert_parse_raw_equals_parse(query2); + let result = parse(query2).unwrap(); let tables: Vec = sorted(result.tables()).collect(); let ddl_tables: Vec = sorted(result.ddl_tables()).collect(); assert_eq!(tables, ["abc.test123", "test"]); @@ -275,7 +320,10 @@ fn it_parses_DROP_TABLE() { #[test] fn it_parses_COMMIT() { - let result = parse("COMMIT").unwrap(); + let query = "COMMIT"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.statement_types(), ["TransactionStmt"]); let stmt = cast!(result.protobuf.nodes()[0].0, NodeRef::TransactionStmt); @@ -284,7 +332,10 @@ fn it_parses_COMMIT() { #[test] fn it_parses_CHECKPOINT() { - let result = parse("CHECKPOINT").unwrap(); + let query = "CHECKPOINT"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.statement_types(), ["CheckPointStmt"]); cast!(result.protobuf.nodes()[0].0, NodeRef::CheckPointStmt); @@ -292,7 +343,10 @@ fn it_parses_CHECKPOINT() { #[test] fn it_parses_VACUUM() { - let result = parse("VACUUM my_table").unwrap(); + let query = "VACUUM my_table"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["my_table"]); assert_eq!(result.ddl_tables(), ["my_table"]); @@ -302,10 +356,10 @@ fn it_parses_VACUUM() { #[test] fn it_parses_MERGE() { - let result = parse( - "WITH cte AS (SELECT * FROM g.other_table CROSS JOIN p) MERGE INTO my_table USING cte ON (id=oid) WHEN MATCHED THEN UPDATE SET a=b WHEN NOT MATCHED THEN INSERT (id, a) VALUES (oid, b);", - ) - .unwrap(); + let query = "WITH cte AS (SELECT * FROM g.other_table CROSS JOIN p) MERGE INTO my_table USING cte ON (id=oid) WHEN MATCHED THEN UPDATE SET a=b WHEN NOT MATCHED THEN INSERT (id, a) VALUES (oid, b);"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let select_tables: Vec = sorted(result.select_tables()).collect(); @@ -325,7 +379,10 @@ fn it_parses_MERGE() { #[test] fn it_parses_EXPLAIN() { - let result = parse("EXPLAIN DELETE FROM test").unwrap(); + let query = "EXPLAIN DELETE FROM test"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["test"]); assert_eq!(result.statement_types(), ["ExplainStmt"]); @@ -335,7 +392,10 @@ fn it_parses_EXPLAIN() { #[test] fn it_parses_SELECT_INTO() { - let result = parse("CREATE TEMP TABLE test AS SELECT 1").unwrap(); + let query = "CREATE TEMP TABLE test AS SELECT 1"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["test"]); assert_eq!(result.ddl_tables(), ["test"]); @@ -374,7 +434,10 @@ fn it_parses_SELECT_INTO() { #[test] fn it_parses_LOCK() { - let result = parse("LOCK TABLE public.schema_migrations IN ACCESS SHARE MODE").unwrap(); + let query = "LOCK TABLE public.schema_migrations IN ACCESS SHARE MODE"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["public.schema_migrations"]); assert_eq!(result.statement_types(), ["LockStmt"]); @@ -384,7 +447,10 @@ fn it_parses_LOCK() { #[test] fn it_parses_CREATE_TABLE() { - let result = parse("CREATE TABLE test (a int4)").unwrap(); + let query = "CREATE TABLE test (a int4)"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["test"]); assert_eq!(result.ddl_tables(), ["test"]); @@ -440,7 +506,10 @@ fn it_parses_CREATE_TABLE() { #[test] fn it_parses_CREATE_TABLE_AS() { - let result = parse("CREATE TABLE foo AS SELECT * FROM bar;").unwrap(); + let query = "CREATE TABLE foo AS SELECT * FROM bar;"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); assert_eq!(tables, ["bar", "foo"]); @@ -448,8 +517,9 @@ fn it_parses_CREATE_TABLE_AS() { assert_eq!(result.select_tables(), ["bar"]); assert_eq!(result.statement_types(), ["CreateTableAsStmt"]); - let sql = "CREATE TABLE foo AS SELECT id FROM bar UNION SELECT id from baz;"; - let result = parse(sql).unwrap(); + let query2 = "CREATE TABLE foo AS SELECT id FROM bar UNION SELECT id from baz;"; + assert_parse_raw_equals_parse(query2); + let result = parse(query2).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); let select_tables: Vec = sorted(result.select_tables()).collect(); @@ -467,7 +537,10 @@ fn it_fails_to_parse_CREATE_TABLE_WITH_OIDS() { #[test] fn it_parses_CREATE_INDEX() { - let result = parse("CREATE INDEX testidx ON test USING btree (a, (lower(b) || upper(c))) WHERE pow(a, 2) > 25").unwrap(); + let query = "CREATE INDEX testidx ON test USING btree (a, (lower(b) || upper(c))) WHERE pow(a, 2) > 25"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["test"]); assert_eq!(result.ddl_tables(), ["test"]); @@ -480,7 +553,10 @@ fn it_parses_CREATE_INDEX() { #[test] fn it_parses_CREATE_SCHEMA() { - let result = parse("CREATE SCHEMA IF NOT EXISTS test AUTHORIZATION joe").unwrap(); + let query = "CREATE SCHEMA IF NOT EXISTS test AUTHORIZATION joe"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables().len(), 0); assert_eq!(result.statement_types(), ["CreateSchemaStmt"]); @@ -504,7 +580,10 @@ fn it_parses_CREATE_SCHEMA() { #[test] fn it_parses_CREATE_VIEW() { - let result = parse("CREATE VIEW myview AS SELECT * FROM mytab").unwrap(); + let query = "CREATE VIEW myview AS SELECT * FROM mytab"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); assert_eq!(tables, ["mytab", "myview"]); @@ -614,7 +693,10 @@ fn it_parses_CREATE_VIEW() { #[test] fn it_parses_REFRESH_MATERIALIZED_VIEW() { - let result = parse("REFRESH MATERIALIZED VIEW myview").unwrap(); + let query = "REFRESH MATERIALIZED VIEW myview"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["myview"]); assert_eq!(result.ddl_tables(), ["myview"]); @@ -624,8 +706,10 @@ fn it_parses_REFRESH_MATERIALIZED_VIEW() { #[test] fn it_parses_CREATE_RULE() { - let sql = "CREATE RULE shoe_ins_protect AS ON INSERT TO shoe DO INSTEAD NOTHING"; - let result = parse(sql).unwrap(); + let query = "CREATE RULE shoe_ins_protect AS ON INSERT TO shoe DO INSTEAD NOTHING"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["shoe"]); assert_eq!(result.ddl_tables(), ["shoe"]); @@ -637,8 +721,10 @@ fn it_parses_CREATE_RULE() { #[test] fn it_parses_CREATE_TRIGGER() { - let sql = "CREATE TRIGGER check_update BEFORE UPDATE ON accounts FOR EACH ROW EXECUTE PROCEDURE check_account_update()"; - let result = parse(sql).unwrap(); + let query = "CREATE TRIGGER check_update BEFORE UPDATE ON accounts FOR EACH ROW EXECUTE PROCEDURE check_account_update()"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["accounts"]); assert_eq!(result.ddl_tables(), ["accounts"]); @@ -652,7 +738,10 @@ fn it_parses_CREATE_TRIGGER() { #[test] fn it_parses_DROP_SCHEMA() { - let result = parse("DROP SCHEMA myschema").unwrap(); + let query = "DROP SCHEMA myschema"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables().len(), 0); assert_eq!(result.statement_types(), ["DropStmt"]); @@ -681,7 +770,10 @@ fn it_parses_DROP_SCHEMA() { #[test] fn it_parses_DROP_VIEW() { - let result = parse("DROP VIEW myview, myview2").unwrap(); + let query = "DROP VIEW myview, myview2"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables().len(), 0); assert_eq!(result.statement_types(), ["DropStmt"]); @@ -739,7 +831,10 @@ fn it_parses_DROP_VIEW() { #[test] fn it_parses_DROP_INDEX() { - let result = parse("DROP INDEX CONCURRENTLY myindex").unwrap(); + let query = "DROP INDEX CONCURRENTLY myindex"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables().len(), 0); assert_eq!(result.statement_types(), ["DropStmt"]); @@ -778,7 +873,10 @@ fn it_parses_DROP_INDEX() { #[test] fn it_parses_DROP_RULE() { - let result = parse("DROP RULE myrule ON mytable CASCADE").unwrap(); + let query = "DROP RULE myrule ON mytable CASCADE"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["mytable"]); assert_eq!(result.ddl_tables(), ["mytable"]); @@ -827,7 +925,10 @@ fn it_parses_DROP_RULE() { #[test] fn it_parses_DROP_TRIGGER() { - let result = parse("DROP TRIGGER IF EXISTS mytrigger ON mytable RESTRICT").unwrap(); + let query = "DROP TRIGGER IF EXISTS mytrigger ON mytable RESTRICT"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["mytable"]); assert_eq!(result.ddl_tables(), ["mytable"]); @@ -876,7 +977,10 @@ fn it_parses_DROP_TRIGGER() { #[test] fn it_parses_GRANT() { - let result = parse("GRANT INSERT, UPDATE ON mytable TO myuser").unwrap(); + let query = "GRANT INSERT, UPDATE ON mytable TO myuser"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["mytable"]); assert_eq!(result.ddl_tables(), ["mytable"]); @@ -949,7 +1053,10 @@ fn it_parses_GRANT() { #[test] fn it_parses_REVOKE() { - let result = parse("REVOKE admins FROM joe").unwrap(); + let query = "REVOKE admins FROM joe"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables().len(), 0); assert_eq!(result.statement_types(), ["GrantRoleStmt"]); @@ -992,7 +1099,10 @@ fn it_parses_REVOKE() { #[test] fn it_parses_TRUNCATE() { - let result = parse(r#"TRUNCATE bigtable, "fattable" RESTART IDENTITY"#).unwrap(); + let query = r#"TRUNCATE bigtable, "fattable" RESTART IDENTITY"#; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); let ddl_tables: Vec = sorted(result.ddl_tables()).collect(); @@ -1043,7 +1153,10 @@ fn it_parses_TRUNCATE() { #[test] fn it_parses_WITH() { - let result = parse("WITH a AS (SELECT * FROM x WHERE x.y = $1 AND x.z = 1) SELECT * FROM a").unwrap(); + let query = "WITH a AS (SELECT * FROM x WHERE x.y = $1 AND x.z = 1) SELECT * FROM a"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["x"]); assert_eq!(result.cte_names, ["a"]); @@ -1052,7 +1165,7 @@ fn it_parses_WITH() { #[test] fn it_parses_multi_line_functions() { - let sql = "CREATE OR REPLACE FUNCTION thing(parameter_thing text) + let query = "CREATE OR REPLACE FUNCTION thing(parameter_thing text) RETURNS bigint AS $BODY$ DECLARE @@ -1070,7 +1183,9 @@ BEGIN END; $BODY$ LANGUAGE plpgsql STABLE"; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables().len(), 0); assert_eq!(result.functions(), ["thing"]); @@ -1248,8 +1363,10 @@ $BODY$ #[test] fn it_parses_table_functions() { - let sql = "CREATE FUNCTION getfoo(int) RETURNS TABLE (f1 int) AS 'SELECT * FROM foo WHERE fooid = $1;' LANGUAGE SQL"; - let result = parse(sql).unwrap(); + let query = "CREATE FUNCTION getfoo(int) RETURNS TABLE (f1 int) AS 'SELECT * FROM foo WHERE fooid = $1;' LANGUAGE SQL"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables().len(), 0); assert_eq!(result.functions(), ["getfoo"]); @@ -1260,7 +1377,10 @@ fn it_parses_table_functions() { #[test] fn it_finds_called_functions() { - let result = parse("SELECT testfunc(1);").unwrap(); + let query = "SELECT testfunc(1);"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables().len(), 0); assert_eq!(result.functions(), ["testfunc"]); @@ -1271,7 +1391,10 @@ fn it_finds_called_functions() { #[test] fn it_finds_functions_invoked_with_CALL() { - let result = parse("CALL testfunc(1);").unwrap(); + let query = "CALL testfunc(1);"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables().len(), 0); assert_eq!(result.functions(), ["testfunc"]); @@ -1282,7 +1405,10 @@ fn it_finds_functions_invoked_with_CALL() { #[test] fn it_finds_dropped_functions() { - let result = parse("DROP FUNCTION IF EXISTS testfunc(x integer);").unwrap(); + let query = "DROP FUNCTION IF EXISTS testfunc(x integer);"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables().len(), 0); assert_eq!(result.functions(), ["testfunc"]); @@ -1293,7 +1419,10 @@ fn it_finds_dropped_functions() { #[test] fn it_finds_renamed_functions() { - let result = parse("ALTER FUNCTION testfunc(integer) RENAME TO testfunc2;").unwrap(); + let query = "ALTER FUNCTION testfunc(integer) RENAME TO testfunc2;"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables().len(), 0); let functions: Vec = sorted(result.functions()).collect(); @@ -1307,8 +1436,10 @@ fn it_finds_renamed_functions() { // https://github.com/pganalyze/pg_query/issues/38 #[test] fn it_finds_nested_tables_in_SELECT() { - let sql = "select u.email, (select count(*) from enrollments e where e.user_id = u.id) as num_enrollments from users u"; - let result = parse(sql).unwrap(); + let query = "select u.email, (select count(*) from enrollments e where e.user_id = u.id) as num_enrollments from users u"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); let select_tables: Vec = sorted(result.select_tables()).collect(); @@ -1320,8 +1451,10 @@ fn it_finds_nested_tables_in_SELECT() { // https://github.com/pganalyze/pg_query/issues/52 #[test] fn it_separates_CTE_names_from_table_names() { - let sql = "WITH cte_name AS (SELECT 1) SELECT * FROM table_name, cte_name"; - let result = parse(sql).unwrap(); + let query = "WITH cte_name AS (SELECT 1) SELECT * FROM table_name, cte_name"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["table_name"]); assert_eq!(result.select_tables(), ["table_name"]); @@ -1331,7 +1464,10 @@ fn it_separates_CTE_names_from_table_names() { #[test] fn it_finds_nested_tables_in_FROM_clause() { - let result = parse("select u.* from (select * from users) u").unwrap(); + let query = "select u.* from (select * from users) u"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["users"]); assert_eq!(result.select_tables(), ["users"]); @@ -1340,7 +1476,10 @@ fn it_finds_nested_tables_in_FROM_clause() { #[test] fn it_finds_nested_tables_in_WHERE_clause() { - let result = parse("select users.id from users where 1 = (select count(*) from user_roles)").unwrap(); + let query = "select users.id from users where 1 = (select count(*) from user_roles)"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); let select_tables: Vec = sorted(result.select_tables()).collect(); @@ -1360,6 +1499,8 @@ fn it_finds_tables_in_SELECT_with_subselects_without_FROM() { SELECT 17663 AS oid ) vals ON c.oid = vals.oid "; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["pg_catalog.pg_class"]); @@ -1371,13 +1512,15 @@ fn it_finds_tables_in_SELECT_with_subselects_without_FROM() { #[test] fn it_finds_nested_tables_in_IN_clause() { - let sql = " + let query = " select users.* from users where users.id IN (select user_roles.user_id from user_roles) and (users.created_at between '2016-06-01' and '2016-06-30') "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); let select_tables: Vec = sorted(result.select_tables()).collect(); @@ -1388,7 +1531,7 @@ fn it_finds_nested_tables_in_IN_clause() { #[test] fn it_finds_nested_tables_in_ORDER_BY_clause() { - let sql = " + let query = " select users.* from users order by ( @@ -1397,7 +1540,9 @@ fn it_finds_nested_tables_in_ORDER_BY_clause() { where user_roles.user_id = users.id ) "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); let select_tables: Vec = sorted(result.select_tables()).collect(); @@ -1408,7 +1553,7 @@ fn it_finds_nested_tables_in_ORDER_BY_clause() { #[test] fn it_finds_nested_tables_in_ORDER_BY_clause_with_multiple_entries() { - let sql = " + let query = " select users.* from users order by ( @@ -1421,7 +1566,9 @@ fn it_finds_nested_tables_in_ORDER_BY_clause_with_multiple_entries() { where user_logins.user_id = users.id ) desc "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); let select_tables: Vec = sorted(result.select_tables()).collect(); @@ -1432,7 +1579,7 @@ fn it_finds_nested_tables_in_ORDER_BY_clause_with_multiple_entries() { #[test] fn it_finds_nested_tables_in_GROUP_BY_clause() { - let sql = " + let query = " select users.* from users group by ( @@ -1441,7 +1588,9 @@ fn it_finds_nested_tables_in_GROUP_BY_clause() { where user_roles.user_id = users.id ) "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); let select_tables: Vec = sorted(result.select_tables()).collect(); @@ -1452,7 +1601,7 @@ fn it_finds_nested_tables_in_GROUP_BY_clause() { #[test] fn it_finds_nested_tables_in_GROUP_BY_clause_with_multiple_entries() { - let sql = " + let query = " select users.* from users group by ( @@ -1465,7 +1614,9 @@ fn it_finds_nested_tables_in_GROUP_BY_clause_with_multiple_entries() { where user_logins.user_id = users.id ) "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); let select_tables: Vec = sorted(result.select_tables()).collect(); @@ -1476,7 +1627,7 @@ fn it_finds_nested_tables_in_GROUP_BY_clause_with_multiple_entries() { #[test] fn it_finds_nested_tables_in_HAVING_clause() { - let sql = " + let query = " select users.* from users group by users.id @@ -1486,7 +1637,9 @@ fn it_finds_nested_tables_in_HAVING_clause() { where user_roles.user_id = users.id ) "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); let select_tables: Vec = sorted(result.select_tables()).collect(); @@ -1497,7 +1650,7 @@ fn it_finds_nested_tables_in_HAVING_clause() { #[test] fn it_finds_nested_tables_in_HAVING_clause_with_boolean_expression() { - let sql = " + let query = " select users.* from users group by users.id @@ -1507,7 +1660,9 @@ fn it_finds_nested_tables_in_HAVING_clause_with_boolean_expression() { where user_roles.user_id = users.id ) "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); let select_tables: Vec = sorted(result.select_tables()).collect(); @@ -1518,13 +1673,15 @@ fn it_finds_nested_tables_in_HAVING_clause_with_boolean_expression() { #[test] fn it_finds_nested_tables_in_a_subselect_on_a_JOIN() { - let sql = " + let query = " select foo.* from foo join ( select * from bar ) b on b.baz = foo.quux "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); let select_tables: Vec = sorted(result.select_tables()).collect(); @@ -1535,7 +1692,7 @@ fn it_finds_nested_tables_in_a_subselect_on_a_JOIN() { #[test] fn it_finds_nested_tables_in_a_subselect_in_a_JOIN_condition() { - let sql = " + let query = " SELECT * FROM foo INNER JOIN join_a ON foo.id = join_a.id AND join_a.id IN ( @@ -1551,7 +1708,9 @@ fn it_finds_nested_tables_in_a_subselect_in_a_JOIN_condition() { SELECT id FROM sub_f ) "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); let select_tables: Vec = sorted(result.select_tables()).collect(); @@ -1562,7 +1721,7 @@ fn it_finds_nested_tables_in_a_subselect_in_a_JOIN_condition() { #[test] fn it_correctly_categorizes_CTEs_after_UNION_SELECT() { - let sql = " + let query = " WITH cte_a AS ( SELECT * FROM table_a ), cte_b AS ( @@ -1573,7 +1732,9 @@ fn it_correctly_categorizes_CTEs_after_UNION_SELECT() { UNION SELECT * FROM cte_a "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); let cte_names: Vec = sorted(result.cte_names.clone()).collect(); @@ -1584,7 +1745,7 @@ fn it_correctly_categorizes_CTEs_after_UNION_SELECT() { #[test] fn it_correctly_categorizes_CTEs_after_EXCEPT_SELECT() { - let sql = " + let query = " WITH cte_a AS ( SELECT * FROM table_a ), cte_b AS ( @@ -1595,7 +1756,9 @@ fn it_correctly_categorizes_CTEs_after_EXCEPT_SELECT() { EXCEPT SELECT * FROM cte_a "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); let cte_names: Vec = sorted(result.cte_names.clone()).collect(); @@ -1606,7 +1769,7 @@ fn it_correctly_categorizes_CTEs_after_EXCEPT_SELECT() { #[test] fn it_correctly_categorizes_CTEs_after_INTERSECT_SELECT() { - let sql = " + let query = " WITH cte_a AS ( SELECT * FROM table_a ), cte_b AS ( @@ -1617,7 +1780,9 @@ fn it_correctly_categorizes_CTEs_after_INTERSECT_SELECT() { INTERSECT SELECT * FROM cte_a "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); let cte_names: Vec = sorted(result.cte_names.clone()).collect(); @@ -1628,7 +1793,7 @@ fn it_correctly_categorizes_CTEs_after_INTERSECT_SELECT() { #[test] fn it_finds_tables_inside_subselectes_in_MIN_MAX_COALESCE() { - let sql = " + let query = " SELECT GREATEST( date_trunc($1, $2::timestamptz) + $3::interval, COALESCE( @@ -1641,7 +1806,9 @@ fn it_finds_tables_inside_subselectes_in_MIN_MAX_COALESCE() { ) ) AS first_hourly_start_ts "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["schema_aggregate_infos"]); assert_eq!(result.select_tables(), ["schema_aggregate_infos"]); @@ -1650,7 +1817,7 @@ fn it_finds_tables_inside_subselectes_in_MIN_MAX_COALESCE() { #[test] fn it_finds_tables_inside_CASE_statements() { - let sql = " + let query = " SELECT CASE WHEN id IN (SELECT foo_id FROM when_a) THEN (SELECT MAX(id) FROM then_a) @@ -1659,7 +1826,9 @@ fn it_finds_tables_inside_CASE_statements() { END FROM foo "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); let select_tables: Vec = sorted(result.select_tables()).collect(); @@ -1670,13 +1839,15 @@ fn it_finds_tables_inside_CASE_statements() { #[test] fn it_finds_tables_inside_casts() { - let sql = " + let query = " SELECT 1 FROM foo WHERE x = any(cast(array(SELECT a FROM bar) as bigint[])) OR x = any(array(SELECT a FROM baz)::bigint[]) "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); assert_eq!(tables, ["bar", "baz", "foo"]); @@ -1687,8 +1858,10 @@ fn it_finds_tables_inside_casts() { #[test] fn it_finds_functions_in_FROM_clause() { - let sql = "SELECT * FROM my_custom_func()"; - let result = parse(sql).unwrap(); + let query = "SELECT * FROM my_custom_func()"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables().len(), 0); assert_eq!(result.functions(), ["my_custom_func"]); @@ -1698,7 +1871,7 @@ fn it_finds_functions_in_FROM_clause() { #[test] fn it_finds_functions_in_LATERAL_clause() { - let sql = " + let query = " SELECT * FROM unnest($1::text[]) AS a(x) LEFT OUTER JOIN LATERAL ( @@ -1713,7 +1886,9 @@ fn it_finds_functions_in_LATERAL_clause() { ) f ) AS g ON (1) "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["public.c"]); let functions: Vec = sorted(result.functions()).collect(); @@ -1725,22 +1900,27 @@ fn it_finds_functions_in_LATERAL_clause() { #[test] fn it_parses_INSERT() { - let result = parse("insert into users(pk, name) values (1, 'bob');").unwrap(); + let query1 = "insert into users(pk, name) values (1, 'bob');"; + assert_parse_raw_equals_parse(query1); + let result = parse(query1).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["users"]); - let result = parse("insert into users(pk, name) select pk, name from other_users;").unwrap(); + let query2 = "insert into users(pk, name) select pk, name from other_users;"; + assert_parse_raw_equals_parse(query2); + let result = parse(query2).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); assert_eq!(tables, ["other_users", "users"]); - let sql = " + let query3 = " with cte as ( select pk, name from other_users ) insert into users(pk, name) select * from cte; "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query3); + let result = parse(query3).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); assert_eq!(tables, ["other_users", "users"]); @@ -1752,24 +1932,29 @@ fn it_parses_INSERT() { #[test] fn it_parses_UPDATE() { - let result = parse("update users set name = 'bob';").unwrap(); + let query1 = "update users set name = 'bob';"; + assert_parse_raw_equals_parse(query1); + let result = parse(query1).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["users"]); assert_eq!(result.statement_types(), ["UpdateStmt"]); - let result = parse("update users set name = (select name from other_users limit 1);").unwrap(); + let query2 = "update users set name = (select name from other_users limit 1);"; + assert_parse_raw_equals_parse(query2); + let result = parse(query2).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); assert_eq!(tables, ["other_users", "users"]); assert_eq!(result.statement_types(), ["UpdateStmt"]); - let sql = " + let query3 = " with cte as ( select name from other_users limit 1 ) update users set name = (select name from cte); "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query3); + let result = parse(query3).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); assert_eq!(tables, ["other_users", "users"]); @@ -1778,13 +1963,14 @@ fn it_parses_UPDATE() { assert_eq!(result.cte_names, ["cte"]); assert_eq!(result.statement_types(), ["UpdateStmt"]); - let sql = " + let query4 = " UPDATE users SET name = users_new.name FROM users_new INNER JOIN join_table ON join_table.user_id = new_users.id WHERE users.id = users_new.id "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query4); + let result = parse(query4).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); let select_tables: Vec = sorted(result.select_tables()).collect(); @@ -1796,17 +1982,20 @@ fn it_parses_UPDATE() { #[test] fn it_parses_DELETE() { - let result = parse("DELETE FROM users;").unwrap(); + let query1 = "DELETE FROM users;"; + assert_parse_raw_equals_parse(query1); + let result = parse(query1).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.tables(), ["users"]); assert_eq!(result.dml_tables(), ["users"]); assert_eq!(result.statement_types(), ["DeleteStmt"]); - let sql = " + let query2 = " DELETE FROM users USING foo WHERE foo_id = foo.id AND foo.action = 'delete'; "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query2); + let result = parse(query2).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); assert_eq!(tables, ["foo", "users"]); @@ -1814,11 +2003,12 @@ fn it_parses_DELETE() { assert_eq!(result.select_tables(), ["foo"]); assert_eq!(result.statement_types(), ["DeleteStmt"]); - let sql = " + let query3 = " DELETE FROM users WHERE foo_id IN (SELECT id FROM foo WHERE action = 'delete'); "; - let result = parse(sql).unwrap(); + assert_parse_raw_equals_parse(query3); + let result = parse(query3).unwrap(); assert_eq!(result.warnings.len(), 0); let tables: Vec = sorted(result.tables()).collect(); assert_eq!(tables, ["foo", "users"]); @@ -1829,7 +2019,10 @@ fn it_parses_DELETE() { #[test] fn it_parses_DROP_TYPE() { - let result = parse("DROP TYPE IF EXISTS repack.pk_something").unwrap(); + let query = "DROP TYPE IF EXISTS repack.pk_something"; + assert_parse_raw_equals_parse(query); + assert_deparse_raw_roundtrip(query); + let result = parse(query).unwrap(); assert_eq!(result.warnings.len(), 0); assert_eq!(result.statement_types(), ["DropStmt"]); assert_debug_eq!( diff --git a/tests/raw_parse/basic.rs b/tests/raw_parse/basic.rs index 76c52dc..f33039a 100644 --- a/tests/raw_parse/basic.rs +++ b/tests/raw_parse/basic.rs @@ -212,3 +212,66 @@ fn it_deparse_raw_multiple_statements() { let deparsed = pg_query::deparse_raw(&result.protobuf).unwrap(); assert_eq!(deparsed, query); } + +// ============================================================================ +// deparse_raw method tests (on structs) +// ============================================================================ + +/// Test that ParseResult.deparse_raw() method works +#[test] +fn it_deparse_raw_method_on_parse_result() { + let query = "SELECT * FROM users WHERE id = 1"; + let result = pg_query::parse(query).unwrap(); + // Test the new method on ParseResult + let deparsed = result.deparse_raw().unwrap(); + assert_eq!(deparsed, query); +} + +/// Test that protobuf::ParseResult.deparse_raw() method works +#[test] +fn it_deparse_raw_method_on_protobuf_parse_result() { + let query = "SELECT a, b, c FROM table1 JOIN table2 ON table1.id = table2.id"; + let result = pg_query::parse(query).unwrap(); + // Test the new method on protobuf::ParseResult + let deparsed = result.protobuf.deparse_raw().unwrap(); + assert_eq!(deparsed, query); +} + +/// Test that NodeRef.deparse_raw() method works +#[test] +fn it_deparse_raw_method_on_node_ref() { + let query = "SELECT * FROM users"; + let result = pg_query::parse(query).unwrap(); + // Get the first node (SelectStmt) + let nodes = result.protobuf.nodes(); + assert!(!nodes.is_empty()); + // Find the SelectStmt node + for (node, _depth, _context, _has_filter) in nodes { + if let pg_query::NodeRef::SelectStmt(_) = node { + let deparsed = node.deparse_raw().unwrap(); + assert_eq!(deparsed, query); + return; + } + } + panic!("SelectStmt node not found"); +} + +/// Test that deparse_raw method produces same result as deparse method +#[test] +fn it_deparse_raw_matches_deparse() { + let queries = vec![ + "SELECT 1", + "SELECT * FROM users", + "INSERT INTO t (a) VALUES (1)", + "UPDATE t SET a = 1 WHERE b = 2", + "DELETE FROM t WHERE id = 1", + "SELECT a, b FROM t1 JOIN t2 ON t1.id = t2.id WHERE t1.x > 5 ORDER BY a", + ]; + + for query in queries { + let result = pg_query::parse(query).unwrap(); + let deparse_result = result.deparse().unwrap(); + let deparse_raw_result = result.deparse_raw().unwrap(); + assert_eq!(deparse_result, deparse_raw_result); + } +} diff --git a/tests/support.rs b/tests/support.rs index 0f03869..89ef48f 100644 --- a/tests/support.rs +++ b/tests/support.rs @@ -38,6 +38,40 @@ pub fn assert_vec_matches(a: &Vec, b: &Vec) { assert!(matching == a.len() && matching == b.len()) } +/// Verifies that parse and parse_raw produce identical protobuf results +pub fn assert_parse_raw_equals_parse(query: &str) { + let parse_result = pg_query::parse(query).expect("parse failed"); + let parse_raw_result = pg_query::parse_raw(query).expect("parse_raw failed"); + assert!(parse_result.protobuf == parse_raw_result.protobuf, "parse and parse_raw produced different protobufs for query: {query}"); +} + +/// Verifies that deparse_raw produces valid SQL that can be reparsed. +/// We compare fingerprints rather than full protobuf equality because: +/// 1. Location fields will differ (character offsets change with reformatting) +/// 2. Fingerprints capture the semantic content of the query +pub fn assert_deparse_raw_roundtrip(query: &str) { + let parse_result = pg_query::parse(query).expect("parse failed"); + let deparsed = pg_query::deparse_raw(&parse_result.protobuf).expect("deparse_raw failed"); + let reparsed = pg_query::parse(&deparsed).expect(&format!("reparsing deparsed SQL failed: {}", deparsed)); + + // Compare fingerprints for semantic equality + let original_fp = pg_query::fingerprint(query).expect("fingerprint failed").hex; + let reparsed_fp = pg_query::fingerprint(&deparsed).expect("reparsed fingerprint failed").hex; + assert!( + original_fp == reparsed_fp, + "deparse_raw roundtrip produced different fingerprint for query: {query}\ndeparsed as: {deparsed}\noriginal fp: {}\nreparsed fp: {}", + original_fp, + reparsed_fp + ); + + // Also verify statement types match + std::assert_eq!( + parse_result.statement_types(), + reparsed.statement_types(), + "deparse_raw roundtrip produced different statement types for query: {query}\ndeparsed as: {deparsed}" + ); +} + macro_rules! cast { ($target: expr, $pat: path) => {{ if let $pat(a) = $target { From 26a61b3252c90f59804e874baa38a320600d6e9d Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 31 Dec 2025 19:05:21 -0800 Subject: [PATCH 14/17] finish up --- .cargo/config.toml | 3 + build.rs | 172 +++++ src/raw_deparse.rs | 1628 ++++++++++++++++++++++++++++++++++++++++++-- src/raw_parse.rs | 1385 +++++++++++++++++++++++++++++++++++++ 4 files changed, 3149 insertions(+), 39 deletions(-) create mode 100644 .cargo/config.toml diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..80c8bb9 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,3 @@ +[env] +# Increase default stack size for tests due to deep recursion in parse tree conversion +RUST_MIN_STACK = "16777216" diff --git a/build.rs b/build.rs index 0d7c895..083898c 100644 --- a/build.rs +++ b/build.rs @@ -183,6 +183,8 @@ fn main() -> Result<(), Box> { .allowlist_type("AlterOwnerStmt") .allowlist_type("AlterSeqStmt") .allowlist_type("CreateEnumStmt") + .allowlist_type("AlterEnumStmt") + .allowlist_type("CreateRangeStmt") .allowlist_type("DoStmt") .allowlist_type("RenameStmt") .allowlist_type("NotifyStmt") @@ -209,6 +211,158 @@ fn main() -> Result<(), Box> { .allowlist_type("Float") .allowlist_type("Boolean") .allowlist_type("BitString") + // Additional statement types + .allowlist_type("DeclareCursorStmt") + .allowlist_type("DefineStmt") + .allowlist_type("CommentStmt") + .allowlist_type("SecLabelStmt") + .allowlist_type("CreateStatsStmt") + .allowlist_type("AlterStatsStmt") + .allowlist_type("StatsElem") + .allowlist_type("CreateRoleStmt") + .allowlist_type("AlterRoleStmt") + .allowlist_type("AlterRoleSetStmt") + .allowlist_type("DropRoleStmt") + .allowlist_type("CreatePolicyStmt") + .allowlist_type("AlterPolicyStmt") + .allowlist_type("CreateEventTrigStmt") + .allowlist_type("AlterEventTrigStmt") + .allowlist_type("CreatePLangStmt") + .allowlist_type("CreateAmStmt") + .allowlist_type("CreateOpClassStmt") + .allowlist_type("CreateOpClassItem") + .allowlist_type("CreateOpFamilyStmt") + .allowlist_type("AlterOpFamilyStmt") + .allowlist_type("CreateFdwStmt") + .allowlist_type("AlterFdwStmt") + .allowlist_type("CreateForeignServerStmt") + .allowlist_type("AlterForeignServerStmt") + .allowlist_type("CreateForeignTableStmt") + .allowlist_type("CreateUserMappingStmt") + .allowlist_type("AlterUserMappingStmt") + .allowlist_type("DropUserMappingStmt") + .allowlist_type("ImportForeignSchemaStmt") + .allowlist_type("CreateTableSpaceStmt") + .allowlist_type("DropTableSpaceStmt") + .allowlist_type("AlterTableSpaceOptionsStmt") + .allowlist_type("AlterTableMoveAllStmt") + .allowlist_type("AlterExtensionStmt") + .allowlist_type("AlterExtensionContentsStmt") + .allowlist_type("AlterDomainStmt") + .allowlist_type("AlterFunctionStmt") + .allowlist_type("AlterOperatorStmt") + .allowlist_type("AlterTypeStmt") + .allowlist_type("AlterObjectSchemaStmt") + .allowlist_type("AlterObjectDependsStmt") + .allowlist_type("AlterCollationStmt") + .allowlist_type("AlterDefaultPrivilegesStmt") + .allowlist_type("CreateCastStmt") + .allowlist_type("CreateTransformStmt") + .allowlist_type("CreateConversionStmt") + .allowlist_type("AlterTSDictionaryStmt") + .allowlist_type("AlterTSConfigurationStmt") + .allowlist_type("CreatedbStmt") + .allowlist_type("DropdbStmt") + .allowlist_type("AlterDatabaseStmt") + .allowlist_type("AlterDatabaseSetStmt") + .allowlist_type("AlterDatabaseRefreshCollStmt") + .allowlist_type("AlterSystemStmt") + .allowlist_type("ClusterStmt") + .allowlist_type("ReindexStmt") + .allowlist_type("ConstraintsSetStmt") + .allowlist_type("LoadStmt") + .allowlist_type("DropOwnedStmt") + .allowlist_type("ReassignOwnedStmt") + .allowlist_type("DropSubscriptionStmt") + // Table-related nodes + .allowlist_type("TableFunc") + .allowlist_type("TableLikeClause") + .allowlist_type("RangeTableFunc") + .allowlist_type("RangeTableFuncCol") + .allowlist_type("RangeTableSample") + .allowlist_type("PartitionCmd") + .allowlist_type("SinglePartitionSpec") + // Expression nodes + .allowlist_type("Aggref") + .allowlist_type("Var") + .allowlist_type("Param") + .allowlist_type("WindowFunc") + .allowlist_type("GroupingFunc") + .allowlist_type("FuncExpr") + .allowlist_type("NamedArgExpr") + .allowlist_type("OpExpr") + .allowlist_type("DistinctExpr") + .allowlist_type("NullIfExpr") + .allowlist_type("ScalarArrayOpExpr") + .allowlist_type("FieldSelect") + .allowlist_type("FieldStore") + .allowlist_type("RelabelType") + .allowlist_type("CoerceViaIO") + .allowlist_type("ArrayCoerceExpr") + .allowlist_type("ConvertRowtypeExpr") + .allowlist_type("CollateExpr") + .allowlist_type("CaseTestExpr") + .allowlist_type("ArrayExpr") + .allowlist_type("RowCompareExpr") + .allowlist_type("CoerceToDomainValue") + .allowlist_type("CurrentOfExpr") + .allowlist_type("NextValueExpr") + .allowlist_type("InferenceElem") + .allowlist_type("SubscriptingRef") + .allowlist_type("SQLValueFunction") + .allowlist_type("XmlExpr") + .allowlist_type("XmlSerialize") + // Query/Plan nodes + .allowlist_type("SubPlan") + .allowlist_type("AlternativeSubPlan") + .allowlist_type("TargetEntry") + .allowlist_type("RangeTblRef") + .allowlist_type("FromExpr") + .allowlist_type("OnConflictExpr") + .allowlist_type("Query") + .allowlist_type("SetOperationStmt") + .allowlist_type("ReturnStmt") + .allowlist_type("PLAssignStmt") + .allowlist_type("WindowClause") + .allowlist_type("RowMarkClause") + .allowlist_type("WithCheckOption") + .allowlist_type("RangeTblEntry") + .allowlist_type("RangeTblFunction") + .allowlist_type("TableSampleClause") + // JSON nodes + .allowlist_type("JsonFormat") + .allowlist_type("JsonReturning") + .allowlist_type("JsonValueExpr") + .allowlist_type("JsonConstructorExpr") + .allowlist_type("JsonIsPredicate") + .allowlist_type("JsonBehavior") + .allowlist_type("JsonExpr") + .allowlist_type("JsonTablePath") + .allowlist_type("JsonTablePathScan") + .allowlist_type("JsonTableSiblingJoin") + .allowlist_type("JsonOutput") + .allowlist_type("JsonArgument") + .allowlist_type("JsonFuncExpr") + .allowlist_type("JsonTablePathSpec") + .allowlist_type("JsonTable") + .allowlist_type("JsonTableColumn") + .allowlist_type("JsonKeyValue") + .allowlist_type("JsonParseExpr") + .allowlist_type("JsonScalarExpr") + .allowlist_type("JsonSerializeExpr") + .allowlist_type("JsonObjectConstructor") + .allowlist_type("JsonArrayConstructor") + .allowlist_type("JsonArrayQueryConstructor") + .allowlist_type("JsonAggConstructor") + .allowlist_type("JsonObjectAgg") + .allowlist_type("JsonArrayAgg") + // Other nodes + .allowlist_type("TriggerTransition") + .allowlist_type("InlineCodeBlock") + .allowlist_type("CallContext") + .allowlist_type("ReplicaIdentityStmt") + .allowlist_type("WindowFuncRunCondition") + .allowlist_type("MergeSupportFunc") // Allowlist enums .allowlist_type("SetOperation") .allowlist_type("LimitOption") @@ -245,6 +399,24 @@ fn main() -> Result<(), Box> { .allowlist_type("OverridingKind") .allowlist_type("PartitionStrategy") .allowlist_type("PartitionRangeDatumKind") + .allowlist_type("ReindexObjectType") + .allowlist_type("AlterSubscriptionType") + .allowlist_type("AlterPublicationAction") + .allowlist_type("ImportForeignSchemaType") + .allowlist_type("RoleStmtType") + .allowlist_type("RowCompareType") + .allowlist_type("XmlExprOp") + .allowlist_type("XmlOptionType") + .allowlist_type("JsonFormatType") + .allowlist_type("JsonConstructorType") + .allowlist_type("JsonValueType") + .allowlist_type("JsonTableColumnType") + .allowlist_type("JsonQuotes") + .allowlist_type("JsonExprOp") + .allowlist_type("JsonEncoding") + .allowlist_type("JsonWrapper") + .allowlist_type("SQLValueFunctionOp") + .allowlist_type("TableLikeOption") // Allowlist raw parse functions .allowlist_function("pg_query_parse_raw") .allowlist_function("pg_query_parse_raw_opts") diff --git a/src/raw_deparse.rs b/src/raw_deparse.rs index b9182a6..ef76d08 100644 --- a/src/raw_deparse.rs +++ b/src/raw_deparse.rs @@ -209,11 +209,212 @@ fn write_node_inner(node: &protobuf::node::Node) -> *mut bindings_raw::Node { protobuf::node::Node::CreateTableAsStmt(ctas) => write_create_table_as_stmt(ctas) as *mut bindings_raw::Node, protobuf::node::Node::RefreshMatViewStmt(rmvs) => write_refresh_mat_view_stmt(rmvs) as *mut bindings_raw::Node, protobuf::node::Node::VacuumRelation(vr) => write_vacuum_relation(vr) as *mut bindings_raw::Node, - // TODO: Add remaining node types as needed - _ => { - // For unimplemented nodes, return null and let the deparser handle it + // Simple statement nodes + protobuf::node::Node::ListenStmt(ls) => write_listen_stmt(ls) as *mut bindings_raw::Node, + protobuf::node::Node::UnlistenStmt(us) => write_unlisten_stmt(us) as *mut bindings_raw::Node, + protobuf::node::Node::NotifyStmt(ns) => write_notify_stmt(ns) as *mut bindings_raw::Node, + protobuf::node::Node::DiscardStmt(ds) => write_discard_stmt(ds) as *mut bindings_raw::Node, + // Type definition nodes + protobuf::node::Node::CompositeTypeStmt(cts) => write_composite_type_stmt(cts) as *mut bindings_raw::Node, + protobuf::node::Node::CreateEnumStmt(ces) => write_create_enum_stmt(ces) as *mut bindings_raw::Node, + protobuf::node::Node::CreateRangeStmt(crs) => write_create_range_stmt(crs) as *mut bindings_raw::Node, + protobuf::node::Node::AlterEnumStmt(aes) => write_alter_enum_stmt(aes) as *mut bindings_raw::Node, + protobuf::node::Node::CreateDomainStmt(cds) => write_create_domain_stmt(cds) as *mut bindings_raw::Node, + // Extension nodes + protobuf::node::Node::CreateExtensionStmt(ces) => write_create_extension_stmt(ces) as *mut bindings_raw::Node, + // Publication/Subscription nodes + protobuf::node::Node::CreatePublicationStmt(cps) => write_create_publication_stmt(cps) as *mut bindings_raw::Node, + protobuf::node::Node::AlterPublicationStmt(aps) => write_alter_publication_stmt(aps) as *mut bindings_raw::Node, + protobuf::node::Node::CreateSubscriptionStmt(css) => write_create_subscription_stmt(css) as *mut bindings_raw::Node, + protobuf::node::Node::AlterSubscriptionStmt(ass) => write_alter_subscription_stmt(ass) as *mut bindings_raw::Node, + // Expression nodes + protobuf::node::Node::CoerceToDomain(ctd) => write_coerce_to_domain(ctd) as *mut bindings_raw::Node, + // Sequence nodes + protobuf::node::Node::CreateSeqStmt(css) => write_create_seq_stmt(css) as *mut bindings_raw::Node, + protobuf::node::Node::AlterSeqStmt(ass) => write_alter_seq_stmt(ass) as *mut bindings_raw::Node, + // Cursor nodes + protobuf::node::Node::ClosePortalStmt(cps) => write_close_portal_stmt(cps) as *mut bindings_raw::Node, + protobuf::node::Node::FetchStmt(fs) => write_fetch_stmt(fs) as *mut bindings_raw::Node, + protobuf::node::Node::DeclareCursorStmt(dcs) => write_declare_cursor_stmt(dcs) as *mut bindings_raw::Node, + // Additional DDL statements + protobuf::node::Node::DefineStmt(ds) => write_define_stmt(ds) as *mut bindings_raw::Node, + protobuf::node::Node::CommentStmt(cs) => write_comment_stmt(cs) as *mut bindings_raw::Node, + protobuf::node::Node::SecLabelStmt(sls) => write_sec_label_stmt(sls) as *mut bindings_raw::Node, + protobuf::node::Node::CreateRoleStmt(crs) => write_create_role_stmt(crs) as *mut bindings_raw::Node, + protobuf::node::Node::AlterRoleStmt(ars) => write_alter_role_stmt(ars) as *mut bindings_raw::Node, + protobuf::node::Node::AlterRoleSetStmt(arss) => write_alter_role_set_stmt(arss) as *mut bindings_raw::Node, + protobuf::node::Node::DropRoleStmt(drs) => write_drop_role_stmt(drs) as *mut bindings_raw::Node, + protobuf::node::Node::CreatePolicyStmt(cps) => write_create_policy_stmt(cps) as *mut bindings_raw::Node, + protobuf::node::Node::AlterPolicyStmt(aps) => write_alter_policy_stmt(aps) as *mut bindings_raw::Node, + protobuf::node::Node::CreateEventTrigStmt(cets) => write_create_event_trig_stmt(cets) as *mut bindings_raw::Node, + protobuf::node::Node::AlterEventTrigStmt(aets) => write_alter_event_trig_stmt(aets) as *mut bindings_raw::Node, + protobuf::node::Node::CreatePlangStmt(cpls) => write_create_plang_stmt(cpls) as *mut bindings_raw::Node, + protobuf::node::Node::CreateAmStmt(cas) => write_create_am_stmt(cas) as *mut bindings_raw::Node, + protobuf::node::Node::CreateOpClassStmt(cocs) => write_create_op_class_stmt(cocs) as *mut bindings_raw::Node, + protobuf::node::Node::CreateOpClassItem(coci) => write_create_op_class_item(coci) as *mut bindings_raw::Node, + protobuf::node::Node::CreateOpFamilyStmt(cofs) => write_create_op_family_stmt(cofs) as *mut bindings_raw::Node, + protobuf::node::Node::AlterOpFamilyStmt(aofs) => write_alter_op_family_stmt(aofs) as *mut bindings_raw::Node, + protobuf::node::Node::CreateFdwStmt(cfds) => write_create_fdw_stmt(cfds) as *mut bindings_raw::Node, + protobuf::node::Node::AlterFdwStmt(afds) => write_alter_fdw_stmt(afds) as *mut bindings_raw::Node, + protobuf::node::Node::CreateForeignServerStmt(cfss) => write_create_foreign_server_stmt(cfss) as *mut bindings_raw::Node, + protobuf::node::Node::AlterForeignServerStmt(afss) => write_alter_foreign_server_stmt(afss) as *mut bindings_raw::Node, + protobuf::node::Node::CreateForeignTableStmt(cfts) => write_create_foreign_table_stmt(cfts) as *mut bindings_raw::Node, + protobuf::node::Node::CreateUserMappingStmt(cums) => write_create_user_mapping_stmt(cums) as *mut bindings_raw::Node, + protobuf::node::Node::AlterUserMappingStmt(aums) => write_alter_user_mapping_stmt(aums) as *mut bindings_raw::Node, + protobuf::node::Node::DropUserMappingStmt(dums) => write_drop_user_mapping_stmt(dums) as *mut bindings_raw::Node, + protobuf::node::Node::ImportForeignSchemaStmt(ifss) => write_import_foreign_schema_stmt(ifss) as *mut bindings_raw::Node, + protobuf::node::Node::CreateTableSpaceStmt(ctss) => write_create_table_space_stmt(ctss) as *mut bindings_raw::Node, + protobuf::node::Node::DropTableSpaceStmt(dtss) => write_drop_table_space_stmt(dtss) as *mut bindings_raw::Node, + protobuf::node::Node::AlterTableSpaceOptionsStmt(atsos) => write_alter_table_space_options_stmt(atsos) as *mut bindings_raw::Node, + protobuf::node::Node::AlterTableMoveAllStmt(atmas) => write_alter_table_move_all_stmt(atmas) as *mut bindings_raw::Node, + protobuf::node::Node::AlterExtensionStmt(aes) => write_alter_extension_stmt(aes) as *mut bindings_raw::Node, + protobuf::node::Node::AlterExtensionContentsStmt(aecs) => write_alter_extension_contents_stmt(aecs) as *mut bindings_raw::Node, + protobuf::node::Node::AlterDomainStmt(ads) => write_alter_domain_stmt(ads) as *mut bindings_raw::Node, + protobuf::node::Node::AlterFunctionStmt(afs) => write_alter_function_stmt(afs) as *mut bindings_raw::Node, + protobuf::node::Node::AlterOperatorStmt(aos) => write_alter_operator_stmt(aos) as *mut bindings_raw::Node, + protobuf::node::Node::AlterTypeStmt(ats) => write_alter_type_stmt(ats) as *mut bindings_raw::Node, + protobuf::node::Node::AlterOwnerStmt(aos) => write_alter_owner_stmt(aos) as *mut bindings_raw::Node, + protobuf::node::Node::AlterObjectSchemaStmt(aoss) => write_alter_object_schema_stmt(aoss) as *mut bindings_raw::Node, + protobuf::node::Node::AlterObjectDependsStmt(aods) => write_alter_object_depends_stmt(aods) as *mut bindings_raw::Node, + protobuf::node::Node::AlterCollationStmt(acs) => write_alter_collation_stmt(acs) as *mut bindings_raw::Node, + protobuf::node::Node::AlterDefaultPrivilegesStmt(adps) => write_alter_default_privileges_stmt(adps) as *mut bindings_raw::Node, + protobuf::node::Node::CreateCastStmt(ccs) => write_create_cast_stmt(ccs) as *mut bindings_raw::Node, + protobuf::node::Node::CreateTransformStmt(cts) => write_create_transform_stmt(cts) as *mut bindings_raw::Node, + protobuf::node::Node::CreateConversionStmt(ccs) => write_create_conversion_stmt(ccs) as *mut bindings_raw::Node, + protobuf::node::Node::AlterTsdictionaryStmt(atds) => write_alter_ts_dictionary_stmt(atds) as *mut bindings_raw::Node, + protobuf::node::Node::AlterTsconfigurationStmt(atcs) => write_alter_ts_configuration_stmt(atcs) as *mut bindings_raw::Node, + // Database statements + protobuf::node::Node::CreatedbStmt(cds) => write_createdb_stmt(cds) as *mut bindings_raw::Node, + protobuf::node::Node::DropdbStmt(dds) => write_dropdb_stmt(dds) as *mut bindings_raw::Node, + protobuf::node::Node::AlterDatabaseStmt(ads) => write_alter_database_stmt(ads) as *mut bindings_raw::Node, + protobuf::node::Node::AlterDatabaseSetStmt(adss) => write_alter_database_set_stmt(adss) as *mut bindings_raw::Node, + protobuf::node::Node::AlterDatabaseRefreshCollStmt(adrcs) => write_alter_database_refresh_coll_stmt(adrcs) as *mut bindings_raw::Node, + protobuf::node::Node::AlterSystemStmt(ass) => write_alter_system_stmt(ass) as *mut bindings_raw::Node, + protobuf::node::Node::ClusterStmt(cs) => write_cluster_stmt(cs) as *mut bindings_raw::Node, + protobuf::node::Node::ReindexStmt(rs) => write_reindex_stmt(rs) as *mut bindings_raw::Node, + protobuf::node::Node::ConstraintsSetStmt(css) => write_constraints_set_stmt(css) as *mut bindings_raw::Node, + protobuf::node::Node::LoadStmt(ls) => write_load_stmt(ls) as *mut bindings_raw::Node, + protobuf::node::Node::DropOwnedStmt(dos) => write_drop_owned_stmt(dos) as *mut bindings_raw::Node, + protobuf::node::Node::ReassignOwnedStmt(ros) => write_reassign_owned_stmt(ros) as *mut bindings_raw::Node, + protobuf::node::Node::DropSubscriptionStmt(dss) => write_drop_subscription_stmt(dss) as *mut bindings_raw::Node, + // Table-related nodes + protobuf::node::Node::TableFunc(tf) => write_table_func(tf) as *mut bindings_raw::Node, + protobuf::node::Node::IntoClause(ic) => write_into_clause(ic) as *mut bindings_raw::Node, + protobuf::node::Node::TableLikeClause(tlc) => write_table_like_clause(tlc) as *mut bindings_raw::Node, + protobuf::node::Node::RangeTableFunc(rtf) => write_range_table_func(rtf) as *mut bindings_raw::Node, + protobuf::node::Node::RangeTableFuncCol(rtfc) => write_range_table_func_col(rtfc) as *mut bindings_raw::Node, + protobuf::node::Node::RangeTableSample(rts) => write_range_table_sample(rts) as *mut bindings_raw::Node, + protobuf::node::Node::PartitionSpec(ps) => write_partition_spec(ps) as *mut bindings_raw::Node, + protobuf::node::Node::PartitionBoundSpec(pbs) => write_partition_bound_spec(pbs) as *mut bindings_raw::Node, + protobuf::node::Node::PartitionRangeDatum(prd) => write_partition_range_datum(prd) as *mut bindings_raw::Node, + protobuf::node::Node::PartitionElem(pe) => write_partition_elem(pe) as *mut bindings_raw::Node, + protobuf::node::Node::PartitionCmd(pc) => write_partition_cmd(pc) as *mut bindings_raw::Node, + protobuf::node::Node::SinglePartitionSpec(sps) => write_single_partition_spec(sps) as *mut bindings_raw::Node, + protobuf::node::Node::InferClause(ic) => write_infer_clause(ic) as *mut bindings_raw::Node, + protobuf::node::Node::OnConflictClause(occ) => write_on_conflict_clause(occ) as *mut bindings_raw::Node, + protobuf::node::Node::MultiAssignRef(mar) => write_multi_assign_ref(mar) as *mut bindings_raw::Node, + protobuf::node::Node::TriggerTransition(tt) => write_trigger_transition(tt) as *mut bindings_raw::Node, + // CTE-related nodes + protobuf::node::Node::CtesearchClause(csc) => write_cte_search_clause(csc) as *mut bindings_raw::Node, + protobuf::node::Node::CtecycleClause(ccc) => write_cte_cycle_clause(ccc) as *mut bindings_raw::Node, + // Statistics nodes + protobuf::node::Node::CreateStatsStmt(css) => write_create_stats_stmt(css) as *mut bindings_raw::Node, + protobuf::node::Node::AlterStatsStmt(ass) => write_alter_stats_stmt(ass) as *mut bindings_raw::Node, + protobuf::node::Node::StatsElem(se) => write_stats_elem(se) as *mut bindings_raw::Node, + // Publication nodes + protobuf::node::Node::PublicationObjSpec(pos) => write_publication_obj_spec(pos) as *mut bindings_raw::Node, + protobuf::node::Node::PublicationTable(pt) => write_publication_table(pt) as *mut bindings_raw::Node, + // Expression nodes (internal/executor - return null as they shouldn't appear in raw parse trees) + protobuf::node::Node::Var(_) + | protobuf::node::Node::Aggref(_) + | protobuf::node::Node::WindowFunc(_) + | protobuf::node::Node::WindowFuncRunCondition(_) + | protobuf::node::Node::MergeSupportFunc(_) + | protobuf::node::Node::SubscriptingRef(_) + | protobuf::node::Node::FuncExpr(_) + | protobuf::node::Node::OpExpr(_) + | protobuf::node::Node::DistinctExpr(_) + | protobuf::node::Node::NullIfExpr(_) + | protobuf::node::Node::ScalarArrayOpExpr(_) + | protobuf::node::Node::FieldSelect(_) + | protobuf::node::Node::FieldStore(_) + | protobuf::node::Node::RelabelType(_) + | protobuf::node::Node::CoerceViaIo(_) + | protobuf::node::Node::ArrayCoerceExpr(_) + | protobuf::node::Node::ConvertRowtypeExpr(_) + | protobuf::node::Node::CollateExpr(_) + | protobuf::node::Node::CaseTestExpr(_) + | protobuf::node::Node::ArrayExpr(_) + | protobuf::node::Node::RowCompareExpr(_) + | protobuf::node::Node::CoerceToDomainValue(_) + | protobuf::node::Node::CurrentOfExpr(_) + | protobuf::node::Node::NextValueExpr(_) + | protobuf::node::Node::InferenceElem(_) + | protobuf::node::Node::SubPlan(_) + | protobuf::node::Node::AlternativeSubPlan(_) + | protobuf::node::Node::TargetEntry(_) + | protobuf::node::Node::RangeTblRef(_) + | protobuf::node::Node::FromExpr(_) + | protobuf::node::Node::OnConflictExpr(_) + | protobuf::node::Node::Query(_) + | protobuf::node::Node::MergeAction(_) + | protobuf::node::Node::SortGroupClause(_) + | protobuf::node::Node::WindowClause(_) + | protobuf::node::Node::RowMarkClause(_) + | protobuf::node::Node::WithCheckOption(_) + | protobuf::node::Node::RangeTblEntry(_) + | protobuf::node::Node::RangeTblFunction(_) + | protobuf::node::Node::TableSampleClause(_) + | protobuf::node::Node::RtepermissionInfo(_) + | protobuf::node::Node::GroupingFunc(_) + | protobuf::node::Node::Param(_) + | protobuf::node::Node::IntList(_) + | protobuf::node::Node::OidList(_) + | protobuf::node::Node::RawStmt(_) + | protobuf::node::Node::SetOperationStmt(_) + | protobuf::node::Node::ReturnStmt(_) + | protobuf::node::Node::PlassignStmt(_) + | protobuf::node::Node::ReplicaIdentityStmt(_) + | protobuf::node::Node::CallContext(_) + | protobuf::node::Node::InlineCodeBlock(_) => { + // These are internal/executor nodes that shouldn't appear in raw parse trees, + // or are handled specially elsewhere std::ptr::null_mut() } + // SQL Value function + protobuf::node::Node::SqlvalueFunction(svf) => write_sql_value_function(svf) as *mut bindings_raw::Node, + // XML nodes + protobuf::node::Node::XmlExpr(xe) => write_xml_expr(xe) as *mut bindings_raw::Node, + protobuf::node::Node::XmlSerialize(xs) => write_xml_serialize(xs) as *mut bindings_raw::Node, + // Named argument + protobuf::node::Node::NamedArgExpr(nae) => write_named_arg_expr(nae) as *mut bindings_raw::Node, + // JSON nodes + protobuf::node::Node::JsonFormat(jf) => write_json_format(jf) as *mut bindings_raw::Node, + protobuf::node::Node::JsonReturning(jr) => write_json_returning(jr) as *mut bindings_raw::Node, + protobuf::node::Node::JsonValueExpr(jve) => write_json_value_expr(jve) as *mut bindings_raw::Node, + protobuf::node::Node::JsonConstructorExpr(jce) => write_json_constructor_expr(jce) as *mut bindings_raw::Node, + protobuf::node::Node::JsonIsPredicate(jip) => write_json_is_predicate(jip) as *mut bindings_raw::Node, + protobuf::node::Node::JsonBehavior(jb) => write_json_behavior(jb) as *mut bindings_raw::Node, + protobuf::node::Node::JsonExpr(je) => write_json_expr(je) as *mut bindings_raw::Node, + protobuf::node::Node::JsonTablePath(jtp) => write_json_table_path(jtp) as *mut bindings_raw::Node, + protobuf::node::Node::JsonTablePathScan(jtps) => write_json_table_path_scan(jtps) as *mut bindings_raw::Node, + protobuf::node::Node::JsonTableSiblingJoin(jtsj) => write_json_table_sibling_join(jtsj) as *mut bindings_raw::Node, + protobuf::node::Node::JsonOutput(jo) => write_json_output(jo) as *mut bindings_raw::Node, + protobuf::node::Node::JsonArgument(ja) => write_json_argument(ja) as *mut bindings_raw::Node, + protobuf::node::Node::JsonFuncExpr(jfe) => write_json_func_expr(jfe) as *mut bindings_raw::Node, + protobuf::node::Node::JsonTablePathSpec(jtps) => write_json_table_path_spec(jtps) as *mut bindings_raw::Node, + protobuf::node::Node::JsonTable(jt) => write_json_table(jt) as *mut bindings_raw::Node, + protobuf::node::Node::JsonTableColumn(jtc) => write_json_table_column(jtc) as *mut bindings_raw::Node, + protobuf::node::Node::JsonKeyValue(jkv) => write_json_key_value(jkv) as *mut bindings_raw::Node, + protobuf::node::Node::JsonParseExpr(jpe) => write_json_parse_expr(jpe) as *mut bindings_raw::Node, + protobuf::node::Node::JsonScalarExpr(jse) => write_json_scalar_expr(jse) as *mut bindings_raw::Node, + protobuf::node::Node::JsonSerializeExpr(jse) => write_json_serialize_expr(jse) as *mut bindings_raw::Node, + protobuf::node::Node::JsonObjectConstructor(joc) => write_json_object_constructor(joc) as *mut bindings_raw::Node, + protobuf::node::Node::JsonArrayConstructor(jac) => write_json_array_constructor(jac) as *mut bindings_raw::Node, + protobuf::node::Node::JsonArrayQueryConstructor(jaqc) => write_json_array_query_constructor(jaqc) as *mut bindings_raw::Node, + protobuf::node::Node::JsonAggConstructor(jac) => write_json_agg_constructor(jac) as *mut bindings_raw::Node, + protobuf::node::Node::JsonObjectAgg(joa) => write_json_object_agg(joa) as *mut bindings_raw::Node, + protobuf::node::Node::JsonArrayAgg(jaa) => write_json_array_agg(jaa) as *mut bindings_raw::Node, } } } @@ -372,13 +573,6 @@ unsafe fn write_range_var(rv: &protobuf::RangeVar) -> *mut bindings_raw::RangeVa node } -unsafe fn write_range_var_opt(rv: &Option>) -> *mut bindings_raw::RangeVar { - match rv { - Some(r) => write_range_var(r), - None => std::ptr::null_mut(), - } -} - unsafe fn write_range_var_ref(rv: &Option) -> *mut bindings_raw::RangeVar { match rv { Some(r) => write_range_var(r), @@ -393,13 +587,6 @@ unsafe fn write_alias(alias: &protobuf::Alias) -> *mut bindings_raw::Alias { node } -unsafe fn write_alias_opt(alias: &Option>) -> *mut bindings_raw::Alias { - match alias { - Some(a) => write_alias(a), - None => std::ptr::null_mut(), - } -} - unsafe fn write_alias_ref(alias: &Option) -> *mut bindings_raw::Alias { match alias { Some(a) => write_alias(a), @@ -531,14 +718,6 @@ unsafe fn write_bit_string(bs: &protobuf::BitString) -> *mut bindings_raw::BitSt node } -unsafe fn write_null() -> *mut bindings_raw::Node { - // A_Const with isnull=true represents NULL - let node = alloc_node::(bindings_raw::NodeTag_T_A_Const); - (*node).isnull = true; - (*node).location = -1; - node as *mut bindings_raw::Node -} - unsafe fn write_list(l: &protobuf::List) -> *mut bindings_raw::List { write_node_list(&l.items) } @@ -579,13 +758,6 @@ unsafe fn write_type_cast(tc: &protobuf::TypeCast) -> *mut bindings_raw::TypeCas node } -unsafe fn write_type_name_opt(tn: &Option>) -> *mut bindings_raw::TypeName { - match tn { - Some(t) => write_type_name(t), - None => std::ptr::null_mut(), - } -} - unsafe fn write_type_name_ref(tn: &Option) -> *mut bindings_raw::TypeName { match tn { Some(t) => write_type_name(t), @@ -703,13 +875,6 @@ unsafe fn write_with_clause(wc: &protobuf::WithClause) -> *mut bindings_raw::Wit node } -unsafe fn write_with_clause_opt(wc: &Option>) -> *mut bindings_raw::WithClause { - match wc { - Some(w) => write_with_clause(w), - None => std::ptr::null_mut(), - } -} - unsafe fn write_with_clause_ref(wc: &Option) -> *mut bindings_raw::WithClause { match wc { Some(w) => write_with_clause(w), @@ -1401,3 +1566,1388 @@ unsafe fn write_collate_clause(cc: &protobuf::CollateClause) -> *mut bindings_ra (*node).location = cc.location; node } + +// ============================================================================= +// Simple statement nodes +// ============================================================================= + +unsafe fn write_listen_stmt(stmt: &protobuf::ListenStmt) -> *mut bindings_raw::ListenStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_ListenStmt); + (*node).conditionname = pstrdup(&stmt.conditionname); + node +} + +unsafe fn write_unlisten_stmt(stmt: &protobuf::UnlistenStmt) -> *mut bindings_raw::UnlistenStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_UnlistenStmt); + (*node).conditionname = pstrdup(&stmt.conditionname); + node +} + +unsafe fn write_notify_stmt(stmt: &protobuf::NotifyStmt) -> *mut bindings_raw::NotifyStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_NotifyStmt); + (*node).conditionname = pstrdup(&stmt.conditionname); + (*node).payload = pstrdup(&stmt.payload); + node +} + +unsafe fn write_discard_stmt(stmt: &protobuf::DiscardStmt) -> *mut bindings_raw::DiscardStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_DiscardStmt); + (*node).target = proto_enum_to_c(stmt.target); + node +} + +// ============================================================================= +// Type definition nodes +// ============================================================================= + +unsafe fn write_composite_type_stmt(stmt: &protobuf::CompositeTypeStmt) -> *mut bindings_raw::CompositeTypeStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CompositeTypeStmt); + (*node).typevar = write_range_var_ref(&stmt.typevar); + (*node).coldeflist = write_node_list(&stmt.coldeflist); + node +} + +unsafe fn write_create_enum_stmt(stmt: &protobuf::CreateEnumStmt) -> *mut bindings_raw::CreateEnumStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateEnumStmt); + (*node).typeName = write_node_list(&stmt.type_name); + (*node).vals = write_node_list(&stmt.vals); + node +} + +unsafe fn write_create_range_stmt(stmt: &protobuf::CreateRangeStmt) -> *mut bindings_raw::CreateRangeStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateRangeStmt); + (*node).typeName = write_node_list(&stmt.type_name); + (*node).params = write_node_list(&stmt.params); + node +} + +unsafe fn write_alter_enum_stmt(stmt: &protobuf::AlterEnumStmt) -> *mut bindings_raw::AlterEnumStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterEnumStmt); + (*node).typeName = write_node_list(&stmt.type_name); + (*node).oldVal = pstrdup(&stmt.old_val); + (*node).newVal = pstrdup(&stmt.new_val); + (*node).newValNeighbor = pstrdup(&stmt.new_val_neighbor); + (*node).newValIsAfter = stmt.new_val_is_after; + (*node).skipIfNewValExists = stmt.skip_if_new_val_exists; + node +} + +unsafe fn write_create_domain_stmt(stmt: &protobuf::CreateDomainStmt) -> *mut bindings_raw::CreateDomainStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateDomainStmt); + (*node).domainname = write_node_list(&stmt.domainname); + (*node).typeName = write_type_name_ref(&stmt.type_name); + (*node).collClause = match &stmt.coll_clause { + Some(cc) => write_collate_clause(cc) as *mut bindings_raw::CollateClause, + None => std::ptr::null_mut(), + }; + (*node).constraints = write_node_list(&stmt.constraints); + node +} + +// ============================================================================= +// Extension nodes +// ============================================================================= + +unsafe fn write_create_extension_stmt(stmt: &protobuf::CreateExtensionStmt) -> *mut bindings_raw::CreateExtensionStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateExtensionStmt); + (*node).extname = pstrdup(&stmt.extname); + (*node).if_not_exists = stmt.if_not_exists; + (*node).options = write_node_list(&stmt.options); + node +} + +// ============================================================================= +// Publication/Subscription nodes +// ============================================================================= + +unsafe fn write_create_publication_stmt(stmt: &protobuf::CreatePublicationStmt) -> *mut bindings_raw::CreatePublicationStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreatePublicationStmt); + (*node).pubname = pstrdup(&stmt.pubname); + (*node).options = write_node_list(&stmt.options); + (*node).pubobjects = write_node_list(&stmt.pubobjects); + (*node).for_all_tables = stmt.for_all_tables; + node +} + +unsafe fn write_alter_publication_stmt(stmt: &protobuf::AlterPublicationStmt) -> *mut bindings_raw::AlterPublicationStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterPublicationStmt); + (*node).pubname = pstrdup(&stmt.pubname); + (*node).options = write_node_list(&stmt.options); + (*node).pubobjects = write_node_list(&stmt.pubobjects); + (*node).for_all_tables = stmt.for_all_tables; + (*node).action = proto_enum_to_c(stmt.action); + node +} + +unsafe fn write_create_subscription_stmt(stmt: &protobuf::CreateSubscriptionStmt) -> *mut bindings_raw::CreateSubscriptionStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateSubscriptionStmt); + (*node).subname = pstrdup(&stmt.subname); + (*node).conninfo = pstrdup(&stmt.conninfo); + (*node).publication = write_node_list(&stmt.publication); + (*node).options = write_node_list(&stmt.options); + node +} + +unsafe fn write_alter_subscription_stmt(stmt: &protobuf::AlterSubscriptionStmt) -> *mut bindings_raw::AlterSubscriptionStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterSubscriptionStmt); + (*node).kind = proto_enum_to_c(stmt.kind); + (*node).subname = pstrdup(&stmt.subname); + (*node).conninfo = pstrdup(&stmt.conninfo); + (*node).publication = write_node_list(&stmt.publication); + (*node).options = write_node_list(&stmt.options); + node +} + +// ============================================================================= +// Expression nodes +// ============================================================================= + +unsafe fn write_coerce_to_domain(ctd: &protobuf::CoerceToDomain) -> *mut bindings_raw::CoerceToDomain { + let node = alloc_node::(bindings_raw::NodeTag_T_CoerceToDomain); + (*node).arg = write_node_boxed(&ctd.arg) as *mut bindings_raw::Expr; + (*node).resulttype = ctd.resulttype; + (*node).resulttypmod = ctd.resulttypmod; + (*node).resultcollid = ctd.resultcollid; + (*node).coercionformat = proto_enum_to_c(ctd.coercionformat); + (*node).location = ctd.location; + node +} + +// ============================================================================= +// Sequence nodes +// ============================================================================= + +unsafe fn write_create_seq_stmt(stmt: &protobuf::CreateSeqStmt) -> *mut bindings_raw::CreateSeqStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateSeqStmt); + (*node).sequence = write_range_var_ref(&stmt.sequence); + (*node).options = write_node_list(&stmt.options); + (*node).ownerId = stmt.owner_id; + (*node).for_identity = stmt.for_identity; + (*node).if_not_exists = stmt.if_not_exists; + node +} + +unsafe fn write_alter_seq_stmt(stmt: &protobuf::AlterSeqStmt) -> *mut bindings_raw::AlterSeqStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterSeqStmt); + (*node).sequence = write_range_var_ref(&stmt.sequence); + (*node).options = write_node_list(&stmt.options); + (*node).for_identity = stmt.for_identity; + (*node).missing_ok = stmt.missing_ok; + node +} + +// ============================================================================= +// Cursor nodes +// ============================================================================= + +unsafe fn write_close_portal_stmt(stmt: &protobuf::ClosePortalStmt) -> *mut bindings_raw::ClosePortalStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_ClosePortalStmt); + (*node).portalname = pstrdup(&stmt.portalname); + node +} + +unsafe fn write_fetch_stmt(stmt: &protobuf::FetchStmt) -> *mut bindings_raw::FetchStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_FetchStmt); + (*node).direction = proto_enum_to_c(stmt.direction); + (*node).howMany = stmt.how_many; + (*node).portalname = pstrdup(&stmt.portalname); + (*node).ismove = stmt.ismove; + node +} + +unsafe fn write_declare_cursor_stmt(stmt: &protobuf::DeclareCursorStmt) -> *mut bindings_raw::DeclareCursorStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_DeclareCursorStmt); + (*node).portalname = pstrdup(&stmt.portalname); + (*node).options = stmt.options; + (*node).query = write_node_boxed(&stmt.query); + node +} + +// ============================================================================= +// Additional DDL statements +// ============================================================================= + +unsafe fn write_define_stmt(stmt: &protobuf::DefineStmt) -> *mut bindings_raw::DefineStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_DefineStmt); + (*node).kind = proto_enum_to_c(stmt.kind); + (*node).oldstyle = stmt.oldstyle; + (*node).defnames = write_node_list(&stmt.defnames); + (*node).args = write_node_list(&stmt.args); + (*node).definition = write_node_list(&stmt.definition); + (*node).if_not_exists = stmt.if_not_exists; + (*node).replace = stmt.replace; + node +} + +unsafe fn write_comment_stmt(stmt: &protobuf::CommentStmt) -> *mut bindings_raw::CommentStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CommentStmt); + (*node).objtype = proto_enum_to_c(stmt.objtype); + (*node).object = write_node_boxed(&stmt.object); + (*node).comment = pstrdup(&stmt.comment); + node +} + +unsafe fn write_sec_label_stmt(stmt: &protobuf::SecLabelStmt) -> *mut bindings_raw::SecLabelStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_SecLabelStmt); + (*node).objtype = proto_enum_to_c(stmt.objtype); + (*node).object = write_node_boxed(&stmt.object); + (*node).provider = pstrdup(&stmt.provider); + (*node).label = pstrdup(&stmt.label); + node +} + +unsafe fn write_create_role_stmt(stmt: &protobuf::CreateRoleStmt) -> *mut bindings_raw::CreateRoleStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateRoleStmt); + (*node).stmt_type = proto_enum_to_c(stmt.stmt_type); + (*node).role = pstrdup(&stmt.role); + (*node).options = write_node_list(&stmt.options); + node +} + +unsafe fn write_alter_role_stmt(stmt: &protobuf::AlterRoleStmt) -> *mut bindings_raw::AlterRoleStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterRoleStmt); + (*node).role = write_role_spec_ref(&stmt.role); + (*node).options = write_node_list(&stmt.options); + (*node).action = stmt.action; + node +} + +unsafe fn write_alter_role_set_stmt(stmt: &protobuf::AlterRoleSetStmt) -> *mut bindings_raw::AlterRoleSetStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterRoleSetStmt); + (*node).role = write_role_spec_ref(&stmt.role); + (*node).database = pstrdup(&stmt.database); + (*node).setstmt = write_variable_set_stmt_ref(&stmt.setstmt); + node +} + +unsafe fn write_drop_role_stmt(stmt: &protobuf::DropRoleStmt) -> *mut bindings_raw::DropRoleStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_DropRoleStmt); + (*node).roles = write_node_list(&stmt.roles); + (*node).missing_ok = stmt.missing_ok; + node +} + +unsafe fn write_create_policy_stmt(stmt: &protobuf::CreatePolicyStmt) -> *mut bindings_raw::CreatePolicyStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreatePolicyStmt); + (*node).policy_name = pstrdup(&stmt.policy_name); + (*node).table = write_range_var_ref(&stmt.table); + (*node).cmd_name = pstrdup(&stmt.cmd_name); + (*node).permissive = stmt.permissive; + (*node).roles = write_node_list(&stmt.roles); + (*node).qual = write_node_boxed(&stmt.qual); + (*node).with_check = write_node_boxed(&stmt.with_check); + node +} + +unsafe fn write_alter_policy_stmt(stmt: &protobuf::AlterPolicyStmt) -> *mut bindings_raw::AlterPolicyStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterPolicyStmt); + (*node).policy_name = pstrdup(&stmt.policy_name); + (*node).table = write_range_var_ref(&stmt.table); + (*node).roles = write_node_list(&stmt.roles); + (*node).qual = write_node_boxed(&stmt.qual); + (*node).with_check = write_node_boxed(&stmt.with_check); + node +} + +unsafe fn write_create_event_trig_stmt(stmt: &protobuf::CreateEventTrigStmt) -> *mut bindings_raw::CreateEventTrigStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateEventTrigStmt); + (*node).trigname = pstrdup(&stmt.trigname); + (*node).eventname = pstrdup(&stmt.eventname); + (*node).whenclause = write_node_list(&stmt.whenclause); + (*node).funcname = write_node_list(&stmt.funcname); + node +} + +unsafe fn write_alter_event_trig_stmt(stmt: &protobuf::AlterEventTrigStmt) -> *mut bindings_raw::AlterEventTrigStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterEventTrigStmt); + (*node).trigname = pstrdup(&stmt.trigname); + (*node).tgenabled = if stmt.tgenabled.is_empty() { 0 } else { stmt.tgenabled.as_bytes()[0] as i8 }; + node +} + +unsafe fn write_create_plang_stmt(stmt: &protobuf::CreatePLangStmt) -> *mut bindings_raw::CreatePLangStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreatePLangStmt); + (*node).replace = stmt.replace; + (*node).plname = pstrdup(&stmt.plname); + (*node).plhandler = write_node_list(&stmt.plhandler); + (*node).plinline = write_node_list(&stmt.plinline); + (*node).plvalidator = write_node_list(&stmt.plvalidator); + (*node).pltrusted = stmt.pltrusted; + node +} + +unsafe fn write_create_am_stmt(stmt: &protobuf::CreateAmStmt) -> *mut bindings_raw::CreateAmStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateAmStmt); + (*node).amname = pstrdup(&stmt.amname); + (*node).handler_name = write_node_list(&stmt.handler_name); + (*node).amtype = if stmt.amtype.is_empty() { 0 } else { stmt.amtype.as_bytes()[0] as i8 }; + node +} + +unsafe fn write_create_op_class_stmt(stmt: &protobuf::CreateOpClassStmt) -> *mut bindings_raw::CreateOpClassStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateOpClassStmt); + (*node).opclassname = write_node_list(&stmt.opclassname); + (*node).opfamilyname = write_node_list(&stmt.opfamilyname); + (*node).amname = pstrdup(&stmt.amname); + (*node).datatype = write_type_name_ref(&stmt.datatype); + (*node).items = write_node_list(&stmt.items); + (*node).isDefault = stmt.is_default; + node +} + +unsafe fn write_create_op_class_item(stmt: &protobuf::CreateOpClassItem) -> *mut bindings_raw::CreateOpClassItem { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateOpClassItem); + (*node).itemtype = stmt.itemtype; + (*node).name = write_object_with_args_ref(&stmt.name); + (*node).number = stmt.number; + (*node).order_family = write_node_list(&stmt.order_family); + (*node).class_args = write_node_list(&stmt.class_args); + (*node).storedtype = write_type_name_ref(&stmt.storedtype); + node +} + +unsafe fn write_create_op_family_stmt(stmt: &protobuf::CreateOpFamilyStmt) -> *mut bindings_raw::CreateOpFamilyStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateOpFamilyStmt); + (*node).opfamilyname = write_node_list(&stmt.opfamilyname); + (*node).amname = pstrdup(&stmt.amname); + node +} + +unsafe fn write_alter_op_family_stmt(stmt: &protobuf::AlterOpFamilyStmt) -> *mut bindings_raw::AlterOpFamilyStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterOpFamilyStmt); + (*node).opfamilyname = write_node_list(&stmt.opfamilyname); + (*node).amname = pstrdup(&stmt.amname); + (*node).isDrop = stmt.is_drop; + (*node).items = write_node_list(&stmt.items); + node +} + +unsafe fn write_create_fdw_stmt(stmt: &protobuf::CreateFdwStmt) -> *mut bindings_raw::CreateFdwStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateFdwStmt); + (*node).fdwname = pstrdup(&stmt.fdwname); + (*node).func_options = write_node_list(&stmt.func_options); + (*node).options = write_node_list(&stmt.options); + node +} + +unsafe fn write_alter_fdw_stmt(stmt: &protobuf::AlterFdwStmt) -> *mut bindings_raw::AlterFdwStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterFdwStmt); + (*node).fdwname = pstrdup(&stmt.fdwname); + (*node).func_options = write_node_list(&stmt.func_options); + (*node).options = write_node_list(&stmt.options); + node +} + +unsafe fn write_create_foreign_server_stmt(stmt: &protobuf::CreateForeignServerStmt) -> *mut bindings_raw::CreateForeignServerStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateForeignServerStmt); + (*node).servername = pstrdup(&stmt.servername); + (*node).servertype = pstrdup(&stmt.servertype); + (*node).version = pstrdup(&stmt.version); + (*node).fdwname = pstrdup(&stmt.fdwname); + (*node).if_not_exists = stmt.if_not_exists; + (*node).options = write_node_list(&stmt.options); + node +} + +unsafe fn write_alter_foreign_server_stmt(stmt: &protobuf::AlterForeignServerStmt) -> *mut bindings_raw::AlterForeignServerStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterForeignServerStmt); + (*node).servername = pstrdup(&stmt.servername); + (*node).version = pstrdup(&stmt.version); + (*node).options = write_node_list(&stmt.options); + (*node).has_version = stmt.has_version; + node +} + +unsafe fn write_create_foreign_table_stmt(stmt: &protobuf::CreateForeignTableStmt) -> *mut bindings_raw::CreateForeignTableStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateForeignTableStmt); + // CreateForeignTableStmt extends CreateStmt + (*node).base.type_ = bindings_raw::NodeTag_T_CreateForeignTableStmt; + if let Some(ref base) = stmt.base_stmt { + (*node).base.relation = write_range_var_ref(&base.relation); + (*node).base.tableElts = write_node_list(&base.table_elts); + (*node).base.inhRelations = write_node_list(&base.inh_relations); + (*node).base.partbound = write_partition_bound_spec_ref(&base.partbound); + (*node).base.partspec = write_partition_spec_ref(&base.partspec); + (*node).base.ofTypename = write_type_name_ref(&base.of_typename); + (*node).base.constraints = write_node_list(&base.constraints); + (*node).base.options = write_node_list(&base.options); + (*node).base.oncommit = proto_enum_to_c(base.oncommit); + (*node).base.tablespacename = pstrdup(&base.tablespacename); + (*node).base.accessMethod = pstrdup(&base.access_method); + (*node).base.if_not_exists = base.if_not_exists; + } + (*node).servername = pstrdup(&stmt.servername); + (*node).options = write_node_list(&stmt.options); + node +} + +unsafe fn write_create_user_mapping_stmt(stmt: &protobuf::CreateUserMappingStmt) -> *mut bindings_raw::CreateUserMappingStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateUserMappingStmt); + (*node).user = write_role_spec_ref(&stmt.user); + (*node).servername = pstrdup(&stmt.servername); + (*node).if_not_exists = stmt.if_not_exists; + (*node).options = write_node_list(&stmt.options); + node +} + +unsafe fn write_alter_user_mapping_stmt(stmt: &protobuf::AlterUserMappingStmt) -> *mut bindings_raw::AlterUserMappingStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterUserMappingStmt); + (*node).user = write_role_spec_ref(&stmt.user); + (*node).servername = pstrdup(&stmt.servername); + (*node).options = write_node_list(&stmt.options); + node +} + +unsafe fn write_drop_user_mapping_stmt(stmt: &protobuf::DropUserMappingStmt) -> *mut bindings_raw::DropUserMappingStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_DropUserMappingStmt); + (*node).user = write_role_spec_ref(&stmt.user); + (*node).servername = pstrdup(&stmt.servername); + (*node).missing_ok = stmt.missing_ok; + node +} + +unsafe fn write_import_foreign_schema_stmt(stmt: &protobuf::ImportForeignSchemaStmt) -> *mut bindings_raw::ImportForeignSchemaStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_ImportForeignSchemaStmt); + (*node).server_name = pstrdup(&stmt.server_name); + (*node).remote_schema = pstrdup(&stmt.remote_schema); + (*node).local_schema = pstrdup(&stmt.local_schema); + (*node).list_type = proto_enum_to_c(stmt.list_type); + (*node).table_list = write_node_list(&stmt.table_list); + (*node).options = write_node_list(&stmt.options); + node +} + +unsafe fn write_create_table_space_stmt(stmt: &protobuf::CreateTableSpaceStmt) -> *mut bindings_raw::CreateTableSpaceStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateTableSpaceStmt); + (*node).tablespacename = pstrdup(&stmt.tablespacename); + (*node).owner = write_role_spec_ref(&stmt.owner); + (*node).location = pstrdup(&stmt.location); + (*node).options = write_node_list(&stmt.options); + node +} + +unsafe fn write_drop_table_space_stmt(stmt: &protobuf::DropTableSpaceStmt) -> *mut bindings_raw::DropTableSpaceStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_DropTableSpaceStmt); + (*node).tablespacename = pstrdup(&stmt.tablespacename); + (*node).missing_ok = stmt.missing_ok; + node +} + +unsafe fn write_alter_table_space_options_stmt(stmt: &protobuf::AlterTableSpaceOptionsStmt) -> *mut bindings_raw::AlterTableSpaceOptionsStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterTableSpaceOptionsStmt); + (*node).tablespacename = pstrdup(&stmt.tablespacename); + (*node).options = write_node_list(&stmt.options); + (*node).isReset = stmt.is_reset; + node +} + +unsafe fn write_alter_table_move_all_stmt(stmt: &protobuf::AlterTableMoveAllStmt) -> *mut bindings_raw::AlterTableMoveAllStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterTableMoveAllStmt); + (*node).orig_tablespacename = pstrdup(&stmt.orig_tablespacename); + (*node).objtype = proto_enum_to_c(stmt.objtype); + (*node).roles = write_node_list(&stmt.roles); + (*node).new_tablespacename = pstrdup(&stmt.new_tablespacename); + (*node).nowait = stmt.nowait; + node +} + +unsafe fn write_alter_extension_stmt(stmt: &protobuf::AlterExtensionStmt) -> *mut bindings_raw::AlterExtensionStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterExtensionStmt); + (*node).extname = pstrdup(&stmt.extname); + (*node).options = write_node_list(&stmt.options); + node +} + +unsafe fn write_alter_extension_contents_stmt(stmt: &protobuf::AlterExtensionContentsStmt) -> *mut bindings_raw::AlterExtensionContentsStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterExtensionContentsStmt); + (*node).extname = pstrdup(&stmt.extname); + (*node).action = stmt.action; + (*node).objtype = proto_enum_to_c(stmt.objtype); + (*node).object = write_node_boxed(&stmt.object); + node +} + +unsafe fn write_alter_domain_stmt(stmt: &protobuf::AlterDomainStmt) -> *mut bindings_raw::AlterDomainStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterDomainStmt); + (*node).subtype = if stmt.subtype.is_empty() { 0 } else { stmt.subtype.as_bytes()[0] as i8 }; + (*node).typeName = write_node_list(&stmt.type_name); + (*node).name = pstrdup(&stmt.name); + (*node).def = write_node_boxed(&stmt.def); + (*node).behavior = proto_enum_to_c(stmt.behavior); + (*node).missing_ok = stmt.missing_ok; + node +} + +unsafe fn write_alter_function_stmt(stmt: &protobuf::AlterFunctionStmt) -> *mut bindings_raw::AlterFunctionStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterFunctionStmt); + (*node).objtype = proto_enum_to_c(stmt.objtype); + (*node).func = write_object_with_args_ref(&stmt.func); + (*node).actions = write_node_list(&stmt.actions); + node +} + +unsafe fn write_alter_operator_stmt(stmt: &protobuf::AlterOperatorStmt) -> *mut bindings_raw::AlterOperatorStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterOperatorStmt); + (*node).opername = write_object_with_args_ref(&stmt.opername); + (*node).options = write_node_list(&stmt.options); + node +} + +unsafe fn write_alter_type_stmt(stmt: &protobuf::AlterTypeStmt) -> *mut bindings_raw::AlterTypeStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterTypeStmt); + (*node).typeName = write_node_list(&stmt.type_name); + (*node).options = write_node_list(&stmt.options); + node +} + +unsafe fn write_alter_owner_stmt(stmt: &protobuf::AlterOwnerStmt) -> *mut bindings_raw::AlterOwnerStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterOwnerStmt); + (*node).objectType = proto_enum_to_c(stmt.object_type); + (*node).relation = write_range_var_ref(&stmt.relation); + (*node).object = write_node_boxed(&stmt.object); + (*node).newowner = write_role_spec_ref(&stmt.newowner); + node +} + +unsafe fn write_alter_object_schema_stmt(stmt: &protobuf::AlterObjectSchemaStmt) -> *mut bindings_raw::AlterObjectSchemaStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterObjectSchemaStmt); + (*node).objectType = proto_enum_to_c(stmt.object_type); + (*node).relation = write_range_var_ref(&stmt.relation); + (*node).object = write_node_boxed(&stmt.object); + (*node).newschema = pstrdup(&stmt.newschema); + (*node).missing_ok = stmt.missing_ok; + node +} + +unsafe fn write_alter_object_depends_stmt(stmt: &protobuf::AlterObjectDependsStmt) -> *mut bindings_raw::AlterObjectDependsStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterObjectDependsStmt); + (*node).objectType = proto_enum_to_c(stmt.object_type); + (*node).relation = write_range_var_ref(&stmt.relation); + (*node).object = write_node_boxed(&stmt.object); + (*node).extname = write_string_ref(&stmt.extname); + (*node).remove = stmt.remove; + node +} + +unsafe fn write_alter_collation_stmt(stmt: &protobuf::AlterCollationStmt) -> *mut bindings_raw::AlterCollationStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterCollationStmt); + (*node).collname = write_node_list(&stmt.collname); + node +} + +unsafe fn write_alter_default_privileges_stmt(stmt: &protobuf::AlterDefaultPrivilegesStmt) -> *mut bindings_raw::AlterDefaultPrivilegesStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterDefaultPrivilegesStmt); + (*node).options = write_node_list(&stmt.options); + (*node).action = write_grant_stmt_ref(&stmt.action); + node +} + +unsafe fn write_create_cast_stmt(stmt: &protobuf::CreateCastStmt) -> *mut bindings_raw::CreateCastStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateCastStmt); + (*node).sourcetype = write_type_name_ref(&stmt.sourcetype); + (*node).targettype = write_type_name_ref(&stmt.targettype); + (*node).func = write_object_with_args_ref(&stmt.func); + (*node).context = proto_enum_to_c(stmt.context); + (*node).inout = stmt.inout; + node +} + +unsafe fn write_create_transform_stmt(stmt: &protobuf::CreateTransformStmt) -> *mut bindings_raw::CreateTransformStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateTransformStmt); + (*node).replace = stmt.replace; + (*node).type_name = write_type_name_ref(&stmt.type_name); + (*node).lang = pstrdup(&stmt.lang); + (*node).fromsql = write_object_with_args_ref(&stmt.fromsql); + (*node).tosql = write_object_with_args_ref(&stmt.tosql); + node +} + +unsafe fn write_create_conversion_stmt(stmt: &protobuf::CreateConversionStmt) -> *mut bindings_raw::CreateConversionStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateConversionStmt); + (*node).conversion_name = write_node_list(&stmt.conversion_name); + (*node).for_encoding_name = pstrdup(&stmt.for_encoding_name); + (*node).to_encoding_name = pstrdup(&stmt.to_encoding_name); + (*node).func_name = write_node_list(&stmt.func_name); + (*node).def = stmt.def; + node +} + +unsafe fn write_alter_ts_dictionary_stmt(stmt: &protobuf::AlterTsDictionaryStmt) -> *mut bindings_raw::AlterTSDictionaryStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterTSDictionaryStmt); + (*node).dictname = write_node_list(&stmt.dictname); + (*node).options = write_node_list(&stmt.options); + node +} + +unsafe fn write_alter_ts_configuration_stmt(stmt: &protobuf::AlterTsConfigurationStmt) -> *mut bindings_raw::AlterTSConfigurationStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterTSConfigurationStmt); + (*node).kind = proto_enum_to_c(stmt.kind); + (*node).cfgname = write_node_list(&stmt.cfgname); + (*node).tokentype = write_node_list(&stmt.tokentype); + (*node).dicts = write_node_list(&stmt.dicts); + (*node).override_ = stmt.r#override; + (*node).replace = stmt.replace; + (*node).missing_ok = stmt.missing_ok; + node +} + +// ============================================================================= +// Database statements +// ============================================================================= + +unsafe fn write_createdb_stmt(stmt: &protobuf::CreatedbStmt) -> *mut bindings_raw::CreatedbStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreatedbStmt); + (*node).dbname = pstrdup(&stmt.dbname); + (*node).options = write_node_list(&stmt.options); + node +} + +unsafe fn write_dropdb_stmt(stmt: &protobuf::DropdbStmt) -> *mut bindings_raw::DropdbStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_DropdbStmt); + (*node).dbname = pstrdup(&stmt.dbname); + (*node).missing_ok = stmt.missing_ok; + (*node).options = write_node_list(&stmt.options); + node +} + +unsafe fn write_alter_database_stmt(stmt: &protobuf::AlterDatabaseStmt) -> *mut bindings_raw::AlterDatabaseStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterDatabaseStmt); + (*node).dbname = pstrdup(&stmt.dbname); + (*node).options = write_node_list(&stmt.options); + node +} + +unsafe fn write_alter_database_set_stmt(stmt: &protobuf::AlterDatabaseSetStmt) -> *mut bindings_raw::AlterDatabaseSetStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterDatabaseSetStmt); + (*node).dbname = pstrdup(&stmt.dbname); + (*node).setstmt = write_variable_set_stmt_ref(&stmt.setstmt); + node +} + +unsafe fn write_alter_database_refresh_coll_stmt(stmt: &protobuf::AlterDatabaseRefreshCollStmt) -> *mut bindings_raw::AlterDatabaseRefreshCollStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterDatabaseRefreshCollStmt); + (*node).dbname = pstrdup(&stmt.dbname); + node +} + +unsafe fn write_alter_system_stmt(stmt: &protobuf::AlterSystemStmt) -> *mut bindings_raw::AlterSystemStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterSystemStmt); + (*node).setstmt = write_variable_set_stmt_ref(&stmt.setstmt); + node +} + +unsafe fn write_cluster_stmt(stmt: &protobuf::ClusterStmt) -> *mut bindings_raw::ClusterStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_ClusterStmt); + (*node).relation = write_range_var_ref(&stmt.relation); + (*node).indexname = pstrdup(&stmt.indexname); + (*node).params = write_node_list(&stmt.params); + node +} + +unsafe fn write_reindex_stmt(stmt: &protobuf::ReindexStmt) -> *mut bindings_raw::ReindexStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_ReindexStmt); + (*node).kind = proto_enum_to_c(stmt.kind); + (*node).relation = write_range_var_ref(&stmt.relation); + (*node).name = pstrdup(&stmt.name); + (*node).params = write_node_list(&stmt.params); + node +} + +unsafe fn write_constraints_set_stmt(stmt: &protobuf::ConstraintsSetStmt) -> *mut bindings_raw::ConstraintsSetStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_ConstraintsSetStmt); + (*node).constraints = write_node_list(&stmt.constraints); + (*node).deferred = stmt.deferred; + node +} + +unsafe fn write_load_stmt(stmt: &protobuf::LoadStmt) -> *mut bindings_raw::LoadStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_LoadStmt); + (*node).filename = pstrdup(&stmt.filename); + node +} + +unsafe fn write_drop_owned_stmt(stmt: &protobuf::DropOwnedStmt) -> *mut bindings_raw::DropOwnedStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_DropOwnedStmt); + (*node).roles = write_node_list(&stmt.roles); + (*node).behavior = proto_enum_to_c(stmt.behavior); + node +} + +unsafe fn write_reassign_owned_stmt(stmt: &protobuf::ReassignOwnedStmt) -> *mut bindings_raw::ReassignOwnedStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_ReassignOwnedStmt); + (*node).roles = write_node_list(&stmt.roles); + (*node).newrole = write_role_spec_ref(&stmt.newrole); + node +} + +unsafe fn write_drop_subscription_stmt(stmt: &protobuf::DropSubscriptionStmt) -> *mut bindings_raw::DropSubscriptionStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_DropSubscriptionStmt); + (*node).subname = pstrdup(&stmt.subname); + (*node).missing_ok = stmt.missing_ok; + (*node).behavior = proto_enum_to_c(stmt.behavior); + node +} + +// ============================================================================= +// Table-related nodes +// ============================================================================= + +unsafe fn write_table_func(stmt: &protobuf::TableFunc) -> *mut bindings_raw::TableFunc { + let node = alloc_node::(bindings_raw::NodeTag_T_TableFunc); + (*node).ns_uris = write_node_list(&stmt.ns_uris); + (*node).ns_names = write_node_list(&stmt.ns_names); + (*node).docexpr = write_node_boxed(&stmt.docexpr); + (*node).rowexpr = write_node_boxed(&stmt.rowexpr); + (*node).colnames = write_node_list(&stmt.colnames); + (*node).coltypes = write_node_list(&stmt.coltypes); + (*node).coltypmods = write_node_list(&stmt.coltypmods); + (*node).colcollations = write_node_list(&stmt.colcollations); + (*node).colexprs = write_node_list(&stmt.colexprs); + (*node).coldefexprs = write_node_list(&stmt.coldefexprs); + (*node).ordinalitycol = stmt.ordinalitycol; + (*node).location = stmt.location; + node +} + +unsafe fn write_table_like_clause(stmt: &protobuf::TableLikeClause) -> *mut bindings_raw::TableLikeClause { + let node = alloc_node::(bindings_raw::NodeTag_T_TableLikeClause); + (*node).relation = write_range_var_ref(&stmt.relation); + (*node).options = stmt.options; + (*node).relationOid = stmt.relation_oid; + node +} + +unsafe fn write_range_table_func(stmt: &protobuf::RangeTableFunc) -> *mut bindings_raw::RangeTableFunc { + let node = alloc_node::(bindings_raw::NodeTag_T_RangeTableFunc); + (*node).lateral = stmt.lateral; + (*node).docexpr = write_node_boxed(&stmt.docexpr); + (*node).rowexpr = write_node_boxed(&stmt.rowexpr); + (*node).namespaces = write_node_list(&stmt.namespaces); + (*node).columns = write_node_list(&stmt.columns); + (*node).alias = write_alias_ref(&stmt.alias); + (*node).location = stmt.location; + node +} + +unsafe fn write_range_table_func_col(stmt: &protobuf::RangeTableFuncCol) -> *mut bindings_raw::RangeTableFuncCol { + let node = alloc_node::(bindings_raw::NodeTag_T_RangeTableFuncCol); + (*node).colname = pstrdup(&stmt.colname); + (*node).typeName = write_type_name_ref(&stmt.type_name); + (*node).for_ordinality = stmt.for_ordinality; + (*node).is_not_null = stmt.is_not_null; + (*node).colexpr = write_node_boxed(&stmt.colexpr); + (*node).coldefexpr = write_node_boxed(&stmt.coldefexpr); + (*node).location = stmt.location; + node +} + +unsafe fn write_range_table_sample(stmt: &protobuf::RangeTableSample) -> *mut bindings_raw::RangeTableSample { + let node = alloc_node::(bindings_raw::NodeTag_T_RangeTableSample); + (*node).relation = write_node_boxed(&stmt.relation); + (*node).method = write_node_list(&stmt.method); + (*node).args = write_node_list(&stmt.args); + (*node).repeatable = write_node_boxed(&stmt.repeatable); + (*node).location = stmt.location; + node +} + +unsafe fn write_partition_spec(stmt: &protobuf::PartitionSpec) -> *mut bindings_raw::PartitionSpec { + let node = alloc_node::(bindings_raw::NodeTag_T_PartitionSpec); + (*node).strategy = proto_enum_to_c(stmt.strategy); + (*node).partParams = write_node_list(&stmt.part_params); + (*node).location = stmt.location; + node +} + +unsafe fn write_partition_bound_spec(stmt: &protobuf::PartitionBoundSpec) -> *mut bindings_raw::PartitionBoundSpec { + let node = alloc_node::(bindings_raw::NodeTag_T_PartitionBoundSpec); + (*node).strategy = if stmt.strategy.is_empty() { 0 } else { stmt.strategy.as_bytes()[0] as i8 }; + (*node).is_default = stmt.is_default; + (*node).modulus = stmt.modulus; + (*node).remainder = stmt.remainder; + (*node).listdatums = write_node_list(&stmt.listdatums); + (*node).lowerdatums = write_node_list(&stmt.lowerdatums); + (*node).upperdatums = write_node_list(&stmt.upperdatums); + (*node).location = stmt.location; + node +} + +unsafe fn write_partition_range_datum(stmt: &protobuf::PartitionRangeDatum) -> *mut bindings_raw::PartitionRangeDatum { + let node = alloc_node::(bindings_raw::NodeTag_T_PartitionRangeDatum); + (*node).kind = proto_enum_to_c(stmt.kind) as i32; + (*node).value = write_node_boxed(&stmt.value); + (*node).location = stmt.location; + node +} + +unsafe fn write_partition_elem(stmt: &protobuf::PartitionElem) -> *mut bindings_raw::PartitionElem { + let node = alloc_node::(bindings_raw::NodeTag_T_PartitionElem); + (*node).name = pstrdup(&stmt.name); + (*node).expr = write_node_boxed(&stmt.expr); + (*node).collation = write_node_list(&stmt.collation); + (*node).opclass = write_node_list(&stmt.opclass); + (*node).location = stmt.location; + node +} + +unsafe fn write_partition_cmd(stmt: &protobuf::PartitionCmd) -> *mut bindings_raw::PartitionCmd { + let node = alloc_node::(bindings_raw::NodeTag_T_PartitionCmd); + (*node).name = write_range_var_ref(&stmt.name); + (*node).bound = write_partition_bound_spec_ref(&stmt.bound); + (*node).concurrent = stmt.concurrent; + node +} + +unsafe fn write_single_partition_spec(_stmt: &protobuf::SinglePartitionSpec) -> *mut bindings_raw::SinglePartitionSpec { + // SinglePartitionSpec is an empty struct in protobuf + let node = alloc_node::(bindings_raw::NodeTag_T_SinglePartitionSpec); + node +} + +unsafe fn write_infer_clause(stmt: &protobuf::InferClause) -> *mut bindings_raw::InferClause { + let node = alloc_node::(bindings_raw::NodeTag_T_InferClause); + (*node).indexElems = write_node_list(&stmt.index_elems); + (*node).whereClause = write_node_boxed(&stmt.where_clause); + (*node).conname = pstrdup(&stmt.conname); + (*node).location = stmt.location; + node +} + +unsafe fn write_multi_assign_ref(stmt: &protobuf::MultiAssignRef) -> *mut bindings_raw::MultiAssignRef { + let node = alloc_node::(bindings_raw::NodeTag_T_MultiAssignRef); + (*node).source = write_node_boxed(&stmt.source); + (*node).colno = stmt.colno; + (*node).ncolumns = stmt.ncolumns; + node +} + +unsafe fn write_trigger_transition(stmt: &protobuf::TriggerTransition) -> *mut bindings_raw::TriggerTransition { + let node = alloc_node::(bindings_raw::NodeTag_T_TriggerTransition); + (*node).name = pstrdup(&stmt.name); + (*node).isNew = stmt.is_new; + (*node).isTable = stmt.is_table; + node +} + +// ============================================================================= +// CTE-related nodes +// ============================================================================= + +unsafe fn write_cte_search_clause(stmt: &protobuf::CteSearchClause) -> *mut bindings_raw::CTESearchClause { + let node = alloc_node::(bindings_raw::NodeTag_T_CTESearchClause); + (*node).search_col_list = write_node_list(&stmt.search_col_list); + (*node).search_breadth_first = stmt.search_breadth_first; + (*node).search_seq_column = pstrdup(&stmt.search_seq_column); + (*node).location = stmt.location; + node +} + +unsafe fn write_cte_cycle_clause(stmt: &protobuf::CteCycleClause) -> *mut bindings_raw::CTECycleClause { + let node = alloc_node::(bindings_raw::NodeTag_T_CTECycleClause); + (*node).cycle_col_list = write_node_list(&stmt.cycle_col_list); + (*node).cycle_mark_column = pstrdup(&stmt.cycle_mark_column); + (*node).cycle_mark_value = write_node_boxed(&stmt.cycle_mark_value); + (*node).cycle_mark_default = write_node_boxed(&stmt.cycle_mark_default); + (*node).cycle_path_column = pstrdup(&stmt.cycle_path_column); + (*node).location = stmt.location; + node +} + +// ============================================================================= +// Statistics nodes +// ============================================================================= + +unsafe fn write_create_stats_stmt(stmt: &protobuf::CreateStatsStmt) -> *mut bindings_raw::CreateStatsStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_CreateStatsStmt); + (*node).defnames = write_node_list(&stmt.defnames); + (*node).stat_types = write_node_list(&stmt.stat_types); + (*node).exprs = write_node_list(&stmt.exprs); + (*node).relations = write_node_list(&stmt.relations); + (*node).stxcomment = pstrdup(&stmt.stxcomment); + (*node).transformed = stmt.transformed; + (*node).if_not_exists = stmt.if_not_exists; + node +} + +unsafe fn write_alter_stats_stmt(stmt: &protobuf::AlterStatsStmt) -> *mut bindings_raw::AlterStatsStmt { + let node = alloc_node::(bindings_raw::NodeTag_T_AlterStatsStmt); + (*node).defnames = write_node_list(&stmt.defnames); + (*node).missing_ok = stmt.missing_ok; + (*node).stxstattarget = write_node_boxed(&stmt.stxstattarget); + node +} + +unsafe fn write_stats_elem(stmt: &protobuf::StatsElem) -> *mut bindings_raw::StatsElem { + let node = alloc_node::(bindings_raw::NodeTag_T_StatsElem); + (*node).name = pstrdup(&stmt.name); + (*node).expr = write_node_boxed(&stmt.expr); + node +} + +// ============================================================================= +// Publication nodes +// ============================================================================= + +unsafe fn write_publication_obj_spec(stmt: &protobuf::PublicationObjSpec) -> *mut bindings_raw::PublicationObjSpec { + let node = alloc_node::(bindings_raw::NodeTag_T_PublicationObjSpec); + (*node).pubobjtype = proto_enum_to_c(stmt.pubobjtype); + (*node).name = pstrdup(&stmt.name); + (*node).pubtable = write_publication_table_ref(&stmt.pubtable); + (*node).location = stmt.location; + node +} + +unsafe fn write_publication_table(stmt: &protobuf::PublicationTable) -> *mut bindings_raw::PublicationTable { + let node = alloc_node::(bindings_raw::NodeTag_T_PublicationTable); + (*node).relation = write_range_var_ref(&stmt.relation); + (*node).whereClause = write_node_boxed(&stmt.where_clause); + (*node).columns = write_node_list(&stmt.columns); + node +} + +// ============================================================================= +// SQL Value function +// ============================================================================= + +unsafe fn write_sql_value_function(stmt: &protobuf::SqlValueFunction) -> *mut bindings_raw::SQLValueFunction { + let node = alloc_node::(bindings_raw::NodeTag_T_SQLValueFunction); + (*node).op = proto_enum_to_c(stmt.op); + (*node).type_ = stmt.r#type; + (*node).typmod = stmt.typmod; + (*node).location = stmt.location; + node +} + +// ============================================================================= +// XML nodes +// ============================================================================= + +unsafe fn write_xml_expr(stmt: &protobuf::XmlExpr) -> *mut bindings_raw::XmlExpr { + let node = alloc_node::(bindings_raw::NodeTag_T_XmlExpr); + (*node).op = proto_enum_to_c(stmt.op); + (*node).name = pstrdup(&stmt.name); + (*node).named_args = write_node_list(&stmt.named_args); + (*node).arg_names = write_node_list(&stmt.arg_names); + (*node).args = write_node_list(&stmt.args); + (*node).xmloption = proto_enum_to_c(stmt.xmloption); + (*node).indent = stmt.indent; + (*node).type_ = stmt.r#type; + (*node).typmod = stmt.typmod; + (*node).location = stmt.location; + node +} + +unsafe fn write_xml_serialize(stmt: &protobuf::XmlSerialize) -> *mut bindings_raw::XmlSerialize { + let node = alloc_node::(bindings_raw::NodeTag_T_XmlSerialize); + (*node).xmloption = proto_enum_to_c(stmt.xmloption); + (*node).expr = write_node_boxed(&stmt.expr); + (*node).typeName = write_type_name_ref(&stmt.type_name); + (*node).indent = stmt.indent; + (*node).location = stmt.location; + node +} + +// ============================================================================= +// Named argument +// ============================================================================= + +unsafe fn write_named_arg_expr(stmt: &protobuf::NamedArgExpr) -> *mut bindings_raw::NamedArgExpr { + let node = alloc_node::(bindings_raw::NodeTag_T_NamedArgExpr); + (*node).arg = write_node_boxed(&stmt.arg) as *mut bindings_raw::Expr; + (*node).name = pstrdup(&stmt.name); + (*node).argnumber = stmt.argnumber; + (*node).location = stmt.location; + node +} + +// ============================================================================= +// JSON nodes +// ============================================================================= + +unsafe fn write_json_format(stmt: &protobuf::JsonFormat) -> *mut bindings_raw::JsonFormat { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonFormat); + (*node).format_type = proto_enum_to_c(stmt.format_type); + (*node).encoding = proto_enum_to_c(stmt.encoding); + (*node).location = stmt.location; + node +} + +unsafe fn write_json_returning(stmt: &protobuf::JsonReturning) -> *mut bindings_raw::JsonReturning { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonReturning); + (*node).format = write_json_format_ref(&stmt.format); + (*node).typid = stmt.typid; + (*node).typmod = stmt.typmod; + node +} + +unsafe fn write_json_value_expr(stmt: &protobuf::JsonValueExpr) -> *mut bindings_raw::JsonValueExpr { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonValueExpr); + (*node).raw_expr = write_node_boxed(&stmt.raw_expr) as *mut bindings_raw::Expr; + (*node).formatted_expr = write_node_boxed(&stmt.formatted_expr) as *mut bindings_raw::Expr; + (*node).format = write_json_format_ref(&stmt.format); + node +} + +unsafe fn write_json_constructor_expr(stmt: &protobuf::JsonConstructorExpr) -> *mut bindings_raw::JsonConstructorExpr { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonConstructorExpr); + (*node).type_ = proto_enum_to_c(stmt.r#type); + (*node).args = write_node_list(&stmt.args); + (*node).func = write_node_boxed(&stmt.func) as *mut bindings_raw::Expr; + (*node).coercion = write_node_boxed(&stmt.coercion) as *mut bindings_raw::Expr; + (*node).returning = write_json_returning_ref(&stmt.returning); + (*node).absent_on_null = stmt.absent_on_null; + (*node).unique = stmt.unique; + (*node).location = stmt.location; + node +} + +unsafe fn write_json_is_predicate(stmt: &protobuf::JsonIsPredicate) -> *mut bindings_raw::JsonIsPredicate { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonIsPredicate); + (*node).expr = write_node_boxed(&stmt.expr); + (*node).format = write_json_format_ref(&stmt.format); + (*node).item_type = proto_enum_to_c(stmt.item_type); + (*node).unique_keys = stmt.unique_keys; + (*node).location = stmt.location; + node +} + +unsafe fn write_json_behavior(stmt: &protobuf::JsonBehavior) -> *mut bindings_raw::JsonBehavior { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonBehavior); + (*node).btype = proto_enum_to_c(stmt.btype); + (*node).expr = write_node_boxed(&stmt.expr); + (*node).coerce = stmt.coerce; + (*node).location = stmt.location; + node +} + +unsafe fn write_json_expr(stmt: &protobuf::JsonExpr) -> *mut bindings_raw::JsonExpr { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonExpr); + (*node).op = proto_enum_to_c(stmt.op); + (*node).column_name = pstrdup(&stmt.column_name); + (*node).formatted_expr = write_node_boxed(&stmt.formatted_expr); + (*node).format = write_json_format_ref(&stmt.format); + (*node).path_spec = write_node_boxed(&stmt.path_spec); + (*node).returning = write_json_returning_ref(&stmt.returning); + (*node).passing_names = write_node_list(&stmt.passing_names); + (*node).passing_values = write_node_list(&stmt.passing_values); + (*node).on_empty = write_json_behavior_ref(&stmt.on_empty); + (*node).on_error = write_json_behavior_ref(&stmt.on_error); + (*node).use_io_coercion = stmt.use_io_coercion; + (*node).use_json_coercion = stmt.use_json_coercion; + (*node).wrapper = proto_enum_to_c(stmt.wrapper); + (*node).omit_quotes = stmt.omit_quotes; + (*node).collation = stmt.collation; + (*node).location = stmt.location; + node +} + +unsafe fn write_json_table_path(stmt: &protobuf::JsonTablePath) -> *mut bindings_raw::JsonTablePath { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonTablePath); + // value is only populated after semantic analysis, not in raw parse tree + (*node).value = std::ptr::null_mut(); + (*node).name = pstrdup(&stmt.name); + node +} + +unsafe fn write_json_table_path_scan(stmt: &protobuf::JsonTablePathScan) -> *mut bindings_raw::JsonTablePathScan { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonTablePathScan); + (*node).path = write_json_table_path_opt(&stmt.path); + (*node).errorOnError = stmt.error_on_error; + (*node).child = write_node_boxed(&stmt.child) as *mut bindings_raw::JsonTablePlan; + (*node).colMin = stmt.col_min; + (*node).colMax = stmt.col_max; + node +} + +unsafe fn write_json_table_sibling_join(stmt: &protobuf::JsonTableSiblingJoin) -> *mut bindings_raw::JsonTableSiblingJoin { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonTableSiblingJoin); + (*node).lplan = write_node_boxed(&stmt.lplan) as *mut bindings_raw::JsonTablePlan; + (*node).rplan = write_node_boxed(&stmt.rplan) as *mut bindings_raw::JsonTablePlan; + node +} + +unsafe fn write_json_output(stmt: &protobuf::JsonOutput) -> *mut bindings_raw::JsonOutput { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonOutput); + (*node).typeName = write_type_name_ref(&stmt.type_name); + (*node).returning = write_json_returning_ref(&stmt.returning); + node +} + +unsafe fn write_json_argument(stmt: &protobuf::JsonArgument) -> *mut bindings_raw::JsonArgument { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonArgument); + (*node).val = write_json_value_expr_ref(&stmt.val); + (*node).name = pstrdup(&stmt.name); + node +} + +unsafe fn write_json_func_expr(stmt: &protobuf::JsonFuncExpr) -> *mut bindings_raw::JsonFuncExpr { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonFuncExpr); + (*node).op = proto_enum_to_c(stmt.op); + (*node).column_name = pstrdup(&stmt.column_name); + (*node).context_item = write_json_value_expr_ref(&stmt.context_item); + (*node).pathspec = write_node_boxed(&stmt.pathspec); + (*node).passing = write_node_list(&stmt.passing); + (*node).output = write_json_output_ref(&stmt.output); + (*node).on_empty = write_json_behavior_ref(&stmt.on_empty); + (*node).on_error = write_json_behavior_ref(&stmt.on_error); + (*node).wrapper = proto_enum_to_c(stmt.wrapper); + (*node).quotes = proto_enum_to_c(stmt.quotes); + (*node).location = stmt.location; + node +} + +unsafe fn write_json_table_path_spec(stmt: &protobuf::JsonTablePathSpec) -> *mut bindings_raw::JsonTablePathSpec { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonTablePathSpec); + (*node).string = write_node_boxed(&stmt.string); + (*node).name = pstrdup(&stmt.name); + (*node).name_location = stmt.name_location; + (*node).location = stmt.location; + node +} + +unsafe fn write_json_table(stmt: &protobuf::JsonTable) -> *mut bindings_raw::JsonTable { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonTable); + (*node).context_item = write_json_value_expr_ref(&stmt.context_item); + (*node).pathspec = write_json_table_path_spec_ref(&stmt.pathspec); + (*node).passing = write_node_list(&stmt.passing); + (*node).columns = write_node_list(&stmt.columns); + (*node).on_error = write_json_behavior_ref(&stmt.on_error); + (*node).alias = write_alias_ref(&stmt.alias); + (*node).lateral = stmt.lateral; + (*node).location = stmt.location; + node +} + +unsafe fn write_json_table_column(stmt: &protobuf::JsonTableColumn) -> *mut bindings_raw::JsonTableColumn { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonTableColumn); + (*node).coltype = proto_enum_to_c(stmt.coltype); + (*node).name = pstrdup(&stmt.name); + (*node).typeName = write_type_name_ref(&stmt.type_name); + (*node).pathspec = write_json_table_path_spec_ref(&stmt.pathspec); + (*node).format = write_json_format_ref(&stmt.format); + (*node).wrapper = proto_enum_to_c(stmt.wrapper); + (*node).quotes = proto_enum_to_c(stmt.quotes); + (*node).columns = write_node_list(&stmt.columns); + (*node).on_empty = write_json_behavior_ref(&stmt.on_empty); + (*node).on_error = write_json_behavior_ref(&stmt.on_error); + (*node).location = stmt.location; + node +} + +unsafe fn write_json_key_value(stmt: &protobuf::JsonKeyValue) -> *mut bindings_raw::JsonKeyValue { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonKeyValue); + (*node).key = write_node_boxed(&stmt.key) as *mut bindings_raw::Expr; + (*node).value = write_json_value_expr_ref(&stmt.value); + node +} + +unsafe fn write_json_parse_expr(stmt: &protobuf::JsonParseExpr) -> *mut bindings_raw::JsonParseExpr { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonParseExpr); + (*node).expr = write_json_value_expr_ref(&stmt.expr); + (*node).output = write_json_output_ref(&stmt.output); + (*node).unique_keys = stmt.unique_keys; + (*node).location = stmt.location; + node +} + +unsafe fn write_json_scalar_expr(stmt: &protobuf::JsonScalarExpr) -> *mut bindings_raw::JsonScalarExpr { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonScalarExpr); + (*node).expr = write_node_boxed(&stmt.expr) as *mut bindings_raw::Expr; + (*node).output = write_json_output_ref(&stmt.output); + (*node).location = stmt.location; + node +} + +unsafe fn write_json_serialize_expr(stmt: &protobuf::JsonSerializeExpr) -> *mut bindings_raw::JsonSerializeExpr { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonSerializeExpr); + (*node).expr = write_json_value_expr_ref(&stmt.expr); + (*node).output = write_json_output_ref(&stmt.output); + (*node).location = stmt.location; + node +} + +unsafe fn write_json_object_constructor(stmt: &protobuf::JsonObjectConstructor) -> *mut bindings_raw::JsonObjectConstructor { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonObjectConstructor); + (*node).exprs = write_node_list(&stmt.exprs); + (*node).output = write_json_output_ref(&stmt.output); + (*node).absent_on_null = stmt.absent_on_null; + (*node).unique = stmt.unique; + (*node).location = stmt.location; + node +} + +unsafe fn write_json_array_constructor(stmt: &protobuf::JsonArrayConstructor) -> *mut bindings_raw::JsonArrayConstructor { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonArrayConstructor); + (*node).exprs = write_node_list(&stmt.exprs); + (*node).output = write_json_output_ref(&stmt.output); + (*node).absent_on_null = stmt.absent_on_null; + (*node).location = stmt.location; + node +} + +unsafe fn write_json_array_query_constructor(stmt: &protobuf::JsonArrayQueryConstructor) -> *mut bindings_raw::JsonArrayQueryConstructor { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonArrayQueryConstructor); + (*node).query = write_node_boxed(&stmt.query); + (*node).output = write_json_output_ref(&stmt.output); + (*node).format = write_json_format_ref(&stmt.format); + (*node).absent_on_null = stmt.absent_on_null; + (*node).location = stmt.location; + node +} + +unsafe fn write_json_agg_constructor(stmt: &protobuf::JsonAggConstructor) -> *mut bindings_raw::JsonAggConstructor { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonAggConstructor); + (*node).output = write_json_output_ref(&stmt.output); + (*node).agg_filter = write_node_boxed(&stmt.agg_filter); + (*node).agg_order = write_node_list(&stmt.agg_order); + (*node).over = write_window_def_boxed_ref(&stmt.over); + (*node).location = stmt.location; + node +} + +unsafe fn write_json_object_agg(stmt: &protobuf::JsonObjectAgg) -> *mut bindings_raw::JsonObjectAgg { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonObjectAgg); + (*node).constructor = write_json_agg_constructor_ref(&stmt.constructor); + (*node).arg = write_json_key_value_ref(&stmt.arg); + (*node).absent_on_null = stmt.absent_on_null; + (*node).unique = stmt.unique; + node +} + +unsafe fn write_json_array_agg(stmt: &protobuf::JsonArrayAgg) -> *mut bindings_raw::JsonArrayAgg { + let node = alloc_node::(bindings_raw::NodeTag_T_JsonArrayAgg); + (*node).constructor = write_json_agg_constructor_ref(&stmt.constructor); + (*node).arg = write_json_value_expr_ref(&stmt.arg); + (*node).absent_on_null = stmt.absent_on_null; + node +} + +// ============================================================================= +// Additional helper functions for optional refs +// ============================================================================= + +unsafe fn write_role_spec_ref(role: &Option) -> *mut bindings_raw::RoleSpec { + match role { + Some(r) => write_role_spec(r), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_variable_set_stmt_ref(stmt: &Option) -> *mut bindings_raw::VariableSetStmt { + match stmt { + Some(s) => write_variable_set_stmt(s), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_object_with_args_ref(owa: &Option) -> *mut bindings_raw::ObjectWithArgs { + match owa { + Some(o) => write_object_with_args(o), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_partition_bound_spec_ref(pbs: &Option) -> *mut bindings_raw::PartitionBoundSpec { + match pbs { + Some(p) => write_partition_bound_spec(p), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_partition_spec_ref(ps: &Option) -> *mut bindings_raw::PartitionSpec { + match ps { + Some(p) => write_partition_spec(p), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_grant_stmt_ref(gs: &Option) -> *mut bindings_raw::GrantStmt { + match gs { + Some(g) => write_grant_stmt(g), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_string_ref(s: &Option) -> *mut bindings_raw::String { + match s { + Some(str_val) => write_string(str_val), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_publication_table_ref(pt: &Option>) -> *mut bindings_raw::PublicationTable { + match pt { + Some(p) => write_publication_table(p), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_json_format_ref(jf: &Option) -> *mut bindings_raw::JsonFormat { + match jf { + Some(f) => write_json_format(f), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_json_returning_ref(jr: &Option) -> *mut bindings_raw::JsonReturning { + match jr { + Some(r) => write_json_returning(r), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_json_behavior_ref(jb: &Option>) -> *mut bindings_raw::JsonBehavior { + match jb { + Some(b) => write_json_behavior(b), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_json_output_ref(jo: &Option) -> *mut bindings_raw::JsonOutput { + match jo { + Some(o) => write_json_output(o), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_json_value_expr_ref(jve: &Option>) -> *mut bindings_raw::JsonValueExpr { + match jve { + Some(v) => write_json_value_expr(v), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_json_table_path_spec_ref(jtps: &Option>) -> *mut bindings_raw::JsonTablePathSpec { + match jtps { + Some(p) => write_json_table_path_spec(p), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_window_def_boxed_ref(wd: &Option>) -> *mut bindings_raw::WindowDef { + match wd { + Some(w) => write_window_def(w), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_json_table_path_opt(jtp: &Option) -> *mut bindings_raw::JsonTablePath { + match jtp { + Some(p) => write_json_table_path(p), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_json_agg_constructor_ref(jac: &Option>) -> *mut bindings_raw::JsonAggConstructor { + match jac { + Some(c) => write_json_agg_constructor(c), + None => std::ptr::null_mut(), + } +} + +unsafe fn write_json_key_value_ref(jkv: &Option>) -> *mut bindings_raw::JsonKeyValue { + match jkv { + Some(k) => write_json_key_value(k), + None => std::ptr::null_mut(), + } +} diff --git a/src/raw_parse.rs b/src/raw_parse.rs index cfb7a7a..45b59b6 100644 --- a/src/raw_parse.rs +++ b/src/raw_parse.rs @@ -498,6 +498,456 @@ unsafe fn convert_node(node_ptr: *mut bindings_raw::Node) -> Option { + let bs = node_ptr as *mut bindings_raw::BitString; + Some(protobuf::node::Node::BitString(convert_bit_string(&*bs))) + } + bindings_raw::NodeTag_T_BooleanTest => { + let bt = node_ptr as *mut bindings_raw::BooleanTest; + Some(protobuf::node::Node::BooleanTest(Box::new(convert_boolean_test(&*bt)))) + } + bindings_raw::NodeTag_T_CreateRangeStmt => { + let crs = node_ptr as *mut bindings_raw::CreateRangeStmt; + Some(protobuf::node::Node::CreateRangeStmt(convert_create_range_stmt(&*crs))) + } + bindings_raw::NodeTag_T_AlterEnumStmt => { + let aes = node_ptr as *mut bindings_raw::AlterEnumStmt; + Some(protobuf::node::Node::AlterEnumStmt(convert_alter_enum_stmt(&*aes))) + } + bindings_raw::NodeTag_T_ClosePortalStmt => { + let cps = node_ptr as *mut bindings_raw::ClosePortalStmt; + Some(protobuf::node::Node::ClosePortalStmt(convert_close_portal_stmt(&*cps))) + } + bindings_raw::NodeTag_T_FetchStmt => { + let fs = node_ptr as *mut bindings_raw::FetchStmt; + Some(protobuf::node::Node::FetchStmt(convert_fetch_stmt(&*fs))) + } + bindings_raw::NodeTag_T_DeclareCursorStmt => { + let dcs = node_ptr as *mut bindings_raw::DeclareCursorStmt; + Some(protobuf::node::Node::DeclareCursorStmt(Box::new(convert_declare_cursor_stmt(&*dcs)))) + } + bindings_raw::NodeTag_T_DefineStmt => { + let ds = node_ptr as *mut bindings_raw::DefineStmt; + Some(protobuf::node::Node::DefineStmt(convert_define_stmt(&*ds))) + } + bindings_raw::NodeTag_T_CommentStmt => { + let cs = node_ptr as *mut bindings_raw::CommentStmt; + Some(protobuf::node::Node::CommentStmt(Box::new(convert_comment_stmt(&*cs)))) + } + bindings_raw::NodeTag_T_SecLabelStmt => { + let sls = node_ptr as *mut bindings_raw::SecLabelStmt; + Some(protobuf::node::Node::SecLabelStmt(Box::new(convert_sec_label_stmt(&*sls)))) + } + bindings_raw::NodeTag_T_CreateRoleStmt => { + let crs = node_ptr as *mut bindings_raw::CreateRoleStmt; + Some(protobuf::node::Node::CreateRoleStmt(convert_create_role_stmt(&*crs))) + } + bindings_raw::NodeTag_T_AlterRoleStmt => { + let ars = node_ptr as *mut bindings_raw::AlterRoleStmt; + Some(protobuf::node::Node::AlterRoleStmt(convert_alter_role_stmt(&*ars))) + } + bindings_raw::NodeTag_T_AlterRoleSetStmt => { + let arss = node_ptr as *mut bindings_raw::AlterRoleSetStmt; + Some(protobuf::node::Node::AlterRoleSetStmt(convert_alter_role_set_stmt(&*arss))) + } + bindings_raw::NodeTag_T_DropRoleStmt => { + let drs = node_ptr as *mut bindings_raw::DropRoleStmt; + Some(protobuf::node::Node::DropRoleStmt(convert_drop_role_stmt(&*drs))) + } + bindings_raw::NodeTag_T_CreatePolicyStmt => { + let cps = node_ptr as *mut bindings_raw::CreatePolicyStmt; + Some(protobuf::node::Node::CreatePolicyStmt(Box::new(convert_create_policy_stmt(&*cps)))) + } + bindings_raw::NodeTag_T_AlterPolicyStmt => { + let aps = node_ptr as *mut bindings_raw::AlterPolicyStmt; + Some(protobuf::node::Node::AlterPolicyStmt(Box::new(convert_alter_policy_stmt(&*aps)))) + } + bindings_raw::NodeTag_T_CreateEventTrigStmt => { + let cets = node_ptr as *mut bindings_raw::CreateEventTrigStmt; + Some(protobuf::node::Node::CreateEventTrigStmt(convert_create_event_trig_stmt(&*cets))) + } + bindings_raw::NodeTag_T_AlterEventTrigStmt => { + let aets = node_ptr as *mut bindings_raw::AlterEventTrigStmt; + Some(protobuf::node::Node::AlterEventTrigStmt(convert_alter_event_trig_stmt(&*aets))) + } + bindings_raw::NodeTag_T_CreatePLangStmt => { + let cpls = node_ptr as *mut bindings_raw::CreatePLangStmt; + Some(protobuf::node::Node::CreatePlangStmt(convert_create_plang_stmt(&*cpls))) + } + bindings_raw::NodeTag_T_CreateAmStmt => { + let cas = node_ptr as *mut bindings_raw::CreateAmStmt; + Some(protobuf::node::Node::CreateAmStmt(convert_create_am_stmt(&*cas))) + } + bindings_raw::NodeTag_T_CreateOpClassStmt => { + let cocs = node_ptr as *mut bindings_raw::CreateOpClassStmt; + Some(protobuf::node::Node::CreateOpClassStmt(convert_create_op_class_stmt(&*cocs))) + } + bindings_raw::NodeTag_T_CreateOpClassItem => { + let coci = node_ptr as *mut bindings_raw::CreateOpClassItem; + Some(protobuf::node::Node::CreateOpClassItem(convert_create_op_class_item(&*coci))) + } + bindings_raw::NodeTag_T_CreateOpFamilyStmt => { + let cofs = node_ptr as *mut bindings_raw::CreateOpFamilyStmt; + Some(protobuf::node::Node::CreateOpFamilyStmt(convert_create_op_family_stmt(&*cofs))) + } + bindings_raw::NodeTag_T_AlterOpFamilyStmt => { + let aofs = node_ptr as *mut bindings_raw::AlterOpFamilyStmt; + Some(protobuf::node::Node::AlterOpFamilyStmt(convert_alter_op_family_stmt(&*aofs))) + } + bindings_raw::NodeTag_T_CreateFdwStmt => { + let cfds = node_ptr as *mut bindings_raw::CreateFdwStmt; + Some(protobuf::node::Node::CreateFdwStmt(convert_create_fdw_stmt(&*cfds))) + } + bindings_raw::NodeTag_T_AlterFdwStmt => { + let afds = node_ptr as *mut bindings_raw::AlterFdwStmt; + Some(protobuf::node::Node::AlterFdwStmt(convert_alter_fdw_stmt(&*afds))) + } + bindings_raw::NodeTag_T_CreateForeignServerStmt => { + let cfss = node_ptr as *mut bindings_raw::CreateForeignServerStmt; + Some(protobuf::node::Node::CreateForeignServerStmt(convert_create_foreign_server_stmt(&*cfss))) + } + bindings_raw::NodeTag_T_AlterForeignServerStmt => { + let afss = node_ptr as *mut bindings_raw::AlterForeignServerStmt; + Some(protobuf::node::Node::AlterForeignServerStmt(convert_alter_foreign_server_stmt(&*afss))) + } + bindings_raw::NodeTag_T_CreateForeignTableStmt => { + let cfts = node_ptr as *mut bindings_raw::CreateForeignTableStmt; + Some(protobuf::node::Node::CreateForeignTableStmt(convert_create_foreign_table_stmt(&*cfts))) + } + bindings_raw::NodeTag_T_CreateUserMappingStmt => { + let cums = node_ptr as *mut bindings_raw::CreateUserMappingStmt; + Some(protobuf::node::Node::CreateUserMappingStmt(convert_create_user_mapping_stmt(&*cums))) + } + bindings_raw::NodeTag_T_AlterUserMappingStmt => { + let aums = node_ptr as *mut bindings_raw::AlterUserMappingStmt; + Some(protobuf::node::Node::AlterUserMappingStmt(convert_alter_user_mapping_stmt(&*aums))) + } + bindings_raw::NodeTag_T_DropUserMappingStmt => { + let dums = node_ptr as *mut bindings_raw::DropUserMappingStmt; + Some(protobuf::node::Node::DropUserMappingStmt(convert_drop_user_mapping_stmt(&*dums))) + } + bindings_raw::NodeTag_T_ImportForeignSchemaStmt => { + let ifss = node_ptr as *mut bindings_raw::ImportForeignSchemaStmt; + Some(protobuf::node::Node::ImportForeignSchemaStmt(convert_import_foreign_schema_stmt(&*ifss))) + } + bindings_raw::NodeTag_T_CreateTableSpaceStmt => { + let ctss = node_ptr as *mut bindings_raw::CreateTableSpaceStmt; + Some(protobuf::node::Node::CreateTableSpaceStmt(convert_create_table_space_stmt(&*ctss))) + } + bindings_raw::NodeTag_T_DropTableSpaceStmt => { + let dtss = node_ptr as *mut bindings_raw::DropTableSpaceStmt; + Some(protobuf::node::Node::DropTableSpaceStmt(convert_drop_table_space_stmt(&*dtss))) + } + bindings_raw::NodeTag_T_AlterTableSpaceOptionsStmt => { + let atsos = node_ptr as *mut bindings_raw::AlterTableSpaceOptionsStmt; + Some(protobuf::node::Node::AlterTableSpaceOptionsStmt(convert_alter_table_space_options_stmt(&*atsos))) + } + bindings_raw::NodeTag_T_AlterTableMoveAllStmt => { + let atmas = node_ptr as *mut bindings_raw::AlterTableMoveAllStmt; + Some(protobuf::node::Node::AlterTableMoveAllStmt(convert_alter_table_move_all_stmt(&*atmas))) + } + bindings_raw::NodeTag_T_AlterExtensionStmt => { + let aes = node_ptr as *mut bindings_raw::AlterExtensionStmt; + Some(protobuf::node::Node::AlterExtensionStmt(convert_alter_extension_stmt(&*aes))) + } + bindings_raw::NodeTag_T_AlterExtensionContentsStmt => { + let aecs = node_ptr as *mut bindings_raw::AlterExtensionContentsStmt; + Some(protobuf::node::Node::AlterExtensionContentsStmt(Box::new(convert_alter_extension_contents_stmt(&*aecs)))) + } + bindings_raw::NodeTag_T_AlterDomainStmt => { + let ads = node_ptr as *mut bindings_raw::AlterDomainStmt; + Some(protobuf::node::Node::AlterDomainStmt(Box::new(convert_alter_domain_stmt(&*ads)))) + } + bindings_raw::NodeTag_T_AlterFunctionStmt => { + let afs = node_ptr as *mut bindings_raw::AlterFunctionStmt; + Some(protobuf::node::Node::AlterFunctionStmt(convert_alter_function_stmt(&*afs))) + } + bindings_raw::NodeTag_T_AlterOperatorStmt => { + let aos = node_ptr as *mut bindings_raw::AlterOperatorStmt; + Some(protobuf::node::Node::AlterOperatorStmt(convert_alter_operator_stmt(&*aos))) + } + bindings_raw::NodeTag_T_AlterTypeStmt => { + let ats = node_ptr as *mut bindings_raw::AlterTypeStmt; + Some(protobuf::node::Node::AlterTypeStmt(convert_alter_type_stmt(&*ats))) + } + bindings_raw::NodeTag_T_AlterObjectSchemaStmt => { + let aoss = node_ptr as *mut bindings_raw::AlterObjectSchemaStmt; + Some(protobuf::node::Node::AlterObjectSchemaStmt(Box::new(convert_alter_object_schema_stmt(&*aoss)))) + } + bindings_raw::NodeTag_T_AlterObjectDependsStmt => { + let aods = node_ptr as *mut bindings_raw::AlterObjectDependsStmt; + Some(protobuf::node::Node::AlterObjectDependsStmt(Box::new(convert_alter_object_depends_stmt(&*aods)))) + } + bindings_raw::NodeTag_T_AlterCollationStmt => { + let acs = node_ptr as *mut bindings_raw::AlterCollationStmt; + Some(protobuf::node::Node::AlterCollationStmt(convert_alter_collation_stmt(&*acs))) + } + bindings_raw::NodeTag_T_AlterDefaultPrivilegesStmt => { + let adps = node_ptr as *mut bindings_raw::AlterDefaultPrivilegesStmt; + Some(protobuf::node::Node::AlterDefaultPrivilegesStmt(convert_alter_default_privileges_stmt(&*adps))) + } + bindings_raw::NodeTag_T_CreateCastStmt => { + let ccs = node_ptr as *mut bindings_raw::CreateCastStmt; + Some(protobuf::node::Node::CreateCastStmt(convert_create_cast_stmt(&*ccs))) + } + bindings_raw::NodeTag_T_CreateTransformStmt => { + let cts = node_ptr as *mut bindings_raw::CreateTransformStmt; + Some(protobuf::node::Node::CreateTransformStmt(convert_create_transform_stmt(&*cts))) + } + bindings_raw::NodeTag_T_CreateConversionStmt => { + let ccs = node_ptr as *mut bindings_raw::CreateConversionStmt; + Some(protobuf::node::Node::CreateConversionStmt(convert_create_conversion_stmt(&*ccs))) + } + bindings_raw::NodeTag_T_AlterTSDictionaryStmt => { + let atds = node_ptr as *mut bindings_raw::AlterTSDictionaryStmt; + Some(protobuf::node::Node::AlterTsdictionaryStmt(convert_alter_ts_dictionary_stmt(&*atds))) + } + bindings_raw::NodeTag_T_AlterTSConfigurationStmt => { + let atcs = node_ptr as *mut bindings_raw::AlterTSConfigurationStmt; + Some(protobuf::node::Node::AlterTsconfigurationStmt(convert_alter_ts_configuration_stmt(&*atcs))) + } + bindings_raw::NodeTag_T_CreatedbStmt => { + let cds = node_ptr as *mut bindings_raw::CreatedbStmt; + Some(protobuf::node::Node::CreatedbStmt(convert_createdb_stmt(&*cds))) + } + bindings_raw::NodeTag_T_DropdbStmt => { + let dds = node_ptr as *mut bindings_raw::DropdbStmt; + Some(protobuf::node::Node::DropdbStmt(convert_dropdb_stmt(&*dds))) + } + bindings_raw::NodeTag_T_AlterDatabaseStmt => { + let ads = node_ptr as *mut bindings_raw::AlterDatabaseStmt; + Some(protobuf::node::Node::AlterDatabaseStmt(convert_alter_database_stmt(&*ads))) + } + bindings_raw::NodeTag_T_AlterDatabaseSetStmt => { + let adss = node_ptr as *mut bindings_raw::AlterDatabaseSetStmt; + Some(protobuf::node::Node::AlterDatabaseSetStmt(convert_alter_database_set_stmt(&*adss))) + } + bindings_raw::NodeTag_T_AlterDatabaseRefreshCollStmt => { + let adrcs = node_ptr as *mut bindings_raw::AlterDatabaseRefreshCollStmt; + Some(protobuf::node::Node::AlterDatabaseRefreshCollStmt(convert_alter_database_refresh_coll_stmt(&*adrcs))) + } + bindings_raw::NodeTag_T_AlterSystemStmt => { + let ass = node_ptr as *mut bindings_raw::AlterSystemStmt; + Some(protobuf::node::Node::AlterSystemStmt(convert_alter_system_stmt(&*ass))) + } + bindings_raw::NodeTag_T_ClusterStmt => { + let cs = node_ptr as *mut bindings_raw::ClusterStmt; + Some(protobuf::node::Node::ClusterStmt(convert_cluster_stmt(&*cs))) + } + bindings_raw::NodeTag_T_ReindexStmt => { + let rs = node_ptr as *mut bindings_raw::ReindexStmt; + Some(protobuf::node::Node::ReindexStmt(convert_reindex_stmt(&*rs))) + } + bindings_raw::NodeTag_T_ConstraintsSetStmt => { + let css = node_ptr as *mut bindings_raw::ConstraintsSetStmt; + Some(protobuf::node::Node::ConstraintsSetStmt(convert_constraints_set_stmt(&*css))) + } + bindings_raw::NodeTag_T_LoadStmt => { + let ls = node_ptr as *mut bindings_raw::LoadStmt; + Some(protobuf::node::Node::LoadStmt(convert_load_stmt(&*ls))) + } + bindings_raw::NodeTag_T_DropOwnedStmt => { + let dos = node_ptr as *mut bindings_raw::DropOwnedStmt; + Some(protobuf::node::Node::DropOwnedStmt(convert_drop_owned_stmt(&*dos))) + } + bindings_raw::NodeTag_T_ReassignOwnedStmt => { + let ros = node_ptr as *mut bindings_raw::ReassignOwnedStmt; + Some(protobuf::node::Node::ReassignOwnedStmt(convert_reassign_owned_stmt(&*ros))) + } + bindings_raw::NodeTag_T_DropSubscriptionStmt => { + let dss = node_ptr as *mut bindings_raw::DropSubscriptionStmt; + Some(protobuf::node::Node::DropSubscriptionStmt(convert_drop_subscription_stmt(&*dss))) + } + bindings_raw::NodeTag_T_TableFunc => { + let tf = node_ptr as *mut bindings_raw::TableFunc; + Some(protobuf::node::Node::TableFunc(Box::new(convert_table_func(&*tf)))) + } + bindings_raw::NodeTag_T_IntoClause => { + let ic = node_ptr as *mut bindings_raw::IntoClause; + Some(protobuf::node::Node::IntoClause(Box::new(convert_into_clause_node(&*ic)))) + } + bindings_raw::NodeTag_T_TableLikeClause => { + let tlc = node_ptr as *mut bindings_raw::TableLikeClause; + Some(protobuf::node::Node::TableLikeClause(convert_table_like_clause(&*tlc))) + } + bindings_raw::NodeTag_T_RangeTableFunc => { + let rtf = node_ptr as *mut bindings_raw::RangeTableFunc; + Some(protobuf::node::Node::RangeTableFunc(Box::new(convert_range_table_func(&*rtf)))) + } + bindings_raw::NodeTag_T_RangeTableFuncCol => { + let rtfc = node_ptr as *mut bindings_raw::RangeTableFuncCol; + Some(protobuf::node::Node::RangeTableFuncCol(Box::new(convert_range_table_func_col(&*rtfc)))) + } + bindings_raw::NodeTag_T_RangeTableSample => { + let rts = node_ptr as *mut bindings_raw::RangeTableSample; + Some(protobuf::node::Node::RangeTableSample(Box::new(convert_range_table_sample(&*rts)))) + } + bindings_raw::NodeTag_T_PartitionSpec => { + let ps = node_ptr as *mut bindings_raw::PartitionSpec; + Some(protobuf::node::Node::PartitionSpec(convert_partition_spec(&*ps))) + } + bindings_raw::NodeTag_T_PartitionBoundSpec => { + let pbs = node_ptr as *mut bindings_raw::PartitionBoundSpec; + Some(protobuf::node::Node::PartitionBoundSpec(convert_partition_bound_spec(&*pbs))) + } + bindings_raw::NodeTag_T_PartitionCmd => { + let pc = node_ptr as *mut bindings_raw::PartitionCmd; + Some(protobuf::node::Node::PartitionCmd(convert_partition_cmd(&*pc))) + } + bindings_raw::NodeTag_T_SinglePartitionSpec => Some(protobuf::node::Node::SinglePartitionSpec(protobuf::SinglePartitionSpec {})), + bindings_raw::NodeTag_T_InferClause => { + let ic = node_ptr as *mut bindings_raw::InferClause; + convert_infer_clause(ic).map(|c| protobuf::node::Node::InferClause(c)) + } + bindings_raw::NodeTag_T_OnConflictClause => { + let occ = node_ptr as *mut bindings_raw::OnConflictClause; + Some(protobuf::node::Node::OnConflictClause(Box::new(convert_on_conflict_clause_node(&*occ)))) + } + bindings_raw::NodeTag_T_TriggerTransition => { + let tt = node_ptr as *mut bindings_raw::TriggerTransition; + Some(protobuf::node::Node::TriggerTransition(convert_trigger_transition(&*tt))) + } + bindings_raw::NodeTag_T_CTESearchClause => { + let csc = node_ptr as *mut bindings_raw::CTESearchClause; + Some(protobuf::node::Node::CtesearchClause(convert_cte_search_clause(&*csc))) + } + bindings_raw::NodeTag_T_CTECycleClause => { + let ccc = node_ptr as *mut bindings_raw::CTECycleClause; + Some(protobuf::node::Node::CtecycleClause(Box::new(convert_cte_cycle_clause(&*ccc)))) + } + bindings_raw::NodeTag_T_CreateStatsStmt => { + let css = node_ptr as *mut bindings_raw::CreateStatsStmt; + Some(protobuf::node::Node::CreateStatsStmt(convert_create_stats_stmt(&*css))) + } + bindings_raw::NodeTag_T_AlterStatsStmt => { + let ass = node_ptr as *mut bindings_raw::AlterStatsStmt; + Some(protobuf::node::Node::AlterStatsStmt(Box::new(convert_alter_stats_stmt(&*ass)))) + } + bindings_raw::NodeTag_T_StatsElem => { + let se = node_ptr as *mut bindings_raw::StatsElem; + Some(protobuf::node::Node::StatsElem(Box::new(convert_stats_elem(&*se)))) + } + bindings_raw::NodeTag_T_SQLValueFunction => { + let svf = node_ptr as *mut bindings_raw::SQLValueFunction; + Some(protobuf::node::Node::SqlvalueFunction(Box::new(convert_sql_value_function(&*svf)))) + } + bindings_raw::NodeTag_T_XmlExpr => { + let xe = node_ptr as *mut bindings_raw::XmlExpr; + Some(protobuf::node::Node::XmlExpr(Box::new(convert_xml_expr(&*xe)))) + } + bindings_raw::NodeTag_T_XmlSerialize => { + let xs = node_ptr as *mut bindings_raw::XmlSerialize; + Some(protobuf::node::Node::XmlSerialize(Box::new(convert_xml_serialize(&*xs)))) + } + bindings_raw::NodeTag_T_NamedArgExpr => { + let nae = node_ptr as *mut bindings_raw::NamedArgExpr; + Some(protobuf::node::Node::NamedArgExpr(Box::new(convert_named_arg_expr(&*nae)))) + } + // JSON nodes + bindings_raw::NodeTag_T_JsonFormat => { + let jf = node_ptr as *mut bindings_raw::JsonFormat; + Some(protobuf::node::Node::JsonFormat(convert_json_format(&*jf))) + } + bindings_raw::NodeTag_T_JsonReturning => { + let jr = node_ptr as *mut bindings_raw::JsonReturning; + Some(protobuf::node::Node::JsonReturning(convert_json_returning(&*jr))) + } + bindings_raw::NodeTag_T_JsonValueExpr => { + let jve = node_ptr as *mut bindings_raw::JsonValueExpr; + Some(protobuf::node::Node::JsonValueExpr(Box::new(convert_json_value_expr(&*jve)))) + } + bindings_raw::NodeTag_T_JsonConstructorExpr => { + let jce = node_ptr as *mut bindings_raw::JsonConstructorExpr; + Some(protobuf::node::Node::JsonConstructorExpr(Box::new(convert_json_constructor_expr(&*jce)))) + } + bindings_raw::NodeTag_T_JsonIsPredicate => { + let jip = node_ptr as *mut bindings_raw::JsonIsPredicate; + Some(protobuf::node::Node::JsonIsPredicate(Box::new(convert_json_is_predicate(&*jip)))) + } + bindings_raw::NodeTag_T_JsonBehavior => { + let jb = node_ptr as *mut bindings_raw::JsonBehavior; + Some(protobuf::node::Node::JsonBehavior(Box::new(convert_json_behavior(&*jb)))) + } + bindings_raw::NodeTag_T_JsonExpr => { + let je = node_ptr as *mut bindings_raw::JsonExpr; + Some(protobuf::node::Node::JsonExpr(Box::new(convert_json_expr(&*je)))) + } + bindings_raw::NodeTag_T_JsonTablePath => { + let jtp = node_ptr as *mut bindings_raw::JsonTablePath; + Some(protobuf::node::Node::JsonTablePath(convert_json_table_path(&*jtp))) + } + bindings_raw::NodeTag_T_JsonTablePathScan => { + let jtps = node_ptr as *mut bindings_raw::JsonTablePathScan; + Some(protobuf::node::Node::JsonTablePathScan(Box::new(convert_json_table_path_scan(&*jtps)))) + } + bindings_raw::NodeTag_T_JsonTableSiblingJoin => { + let jtsj = node_ptr as *mut bindings_raw::JsonTableSiblingJoin; + Some(protobuf::node::Node::JsonTableSiblingJoin(Box::new(convert_json_table_sibling_join(&*jtsj)))) + } + bindings_raw::NodeTag_T_JsonOutput => { + let jo = node_ptr as *mut bindings_raw::JsonOutput; + Some(protobuf::node::Node::JsonOutput(convert_json_output(&*jo))) + } + bindings_raw::NodeTag_T_JsonArgument => { + let ja = node_ptr as *mut bindings_raw::JsonArgument; + Some(protobuf::node::Node::JsonArgument(Box::new(convert_json_argument(&*ja)))) + } + bindings_raw::NodeTag_T_JsonFuncExpr => { + let jfe = node_ptr as *mut bindings_raw::JsonFuncExpr; + Some(protobuf::node::Node::JsonFuncExpr(Box::new(convert_json_func_expr(&*jfe)))) + } + bindings_raw::NodeTag_T_JsonTablePathSpec => { + let jtps = node_ptr as *mut bindings_raw::JsonTablePathSpec; + Some(protobuf::node::Node::JsonTablePathSpec(Box::new(convert_json_table_path_spec(&*jtps)))) + } + bindings_raw::NodeTag_T_JsonTable => { + let jt = node_ptr as *mut bindings_raw::JsonTable; + Some(protobuf::node::Node::JsonTable(Box::new(convert_json_table(&*jt)))) + } + bindings_raw::NodeTag_T_JsonTableColumn => { + let jtc = node_ptr as *mut bindings_raw::JsonTableColumn; + Some(protobuf::node::Node::JsonTableColumn(Box::new(convert_json_table_column(&*jtc)))) + } + bindings_raw::NodeTag_T_JsonKeyValue => { + let jkv = node_ptr as *mut bindings_raw::JsonKeyValue; + Some(protobuf::node::Node::JsonKeyValue(Box::new(convert_json_key_value(&*jkv)))) + } + bindings_raw::NodeTag_T_JsonParseExpr => { + let jpe = node_ptr as *mut bindings_raw::JsonParseExpr; + Some(protobuf::node::Node::JsonParseExpr(Box::new(convert_json_parse_expr(&*jpe)))) + } + bindings_raw::NodeTag_T_JsonScalarExpr => { + let jse = node_ptr as *mut bindings_raw::JsonScalarExpr; + Some(protobuf::node::Node::JsonScalarExpr(Box::new(convert_json_scalar_expr(&*jse)))) + } + bindings_raw::NodeTag_T_JsonSerializeExpr => { + let jse = node_ptr as *mut bindings_raw::JsonSerializeExpr; + Some(protobuf::node::Node::JsonSerializeExpr(Box::new(convert_json_serialize_expr(&*jse)))) + } + bindings_raw::NodeTag_T_JsonObjectConstructor => { + let joc = node_ptr as *mut bindings_raw::JsonObjectConstructor; + Some(protobuf::node::Node::JsonObjectConstructor(convert_json_object_constructor(&*joc))) + } + bindings_raw::NodeTag_T_JsonArrayConstructor => { + let jac = node_ptr as *mut bindings_raw::JsonArrayConstructor; + Some(protobuf::node::Node::JsonArrayConstructor(convert_json_array_constructor(&*jac))) + } + bindings_raw::NodeTag_T_JsonArrayQueryConstructor => { + let jaqc = node_ptr as *mut bindings_raw::JsonArrayQueryConstructor; + Some(protobuf::node::Node::JsonArrayQueryConstructor(Box::new(convert_json_array_query_constructor(&*jaqc)))) + } + bindings_raw::NodeTag_T_JsonAggConstructor => { + let jac = node_ptr as *mut bindings_raw::JsonAggConstructor; + Some(protobuf::node::Node::JsonAggConstructor(Box::new(convert_json_agg_constructor(&*jac)))) + } + bindings_raw::NodeTag_T_JsonObjectAgg => { + let joa = node_ptr as *mut bindings_raw::JsonObjectAgg; + Some(protobuf::node::Node::JsonObjectAgg(Box::new(convert_json_object_agg(&*joa)))) + } + bindings_raw::NodeTag_T_JsonArrayAgg => { + let jaa = node_ptr as *mut bindings_raw::JsonArrayAgg; + Some(protobuf::node::Node::JsonArrayAgg(Box::new(convert_json_array_agg(&*jaa)))) + } _ => { // For unhandled node types, return None // In the future, we could add more node types here @@ -1690,3 +2140,938 @@ unsafe fn convert_c_string(ptr: *const i8) -> std::string::String { CStr::from_ptr(ptr).to_string_lossy().to_string() } } + +// ============================================================================ +// New Node Conversions (matching raw_deparse.rs) +// ============================================================================ + +unsafe fn convert_bit_string(bs: &bindings_raw::BitString) -> protobuf::BitString { + protobuf::BitString { bsval: convert_c_string(bs.bsval) } +} + +unsafe fn convert_boolean_test(bt: &bindings_raw::BooleanTest) -> protobuf::BooleanTest { + let xpr_ptr = &bt.xpr as *const bindings_raw::Expr as *mut bindings_raw::Node; + protobuf::BooleanTest { + xpr: convert_node_boxed(xpr_ptr), + arg: convert_node_boxed(bt.arg as *mut bindings_raw::Node), + booltesttype: bt.booltesttype as i32 + 1, + location: bt.location, + } +} + +unsafe fn convert_create_range_stmt(crs: &bindings_raw::CreateRangeStmt) -> protobuf::CreateRangeStmt { + protobuf::CreateRangeStmt { type_name: convert_list_to_nodes(crs.typeName), params: convert_list_to_nodes(crs.params) } +} + +unsafe fn convert_alter_enum_stmt(aes: &bindings_raw::AlterEnumStmt) -> protobuf::AlterEnumStmt { + protobuf::AlterEnumStmt { + type_name: convert_list_to_nodes(aes.typeName), + old_val: convert_c_string(aes.oldVal), + new_val: convert_c_string(aes.newVal), + new_val_neighbor: convert_c_string(aes.newValNeighbor), + new_val_is_after: aes.newValIsAfter, + skip_if_new_val_exists: aes.skipIfNewValExists, + } +} + +unsafe fn convert_close_portal_stmt(cps: &bindings_raw::ClosePortalStmt) -> protobuf::ClosePortalStmt { + protobuf::ClosePortalStmt { portalname: convert_c_string(cps.portalname) } +} + +unsafe fn convert_fetch_stmt(fs: &bindings_raw::FetchStmt) -> protobuf::FetchStmt { + protobuf::FetchStmt { direction: fs.direction as i32 + 1, how_many: fs.howMany, portalname: convert_c_string(fs.portalname), ismove: fs.ismove } +} + +unsafe fn convert_declare_cursor_stmt(dcs: &bindings_raw::DeclareCursorStmt) -> protobuf::DeclareCursorStmt { + protobuf::DeclareCursorStmt { portalname: convert_c_string(dcs.portalname), options: dcs.options, query: convert_node_boxed(dcs.query) } +} + +unsafe fn convert_define_stmt(ds: &bindings_raw::DefineStmt) -> protobuf::DefineStmt { + protobuf::DefineStmt { + kind: ds.kind as i32 + 1, + oldstyle: ds.oldstyle, + defnames: convert_list_to_nodes(ds.defnames), + args: convert_list_to_nodes(ds.args), + definition: convert_list_to_nodes(ds.definition), + if_not_exists: ds.if_not_exists, + replace: ds.replace, + } +} + +unsafe fn convert_comment_stmt(cs: &bindings_raw::CommentStmt) -> protobuf::CommentStmt { + protobuf::CommentStmt { objtype: cs.objtype as i32 + 1, object: convert_node_boxed(cs.object), comment: convert_c_string(cs.comment) } +} + +unsafe fn convert_sec_label_stmt(sls: &bindings_raw::SecLabelStmt) -> protobuf::SecLabelStmt { + protobuf::SecLabelStmt { + objtype: sls.objtype as i32 + 1, + object: convert_node_boxed(sls.object), + provider: convert_c_string(sls.provider), + label: convert_c_string(sls.label), + } +} + +unsafe fn convert_create_role_stmt(crs: &bindings_raw::CreateRoleStmt) -> protobuf::CreateRoleStmt { + protobuf::CreateRoleStmt { stmt_type: crs.stmt_type as i32 + 1, role: convert_c_string(crs.role), options: convert_list_to_nodes(crs.options) } +} + +unsafe fn convert_alter_role_stmt(ars: &bindings_raw::AlterRoleStmt) -> protobuf::AlterRoleStmt { + protobuf::AlterRoleStmt { + role: if ars.role.is_null() { None } else { Some(convert_role_spec(&*ars.role)) }, + options: convert_list_to_nodes(ars.options), + action: ars.action, + } +} + +unsafe fn convert_alter_role_set_stmt(arss: &bindings_raw::AlterRoleSetStmt) -> protobuf::AlterRoleSetStmt { + protobuf::AlterRoleSetStmt { + role: if arss.role.is_null() { None } else { Some(convert_role_spec(&*arss.role)) }, + database: convert_c_string(arss.database), + setstmt: convert_variable_set_stmt_opt(arss.setstmt), + } +} + +unsafe fn convert_drop_role_stmt(drs: &bindings_raw::DropRoleStmt) -> protobuf::DropRoleStmt { + protobuf::DropRoleStmt { roles: convert_list_to_nodes(drs.roles), missing_ok: drs.missing_ok } +} + +unsafe fn convert_create_policy_stmt(cps: &bindings_raw::CreatePolicyStmt) -> protobuf::CreatePolicyStmt { + protobuf::CreatePolicyStmt { + policy_name: convert_c_string(cps.policy_name), + table: if cps.table.is_null() { None } else { Some(convert_range_var(&*cps.table)) }, + cmd_name: convert_c_string(cps.cmd_name), + permissive: cps.permissive, + roles: convert_list_to_nodes(cps.roles), + qual: convert_node_boxed(cps.qual), + with_check: convert_node_boxed(cps.with_check), + } +} + +unsafe fn convert_alter_policy_stmt(aps: &bindings_raw::AlterPolicyStmt) -> protobuf::AlterPolicyStmt { + protobuf::AlterPolicyStmt { + policy_name: convert_c_string(aps.policy_name), + table: if aps.table.is_null() { None } else { Some(convert_range_var(&*aps.table)) }, + roles: convert_list_to_nodes(aps.roles), + qual: convert_node_boxed(aps.qual), + with_check: convert_node_boxed(aps.with_check), + } +} + +unsafe fn convert_create_event_trig_stmt(cets: &bindings_raw::CreateEventTrigStmt) -> protobuf::CreateEventTrigStmt { + protobuf::CreateEventTrigStmt { + trigname: convert_c_string(cets.trigname), + eventname: convert_c_string(cets.eventname), + whenclause: convert_list_to_nodes(cets.whenclause), + funcname: convert_list_to_nodes(cets.funcname), + } +} + +unsafe fn convert_alter_event_trig_stmt(aets: &bindings_raw::AlterEventTrigStmt) -> protobuf::AlterEventTrigStmt { + protobuf::AlterEventTrigStmt { + trigname: convert_c_string(aets.trigname), + tgenabled: String::from_utf8_lossy(&[aets.tgenabled as u8]).to_string(), + } +} + +unsafe fn convert_create_plang_stmt(cpls: &bindings_raw::CreatePLangStmt) -> protobuf::CreatePLangStmt { + protobuf::CreatePLangStmt { + replace: cpls.replace, + plname: convert_c_string(cpls.plname), + plhandler: convert_list_to_nodes(cpls.plhandler), + plinline: convert_list_to_nodes(cpls.plinline), + plvalidator: convert_list_to_nodes(cpls.plvalidator), + pltrusted: cpls.pltrusted, + } +} + +unsafe fn convert_create_am_stmt(cas: &bindings_raw::CreateAmStmt) -> protobuf::CreateAmStmt { + protobuf::CreateAmStmt { + amname: convert_c_string(cas.amname), + handler_name: convert_list_to_nodes(cas.handler_name), + amtype: String::from_utf8_lossy(&[cas.amtype as u8]).to_string(), + } +} + +unsafe fn convert_create_op_class_stmt(cocs: &bindings_raw::CreateOpClassStmt) -> protobuf::CreateOpClassStmt { + protobuf::CreateOpClassStmt { + opclassname: convert_list_to_nodes(cocs.opclassname), + opfamilyname: convert_list_to_nodes(cocs.opfamilyname), + amname: convert_c_string(cocs.amname), + datatype: if cocs.datatype.is_null() { None } else { Some(convert_type_name(&*cocs.datatype)) }, + items: convert_list_to_nodes(cocs.items), + is_default: cocs.isDefault, + } +} + +unsafe fn convert_create_op_class_item(coci: &bindings_raw::CreateOpClassItem) -> protobuf::CreateOpClassItem { + protobuf::CreateOpClassItem { + itemtype: coci.itemtype, + name: if coci.name.is_null() { None } else { Some(convert_object_with_args(&*coci.name)) }, + number: coci.number, + order_family: convert_list_to_nodes(coci.order_family), + class_args: convert_list_to_nodes(coci.class_args), + storedtype: if coci.storedtype.is_null() { None } else { Some(convert_type_name(&*coci.storedtype)) }, + } +} + +unsafe fn convert_create_op_family_stmt(cofs: &bindings_raw::CreateOpFamilyStmt) -> protobuf::CreateOpFamilyStmt { + protobuf::CreateOpFamilyStmt { opfamilyname: convert_list_to_nodes(cofs.opfamilyname), amname: convert_c_string(cofs.amname) } +} + +unsafe fn convert_alter_op_family_stmt(aofs: &bindings_raw::AlterOpFamilyStmt) -> protobuf::AlterOpFamilyStmt { + protobuf::AlterOpFamilyStmt { + opfamilyname: convert_list_to_nodes(aofs.opfamilyname), + amname: convert_c_string(aofs.amname), + is_drop: aofs.isDrop, + items: convert_list_to_nodes(aofs.items), + } +} + +unsafe fn convert_create_fdw_stmt(cfds: &bindings_raw::CreateFdwStmt) -> protobuf::CreateFdwStmt { + protobuf::CreateFdwStmt { + fdwname: convert_c_string(cfds.fdwname), + func_options: convert_list_to_nodes(cfds.func_options), + options: convert_list_to_nodes(cfds.options), + } +} + +unsafe fn convert_alter_fdw_stmt(afds: &bindings_raw::AlterFdwStmt) -> protobuf::AlterFdwStmt { + protobuf::AlterFdwStmt { + fdwname: convert_c_string(afds.fdwname), + func_options: convert_list_to_nodes(afds.func_options), + options: convert_list_to_nodes(afds.options), + } +} + +unsafe fn convert_create_foreign_server_stmt(cfss: &bindings_raw::CreateForeignServerStmt) -> protobuf::CreateForeignServerStmt { + protobuf::CreateForeignServerStmt { + servername: convert_c_string(cfss.servername), + servertype: convert_c_string(cfss.servertype), + version: convert_c_string(cfss.version), + fdwname: convert_c_string(cfss.fdwname), + if_not_exists: cfss.if_not_exists, + options: convert_list_to_nodes(cfss.options), + } +} + +unsafe fn convert_alter_foreign_server_stmt(afss: &bindings_raw::AlterForeignServerStmt) -> protobuf::AlterForeignServerStmt { + protobuf::AlterForeignServerStmt { + servername: convert_c_string(afss.servername), + version: convert_c_string(afss.version), + options: convert_list_to_nodes(afss.options), + has_version: afss.has_version, + } +} + +unsafe fn convert_create_foreign_table_stmt(cfts: &bindings_raw::CreateForeignTableStmt) -> protobuf::CreateForeignTableStmt { + protobuf::CreateForeignTableStmt { + base_stmt: Some(convert_create_stmt(&cfts.base)), + servername: convert_c_string(cfts.servername), + options: convert_list_to_nodes(cfts.options), + } +} + +unsafe fn convert_create_user_mapping_stmt(cums: &bindings_raw::CreateUserMappingStmt) -> protobuf::CreateUserMappingStmt { + protobuf::CreateUserMappingStmt { + user: if cums.user.is_null() { None } else { Some(convert_role_spec(&*cums.user)) }, + servername: convert_c_string(cums.servername), + if_not_exists: cums.if_not_exists, + options: convert_list_to_nodes(cums.options), + } +} + +unsafe fn convert_alter_user_mapping_stmt(aums: &bindings_raw::AlterUserMappingStmt) -> protobuf::AlterUserMappingStmt { + protobuf::AlterUserMappingStmt { + user: if aums.user.is_null() { None } else { Some(convert_role_spec(&*aums.user)) }, + servername: convert_c_string(aums.servername), + options: convert_list_to_nodes(aums.options), + } +} + +unsafe fn convert_drop_user_mapping_stmt(dums: &bindings_raw::DropUserMappingStmt) -> protobuf::DropUserMappingStmt { + protobuf::DropUserMappingStmt { + user: if dums.user.is_null() { None } else { Some(convert_role_spec(&*dums.user)) }, + servername: convert_c_string(dums.servername), + missing_ok: dums.missing_ok, + } +} + +unsafe fn convert_import_foreign_schema_stmt(ifss: &bindings_raw::ImportForeignSchemaStmt) -> protobuf::ImportForeignSchemaStmt { + protobuf::ImportForeignSchemaStmt { + server_name: convert_c_string(ifss.server_name), + remote_schema: convert_c_string(ifss.remote_schema), + local_schema: convert_c_string(ifss.local_schema), + list_type: ifss.list_type as i32 + 1, + table_list: convert_list_to_nodes(ifss.table_list), + options: convert_list_to_nodes(ifss.options), + } +} + +unsafe fn convert_create_table_space_stmt(ctss: &bindings_raw::CreateTableSpaceStmt) -> protobuf::CreateTableSpaceStmt { + protobuf::CreateTableSpaceStmt { + tablespacename: convert_c_string(ctss.tablespacename), + owner: if ctss.owner.is_null() { None } else { Some(convert_role_spec(&*ctss.owner)) }, + location: convert_c_string(ctss.location), + options: convert_list_to_nodes(ctss.options), + } +} + +unsafe fn convert_drop_table_space_stmt(dtss: &bindings_raw::DropTableSpaceStmt) -> protobuf::DropTableSpaceStmt { + protobuf::DropTableSpaceStmt { tablespacename: convert_c_string(dtss.tablespacename), missing_ok: dtss.missing_ok } +} + +unsafe fn convert_alter_table_space_options_stmt(atsos: &bindings_raw::AlterTableSpaceOptionsStmt) -> protobuf::AlterTableSpaceOptionsStmt { + protobuf::AlterTableSpaceOptionsStmt { + tablespacename: convert_c_string(atsos.tablespacename), + options: convert_list_to_nodes(atsos.options), + is_reset: atsos.isReset, + } +} + +unsafe fn convert_alter_table_move_all_stmt(atmas: &bindings_raw::AlterTableMoveAllStmt) -> protobuf::AlterTableMoveAllStmt { + protobuf::AlterTableMoveAllStmt { + orig_tablespacename: convert_c_string(atmas.orig_tablespacename), + objtype: atmas.objtype as i32 + 1, + roles: convert_list_to_nodes(atmas.roles), + new_tablespacename: convert_c_string(atmas.new_tablespacename), + nowait: atmas.nowait, + } +} + +unsafe fn convert_alter_extension_stmt(aes: &bindings_raw::AlterExtensionStmt) -> protobuf::AlterExtensionStmt { + protobuf::AlterExtensionStmt { extname: convert_c_string(aes.extname), options: convert_list_to_nodes(aes.options) } +} + +unsafe fn convert_alter_extension_contents_stmt(aecs: &bindings_raw::AlterExtensionContentsStmt) -> protobuf::AlterExtensionContentsStmt { + protobuf::AlterExtensionContentsStmt { + extname: convert_c_string(aecs.extname), + action: aecs.action, + objtype: aecs.objtype as i32 + 1, + object: convert_node_boxed(aecs.object), + } +} + +unsafe fn convert_alter_domain_stmt(ads: &bindings_raw::AlterDomainStmt) -> protobuf::AlterDomainStmt { + protobuf::AlterDomainStmt { + subtype: String::from_utf8_lossy(&[ads.subtype as u8]).to_string(), + type_name: convert_list_to_nodes(ads.typeName), + name: convert_c_string(ads.name), + def: convert_node_boxed(ads.def), + behavior: ads.behavior as i32 + 1, + missing_ok: ads.missing_ok, + } +} + +unsafe fn convert_alter_function_stmt(afs: &bindings_raw::AlterFunctionStmt) -> protobuf::AlterFunctionStmt { + protobuf::AlterFunctionStmt { + objtype: afs.objtype as i32 + 1, + func: if afs.func.is_null() { None } else { Some(convert_object_with_args(&*afs.func)) }, + actions: convert_list_to_nodes(afs.actions), + } +} + +unsafe fn convert_alter_operator_stmt(aos: &bindings_raw::AlterOperatorStmt) -> protobuf::AlterOperatorStmt { + protobuf::AlterOperatorStmt { + opername: if aos.opername.is_null() { None } else { Some(convert_object_with_args(&*aos.opername)) }, + options: convert_list_to_nodes(aos.options), + } +} + +unsafe fn convert_alter_type_stmt(ats: &bindings_raw::AlterTypeStmt) -> protobuf::AlterTypeStmt { + protobuf::AlterTypeStmt { type_name: convert_list_to_nodes(ats.typeName), options: convert_list_to_nodes(ats.options) } +} + +unsafe fn convert_alter_object_schema_stmt(aoss: &bindings_raw::AlterObjectSchemaStmt) -> protobuf::AlterObjectSchemaStmt { + protobuf::AlterObjectSchemaStmt { + object_type: aoss.objectType as i32 + 1, + relation: if aoss.relation.is_null() { None } else { Some(convert_range_var(&*aoss.relation)) }, + object: convert_node_boxed(aoss.object), + newschema: convert_c_string(aoss.newschema), + missing_ok: aoss.missing_ok, + } +} + +unsafe fn convert_alter_object_depends_stmt(aods: &bindings_raw::AlterObjectDependsStmt) -> protobuf::AlterObjectDependsStmt { + protobuf::AlterObjectDependsStmt { + object_type: aods.objectType as i32 + 1, + relation: if aods.relation.is_null() { None } else { Some(convert_range_var(&*aods.relation)) }, + object: convert_node_boxed(aods.object), + extname: Some(protobuf::String { sval: convert_c_string(aods.extname as *mut i8) }), + remove: aods.remove, + } +} + +unsafe fn convert_alter_collation_stmt(acs: &bindings_raw::AlterCollationStmt) -> protobuf::AlterCollationStmt { + protobuf::AlterCollationStmt { collname: convert_list_to_nodes(acs.collname) } +} + +unsafe fn convert_alter_default_privileges_stmt(adps: &bindings_raw::AlterDefaultPrivilegesStmt) -> protobuf::AlterDefaultPrivilegesStmt { + protobuf::AlterDefaultPrivilegesStmt { + options: convert_list_to_nodes(adps.options), + action: if adps.action.is_null() { None } else { Some(convert_grant_stmt(&*adps.action)) }, + } +} + +unsafe fn convert_create_cast_stmt(ccs: &bindings_raw::CreateCastStmt) -> protobuf::CreateCastStmt { + protobuf::CreateCastStmt { + sourcetype: if ccs.sourcetype.is_null() { None } else { Some(convert_type_name(&*ccs.sourcetype)) }, + targettype: if ccs.targettype.is_null() { None } else { Some(convert_type_name(&*ccs.targettype)) }, + func: if ccs.func.is_null() { None } else { Some(convert_object_with_args(&*ccs.func)) }, + context: ccs.context as i32 + 1, + inout: ccs.inout, + } +} + +unsafe fn convert_create_transform_stmt(cts: &bindings_raw::CreateTransformStmt) -> protobuf::CreateTransformStmt { + protobuf::CreateTransformStmt { + replace: cts.replace, + type_name: if cts.type_name.is_null() { None } else { Some(convert_type_name(&*cts.type_name)) }, + lang: convert_c_string(cts.lang), + fromsql: if cts.fromsql.is_null() { None } else { Some(convert_object_with_args(&*cts.fromsql)) }, + tosql: if cts.tosql.is_null() { None } else { Some(convert_object_with_args(&*cts.tosql)) }, + } +} + +unsafe fn convert_create_conversion_stmt(ccs: &bindings_raw::CreateConversionStmt) -> protobuf::CreateConversionStmt { + protobuf::CreateConversionStmt { + conversion_name: convert_list_to_nodes(ccs.conversion_name), + for_encoding_name: convert_c_string(ccs.for_encoding_name), + to_encoding_name: convert_c_string(ccs.to_encoding_name), + func_name: convert_list_to_nodes(ccs.func_name), + def: ccs.def, + } +} + +unsafe fn convert_alter_ts_dictionary_stmt(atds: &bindings_raw::AlterTSDictionaryStmt) -> protobuf::AlterTsDictionaryStmt { + protobuf::AlterTsDictionaryStmt { dictname: convert_list_to_nodes(atds.dictname), options: convert_list_to_nodes(atds.options) } +} + +unsafe fn convert_alter_ts_configuration_stmt(atcs: &bindings_raw::AlterTSConfigurationStmt) -> protobuf::AlterTsConfigurationStmt { + protobuf::AlterTsConfigurationStmt { + kind: atcs.kind as i32 + 1, + cfgname: convert_list_to_nodes(atcs.cfgname), + tokentype: convert_list_to_nodes(atcs.tokentype), + dicts: convert_list_to_nodes(atcs.dicts), + r#override: atcs.override_, + replace: atcs.replace, + missing_ok: atcs.missing_ok, + } +} + +unsafe fn convert_createdb_stmt(cds: &bindings_raw::CreatedbStmt) -> protobuf::CreatedbStmt { + protobuf::CreatedbStmt { dbname: convert_c_string(cds.dbname), options: convert_list_to_nodes(cds.options) } +} + +unsafe fn convert_dropdb_stmt(dds: &bindings_raw::DropdbStmt) -> protobuf::DropdbStmt { + protobuf::DropdbStmt { dbname: convert_c_string(dds.dbname), missing_ok: dds.missing_ok, options: convert_list_to_nodes(dds.options) } +} + +unsafe fn convert_alter_database_stmt(ads: &bindings_raw::AlterDatabaseStmt) -> protobuf::AlterDatabaseStmt { + protobuf::AlterDatabaseStmt { dbname: convert_c_string(ads.dbname), options: convert_list_to_nodes(ads.options) } +} + +unsafe fn convert_alter_database_set_stmt(adss: &bindings_raw::AlterDatabaseSetStmt) -> protobuf::AlterDatabaseSetStmt { + protobuf::AlterDatabaseSetStmt { dbname: convert_c_string(adss.dbname), setstmt: convert_variable_set_stmt_opt(adss.setstmt) } +} + +unsafe fn convert_alter_database_refresh_coll_stmt(adrcs: &bindings_raw::AlterDatabaseRefreshCollStmt) -> protobuf::AlterDatabaseRefreshCollStmt { + protobuf::AlterDatabaseRefreshCollStmt { dbname: convert_c_string(adrcs.dbname) } +} + +unsafe fn convert_alter_system_stmt(ass: &bindings_raw::AlterSystemStmt) -> protobuf::AlterSystemStmt { + protobuf::AlterSystemStmt { setstmt: convert_variable_set_stmt_opt(ass.setstmt) } +} + +unsafe fn convert_cluster_stmt(cs: &bindings_raw::ClusterStmt) -> protobuf::ClusterStmt { + protobuf::ClusterStmt { + relation: if cs.relation.is_null() { None } else { Some(convert_range_var(&*cs.relation)) }, + indexname: convert_c_string(cs.indexname), + params: convert_list_to_nodes(cs.params), + } +} + +unsafe fn convert_reindex_stmt(rs: &bindings_raw::ReindexStmt) -> protobuf::ReindexStmt { + protobuf::ReindexStmt { + kind: rs.kind as i32 + 1, + relation: if rs.relation.is_null() { None } else { Some(convert_range_var(&*rs.relation)) }, + name: convert_c_string(rs.name), + params: convert_list_to_nodes(rs.params), + } +} + +unsafe fn convert_constraints_set_stmt(css: &bindings_raw::ConstraintsSetStmt) -> protobuf::ConstraintsSetStmt { + protobuf::ConstraintsSetStmt { constraints: convert_list_to_nodes(css.constraints), deferred: css.deferred } +} + +unsafe fn convert_load_stmt(ls: &bindings_raw::LoadStmt) -> protobuf::LoadStmt { + protobuf::LoadStmt { filename: convert_c_string(ls.filename) } +} + +unsafe fn convert_drop_owned_stmt(dos: &bindings_raw::DropOwnedStmt) -> protobuf::DropOwnedStmt { + protobuf::DropOwnedStmt { roles: convert_list_to_nodes(dos.roles), behavior: dos.behavior as i32 + 1 } +} + +unsafe fn convert_reassign_owned_stmt(ros: &bindings_raw::ReassignOwnedStmt) -> protobuf::ReassignOwnedStmt { + protobuf::ReassignOwnedStmt { + roles: convert_list_to_nodes(ros.roles), + newrole: if ros.newrole.is_null() { None } else { Some(convert_role_spec(&*ros.newrole)) }, + } +} + +unsafe fn convert_drop_subscription_stmt(dss: &bindings_raw::DropSubscriptionStmt) -> protobuf::DropSubscriptionStmt { + protobuf::DropSubscriptionStmt { subname: convert_c_string(dss.subname), missing_ok: dss.missing_ok, behavior: dss.behavior as i32 + 1 } +} + +unsafe fn convert_table_func(tf: &bindings_raw::TableFunc) -> protobuf::TableFunc { + protobuf::TableFunc { + functype: tf.functype as i32, + ns_uris: convert_list_to_nodes(tf.ns_uris), + ns_names: convert_list_to_nodes(tf.ns_names), + docexpr: convert_node_boxed(tf.docexpr), + rowexpr: convert_node_boxed(tf.rowexpr), + colnames: convert_list_to_nodes(tf.colnames), + coltypes: convert_list_to_nodes(tf.coltypes), + coltypmods: convert_list_to_nodes(tf.coltypmods), + colcollations: convert_list_to_nodes(tf.colcollations), + colexprs: convert_list_to_nodes(tf.colexprs), + coldefexprs: convert_list_to_nodes(tf.coldefexprs), + colvalexprs: convert_list_to_nodes(tf.colvalexprs), + passingvalexprs: convert_list_to_nodes(tf.passingvalexprs), + notnulls: vec![], // Bitmapset conversion not yet supported + plan: convert_node_boxed(tf.plan), + ordinalitycol: tf.ordinalitycol, + location: tf.location, + } +} + +unsafe fn convert_into_clause_node(ic: &bindings_raw::IntoClause) -> protobuf::IntoClause { + protobuf::IntoClause { + rel: if ic.rel.is_null() { None } else { Some(convert_range_var(&*ic.rel)) }, + col_names: convert_list_to_nodes(ic.colNames), + access_method: convert_c_string(ic.accessMethod), + options: convert_list_to_nodes(ic.options), + on_commit: ic.onCommit as i32 + 1, + table_space_name: convert_c_string(ic.tableSpaceName), + view_query: convert_node_boxed(ic.viewQuery), + skip_data: ic.skipData, + } +} + +unsafe fn convert_table_like_clause(tlc: &bindings_raw::TableLikeClause) -> protobuf::TableLikeClause { + protobuf::TableLikeClause { + relation: if tlc.relation.is_null() { None } else { Some(convert_range_var(&*tlc.relation)) }, + options: tlc.options, + relation_oid: tlc.relationOid, + } +} + +unsafe fn convert_range_table_func(rtf: &bindings_raw::RangeTableFunc) -> protobuf::RangeTableFunc { + protobuf::RangeTableFunc { + lateral: rtf.lateral, + docexpr: convert_node_boxed(rtf.docexpr), + rowexpr: convert_node_boxed(rtf.rowexpr), + namespaces: convert_list_to_nodes(rtf.namespaces), + columns: convert_list_to_nodes(rtf.columns), + alias: if rtf.alias.is_null() { None } else { Some(convert_alias(&*rtf.alias)) }, + location: rtf.location, + } +} + +unsafe fn convert_range_table_func_col(rtfc: &bindings_raw::RangeTableFuncCol) -> protobuf::RangeTableFuncCol { + protobuf::RangeTableFuncCol { + colname: convert_c_string(rtfc.colname), + type_name: if rtfc.typeName.is_null() { None } else { Some(convert_type_name(&*rtfc.typeName)) }, + for_ordinality: rtfc.for_ordinality, + is_not_null: rtfc.is_not_null, + colexpr: convert_node_boxed(rtfc.colexpr), + coldefexpr: convert_node_boxed(rtfc.coldefexpr), + location: rtfc.location, + } +} + +unsafe fn convert_range_table_sample(rts: &bindings_raw::RangeTableSample) -> protobuf::RangeTableSample { + protobuf::RangeTableSample { + relation: convert_node_boxed(rts.relation), + method: convert_list_to_nodes(rts.method), + args: convert_list_to_nodes(rts.args), + repeatable: convert_node_boxed(rts.repeatable), + location: rts.location, + } +} + +unsafe fn convert_partition_cmd(pc: &bindings_raw::PartitionCmd) -> protobuf::PartitionCmd { + protobuf::PartitionCmd { + name: if pc.name.is_null() { None } else { Some(convert_range_var(&*pc.name)) }, + bound: convert_partition_bound_spec_opt(pc.bound), + concurrent: pc.concurrent, + } +} + +unsafe fn convert_on_conflict_clause_node(occ: &bindings_raw::OnConflictClause) -> protobuf::OnConflictClause { + protobuf::OnConflictClause { + action: occ.action as i32 + 1, + infer: convert_infer_clause_opt(occ.infer), + target_list: convert_list_to_nodes(occ.targetList), + where_clause: convert_node_boxed(occ.whereClause), + location: occ.location, + } +} + +unsafe fn convert_trigger_transition(tt: &bindings_raw::TriggerTransition) -> protobuf::TriggerTransition { + protobuf::TriggerTransition { name: convert_c_string(tt.name), is_new: tt.isNew, is_table: tt.isTable } +} + +unsafe fn convert_create_stats_stmt(css: &bindings_raw::CreateStatsStmt) -> protobuf::CreateStatsStmt { + protobuf::CreateStatsStmt { + defnames: convert_list_to_nodes(css.defnames), + stat_types: convert_list_to_nodes(css.stat_types), + exprs: convert_list_to_nodes(css.exprs), + relations: convert_list_to_nodes(css.relations), + stxcomment: convert_c_string(css.stxcomment), + transformed: css.transformed, + if_not_exists: css.if_not_exists, + } +} + +unsafe fn convert_alter_stats_stmt(ass: &bindings_raw::AlterStatsStmt) -> protobuf::AlterStatsStmt { + protobuf::AlterStatsStmt { + defnames: convert_list_to_nodes(ass.defnames), + stxstattarget: convert_node_boxed(ass.stxstattarget), + missing_ok: ass.missing_ok, + } +} + +unsafe fn convert_stats_elem(se: &bindings_raw::StatsElem) -> protobuf::StatsElem { + protobuf::StatsElem { name: convert_c_string(se.name), expr: convert_node_boxed(se.expr) } +} + +unsafe fn convert_sql_value_function(svf: &bindings_raw::SQLValueFunction) -> protobuf::SqlValueFunction { + let xpr_ptr = &svf.xpr as *const bindings_raw::Expr as *mut bindings_raw::Node; + protobuf::SqlValueFunction { + xpr: convert_node_boxed(xpr_ptr), + op: svf.op as i32 + 1, + r#type: svf.type_, + typmod: svf.typmod, + location: svf.location, + } +} + +unsafe fn convert_xml_expr(xe: &bindings_raw::XmlExpr) -> protobuf::XmlExpr { + let xpr_ptr = &xe.xpr as *const bindings_raw::Expr as *mut bindings_raw::Node; + protobuf::XmlExpr { + xpr: convert_node_boxed(xpr_ptr), + op: xe.op as i32 + 1, + name: convert_c_string(xe.name), + named_args: convert_list_to_nodes(xe.named_args), + arg_names: convert_list_to_nodes(xe.arg_names), + args: convert_list_to_nodes(xe.args), + xmloption: xe.xmloption as i32 + 1, + indent: xe.indent, + r#type: xe.type_, + typmod: xe.typmod, + location: xe.location, + } +} + +unsafe fn convert_xml_serialize(xs: &bindings_raw::XmlSerialize) -> protobuf::XmlSerialize { + protobuf::XmlSerialize { + xmloption: xs.xmloption as i32 + 1, + expr: convert_node_boxed(xs.expr), + type_name: if xs.typeName.is_null() { None } else { Some(convert_type_name(&*xs.typeName)) }, + indent: xs.indent, + location: xs.location, + } +} + +unsafe fn convert_named_arg_expr(nae: &bindings_raw::NamedArgExpr) -> protobuf::NamedArgExpr { + let xpr_ptr = &nae.xpr as *const bindings_raw::Expr as *mut bindings_raw::Node; + protobuf::NamedArgExpr { + xpr: convert_node_boxed(xpr_ptr), + arg: convert_node_boxed(nae.arg as *mut bindings_raw::Node), + name: convert_c_string(nae.name), + argnumber: nae.argnumber, + location: nae.location, + } +} + +// ============================================================================ +// JSON Node Conversions +// ============================================================================ + +unsafe fn convert_json_format(jf: &bindings_raw::JsonFormat) -> protobuf::JsonFormat { + protobuf::JsonFormat { format_type: jf.format_type as i32 + 1, encoding: jf.encoding as i32 + 1, location: jf.location } +} + +unsafe fn convert_json_returning(jr: &bindings_raw::JsonReturning) -> protobuf::JsonReturning { + protobuf::JsonReturning { + format: if jr.format.is_null() { None } else { Some(convert_json_format(&*jr.format)) }, + typid: jr.typid, + typmod: jr.typmod, + } +} + +unsafe fn convert_json_value_expr(jve: &bindings_raw::JsonValueExpr) -> protobuf::JsonValueExpr { + protobuf::JsonValueExpr { + raw_expr: convert_node_boxed(jve.raw_expr as *mut bindings_raw::Node), + formatted_expr: convert_node_boxed(jve.formatted_expr as *mut bindings_raw::Node), + format: if jve.format.is_null() { None } else { Some(convert_json_format(&*jve.format)) }, + } +} + +unsafe fn convert_json_constructor_expr(jce: &bindings_raw::JsonConstructorExpr) -> protobuf::JsonConstructorExpr { + let xpr_ptr = &jce.xpr as *const bindings_raw::Expr as *mut bindings_raw::Node; + protobuf::JsonConstructorExpr { + xpr: convert_node_boxed(xpr_ptr), + r#type: jce.type_ as i32 + 1, + args: convert_list_to_nodes(jce.args), + func: convert_node_boxed(jce.func as *mut bindings_raw::Node), + coercion: convert_node_boxed(jce.coercion as *mut bindings_raw::Node), + returning: if jce.returning.is_null() { None } else { Some(convert_json_returning(&*jce.returning)) }, + absent_on_null: jce.absent_on_null, + unique: jce.unique, + location: jce.location, + } +} + +unsafe fn convert_json_is_predicate(jip: &bindings_raw::JsonIsPredicate) -> protobuf::JsonIsPredicate { + protobuf::JsonIsPredicate { + expr: convert_node_boxed(jip.expr), + format: if jip.format.is_null() { None } else { Some(convert_json_format(&*jip.format)) }, + item_type: jip.item_type as i32 + 1, + unique_keys: jip.unique_keys, + location: jip.location, + } +} + +unsafe fn convert_json_behavior(jb: &bindings_raw::JsonBehavior) -> protobuf::JsonBehavior { + protobuf::JsonBehavior { btype: jb.btype as i32 + 1, expr: convert_node_boxed(jb.expr), coerce: jb.coerce, location: jb.location } +} + +unsafe fn convert_json_expr(je: &bindings_raw::JsonExpr) -> protobuf::JsonExpr { + let xpr_ptr = &je.xpr as *const bindings_raw::Expr as *mut bindings_raw::Node; + protobuf::JsonExpr { + xpr: convert_node_boxed(xpr_ptr), + op: je.op as i32 + 1, + column_name: convert_c_string(je.column_name), + formatted_expr: convert_node_boxed(je.formatted_expr as *mut bindings_raw::Node), + format: if je.format.is_null() { None } else { Some(convert_json_format(&*je.format)) }, + path_spec: convert_node_boxed(je.path_spec), + returning: if je.returning.is_null() { None } else { Some(convert_json_returning(&*je.returning)) }, + passing_names: convert_list_to_nodes(je.passing_names), + passing_values: convert_list_to_nodes(je.passing_values), + on_empty: if je.on_empty.is_null() { None } else { Some(Box::new(convert_json_behavior(&*je.on_empty))) }, + on_error: if je.on_error.is_null() { None } else { Some(Box::new(convert_json_behavior(&*je.on_error))) }, + use_io_coercion: je.use_io_coercion, + use_json_coercion: je.use_json_coercion, + wrapper: je.wrapper as i32 + 1, + omit_quotes: je.omit_quotes, + collation: je.collation, + location: je.location, + } +} + +unsafe fn convert_json_table_path(jtp: &bindings_raw::JsonTablePath) -> protobuf::JsonTablePath { + // In raw parse tree, value is not populated - only name + protobuf::JsonTablePath { name: convert_c_string(jtp.name) } +} + +unsafe fn convert_json_table_path_scan(jtps: &bindings_raw::JsonTablePathScan) -> protobuf::JsonTablePathScan { + protobuf::JsonTablePathScan { + plan: convert_node_boxed(&jtps.plan as *const bindings_raw::JsonTablePlan as *mut bindings_raw::Node), + path: if jtps.path.is_null() { None } else { Some(convert_json_table_path(&*jtps.path)) }, + error_on_error: jtps.errorOnError, + child: convert_node_boxed(jtps.child as *mut bindings_raw::Node), + col_min: jtps.colMin, + col_max: jtps.colMax, + } +} + +unsafe fn convert_json_table_sibling_join(jtsj: &bindings_raw::JsonTableSiblingJoin) -> protobuf::JsonTableSiblingJoin { + protobuf::JsonTableSiblingJoin { + plan: convert_node_boxed(&jtsj.plan as *const bindings_raw::JsonTablePlan as *mut bindings_raw::Node), + lplan: convert_node_boxed(jtsj.lplan as *mut bindings_raw::Node), + rplan: convert_node_boxed(jtsj.rplan as *mut bindings_raw::Node), + } +} + +unsafe fn convert_json_output(jo: &bindings_raw::JsonOutput) -> protobuf::JsonOutput { + protobuf::JsonOutput { + type_name: if jo.typeName.is_null() { None } else { Some(convert_type_name(&*jo.typeName)) }, + returning: if jo.returning.is_null() { None } else { Some(convert_json_returning(&*jo.returning)) }, + } +} + +unsafe fn convert_json_argument(ja: &bindings_raw::JsonArgument) -> protobuf::JsonArgument { + protobuf::JsonArgument { + val: if ja.val.is_null() { None } else { Some(Box::new(convert_json_value_expr(&*ja.val))) }, + name: convert_c_string(ja.name), + } +} + +unsafe fn convert_json_func_expr(jfe: &bindings_raw::JsonFuncExpr) -> protobuf::JsonFuncExpr { + protobuf::JsonFuncExpr { + op: jfe.op as i32 + 1, + column_name: convert_c_string(jfe.column_name), + context_item: if jfe.context_item.is_null() { None } else { Some(Box::new(convert_json_value_expr(&*jfe.context_item))) }, + pathspec: convert_node_boxed(jfe.pathspec), + passing: convert_list_to_nodes(jfe.passing), + output: if jfe.output.is_null() { None } else { Some(convert_json_output(&*jfe.output)) }, + on_empty: if jfe.on_empty.is_null() { None } else { Some(Box::new(convert_json_behavior(&*jfe.on_empty))) }, + on_error: if jfe.on_error.is_null() { None } else { Some(Box::new(convert_json_behavior(&*jfe.on_error))) }, + wrapper: jfe.wrapper as i32 + 1, + quotes: jfe.quotes as i32 + 1, + location: jfe.location, + } +} + +unsafe fn convert_json_table_path_spec(jtps: &bindings_raw::JsonTablePathSpec) -> protobuf::JsonTablePathSpec { + protobuf::JsonTablePathSpec { + string: convert_node_boxed(jtps.string), + name: convert_c_string(jtps.name), + name_location: jtps.name_location, + location: jtps.location, + } +} + +unsafe fn convert_json_table(jt: &bindings_raw::JsonTable) -> protobuf::JsonTable { + protobuf::JsonTable { + context_item: if jt.context_item.is_null() { None } else { Some(Box::new(convert_json_value_expr(&*jt.context_item))) }, + pathspec: if jt.pathspec.is_null() { None } else { Some(Box::new(convert_json_table_path_spec(&*jt.pathspec))) }, + passing: convert_list_to_nodes(jt.passing), + columns: convert_list_to_nodes(jt.columns), + on_error: if jt.on_error.is_null() { None } else { Some(Box::new(convert_json_behavior(&*jt.on_error))) }, + alias: if jt.alias.is_null() { None } else { Some(convert_alias(&*jt.alias)) }, + lateral: jt.lateral, + location: jt.location, + } +} + +unsafe fn convert_json_table_column(jtc: &bindings_raw::JsonTableColumn) -> protobuf::JsonTableColumn { + protobuf::JsonTableColumn { + coltype: jtc.coltype as i32 + 1, + name: convert_c_string(jtc.name), + type_name: if jtc.typeName.is_null() { None } else { Some(convert_type_name(&*jtc.typeName)) }, + pathspec: if jtc.pathspec.is_null() { None } else { Some(Box::new(convert_json_table_path_spec(&*jtc.pathspec))) }, + format: if jtc.format.is_null() { None } else { Some(convert_json_format(&*jtc.format)) }, + wrapper: jtc.wrapper as i32 + 1, + quotes: jtc.quotes as i32 + 1, + columns: convert_list_to_nodes(jtc.columns), + on_empty: if jtc.on_empty.is_null() { None } else { Some(Box::new(convert_json_behavior(&*jtc.on_empty))) }, + on_error: if jtc.on_error.is_null() { None } else { Some(Box::new(convert_json_behavior(&*jtc.on_error))) }, + location: jtc.location, + } +} + +unsafe fn convert_json_key_value(jkv: &bindings_raw::JsonKeyValue) -> protobuf::JsonKeyValue { + protobuf::JsonKeyValue { + key: convert_node_boxed(jkv.key as *mut bindings_raw::Node), + value: if jkv.value.is_null() { None } else { Some(Box::new(convert_json_value_expr(&*jkv.value))) }, + } +} + +unsafe fn convert_json_parse_expr(jpe: &bindings_raw::JsonParseExpr) -> protobuf::JsonParseExpr { + protobuf::JsonParseExpr { + expr: if jpe.expr.is_null() { None } else { Some(Box::new(convert_json_value_expr(&*jpe.expr))) }, + output: if jpe.output.is_null() { None } else { Some(convert_json_output(&*jpe.output)) }, + unique_keys: jpe.unique_keys, + location: jpe.location, + } +} + +unsafe fn convert_json_scalar_expr(jse: &bindings_raw::JsonScalarExpr) -> protobuf::JsonScalarExpr { + protobuf::JsonScalarExpr { + expr: convert_node_boxed(jse.expr as *mut bindings_raw::Node), + output: if jse.output.is_null() { None } else { Some(convert_json_output(&*jse.output)) }, + location: jse.location, + } +} + +unsafe fn convert_json_serialize_expr(jse: &bindings_raw::JsonSerializeExpr) -> protobuf::JsonSerializeExpr { + protobuf::JsonSerializeExpr { + expr: if jse.expr.is_null() { None } else { Some(Box::new(convert_json_value_expr(&*jse.expr))) }, + output: if jse.output.is_null() { None } else { Some(convert_json_output(&*jse.output)) }, + location: jse.location, + } +} + +unsafe fn convert_json_object_constructor(joc: &bindings_raw::JsonObjectConstructor) -> protobuf::JsonObjectConstructor { + protobuf::JsonObjectConstructor { + exprs: convert_list_to_nodes(joc.exprs), + output: if joc.output.is_null() { None } else { Some(convert_json_output(&*joc.output)) }, + absent_on_null: joc.absent_on_null, + unique: joc.unique, + location: joc.location, + } +} + +unsafe fn convert_json_array_constructor(jac: &bindings_raw::JsonArrayConstructor) -> protobuf::JsonArrayConstructor { + protobuf::JsonArrayConstructor { + exprs: convert_list_to_nodes(jac.exprs), + output: if jac.output.is_null() { None } else { Some(convert_json_output(&*jac.output)) }, + absent_on_null: jac.absent_on_null, + location: jac.location, + } +} + +unsafe fn convert_json_array_query_constructor(jaqc: &bindings_raw::JsonArrayQueryConstructor) -> protobuf::JsonArrayQueryConstructor { + protobuf::JsonArrayQueryConstructor { + query: convert_node_boxed(jaqc.query), + output: if jaqc.output.is_null() { None } else { Some(convert_json_output(&*jaqc.output)) }, + format: if jaqc.format.is_null() { None } else { Some(convert_json_format(&*jaqc.format)) }, + absent_on_null: jaqc.absent_on_null, + location: jaqc.location, + } +} + +unsafe fn convert_json_agg_constructor(jac: &bindings_raw::JsonAggConstructor) -> protobuf::JsonAggConstructor { + protobuf::JsonAggConstructor { + output: if jac.output.is_null() { None } else { Some(convert_json_output(&*jac.output)) }, + agg_filter: convert_node_boxed(jac.agg_filter), + agg_order: convert_list_to_nodes(jac.agg_order), + over: if jac.over.is_null() { None } else { Some(Box::new(convert_window_def(&*jac.over))) }, + location: jac.location, + } +} + +unsafe fn convert_json_object_agg(joa: &bindings_raw::JsonObjectAgg) -> protobuf::JsonObjectAgg { + protobuf::JsonObjectAgg { + constructor: if joa.constructor.is_null() { None } else { Some(Box::new(convert_json_agg_constructor(&*joa.constructor))) }, + arg: if joa.arg.is_null() { None } else { Some(Box::new(convert_json_key_value(&*joa.arg))) }, + absent_on_null: joa.absent_on_null, + unique: joa.unique, + } +} + +unsafe fn convert_json_array_agg(jaa: &bindings_raw::JsonArrayAgg) -> protobuf::JsonArrayAgg { + protobuf::JsonArrayAgg { + constructor: if jaa.constructor.is_null() { None } else { Some(Box::new(convert_json_agg_constructor(&*jaa.constructor))) }, + arg: if jaa.arg.is_null() { None } else { Some(Box::new(convert_json_value_expr(&*jaa.arg))) }, + absent_on_null: jaa.absent_on_null, + } +} + +// ============================================================================ +// Additional Helper Functions +// ============================================================================ + +unsafe fn convert_variable_set_stmt_opt(stmt: *mut bindings_raw::VariableSetStmt) -> Option { + if stmt.is_null() { + None + } else { + Some(convert_variable_set_stmt(&*stmt)) + } +} + +unsafe fn convert_infer_clause_opt(ic: *mut bindings_raw::InferClause) -> Option> { + if ic.is_null() { + None + } else { + let ic_ref = &*ic; + Some(Box::new(protobuf::InferClause { + index_elems: convert_list_to_nodes(ic_ref.indexElems), + where_clause: convert_node_boxed(ic_ref.whereClause), + conname: convert_c_string(ic_ref.conname), + location: ic_ref.location, + })) + } +} From f87319011074cc35275d1fc739fa3acb2dd201b7 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 31 Dec 2025 19:12:58 -0800 Subject: [PATCH 15/17] reduce stack on windows --- Cargo.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 65f96e7..1f535d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,3 +42,9 @@ regex = "1.6.0" [[bench]] name = "parse_vs_summary" harness = false + +# Optimize build scripts even in debug mode to reduce stack usage. +# This is needed because bindgen uses deep recursion when processing +# PostgreSQL header files, which can overflow the stack on Windows. +[profile.dev.build-override] +opt-level = 2 From 59a0f77b2e5e5dba78bbe0e94defcdad53992b43 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 31 Dec 2025 19:25:20 -0800 Subject: [PATCH 16/17] save --- src/raw_deparse.rs | 184 ++++++++++++++++++++++----------------------- 1 file changed, 92 insertions(+), 92 deletions(-) diff --git a/src/raw_deparse.rs b/src/raw_deparse.rs index ef76d08..cf1b4ac 100644 --- a/src/raw_deparse.rs +++ b/src/raw_deparse.rs @@ -50,18 +50,18 @@ pub fn deparse_raw(protobuf: &protobuf::ParseResult) -> Result { } /// Allocates a C node of the given type. -unsafe fn alloc_node(tag: u32) -> *mut T { +unsafe fn alloc_node(tag: bindings_raw::NodeTag) -> *mut T { bindings_raw::pg_query_alloc_node(std::mem::size_of::(), tag as i32) as *mut T } /// Converts a protobuf enum value to a C enum value. /// Protobuf enums have an extra "Undefined = 0" value, so we subtract 1. /// If the value is 0 (Undefined), we return 0 (treating it as the first C enum value). -fn proto_enum_to_c(value: i32) -> u32 { +fn proto_enum_to_c(value: i32) -> i32 { if value <= 0 { 0 } else { - (value - 1) as u32 + value - 1 } } @@ -460,10 +460,10 @@ unsafe fn write_select_stmt(stmt: &protobuf::SelectStmt) -> *mut bindings_raw::S (*node).sortClause = write_node_list(&stmt.sort_clause); (*node).limitOffset = write_node_boxed(&stmt.limit_offset); (*node).limitCount = write_node_boxed(&stmt.limit_count); - (*node).limitOption = proto_enum_to_c(stmt.limit_option); + (*node).limitOption = proto_enum_to_c(stmt.limit_option) as _; (*node).lockingClause = write_node_list(&stmt.locking_clause); (*node).withClause = write_with_clause_ref(&stmt.with_clause); - (*node).op = proto_enum_to_c(stmt.op); + (*node).op = proto_enum_to_c(stmt.op) as _; (*node).all = stmt.all; (*node).larg = write_select_stmt_opt(&stmt.larg); (*node).rarg = write_select_stmt_opt(&stmt.rarg); @@ -490,7 +490,7 @@ unsafe fn write_into_clause(ic: &protobuf::IntoClause) -> *mut bindings_raw::Int (*node).colNames = write_node_list(&ic.col_names); (*node).accessMethod = pstrdup(&ic.access_method); (*node).options = write_node_list(&ic.options); - (*node).onCommit = proto_enum_to_c(ic.on_commit); + (*node).onCommit = proto_enum_to_c(ic.on_commit) as _; (*node).tableSpaceName = pstrdup(&ic.table_space_name); (*node).viewQuery = write_node_boxed(&ic.view_query); (*node).skipData = ic.skip_data; @@ -505,7 +505,7 @@ unsafe fn write_insert_stmt(stmt: &protobuf::InsertStmt) -> *mut bindings_raw::I (*node).onConflictClause = write_on_conflict_clause_opt(&stmt.on_conflict_clause); (*node).returningList = write_node_list(&stmt.returning_list); (*node).withClause = write_with_clause_ref(&stmt.with_clause); - (*node).override_ = proto_enum_to_c(stmt.r#override); + (*node).override_ = proto_enum_to_c(stmt.r#override) as _; node } @@ -518,7 +518,7 @@ unsafe fn write_on_conflict_clause_opt(oc: &Option *mut bindings_raw::OnConflictClause { let node = alloc_node::(bindings_raw::NodeTag_T_OnConflictClause); - (*node).action = proto_enum_to_c(oc.action); + (*node).action = proto_enum_to_c(oc.action) as _; (*node).infer = write_infer_clause_opt(&oc.infer); (*node).targetList = write_node_list(&oc.target_list); (*node).whereClause = write_node_boxed(&oc.where_clause); @@ -644,7 +644,7 @@ unsafe fn write_a_const(ac: &protobuf::AConst) -> *mut bindings_raw::A_Const { unsafe fn write_a_expr(expr: &protobuf::AExpr) -> *mut bindings_raw::A_Expr { let node = alloc_node::(bindings_raw::NodeTag_T_A_Expr); - (*node).kind = proto_enum_to_c(expr.kind); + (*node).kind = proto_enum_to_c(expr.kind) as _; (*node).name = write_node_list(&expr.name); (*node).lexpr = write_node_boxed(&expr.lexpr); (*node).rexpr = write_node_boxed(&expr.rexpr); @@ -663,7 +663,7 @@ unsafe fn write_func_call(fc: &protobuf::FuncCall) -> *mut bindings_raw::FuncCal (*node).agg_star = fc.agg_star; (*node).agg_distinct = fc.agg_distinct; (*node).func_variadic = fc.func_variadic; - (*node).funcformat = proto_enum_to_c(fc.funcformat); + (*node).funcformat = proto_enum_to_c(fc.funcformat) as _; (*node).location = fc.location; node } @@ -728,7 +728,7 @@ unsafe fn write_a_star() -> *mut bindings_raw::A_Star { unsafe fn write_join_expr(je: &protobuf::JoinExpr) -> *mut bindings_raw::JoinExpr { let node = alloc_node::(bindings_raw::NodeTag_T_JoinExpr); - (*node).jointype = proto_enum_to_c(je.jointype); + (*node).jointype = proto_enum_to_c(je.jointype) as _; (*node).isNatural = je.is_natural; (*node).larg = write_node_boxed(&je.larg); (*node).rarg = write_node_boxed(&je.rarg); @@ -743,8 +743,8 @@ unsafe fn write_join_expr(je: &protobuf::JoinExpr) -> *mut bindings_raw::JoinExp unsafe fn write_sort_by(sb: &protobuf::SortBy) -> *mut bindings_raw::SortBy { let node = alloc_node::(bindings_raw::NodeTag_T_SortBy); (*node).node = write_node_boxed(&sb.node); - (*node).sortby_dir = proto_enum_to_c(sb.sortby_dir); - (*node).sortby_nulls = proto_enum_to_c(sb.sortby_nulls); + (*node).sortby_dir = proto_enum_to_c(sb.sortby_dir) as _; + (*node).sortby_nulls = proto_enum_to_c(sb.sortby_nulls) as _; (*node).useOp = write_node_list(&sb.use_op); (*node).location = sb.location; node @@ -788,7 +788,7 @@ unsafe fn write_param_ref(pr: &protobuf::ParamRef) -> *mut bindings_raw::ParamRe unsafe fn write_null_test(nt: &protobuf::NullTest) -> *mut bindings_raw::NullTest { let node = alloc_node::(bindings_raw::NodeTag_T_NullTest); (*node).arg = write_node_boxed(&nt.arg) as *mut bindings_raw::Expr; - (*node).nulltesttype = proto_enum_to_c(nt.nulltesttype); + (*node).nulltesttype = proto_enum_to_c(nt.nulltesttype) as _; (*node).argisrow = nt.argisrow; (*node).location = nt.location; node @@ -796,7 +796,7 @@ unsafe fn write_null_test(nt: &protobuf::NullTest) -> *mut bindings_raw::NullTes unsafe fn write_bool_expr(be: &protobuf::BoolExpr) -> *mut bindings_raw::BoolExpr { let node = alloc_node::(bindings_raw::NodeTag_T_BoolExpr); - (*node).boolop = proto_enum_to_c(be.boolop); + (*node).boolop = proto_enum_to_c(be.boolop) as _; (*node).args = write_node_list(&be.args); (*node).location = be.location; node @@ -804,7 +804,7 @@ unsafe fn write_bool_expr(be: &protobuf::BoolExpr) -> *mut bindings_raw::BoolExp unsafe fn write_sub_link(sl: &protobuf::SubLink) -> *mut bindings_raw::SubLink { let node = alloc_node::(bindings_raw::NodeTag_T_SubLink); - (*node).subLinkType = proto_enum_to_c(sl.sub_link_type); + (*node).subLinkType = proto_enum_to_c(sl.sub_link_type) as _; (*node).subLinkId = sl.sub_link_id; (*node).testexpr = write_node_boxed(&sl.testexpr); (*node).operName = write_node_list(&sl.oper_name); @@ -825,7 +825,7 @@ unsafe fn write_common_table_expr(cte: &protobuf::CommonTableExpr) -> *mut bindi let node = alloc_node::(bindings_raw::NodeTag_T_CommonTableExpr); (*node).ctename = pstrdup(&cte.ctename); (*node).aliascolnames = write_node_list(&cte.aliascolnames); - (*node).ctematerialized = proto_enum_to_c(cte.ctematerialized); + (*node).ctematerialized = proto_enum_to_c(cte.ctematerialized) as _; (*node).ctequery = write_node_boxed(&cte.ctequery); (*node).search_clause = write_cte_search_clause_opt(&cte.search_clause); (*node).cycle_clause = write_cte_cycle_clause_opt(&cte.cycle_clause); @@ -884,7 +884,7 @@ unsafe fn write_with_clause_ref(wc: &Option) -> *mut bindi unsafe fn write_grouping_set(gs: &protobuf::GroupingSet) -> *mut bindings_raw::GroupingSet { let node = alloc_node::(bindings_raw::NodeTag_T_GroupingSet); - (*node).kind = proto_enum_to_c(gs.kind); + (*node).kind = proto_enum_to_c(gs.kind) as _; (*node).content = write_node_list(&gs.content); (*node).location = gs.location; node @@ -927,8 +927,8 @@ unsafe fn write_set_to_default() -> *mut bindings_raw::SetToDefault { unsafe fn write_locking_clause(lc: &protobuf::LockingClause) -> *mut bindings_raw::LockingClause { let node = alloc_node::(bindings_raw::NodeTag_T_LockingClause); (*node).lockedRels = write_node_list(&lc.locked_rels); - (*node).strength = proto_enum_to_c(lc.strength); - (*node).waitPolicy = proto_enum_to_c(lc.wait_policy); + (*node).strength = proto_enum_to_c(lc.strength) as _; + (*node).waitPolicy = proto_enum_to_c(lc.wait_policy) as _; node } @@ -996,8 +996,8 @@ unsafe fn write_index_elem(ie: &protobuf::IndexElem) -> *mut bindings_raw::Index (*node).collation = write_node_list(&ie.collation); (*node).opclass = write_node_list(&ie.opclass); (*node).opclassopts = write_node_list(&ie.opclassopts); - (*node).ordering = proto_enum_to_c(ie.ordering); - (*node).nulls_ordering = proto_enum_to_c(ie.nulls_ordering); + (*node).ordering = proto_enum_to_c(ie.ordering) as _; + (*node).nulls_ordering = proto_enum_to_c(ie.nulls_ordering) as _; node } @@ -1043,8 +1043,8 @@ unsafe fn write_values_lists(values: &[protobuf::Node]) -> *mut bindings_raw::Li unsafe fn write_drop_stmt(stmt: &protobuf::DropStmt) -> *mut bindings_raw::DropStmt { let node = alloc_node::(bindings_raw::NodeTag_T_DropStmt); (*node).objects = write_node_list(&stmt.objects); - (*node).removeType = proto_enum_to_c(stmt.remove_type); - (*node).behavior = proto_enum_to_c(stmt.behavior); + (*node).removeType = proto_enum_to_c(stmt.remove_type) as _; + (*node).behavior = proto_enum_to_c(stmt.behavior) as _; (*node).missing_ok = stmt.missing_ok; (*node).concurrent = stmt.concurrent; node @@ -1091,7 +1091,7 @@ unsafe fn write_truncate_stmt(stmt: &protobuf::TruncateStmt) -> *mut bindings_ra let node = alloc_node::(bindings_raw::NodeTag_T_TruncateStmt); (*node).relations = write_node_list(&stmt.relations); (*node).restart_seqs = stmt.restart_seqs; - (*node).behavior = proto_enum_to_c(stmt.behavior); + (*node).behavior = proto_enum_to_c(stmt.behavior) as _; node } @@ -1105,7 +1105,7 @@ unsafe fn write_create_stmt(stmt: &protobuf::CreateStmt) -> *mut bindings_raw::C (*node).ofTypename = write_type_name_ptr(&stmt.of_typename); (*node).constraints = write_node_list(&stmt.constraints); (*node).options = write_node_list(&stmt.options); - (*node).oncommit = proto_enum_to_c(stmt.oncommit); + (*node).oncommit = proto_enum_to_c(stmt.oncommit) as _; (*node).tablespacename = pstrdup(&stmt.tablespacename); (*node).accessMethod = pstrdup(&stmt.access_method); (*node).if_not_exists = stmt.if_not_exists; @@ -1123,19 +1123,19 @@ unsafe fn write_alter_table_stmt(stmt: &protobuf::AlterTableStmt) -> *mut bindin let node = alloc_node::(bindings_raw::NodeTag_T_AlterTableStmt); (*node).relation = write_range_var_ptr(&stmt.relation); (*node).cmds = write_node_list(&stmt.cmds); - (*node).objtype = proto_enum_to_c(stmt.objtype); + (*node).objtype = proto_enum_to_c(stmt.objtype) as _; (*node).missing_ok = stmt.missing_ok; node } unsafe fn write_alter_table_cmd(cmd: &protobuf::AlterTableCmd) -> *mut bindings_raw::AlterTableCmd { let node = alloc_node::(bindings_raw::NodeTag_T_AlterTableCmd); - (*node).subtype = proto_enum_to_c(cmd.subtype); + (*node).subtype = proto_enum_to_c(cmd.subtype) as _; (*node).name = pstrdup(&cmd.name); (*node).num = cmd.num as i16; (*node).newowner = std::ptr::null_mut(); // RoleSpec, complex (*node).def = write_node_boxed(&cmd.def); - (*node).behavior = proto_enum_to_c(cmd.behavior); + (*node).behavior = proto_enum_to_c(cmd.behavior) as _; (*node).missing_ok = cmd.missing_ok; (*node).recurse = cmd.recurse; node @@ -1166,7 +1166,7 @@ unsafe fn write_column_def(cd: &protobuf::ColumnDef) -> *mut bindings_raw::Colum unsafe fn write_constraint(c: &protobuf::Constraint) -> *mut bindings_raw::Constraint { let node = alloc_node::(bindings_raw::NodeTag_T_Constraint); - (*node).contype = proto_enum_to_c(c.contype); + (*node).contype = proto_enum_to_c(c.contype) as _; (*node).conname = pstrdup(&c.conname); (*node).deferrable = c.deferrable; (*node).initdeferred = c.initdeferred; @@ -1235,13 +1235,13 @@ unsafe fn write_view_stmt(stmt: &protobuf::ViewStmt) -> *mut bindings_raw::ViewS (*node).query = write_node_boxed(&stmt.query); (*node).replace = stmt.replace; (*node).options = write_node_list(&stmt.options); - (*node).withCheckOption = proto_enum_to_c(stmt.with_check_option); + (*node).withCheckOption = proto_enum_to_c(stmt.with_check_option) as _; node } unsafe fn write_transaction_stmt(stmt: &protobuf::TransactionStmt) -> *mut bindings_raw::TransactionStmt { let node = alloc_node::(bindings_raw::NodeTag_T_TransactionStmt); - (*node).kind = proto_enum_to_c(stmt.kind); + (*node).kind = proto_enum_to_c(stmt.kind) as _; (*node).options = write_node_list(&stmt.options); (*node).savepoint_name = pstrdup(&stmt.savepoint_name); (*node).gid = pstrdup(&stmt.gid); @@ -1274,7 +1274,7 @@ unsafe fn write_create_table_as_stmt(stmt: &protobuf::CreateTableAsStmt) -> *mut let node = alloc_node::(bindings_raw::NodeTag_T_CreateTableAsStmt); (*node).query = write_node_boxed(&stmt.query); (*node).into = if let Some(ref into) = stmt.into { write_into_clause(into) } else { std::ptr::null_mut() }; - (*node).objtype = proto_enum_to_c(stmt.objtype); + (*node).objtype = proto_enum_to_c(stmt.objtype) as _; (*node).is_select_into = stmt.is_select_into; (*node).if_not_exists = stmt.if_not_exists; node @@ -1323,7 +1323,7 @@ unsafe fn write_create_schema_stmt(stmt: &protobuf::CreateSchemaStmt) -> *mut bi unsafe fn write_variable_set_stmt(stmt: &protobuf::VariableSetStmt) -> *mut bindings_raw::VariableSetStmt { let node = alloc_node::(bindings_raw::NodeTag_T_VariableSetStmt); - (*node).kind = proto_enum_to_c(stmt.kind); + (*node).kind = proto_enum_to_c(stmt.kind) as _; (*node).name = pstrdup(&stmt.name); (*node).args = write_node_list(&stmt.args); (*node).is_local = stmt.is_local; @@ -1338,13 +1338,13 @@ unsafe fn write_variable_show_stmt(stmt: &protobuf::VariableShowStmt) -> *mut bi unsafe fn write_rename_stmt(stmt: &protobuf::RenameStmt) -> *mut bindings_raw::RenameStmt { let node = alloc_node::(bindings_raw::NodeTag_T_RenameStmt); - (*node).renameType = proto_enum_to_c(stmt.rename_type); - (*node).relationType = proto_enum_to_c(stmt.relation_type); + (*node).renameType = proto_enum_to_c(stmt.rename_type) as _; + (*node).relationType = proto_enum_to_c(stmt.relation_type) as _; (*node).relation = write_range_var_ptr(&stmt.relation); (*node).object = write_node_boxed(&stmt.object); (*node).subname = pstrdup(&stmt.subname); (*node).newname = pstrdup(&stmt.newname); - (*node).behavior = proto_enum_to_c(stmt.behavior); + (*node).behavior = proto_enum_to_c(stmt.behavior) as _; (*node).missing_ok = stmt.missing_ok; node } @@ -1352,20 +1352,20 @@ unsafe fn write_rename_stmt(stmt: &protobuf::RenameStmt) -> *mut bindings_raw::R unsafe fn write_grant_stmt(stmt: &protobuf::GrantStmt) -> *mut bindings_raw::GrantStmt { let node = alloc_node::(bindings_raw::NodeTag_T_GrantStmt); (*node).is_grant = stmt.is_grant; - (*node).targtype = proto_enum_to_c(stmt.targtype); - (*node).objtype = proto_enum_to_c(stmt.objtype); + (*node).targtype = proto_enum_to_c(stmt.targtype) as _; + (*node).objtype = proto_enum_to_c(stmt.objtype) as _; (*node).objects = write_node_list(&stmt.objects); (*node).privileges = write_node_list(&stmt.privileges); (*node).grantees = write_node_list(&stmt.grantees); (*node).grant_option = stmt.grant_option; (*node).grantor = std::ptr::null_mut(); // RoleSpec, complex - (*node).behavior = proto_enum_to_c(stmt.behavior); + (*node).behavior = proto_enum_to_c(stmt.behavior) as _; node } unsafe fn write_role_spec(rs: &protobuf::RoleSpec) -> *mut bindings_raw::RoleSpec { let node = alloc_node::(bindings_raw::NodeTag_T_RoleSpec); - (*node).roletype = proto_enum_to_c(rs.roletype); + (*node).roletype = proto_enum_to_c(rs.roletype) as _; (*node).rolename = pstrdup(&rs.rolename); (*node).location = rs.location; node @@ -1395,7 +1395,7 @@ unsafe fn write_def_elem(de: &protobuf::DefElem) -> *mut bindings_raw::DefElem { (*node).defnamespace = pstrdup(&de.defnamespace); (*node).defname = pstrdup(&de.defname); (*node).arg = write_node_boxed(&de.arg); - (*node).defaction = proto_enum_to_c(de.defaction); + (*node).defaction = proto_enum_to_c(de.defaction) as _; (*node).location = de.location; node } @@ -1405,7 +1405,7 @@ unsafe fn write_rule_stmt(stmt: &protobuf::RuleStmt) -> *mut bindings_raw::RuleS (*node).relation = write_range_var_ptr(&stmt.relation); (*node).rulename = pstrdup(&stmt.rulename); (*node).whereClause = write_node_boxed(&stmt.where_clause); - (*node).event = proto_enum_to_c(stmt.event); + (*node).event = proto_enum_to_c(stmt.event) as _; (*node).instead = stmt.instead; (*node).actions = write_node_list(&stmt.actions); (*node).replace = stmt.replace; @@ -1465,9 +1465,9 @@ unsafe fn write_merge_stmt(stmt: &protobuf::MergeStmt) -> *mut bindings_raw::Mer unsafe fn write_merge_when_clause(mwc: &protobuf::MergeWhenClause) -> *mut bindings_raw::MergeWhenClause { let node = alloc_node::(bindings_raw::NodeTag_T_MergeWhenClause); - (*node).matchKind = proto_enum_to_c(mwc.match_kind); - (*node).commandType = proto_enum_to_c(mwc.command_type); - (*node).override_ = proto_enum_to_c(mwc.r#override); + (*node).matchKind = proto_enum_to_c(mwc.match_kind) as _; + (*node).commandType = proto_enum_to_c(mwc.command_type) as _; + (*node).override_ = proto_enum_to_c(mwc.r#override) as _; (*node).condition = write_node_boxed(&mwc.condition); (*node).targetList = write_node_list(&mwc.target_list); (*node).values = write_node_list(&mwc.values); @@ -1481,7 +1481,7 @@ unsafe fn write_grant_role_stmt(stmt: &protobuf::GrantRoleStmt) -> *mut bindings (*node).is_grant = stmt.is_grant; (*node).opt = write_node_list(&stmt.opt); (*node).grantor = std::ptr::null_mut(); - (*node).behavior = proto_enum_to_c(stmt.behavior); + (*node).behavior = proto_enum_to_c(stmt.behavior) as _; node } @@ -1528,7 +1528,7 @@ unsafe fn write_min_max_expr(mme: &protobuf::MinMaxExpr) -> *mut bindings_raw::M (*node).minmaxtype = mme.minmaxtype; (*node).minmaxcollid = mme.minmaxcollid; (*node).inputcollid = mme.inputcollid; - (*node).op = proto_enum_to_c(mme.op); + (*node).op = proto_enum_to_c(mme.op) as _; (*node).args = write_node_list(&mme.args); (*node).location = mme.location; node @@ -1538,7 +1538,7 @@ unsafe fn write_row_expr(re: &protobuf::RowExpr) -> *mut bindings_raw::RowExpr { let node = alloc_node::(bindings_raw::NodeTag_T_RowExpr); (*node).args = write_node_list(&re.args); (*node).row_typeid = re.row_typeid; - (*node).row_format = proto_enum_to_c(re.row_format); + (*node).row_format = proto_enum_to_c(re.row_format) as _; (*node).colnames = write_node_list(&re.colnames); (*node).location = re.location; node @@ -1554,7 +1554,7 @@ unsafe fn write_a_array_expr(ae: &protobuf::AArrayExpr) -> *mut bindings_raw::A_ unsafe fn write_boolean_test(bt: &protobuf::BooleanTest) -> *mut bindings_raw::BooleanTest { let node = alloc_node::(bindings_raw::NodeTag_T_BooleanTest); (*node).arg = write_node_boxed(&bt.arg) as *mut bindings_raw::Expr; - (*node).booltesttype = proto_enum_to_c(bt.booltesttype); + (*node).booltesttype = proto_enum_to_c(bt.booltesttype) as _; (*node).location = bt.location; node } @@ -1592,7 +1592,7 @@ unsafe fn write_notify_stmt(stmt: &protobuf::NotifyStmt) -> *mut bindings_raw::N unsafe fn write_discard_stmt(stmt: &protobuf::DiscardStmt) -> *mut bindings_raw::DiscardStmt { let node = alloc_node::(bindings_raw::NodeTag_T_DiscardStmt); - (*node).target = proto_enum_to_c(stmt.target); + (*node).target = proto_enum_to_c(stmt.target) as _; node } @@ -1675,7 +1675,7 @@ unsafe fn write_alter_publication_stmt(stmt: &protobuf::AlterPublicationStmt) -> (*node).options = write_node_list(&stmt.options); (*node).pubobjects = write_node_list(&stmt.pubobjects); (*node).for_all_tables = stmt.for_all_tables; - (*node).action = proto_enum_to_c(stmt.action); + (*node).action = proto_enum_to_c(stmt.action) as _; node } @@ -1690,7 +1690,7 @@ unsafe fn write_create_subscription_stmt(stmt: &protobuf::CreateSubscriptionStmt unsafe fn write_alter_subscription_stmt(stmt: &protobuf::AlterSubscriptionStmt) -> *mut bindings_raw::AlterSubscriptionStmt { let node = alloc_node::(bindings_raw::NodeTag_T_AlterSubscriptionStmt); - (*node).kind = proto_enum_to_c(stmt.kind); + (*node).kind = proto_enum_to_c(stmt.kind) as _; (*node).subname = pstrdup(&stmt.subname); (*node).conninfo = pstrdup(&stmt.conninfo); (*node).publication = write_node_list(&stmt.publication); @@ -1708,7 +1708,7 @@ unsafe fn write_coerce_to_domain(ctd: &protobuf::CoerceToDomain) -> *mut binding (*node).resulttype = ctd.resulttype; (*node).resulttypmod = ctd.resulttypmod; (*node).resultcollid = ctd.resultcollid; - (*node).coercionformat = proto_enum_to_c(ctd.coercionformat); + (*node).coercionformat = proto_enum_to_c(ctd.coercionformat) as _; (*node).location = ctd.location; node } @@ -1748,7 +1748,7 @@ unsafe fn write_close_portal_stmt(stmt: &protobuf::ClosePortalStmt) -> *mut bind unsafe fn write_fetch_stmt(stmt: &protobuf::FetchStmt) -> *mut bindings_raw::FetchStmt { let node = alloc_node::(bindings_raw::NodeTag_T_FetchStmt); - (*node).direction = proto_enum_to_c(stmt.direction); + (*node).direction = proto_enum_to_c(stmt.direction) as _; (*node).howMany = stmt.how_many; (*node).portalname = pstrdup(&stmt.portalname); (*node).ismove = stmt.ismove; @@ -1769,7 +1769,7 @@ unsafe fn write_declare_cursor_stmt(stmt: &protobuf::DeclareCursorStmt) -> *mut unsafe fn write_define_stmt(stmt: &protobuf::DefineStmt) -> *mut bindings_raw::DefineStmt { let node = alloc_node::(bindings_raw::NodeTag_T_DefineStmt); - (*node).kind = proto_enum_to_c(stmt.kind); + (*node).kind = proto_enum_to_c(stmt.kind) as _; (*node).oldstyle = stmt.oldstyle; (*node).defnames = write_node_list(&stmt.defnames); (*node).args = write_node_list(&stmt.args); @@ -1781,7 +1781,7 @@ unsafe fn write_define_stmt(stmt: &protobuf::DefineStmt) -> *mut bindings_raw::D unsafe fn write_comment_stmt(stmt: &protobuf::CommentStmt) -> *mut bindings_raw::CommentStmt { let node = alloc_node::(bindings_raw::NodeTag_T_CommentStmt); - (*node).objtype = proto_enum_to_c(stmt.objtype); + (*node).objtype = proto_enum_to_c(stmt.objtype) as _; (*node).object = write_node_boxed(&stmt.object); (*node).comment = pstrdup(&stmt.comment); node @@ -1789,7 +1789,7 @@ unsafe fn write_comment_stmt(stmt: &protobuf::CommentStmt) -> *mut bindings_raw: unsafe fn write_sec_label_stmt(stmt: &protobuf::SecLabelStmt) -> *mut bindings_raw::SecLabelStmt { let node = alloc_node::(bindings_raw::NodeTag_T_SecLabelStmt); - (*node).objtype = proto_enum_to_c(stmt.objtype); + (*node).objtype = proto_enum_to_c(stmt.objtype) as _; (*node).object = write_node_boxed(&stmt.object); (*node).provider = pstrdup(&stmt.provider); (*node).label = pstrdup(&stmt.label); @@ -1798,7 +1798,7 @@ unsafe fn write_sec_label_stmt(stmt: &protobuf::SecLabelStmt) -> *mut bindings_r unsafe fn write_create_role_stmt(stmt: &protobuf::CreateRoleStmt) -> *mut bindings_raw::CreateRoleStmt { let node = alloc_node::(bindings_raw::NodeTag_T_CreateRoleStmt); - (*node).stmt_type = proto_enum_to_c(stmt.stmt_type); + (*node).stmt_type = proto_enum_to_c(stmt.stmt_type) as _; (*node).role = pstrdup(&stmt.role); (*node).options = write_node_list(&stmt.options); node @@ -1971,7 +1971,7 @@ unsafe fn write_create_foreign_table_stmt(stmt: &protobuf::CreateForeignTableStm (*node).base.ofTypename = write_type_name_ref(&base.of_typename); (*node).base.constraints = write_node_list(&base.constraints); (*node).base.options = write_node_list(&base.options); - (*node).base.oncommit = proto_enum_to_c(base.oncommit); + (*node).base.oncommit = proto_enum_to_c(base.oncommit) as _; (*node).base.tablespacename = pstrdup(&base.tablespacename); (*node).base.accessMethod = pstrdup(&base.access_method); (*node).base.if_not_exists = base.if_not_exists; @@ -2011,7 +2011,7 @@ unsafe fn write_import_foreign_schema_stmt(stmt: &protobuf::ImportForeignSchemaS (*node).server_name = pstrdup(&stmt.server_name); (*node).remote_schema = pstrdup(&stmt.remote_schema); (*node).local_schema = pstrdup(&stmt.local_schema); - (*node).list_type = proto_enum_to_c(stmt.list_type); + (*node).list_type = proto_enum_to_c(stmt.list_type) as _; (*node).table_list = write_node_list(&stmt.table_list); (*node).options = write_node_list(&stmt.options); node @@ -2044,7 +2044,7 @@ unsafe fn write_alter_table_space_options_stmt(stmt: &protobuf::AlterTableSpaceO unsafe fn write_alter_table_move_all_stmt(stmt: &protobuf::AlterTableMoveAllStmt) -> *mut bindings_raw::AlterTableMoveAllStmt { let node = alloc_node::(bindings_raw::NodeTag_T_AlterTableMoveAllStmt); (*node).orig_tablespacename = pstrdup(&stmt.orig_tablespacename); - (*node).objtype = proto_enum_to_c(stmt.objtype); + (*node).objtype = proto_enum_to_c(stmt.objtype) as _; (*node).roles = write_node_list(&stmt.roles); (*node).new_tablespacename = pstrdup(&stmt.new_tablespacename); (*node).nowait = stmt.nowait; @@ -2062,7 +2062,7 @@ unsafe fn write_alter_extension_contents_stmt(stmt: &protobuf::AlterExtensionCon let node = alloc_node::(bindings_raw::NodeTag_T_AlterExtensionContentsStmt); (*node).extname = pstrdup(&stmt.extname); (*node).action = stmt.action; - (*node).objtype = proto_enum_to_c(stmt.objtype); + (*node).objtype = proto_enum_to_c(stmt.objtype) as _; (*node).object = write_node_boxed(&stmt.object); node } @@ -2073,14 +2073,14 @@ unsafe fn write_alter_domain_stmt(stmt: &protobuf::AlterDomainStmt) -> *mut bind (*node).typeName = write_node_list(&stmt.type_name); (*node).name = pstrdup(&stmt.name); (*node).def = write_node_boxed(&stmt.def); - (*node).behavior = proto_enum_to_c(stmt.behavior); + (*node).behavior = proto_enum_to_c(stmt.behavior) as _; (*node).missing_ok = stmt.missing_ok; node } unsafe fn write_alter_function_stmt(stmt: &protobuf::AlterFunctionStmt) -> *mut bindings_raw::AlterFunctionStmt { let node = alloc_node::(bindings_raw::NodeTag_T_AlterFunctionStmt); - (*node).objtype = proto_enum_to_c(stmt.objtype); + (*node).objtype = proto_enum_to_c(stmt.objtype) as _; (*node).func = write_object_with_args_ref(&stmt.func); (*node).actions = write_node_list(&stmt.actions); node @@ -2102,7 +2102,7 @@ unsafe fn write_alter_type_stmt(stmt: &protobuf::AlterTypeStmt) -> *mut bindings unsafe fn write_alter_owner_stmt(stmt: &protobuf::AlterOwnerStmt) -> *mut bindings_raw::AlterOwnerStmt { let node = alloc_node::(bindings_raw::NodeTag_T_AlterOwnerStmt); - (*node).objectType = proto_enum_to_c(stmt.object_type); + (*node).objectType = proto_enum_to_c(stmt.object_type) as _; (*node).relation = write_range_var_ref(&stmt.relation); (*node).object = write_node_boxed(&stmt.object); (*node).newowner = write_role_spec_ref(&stmt.newowner); @@ -2111,7 +2111,7 @@ unsafe fn write_alter_owner_stmt(stmt: &protobuf::AlterOwnerStmt) -> *mut bindin unsafe fn write_alter_object_schema_stmt(stmt: &protobuf::AlterObjectSchemaStmt) -> *mut bindings_raw::AlterObjectSchemaStmt { let node = alloc_node::(bindings_raw::NodeTag_T_AlterObjectSchemaStmt); - (*node).objectType = proto_enum_to_c(stmt.object_type); + (*node).objectType = proto_enum_to_c(stmt.object_type) as _; (*node).relation = write_range_var_ref(&stmt.relation); (*node).object = write_node_boxed(&stmt.object); (*node).newschema = pstrdup(&stmt.newschema); @@ -2121,7 +2121,7 @@ unsafe fn write_alter_object_schema_stmt(stmt: &protobuf::AlterObjectSchemaStmt) unsafe fn write_alter_object_depends_stmt(stmt: &protobuf::AlterObjectDependsStmt) -> *mut bindings_raw::AlterObjectDependsStmt { let node = alloc_node::(bindings_raw::NodeTag_T_AlterObjectDependsStmt); - (*node).objectType = proto_enum_to_c(stmt.object_type); + (*node).objectType = proto_enum_to_c(stmt.object_type) as _; (*node).relation = write_range_var_ref(&stmt.relation); (*node).object = write_node_boxed(&stmt.object); (*node).extname = write_string_ref(&stmt.extname); @@ -2147,7 +2147,7 @@ unsafe fn write_create_cast_stmt(stmt: &protobuf::CreateCastStmt) -> *mut bindin (*node).sourcetype = write_type_name_ref(&stmt.sourcetype); (*node).targettype = write_type_name_ref(&stmt.targettype); (*node).func = write_object_with_args_ref(&stmt.func); - (*node).context = proto_enum_to_c(stmt.context); + (*node).context = proto_enum_to_c(stmt.context) as _; (*node).inout = stmt.inout; node } @@ -2181,7 +2181,7 @@ unsafe fn write_alter_ts_dictionary_stmt(stmt: &protobuf::AlterTsDictionaryStmt) unsafe fn write_alter_ts_configuration_stmt(stmt: &protobuf::AlterTsConfigurationStmt) -> *mut bindings_raw::AlterTSConfigurationStmt { let node = alloc_node::(bindings_raw::NodeTag_T_AlterTSConfigurationStmt); - (*node).kind = proto_enum_to_c(stmt.kind); + (*node).kind = proto_enum_to_c(stmt.kind) as _; (*node).cfgname = write_node_list(&stmt.cfgname); (*node).tokentype = write_node_list(&stmt.tokentype); (*node).dicts = write_node_list(&stmt.dicts); @@ -2246,7 +2246,7 @@ unsafe fn write_cluster_stmt(stmt: &protobuf::ClusterStmt) -> *mut bindings_raw: unsafe fn write_reindex_stmt(stmt: &protobuf::ReindexStmt) -> *mut bindings_raw::ReindexStmt { let node = alloc_node::(bindings_raw::NodeTag_T_ReindexStmt); - (*node).kind = proto_enum_to_c(stmt.kind); + (*node).kind = proto_enum_to_c(stmt.kind) as _; (*node).relation = write_range_var_ref(&stmt.relation); (*node).name = pstrdup(&stmt.name); (*node).params = write_node_list(&stmt.params); @@ -2269,7 +2269,7 @@ unsafe fn write_load_stmt(stmt: &protobuf::LoadStmt) -> *mut bindings_raw::LoadS unsafe fn write_drop_owned_stmt(stmt: &protobuf::DropOwnedStmt) -> *mut bindings_raw::DropOwnedStmt { let node = alloc_node::(bindings_raw::NodeTag_T_DropOwnedStmt); (*node).roles = write_node_list(&stmt.roles); - (*node).behavior = proto_enum_to_c(stmt.behavior); + (*node).behavior = proto_enum_to_c(stmt.behavior) as _; node } @@ -2284,7 +2284,7 @@ unsafe fn write_drop_subscription_stmt(stmt: &protobuf::DropSubscriptionStmt) -> let node = alloc_node::(bindings_raw::NodeTag_T_DropSubscriptionStmt); (*node).subname = pstrdup(&stmt.subname); (*node).missing_ok = stmt.missing_ok; - (*node).behavior = proto_enum_to_c(stmt.behavior); + (*node).behavior = proto_enum_to_c(stmt.behavior) as _; node } @@ -2353,7 +2353,7 @@ unsafe fn write_range_table_sample(stmt: &protobuf::RangeTableSample) -> *mut bi unsafe fn write_partition_spec(stmt: &protobuf::PartitionSpec) -> *mut bindings_raw::PartitionSpec { let node = alloc_node::(bindings_raw::NodeTag_T_PartitionSpec); - (*node).strategy = proto_enum_to_c(stmt.strategy); + (*node).strategy = proto_enum_to_c(stmt.strategy) as _; (*node).partParams = write_node_list(&stmt.part_params); (*node).location = stmt.location; node @@ -2490,7 +2490,7 @@ unsafe fn write_stats_elem(stmt: &protobuf::StatsElem) -> *mut bindings_raw::Sta unsafe fn write_publication_obj_spec(stmt: &protobuf::PublicationObjSpec) -> *mut bindings_raw::PublicationObjSpec { let node = alloc_node::(bindings_raw::NodeTag_T_PublicationObjSpec); - (*node).pubobjtype = proto_enum_to_c(stmt.pubobjtype); + (*node).pubobjtype = proto_enum_to_c(stmt.pubobjtype) as _; (*node).name = pstrdup(&stmt.name); (*node).pubtable = write_publication_table_ref(&stmt.pubtable); (*node).location = stmt.location; @@ -2511,7 +2511,7 @@ unsafe fn write_publication_table(stmt: &protobuf::PublicationTable) -> *mut bin unsafe fn write_sql_value_function(stmt: &protobuf::SqlValueFunction) -> *mut bindings_raw::SQLValueFunction { let node = alloc_node::(bindings_raw::NodeTag_T_SQLValueFunction); - (*node).op = proto_enum_to_c(stmt.op); + (*node).op = proto_enum_to_c(stmt.op) as _; (*node).type_ = stmt.r#type; (*node).typmod = stmt.typmod; (*node).location = stmt.location; @@ -2524,12 +2524,12 @@ unsafe fn write_sql_value_function(stmt: &protobuf::SqlValueFunction) -> *mut bi unsafe fn write_xml_expr(stmt: &protobuf::XmlExpr) -> *mut bindings_raw::XmlExpr { let node = alloc_node::(bindings_raw::NodeTag_T_XmlExpr); - (*node).op = proto_enum_to_c(stmt.op); + (*node).op = proto_enum_to_c(stmt.op) as _; (*node).name = pstrdup(&stmt.name); (*node).named_args = write_node_list(&stmt.named_args); (*node).arg_names = write_node_list(&stmt.arg_names); (*node).args = write_node_list(&stmt.args); - (*node).xmloption = proto_enum_to_c(stmt.xmloption); + (*node).xmloption = proto_enum_to_c(stmt.xmloption) as _; (*node).indent = stmt.indent; (*node).type_ = stmt.r#type; (*node).typmod = stmt.typmod; @@ -2539,7 +2539,7 @@ unsafe fn write_xml_expr(stmt: &protobuf::XmlExpr) -> *mut bindings_raw::XmlExpr unsafe fn write_xml_serialize(stmt: &protobuf::XmlSerialize) -> *mut bindings_raw::XmlSerialize { let node = alloc_node::(bindings_raw::NodeTag_T_XmlSerialize); - (*node).xmloption = proto_enum_to_c(stmt.xmloption); + (*node).xmloption = proto_enum_to_c(stmt.xmloption) as _; (*node).expr = write_node_boxed(&stmt.expr); (*node).typeName = write_type_name_ref(&stmt.type_name); (*node).indent = stmt.indent; @@ -2566,8 +2566,8 @@ unsafe fn write_named_arg_expr(stmt: &protobuf::NamedArgExpr) -> *mut bindings_r unsafe fn write_json_format(stmt: &protobuf::JsonFormat) -> *mut bindings_raw::JsonFormat { let node = alloc_node::(bindings_raw::NodeTag_T_JsonFormat); - (*node).format_type = proto_enum_to_c(stmt.format_type); - (*node).encoding = proto_enum_to_c(stmt.encoding); + (*node).format_type = proto_enum_to_c(stmt.format_type) as _; + (*node).encoding = proto_enum_to_c(stmt.encoding) as _; (*node).location = stmt.location; node } @@ -2590,7 +2590,7 @@ unsafe fn write_json_value_expr(stmt: &protobuf::JsonValueExpr) -> *mut bindings unsafe fn write_json_constructor_expr(stmt: &protobuf::JsonConstructorExpr) -> *mut bindings_raw::JsonConstructorExpr { let node = alloc_node::(bindings_raw::NodeTag_T_JsonConstructorExpr); - (*node).type_ = proto_enum_to_c(stmt.r#type); + (*node).type_ = proto_enum_to_c(stmt.r#type) as _; (*node).args = write_node_list(&stmt.args); (*node).func = write_node_boxed(&stmt.func) as *mut bindings_raw::Expr; (*node).coercion = write_node_boxed(&stmt.coercion) as *mut bindings_raw::Expr; @@ -2605,7 +2605,7 @@ unsafe fn write_json_is_predicate(stmt: &protobuf::JsonIsPredicate) -> *mut bind let node = alloc_node::(bindings_raw::NodeTag_T_JsonIsPredicate); (*node).expr = write_node_boxed(&stmt.expr); (*node).format = write_json_format_ref(&stmt.format); - (*node).item_type = proto_enum_to_c(stmt.item_type); + (*node).item_type = proto_enum_to_c(stmt.item_type) as _; (*node).unique_keys = stmt.unique_keys; (*node).location = stmt.location; node @@ -2613,7 +2613,7 @@ unsafe fn write_json_is_predicate(stmt: &protobuf::JsonIsPredicate) -> *mut bind unsafe fn write_json_behavior(stmt: &protobuf::JsonBehavior) -> *mut bindings_raw::JsonBehavior { let node = alloc_node::(bindings_raw::NodeTag_T_JsonBehavior); - (*node).btype = proto_enum_to_c(stmt.btype); + (*node).btype = proto_enum_to_c(stmt.btype) as _; (*node).expr = write_node_boxed(&stmt.expr); (*node).coerce = stmt.coerce; (*node).location = stmt.location; @@ -2622,7 +2622,7 @@ unsafe fn write_json_behavior(stmt: &protobuf::JsonBehavior) -> *mut bindings_ra unsafe fn write_json_expr(stmt: &protobuf::JsonExpr) -> *mut bindings_raw::JsonExpr { let node = alloc_node::(bindings_raw::NodeTag_T_JsonExpr); - (*node).op = proto_enum_to_c(stmt.op); + (*node).op = proto_enum_to_c(stmt.op) as _; (*node).column_name = pstrdup(&stmt.column_name); (*node).formatted_expr = write_node_boxed(&stmt.formatted_expr); (*node).format = write_json_format_ref(&stmt.format); @@ -2634,7 +2634,7 @@ unsafe fn write_json_expr(stmt: &protobuf::JsonExpr) -> *mut bindings_raw::JsonE (*node).on_error = write_json_behavior_ref(&stmt.on_error); (*node).use_io_coercion = stmt.use_io_coercion; (*node).use_json_coercion = stmt.use_json_coercion; - (*node).wrapper = proto_enum_to_c(stmt.wrapper); + (*node).wrapper = proto_enum_to_c(stmt.wrapper) as _; (*node).omit_quotes = stmt.omit_quotes; (*node).collation = stmt.collation; (*node).location = stmt.location; @@ -2682,7 +2682,7 @@ unsafe fn write_json_argument(stmt: &protobuf::JsonArgument) -> *mut bindings_ra unsafe fn write_json_func_expr(stmt: &protobuf::JsonFuncExpr) -> *mut bindings_raw::JsonFuncExpr { let node = alloc_node::(bindings_raw::NodeTag_T_JsonFuncExpr); - (*node).op = proto_enum_to_c(stmt.op); + (*node).op = proto_enum_to_c(stmt.op) as _; (*node).column_name = pstrdup(&stmt.column_name); (*node).context_item = write_json_value_expr_ref(&stmt.context_item); (*node).pathspec = write_node_boxed(&stmt.pathspec); @@ -2690,8 +2690,8 @@ unsafe fn write_json_func_expr(stmt: &protobuf::JsonFuncExpr) -> *mut bindings_r (*node).output = write_json_output_ref(&stmt.output); (*node).on_empty = write_json_behavior_ref(&stmt.on_empty); (*node).on_error = write_json_behavior_ref(&stmt.on_error); - (*node).wrapper = proto_enum_to_c(stmt.wrapper); - (*node).quotes = proto_enum_to_c(stmt.quotes); + (*node).wrapper = proto_enum_to_c(stmt.wrapper) as _; + (*node).quotes = proto_enum_to_c(stmt.quotes) as _; (*node).location = stmt.location; node } @@ -2720,13 +2720,13 @@ unsafe fn write_json_table(stmt: &protobuf::JsonTable) -> *mut bindings_raw::Jso unsafe fn write_json_table_column(stmt: &protobuf::JsonTableColumn) -> *mut bindings_raw::JsonTableColumn { let node = alloc_node::(bindings_raw::NodeTag_T_JsonTableColumn); - (*node).coltype = proto_enum_to_c(stmt.coltype); + (*node).coltype = proto_enum_to_c(stmt.coltype) as _; (*node).name = pstrdup(&stmt.name); (*node).typeName = write_type_name_ref(&stmt.type_name); (*node).pathspec = write_json_table_path_spec_ref(&stmt.pathspec); (*node).format = write_json_format_ref(&stmt.format); - (*node).wrapper = proto_enum_to_c(stmt.wrapper); - (*node).quotes = proto_enum_to_c(stmt.quotes); + (*node).wrapper = proto_enum_to_c(stmt.wrapper) as _; + (*node).quotes = proto_enum_to_c(stmt.quotes) as _; (*node).columns = write_node_list(&stmt.columns); (*node).on_empty = write_json_behavior_ref(&stmt.on_empty); (*node).on_error = write_json_behavior_ref(&stmt.on_error); From 134de026ad8cd45c4ae65af285d0f70f4f80e3f8 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 1 Jan 2026 13:34:23 -0800 Subject: [PATCH 17/17] windows --- src/raw_deparse.rs | 2 +- src/raw_parse.rs | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/raw_deparse.rs b/src/raw_deparse.rs index cf1b4ac..bdad7ae 100644 --- a/src/raw_deparse.rs +++ b/src/raw_deparse.rs @@ -1749,7 +1749,7 @@ unsafe fn write_close_portal_stmt(stmt: &protobuf::ClosePortalStmt) -> *mut bind unsafe fn write_fetch_stmt(stmt: &protobuf::FetchStmt) -> *mut bindings_raw::FetchStmt { let node = alloc_node::(bindings_raw::NodeTag_T_FetchStmt); (*node).direction = proto_enum_to_c(stmt.direction) as _; - (*node).howMany = stmt.how_many; + (*node).howMany = stmt.how_many as _; (*node).portalname = pstrdup(&stmt.portalname); (*node).ismove = stmt.ismove; node diff --git a/src/raw_parse.rs b/src/raw_parse.rs index 45b59b6..5ecec1a 100644 --- a/src/raw_parse.rs +++ b/src/raw_parse.rs @@ -2179,7 +2179,12 @@ unsafe fn convert_close_portal_stmt(cps: &bindings_raw::ClosePortalStmt) -> prot } unsafe fn convert_fetch_stmt(fs: &bindings_raw::FetchStmt) -> protobuf::FetchStmt { - protobuf::FetchStmt { direction: fs.direction as i32 + 1, how_many: fs.howMany, portalname: convert_c_string(fs.portalname), ismove: fs.ismove } + protobuf::FetchStmt { + direction: fs.direction as i32 + 1, + how_many: fs.howMany as i64, + portalname: convert_c_string(fs.portalname), + ismove: fs.ismove, + } } unsafe fn convert_declare_cursor_stmt(dcs: &bindings_raw::DeclareCursorStmt) -> protobuf::DeclareCursorStmt {