diff --git a/README.md b/README.md index 5203ced..b051eb4 100644 --- a/README.md +++ b/README.md @@ -81,20 +81,33 @@ This means that whenever you are working with variables coming Spring, you gener One particular place to watch out for this is when using Spring's `@RequestParam` and `@PathVariable` annotations in controllers. ## Spring's ThreadLocal Context -Much of Spring's async programming model relies on ThreadLocal context. This used to be a common pattern in Java, but not one that is used in Scala. +Much of Spring's async programming model relies on ThreadLocal context, particularly when using WebMVC. This used to be a common pattern in Java, but not one that is used in Scala. This becomes particularly annoying when interfacing between things like controller entry points and services and utilities that are built around IO/Future/ZIO etc. monads. Effectively, trying to access something like Spring Security's SecurityContext from these methods -will not work. The best solution I have found is to pass the SecurityContext and any other ThreadLocal context as an argument to +will not work. Without going into too much detail WebFlux has the same basic problem, even though its not technically using ThreadLocal context. + +The best solution I have found is to pass the SecurityContext and any other ThreadLocal / pseudo global context data as an argument to these methods. This is not ideal, but it is the best solution I have found so far. ## Async Programming Spring has it's own mechanisms for async programming, and it takes some work to adapt it to be compatible with IO monads. -Even after adapting these mechanisms we are left with having to manage an additional threadpool to accommodate Spring. -The other challenge here is adapting the handling of uncaught exceptions so that Spring's conventional mechanisms will +Even after adapting these mechanisms we are left with having to manage an additional threadpool(s) to accommodate Spring. +Another challenge here is adapting the handling of uncaught exceptions so that Spring's conventional mechanisms will continue to function. -I've not gotten around to adding this to the example yet, but it is doable, and once it's done you can pretty much forget -about it. +### Async Controllers +The original version of this project used WebMVC which is built on top of Apache Tomcat and has its own async programming model. +I've since switched to using WebFlux which is built on top of Netty and is generally considered to be more performant, particularly +when it comes to servicing large numbers of requests concurrently. I would not be surprised if this changes in the future +thanks to the work being done on Project Loom. For those interested in exploring this further, check out the [webmvc tag](https://github.com/halfhp/ScalaSpringExperiment/releases/tag/webmvc) +of this repository. + +### Async Database Drivers +This project uses Doobie, which is built on top of JDBC which is synchronous. There is another library, Skunk, which is written +by the same author and offers similar functionality. It's fully asynchronous but also locks you into using Postgres. + +Another option would be to use one Spring's database facilities that supports R2DBC, which is also async. I've not tried this approach +yet but imagine it could be wrapped with cats-effect IO similarly to what was done with [Mono] in the controller layer. # Future Improvements ## Spring Security diff --git a/build.gradle b/build.gradle index 4dc951d..54170e2 100644 --- a/build.gradle +++ b/build.gradle @@ -21,7 +21,7 @@ repositories { dependencies { implementation 'org.scala-lang:scala3-library_3:3.6.4' implementation 'org.typelevel:cats-effect_3:3.6.1' - implementation('org.springframework.boot:spring-boot-starter-web') { + implementation("org.springframework.boot:spring-boot-starter-webflux") { exclude group: 'com.fasterxml.jackson.core' exclude group: 'com.fasterxml.jackson.datatype' exclude group: 'com.fasterxml.jackson.module' @@ -55,6 +55,10 @@ dependencies { annotationProcessor 'org.springframework.boot:spring-boot-configuration-processor' testImplementation 'org.springframework.boot:spring-boot-starter-test' + + testImplementation('com.github.javafaker:javafaker:1.0.2') { + exclude group: 'org.yaml' + } } test { diff --git a/docker-compose.yml b/docker-compose.yml index 614d1c6..ca7e105 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,12 +2,12 @@ version: '3.8' services: sse_app: + container_name: sse_app build: context: . dockerfile: Dockerfile ports: - "8080:8080" - container_name: sse_app environment: SPRING_DATASOURCE_URL: jdbc:postgresql://sse_postgres:5432/postgres SPRING_DATASOURCE_USERNAME: postgres @@ -16,8 +16,8 @@ services: - sse_postgres sse_postgres: - image: postgres:15 container_name: sse_postgres + image: postgres:15 environment: POSTGRES_DB: postgres POSTGRES_USER: postgres diff --git a/extras/taurus/benchmark_localhost.sh b/extras/taurus/benchmark_localhost.sh new file mode 100644 index 0000000..1596ab8 --- /dev/null +++ b/extras/taurus/benchmark_localhost.sh @@ -0,0 +1,2 @@ +docker run -it --rm -v ./tests:/bzt-configs blazemeter/taurus test.yml -o settings.check-plugins=false + diff --git a/extras/taurus/tests/test.yml b/extras/taurus/tests/test.yml new file mode 100644 index 0000000..751f64d --- /dev/null +++ b/extras/taurus/tests/test.yml @@ -0,0 +1,38 @@ +#modules: +# jmeter: +# disable-plugins: +# - aggregate-report +# - view-results-tree +# - view-results-in-table +# - summary-report +# java-opts: +# - "-Djava.awt.headless=true" +# - "-XX:-TieredCompilation" +# - "-Xmx512m" + +execution: + - concurrency: 10 + ramp-up: 30s + hold-for: 1m + scenario: simple + +scenarios: + simple: + requests: + - url: http://host.docker.internal:8080/ + method: GET + +#settings: +# artifacts-dir: /output + +reporting: + - module: console + - module: final-stats + summary: true +# - module: junit-xml +# filename: /output/results.xml + +monitoring: + - module: local + cpu: true + memory: true diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index 7454180..e644113 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 0d18421..ca025c8 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,7 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-8.8-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.14-bin.zip +networkTimeout=10000 +validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew index 744e882..b740cf1 100755 --- a/gradlew +++ b/gradlew @@ -1,7 +1,7 @@ -#!/usr/bin/env sh +#!/bin/sh # -# Copyright 2015 the original author or authors. +# Copyright © 2015-2021 the original authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,67 +17,99 @@ # ############################################################################## -## -## Gradle start up script for UN*X -## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/platforms/jvm/plugins-application/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# ############################################################################## # Attempt to set APP_HOME + # Resolve links: $0 may be a link -PRG="$0" -# Need this for relative symlinks. -while [ -h "$PRG" ] ; do - ls=`ls -ld "$PRG"` - link=`expr "$ls" : '.*-> \(.*\)$'` - if expr "$link" : '/.*' > /dev/null; then - PRG="$link" - else - PRG=`dirname "$PRG"`"/$link" - fi +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac done -SAVED="`pwd`" -cd "`dirname \"$PRG\"`/" >/dev/null -APP_HOME="`pwd -P`" -cd "$SAVED" >/dev/null -APP_NAME="Gradle" -APP_BASE_NAME=`basename "$0"` - -# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit # Use the maximum available, or set MAX_FD != -1 to use that value. -MAX_FD="maximum" +MAX_FD=maximum warn () { echo "$*" -} +} >&2 die () { echo echo "$*" echo exit 1 -} +} >&2 # OS specific support (must be 'true' or 'false'). cygwin=false msys=false darwin=false nonstop=false -case "`uname`" in - CYGWIN* ) - cygwin=true - ;; - Darwin* ) - darwin=true - ;; - MSYS* | MINGW* ) - msys=true - ;; - NONSTOP* ) - nonstop=true - ;; +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; esac CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar @@ -87,9 +119,9 @@ CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar if [ -n "$JAVA_HOME" ] ; then if [ -x "$JAVA_HOME/jre/sh/java" ] ; then # IBM's JDK on AIX uses strange locations for the executables - JAVACMD="$JAVA_HOME/jre/sh/java" + JAVACMD=$JAVA_HOME/jre/sh/java else - JAVACMD="$JAVA_HOME/bin/java" + JAVACMD=$JAVA_HOME/bin/java fi if [ ! -x "$JAVACMD" ] ; then die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME @@ -98,88 +130,120 @@ Please set the JAVA_HOME variable in your environment to match the location of your Java installation." fi else - JAVACMD="java" - which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the location of your Java installation." + fi fi # Increase the maximum file descriptors if we can. -if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then - MAX_FD_LIMIT=`ulimit -H -n` - if [ $? -eq 0 ] ; then - if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then - MAX_FD="$MAX_FD_LIMIT" - fi - ulimit -n $MAX_FD - if [ $? -ne 0 ] ; then - warn "Could not set maximum file descriptor limit: $MAX_FD" - fi - else - warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" - fi +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac fi -# For Darwin, add options to specify how the application appears in the dock -if $darwin; then - GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" -fi +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. # For Cygwin or MSYS, switch paths to Windows format before running java -if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then - APP_HOME=`cygpath --path --mixed "$APP_HOME"` - CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` - - JAVACMD=`cygpath --unix "$JAVACMD"` - - # We build the pattern for arguments to be converted via cygpath - ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` - SEP="" - for dir in $ROOTDIRSRAW ; do - ROOTDIRS="$ROOTDIRS$SEP$dir" - SEP="|" - done - OURCYGPATTERN="(^($ROOTDIRS))" - # Add a user-defined pattern to the cygpath arguments - if [ "$GRADLE_CYGPATTERN" != "" ] ; then - OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" - fi +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + # Now convert the arguments - kludge to limit ourselves to /bin/sh - i=0 - for arg in "$@" ; do - CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` - CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option - - if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition - eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` - else - eval `echo args$i`="\"$arg\"" + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) fi - i=`expr $i + 1` + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg done - case $i in - 0) set -- ;; - 1) set -- "$args0" ;; - 2) set -- "$args0" "$args1" ;; - 3) set -- "$args0" "$args1" "$args2" ;; - 4) set -- "$args0" "$args1" "$args2" "$args3" ;; - 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; - 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; - 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; - 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; - 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; - esac fi -# Escape application args -save () { - for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done - echo " " -} -APP_ARGS=`save "$@"` -# Collect all arguments for the java command, following the shell quoting and substitution rules -eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' exec "$JAVACMD" "$@" diff --git a/gradlew.bat b/gradlew.bat index 107acd3..25da30d 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -14,7 +14,7 @@ @rem limitations under the License. @rem -@if "%DEBUG%" == "" @echo off +@if "%DEBUG%"=="" @echo off @rem ########################################################################## @rem @rem Gradle startup script for Windows @@ -25,7 +25,8 @@ if "%OS%"=="Windows_NT" setlocal set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% @@ -40,13 +41,13 @@ if defined JAVA_HOME goto findJavaFromJavaHome set JAVA_EXE=java.exe %JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto execute +if %ERRORLEVEL% equ 0 goto execute -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 goto fail @@ -56,11 +57,11 @@ set JAVA_EXE=%JAVA_HOME%/bin/java.exe if exist "%JAVA_EXE%" goto execute -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 goto fail @@ -75,13 +76,15 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar :end @rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd +if %ERRORLEVEL% equ 0 goto mainEnd :fail rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% :mainEnd if "%OS%"=="Windows_NT" endlocal diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index dfa6827..19286f8 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -1,3 +1,4 @@ +spring.profiles.active=default server.port=8080 spring.flyway.baselineOnMigrate = false diff --git a/src/main/resources/db/migration/V1.0.sql b/src/main/resources/db/migration/V1.0.sql index 42280cb..c3c31a3 100644 --- a/src/main/resources/db/migration/V1.0.sql +++ b/src/main/resources/db/migration/V1.0.sql @@ -1,3 +1,5 @@ +create extension if not exists "postgis"; + CREATE OR REPLACE FUNCTION set_timestamp_fields() RETURNS TRIGGER AS $$ BEGIN @@ -28,13 +30,14 @@ CREATE TABLE address ( id BIGSERIAL PRIMARY KEY, date_created timestamp not null, last_updated timestamp not null, - person_id BIGINT not null + person_id BIGINT not null constraint external_association_user_id_fk references person on delete cascade, street varchar, city varchar, - state varchar + state varchar, + coordinates geometry ); CREATE TRIGGER set_address_timestamps diff --git a/src/main/scala/com/example/scalaspringexperiment/SpringConfig.scala b/src/main/scala/com/example/scalaspringexperiment/SpringConfig.scala index 37dfbd9..9b21768 100644 --- a/src/main/scala/com/example/scalaspringexperiment/SpringConfig.scala +++ b/src/main/scala/com/example/scalaspringexperiment/SpringConfig.scala @@ -2,31 +2,28 @@ package com.example.scalaspringexperiment import cats.effect.unsafe.IORuntime import cats.effect.{IO, Resource} -import com.example.scalaspringexperiment.util.CirceHttpMessageConverter +import com.example.scalaspringexperiment.util.{CirceJsonDecoder, CirceJsonEncoder} import doobie.{DataSourceTransactor, ExecutionContexts} import doobie.util.transactor.Transactor -import org.springframework.boot.autoconfigure.http.HttpMessageConverters import org.springframework.context.annotation.{Bean, Configuration} +import org.springframework.http.codec.ServerCodecConfigurer import org.springframework.security.config.Customizer -import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity -import org.springframework.security.config.annotation.web.builders.HttpSecurity -import org.springframework.security.config.http.SessionCreationPolicy -import org.springframework.security.web.SecurityFilterChain +import org.springframework.security.config.annotation.method.configuration.EnableReactiveMethodSecurity +import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity +import org.springframework.security.config.web.server.ServerHttpSecurity +import org.springframework.security.web.server.SecurityWebFilterChain +import org.springframework.security.web.server.context.NoOpServerSecurityContextRepository +import org.springframework.web.reactive.config.WebFluxConfigurer import javax.sql.DataSource @Configuration -@EnableMethodSecurity(prePostEnabled = true) +@EnableWebFluxSecurity +@EnableReactiveMethodSecurity class SpringConfig( dataSource: DataSource, ) { - @Bean - def getCustomConverters(): HttpMessageConverters = { - val circe = new CirceHttpMessageConverter() - new HttpMessageConverters(circe) - } - @Bean def catsEffectIORuntime(): IORuntime = { cats.effect.unsafe.implicits.global @@ -40,18 +37,31 @@ class SpringConfig( } @Bean - def securityFilterChain( - http: HttpSecurity - ): SecurityFilterChain = { + def securityFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain = { http .cors(Customizer.withDefaults()) - - // we'll be using stateless JWT authentication, and csrf messes with mockmvc tests - // so we're disabling csrf. alternatively, this could be disabled only for testing - // in the test config. - .csrf(csrf => csrf.disable()) - .sessionManagement(sm => sm.sessionCreationPolicy(SessionCreationPolicy.STATELESS)) + .csrf(csrf => csrf.disable()) // Stateless app using JWT + .securityContextRepository(NoOpServerSecurityContextRepository.getInstance()) // optional, disables session caching + .authorizeExchange(authz => + authz.anyExchange().permitAll() + ) +// .httpBasic().disable() // or leave enabled if using basic auth +// .formLogin().disable() .build() } } +@Configuration +class CirceWebFluxConfig extends WebFluxConfigurer { + + override def configureHttpMessageCodecs(configurer: ServerCodecConfigurer): Unit = { + // disable jackson codecs + configurer.defaultCodecs().jackson2JsonDecoder(null) + configurer.defaultCodecs().jackson2JsonEncoder(null) + + // register circe codecs + configurer.customCodecs().register(new CirceJsonDecoder()) + configurer.customCodecs().register(new CirceJsonEncoder()) + } +} + diff --git a/src/main/scala/com/example/scalaspringexperiment/controller/ControllerErrorHandler.scala b/src/main/scala/com/example/scalaspringexperiment/controller/ControllerErrorHandler.scala new file mode 100644 index 0000000..f5b43f8 --- /dev/null +++ b/src/main/scala/com/example/scalaspringexperiment/controller/ControllerErrorHandler.scala @@ -0,0 +1,28 @@ +package com.example.scalaspringexperiment.controller + +import io.circe.Json +import org.springframework.http.HttpStatus +import org.springframework.web.bind.annotation.{ExceptionHandler, RestControllerAdvice} +import org.springframework.web.server.{ResponseStatusException, ServerWebExchange} +import reactor.core.publisher.Mono + +@RestControllerAdvice +class ControllerErrorHandler { + + @ExceptionHandler(Array(classOf[Throwable])) + def handleError(exchange: ServerWebExchange, ex: Throwable): Mono[Json] = { + val path = exchange.getRequest.getPath.toString + val status = ex match { + case e: ResponseStatusException => e.getStatusCode.value() + case _ => HttpStatus.INTERNAL_SERVER_ERROR.value() + } + + val errorJson = Json.obj( + "error" -> Json.fromString(ex.getMessage), + "path" -> Json.fromString(path), + "status" -> Json.fromInt(status) + ) + + Mono.just(errorJson) + } +} diff --git a/src/main/scala/com/example/scalaspringexperiment/controller/SimpleController.scala b/src/main/scala/com/example/scalaspringexperiment/controller/SimpleController.scala index 1ceb314..7b564b1 100644 --- a/src/main/scala/com/example/scalaspringexperiment/controller/SimpleController.scala +++ b/src/main/scala/com/example/scalaspringexperiment/controller/SimpleController.scala @@ -1,6 +1,5 @@ package com.example.scalaspringexperiment.controller -import cats.data.OptionT import com.example.scalaspringexperiment.service.{AddressService, PersonService} import cats.effect.IO import cats.effect.unsafe.implicits.global @@ -9,39 +8,68 @@ import doobie.implicits.* import io.circe.* import io.circe.generic.auto.* import io.circe.syntax.* -import com.example.scalaspringexperiment.util.MyJsonCodecs.timestampCodec -import org.springframework.beans.factory.annotation.Autowired +import com.example.scalaspringexperiment.util.MyJsonCodecs.* +import org.springframework.http.ResponseEntity import org.springframework.security.access.prepost.PreAuthorize import org.springframework.web.bind.annotation.* +import reactor.core.publisher.Mono +import java.util.concurrent.CompletableFuture import scala.language.implicitConversions +import scala.util.chaining.* /** - * The simplest controller possible. Not something one would actually use in a serious project, but still useful - * as a simplified example. The AsyncController (TODO) is a more realistic example for production use. + * A simple async REST controller */ @RestController -@Autowired class SimpleController( personService: PersonService, addressService: AddressService ) { - implicit def ioToA[A](io: IO[A]): A = { - io.unsafeRunSync() + private implicit def ioToMono[A](io: IO[A]): Mono[A] = { + Mono.fromFuture(new CompletableFuture[A]().tap { cf => + io.unsafeRunAsync { + case Right(value) => cf.complete(value) + case Left(error) => cf.completeExceptionally(error) + } + }) } @PreAuthorize("permitAll()") - @GetMapping(path = Array("/person/{id}/detailed")) - def getDetailedPerson( + @GetMapping(path = Array("/")) + def index(): Mono[ResponseEntity[Json]] = { + Mono.just(ResponseEntity.ok(Json.obj( + "message" -> Json.fromString("Hello, world!"), + ))) + } + + @PreAuthorize("permitAll()") + @GetMapping(path = Array("/person/{id}")) + def getPersonById( @PathVariable id: Long, - ): Json = { - (for { - person <- OptionT(personService.findById(id)) - addresses <- OptionT.liftF(addressService.findByPersonId(id)) + ): Mono[Json] = { + for { + person <- personService.findById(id) } yield Json.obj( "person" -> person.asJson, - "addresses" -> addresses.asJson - )).value.getOrElse(Json.obj()) + ) + } + + @PreAuthorize("permitAll()") + @GetMapping(path = Array("/person/{id}/detailed")) + def getDetailedPersonById( + @PathVariable id: Long, + ): Mono[ResponseEntity[Json]] = { + for { + person <- personService.findById(id) + addresses <- addressService.findByPersonId(id) + } yield person match { + case Some(person) => ResponseEntity.ok(Json.obj( + "person" -> person.asJson, + "addresses" -> addresses.asJson + )) + case None => ResponseEntity.notFound().build() + } } } diff --git a/src/main/scala/com/example/scalaspringexperiment/entity/Address.scala b/src/main/scala/com/example/scalaspringexperiment/entity/Address.scala index ac2c93d..97d8148 100644 --- a/src/main/scala/com/example/scalaspringexperiment/entity/Address.scala +++ b/src/main/scala/com/example/scalaspringexperiment/entity/Address.scala @@ -1,6 +1,7 @@ package com.example.scalaspringexperiment.entity import com.example.scalaspringexperiment.dao.{Column, Table} +import net.postgis.jdbc.geometry.Point import java.sql.Timestamp @@ -19,11 +20,14 @@ case class Address( personId: Long, @Column("street") - street: Option[String], + street: Option[String] = None, @Column("city") - city: Option[String], + city: Option[String] = None, @Column("state") - state: Option[String] + state: Option[String] = None, + + @Column("coordinates") + coordinates: Option[Point] = None, ) extends DomainEntity diff --git a/src/main/scala/com/example/scalaspringexperiment/service/AddressService.scala b/src/main/scala/com/example/scalaspringexperiment/service/AddressService.scala index ef9b333..fa00b58 100644 --- a/src/main/scala/com/example/scalaspringexperiment/service/AddressService.scala +++ b/src/main/scala/com/example/scalaspringexperiment/service/AddressService.scala @@ -3,10 +3,12 @@ package com.example.scalaspringexperiment.service import cats.effect.{IO, Resource} import com.example.scalaspringexperiment.dao.{Dao, TableInfo} import com.example.scalaspringexperiment.entity.Address +import com.example.scalaspringexperiment.util.PointUtils import doobie.{DataSourceTransactor, Fragment} import doobie.implicits.toSqlInterpolator import doobie.util.{Read, Write} import doobie.implicits.* +import doobie.postgres.pgisimplicits.PointType import org.slf4j.LoggerFactory import org.springframework.stereotype.Service @@ -29,4 +31,17 @@ class AddressService( WHERE person_id = $personId """.query[Address].to[Seq].transact(xa) } + + def findWithinDistance( + lat: Double, + lon: Double, + distanceInMeters: Float + ): IO[Seq[Address]] = ds.use { xa => + val point = PointUtils.pointFromLatLon(lat = lat, lon = lon) + val theTableName = Fragment.const0(tableInfo.table.name) + sql""" + SELECT * FROM $theTableName + WHERE st_distancesphere(coordinates, $point) < $distanceInMeters + """.query[Address].to[Seq].transact(xa) + } } diff --git a/src/main/scala/com/example/scalaspringexperiment/util/CirceHttpMessageConverter.scala b/src/main/scala/com/example/scalaspringexperiment/util/CirceHttpMessageConverter.scala deleted file mode 100644 index 2c28c78..0000000 --- a/src/main/scala/com/example/scalaspringexperiment/util/CirceHttpMessageConverter.scala +++ /dev/null @@ -1,41 +0,0 @@ -package com.example.scalaspringexperiment.util - -import com.example.scalaspringexperiment.util.CirceHttpMessageConverter.{CirceJsonType, ObjectMapType} -import io.circe.* -import io.circe.generic.auto.* -import io.circe.parser.* -import io.circe.syntax.* -import org.springframework.http.converter.json.AbstractJsonHttpMessageConverter - -import java.io.{BufferedReader, Reader, Writer} -import java.lang.reflect.Type -import java.util -import java.util.stream.Collectors - -object CirceHttpMessageConverter { - val CirceJsonType = "io.circe.Json" - val ObjectMapType = "java.util.Map" -} - -/** - * Custom HTTP message converter for Circe JSON serialization/deserialization. - */ -class CirceHttpMessageConverter extends AbstractJsonHttpMessageConverter { - - override def readInternal(resolvedType: Type, reader: Reader): AnyRef = { - val br = new BufferedReader(reader) - val data = br.lines().collect(Collectors.joining()) // TODO: is this stripping newlines? - parse(data).getOrElse(throw RuntimeException("Invalid JSON")) - } - - override def writeInternal(obj: Object, t: Type, writer: Writer): Unit = { - t.getTypeName match { - case CirceJsonType => writer.write(obj.asInstanceOf[Json].noSpaces) - case ObjectMapType => - val message = obj.asInstanceOf[java.util.Map[String, Object]].get("error").toString - writer.write(RestError(message).asJson.noSpaces) - case _ => writer.write(RestError("Something went wrong").asJson.noSpaces) - } - } - -} diff --git a/src/main/scala/com/example/scalaspringexperiment/util/CirceJsonDecoder.scala b/src/main/scala/com/example/scalaspringexperiment/util/CirceJsonDecoder.scala new file mode 100644 index 0000000..94042f3 --- /dev/null +++ b/src/main/scala/com/example/scalaspringexperiment/util/CirceJsonDecoder.scala @@ -0,0 +1,42 @@ +package com.example.scalaspringexperiment.util + +import io.circe.Json +import io.circe.parser.decode +import org.reactivestreams.Publisher +import org.springframework.core.ResolvableType +import org.springframework.core.codec.Decoder +import org.springframework.core.io.buffer.DataBuffer +import org.springframework.core.io.buffer.DataBufferUtils +import org.springframework.http.{MediaType, ResponseEntity} +import org.springframework.util.MimeType +import reactor.core.publisher.Mono +import reactor.core.publisher.Flux + +import java.nio.charset.StandardCharsets +import java.util +import java.lang.reflect.Type + +class CirceJsonDecoder extends Decoder[io.circe.Json] { + + override def canDecode(elementType: ResolvableType, mimeType: MimeType): Boolean = { + mimeType == null || mimeType.isCompatibleWith(MediaType.APPLICATION_JSON) + } + + override def decode(input: Publisher[DataBuffer], elementType: ResolvableType, mimeType: MimeType, hints: util.Map[String, AnyRef]): Flux[io.circe.Json] = { + Flux.from(input).flatMap { buffer => + val jsonStr = StandardCharsets.UTF_8.decode(buffer.asByteBuffer()).toString + DataBufferUtils.release(buffer) + io.circe.parser.decode[io.circe.Json](jsonStr) match { + case Right(json) => Flux.just(json) + case Left(err) => Flux.error(new RuntimeException(s"JSON decoding error: ${err.getMessage}")) + } + } + } + + override def decodeToMono(input: Publisher[DataBuffer], elementType: ResolvableType, mimeType: MimeType, hints: util.Map[String, AnyRef]): Mono[io.circe.Json] = { + decode(input, elementType, mimeType, hints).single() + } + + override def getDecodableMimeTypes: util.List[MimeType] = + util.Arrays.asList(MediaType.APPLICATION_JSON) +} diff --git a/src/main/scala/com/example/scalaspringexperiment/util/CirceJsonEncoder.scala b/src/main/scala/com/example/scalaspringexperiment/util/CirceJsonEncoder.scala new file mode 100644 index 0000000..3796140 --- /dev/null +++ b/src/main/scala/com/example/scalaspringexperiment/util/CirceJsonEncoder.scala @@ -0,0 +1,35 @@ +package com.example.scalaspringexperiment.util + +import io.circe.Json +import org.springframework.core.ResolvableType +import org.springframework.core.codec.Encoder +import org.springframework.core.io.buffer.{DataBuffer, DataBufferFactory} +import org.springframework.http.MediaType +import org.springframework.util.MimeType +import reactor.core.publisher.Flux + +import java.nio.charset.StandardCharsets +import java.util +import org.reactivestreams.Publisher + +class CirceJsonEncoder extends Encoder[io.circe.Json] { + + override def canEncode(elementType: ResolvableType, mimeType: MimeType): Boolean = + mimeType == null || mimeType.isCompatibleWith(MediaType.APPLICATION_JSON) + + override def encode(inputStream: Publisher[_ <: io.circe.Json], bufferFactory: DataBufferFactory, elementType: ResolvableType, mimeType: MimeType, hints: util.Map[String, AnyRef]): Flux[DataBuffer] = { + Flux.from(inputStream).map { json => + val bytes = json.noSpaces.getBytes(StandardCharsets.UTF_8) + val buffer = bufferFactory.wrap(bytes) + buffer + } + } + + override def encodeValue(value: io.circe.Json, bufferFactory: DataBufferFactory, elementType: ResolvableType, mimeType: MimeType, hints: util.Map[String, AnyRef]): DataBuffer = { + val bytes = value.noSpaces.getBytes(StandardCharsets.UTF_8) + bufferFactory.wrap(bytes) + } + + override def getEncodableMimeTypes: util.List[MimeType] = + util.Arrays.asList(MediaType.APPLICATION_JSON) +} diff --git a/src/main/scala/com/example/scalaspringexperiment/util/MyJsonCodecs.scala b/src/main/scala/com/example/scalaspringexperiment/util/MyJsonCodecs.scala index a40ca70..72ac18c 100644 --- a/src/main/scala/com/example/scalaspringexperiment/util/MyJsonCodecs.scala +++ b/src/main/scala/com/example/scalaspringexperiment/util/MyJsonCodecs.scala @@ -1,6 +1,7 @@ package com.example.scalaspringexperiment.util -import io.circe.{Codec, Encoder, HCursor} +import io.circe.{Codec, Encoder, HCursor, Json} +import net.postgis.jdbc.geometry.Point import java.sql.Timestamp import java.time.Instant @@ -16,4 +17,21 @@ object MyJsonCodecs { }, encodeA = Encoder.encodeLong.contramap[Timestamp](_.getTime) ) + + implicit val pointCodec: Codec[Point] = Codec.from( + decodeA = (c: HCursor) => { + c.value match { + case v if v.isObject => + for { + lat <- v.hcursor.get[Double]("lat") + lon <- v.hcursor.get[Double]("lon") + } yield new Point(lat, lon) + case _ => ??? + } + }, + encodeA = Encoder.encodeJson.contramap[Point](p => Json.obj( + "lat" -> Json.fromDoubleOrNull(p.getX), + "lon" -> Json.fromDoubleOrNull(p.getY), + )) + ) } diff --git a/src/main/scala/com/example/scalaspringexperiment/util/PointUtils.scala b/src/main/scala/com/example/scalaspringexperiment/util/PointUtils.scala new file mode 100644 index 0000000..86fdb33 --- /dev/null +++ b/src/main/scala/com/example/scalaspringexperiment/util/PointUtils.scala @@ -0,0 +1,13 @@ +package com.example.scalaspringexperiment.util + +import net.postgis.jdbc.geometry.Point + +object PointUtils { + + def pointFromLatLon( + lat: Double, + lon: Double, + ): Point = { + new Point(lat, lon) + } +} diff --git a/src/test/scala/com/example/scalaspringexperiment/controller/SimpleControllerTest.scala b/src/test/scala/com/example/scalaspringexperiment/controller/SimpleControllerTest.scala index 8c5702c..4cb4474 100644 --- a/src/test/scala/com/example/scalaspringexperiment/controller/SimpleControllerTest.scala +++ b/src/test/scala/com/example/scalaspringexperiment/controller/SimpleControllerTest.scala @@ -4,18 +4,15 @@ import cats.effect.unsafe.IORuntime import com.example.scalaspringexperiment.entity.Person import com.example.scalaspringexperiment.service.PersonService import com.example.scalaspringexperiment.test.{SpringTestConfig, TestUtils} -import io.circe.generic.auto.* -import com.example.scalaspringexperiment.util.MyJsonCodecs.timestampCodec -import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.{BeforeEach, Test} -import org.springframework.boot.test.context.SpringBootTest import org.springframework.beans.factory.annotation.Autowired +import org.springframework.boot.test.context.SpringBootTest import org.springframework.context.annotation.Import +import org.springframework.test.web.reactive.server.WebTestClient import scala.compiletime.uninitialized -import scala.util.chaining.* -@SpringBootTest() +@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) @Import(Array(classOf[SpringTestConfig])) class SimpleControllerTest { @@ -25,11 +22,14 @@ class SimpleControllerTest { @Autowired var personService: PersonService = uninitialized + @Autowired + var testUtils: TestUtils = uninitialized + @Autowired implicit var runtime: IORuntime = uninitialized @Autowired - var testUtils: TestUtils = uninitialized + var webTestClient: WebTestClient = uninitialized @BeforeEach def beforeEach(): Unit = { @@ -37,17 +37,25 @@ class SimpleControllerTest { } @Test - def testGetDetailedPerson(): Unit = { + def getDetailedPersonById_returnsPerson(): Unit = { val person = personService.insert( Person(name = "John Doe", age = 30) ).unsafeRunSync() - val result = simpleController.getDetailedPerson(person.id) + webTestClient.get() + .uri(s"/person/${person.id}/detailed") + .exchange() + .expectStatus().isOk + .expectBody() + .jsonPath("$.person.name").isEqualTo("John Doe") + .jsonPath("$.person.age").isEqualTo(30) + } - result.hcursor.get[Person]("person").getOrElse(???).tap { p => - assertEquals(person.id, p.id) - assertEquals(person.name, p.name) - } + @Test + def getDetailedPersonById_returns404_ifPersonNotFound(): Unit = { + webTestClient.get() + .uri("/person/999/detailed") + .exchange() + .expectStatus().isNotFound } } - diff --git a/src/test/scala/com/example/scalaspringexperiment/controller/SimpleControllerTestWithExpectations.scala b/src/test/scala/com/example/scalaspringexperiment/controller/SimpleControllerTestWithExpectations.scala deleted file mode 100644 index e527a2a..0000000 --- a/src/test/scala/com/example/scalaspringexperiment/controller/SimpleControllerTestWithExpectations.scala +++ /dev/null @@ -1,50 +0,0 @@ -package com.example.scalaspringexperiment.controller - -import cats.effect.unsafe.IORuntime -import com.example.scalaspringexperiment.entity.Person -import com.example.scalaspringexperiment.service.{AddressService, PersonService} -import com.example.scalaspringexperiment.test.{SpringTestConfig, TestUtils} -import org.junit.jupiter.api.{BeforeEach, Test} -import org.mockito.Mockito.{times, verify} -import org.springframework.beans.factory.annotation.Autowired -import org.springframework.boot.test.context.SpringBootTest -import org.springframework.context.annotation.Import -import org.springframework.test.context.bean.`override`.mockito.MockitoSpyBean - -import scala.compiletime.uninitialized - -@SpringBootTest -@Import(Array(classOf[SpringTestConfig])) -class SimpleControllerTestWithExpectations { - - @Autowired - var simpleController: SimpleController = uninitialized - - @MockitoSpyBean - var personService: PersonService = uninitialized - - @MockitoSpyBean - var addressService: AddressService = uninitialized - - @Autowired - implicit var runtime: IORuntime = uninitialized - - @Autowired - var testUtils: TestUtils = uninitialized - - @BeforeEach - def beforeEach(): Unit = { - testUtils.truncateTables() - } - - @Test - def testGetDetailedPerson_invokesExpectedServiceMethods(): Unit = { - val person = personService.insert( - Person(name = "John Doe", age = 30) - ).unsafeRunSync() - - simpleController.getDetailedPerson(person.id) - verify(personService, times(1)).findById(person.id) - verify(addressService, times(1)).findByPersonId(person.id) - } -} diff --git a/src/test/scala/com/example/scalaspringexperiment/controller/SimpleControllerTestWithMockMvc.scala b/src/test/scala/com/example/scalaspringexperiment/controller/SimpleControllerTestWithMockMvc.scala deleted file mode 100644 index f69b397..0000000 --- a/src/test/scala/com/example/scalaspringexperiment/controller/SimpleControllerTestWithMockMvc.scala +++ /dev/null @@ -1,56 +0,0 @@ -package com.example.scalaspringexperiment.controller - -import cats.effect.unsafe.IORuntime -import com.example.scalaspringexperiment.entity.Person -import com.example.scalaspringexperiment.service.PersonService -import com.example.scalaspringexperiment.test.{SpringTestConfig, TestUtils} -import org.junit.jupiter.api.{BeforeEach, Test} -import org.springframework.beans.factory.annotation.Autowired -import org.springframework.boot.test.context.SpringBootTest -import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc -import org.springframework.test.web.servlet.MockMvc -import org.springframework.context.annotation.Import -import org.springframework.test.web.servlet.request.MockMvcRequestBuilders -import org.springframework.test.web.servlet.result.MockMvcResultMatchers - -import scala.compiletime.uninitialized - -@SpringBootTest() -@Import(Array(classOf[SpringTestConfig])) -@AutoConfigureMockMvc -class SimpleControllerTestWithMockMvc { - - @Autowired - var simpleController: SimpleController = uninitialized - - @Autowired - var personService: PersonService = uninitialized - - @Autowired - var testUtils: TestUtils = uninitialized - - @Autowired - implicit var runtime: IORuntime = uninitialized - - @Autowired - var mockMvc: MockMvc = uninitialized - - @BeforeEach - def beforeEach(): Unit = { - testUtils.truncateTables() - } - - @Test - def testGetDetailedPerson(): Unit = { - val person = personService.insert( - Person(name = "John Doe", age = 30) - ).unsafeRunSync() - - mockMvc.perform( - MockMvcRequestBuilders.get(s"/person/${person.id}/detailed") - ).andExpect(MockMvcResultMatchers.status().isOk) - .andExpect(MockMvcResultMatchers.jsonPath("$.person.name").value("John Doe")) - .andExpect(MockMvcResultMatchers.jsonPath("$.person.age").value(30)) - } -} - diff --git a/src/test/scala/com/example/scalaspringexperiment/controller/SimpleControllerTestWithMocks.scala b/src/test/scala/com/example/scalaspringexperiment/controller/SimpleControllerTestWithMocks.scala deleted file mode 100644 index 94e177f..0000000 --- a/src/test/scala/com/example/scalaspringexperiment/controller/SimpleControllerTestWithMocks.scala +++ /dev/null @@ -1,58 +0,0 @@ -package com.example.scalaspringexperiment.controller - -import cats.effect.IO -import cats.effect.unsafe.IORuntime -import com.example.scalaspringexperiment.entity.Person -import com.example.scalaspringexperiment.service.{AddressService, PersonService} -import com.example.scalaspringexperiment.test.{SpringTestConfig, TestUtils} -import io.circe.generic.auto.* -import com.example.scalaspringexperiment.util.MyJsonCodecs.timestampCodec -import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.{BeforeEach, Test} -import org.mockito.ArgumentMatchers.anyLong -import org.mockito.Mockito.when -import org.springframework.beans.factory.annotation.Autowired -import org.springframework.boot.test.context.SpringBootTest -import org.springframework.context.annotation.Import -import org.springframework.test.context.bean.`override`.mockito.MockitoBean - -import scala.compiletime.uninitialized -import scala.util.chaining.* - -@SpringBootTest() -@Import(Array(classOf[SpringTestConfig])) -class SimpleControllerTestWithMocks { - - @Autowired - var simpleController: SimpleController = uninitialized - - @MockitoBean - var personService: PersonService = uninitialized - - @MockitoBean - var addressService: AddressService = uninitialized - - @Autowired - implicit var runtime: IORuntime = uninitialized - - @Autowired - var testUtils: TestUtils = uninitialized - - @BeforeEach - def beforeEach(): Unit = { - testUtils.truncateTables() - } - - @Test - def getDetailedPerson_rendersDetailedPersonJson(): Unit = { - val person = Person(id = 123, name = "John Doe", age = 30) - when(personService.findById(anyLong())).thenReturn(IO.pure(Some(person))) - when(addressService.findByPersonId(anyLong())).thenReturn(IO.pure(List())) - - simpleController.getDetailedPerson(person.id).tap { json => - val decodedPerson = json.hcursor.get[Person]("person").getOrElse(???) - assertEquals(person.id, decodedPerson.id) - } - } -} - diff --git a/src/test/scala/com/example/scalaspringexperiment/service/AddressServiceTest.scala b/src/test/scala/com/example/scalaspringexperiment/service/AddressServiceTest.scala new file mode 100644 index 0000000..cf02a4c --- /dev/null +++ b/src/test/scala/com/example/scalaspringexperiment/service/AddressServiceTest.scala @@ -0,0 +1,69 @@ +package com.example.scalaspringexperiment.service + +import cats.effect.unsafe.IORuntime +import com.example.scalaspringexperiment.entity.Address +import com.example.scalaspringexperiment.test.{SpringTestConfig, TestUtils} +import com.example.scalaspringexperiment.util.PointUtils +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.{BeforeEach, Test} +import org.springframework.beans.factory.annotation.Autowired +import org.springframework.boot.test.context.SpringBootTest +import org.springframework.context.annotation.Import + +import scala.compiletime.uninitialized +import scala.util.chaining.* + +@SpringBootTest() +@Import(Array(classOf[SpringTestConfig])) +class AddressServiceTest { + + @Autowired + var addressService: AddressService = uninitialized + + @Autowired + var testUtils: TestUtils = uninitialized + + @Autowired + implicit var runtime: IORuntime = uninitialized + + @BeforeEach + def beforeEach(): Unit = { + testUtils.truncateTables() + } + + @Test + def findWithinDistance_returnsOnlyAddressesWithinDistance(): Unit = { + val personInAustin = testUtils.newRandomPerson(persist = true).tap { p => + addressService.insert( + Address( + personId = p.id, + street = Some("123 Main St"), + city = Some("Austin"), + state = Some("TX"), + coordinates = Some(PointUtils.pointFromLatLon(30.2673, 97.7432)) + ) + ).unsafeRunSync() + } + + val personInSanAntonio = testUtils.newRandomPerson(persist = true).tap { p => + addressService.insert( + Address( + personId = p.id, + street = Some("100 San Antonio Blvd"), + city = Some("San Antonio"), + state = Some("TX"), + coordinates = Some(PointUtils.pointFromLatLon(29.4252, 98.4946)) + ) + ).unsafeRunSync() + } + + val addressesNearAustin = addressService.findWithinDistance( + lat = 30.2672, + lon = 97.7431, + distanceInMeters = 500 + ).unsafeRunSync() + + assertEquals(1, addressesNearAustin.length) + assertEquals("Austin", addressesNearAustin.head.city.get) + } +} diff --git a/src/test/scala/com/example/scalaspringexperiment/test/TestUtils.scala b/src/test/scala/com/example/scalaspringexperiment/test/TestUtils.scala index 4938ee3..78a8a05 100644 --- a/src/test/scala/com/example/scalaspringexperiment/test/TestUtils.scala +++ b/src/test/scala/com/example/scalaspringexperiment/test/TestUtils.scala @@ -2,10 +2,22 @@ package com.example.scalaspringexperiment.test import cats.effect.{IO, Resource} import cats.effect.unsafe.IORuntime -import doobie.DataSourceTransactor +import com.example.scalaspringexperiment.dao.Dao +import com.example.scalaspringexperiment.entity.Person +import com.example.scalaspringexperiment.service.PersonService +import com.github.javafaker.Faker +import doobie.{DataSourceTransactor, Fragment} import doobie.implicits.toSqlInterpolator import org.springframework.stereotype.Service import doobie.syntax.all.* +import org.springframework.beans.factory.annotation.Autowired +import org.springframework.context.ApplicationContext + +import scala.compiletime.uninitialized +import scala.jdk.CollectionConverters.* +import scala.annotation.StaticAnnotation + +case class Table(name: String) extends StaticAnnotation @Service class TestUtils( @@ -13,20 +25,60 @@ class TestUtils( implicit val runtime: IORuntime ) { + @Autowired + var personService: PersonService = uninitialized + + @Autowired + var applicationContext: ApplicationContext = uninitialized + + val faker = new Faker() + + private val ignoredTables = Seq( + "flyway_schema_history", + "geometry_columns", + "geography_columns", + "spatial_ref_sys", + ) + + def getAllDaos(): Seq[Dao[?]] = { + applicationContext + .getBeansOfType(classOf[Dao[?]]) + .values() + .asScala + .toSeq + } + def truncateTables(): Unit = { + println("TRUNCATING ALL TABLES") ds.use { xa => for { - tables <- sql"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'" - .query[String].to[List].transact(xa) + tables <- IO(getAllDaos().map(_.tableInfo.table.name)) results <- IO { - // leave flyway_schema_history in place: - tables.filter(_ != "flyway_schema_history").map { table => - sql"TRUNCATE TABLE $table CASCADE".update.run.transact(xa) + tables.filter(!ignoredTables.contains(_)).map { table => + val theTableName = Fragment.const0(table) + sql"TRUNCATE TABLE $theTableName CASCADE".update.run.transact(xa).unsafeRunSync() } } } yield results }.unsafeRunSync() } + + def newRandomPerson( + persist: Boolean = false, + changes: Person => Person = {p => p} + ): Person = { + (for { + unsavedPerson <- IO(Person( + name = faker.name().fullName(), + age = faker.number().numberBetween(18, 79) + )) + changedPerson = changes(unsavedPerson) + finalizedPerson <- persist match { + case true => personService.insert(changedPerson) + case false => IO.pure(changedPerson) + } + } yield finalizedPerson).unsafeRunSync() + } }