diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs deleted file mode 100644 index ed98ff5..0000000 --- a/.git-blame-ignore-revs +++ /dev/null @@ -1,2 +0,0 @@ -# Scala Steward: Reformat with scalafmt 3.7.4 -3f8323c8a559d0739ebd00b92ad14f373eaf177d diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ba9b643..137c69b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,55 +8,60 @@ concurrency: group: ${{ github.workflow }} @ ${{ github.ref }} cancel-in-progress: true jobs: + build: + strategy: + matrix: + os: [ubuntu-latest, macos-15] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - uses: coursier/cache-action@v6 + - uses: VirtusLab/scala-cli-setup@main + with: + jvm: temurin:21 + - name: Install libpq (macOS) + if: runner.os == 'macOS' + run: brew install libpq && brew link --force libpq + - run: | + OUT=out/skunk-codegen-$(uname -m)-$(uname | tr '[:upper:]' '[:lower:]') + echo "Compiling to $OUT" + scala-cli --power package \ + --native \ + --native-mode release-fast PgCodeGen.scala \ + -o $OUT -f && \ + zip -j "${OUT}.zip" $OUT + - name: Upload command line binaries + uses: actions/upload-artifact@v4 + with: + name: codegen-bin-${{ matrix.os }} + path: out/* test: runs-on: ubuntu-latest steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Setup JDK - uses: actions/setup-java@v4 - with: - distribution: temurin - java-version: 21 - cache: sbt - - name: Start up Postgres - run: docker run --rm -d -e POSTGRES_PASSWORD=postgres -p 5432:5432 postgres:16-alpine - - name: Generate code - run: sbt '+core/Test/runMain com.anymindgroup.RunPgCodeGen' - - name: Test generated code - run: sbt '++2.13 Test/runMain com.anymindgroup.GeneratedCodeTest; ++3.3 Test/runMain com.anymindgroup.GeneratedCodeTest' - - name: Test sbt plugin - # for sbt < v2 which only supports scala 2.12 - run: sbt ++2.12 scripted - release: - name: Release - runs-on: ubuntu-latest - continue-on-error: false - needs: - - test - if: ${{ startsWith(github.ref, 'refs/tags/v') }} + - uses: actions/checkout@v2 + - uses: coursier/cache-action@v6 + - uses: VirtusLab/scala-cli-setup@main + with: + jvm: temurin:21 + - run: ./test.sh + + publish-bin: + name: Publish command line binaries + needs: [build] + if: startsWith(github.ref, 'refs/tags/') + strategy: + matrix: + os: [ubuntu-latest] + runs-on: ${{ matrix.os }} steps: - - name: Git Checkout - uses: actions/checkout@v4 - - name: Setup JDK - uses: actions/setup-java@v4 - with: - distribution: temurin - java-version: 21 - cache: sbt - - name: Import signing key and strip passphrase - if: env.PGP_SECRET != '' && env.PGP_PASSPHRASE != '' - env: - PGP_SECRET: ${{ secrets.PGP_SECRET }} - PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }} - run: | - echo "$PGP_SECRET" | base64 -d -i - > /tmp/signing-key.gpg - echo "$PGP_PASSPHRASE" | gpg --pinentry-mode loopback --passphrase-fd 0 --import /tmp/signing-key.gpg - (echo "$PGP_PASSPHRASE"; echo; echo) | gpg --command-fd 0 --pinentry-mode loopback --change-passphrase $(gpg --list-secret-keys --with-colons 2> /dev/null | grep '^sec:' | cut --delimiter ':' --fields 5 | tail -n 1) - - name: Release - run: sbt '++2.12 publishSigned; sonatypeCentralRelease' - env: - PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }} - PGP_SECRET: ${{ secrets.PGP_SECRET }} - SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }} - SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }} \ No newline at end of file + - name: Download command line binaries + uses: actions/download-artifact@v4 + with: + pattern: codegen-bin-* + path: out + merge-multiple: true + + - name: Upload release binaries + uses: softprops/action-gh-release@v1 + with: + files: out/* \ No newline at end of file diff --git a/.gitignore b/.gitignore index 4762d81..88251bc 100644 --- a/.gitignore +++ b/.gitignore @@ -15,5 +15,6 @@ target/ .vscode/ .bloop/ metals.sbt -modules/core/src/test/scala-*/com/anymindgroup/generated/ -modules/core/src/test/scala-*/com/anymindgroup/GeneratedCodeTest.scala +test-generated +.scala-build/ +out \ No newline at end of file diff --git a/.scalafmt.conf b/.scalafmt.conf index 672e7dd..e7b57e7 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,26 +1,7 @@ -version = "3.7.7" +version = "3.9.7" maxColumn = 120 -align.preset = most -align.multiline = false -continuationIndent.defnSite = 2 -assumeStandardLibraryStripMargin = true -docstrings.style = Asterisk -docstrings.wrapMaxColumn = 80 -lineEndings = preserve -trailingCommas = multiple -includeCurlyBraceInSelectChains = false -danglingParentheses.preset = true -optIn.annotationNewlines = true -newlines.alwaysBeforeMultilineDef = false runner.dialect = scala3 -rewrite.rules = [RedundantBraces] -indentOperator.exemptScope = aloneArgOrBody -indentOperator.excludeRegex = "^(&&|\\|\\|)$" -project.excludeFilters = [] - -rewrite.redundantBraces.generalExpressions = false -rewriteTokens = { - "⇒": "=>" - "→": "->" - "←": "<-" -} \ No newline at end of file +rewrite.rules = [Imports] +rewrite.imports.sort = scalastyle +rewrite.imports.expand = false +rewrite.imports.groups = [["scala\\..*"], ["java\\..*"]] \ No newline at end of file diff --git a/PgCodeGen.scala b/PgCodeGen.scala new file mode 100644 index 0000000..7b3acc9 --- /dev/null +++ b/PgCodeGen.scala @@ -0,0 +1,1066 @@ +//> using scala 3.7.1 +//> using dep com.indoorvivants.roach::core::0.1.0 +//> using dep com.github.lolgab::scala-native-crypto::0.2.1 +//> using platform native +//> using nativeVersion 0.5.8 + +package com.anymindgroup + +import scala.annotation.tailrec +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration.* +import scala.jdk.CollectionConverters.* +import scala.scalanative.unsafe.Zone +import scala.sys.process.* +import scala.util.{Failure, Random, Success, Using} + +import java.io.File +import java.net.ServerSocket +import java.net.URI +import java.nio.charset.Charset +import java.nio.file.{Files, Path, Paths} +import java.security.MessageDigest + +import roach.* +import roach.codecs.* + +@main +def run(args: String*) = + (for + argsMap <- + try + Right( + args + .map(arg => + val (k, v) = arg.splitAt(arg.indexOf("=")) + (k, v.stripPrefix("=").stripPrefix("\"").stripPrefix("'").stripSuffix("\"").stripSuffix("'")) + ) + .toMap + ) + catch case e: Throwable => Left(e.getMessage()) + outputDir <- argsMap.get("-output-dir").map(File(_)).toRight("outputDir not set") + pkgName <- argsMap.get("-pkg-name").toRight("pkgName not set") + sourceDir <- argsMap.get("-source-dir").map(File(_)).toRight("sourceDir not set") + useDockerImage = argsMap.get("-use-docker-image").getOrElse("postgres:17-alpine") + useConnectionUri <- argsMap.get("-use-connection") match + case Some(u) => + try Right(Some(java.net.URI(u))) + catch case e: Throwable => Left(s"Invalid connection URI: ${e.getMessage()}") + case _ => Right(None) + excludeTables = argsMap.get("-exclude-tables").toList.flatMap(_.split(",")) + scalaVersion = argsMap.get("-scala-version").getOrElse("3.7.1") + debug = argsMap.get("-debug") match + case Some("1" | "true") => true + case _ => false + forceRegeneration = argsMap.get("-force") match + case Some("1" | "true") => true + case _ => false + _ = if debug then println(s"Running code generator with arguments: ${args.mkString(", ")}") + yield PgCodeGen.run( + useDockerImage = useDockerImage, + outputDir = outputDir, + pkgName = pkgName, + sourceDir = sourceDir, + excludeTables = excludeTables, + scalaVersion = scalaVersion, + forceRegeneration = forceRegeneration, + useConnectionUri = useConnectionUri, + debug = debug + )(using ExecutionContext.global)) match + case Right(task) => + try + Await.result(task, 30.seconds) + sys.exit(0) + catch + case err: Throwable => + Console.err.println(s"Failure: ${err.getMessage()}") + sys.exit(1) + case Left(err) => + Console.err.println(s"Failure: $err") + sys.exit(1) + +extension (p: Path) def /(s: String): Path = Paths.get(p.toString(), s) + +case class Type(name: String, componentTypes: List[Type] = Nil) + +class PgCodeGen private ( + pkgName: String, + sourceFiles: List[Path], + excludeTables: List[String], + debug: Boolean, + user: String, + password: String, + host: String, + port: Int, + database: String, + schemaHistoryTableName: String, + pkgDir: Path, + outDir: Path +)(using ExecutionContext) { + import PgCodeGen.* + + private val connectionString = s"postgresql://$user:$password@$host:$port/$database" + + private def getConstraints = + pgSessionRun: + val q = + sql"""SELECT c.table_name, c.constraint_name, c.constraint_type, cu.column_name, cu.table_name, kcu.column_name + FROM information_schema.table_constraints AS c + JOIN information_schema.key_column_usage as kcu ON kcu.constraint_name = c.constraint_name + JOIN information_schema.constraint_column_usage AS cu ON cu.constraint_name = c.constraint_name + WHERE c.table_schema='public'""".all(name ~ name ~ varchar ~ name ~ name ~ name) + + q.map { (a, b, c, d, e, f) => + ConstraintRow(tableName = a, name = b, typ = c, refCol = d, refTable = e, fromCol = f) + }.groupBy(_.tableName) + .map { case (tName, constraints) => + ( + tName, + constraints.groupBy(c => (c.name, c.typ)).toVector.map { + case ((cName, "PRIMARY KEY"), cItems) => + Constraint.PrimaryKey(name = cName, columnNames = cItems.map(_.fromCol)) + case ((cName, "UNIQUE"), cItems) => Constraint.Unique(name = cName, columnNames = cItems.map(_.fromCol)) + case ((cName, "FOREIGN KEY"), cItems) => + Constraint.ForeignKey( + name = cName, + refs = cItems.map { cr => + ColumnRef(fromColName = cr.fromCol, toColName = cr.refCol, toTableName = cr.refTable) + } + ) + case ((cName, _), _) => Constraint.Unknown(cName) + } + ) + } + + private def toType( + udt: String, + maxCharLength: Option[Int], + numPrecision: Option[Int], + numScale: Option[Int] + ): Type = + (udt, maxCharLength, numPrecision, numScale) match { + case (u @ ("bpchar" | "varchar"), Some(l), _, _) => Type(s"$u($l)") + case ("numeric", _, Some(p), Some(s)) => Type(s"numeric($p${if (s > 0) ", " + s.toString else ""})") + case _ => + val componentTypes = if (udt.startsWith("_")) List(Type(udt.stripPrefix("_"))) else Nil + Type(udt, componentTypes) + } + + private def pgSessionRun[A](f: (Zone, Database) ?=> A): Future[A] = + Future: + Zone: + Pool.single(connectionString): pool => + pool.withLease(f) + + private def getColumns(enums: Enums) = + pgSessionRun: + val filterFragment = + s" AND table_name NOT IN (${(schemaHistoryTableName :: excludeTables).mkString("'", "','", "'")})" + + val q = + sql"""SELECT table_name,column_name,udt_name,character_maximum_length,numeric_precision,numeric_scale,is_nullable,column_default,is_generated + FROM information_schema.COLUMNS WHERE table_schema = 'public'$filterFragment UNION + (SELECT + cls.relname AS table_name, + attr.attname AS column_name, + tp.typname AS udt_name, + information_schema._pg_char_max_length(information_schema._pg_truetypid(attr.*, tp.*), information_schema._pg_truetypmod( + attr.*, tp.*))::information_schema.cardinal_number AS character_maximum_length, + information_schema._pg_numeric_precision(information_schema._pg_truetypid(attr.*, tp.*), information_schema._pg_truetypmod( + attr.*, tp.*))::information_schema.cardinal_number AS numeric_precision, + information_schema._pg_numeric_scale(information_schema._pg_truetypid(attr.*, tp.*), information_schema._pg_truetypmod( + attr.*, tp.*))::information_schema.cardinal_number AS numeric_scale, + CASE + WHEN attr.attnotnull OR tp.typtype = 'd'::"char" AND tp.typnotnull THEN 'NO'::text + ELSE 'YES'::text + END::information_schema.yes_or_no AS is_nullable, + NULL AS column_default, + 'NEVER' AS is_generated + FROM pg_catalog.pg_attribute as attr + JOIN pg_catalog.pg_class as cls on cls.oid = attr.attrelid + JOIN pg_catalog.pg_namespace as ns on ns.oid = cls.relnamespace + JOIN pg_catalog.pg_type as tp on tp.oid = attr.atttypid + WHERE cls.relkind = 'm' and attr.attnum >= 1 AND ns.nspname = 'public' + ORDER by attr.attnum) + """.all(name ~ name ~ name ~ int4.opt ~ int4.opt ~ int4.opt ~ varchar ~ varchar.opt ~ varchar) + + q.map { (tName, colName, udt, maxCharLength, numPrecision, numScale, nullable, default, is_generated) => + ( + tName, + colName, + toType(udt, maxCharLength, numPrecision, numScale), + nullable == "YES", + default.flatMap(ColumnDefault.fromString), + is_generated == "ALWAYS" + ) + }.map { (tName, colName, udt, isNullable, default, isAlwaysGenerated) => + toScalaType(udt, isNullable, enums).map { st => + ( + tName, + Column( + columnName = colName, + pgType = udt, + isEnum = enums.exists(_.name == udt.name), + scalaType = st, + isNullable = isNullable, + default = default, + isAlwaysGenerated = isAlwaysGenerated + ) + ) + } match { + case Left(err) => throw Throwable(err) + case Right(value) => value + } + }.groupBy(_._1) + .map { case (k, v) => (k, v.map(_._2)) } + end getColumns + + private def getIndexes = + pgSessionRun: + val q = + sql"""SELECT indexname,indexdef,tablename FROM pg_indexes WHERE schemaname='public'""".all(name ~ text ~ name) + + q.map { (name, indexDef, tableName) => + (tableName, Index(name, indexDef)) + }.groupBy(_._1) + .map((tName, v) => (tName, v.map(_._2))) + + private def getEnums = + pgSessionRun: + val q = + sql"""SELECT pt.typname,pe.enumlabel FROM pg_enum AS pe JOIN pg_type AS pt ON pt.oid = pe.enumtypid""" + .all( + name ~ name + ) + + q.groupBy(_._1).toVector.map { (name, values) => + Enum(name, values.map(_._2).map(EnumValue(_))) + } + + private def getViews: Future[Set[TableName]] = + pgSessionRun: + sql"""SELECT table_name FROM information_schema.VIEWS WHERE table_schema = 'public' + UNION + SELECT matviewname FROM pg_matviews WHERE schemaname = 'public';""".all(name).toSet + + def run(): Future[List[File]] = + for + _ <- Future { + if debug then println("Running migrations...") + + val sortedFiles = sourceFiles + .map(p => + MigrationVersion.fromFileName(p.getFileName().toString()) match + case Right(v) => p -> v + case Left(err) => throw Throwable(s"Invalid migration file name: $err") + ) + .sortBy((_, version) => version) + .map((path, _) => path) + + Zone: + Using( + Database(connectionString).either match + case Left(err) => throw err + case Right(db) => db + )(db => + sortedFiles.foreach: path => + if debug then println(s"Running migration for $path") + db.execute(Files.readString(path)).either match + case Left(err) => throw err + case _ => () + ) + } + enums <- getEnums + tables <- getColumns(enums) + .zip(getIndexes) + .zip(getConstraints) + .zip(getViews) + .map: + case (((columns, indexes), constraints), views) => toTables(columns, indexes, constraints, views) + filesToWrite = pkgFiles(tables, enums) ::: tables.flatMap { table => + rowFileContent(table) match { + case None => Nil + case Some(rowContent) => + List( + outDir / s"${table.tableClassName}.scala" -> tableFileContent(table), + outDir / s"${table.rowClassName}.scala" -> rowContent + ) + } + } + _ <- + if Files.exists(pkgDir) then + Future: + listFilesRec(pkgDir) + .sortBy(_.getNameCount)(using Ordering[Int].reverse) + .foreach(Files.delete(_)) + Files.createDirectories(outDir) + else Future(Files.createDirectories(outDir)) + files <- Future.traverse(filesToWrite): (path, content) => + Future: + Files.writeString(path, content) + println(s"Created ${path.toString()}") + File(path.toString()) + yield files + end run + + private def toTables( + columns: TableMap[Column], + indexes: TableMap[Index], + constraints: TableMap[Constraint], + views: Set[TableName] + ): List[Table] = { + + def findAutoIncColumns(tableName: TableName) = + columns + .getOrElse(tableName, Vector.empty) + .filter(_.default.contains(ColumnDefault.AutoInc)) + + def findAutoPk(tableName: TableName): Option[Column] = findAutoIncColumns(tableName) + .find(col => + constraints + .getOrElse(tableName, Nil) + .collect { case c: Constraint.PrimaryKey => c } + .exists(_.columnNames.contains(col.columnName)) + ) + + columns.toList.map { case (tname, tableCols) => + val tableConstraints = constraints.getOrElse(tname, Vector.empty) + val generatedCols = findAutoIncColumns(tname) ++ tableCols.filter(_.isAlwaysGenerated) + val autoIncFk = tableConstraints.collect { case c: Constraint.ForeignKey => c }.flatMap { + _.refs.flatMap { ref => + tableCols.find(c => c.columnName == ref.fromColName).filter { _ => + findAutoPk(ref.toTableName).exists(_.columnName == ref.toColName) + } + } + } + + Table( + name = tname, + columns = tableCols.filterNot((generatedCols ++ autoIncFk).contains).toList, + generatedColumns = generatedCols.toList, + constraints = tableConstraints.toList, + indexes = indexes.getOrElse(tname, Vector.empty).toList, + autoIncFk = autoIncFk.toList, + isView = views.contains(tname) + ) + } + } + + private def scalaEnums(enums: Enums): Vector[(Path, String)] = + enums.map { e => + ( + outDir / s"${e.scalaName}.scala", + s"""|package $pkgName + | + |import skunk.Codec + |import skunk.data.Type + | + |enum ${e.scalaName}(val value: String): + | ${e.values.map(v => s"""case ${v.scalaName} extends ${e.scalaName}("${v.name}")""").mkString("\n ")} + | + |object ${e.scalaName}: + | given codec: Codec[${e.scalaName}] = + | Codec.simple[${e.scalaName}]( + | a => a.value, + | s =>${e.scalaName}.values.find(_.value == s).toRight(s"Invalid ${e.name} type: $$s"), + | Type("${e.name}"), + | )""".stripMargin + ) + } + + private def pkgFiles(tables: List[Table], enums: Enums): List[(Path, String)] = { + val indexes = tables.flatMap { table => + table.indexes.map(i => + s"""val ${toScalaName(i.name)} = Index(name = "${i.name}", createSql = \"\"\"${i.createSql}\"\"\")""" + ) + } + + val constraints = tables.flatMap { table => + table.constraints.map(c => s"""val ${toScalaName(c.name)} = Constraint(name = "${c.name}")""") + } + + val arrayCodec = + s"""|extension [A](arrCodec: skunk.Codec[skunk.data.Arr[A]]) + | def _list(using factory: scala.collection.Factory[A, List[A]]): skunk.Codec[List[A]] = + | arrCodec.imap(arr => arr.flattenTo(factory))(xs => skunk.data.Arr.fromFoldable(xs))""".stripMargin + + val pkgLastPart = pkgName.split('.').last + List( + ( + outDir / "package.scala", + List( + s"package $pkgName", + "", + arrayCodec, + if indexes.nonEmpty then indexes.mkString("\nobject indexes:\n ", "\n ", "\n") else "", + if constraints.nonEmpty then constraints.mkString("\nobject constraints:\n ", "\n ", "\n") else "" + ).mkString("\n") + ), + ( + outDir / "Index.scala", + s"""|package $pkgName + | + |final case class Index(name: String, createSql: String) + """.stripMargin + ), + ( + outDir / "Constraint.scala", + s"""|package $pkgName + | + |final case class Constraint(name: String) + """.stripMargin + ), + ( + outDir / "Cols.scala", + s"""|package $pkgName + |import skunk.* + |import cats.data.NonEmptyList + |import cats.implicits.* + | + |final case class Cols[A] private[$pkgLastPart] (names: NonEmptyList[String], codec: Codec[A], tableAlias: String) + | extends (A => AppliedCol[A]) { + | def name: String = names.intercalate(",") + | def fullName: String = names.map(n => s"$${tableAlias}.$$n").intercalate(",") + | def aliasedName: String = names.map(name => s"$${tableAlias}.$${name} $${tableAlias}__$$name").intercalate(",") + | def ~[B](that: Cols[B]): Cols[(A, B)] = Cols(this.names ::: that.names, this.codec ~ that.codec, this.tableAlias) + | def apply(a: A): AppliedCol[A] = AppliedCol(this, a) + |} + | + |final case class AppliedCol[A] (cols: Cols[A], value: A) { + | def name = cols.name + | def fullName = cols.fullName + | def codec = cols.codec + | + | def ~[B] (that: AppliedCol[B]): AppliedCol[(A, B)] = AppliedCol(this.cols ~ that.cols, (this.value, that.value)) + |} + |""".stripMargin + ) + ) ++ scalaEnums(enums) + } + + private def toScalaType(t: Type, isNullable: Boolean, enums: Enums): Result[ScalaType] = + t.componentTypes match { + case Nil => + Map[String, List[String]]( + "Boolean" -> List("bool"), + "String" -> List("text", "varchar", "bpchar", "name"), + "java.util.UUID" -> List("uuid"), + "Short" -> List("int2"), + "Int" -> List("int4"), + "Long" -> List("int8"), + "BigDecimal" -> List("numeric"), + "Float" -> List("float4"), + "Double" -> List("float8"), + "java.time.LocalDate" -> List("date"), + "java.time.LocalTime" -> List("time"), + "java.time.OffsetTime" -> List("timetz"), + "java.time.LocalDateTime" -> List("timestamp"), + "java.time.OffsetDateTime" -> List("timestamptz"), + "java.time.Duration" -> List("interval") + ).collectFirst { + // check by type name without a max length parameter if set, e.g. vacrhar instead of varchar(3) + case (scalaType, pgTypes) if pgTypes.contains(t.name.takeWhile(_ != '(')) => + if (isNullable) s"Option[$scalaType]" else scalaType + }.orElse { + enums.find(_.name == t.name).map(e => if (isNullable) s"Option[${e.scalaName}]" else e.scalaName) + }.toRight(s"No scala type found for type ${t.name}") + case x :: Nil => + toScalaType(x, isNullable = false, enums).map(t => if (isNullable) s"Option[List[$t]]" else s"List[$t]") + case x :: xs => + Left(s"Unsupported type of multiple components: ${x :: xs}") + } + + private def rowFileContent(table: Table): Option[String] = { + import table.* + + def toClassPropsStr(cols: Seq[Column]) = cols + .map(c => s" ${c.scalaName}: ${c.scalaType}") + .mkString("", ",\n", "") + + def toUpdateClassPropsStr(cols: Seq[Column]) = cols + .map(c => s" ${c.scalaName}: Option[${c.scalaType}]") + .mkString("", ",\n", "") + + def toCodecFieldsStr(cols: Seq[Column]) = s"${cols.map(_.codecName).mkString(" *: ")}" + + def toUpdateFragment(cols: Seq[Column]) = { + def toOptFrStr(c: Column) = s"""${c.scalaName}.map(sql"${c.columnName}=$${${c.codecName}}".apply(_))""" + s""" + def fragment: AppliedFragment = List( + ${cols.map(toOptFrStr(_)).mkString(",\n")} + ).flatten.intercalate(void",") + """ + } + + val rowUpdateClassData = + if (table.isView) (Nil, Nil) + else + primaryUniqueConstraint match { + case Some(cstr) => + columns.filterNot(cstr.containsColumn).toList match { + case Nil => (Nil, Nil) + case updateCols => + val colsData = toUpdateClassPropsStr(updateCols) + val fragmentData = toUpdateFragment(updateCols) + ( + updateCols, + List( + "", + s"final case class $rowUpdateClassName(", + s"$colsData", + ") {", + fragmentData, + "}" + ) + ) + } + + case None => (Nil, Nil) + } + + def withImportsStr = rowUpdateClassData match { + case (Nil, Nil) => "" + case (_, _) => List("import skunk.implicits.*", "import cats.implicits.*").mkString("\n") + } + + def withUpdateStr = rowUpdateClassData match { + case (Nil, Nil) => "" + case (cols, _) => + val updateProps = cols.map(_.scalaName).map(n => s"$n = Some($n)").mkString(" ", ",\n ", "") + List( + " {", + s" def asUpdate: $rowUpdateClassName = $rowUpdateClassName(", + s"$updateProps", + " )", + "", + s" def withUpdateAll: ($rowClassName, AppliedFragment) = (this, asUpdate.fragment)", + s" def withUpdate(f: $rowUpdateClassName => $rowUpdateClassName): ($rowClassName, AppliedFragment) = (this, f(asUpdate).fragment)", + "", + "}" + ).mkString("\n") + } + + columns.headOption.map { _ => + val colsData = toClassPropsStr(columns) + val codecData = toCodecFieldsStr(columns) + List( + s"package $pkgName", + "", + "import skunk.*", + withImportsStr, + "", + s"final case class $rowClassName(", + s"$colsData", + s")$withUpdateStr", + "", + s"object $rowClassName {", + s" given codec: Codec[$rowClassName] = ($codecData).to[$rowClassName]", + "}", + s"${rowUpdateClassData._2.mkString("\n")}" + ).mkString("\n") + } + } + + private def tableFileContent(table: Table): String = { + val (maybeAllCol, cols) = tableColumns(table) + ( + List( + s"package $pkgName\n", + "import skunk.*", + "import skunk.implicits.*", + "import cats.data.NonEmptyList" + ) ::: + List( + "", + s"class ${table.tableClassName}(val tableName: String) {", + s" def withPrefix(prefix: String): ${table.tableClassName} = new ${table.tableClassName}(prefix + tableName)", + s" def withAlias(alias: String): ${table.tableClassName} = new ${table.tableClassName}(alias)", + "", + maybeAllCol.getOrElse(""), + " object column {", + s"$cols", + " }", + writeStatements(table), + selectAllStatement(table), + "}", + "", + s"""object ${table.tableClassName} extends ${table.tableClassName}("${table.name}")""" + ) + ).mkString("\n") + } + + private def queryTypesStr(table: Table): (String, String) = { + import table.* + + if (autoIncFk.isEmpty) { + (rowClassName, s"${rowClassName}.codec") + } else { + val autoIncFkCodecs = autoIncFk.map(_.codecName).mkString(" *: ") + val autoIncFkScalaTypes = autoIncFk.map(_.scalaType).mkString(" *: ") + (s"($autoIncFkScalaTypes ~ $rowClassName)", s"$autoIncFkCodecs ~ ${rowClassName}.codec") + } + } + + private def writeStatements(table: Table): String = + if (table.isView) "" + else { + import table.* + + val allCols = autoIncFk ++ columns + val allColNames = allCols.map(_.columnName).mkString(",") + val (insertScalaType, insertCodec) = queryTypesStr(table) + + val returningStatement = generatedColumns match { + case Nil => "" + case _ => generatedColumns.map(_.columnName).mkString(" RETURNING ", ",", "") + } + val returningType = generatedColumns + .map(_.scalaType) + .mkString("", " *: ", if (generatedColumns.length > 1) " *: EmptyTuple" else "") + val fragmentType = generatedColumns match { + case Nil => "command" + case _ => s"query(${generatedColumns.map(_.codecName).mkString(" *: ")})" + } + + val upsertQ = primaryUniqueConstraint.map { cstr => + val queryType = generatedColumns match { + case Nil => s"Command[$insertScalaType *: updateFr.A *: EmptyTuple]" + case _ => s"Query[$insertScalaType *: updateFr.A *: EmptyTuple, $returningType]" + } + + s"""| def upsertQuery(updateFr: AppliedFragment, constraint: Constraint = Constraint("${cstr.name}")): $queryType = + | sql\"\"\"INSERT INTO #$$tableName ($allColNames) VALUES ($${$insertCodec}) + | ON CONFLICT ON CONSTRAINT #$${constraint.name} + | DO UPDATE SET $${updateFr.fragment}$returningStatement\"\"\".$fragmentType""".stripMargin + } + + val queryType = generatedColumns match { + case Nil => s"Command[$insertScalaType]" + case _ => s"Query[$insertScalaType, $returningType]" + } + val insertQ = + s"""| def insertQuery(ignoreConflict: Boolean = true): $queryType = { + | val onConflictFr = if (ignoreConflict) const" ON CONFLICT DO NOTHING" else const"" + | sql\"INSERT INTO #$$tableName ($allColNames) VALUES ($${$insertCodec})$$onConflictFr$returningStatement\".$fragmentType + | }""".stripMargin + + val insertCol = + s"""| + | def insert[A](cols: Cols[A]): Command[A] = + | sql\"INSERT INTO #$$tableName (#$${cols.name}) VALUES ($${cols.codec})\".command + | + | def insert0[A, B](cols: Cols[A], rest: Fragment[B] = sql"ON CONFLICT DO NOTHING")(implicit + | ev: Void =:= B + | ): Command[A] = + | (sql\"INSERT INTO #$$tableName (#$${cols.name}) VALUES ($${cols.codec}) " ~ rest).command.contramap[A](a => (a, ev.apply(Void))) + | + | def insert[A, B](cols: Cols[A], rest: Fragment[B] = sql"ON CONFLICT DO NOTHING"): Command[(A, B)] = + | (sql\"INSERT INTO #$$tableName (#$${cols.name}) VALUES ($${cols.codec})" ~ rest).command + |""".stripMargin + List( + upsertQ.getOrElse(""), + insertQ, + insertCol + ).mkString("\n\n") + } + + private def tableColumns(table: Table): (Option[String], String) = { + val allCols = table.generatedColumns ++ table.autoIncFk ++ table.columns + val cols = + allCols.map(column => + s""" val ${column.snakeCaseScalaName} = Cols(NonEmptyList.of("${column.columnName}"), ${column.codecName}, tableName)""" + ) + + val allCol = + if table.columns.nonEmpty then + Some { + val s = table.columns.map(_.columnName).map(x => s""""$x"""").mkString(",") + s"""| + | val all = Cols(NonEmptyList.of($s), ${table.rowClassName}.codec, tableName) + |""".stripMargin + } + else None + + allCol -> cols.mkString("\n") + } + + private def selectAllStatement(table: Table): String = { + import table.* + + val generatedColStm = if (generatedColumns.nonEmpty) { + val types = generatedColumns.map(_.codecName).mkString(" *: ") + val sTypes = generatedColumns.map(_.scalaType).mkString(" *: ") + val colNamesStr = (generatedColumns ++ columns).map(_.columnName).mkString(", ") + + s""" + | def selectAllWithGenerated[A](addClause: Fragment[A] = Fragment.empty): Query[A, $sTypes *: $rowClassName *: EmptyTuple] = + | sql"SELECT $colNamesStr FROM #$$tableName $$addClause".query($types *: ${rowClassName}.codec) + | + """.stripMargin + } else { + "" + } + + val colNamesStr = (autoIncFk ++ columns).map(_.columnName).mkString(",") + val (queryReturnType, queryCodec) = queryTypesStr(table) + + val defaultStm = s""" + | def selectAll[A](addClause: Fragment[A] = Fragment.empty): Query[A, $queryReturnType] = + | sql"SELECT $colNamesStr FROM #$$tableName $$addClause".query($queryCodec) + | + |""".stripMargin + + val selectCol = s"""| def select[A, B](cols: Cols[A], rest: Fragment[B] = Fragment.empty): Query[B, A] = + | sql"SELECT #$${cols.name} FROM #$$tableName $$rest".query(cols.codec) + |""".stripMargin + generatedColStm ++ defaultStm ++ selectCol + } +} + +object PgCodeGen { + case class UseDocker(dockerImage: String, dockerName: String) + type UseConnection = URI | UseDocker + + object defaults: + val dumboDockerMigrationCmd = List( + """docker run --rm --net="host"""", + "-v %sourcePath:/migration", + "rolang/dumbo:latest-alpine", + "-user=%user", + "-password=%password", + "-url=postgresql://%host:%port/%database", + "-table=%schemaHistoryTableName", + "-location=/migration", + "migrate" + ).mkString(" ") + + def run( + useDockerImage: String, + outputDir: File, + pkgName: String, + sourceDir: File, + excludeTables: List[String], + scalaVersion: String, + useConnectionUri: Option[URI], + forceRegeneration: Boolean, + debug: Boolean + )(using ExecutionContext): Future[List[File]] = + val pkgDir = Paths.get(outputDir.getPath(), pkgName.replace('.', File.separatorChar)) + def outDir(sha1: String) = pkgDir / sha1 + val schemaHistoryTableName = "dumbo_history" + val useConnection: UseConnection = useConnectionUri match + case None => UseDocker(dockerImage = useDockerImage, dockerName = s"codegen_${pkgName.replace(".", "_")}") + case Some(u) => u + + def cleanup = useConnection match + case d: UseDocker => Future(s"docker rm -f ${d.dockerName}" ! ProcessLogger(_ => ())) + case _ => Future.unit + + def listMigrationFiles: Future[(List[Path], String)] = Future: + val digest = MessageDigest.getInstance("SHA-1") + val files = listFilesRec(sourceDir.toPath) + .map(path => + digest.update(path.toString.getBytes("UTF-8")) + if !Files.isDirectory(path) then digest.update(Files.readAllBytes(path)) + path + ) + .filter(!Files.isDirectory(_)) + + (files, digest.digest().map("%02x".format(_)).mkString) + + if !scalaVersion.startsWith("3") then + Future.failed( + UnsupportedOperationException(s"Scala version smaller than 3 is not supported. Used version: $scalaVersion") + ) + else + listMigrationFiles.flatMap: + case (sourceFiles, sha1) => + val isDivergent = !Files.exists(outDir(sha1)) + if forceRegeneration || isDivergent then + for + _ <- + if sourceFiles.isEmpty then + Future.failed(Exception(s"Cannot find any .sql files in ${sourceDir.toPath()}")) + else Future.unit + _ = println("Generating Postgres models") + db <- initGeneratorDatabase(useConnection) + codegen = PgCodeGen( + pkgName = pkgName, + sourceFiles = sourceFiles, + excludeTables = excludeTables, + debug = debug, + user = db.user, + password = db.password, + host = db.host, + port = db.port, + database = db.databaseName, + schemaHistoryTableName = schemaHistoryTableName, + pkgDir = pkgDir, + outDir = outDir(sha1) + ) + files <- codegen + .run() + .transformWith: + case Success(files) => + cleanup.map: _ => + println(s"Generated ${files.length} files") + files + case Failure(err) => cleanup.flatMap(_ => Future.failed(err)) + yield files + else + Future: + println(s"Generated code already exists in ${outDir(sha1)}. Skipping code generation.") + listFilesRec(outputDir.toPath).map(_.toFile) + + private def initGeneratorDatabase(useConnection: UseConnection)(using + ExecutionContext + ): Future[ + ( + host: String, + user: String, + port: Int, + password: String, + databaseName: String + ) + ] = + val useConnectionUri = useConnection match + case u: URI => Some(u) + case _ => None + val useConnectionUserInfo = useConnectionUri.flatMap: + _.getRawUserInfo().split(':') match + case Array(user, password, _*) => Some((user = user, password = password)) + case _ => None + val host = useConnectionUri.map(_.getHost()).getOrElse("localhost") + val user = useConnectionUserInfo.map(_.user).getOrElse("postgres") + val port = useConnectionUri.map(u => u.getPort()).getOrElse(findFreePort()) + val password = useConnectionUserInfo.map(_.password).getOrElse("postgres") + val databaseName = s"codegen_db_${Random.alphanumeric.take(10).mkString.toLowerCase()}" + + def awaitReadiness(connectionString: String) = + @tailrec + def check(attempt: Int): Unit = + Thread.sleep(500) + try { + Zone: + Pool.single(connectionString): pool => + pool.withLease: + val res = sql"SELECT true".one(bool).contains(true) + } catch { + case e: Throwable => + if attempt <= 10 then check(attempt + 1) + else + Console.err.println(s"Could not connect to docker on $host:$port ${e.getMessage()}") + throw e + } + + Future: + check(0) + + ( + host = host, + user = user, + port = port, + password = password, + databaseName = databaseName + ) + + useConnection match + case u: URI => + awaitReadiness(u.toString()).map: res => + Zone: + Pool.single(u.toString())(_.withLease(sql"CREATE DATABASE $databaseName".exec())) + res + case d: UseDocker => + Future( + List( + s"docker run", + s"-p $port:5432", + s"-h $host", + s"-e POSTGRES_USER=$user", + s"-e POSTGRES_PASSWORD=$password", + s"-e POSTGRES_DB=$databaseName", + s"--name ${d.dockerName}", + s"-d ${d.dockerImage}" + ).mkString(" ").!! + ).flatMap(_ => awaitReadiness(s"postgresql://$user:$password@$host:$port/$databaseName")) + end initGeneratorDatabase + + enum MigrationVersion: + def compare(that: MigrationVersion): Int = { + @tailrec + def cmprVersioned(a: List[Int], b: List[Int]): Int = + (a, b) match { + case (xa :: xsa, xb :: xsb) if xa == xb => cmprVersioned(xsa, xsb) + case (xa :: _, xb :: _) => xa.compare(xb) + case (xa :: _, Nil) => xa.compare(0) + case (Nil, xb :: _) => xb.compare(0) + case (Nil, Nil) => 0 + } + + (this, that) match { + case (_: Repeatable, _: Versioned) => 1 + case (_: Versioned, _: Repeatable) => -1 + case (Repeatable(descThis), Repeatable(descThat)) => descThis.compare(descThat) + case (Versioned(thisParts), Versioned(thatParts)) => cmprVersioned(thisParts, thatParts) + } + } + case Versioned(parts: List[Int]) + case Repeatable(name: String) + + object MigrationVersion: + private val versioned = "^V([^_]+)__(.+)\\.sql$".r + private val repeatable = "^R__(.+)\\.sql$".r + + given Ordering[MigrationVersion] with + def compare(x: MigrationVersion, y: MigrationVersion): Int = x.compare(y) + + def fromFileName(name: String): Either[String, MigrationVersion] = name match + case versioned(version, name) => + try Right(MigrationVersion.Versioned(version.split('.').map(_.toInt).toList)) + catch case e: Throwable => Left(s"Invalid version $version: ${e.getMessage()}") + case repeatable(n) => Right(MigrationVersion.Repeatable(n)) + case other => Left(s"Invalid file name $other") + + private def listFilesRec(path: Path): List[Path] = + import scala.jdk.CollectionConverters.* + Files + .walk(path) + .iterator() + .asScala + .toList + + @tailrec + private def findFreePort(): Int = + try + val portCandidate = 1024 + Random.nextInt(65535 - 1024) + val socket = new ServerSocket(portCandidate) + val port = socket.getLocalPort + socket.close() + port + catch case e: Throwable => findFreePort() + + type TableName = String + type TableMap[T] = Map[TableName, Vector[T]] + type Enums = Vector[Enum] + type ScalaType = String + type Result[T] = Either[String, T] + + final case class Enum(name: String, values: Vector[EnumValue]) { + val scalaName: String = toScalaName(name).capitalize + } + final case class EnumValue(name: String) { + val scalaName: String = toScalaName(name.toLowerCase).capitalize + } + + final case class Column( + columnName: String, + scalaType: ScalaType, + pgType: Type, + isEnum: Boolean, + isNullable: Boolean, + default: Option[ColumnDefault], + isAlwaysGenerated: Boolean + ) { + val scalaName: String = toScalaName(columnName) + val snakeCaseScalaName: String = escapeScalaKeywords(columnName) + + def isArr = pgType.componentTypes.nonEmpty + + val codecName: String = + ( + (if (isEnum) s"${toScalaName(pgType.name).capitalize}.codec" else s"skunk.codec.all.${pgType.name}") + + (if (isArr) "._list" else "") + + (if (isNullable) ".opt" else "") + ) + } + + final case class ColumnRef(fromColName: String, toColName: String, toTableName: String) + + sealed trait ColumnDefault + object ColumnDefault { + case object AutoInc extends ColumnDefault + + def fromString(value: String): Option[ColumnDefault] = + if (value.contains("nextval")) Some(AutoInc) else None + } + sealed trait Constraint { + def name: String + } + sealed trait UniqueConstraint extends Constraint { + def columnNames: Vector[String] + + def containsColumn(c: Column): Boolean = columnNames.contains(c.columnName) + } + object Constraint { + final case class PrimaryKey(name: String, columnNames: Vector[String]) extends UniqueConstraint + final case class Unique(name: String, columnNames: Vector[String]) extends UniqueConstraint + final case class ForeignKey(name: String, refs: Vector[ColumnRef]) extends Constraint + final case class Unknown(name: String) extends Constraint + } + + final case class Index(name: String, createSql: String) + + final case class Table( + name: String, + columns: List[Column], + generatedColumns: List[Column], + constraints: List[Constraint], + indexes: List[Index], + autoIncFk: List[Column], + isView: Boolean + ) { + val tableClassName: String = toTableClassName(name) + val rowClassName: String = toRowClassName(name) + val rowUpdateClassName: String = toRowUpdateClassName(name) + + val primaryUniqueConstraint: Option[UniqueConstraint] = constraints + .collectFirst { case c: Constraint.PrimaryKey => + c + } + .orElse { + constraints.collectFirst { case c: Constraint.Unique => + c + } + } + + def isInPrimaryConstraint(c: Column): Boolean = primaryUniqueConstraint.exists(_.containsColumn(c)) + } + + final case class ConstraintRow( + tableName: String, + name: String, + typ: String, + refCol: String, + refTable: String, + fromCol: String + ) + + def toScalaName(s: String): String = + escapeScalaKeywords(toCamelCase(s)) + + def escapeScalaKeywords(v: String): String = + v match + case "type" => "`type`" + case "import" => "`import`" + case "val" => "`val`" // add more as required + case v if !v.head.isLetter => s"`$v`" + case v => v + + def toCamelCase(s: String, capitalize: Boolean = false): String = + s.split("_") + .zipWithIndex + .map { + case (t, 0) if !capitalize => t + case (t, _) => t.capitalize + } + .mkString + + private def toRowClassName(s: String): String = + toCamelCase(s, capitalize = true) + "Row" + + private def toRowUpdateClassName(s: String): String = + toCamelCase(s, capitalize = true) + "Update" + + private def toTableClassName(s: String): String = + toCamelCase(s, capitalize = true) + "Table" +} diff --git a/PgCodeGenTest.scala b/PgCodeGenTest.scala new file mode 100644 index 0000000..cd8651d --- /dev/null +++ b/PgCodeGenTest.scala @@ -0,0 +1,262 @@ +//> using scala 3.7.1 +//> using dep dev.rolang::dumbo:0.5.5 +//> using platform jvm +//> using jvm system +//> using file test-generated/generated + +import scala.annotation.tailrec +import scala.concurrent.duration.* +import scala.util.Random + +import java.net.ServerSocket +import java.time.{OffsetDateTime, ZoneOffset} +import java.time.temporal.ChronoUnit + +import cats.effect.{ExitCode, IO, IOApp} +import cats.effect.std.Console +import cats.implicits.* +import dumbo.ConnectionConfig +import fs2.io.file.Path +import generated.* +import org.typelevel.otel4s.trace.Tracer.Implicits.noop +import skunk.* +import skunk.codec.all.* +import skunk.implicits.* +import skunk.util.{Origin, Typer} +import sys.process.* + +object GeneratedCodeTest extends IOApp { + override def run(args: List[String]): IO[ExitCode] = + (for + testDbPort <- IO(findFreePort()) + _ <- IO( + List( + "docker run", + s"-p $testDbPort:5432", + "-h localhost", + "-e POSTGRES_USER=postgres", + "-e POSTGRES_PASSWORD=postgres", + "--name codegen-test", + "-d", + "postgres:17-alpine" + ).mkString(" ").!! + ) + _ <- awaitReadiness(testDbPort) + _ <- migrate(testDbPort) + _ <- Session + .single[IO]( + host = "localhost", + port = testDbPort, + user = "postgres", + database = "postgres", + password = Some("postgres"), + strategy = Typer.Strategy.SearchPath // to include custom types like enums, + ) + .use { s => + val (testRow, testUpdateFr) = TestRow( + number = Some(1), + createdAt = OffsetDateTime.now(ZoneOffset.UTC).truncatedTo(ChronoUnit.MILLIS), + template = Some(TestEnumType.T1One), + name = Some("name"), + name2 = "name2", + `type` = Some("type"), + tla = "abc", + tlaVar = "abc", + numericDefault = BigDecimal(1), + numeric24p = BigDecimal(2), + numeric16p2s = BigDecimal(3) + ).withUpdateAll + + val testBRow = TestBRow( + keyA = "keyA", + keyB = "keyB", + val1 = "val1", + val2 = "val2", + val3 = "val3", + val4 = "val4", + val5 = "val5", + val6 = "val6", + val7 = "val7", + val8 = "val8", + val9 = "val9", + val10 = "val10", + val11 = "val11", + val12 = "val12", + val13 = "val13", + val14 = "val14", + val15 = "val15", + val16 = "val16", + val17 = "val17", + val18 = "val18", + val19 = "val19", + val20 = "val20", + val21 = "val21", + val22 = "val22", + val23 = "val23", + val24 = "val24", + val25 = "val25", + val26 = List("val26"), + val27 = Some(List(1, 2)), + date = None + ) + for { + // Test table + p <- s.prepare(TestTable.upsertQuery(testUpdateFr)) + _ <- s.prepare(TestTable.insertQuery(ignoreConflict = true)) + res <- p.option((testRow, testUpdateFr.argument)) + _ <- IO.raiseWhen(res.isEmpty)(new Throwable("test A did not return generated columns")) + id = res.get._1 + _ <- IO.raiseWhen(res.get._2 != 2 && res.get._3 != Some(2))( + new Throwable("unexpected result for generated columns") + ) + all <- s.execute(TestTable.selectAll()) + allWithGen <- s.execute(TestTable.selectAllWithGenerated()) + _ <- IO.raiseWhen(all != List(testRow))(new Throwable("test A result not equal")) + _ <- IO.raiseWhen(allWithGen.map(_._4) != List(testRow))(new Throwable("test A result with id not equal")) + aliasedTestTable = TestTable.withAlias("t") + idAndName2 = aliasedTestTable.column.id ~ aliasedTestTable.column.name_2 + xs <- + s.execute( + sql"""SELECT #${idAndName2.aliasedName},#${aliasedTestTable.column.name.fullName} FROM #${TestTable.tableName} #${aliasedTestTable.tableName}""" + .query(idAndName2.codec ~ TestTable.column.name.codec) + ) + _ <- IO.raiseWhen(xs != List((id, testRow.name2) -> testRow.name))( + new Throwable("test A select fields not equal") + ) + all2 <- s.execute(TestTable.select(TestTable.all)) + _ <- IO.raiseWhen(all2 != List(testRow))(new Throwable("test A select all fields not equal")) + // TestB table + testBUpdateAllFr = testBRow.withUpdateAll._2 + + upsertCmd <- s.prepare(TestBTable.upsertQuery(testBUpdateAllFr)) + _ <- upsertCmd.execute((testBRow, testBUpdateAllFr.argument)) + allLoaded <- s.execute(TestBTable.selectAll()) + _ <- IO.raiseWhen(List(testBRow) != allLoaded)(new Throwable("test B result not equal")) + loadByIdQ <- s.prepare(TestBTable.selectAll(sql"WHERE key_a = ${varchar} AND key_b = ${varchar}")) + loadedById <- loadByIdQ.option((testBRow.keyA, testBRow.keyB)) + _ <- IO.raiseWhen(Some(testBRow) != loadedById)(new Throwable("test B result by id not equal")) + notFoundRes <- s.execute(TestBTable.selectAll(sql"WHERE key_a = 'not_existing'")) + _ <- IO.raiseWhen(notFoundRes.nonEmpty)(new Throwable("test B query result is empty")) + + testBRowUpdate = testBRow.copy(val1 = "val1_update", val2 = "val2_update") + testBUpdateFr = testBRowUpdate.withUpdate(_.copy(val2 = None))._2 // exclude update of val2 + + _ <- s.execute(TestBTable.upsertQuery(testBUpdateFr))((testBRowUpdate, testBUpdateFr.argument)) + afterUpdate <- s.execute(TestBTable.selectAll()) + _ <- IO.raiseWhen(afterUpdate.length != 1)(new Throwable("test B result unexpected length")) + _ <- + IO.raiseWhen( + afterUpdate.headOption.map(_.val1) != Some("val1_update") || + afterUpdate.headOption.map(_.val2) == Some("val2_update") // should not be updated + )( + new Throwable("test B result unexpected update") + ) + // Check Enum variable format + _ = Seq( + TestEnumType.T1One, + TestEnumType.T2Two, + TestEnumType.T3Three, + TestEnumType.T4Four, + TestEnumType.T5Five, + TestEnumType.T6six, + TestEnumType.MultipleWordEnum + ) + _ <- s.execute(sql"TRUNCATE TABLE #${TestBTable.tableName}".command) + + _ <- s.execute(TestBTable.insert0(TestBTable.all))(testBRow) + allBTable <- s.execute(TestBTable.select(TestBTable.all)) + _ <- IO.raiseWhen(allBTable != List(testBRow))(new Throwable("test B not equal")) + loadedById <- + s.option( + TestBTable.select( + TestBTable.all, + sql"WHERE #${TestBTable.column.key_a.name} = ${TestBTable.column.key_a.codec} AND #${TestBTable.column.key_b.name} = ${TestBTable.column.key_b.codec}" + ) + )((testBRow.keyA, testBRow.keyB)) + + _ <- IO.raiseWhen(Some(testBRow) != loadedById)(new Throwable("test B result by id not equal")) + + _ <- + s.execute( + TestBTable.select( + TestBTable.column.key_a ~ TestBTable.column.key_b ~ TestBTable.column.val_1, + sql"WHERE key_a = 'not_existing'" + ) + ).flatMap(notFoundRes => + IO.raiseWhen(notFoundRes.nonEmpty)(new Throwable("test B query result is empty")) + ) + + updatingFields = + TestBTable.column.val_27(None) ~ TestBTable.column.val_2("updated_val_2") ~ TestBTable.column.val_14( + "updated_val_14" + ) + updateQ = sql""" + ON CONFLICT ON CONSTRAINT #${generated.constraints.testBPkey.name} DO UPDATE SET + (#${updatingFields.name}) = (${updatingFields.codec}) + """ + _ <- s.execute(TestBTable.insert(TestBTable.all, updateQ))(testBRow *: updatingFields.value *: EmptyTuple) + loadedById <- + s.option( + TestBTable.select( + TestBTable.all, + sql"WHERE #${TestBTable.column.key_a.name} = ${TestBTable.column.key_a.codec} AND #${TestBTable.column.key_b.name} = ${TestBTable.column.key_b.codec}" + ) + )((testBRow.keyA, testBRow.keyB)) + _ <- IO.raiseWhen( + Some(testBRow.copy(val27 = None, val2 = "updated_val_2", val14 = "updated_val_14")) != loadedById + )(new Throwable("test B result missing update")) + _ <- s.execute(sql"REFRESH MATERIALIZED VIEW test_materialized_view".command) + result <- s.execute(TestMaterializedViewTable.selectAll()) + _ <- IO.raiseWhen(result.isEmpty)(new Throwable(s"materialized view doesn't have correct value: ${result}")) + _ <- IO.println("Test successful!") + } yield () + } + yield ()).attempt.flatMap: + case Right(_) => IO("docker rm -f codegen-test".!!).as(ExitCode.Success) + case Left(err) => Console[IO].printStackTrace(err) >> IO("docker rm -f codegen-test".!!).as(ExitCode.Error) + + private def migrate(port: Int) = dumbo.Dumbo + .withFilesIn[IO](Path("test/migrations")) + .apply( + connection = ConnectionConfig( + host = "localhost", + port = port, + user = "postgres", + database = "postgres", + password = Some("postgres") + ) + ) + .runMigration + + @tailrec + private def findFreePort(): Int = + try + val portCandidate = 1024 + Random.nextInt(65535 - 1024) + val socket = new ServerSocket(portCandidate) + val port = socket.getLocalPort + socket.close() + port + catch case e: Throwable => findFreePort() + + private def awaitReadiness(port: Int) = + fs2.Stream + .repeatEval( + Session + .single[IO]( + host = "localhost", + port = port, + user = "postgres", + database = "postgres", + password = Some("postgres") + ) + .use(_.unique(sql"SELECT 1".query(int4)).void) + .attempt + .map(_.swap.toOption) + ) + .metered(500.millis) + .timeout(10.seconds) + .unNoneTerminate + .compile + .drain + .onError(e => IO.println(s"Could not connect to docker on localhost:$port ${e.getMessage()}")) +} diff --git a/README.md b/README.md index 9f9eec3..b4fe546 100644 --- a/README.md +++ b/README.md @@ -1,55 +1,105 @@ -# Sbt plugin for generating source code from Postgres database schema +# skunk-codegen -![Maven Central Version](https://img.shields.io/maven-central/v/com.anymindgroup/skunk-codegen_2.12) +`skunk-codegen` is a Scala 3 code generator for PostgreSQL database schemas. It introspects your database and generates type-safe Scala code for use with the [skunk](https://tpolecat.github.io/skunk/) functional Postgres library. + +The code generator is based on [roach](https://github.com/indoorvivants/roach), an experimental Scala Native library for Postgres access using libpq and is shipped with a command line as native binary (~6MB). + +## Features + +- **Schema Introspection:** Reads tables, columns, constraints (primary, unique, foreign keys), indexes, and enums from a PostgreSQL database. +- **Code Generation:** Produces Scala case classes, codecs, and table definitions for each table and enum in your schema. +- **Migration Support:** Runs Flyway compatible database migrations. +- **Docker Integration:** Can spin up a PostgreSQL Docker container for isolated code generation. +- **Customizable:** Supports excluding tables, specifying output/source directories, and customizing package names. ## Usage -Add plugin to `project/plugins.sbt` -```scala -addSbtPlugin("com.anymindgroup" % "sbt-skunk-codegen" % "x.y.z") +Run the generator via command line: +(_Ensure `libpq` and `docker` are installed on your system_) + +```shell +# download executable (for Linux / x86_64) +curl https://github.com/AnyMindGroup/skunk-codegen/releases/download/latest/skunk-codegen-x86_64-linux > skunk_codegen && chmod +x skunk_codegen + +# run code generator +./skunk_codegen \ + -use-docker-image="postgres:17-alpine" \ + -output-dir=my/out/dir \ + -pkg-name=my.package \ + -exclude-tables=table_name_a,table_name_b \ + -source-dir=path/to/db/migrations ``` -Enable plugin for the project and configure: +**Command line arguments:** +- `-output-dir` (required): Output directory for generated Scala files +- `-pkg-name` (required): Scala package name for generated code +- `-source-dir` (required): Directory containing migration SQL files +- `-use-docker-image`: Docker image for Postgres (default: postgres:17-alpine) +- `-use-connection`: Use a custom Postgres connection URI (will not boot up a new Postgres Docker container if set) +- `-exclude-tables`: Comma-separated list of tables to exclude +- `-scala-version`: Scala version (default: 3.7.1) +- `-debug`: Enable debug output (`true`/`1` to enable) +- `-force`: Force code generation, ignoring cache (`true`/`1` to enable) + +#### Example usage of command line as source generator in [sbt](https://www.scala-sbt.org): + ```scala lazy val myProject = (project in file(".")) - .enablePlugins(PgCodeGenPlugin) .settings( - name := "test", - Compile / scalacOptions ++= Seq("-Xsource:3", "-release:17"), - // Generator settings - pgCodeGenOutputPackage := "com.example", // output package - - // postgres connection settings - pgCodeGenHost := "localhost", // default: "localhost" - pgCodeGenPort := 5432, - pgCodeGenUser := "postgres", // default: "postgres" - pgCodeGenDb := "postgres", // default: "postgres" - pgCodeGenPassword := Some("postgres"), // default: None - // pgCodeGenOperateDB value will create new database with specified - // name if not exist for pgCodeGen migration process. Recommend to be configure differently - // with multiple module in the same project - pgCodeGenOperateDB := Some("postgres_b") // default: None - - // whether to start a postgres docker container and what image to use on running the task (default: Some("postgres:16-alpine")) - pgCodeGenUseDockerImage := Some("postgres:16-alpine"), - - // path to directory with sql migration script - pgCodeGenSqlSourceDir := file("src") / "main" / "resources" / "db" / "migration", - - // list of tables to exclude from generator - pgCodeGenExcludedTables := List("to_exclude_table_name") - ) -``` -See all available settings under [PgCodeGenPlugin.scala](modules/sbt/src/main/scala/com/anymindgroup/sbt/PgCodeGenPlugin.scala). -See example setup under [sbt test](modules/sbt/src/sbt-test/test/basic). + Compile / sourceGenerators += skunkCodeGenTask( + pkgName = "my.package", + migrationsDir = file("src") / "main" / "resources" / "db" / "migration", + ) + ) -Generator will run on changes to sql migration scripts. -Watch and re-compile on changes by e.g.: -```shell -sbt ~compile -``` +def skunkCodeGenTask( + pkgName: String, + migrationsDir: File, + excludeTables: List[String] = Nil, +) = Def.task { + import sys.process.* + import scala.jdk.CollectionConverters.* + import java.nio.file.Files -To force code re-generation, execute the task -```shell -sbt pgCodeGen + val logger = streams.value.log + val outDir = (Compile / sourceManaged).value + val outPkgDir = outDir / pkgName.split('.').mkString(java.io.File.separator) + + val cmd = List( + "./path/to/skunk_codegen", + s"-output-dir=${outDir.getPath()}", + s"-pkg-name=$pkgName", + s"-source-dir=${migrationsDir.getPath()}", + s"-exclude-tables=${excludeTables.mkString(",")}", + "-force=false", + ).mkString(" ") + + logger.debug(s"Running skunk code generator with: $cmd") + + val errs = scala.collection.mutable.ListBuffer.empty[String] + cmd ! ProcessLogger(i => logger.info(s"[Skunk codegen] $i"), e => errs += e) match { + case 0 => () + case c => throw new InterruptedException(s"Failure on code generation:\n${errs.mkString("\n")}") + } + + Files + .walk(outPkgDir.toPath) + .iterator() + .asScala + .collect { + case p if !Files.isDirectory(p) => p.toFile + } + .toList +} ``` + + +## Output + +- Scala files for each table and enum in the specified package directory. +- Type-safe codecs and helper methods for querying and updating tables. +- Support for array types, nullable columns, and enum mappings. + +--- + +For more details, see the code in [`PgCodeGen.scala`](PgCodeGen.scala). \ No newline at end of file diff --git a/build.sbt b/build.sbt deleted file mode 100644 index fe396ce..0000000 --- a/build.sbt +++ /dev/null @@ -1,95 +0,0 @@ -lazy val scala212 = "2.12.19" -lazy val scala213 = "2.13.15" -lazy val scala3 = "3.3.4" -lazy val allScala = Seq(scala212, scala213, scala3) - -ThisBuild / organization := "com.anymindgroup" -ThisBuild / organizationName := "AnyMind Group" -ThisBuild / organizationHomepage := Some(url("https://anymindgroup.com")) -ThisBuild / licenses := Seq(License.Apache2) -ThisBuild / homepage := Some(url("https://github.com/AnyMindGroup/sbt-skunk-codegen")) -ThisBuild / scmInfo := Some( - ScmInfo( - url("https://github.com/AnyMindGroup/sbt-skunk-codegen"), - "scm:git@github.com:AnyMindGroup/sbt-skunk-codegen.git", - ) -) -ThisBuild / description := "SBT plugin for generating source code from Postgres database schema." -ThisBuild / developers := List( - Developer("rolang", "Roman Langolf", "@rolang", url("https://github.com/rolang")), - Developer("dutch3883", "Panuwach Boonyasup", "@dutch3883", url("https://github.com/dutch3883")), - Developer("qhquanghuy", "Huy Nguyen", "@qhquanghuy", url("https://github.com/qhquanghuy")), - Developer("alialiusefi", "Ali Al-Yousefi", "@alialiusefi", url("https://github.com/alialiusefi")) -) -ThisBuild / sonatypeCredentialHost := xerial.sbt.Sonatype.sonatypeCentralHost - -lazy val betterFilesVersion = "3.9.2" -lazy val commonSettings = List( - libraryDependencies ++= { - if (scalaVersion.value == scala3) - Seq() - else - Seq(compilerPlugin("com.olegpy" %% "better-monadic-for" % "0.3.1")) - }, - version ~= { v => if (v.contains('+')) s"${v.replace('+', '-')}-SNAPSHOT" else v }, - Test / scalacOptions --= Seq("-Xfatal-warnings"), -) - -lazy val sbtSkunkCodegen = (project in file(".")) - .dependsOn(core, sbtPlugin) - .aggregate(core, sbtPlugin) - .settings(noPublishSettings) - -val noPublishSettings = List( - publish := {}, - publishLocal := {}, - publishArtifact := false, - publish / skip := true, -) - -val releaseSettings = List( - publishTo := sonatypePublishToBundle.value -) - -lazy val core = (project in file("modules/core")) - .settings( - name := "skunk-codegen", - scalaVersion := scala213, - crossScalaVersions := allScala, - javacOptions ++= Seq("-source", "17"), - Compile / scalacOptions ++= { - if (scalaVersion.value == scala3) - Seq("-source:future") - else if (scalaVersion.value == scala213) - Seq( - "-Ymacro-annotations", - "-Xsource:3", - "-Wconf:cat=scala3-migration:s", - ) // https://github.com/scala/scala/pull/10439 - else - Seq("-Xsource:3") - }, - libraryDependencies ++= Seq( - "dev.rolang" %% "dumbo" % "0.0.9", - "com.github.pathikrit" %% "better-files" % betterFilesVersion, - ), - ) - .settings(commonSettings) - .settings(releaseSettings) - -lazy val sbtPlugin = (project in file("modules/sbt")) - .enablePlugins(SbtPlugin) - .dependsOn(core) - .aggregate(core) - .settings(commonSettings) - .settings(releaseSettings) - .settings( - name := "sbt-skunk-codegen", - sbtPluginPublishLegacyMavenStyle := false, - scalaVersion := scala212, - scriptedLaunchOpts := { - scriptedLaunchOpts.value ++ - Seq("-Xmx1024M", "-Dplugin.version=" + version.value) - }, - scriptedBufferLog := false, - ) diff --git a/modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala b/modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala deleted file mode 100644 index 91c7768..0000000 --- a/modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala +++ /dev/null @@ -1,917 +0,0 @@ -package com.anymindgroup - -import better.files.* -import cats.Show -import cats.data.{NonEmptyList, Validated} -import cats.effect.* -import cats.implicits.* -import com.anymindgroup.PgCodeGen.Constraint.PrimaryKey -import dumbo.{ConnectionConfig, Dumbo, ResourceFile} -import fs2.io.file.Files -import natchez.Trace.Implicits.noop -import skunk.* -import skunk.codec.all.* -import skunk.data.Type -import skunk.implicits.* - -import java.io.File as JFile -import java.nio.charset.Charset -import scala.concurrent.duration.* -import scala.sys.process.* - -class PgCodeGen( - host: String, - user: String, - database: String, - operateDatabase: Option[String], - port: Int, - password: Option[String], - useDockerImage: Option[String], - outputDir: JFile, - pkgName: String, - sourceDir: JFile, - excludeTables: List[String], - scalaVersion: String, -) { - import PgCodeGen.* - - private val pkgDir = File(outputDir.toPath(), pkgName.replace('.', JFile.separatorChar)) - private val schemaHistoryTableName = "dumbo_history" - implicit val consoleErrLevel: std.Console[IO] = new std.Console[IO] { - override def readLineWithCharset(charset: Charset): IO[String] = IO.consoleForIO.readLineWithCharset(charset) - override def print[A](a: A)(implicit S: Show[A]): IO[Unit] = IO.unit - override def println[A](a: A)(implicit S: Show[A]): IO[Unit] = IO.unit - override def error[A](a: A)(implicit S: Show[A]): IO[Unit] = IO.consoleForIO.error(a) - override def errorln[A](a: A)(implicit S: Show[A]): IO[Unit] = IO.consoleForIO.errorln(a) - } - - private def getConstraints(s: Session[IO]): IO[TableMap[Constraint]] = { - val q: Query[Void, String ~ String ~ String ~ String ~ String ~ String] = - sql""" - SELECT c.table_name, c.constraint_name, c.constraint_type, cu.column_name, cu.table_name, kcu.column_name - FROM information_schema.table_constraints AS c - JOIN information_schema.key_column_usage as kcu ON kcu.constraint_name = c.constraint_name - JOIN information_schema.constraint_column_usage AS cu ON cu.constraint_name = c.constraint_name - WHERE c.table_schema='public' - """.query(name ~ name ~ varchar ~ name ~ name ~ name) - - s.execute(q.map { case a ~ b ~ c ~ d ~ e ~ f => - ConstraintRow(tableName = a, name = b, typ = c, refCol = d, refTable = e, fromCol = f) - }).map { - _.groupBy(_.tableName).map { case (tName, constraints) => - ( - tName, - constraints.groupBy(c => (c.name, c.typ)).toList.map { - case ((cName, "PRIMARY KEY"), cItems) => - Constraint.PrimaryKey(name = cName, columnNames = cItems.map(_.fromCol)) - case ((cName, "UNIQUE"), cItems) => Constraint.Unique(name = cName, columnNames = cItems.map(_.fromCol)) - case ((cName, "FOREIGN KEY"), cItems) => - Constraint.ForeignKey( - name = cName, - refs = cItems.map { cr => - ColumnRef(fromColName = cr.fromCol, toColName = cr.refCol, toTableName = cr.refTable) - }, - ) - case ((cName, _), _) => Constraint.Unknown(cName) - }, - ) - } - } - } - - private def toType( - udt: String, - maxCharLength: Option[Int], - numPrecision: Option[Int], - numScale: Option[Int], - ): Type = - (udt, maxCharLength, numPrecision, numScale) match { - case (u @ ("bpchar" | "varchar"), Some(l), _, _) => Type(s"$u($l)") - case ("numeric", _, Some(p), Some(s)) => Type(s"numeric($p${if (s > 0) ", " + s.toString else ""})") - case _ => - val componentTypes = if (udt.startsWith("_")) List(Type(udt.stripPrefix("_"))) else Nil - Type(udt, componentTypes) - } - - private def getColumns(s: Session[IO], enums: Enums): IO[TableMap[Column]] = { - val filterFragment: Fragment[Void] = - sql" AND table_name NOT IN (#${(schemaHistoryTableName :: excludeTables).mkString("'", "','", "'")})" - - val q = - sql"""SELECT table_name,column_name,udt_name,character_maximum_length,numeric_precision,numeric_scale,is_nullable,column_default,is_generated - FROM information_schema.COLUMNS WHERE table_schema = 'public'$filterFragment UNION - (SELECT - cls.relname AS table_name, - attr.attname AS column_name, - tp.typname AS udt_name, - information_schema._pg_char_max_length(information_schema._pg_truetypid(attr.*, tp.*), information_schema._pg_truetypmod( - attr.*, tp.*))::information_schema.cardinal_number AS character_maximum_length, - information_schema._pg_numeric_precision(information_schema._pg_truetypid(attr.*, tp.*), information_schema._pg_truetypmod( - attr.*, tp.*))::information_schema.cardinal_number AS numeric_precision, - information_schema._pg_numeric_scale(information_schema._pg_truetypid(attr.*, tp.*), information_schema._pg_truetypmod( - attr.*, tp.*))::information_schema.cardinal_number AS numeric_scale, - CASE - WHEN attr.attnotnull OR tp.typtype = 'd'::"char" AND tp.typnotnull THEN 'NO'::text - ELSE 'YES'::text - END::information_schema.yes_or_no AS is_nullable, - NULL AS column_default, - 'NEVER' AS is_generated - FROM pg_catalog.pg_attribute as attr - JOIN pg_catalog.pg_class as cls on cls.oid = attr.attrelid - JOIN pg_catalog.pg_namespace as ns on ns.oid = cls.relnamespace - JOIN pg_catalog.pg_type as tp on tp.oid = attr.atttypid - WHERE cls.relkind = 'm' and attr.attnum >= 1 AND ns.nspname = 'public' - ORDER by attr.attnum) - """.query(name ~ name ~ name ~ int4.opt ~ int4.opt ~ int4.opt ~ varchar(3) ~ varchar.opt ~ varchar) - - s.execute(q.map { - case tName ~ colName ~ udt ~ maxCharLength ~ numPrecision ~ numScale ~ nullable ~ default ~ is_generated => - ( - tName, - colName, - toType(udt, maxCharLength, numPrecision, numScale), - nullable == "YES", - default.flatMap(ColumnDefault.fromString), - is_generated == "ALWAYS", - ) - }).map(_.map { case (tName, colName, udt, isNullable, default, isAlwaysGenerated) => - toScalaType(udt, isNullable, enums).map { st => - ( - tName, - Column( - columnName = colName, - pgType = udt, - isEnum = enums.exists(_.name == udt.name), - scalaType = st, - isNullable = isNullable, - default = default, - isAlwaysGenerated = isAlwaysGenerated, - ), - ) - }.leftMap(new Exception(_)) - }).flatMap { - _.traverse(IO.fromEither(_)).map { - _.groupBy(_._1).map { case (k, v) => (k, v.map(_._2)) } - } - } - } - - private def getIndexes(s: Session[IO]): IO[TableMap[Index]] = { - val q: Query[Void, String ~ String ~ String] = - sql"""SELECT indexname,indexdef,tablename FROM pg_indexes WHERE schemaname='public'""".query(name ~ text ~ name) - - s.execute(q.map { case name ~ indexDef ~ tableName => - (tableName, Index(name, indexDef)) - }).map { - _.groupBy(_._1).map { case (tName, v) => (tName, v.map(_._2)) } - } - } - - private def getEnums(s: Session[IO]): IO[Enums] = { - val q: Query[Void, String ~ String] = - sql"""SELECT pt.typname,pe.enumlabel FROM pg_enum AS pe JOIN pg_type AS pt ON pt.oid = pe.enumtypid""".query( - name ~ name - ) - - s.execute(q.map { case name ~ value => - (name, value) - }).map { - _.groupBy(_._1).toList.map { case (name, values) => - Enum(name, values.map(_._2).map(EnumValue(_))) - } - } - } - - private def getViews(s: Session[IO]): IO[Set[TableName]] = { - val q: Query[Void, String] = - sql"""SELECT table_name FROM information_schema.VIEWS WHERE table_schema = 'public' - UNION - SELECT matviewname FROM pg_matviews WHERE schemaname = 'public';""".query(name) - - s.execute(q).map(_.toSet) - } - - private val postgresDBSingleSession = Session - .single[IO]( - host = host, - port = port, - user = user, - database = database, - password = password, - ) - - private val singleSession = Session - .single[IO]( - host = host, - port = port, - user = user, - database = operateDatabase.getOrElse(database), - password = password, - ) - - private val dumboWithFiles = Dumbo.withFilesIn[IO](fs2.io.file.Path.fromNioPath(sourceDir.toPath())) - - private val dumbo = dumboWithFiles.apply( - connection = ConnectionConfig( - host = host, - user = user, - database = operateDatabase.getOrElse(database), - port = port, - password = password, - ), - defaultSchema = "public", - schemaHistoryTable = schemaHistoryTableName, - ) - - private def listMigrationFiles: IO[List[ResourceFile]] = dumboWithFiles.listMigrationFiles.flatMap { - case Validated.Invalid(errs) => - IO.raiseError( - new Throwable( - s"Failed reading source files:\n${errs.toList.map(_.getMessage()).mkString("\n")}" - ) - ) - case Validated.Valid(files) => IO.pure(files) - } - - private def generatorTask: IO[List[File]] = - postgresDBSingleSession.use { s => - operateDatabase match { - case Some(opDBName) => - for { - result <- s.execute(sql"SELECT true FROM pg_database WHERE datname = ${varchar}".query(bool))(opDBName) - _ <- IO.whenA(result.isEmpty)(s.execute(sql"CREATE DATABASE #${opDBName};".command).as(())) - } yield () - case None => IO.unit - } - } >> singleSession.use { s => - for { - _ <- s.execute(sql"DROP SCHEMA public CASCADE;".command) - _ <- s.execute(sql"CREATE SCHEMA public;".command) - _ <- dumbo.runMigration.void - enums <- getEnums(s) - (((columns, indexes), constraints), views) <- - getColumns(s, enums).parProduct(getIndexes(s)).parProduct(getConstraints(s)).parProduct(getViews(s)) - tables = toTables(columns, indexes, constraints, views) - files = pkgFiles(tables, enums) ::: tables.flatMap { table => - rowFileContent(table) match { - case None => Nil - case Some(rowContent) => - List( - pkgDir / s"${table.tableClassName}.scala" -> tableFileContent(table), - pkgDir / s"${table.rowClassName}.scala" -> rowContent, - ) - } - } - _ <- IO(pkgDir.delete(true).createDirectoryIfNotExists()) - res <- files.parTraverse { case (file, content) => - for { - _ <- IO(file.writeText(content)) - _ <- IO.println(s"Created ${file.pathAsString}") - } yield file - } - } yield res - } - - private val pgServiceName = pkgName.replace('.', '-') - - private def awaitReadiness: IO[Unit] = - fs2.Stream - .repeatEval(postgresDBSingleSession.use(_.unique(sql"SELECT 1".query(int4)).void).attempt.map(_.swap.toOption)) - .metered(500.millis) - .timeout(10.seconds) - .unNoneTerminate - .compile - .drain - .onError(e => IO.println(s"Could not connect to docker on $host:$port ${e.getMessage()}")) - - private def startDocker: IO[Unit] = - useDockerImage match { - case None => IO.unit - case Some(image) => - val cmd = - s"docker run -p $port:5432 -h $host -e POSTGRES_USER=$user ${password.fold("")(p => s"-e POSTGRES_PASSWORD=$p ")}" + - s"--name $pgServiceName -d $image" - IO(cmd.!!) >> awaitReadiness - } - - private def rmDocker: IO[Unit] = if (useDockerImage.nonEmpty) { - IO(s"docker rm -f $pgServiceName" ! ProcessLogger(_ => ())).void - } else IO.unit - - private def toTables( - columns: TableMap[Column], - indexes: TableMap[Index], - constraints: TableMap[Constraint], - views: Set[TableName], - ): List[Table] = { - - def findAutoIncColumns(tableName: TableName): List[Column] = - columns - .getOrElse(tableName, Nil) - .filter(_.default.contains(ColumnDefault.AutoInc)) - - def findAutoPk(tableName: TableName): Option[Column] = findAutoIncColumns(tableName) - .find(col => - constraints - .getOrElse(tableName, Nil) - .collect { case c: Constraint.PrimaryKey => c } - .exists(_.columnNames.contains(col.columnName)) - ) - - columns.toList.map { case (tname, tableCols) => - val tableConstraints = constraints.getOrElse(tname, Nil) - val generatedCols = findAutoIncColumns(tname) ::: tableCols.filter(_.isAlwaysGenerated) - val autoIncFk = tableConstraints.collect { case c: Constraint.ForeignKey => c }.flatMap { - _.refs.flatMap { ref => - tableCols.find(c => c.columnName == ref.fromColName).filter { _ => - findAutoPk(ref.toTableName).exists(_.columnName == ref.toColName) - } - } - } - - Table( - name = tname, - columns = tableCols.filterNot((generatedCols ::: autoIncFk).contains), - generatedColumns = generatedCols, - constraints = tableConstraints, - indexes = indexes.getOrElse(tname, Nil), - autoIncFk = autoIncFk, - isView = views.contains(tname), - ) - } - } - - private def scalaEnums(enums: Enums): List[(File, String)] = - enums.map { e => - ( - pkgDir / s"${e.scalaName}.scala", - if (!scalaVersion.startsWith("3")) { - s"""|package $pkgName - | - |import skunk.Codec - |import enumeratum.values.{StringEnumEntry, StringEnum} - |import skunk.data.Type - | - |sealed abstract class ${e.scalaName}(val value: String) extends StringEnumEntry - |object ${e.scalaName} extends StringEnum[${e.scalaName}] { - | ${e.values - .map(v => s"""case object ${v.scalaName} extends ${e.scalaName}("${v.name}")""") - .mkString("\n ")} - | - | val values: IndexedSeq[${e.scalaName}] = findValues - | - | implicit val codec: Codec[${e.scalaName}] = - | Codec.simple[${e.scalaName}]( - | a => a.value, - | s => withValueEither(s).left.map(_.getMessage()), - | Type("${e.name}"), - | ) - |}""".stripMargin - } else { - s"""|package $pkgName - | - |import skunk.Codec - |import skunk.data.Type - | - |enum ${e.scalaName}(val value: String): - | ${e.values.map(v => s"""case ${v.scalaName} extends ${e.scalaName}("${v.name}")""").mkString("\n ")} - | - |object ${e.scalaName}: - | implicit val codec: Codec[${e.scalaName}] = - | Codec.simple[${e.scalaName}]( - | a => a.value, - | s =>${e.scalaName}.values.find(_.value == s).toRight(s"Invalid ${e.name} type: $$s"), - | Type("${e.name}"), - | )""".stripMargin - }, - ) - } - - private def pkgFiles(tables: List[Table], enums: Enums): List[(File, String)] = { - val indexes = tables.flatMap { table => - table.indexes.map(i => - s"""val ${toScalaName(i.name)} = Index(name = "${i.name}", createSql = \"\"\"${i.createSql}\"\"\")""" - ) - }.mkString(" object indexes {\n ", "\n ", "\n }") - - val constraints = tables.flatMap { table => - table.constraints.map(c => s"""val ${toScalaName(c.name)} = Constraint(name = "${c.name}")""") - }.mkString(" object constraints {\n ", "\n ", "\n }") - - val arrayCodec = - s"""| implicit class ListCodec[A](arrCodec: skunk.Codec[skunk.data.Arr[A]]) { - | def _list(implicit factory: scala.collection.compat.Factory[A, List[A]]): skunk.Codec[List[A]] = { - | arrCodec.imap(arr => arr.flattenTo(factory))(xs => skunk.data.Arr.fromFoldable(xs)) - | } - | }""".stripMargin - val pkgLastPart = pkgName.split('.').last - List( - ( - pkgDir / "package.scala", - s"""|package ${pkgName.split('.').dropRight(1).mkString(".")} - | - |package object ${pkgLastPart} { - | - |$arrayCodec - | - |$indexes - | - |$constraints - | - |}""".stripMargin, - ), - // ( - // pkgDir / "Column.scala", - // s"""|package $pkgName - // | - // |abstract class Column[T](val name: String) { - // | type Type = T - // | override def toString: String = name - // |} - // |""".stripMargin, - // ), - ( - pkgDir / "Index.scala", - s"""|package $pkgName - | - |final case class Index(name: String, createSql: String) - """.stripMargin, - ), - ( - pkgDir / "Constraint.scala", - s"""|package $pkgName - | - |final case class Constraint(name: String) - """.stripMargin, - ), - ( - pkgDir / "Cols.scala", - s"""|package $pkgName - |import skunk.* - |import cats.data.NonEmptyList - |import cats.implicits.* - | - |final case class Cols[A] private[$pkgLastPart] (names: NonEmptyList[String], codec: Codec[A], tableAlias: String) - | extends (A => AppliedCol[A]) { - | def name: String = names.intercalate(",") - | def fullName: String = names.map(n => s"$${tableAlias}.$$n").intercalate(",") - | def aliasedName: String = names.map(name => s"$${tableAlias}.$${name} $${tableAlias}__$$name").intercalate(",") - | def ~[B](that: Cols[B]): Cols[(A, B)] = Cols(this.names ::: that.names, this.codec ~ that.codec, this.tableAlias) - | def apply(a: A): AppliedCol[A] = AppliedCol(this, a) - |} - | - |final case class AppliedCol[A] (cols: Cols[A], value: A) { - | def name = cols.name - | def fullName = cols.fullName - | def codec = cols.codec - | - | def ~[B] (that: AppliedCol[B]): AppliedCol[(A, B)] = AppliedCol(this.cols ~ that.cols, (this.value, that.value)) - |} - |""".stripMargin, - ), - ) ::: scalaEnums(enums) - } - - private def toScalaType(t: Type, isNullable: Boolean, enums: Enums): Result[ScalaType] = - t.componentTypes match { - case Nil => - Map[String, List[Type]]( - "Boolean" -> bool.types, - "String" -> (text.types ::: varchar.types ::: bpchar.types ::: name.types), - "java.util.UUID" -> uuid.types, - "Short" -> int2.types, - "Int" -> int4.types, - "Long" -> int8.types, - "BigDecimal" -> numeric.types, - "Float" -> float4.types, - "Double" -> float8.types, - "java.time.LocalDate" -> date.types, - "java.time.LocalTime" -> time.types, - "java.time.OffsetTime" -> timetz.types, - "java.time.LocalDateTime" -> timestamp.types, - "java.time.OffsetDateTime" -> timestamptz.types, - "java.time.Duration" -> List(Type.interval), - ).collectFirst { - // check by type name without a max length parameter if set, e.g. vacrhar instead of varchar(3) - case (scalaType, pgTypes) if pgTypes.map(_.name).contains(t.name.takeWhile(_ != '(')) => - if (isNullable) s"Option[$scalaType]" else scalaType - }.orElse { - enums.find(_.name == t.name).map(e => if (isNullable) s"Option[${e.scalaName}]" else e.scalaName) - }.toRight(s"No scala type found for type ${t.name}") - case x :: Nil => - toScalaType(x, isNullable = false, enums).map(t => if (isNullable) s"Option[List[$t]]" else s"List[$t]") - case x :: xs => - Left(s"Unsupported type of multiple components: ${x :: xs}") - } - - private def rowFileContent(table: Table): Option[String] = { - import table.* - - def toClassPropsStr(cols: List[Column]) = cols - .map(c => s" ${c.scalaName}: ${c.scalaType}") - .mkString("", ",\n", "") - - def toUpdateClassPropsStr(cols: List[Column]) = cols - .map(c => s" ${c.scalaName}: Option[${c.scalaType}]") - .mkString("", ",\n", "") - - def toCodecFieldsStr(cols: List[Column]) = s"${cols.map(_.codecName).mkString(" *: ")}" - - def toUpdateFragment(cols: List[Column]) = { - def toOptFrStr(c: Column) = s"""${c.scalaName}.map(sql"${c.columnName}=$${${c.codecName}}".apply(_))""" - s""" - def fragment: AppliedFragment = List( - ${cols.map(toOptFrStr(_)).mkString(",\n")} - ).flatten.intercalate(void",") - """ - } - - val rowUpdateClassData = - if (table.isView) (Nil, Nil) - else - primaryUniqueConstraint match { - case Some(cstr) => - columns.filterNot(cstr.containsColumn) match { - case Nil => (Nil, Nil) - case updateCols => - val colsData = toUpdateClassPropsStr(updateCols) - val fragmentData = toUpdateFragment(updateCols) - ( - updateCols, - List( - "", - s"final case class $rowUpdateClassName(", - s"$colsData", - ") {", - fragmentData, - "}", - ), - ) - } - - case None => (Nil, Nil) - } - - def withImportsStr = rowUpdateClassData match { - case (Nil, Nil) => "" - case (_, _) => List("import skunk.implicits.*", "import cats.implicits.*").mkString("\n") - } - - def withUpdateStr = rowUpdateClassData match { - case (Nil, Nil) => "" - case (cols, _) => - val updateProps = cols.map(_.scalaName).map(n => s"$n = Some($n)").mkString(" ", ",\n ", "") - List( - " {", - s" def asUpdate: $rowUpdateClassName = $rowUpdateClassName(", - s"$updateProps", - " )", - "", - s" def withUpdateAll: ($rowClassName, AppliedFragment) = (this, asUpdate.fragment)", - s" def withUpdate(f: $rowUpdateClassName => $rowUpdateClassName): ($rowClassName, AppliedFragment) = (this, f(asUpdate).fragment)", - "", - "}", - ).mkString("\n") - } - - columns.headOption.map { _ => - val colsData = toClassPropsStr(columns) - val codecData = toCodecFieldsStr(columns) - List( - s"package $pkgName", - "", - "import skunk.*", - withImportsStr, - "", - s"final case class $rowClassName(", - s"$colsData", - s")$withUpdateStr", - "", - s"object $rowClassName {", - s" implicit val codec: Codec[$rowClassName] = ($codecData).to[$rowClassName]", - "}", - s"${rowUpdateClassData._2.mkString("\n")}", - ).mkString("\n") - } - } - - private def tableFileContent(table: Table): String = { - val (maybeAllCol, cols) = tableColumns(table) - ( - List( - s"package $pkgName\n", - "import skunk.*", - "import skunk.implicits.*", - "import cats.data.NonEmptyList", - ) ::: - List( - "", - s"class ${table.tableClassName}(val tableName: String) {", - s" def withPrefix(prefix: String): ${table.tableClassName} = new ${table.tableClassName}(prefix + tableName)", - s" def withAlias(alias: String): ${table.tableClassName} = new ${table.tableClassName}(alias)", - "", - maybeAllCol.getOrElse(""), - " object column {", - s"$cols", - " }", - writeStatements(table), - selectAllStatement(table), - "}", - "", - s"""object ${table.tableClassName} extends ${table.tableClassName}("${table.name}")""", - ) - ).mkString("\n") - } - - private def queryTypesStr(table: Table): (String, String) = { - import table.* - - if (autoIncFk.isEmpty) { - (rowClassName, s"${rowClassName}.codec") - } else { - val autoIncFkCodecs = autoIncFk.map(_.codecName).mkString(" *: ") - val autoIncFkScalaTypes = autoIncFk.map(_.scalaType).mkString(" *: ") - (s"($autoIncFkScalaTypes ~ $rowClassName)", s"$autoIncFkCodecs ~ ${rowClassName}.codec") - } - } - - private def writeStatements(table: Table): String = - if (table.isView) "" - else { - import table.* - - val allCols = autoIncFk ::: columns - val allColNames = allCols.map(_.columnName).mkString(",") - val (insertScalaType, insertCodec) = queryTypesStr(table) - - val returningStatement = generatedColumns match { - case Nil => "" - case _ => generatedColumns.map(_.columnName).mkString(" RETURNING ", ",", "") - } - val returningType = generatedColumns - .map(_.scalaType) - .mkString("", " *: ", if (generatedColumns.length > 1) " *: EmptyTuple" else "") - val fragmentType = generatedColumns match { - case Nil => "command" - case _ => s"query(${generatedColumns.map(_.codecName).mkString(" *: ")})" - } - - val upsertQ = primaryUniqueConstraint.map { cstr => - val queryType = generatedColumns match { - case Nil => s"Command[$insertScalaType *: updateFr.A *: EmptyTuple]" - case _ => s"Query[$insertScalaType *: updateFr.A *: EmptyTuple, $returningType]" - } - - s"""| def upsertQuery(updateFr: AppliedFragment, constraint: Constraint = Constraint("${cstr.name}")): $queryType = - | sql\"\"\"INSERT INTO #$$tableName ($allColNames) VALUES ($${$insertCodec}) - | ON CONFLICT ON CONSTRAINT #$${constraint.name} - | DO UPDATE SET $${updateFr.fragment}$returningStatement\"\"\".$fragmentType""".stripMargin - } - - val queryType = generatedColumns match { - case Nil => s"Command[$insertScalaType]" - case _ => s"Query[$insertScalaType, $returningType]" - } - val insertQ = - s"""| def insertQuery(ignoreConflict: Boolean = true): $queryType = { - | val onConflictFr = if (ignoreConflict) const" ON CONFLICT DO NOTHING" else const"" - | sql\"INSERT INTO #$$tableName ($allColNames) VALUES ($${$insertCodec})$$onConflictFr$returningStatement\".$fragmentType - | }""".stripMargin - - val insertCol = - s"""| - | def insert[A](cols: Cols[A]): Command[A] = - | sql\"INSERT INTO #$$tableName (#$${cols.name}) VALUES ($${cols.codec})\".command - | - | def insert0[A, B](cols: Cols[A], rest: Fragment[B] = sql"ON CONFLICT DO NOTHING")(implicit - | ev: Void =:= B - | ): Command[A] = - | (sql\"INSERT INTO #$$tableName (#$${cols.name}) VALUES ($${cols.codec}) " ~ rest).command.contramap[A](a => (a, ev.apply(Void))) - | - | def insert[A, B](cols: Cols[A], rest: Fragment[B] = sql"ON CONFLICT DO NOTHING"): Command[(A, B)] = - | (sql\"INSERT INTO #$$tableName (#$${cols.name}) VALUES ($${cols.codec})" ~ rest).command - |""".stripMargin - List( - upsertQ.getOrElse(""), - insertQ, - insertCol, - ).mkString("\n\n") - } - - private def tableColumns(table: Table): (Option[String], String) = { - val allCols = table.generatedColumns ::: table.autoIncFk ::: table.columns - val cols = - allCols.map(column => - s""" val ${column.snakeCaseScalaName} = Cols(NonEmptyList.of("${column.columnName}"), ${column.codecName}, tableName)""" - ) - - val allCol = NonEmptyList - .fromList(table.columns.map(_.columnName)) - .map { xs => - val s = xs.map(x => s""""$x"""").intercalate(",") - s"""| - | val all = Cols(NonEmptyList.of($s), ${table.rowClassName}.codec, tableName) - |""".stripMargin - } - - allCol -> cols.mkString("\n") - } - - private def selectAllStatement(table: Table): String = { - import table.* - - val generatedColStm = if (generatedColumns.nonEmpty) { - val types = generatedColumns.map(_.codecName).mkString(" *: ") - val sTypes = generatedColumns.map(_.scalaType).mkString(" *: ") - val colNamesStr = (generatedColumns ::: columns).map(_.columnName).mkString(", ") - - s""" - | def selectAllWithGenerated[A](addClause: Fragment[A] = Fragment.empty): Query[A, $sTypes *: $rowClassName *: EmptyTuple] = - | sql"SELECT $colNamesStr FROM #$$tableName $$addClause".query($types *: ${rowClassName}.codec) - | - """.stripMargin - } else { - "" - } - - val colNamesStr = (autoIncFk ::: columns).map(_.columnName).mkString(",") - val (queryReturnType, queryCodec) = queryTypesStr(table) - - val defaultStm = s""" - | def selectAll[A](addClause: Fragment[A] = Fragment.empty): Query[A, $queryReturnType] = - | sql"SELECT $colNamesStr FROM #$$tableName $$addClause".query($queryCodec) - | - |""".stripMargin - - val selectCol = s"""| def select[A, B](cols: Cols[A], rest: Fragment[B] = Fragment.empty): Query[B, A] = - | sql"SELECT #$${cols.name} FROM #$$tableName $$rest".query(cols.codec) - |""".stripMargin - generatedColStm ++ defaultStm ++ selectCol - } - - private def lastModified(modified: List[Long]): Option[Long] = - modified.sorted(Ordering[Long].reverse).headOption - - private def outputFilesOutdated(sourcesModified: List[Long]): Boolean = (for { - s <- lastModified(sourcesModified) - outFiles = if (pkgDir.exists) pkgDir.list.toList else Nil - o <- lastModified(outFiles.map(_.lastModifiedTime.getEpochSecond())) - // can't rely on timestamps when running in CI - isNotCI = sys.env.get("CI").isEmpty - } yield isNotCI && o < s).getOrElse(true) - - def run(forceRegeneration: Boolean = false): IO[List[JFile]] = - listMigrationFiles.flatMap { sourceFiles => - sourceFiles - .map(_.path) - .traverse(Files[IO].getLastModifiedTime(_)) - .map(l => outputFilesOutdated(l.map(_.toSeconds))) - .map((sourceFiles, _)) - }.flatMap { case (sourceFiles, isOutdated) => - (if ((forceRegeneration || (!pkgDir.exists() || isOutdated))) { - (for { - _ <- IO.raiseWhen(sourceFiles.isEmpty)(new Exception(s"Cannot find any .sql files in ${sourceDir.toPath()}")) - _ <- IO.whenA(!pkgDir.exists())(IO.println("Generated source not found")) - _ <- IO.whenA(pkgDir.exists() && isOutdated)(IO.println("Generated source is outdated")) - _ <- IO.println("Generating Postgres models") - _ <- rmDocker - _ <- startDocker - files <- generatorTask - _ <- rmDocker - } yield files).onError(_ => rmDocker) - } else { - IO(pkgDir.list.toList) - }) - .map(_.map(_.toJava)) - } - - def unsafeRunSync(forceRegeneration: Boolean = false): Seq[JFile] = { - import cats.effect.unsafe.implicits.global - run(forceRegeneration).unsafeRunSync() - } -} - -object PgCodeGen { - type TableName = String - type TableMap[T] = Map[TableName, List[T]] - type Enums = List[Enum] - type ScalaType = String - type Result[T] = Either[String, T] - - final case class Enum(name: String, values: List[EnumValue]) { - val scalaName: String = toScalaName(name).capitalize - } - final case class EnumValue(name: String) { - val scalaName: String = toScalaName(name.toLowerCase).capitalize - } - - final case class Column( - columnName: String, - scalaType: ScalaType, - pgType: Type, - isEnum: Boolean, - isNullable: Boolean, - default: Option[ColumnDefault], - isAlwaysGenerated: Boolean, - ) { - val scalaName: String = toScalaName(columnName) - val snakeCaseScalaName: String = escapeScalaKeywords(columnName) - - def isArr = pgType.componentTypes.nonEmpty - - val codecName: String = - ( - (if (isEnum) s"${toScalaName(pgType.name).capitalize}.codec" else s"skunk.codec.all.${pgType.name}") + - (if (isArr) "._list" else "") + - (if (isNullable) ".opt" else "") - ) - } - - final case class ColumnRef(fromColName: String, toColName: String, toTableName: String) - - sealed trait ColumnDefault - object ColumnDefault { - case object AutoInc extends ColumnDefault - - def fromString(value: String): Option[ColumnDefault] = - if (value.contains("nextval")) Some(AutoInc) else None - } - sealed trait Constraint { - def name: String - } - sealed trait UniqueConstraint extends Constraint { - def columnNames: List[String] - - def containsColumn(c: Column): Boolean = columnNames.contains(c.columnName) - } - object Constraint { - final case class PrimaryKey(name: String, columnNames: List[String]) extends UniqueConstraint - final case class Unique(name: String, columnNames: List[String]) extends UniqueConstraint - final case class ForeignKey(name: String, refs: List[ColumnRef]) extends Constraint - final case class Unknown(name: String) extends Constraint - } - - final case class Index(name: String, createSql: String) - - final case class Table( - name: String, - columns: List[Column], - generatedColumns: List[Column], - constraints: List[Constraint], - indexes: List[Index], - autoIncFk: List[Column], - isView: Boolean, - ) { - val tableClassName: String = toTableClassName(name) - val rowClassName: String = toRowClassName(name) - val rowUpdateClassName: String = toRowUpdateClassName(name) - - val primaryUniqueConstraint: Option[UniqueConstraint] = constraints.collectFirst { case c: PrimaryKey => - c - }.orElse { - constraints.collectFirst { case c: Constraint.Unique => - c - } - } - - def isInPrimaryConstraint(c: Column): Boolean = primaryUniqueConstraint.exists(_.containsColumn(c)) - } - - final case class ConstraintRow( - tableName: String, - name: String, - typ: String, - refCol: String, - refTable: String, - fromCol: String, - ) - - def toScalaName(s: String): String = - escapeScalaKeywords(toCamelCase(s)) - - def escapeScalaKeywords(v: String): String = - v match { - case "type" => "`type`" - case "import" => "`import`" - case "val" => "`val`" // add more as required - case v => v - } - - def toCamelCase(s: String, capitalize: Boolean = false): String = - s.split("_") - .zipWithIndex - .map { - case (t, 0) if !capitalize => t - case (t, _) => t.capitalize - } - .mkString - - private def toRowClassName(s: String): String = - toCamelCase(s, capitalize = true) + "Row" - - private def toRowUpdateClassName(s: String): String = - toCamelCase(s, capitalize = true) + "Update" - - private def toTableClassName(s: String): String = - toCamelCase(s, capitalize = true) + "Table" -} diff --git a/modules/core/src/test/scala-2.12/com/anymindgroup/testsupport.scala b/modules/core/src/test/scala-2.12/com/anymindgroup/testsupport.scala deleted file mode 100644 index 9970a1a..0000000 --- a/modules/core/src/test/scala-2.12/com/anymindgroup/testsupport.scala +++ /dev/null @@ -1,5 +0,0 @@ -package com.anymindgroup - -object testsupport { - val scalaVersion: String = "2.12.x" -} diff --git a/modules/core/src/test/scala-2.13/com/anymindgroup/testsupport.scala b/modules/core/src/test/scala-2.13/com/anymindgroup/testsupport.scala deleted file mode 100644 index 12c7298..0000000 --- a/modules/core/src/test/scala-2.13/com/anymindgroup/testsupport.scala +++ /dev/null @@ -1,5 +0,0 @@ -package com.anymindgroup - -object testsupport { - val scalaVersion: String = "2.13.x" -} diff --git a/modules/core/src/test/scala-3/com/anymindgroup/testsupport.scala b/modules/core/src/test/scala-3/com/anymindgroup/testsupport.scala deleted file mode 100644 index 52591b2..0000000 --- a/modules/core/src/test/scala-3/com/anymindgroup/testsupport.scala +++ /dev/null @@ -1,5 +0,0 @@ -package com.anymindgroup - -object testsupport { - val scalaVersion: String = "3.3.x" -} diff --git a/modules/core/src/test/scala/com/anymindgroup/GeneratedCodeTest._scala b/modules/core/src/test/scala/com/anymindgroup/GeneratedCodeTest._scala deleted file mode 100644 index f5ad61a..0000000 --- a/modules/core/src/test/scala/com/anymindgroup/GeneratedCodeTest._scala +++ /dev/null @@ -1,223 +0,0 @@ -package com.anymindgroup - -import cats.effect.IOApp -import cats.effect.{ExitCode, IO} -import natchez.Trace.Implicits.noop -import com.anymindgroup.generated.TestBTable -import com.anymindgroup.generated.TestMaterializedViewTable -import com.anymindgroup.generated.TestBRow -import skunk.implicits.* -import com.anymindgroup.generated.TestTable -import com.anymindgroup.generated.TestRow -import com.anymindgroup.generated.TestEnumType -import skunk.util.Typer -import java.time.OffsetDateTime -import better.files.File -import skunk.* -import skunk.codec.all.* -import skunk.util.Origin -import skunk.Command as SqlCommand -import cats.implicits.* -import java.time.temporal.ChronoUnit -import java.time.ZoneOffset - -object GeneratedCodeTest extends IOApp { - - override def run(args: List[String]): IO[ExitCode] = migrate >> Session - .single[IO]( - host = "localhost", - port = 5432, - user = "postgres", - database = "postgres", - password = Some("postgres"), - strategy = Typer.Strategy.SearchPath, // to include custom types like enums, - ) - .use { s => - val (testRow, testUpdateFr) = TestRow( - number = Some(1), - createdAt = OffsetDateTime.now(ZoneOffset.UTC).truncatedTo(ChronoUnit.MILLIS), - template = Some(TestEnumType.T1One), - name = Some("name"), - name2 = "name2", - `type` = Some("type"), - tla = "abc", - tlaVar = "abc", - numericDefault = BigDecimal(1), - numeric24p = BigDecimal(2), - numeric16p2s = BigDecimal(3), - ).withUpdateAll - - val testBRow = TestBRow( - keyA = "keyA", - keyB = "keyB", - val1 = "val1", - val2 = "val2", - val3 = "val3", - val4 = "val4", - val5 = "val5", - val6 = "val6", - val7 = "val7", - val8 = "val8", - val9 = "val9", - val10 = "val10", - val11 = "val11", - val12 = "val12", - val13 = "val13", - val14 = "val14", - val15 = "val15", - val16 = "val16", - val17 = "val17", - val18 = "val18", - val19 = "val19", - val20 = "val20", - val21 = "val21", - val22 = "val22", - val23 = "val23", - val24 = "val24", - val25 = "val25", - val26 = List("val26"), - val27 = Some(List(1, 2)), - date = None, - ) - for { - // Test table - p <- s.prepare(TestTable.upsertQuery(testUpdateFr)) - _ <- s.prepare(TestTable.insertQuery(ignoreConflict = true)) - res <- p.option((testRow, testUpdateFr.argument)) - _ <- IO.raiseWhen(res.isEmpty)(new Throwable("test A did not return generated columns")) - id = res.get._1 - _ <- IO.raiseWhen(res.get._2 != 2 && res.get._3 != Some(2))( - new Throwable("unexpected result for generated columns") - ) - all <- s.execute(TestTable.selectAll()) - allWithGen <- s.execute(TestTable.selectAllWithGenerated()) - _ <- IO.raiseWhen(all != List(testRow))(new Throwable("test A result not equal")) - _ <- IO.raiseWhen(allWithGen.map(_._4) != List(testRow))(new Throwable("test A result with id not equal")) - aliasedTestTable = TestTable.withAlias("t") - idAndName2 = aliasedTestTable.column.id ~ aliasedTestTable.column.name_2 - xs <- - s.execute( - sql"""SELECT #${idAndName2.aliasedName},#${aliasedTestTable.column.name.fullName} FROM #${TestTable.tableName} #${aliasedTestTable.tableName}""" - .query(idAndName2.codec ~ TestTable.column.name.codec) - ) - _ <- IO.raiseWhen(xs != List((id, testRow.name2) -> testRow.name))( - new Throwable("test A select fields not equal") - ) - all2 <- s.execute(TestTable.select(TestTable.all)) - _ <- IO.raiseWhen(all2 != List(testRow))(new Throwable("test A select all fields not equal")) - // TestB table - testBUpdateAllFr = testBRow.withUpdateAll._2 - - upsertCmd <- s.prepare(TestBTable.upsertQuery(testBUpdateAllFr)) - _ <- upsertCmd.execute((testBRow, testBUpdateAllFr.argument)) - allLoaded <- s.execute(TestBTable.selectAll()) - _ <- IO.raiseWhen(List(testBRow) != allLoaded)(new Throwable("test B result not equal")) - loadByIdQ <- s.prepare(TestBTable.selectAll(sql"WHERE key_a = ${varchar} AND key_b = ${varchar}")) - loadedById <- loadByIdQ.option((testBRow.keyA, testBRow.keyB)) - _ <- IO.raiseWhen(Some(testBRow) != loadedById)(new Throwable("test B result by id not equal")) - notFoundRes <- s.execute(TestBTable.selectAll(sql"WHERE key_a = 'not_existing'")) - _ <- IO.raiseWhen(notFoundRes.nonEmpty)(new Throwable("test B query result is empty")) - - testBRowUpdate = testBRow.copy(val1 = "val1_update", val2 = "val2_update") - testBUpdateFr = testBRowUpdate.withUpdate(_.copy(val2 = None))._2 // exclude update of val2 - - _ <- s.execute(TestBTable.upsertQuery(testBUpdateFr))((testBRowUpdate, testBUpdateFr.argument)) - afterUpdate <- s.execute(TestBTable.selectAll()) - _ <- IO.raiseWhen(afterUpdate.length != 1)(new Throwable("test B result unexpected length")) - _ <- - IO.raiseWhen( - afterUpdate.headOption.map(_.val1) != Some("val1_update") || - afterUpdate.headOption.map(_.val2) == Some("val2_update") // should not be updated - )( - new Throwable("test B result unexpected update") - ) - // Check Enum variable format - _ = Seq( - TestEnumType.T1One, - TestEnumType.T2Two, - TestEnumType.T3Three, - TestEnumType.T4Four, - TestEnumType.T5Five, - TestEnumType.T6six, - TestEnumType.MultipleWordEnum, - ) - _ <- s.execute(sql"TRUNCATE TABLE #${TestBTable.tableName}".command) - - _ <- s.execute(TestBTable.insert0(TestBTable.all))(testBRow) - allBTable <- s.execute(TestBTable.select(TestBTable.all)) - _ <- IO.raiseWhen(allBTable != List(testBRow))(new Throwable("test B not equal")) - loadedById <- - s.option( - TestBTable.select( - TestBTable.all, - sql"WHERE #${TestBTable.column.key_a.name} = ${TestBTable.column.key_a.codec} AND #${TestBTable.column.key_b.name} = ${TestBTable.column.key_b.codec}", - ) - )((testBRow.keyA, testBRow.keyB)) - - _ <- IO.raiseWhen(Some(testBRow) != loadedById)(new Throwable("test B result by id not equal")) - - _ <- - s.execute( - TestBTable.select( - TestBTable.column.key_a ~ TestBTable.column.key_b ~ TestBTable.column.val_1, - sql"WHERE key_a = 'not_existing'", - ) - ).flatMap(notFoundRes => IO.raiseWhen(notFoundRes.nonEmpty)(new Throwable("test B query result is empty"))) - - updatingFields = - TestBTable.column.val_27(None) ~ TestBTable.column.val_2("updated_val_2") ~ TestBTable.column.val_14( - "updated_val_14" - ) - updateQ = sql""" - ON CONFLICT ON CONSTRAINT #${generated.constraints.testBPkey.name} DO UPDATE SET - (#${updatingFields.name}) = (${updatingFields.codec}) - """ - _ <- s.execute(TestBTable.insert(TestBTable.all, updateQ))(testBRow *: updatingFields.value *: EmptyTuple) - loadedById <- - s.option( - TestBTable.select( - TestBTable.all, - sql"WHERE #${TestBTable.column.key_a.name} = ${TestBTable.column.key_a.codec} AND #${TestBTable.column.key_b.name} = ${TestBTable.column.key_b.codec}", - ) - )((testBRow.keyA, testBRow.keyB)) - _ <- IO.raiseWhen( - Some(testBRow.copy(val27 = None, val2 = "updated_val_2", val14 = "updated_val_14")) != loadedById - )(new Throwable("test B result missing update")) - _ <- s.execute(sql"REFRESH MATERIALIZED VIEW test_materialized_view".command) - result <- s.execute(TestMaterializedViewTable.selectAll()) - _ <- IO.raiseWhen(result.isEmpty)(new Throwable(s"materialized view doesn't have correct value: ${result}")) - _ <- IO.println("Test successful!") - } yield () - - } - .as(ExitCode.Success) - - private def migrate = Session - .single[IO]( - host = "localhost", - port = 5432, - user = "postgres", - database = "postgres", - password = Some("postgres"), - strategy = Typer.Strategy.BuiltinsOnly, - ) - .use { session => - for { - _ <- session.execute(sql"DROP SCHEMA IF EXISTS public CASCADE".command) - _ <- session.execute(sql"CREATE SCHEMA public".command) - _ <- File("modules/core/src/test/resources/db/migration").children.toList - .sortBy(f => f.name.drop(1).takeWhile(_ != '_').toInt) - .flatMap { f => - f.lineIterator - .filterNot(_.trim.startsWith("--")) - .mkString("\n") - .split(";") - .filterNot(l => l.trim.isEmpty || l.contains("ANALYZE VERBOSE")) - .map { sql => - SqlCommand(s"$sql;", Origin(file = f.pathAsString, line = 0), Void.codec) - } - } - .traverse_(session.execute) - } yield () - } -} diff --git a/modules/core/src/test/scala/com/anymindgroup/RunPgCodeGen.scala b/modules/core/src/test/scala/com/anymindgroup/RunPgCodeGen.scala deleted file mode 100644 index a81bc20..0000000 --- a/modules/core/src/test/scala/com/anymindgroup/RunPgCodeGen.scala +++ /dev/null @@ -1,39 +0,0 @@ -package com.anymindgroup - -import cats.effect.{ExitCode, IO, IOApp} -import better.files.* -import com.anymindgroup.testsupport.scalaVersion - -object RunPgCodeGen extends IOApp { - - override def run(args: List[String]): IO[ExitCode] = for { - baseSrcDir <- IO(File("modules") / "core" / "src" / "test") - _ <- IO.println(s"Running test for scala version $scalaVersion") - scalaOutDir = scalaVersion.split('.') match { - case Array("3", _*) => "scala-3" - case Array("2", "12", _*) => "scala-2.12" - case Array("2", "13", _*) => "scala-2.13" - case _ => "scala" - } - scalaOutPkgDir = baseSrcDir / scalaOutDir / "com" / "anymindgroup" - scalaTestOutPkgDir = scalaOutPkgDir / "generated" - _ <- IO(scalaTestOutPkgDir.delete(true)) - _ <- new PgCodeGen( - host = "localhost", - user = "postgres", - database = "postgres", - operateDatabase = Some("new_db"), - port = sys.env.get("CI").fold(5434)(_ => 5432), - password = Some("postgres"), - useDockerImage = sys.env.get("CI").fold(Option("postgres:14-alpine"))(_ => None), - outputDir = (baseSrcDir / scalaOutDir).toJava, - pkgName = "com.anymindgroup.generated", - sourceDir = (baseSrcDir / "resources" / "db" / "migration").toJava, - excludeTables = List("unsupported_yet"), - scalaVersion = scalaVersion, - ).run() - testRunFile = baseSrcDir / "scala" / "com" / "anymindgroup" / "GeneratedCodeTest._scala" - _ <- IO.whenA(testRunFile.exists)(IO(testRunFile.copyTo(scalaOutPkgDir / "GeneratedCodeTest.scala", true)).void) - } yield ExitCode.Success - -} diff --git a/modules/sbt/src/main/scala/com/anymindgroup/sbt/PgCodeGenPlugin.scala b/modules/sbt/src/main/scala/com/anymindgroup/sbt/PgCodeGenPlugin.scala deleted file mode 100644 index 8fd1d4f..0000000 --- a/modules/sbt/src/main/scala/com/anymindgroup/sbt/PgCodeGenPlugin.scala +++ /dev/null @@ -1,98 +0,0 @@ -package com.anymindgroup.sbt - -import sbt._ -import sbt.Keys._ -import java.io.File -import com.anymindgroup.PgCodeGen - -object PgCodeGenPlugin extends AutoPlugin { - - object autoImport { - lazy val pgCodeGen = taskKey[Seq[File]]("Generate models") - - lazy val pgCodeGenHost: SettingKey[String] = - settingKey[String]("Postgres host") - - lazy val pgCodeGenPort: SettingKey[Int] = - settingKey[Int]("Postgres port") - - lazy val pgCodeGenUser: SettingKey[String] = - settingKey[String]("Postgres user") - - lazy val pgCodeGenPassword: SettingKey[Option[String]] = - settingKey[Option[String]]("Postgres user password") - - lazy val pgCodeGenDb: SettingKey[String] = - settingKey[String]("Postgres database name for create connection `postgres` is default value.") - - lazy val pgCodeGenOperateDB: SettingKey[Option[String]] = - settingKey[Option[String]]( - """Giving value will create new database with specified - | name if not exist for pgCodeGen migration process. Recommend to be configure differently - | with multiple module in the same project""".stripMargin - ) - - lazy val pgCodeGenUseDockerImage: SettingKey[Option[String]] = - settingKey[Option[String]]("Whether to use docker and what image") - - lazy val pgCodeGenSqlSourceDir: SettingKey[File] = - settingKey[File]("Directory of sql scripts") - - lazy val pgCodeGenOutputPackage: SettingKey[String] = - settingKey[String]("Package of generated code") - - lazy val pgCodeGenOutputDir: SettingKey[File] = - settingKey[File]("Output directory of generated code") - - lazy val pgCodeGenExcludedTables: SettingKey[List[String]] = - settingKey[List[String]]("Tables that should be excluded") - } - - import autoImport._ - - override def projectSettings: Seq[Def.Setting[_]] = - Seq( - pgCodeGenHost := "localhost", - pgCodeGenUser := "postgres", - pgCodeGenDb := "postgres", - pgCodeGenOperateDB := None, - pgCodeGenPassword := None, - pgCodeGenSqlSourceDir := file("src") / "main" / "resources" / "db" / "migration", - pgCodeGenOutputPackage := "anychat.chat.db", - pgCodeGenOutputDir := (Compile / sourceManaged).value / "main", - pgCodeGenExcludedTables := Nil, - pgCodeGenUseDockerImage := Some("postgres:16-alpine"), - pgCodeGen := { - new PgCodeGen( - host = pgCodeGenHost.value, - port = pgCodeGenPort.value, - user = pgCodeGenUser.value, - password = pgCodeGenPassword.value, - database = pgCodeGenDb.value, - operateDatabase = pgCodeGenOperateDB.value, - outputDir = pgCodeGenOutputDir.value, - pkgName = pgCodeGenOutputPackage.value, - sourceDir = pgCodeGenSqlSourceDir.value, - useDockerImage = pgCodeGenUseDockerImage.value, - excludeTables = pgCodeGenExcludedTables.value, - scalaVersion = scalaVersion.value, - ).unsafeRunSync(true) - }, - Compile / sourceGenerators += Def.task { - new PgCodeGen( - host = pgCodeGenHost.value, - port = pgCodeGenPort.value, - user = pgCodeGenUser.value, - password = pgCodeGenPassword.value, - database = pgCodeGenDb.value, - operateDatabase = pgCodeGenOperateDB.value, - outputDir = pgCodeGenOutputDir.value, - pkgName = pgCodeGenOutputPackage.value, - sourceDir = pgCodeGenSqlSourceDir.value, - useDockerImage = pgCodeGenUseDockerImage.value, - excludeTables = pgCodeGenExcludedTables.value, - scalaVersion = scalaVersion.value, - ).unsafeRunSync() - }.taskValue, - ) -} diff --git a/modules/sbt/src/sbt-test/test/basic/build.sbt b/modules/sbt/src/sbt-test/test/basic/build.sbt deleted file mode 100644 index 5a2475e..0000000 --- a/modules/sbt/src/sbt-test/test/basic/build.sbt +++ /dev/null @@ -1,24 +0,0 @@ -crossScalaVersions := Seq("3.3.4", "2.13.15") - -val skunkVersion = "0.6.4" - -lazy val testRoot = (project in file(".")) - .enablePlugins(PgCodeGenPlugin) - .settings( - name := "test", - Compile / scalacOptions ++= { - if (scalaVersion.value.startsWith("3")) - Seq("-source:future") - else - Seq("-Xsource:3", "-Wconf:cat=scala3-migration:s") - }, - pgCodeGenOutputPackage := "com.example", - pgCodeGenPassword := Some("postgres"), - pgCodeGenPort := sys.env.get("CI").fold(5434)(_ => 5432), - pgCodeGenUseDockerImage := sys.env.get("CI").fold(Option("postgres:14-alpine"))(_ => None), - pgCodeGenSqlSourceDir := file("resources") / "db" / "migration", - pgCodeGenExcludedTables := List("unsupported_yet"), - libraryDependencies ++= Seq( - "org.tpolecat" %% "skunk-core" % skunkVersion - ), - ) diff --git a/modules/sbt/src/sbt-test/test/basic/project/build.properties b/modules/sbt/src/sbt-test/test/basic/project/build.properties deleted file mode 100644 index db1723b..0000000 --- a/modules/sbt/src/sbt-test/test/basic/project/build.properties +++ /dev/null @@ -1 +0,0 @@ -sbt.version=1.10.5 diff --git a/modules/sbt/src/sbt-test/test/basic/project/plugins.sbt b/modules/sbt/src/sbt-test/test/basic/project/plugins.sbt deleted file mode 100644 index a8d39a7..0000000 --- a/modules/sbt/src/sbt-test/test/basic/project/plugins.sbt +++ /dev/null @@ -1 +0,0 @@ -addSbtPlugin("com.anymindgroup" % "sbt-skunk-codegen" % System.getProperty("plugin.version")) diff --git a/modules/sbt/src/sbt-test/test/basic/resources/db/migration/V1__test.sql b/modules/sbt/src/sbt-test/test/basic/resources/db/migration/V1__test.sql deleted file mode 100644 index c752b18..0000000 --- a/modules/sbt/src/sbt-test/test/basic/resources/db/migration/V1__test.sql +++ /dev/null @@ -1,33 +0,0 @@ -CREATE TYPE test_enum_type AS ENUM ('T1', 'T2', 'T3', 'T4', 'T5', 'T6'); - --- some comment -CREATE TABLE test ( - -- ignore this... - id SERIAL PRIMARY KEY, - created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - name text, - name_2 varchar NOT NULL, - number int, - template test_enum_type -); - -CREATE TABLE test_ref_only ( - test_id INT NOT NULL REFERENCES test(id) ON DELETE CASCADE -); - -CREATE TABLE test_ref ( - test_id INT NOT NULL REFERENCES test(id) ON DELETE CASCADE, - ref_name VARCHAR NOT NULL -); - -CREATE TABLE test_ref_auto_pk ( - id SERIAL PRIMARY KEY, - test_id INT NOT NULL REFERENCES test(id) ON DELETE CASCADE, - ref_name VARCHAR NOT NULL -); - -CREATE TABLE test_ref_pk ( - id VARCHAR PRIMARY KEY, - test_id INT NOT NULL REFERENCES test(id) ON DELETE CASCADE, - ref_name VARCHAR NOT NULL -); \ No newline at end of file diff --git a/modules/sbt/src/sbt-test/test/basic/resources/db/migration/V2__test_b.sql b/modules/sbt/src/sbt-test/test/basic/resources/db/migration/V2__test_b.sql deleted file mode 100644 index d533a2e..0000000 --- a/modules/sbt/src/sbt-test/test/basic/resources/db/migration/V2__test_b.sql +++ /dev/null @@ -1,30 +0,0 @@ -CREATE TABLE test_b ( - key_a VARCHAR NOT NULL, - key_b VARCHAR NOT NULL, - val_1 VARCHAR NOT NULL, - val_2 VARCHAR NOT NULL, - val_3 VARCHAR NOT NULL, - val_4 VARCHAR NOT NULL, - val_5 VARCHAR NOT NULL, - val_6 VARCHAR NOT NULL, - val_7 VARCHAR NOT NULL, - val_8 VARCHAR NOT NULL, - val_9 VARCHAR NOT NULL, - val_10 VARCHAR NOT NULL, - val_11 VARCHAR NOT NULL, - val_12 VARCHAR NOT NULL, - val_13 VARCHAR NOT NULL, - val_14 VARCHAR NOT NULL, - val_15 VARCHAR NOT NULL, - val_16 VARCHAR NOT NULL, - val_17 VARCHAR NOT NULL, - val_18 VARCHAR NOT NULL, - val_19 VARCHAR NOT NULL, - val_20 VARCHAR NOT NULL, - val_21 VARCHAR NOT NULL, - val_22 VARCHAR NOT NULL, - val_23 VARCHAR NOT NULL, - val_24 VARCHAR NOT NULL, - val_25 VARCHAR NOT NULL, - PRIMARY KEY (key_a, key_b) -); diff --git a/modules/sbt/src/sbt-test/test/basic/resources/db/migration/V3__exclude.sql b/modules/sbt/src/sbt-test/test/basic/resources/db/migration/V3__exclude.sql deleted file mode 100644 index dcc69f9..0000000 --- a/modules/sbt/src/sbt-test/test/basic/resources/db/migration/V3__exclude.sql +++ /dev/null @@ -1,5 +0,0 @@ -CREATE TABLE unsupported_yet ( - field_a JSON, - field_b JSONB, - field_c INT[] -); diff --git a/modules/sbt/src/sbt-test/test/basic/test b/modules/sbt/src/sbt-test/test/basic/test deleted file mode 100644 index 6c309c5..0000000 --- a/modules/sbt/src/sbt-test/test/basic/test +++ /dev/null @@ -1,6 +0,0 @@ -> +pgCodeGen -$ exists target/scala-2.13/src_managed/main/com/example/TestTable.scala -$ exists target/scala-2.13/src_managed/main/com/example/TestRow.scala -$ exists target/scala-3.3.4/src_managed/main/com/example/TestTable.scala -$ exists target/scala-3.3.4/src_managed/main/com/example/TestRow.scala -> +compile diff --git a/project/build.properties b/project/build.properties deleted file mode 100644 index db1723b..0000000 --- a/project/build.properties +++ /dev/null @@ -1 +0,0 @@ -sbt.version=1.10.5 diff --git a/project/plugins.sbt b/project/plugins.sbt deleted file mode 100644 index fa18e86..0000000 --- a/project/plugins.sbt +++ /dev/null @@ -1,5 +0,0 @@ -addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.2") -addSbtPlugin("org.typelevel" % "sbt-tpolecat" % "0.5.2") -addSbtPlugin("com.github.sbt" % "sbt-dynver" % "5.0.1") -addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.11.2") -addSbtPlugin("com.github.sbt" % "sbt-pgp" % "2.3.0") diff --git a/scripts/tag.sh b/scripts/tag.sh deleted file mode 100755 index e9f6ce7..0000000 --- a/scripts/tag.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/usr/bin/env bash - -set -e - -echo 'Fetching tag from remote...' -git tag -l | xargs git tag -d -git fetch --tags - -if ! git describe --exact-match 2>/dev/null; then - echo 'Not tag found...' - - last_tag=`git describe --abbrev=0 --tags` - current_version=${last_tag#'v'} - - echo "Current version ${current_version}" - - #replace . with space so can split into an array - current_version_parts=(${current_version//./ }) - - #get number parts and increase last one by 1 - current_version_major=${current_version_parts[0]} - current_version_minor=${current_version_parts[1]} - current_version_build=${current_version_parts[2]} - - next_version_build=$((current_version_build+1)) - next_version="$current_version_major.$current_version_minor.$next_version_build" - next_tag="v${next_version}" - - echo "Tagging the current commit with ${next_tag}" - - git tag -a ${next_tag} -m "add tag "${next_version} - - echo "Pushing tag ${next_tag} to origin" - git push origin ${next_tag} - -else - echo 'Tag found, no tag will be add' -fi \ No newline at end of file diff --git a/test.sh b/test.sh new file mode 100755 index 0000000..0b1353d --- /dev/null +++ b/test.sh @@ -0,0 +1,74 @@ +#!/usr/bin/env bash +set -e + +# generate binary +CODEGEN_BIN=out/skunk-codegen-$(uname -m)-$(uname | tr '[:upper:]' '[:lower:]') +scala-cli --power package \ + --native \ + --native-mode release-fast PgCodeGen.scala \ + -o $CODEGEN_BIN -f + +echo "⏳Test generated code" +$CODEGEN_BIN \ + -use-docker-image=postgres:17-alpine \ + -output-dir=test-generated \ + -pkg-name=generated \ + -exclude-tables=unsupported_yet \ + -source-dir=test/migrations \ + -force=true + +TIMESTAMP_A=$(stat test-generated | grep Modify) + +# run test for generated code +scala-cli run PgCodeGenTest.scala +echo "✅ Test of generated code successful" + +echo "⏳running generator again with -force=true should re-run code generation" +./$CODEGEN_BIN \ + -use-docker-image="postgres:17-alpine" \ + -output-dir=test-generated \ + -pkg-name=generated \ + -exclude-tables=unsupported_yet \ + -source-dir=test/migrations \ + -force=true + +TIMESTAMP_B=$(stat test-generated | grep Modify) + +if [ "$TIMESTAMP_A" != "$TIMESTAMP_B" ]; then + echo "✅ Code generation with -force=true as expected (timestamps differ)" +else + echo "❌ Error: Code generation did not re-run (timestamps are the same)" + exit 1 +fi + +echo "⏳ running generator again with -force=false should not run code generation" +./$CODEGEN_BIN \ + -use-docker-image="postgres:17-alpine" \ + -output-dir=test-generated \ + -pkg-name=generated \ + -exclude-tables=unsupported_yet \ + -source-dir=test/migrations \ + -force=false + +TIMESTAMP_C=$(stat test-generated | grep Modify) + +if [ "$TIMESTAMP_B" == "$TIMESTAMP_C" ]; then + echo "✅ Code generation with -force=false as expected (timestamps are the same)" +else + echo "❌ Error: Code generation -force=false not as expected (timestamps differ)" + exit 1 +fi + +echo "⏳ running code generator with provided connection" +docker run --rm --name codegentest -e POSTGRES_PASSWORD=postgres -p 5555:5432 -d postgres:17-alpine + +(./$CODEGEN_BIN \ + -use-docker-image="postgres:17-alpine" \ + -output-dir=test-generated \ + -pkg-name=generated \ + -exclude-tables=unsupported_yet \ + -source-dir=test/migrations \ + -use-connection=postgresql://postgres:postgres@localhost:5555/postgres \ + -force=true && echo "✅ Code generation for provided connection ok.") || (docker rm -f codegentest; exit 1) + +docker rm -f codegentest diff --git a/test/dumbo b/test/dumbo new file mode 100755 index 0000000..95c1db0 Binary files /dev/null and b/test/dumbo differ diff --git a/test/migrations/R__a_repeatable.sql b/test/migrations/R__a_repeatable.sql new file mode 100644 index 0000000..f865419 --- /dev/null +++ b/test/migrations/R__a_repeatable.sql @@ -0,0 +1 @@ +CREATE VIEW r_test_view_a AS SELECT 1; diff --git a/test/migrations/R__b_repeatable.sql b/test/migrations/R__b_repeatable.sql new file mode 100644 index 0000000..6ac1c64 --- /dev/null +++ b/test/migrations/R__b_repeatable.sql @@ -0,0 +1 @@ +CREATE VIEW r_test_view_b AS SELECT * FROM r_test_view_a; diff --git a/modules/core/src/test/resources/db/migration/V1__test.sql b/test/migrations/V1__test.sql similarity index 92% rename from modules/core/src/test/resources/db/migration/V1__test.sql rename to test/migrations/V1__test.sql index 72c2bfd..600c5a7 100644 --- a/modules/core/src/test/resources/db/migration/V1__test.sql +++ b/test/migrations/V1__test.sql @@ -40,7 +40,7 @@ CREATE TABLE test_ref_pk ( ref_name VARCHAR NOT NULL ); -CREATE MATERIALIZED VIEW public.test_materialized_view AS SELECT id, +CREATE MATERIALIZED VIEW test_materialized_view AS SELECT id, created_at, name, name_2, @@ -48,7 +48,7 @@ CREATE MATERIALIZED VIEW public.test_materialized_view AS SELECT id, FROM test WITH DATA; -CREATE VIEW public.test_view AS SELECT id, +CREATE VIEW test_view AS SELECT id, created_at, name, name_2, diff --git a/modules/core/src/test/resources/db/migration/V2__test_b.sql b/test/migrations/V2__test_b.sql similarity index 100% rename from modules/core/src/test/resources/db/migration/V2__test_b.sql rename to test/migrations/V2__test_b.sql diff --git a/modules/core/src/test/resources/db/migration/V3__exclude.sql b/test/migrations/V3__exclude.sql similarity index 100% rename from modules/core/src/test/resources/db/migration/V3__exclude.sql rename to test/migrations/V3__exclude.sql