diff --git a/crates/squawk_linter/src/identifier.rs b/crates/squawk_linter/src/identifier.rs new file mode 100644 index 00000000..873a9d68 --- /dev/null +++ b/crates/squawk_linter/src/identifier.rs @@ -0,0 +1,34 @@ +/// Postgres Identifiers are case insensitive unless they're quoted. +/// +/// This type handles the casing rules for us to make comparisions easier. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) struct Identifier(String); + +impl Identifier { + // TODO: we need to handle more advanced identifiers like: + // U&"d!0061t!+000061" UESCAPE '!' + pub fn new(s: &str) -> Self { + let normalized = if s.starts_with('"') && s.ends_with('"') { + s[1..s.len() - 1].to_string() + } else { + s.to_lowercase() + }; + Identifier(normalized) + } +} + +#[cfg(test)] +mod test { + use crate::identifier::Identifier; + + #[test] + fn case_folds_correctly() { + // https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS + // For example, the identifiers FOO, foo, and "foo" are considered the + // same by PostgreSQL, but "Foo" and "FOO" are different from these + // three and each other. + assert_eq!(Identifier::new("FOO"), Identifier::new("foo")); + assert_eq!(Identifier::new(r#""foo""#), Identifier::new("foo")); + assert_eq!(Identifier::new(r#""foo""#), Identifier::new("FOO")); + } +} diff --git a/crates/squawk_linter/src/lib.rs b/crates/squawk_linter/src/lib.rs index 2da41aff..afbdf08a 100644 --- a/crates/squawk_linter/src/lib.rs +++ b/crates/squawk_linter/src/lib.rs @@ -19,8 +19,8 @@ mod ignore_index; mod version; mod visitors; +mod identifier; mod rules; -mod text; use rules::adding_field_with_default; use rules::adding_foreign_key_constraint; use rules::adding_not_null_field; diff --git a/crates/squawk_linter/src/rules/adding_field_with_default.rs b/crates/squawk_linter/src/rules/adding_field_with_default.rs index 55a29a86..98c7a6f0 100644 --- a/crates/squawk_linter/src/rules/adding_field_with_default.rs +++ b/crates/squawk_linter/src/rules/adding_field_with_default.rs @@ -5,6 +5,7 @@ use squawk_syntax::ast; use squawk_syntax::ast::AstNode; use squawk_syntax::{Parse, SourceFile}; +use crate::identifier::Identifier; use crate::{Linter, Rule, Violation}; fn is_const_expr(expr: &ast::Expr) -> bool { @@ -16,11 +17,12 @@ fn is_const_expr(expr: &ast::Expr) -> bool { } lazy_static! { - static ref NON_VOLATILE_FUNCS: HashSet = { + static ref NON_VOLATILE_FUNCS: HashSet = { NON_VOLATILE_BUILT_IN_FUNCTIONS .split('\n') - .map(|x| x.trim().to_lowercase()) + .map(|x| x.trim()) .filter(|x| !x.is_empty()) + .map(|x| Identifier::new(x)) .collect() }; } @@ -36,7 +38,8 @@ fn is_non_volatile(expr: &ast::Expr) -> bool { return false; }; - let non_volatile_name = NON_VOLATILE_FUNCS.contains(name_ref.text().as_str()); + let non_volatile_name = + NON_VOLATILE_FUNCS.contains(&Identifier::new(name_ref.text().as_str())); no_args && non_volatile_name } else { diff --git a/crates/squawk_linter/src/rules/constraint_missing_not_valid.rs b/crates/squawk_linter/src/rules/constraint_missing_not_valid.rs index 9e3744ca..5d824f60 100644 --- a/crates/squawk_linter/src/rules/constraint_missing_not_valid.rs +++ b/crates/squawk_linter/src/rules/constraint_missing_not_valid.rs @@ -5,12 +5,15 @@ use squawk_syntax::{ ast::{self, AstNode}, }; -use crate::{Linter, Rule, Violation, text::trim_quotes}; +use crate::{ + Linter, Rule, Violation, + identifier::Identifier, +}; pub fn tables_created_in_transaction( assume_in_transaction: bool, file: &ast::SourceFile, -) -> HashSet { +) -> HashSet { let mut created_table_names = HashSet::new(); let mut inside_transaction = assume_in_transaction; for stmt in file.stmts() { @@ -29,7 +32,7 @@ pub fn tables_created_in_transaction( else { continue; }; - created_table_names.insert(trim_quotes(&table_name.text()).to_string()); + created_table_names.insert(Identifier::new(&table_name.text())); } _ => (), } @@ -43,7 +46,7 @@ fn not_valid_validate_in_transaction( file: &ast::SourceFile, ) { let mut inside_transaction = assume_in_transaction; - let mut not_valid_names: HashSet = HashSet::new(); + let mut not_valid_names: HashSet = HashSet::new(); for stmt in file.stmts() { match stmt { ast::Stmt::AlterTable(alter_table) => { @@ -54,7 +57,7 @@ fn not_valid_validate_in_transaction( validate_constraint.name_ref().map(|x| x.text().to_string()) { if inside_transaction - && not_valid_names.contains(trim_quotes(&constraint_name)) + && not_valid_names.contains(&Identifier::new(&constraint_name)) { ctx.report( Violation::new( @@ -70,9 +73,7 @@ fn not_valid_validate_in_transaction( if add_constraint.not_valid().is_some() { if let Some(constraint) = add_constraint.constraint() { if let Some(constraint_name) = constraint.name() { - not_valid_names.insert( - trim_quotes(&constraint_name.text()).to_string(), - ); + not_valid_names.insert(Identifier::new(&constraint_name.text())); } } } @@ -117,7 +118,7 @@ pub(crate) fn constraint_missing_not_valid(ctx: &mut Linter, parse: &Parse = - HashSet::from(["integer", "int4", "serial", "serial4",]); + static ref INT_TYPES: HashSet = HashSet::from([ + Identifier::new("integer"), + Identifier::new("int4"), + Identifier::new("serial"), + Identifier::new("serial4"), + ]); } fn check_ty_for_big_int(ctx: &mut Linter, ty: Option) { diff --git a/crates/squawk_linter/src/rules/prefer_bigint_over_smallint.rs b/crates/squawk_linter/src/rules/prefer_bigint_over_smallint.rs index 61e817ca..530f8247 100644 --- a/crates/squawk_linter/src/rules/prefer_bigint_over_smallint.rs +++ b/crates/squawk_linter/src/rules/prefer_bigint_over_smallint.rs @@ -3,6 +3,7 @@ use std::collections::HashSet; use squawk_syntax::ast::AstNode; use squawk_syntax::{Parse, SourceFile, ast}; +use crate::identifier::Identifier; use crate::{Linter, Rule, Violation}; use crate::visitors::check_not_allowed_types; @@ -11,8 +12,12 @@ use crate::visitors::is_not_valid_int_type; use lazy_static::lazy_static; lazy_static! { - static ref SMALL_INT_TYPES: HashSet<&'static str> = - HashSet::from(["smallint", "int2", "smallserial", "serial2",]); + static ref SMALL_INT_TYPES: HashSet = HashSet::from([ + Identifier::new("smallint"), + Identifier::new("int2"), + Identifier::new("smallserial"), + Identifier::new("serial2"), + ]); } fn check_ty_for_small_int(ctx: &mut Linter, ty: Option) { diff --git a/crates/squawk_linter/src/rules/prefer_identity.rs b/crates/squawk_linter/src/rules/prefer_identity.rs index 1d74932e..1cb0fda3 100644 --- a/crates/squawk_linter/src/rules/prefer_identity.rs +++ b/crates/squawk_linter/src/rules/prefer_identity.rs @@ -5,20 +5,20 @@ use squawk_syntax::{ ast::{self, AstNode}, }; -use crate::{Linter, Rule, Violation}; +use crate::{Linter, Rule, Violation, identifier::Identifier}; use lazy_static::lazy_static; use crate::visitors::{check_not_allowed_types, is_not_valid_int_type}; lazy_static! { - static ref SERIAL_TYPES: HashSet<&'static str> = HashSet::from([ - "serial", - "serial2", - "serial4", - "serial8", - "smallserial", - "bigserial", + static ref SERIAL_TYPES: HashSet = HashSet::from([ + Identifier::new("serial"), + Identifier::new("serial2"), + Identifier::new("serial4"), + Identifier::new("serial8"), + Identifier::new("smallserial"), + Identifier::new("bigserial"), ]); } @@ -86,6 +86,23 @@ create table users ( assert_debug_snapshot!(errors); } + #[test] + fn ok_when_quoted() { + let sql = r#" +create table users ( + id "serial" +); +create table users ( + id "bigserial" +); + "#; + let file = squawk_syntax::SourceFile::parse(sql); + let mut linter = Linter::from([Rule::PreferIdentity]); + let errors = linter.lint(file, sql); + assert_eq!(errors.len(), 2); + assert_debug_snapshot!(errors); + } + #[test] fn ok() { let sql = r#" diff --git a/crates/squawk_linter/src/rules/prefer_robust_stmts.rs b/crates/squawk_linter/src/rules/prefer_robust_stmts.rs index d3acc68c..3b427408 100644 --- a/crates/squawk_linter/src/rules/prefer_robust_stmts.rs +++ b/crates/squawk_linter/src/rules/prefer_robust_stmts.rs @@ -5,7 +5,10 @@ use squawk_syntax::{ ast::{self, AstNode}, }; -use crate::{Linter, Rule, Violation, text::trim_quotes}; +use crate::{ + Linter, Rule, Violation, + identifier::Identifier, +}; #[derive(PartialEq)] enum Constraint { @@ -16,7 +19,7 @@ enum Constraint { pub(crate) fn prefer_robust_stmts(ctx: &mut Linter, parse: &Parse) { let file = parse.tree(); let mut inside_transaction = ctx.settings.assume_in_transaction; - let mut constraint_names: HashMap = HashMap::new(); + let mut constraint_names: HashMap = HashMap::new(); let mut total_stmts = 0; for _ in file.stmts() { @@ -50,7 +53,7 @@ pub(crate) fn prefer_robust_stmts(ctx: &mut Linter, parse: &Parse) { ast::AlterTableAction::DropConstraint(drop_constraint) => { if let Some(constraint_name) = drop_constraint.name_ref() { constraint_names.insert( - trim_quotes(constraint_name.text().as_str()).to_string(), + Identifier::new(constraint_name.text().as_str()), Constraint::Dropped, ); } @@ -68,7 +71,7 @@ pub(crate) fn prefer_robust_stmts(ctx: &mut Linter, parse: &Parse) { ast::AlterTableAction::ValidateConstraint(validate_constraint) => { if let Some(constraint_name) = validate_constraint.name_ref() { if constraint_names - .contains_key(trim_quotes(constraint_name.text().as_str())) + .contains_key(&Identifier::new(constraint_name.text().as_str())) { continue; } @@ -79,8 +82,8 @@ pub(crate) fn prefer_robust_stmts(ctx: &mut Linter, parse: &Parse) { let constraint = add_constraint.constraint(); if let Some(constraint_name) = constraint.and_then(|x| x.name()) { let name_text = constraint_name.text(); - let name = trim_quotes(name_text.as_str()); - if let Some(constraint) = constraint_names.get_mut(name) { + let name = Identifier::new(name_text.as_str()); + if let Some(constraint) = constraint_names.get_mut(&name) { if *constraint == Constraint::Dropped { *constraint = Constraint::Added; continue; diff --git a/crates/squawk_linter/src/rules/prefer_text_field.rs b/crates/squawk_linter/src/rules/prefer_text_field.rs index c173dddc..45a0ade6 100644 --- a/crates/squawk_linter/src/rules/prefer_text_field.rs +++ b/crates/squawk_linter/src/rules/prefer_text_field.rs @@ -5,7 +5,7 @@ use squawk_syntax::{ ast::{self, AstNode}, }; -use crate::{Linter, Rule, Violation, text::trim_quotes}; +use crate::{Linter, Rule, Violation, identifier::Identifier}; use crate::visitors::check_not_allowed_types; @@ -35,10 +35,10 @@ fn is_not_allowed_varchar(ty: &ast::Type) -> bool { return false; }; // if we don't have any args, then it's the same as `text` - trim_quotes(ty_name.as_str()) == "varchar" && path_type.arg_list().is_some() + Identifier::new(ty_name.as_str()) == Identifier::new("varchar") && path_type.arg_list().is_some() } ast::Type::CharType(char_type) => { - trim_quotes(&char_type.text()) == "varchar" && char_type.arg_list().is_some() + Identifier::new(&char_type.text()) == Identifier::new("varchar") && char_type.arg_list().is_some() } ast::Type::BitType(_) => false, ast::Type::DoubleType(_) => false, diff --git a/crates/squawk_linter/src/rules/prefer_timestamptz.rs b/crates/squawk_linter/src/rules/prefer_timestamptz.rs index 6d60d22a..b5587db1 100644 --- a/crates/squawk_linter/src/rules/prefer_timestamptz.rs +++ b/crates/squawk_linter/src/rules/prefer_timestamptz.rs @@ -4,7 +4,7 @@ use squawk_syntax::{ }; use crate::{Linter, Rule, Violation}; -use crate::{text::trim_quotes, visitors::check_not_allowed_types}; +use crate::{identifier::Identifier, visitors::check_not_allowed_types}; pub fn is_not_allowed_timestamp(ty: &ast::Type) -> bool { match ty { @@ -26,7 +26,8 @@ pub fn is_not_allowed_timestamp(ty: &ast::Type) -> bool { return false; }; // if we don't have any args, then it's the same as `text` - trim_quotes(ty_name.as_str()) == "varchar" && path_type.arg_list().is_some() + Identifier::new(ty_name.as_str()) == Identifier::new("varchar") + && path_type.arg_list().is_some() } ast::Type::CharType(_) => false, ast::Type::BitType(_) => false, diff --git a/crates/squawk_linter/src/rules/require_concurrent_index_creation.rs b/crates/squawk_linter/src/rules/require_concurrent_index_creation.rs index ad8796f3..c76b325e 100644 --- a/crates/squawk_linter/src/rules/require_concurrent_index_creation.rs +++ b/crates/squawk_linter/src/rules/require_concurrent_index_creation.rs @@ -3,7 +3,7 @@ use squawk_syntax::{ ast::{self, AstNode}, }; -use crate::{Linter, Rule, Violation, text::trim_quotes}; +use crate::{Linter, Rule, Violation, identifier::Identifier}; use super::constraint_missing_not_valid::tables_created_in_transaction; @@ -19,7 +19,7 @@ pub(crate) fn require_concurrent_index_creation(ctx: &mut Linter, parse: &Parse< .and_then(|x| x.name_ref()) { if create_index.concurrently_token().is_none() - && !tables_created.contains(trim_quotes(table_name.text().as_str())) + && !tables_created.contains(&Identifier::new(&table_name.text())) { ctx.report(Violation::new( Rule::RequireConcurrentIndexCreation, diff --git a/crates/squawk_linter/src/rules/snapshots/squawk_linter__rules__prefer_identity__test__ok_when_quoted.snap b/crates/squawk_linter/src/rules/snapshots/squawk_linter__rules__prefer_identity__test__ok_when_quoted.snap new file mode 100644 index 00000000..c6c66aa5 --- /dev/null +++ b/crates/squawk_linter/src/rules/snapshots/squawk_linter__rules__prefer_identity__test__ok_when_quoted.snap @@ -0,0 +1,22 @@ +--- +source: crates/squawk_linter/src/rules/prefer_identity.rs +expression: errors +--- +[ + Violation { + code: PreferIdentity, + message: "Serial types make schema, dependency, and permission management difficult.", + text_range: 29..37, + help: Some( + "Use an `IDENTITY` column instead.", + ), + }, + Violation { + code: PreferIdentity, + message: "Serial types make schema, dependency, and permission management difficult.", + text_range: 69..80, + help: Some( + "Use an `IDENTITY` column instead.", + ), + }, +] diff --git a/crates/squawk_linter/src/text.rs b/crates/squawk_linter/src/text.rs deleted file mode 100644 index b9a6c5fb..00000000 --- a/crates/squawk_linter/src/text.rs +++ /dev/null @@ -1,8 +0,0 @@ -// TODO: figure out a better way to handle quoted and unquoted idents -pub(crate) fn trim_quotes(s: &str) -> &str { - if s.starts_with('"') && s.ends_with('"') { - &s[1..s.len() - 1] - } else { - s - } -} diff --git a/crates/squawk_linter/src/visitors.rs b/crates/squawk_linter/src/visitors.rs index f7b9d6db..cd6676dc 100644 --- a/crates/squawk_linter/src/visitors.rs +++ b/crates/squawk_linter/src/visitors.rs @@ -2,9 +2,12 @@ use std::collections::HashSet; use squawk_syntax::ast; -use crate::{Linter, text::trim_quotes}; +use crate::{Linter, identifier::Identifier}; -pub(crate) fn is_not_valid_int_type(ty: &ast::Type, invalid_type_names: &HashSet<&str>) -> bool { +pub(crate) fn is_not_valid_int_type( + ty: &ast::Type, + invalid_type_names: &HashSet, +) -> bool { match ty { ast::Type::ArrayType(array_type) => { if let Some(ty) = array_type.ty() { @@ -23,8 +26,8 @@ pub(crate) fn is_not_valid_int_type(ty: &ast::Type, invalid_type_names: &HashSet else { return false; }; - let name = trim_quotes(ty_name.as_str()); - invalid_type_names.contains(name.to_lowercase().as_str()) + let name = Identifier::new(ty_name.as_str()); + invalid_type_names.contains(&name) } ast::Type::CharType(_) => false, ast::Type::BitType(_) => false,