diff --git a/build.gradle b/build.gradle index 91cc2e77d..deefc611b 100644 --- a/build.gradle +++ b/build.gradle @@ -68,6 +68,9 @@ dependencies { implementation 'org.hibernate.validator:hibernate-validator' implementation 'com.amazonaws:aws-java-sdk-s3:1.12.782' implementation 'org.springframework.boot:spring-boot-starter-websocket' + + // Database Proxy + implementation 'net.ttddyy.observation:datasource-micrometer:1.2.0' } tasks.named('test', Test) { diff --git a/src/main/java/com/example/solidconnection/common/config/datasource/DataSourceProxyConfig.java b/src/main/java/com/example/solidconnection/common/config/datasource/DataSourceProxyConfig.java new file mode 100644 index 000000000..b7bf0b008 --- /dev/null +++ b/src/main/java/com/example/solidconnection/common/config/datasource/DataSourceProxyConfig.java @@ -0,0 +1,29 @@ +package com.example.solidconnection.common.config.datasource; + +import com.example.solidconnection.common.listener.QueryMetricsListener; +import javax.sql.DataSource; +import lombok.RequiredArgsConstructor; +import net.ttddyy.dsproxy.support.ProxyDataSourceBuilder; +import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Primary; + +@RequiredArgsConstructor +@Configuration +public class DataSourceProxyConfig { + + private final QueryMetricsListener queryMetricsListener; + + @Bean + @Primary + public DataSource proxyDataSource(DataSourceProperties props) { + DataSource dataSource = props.initializeDataSourceBuilder().build(); + + return ProxyDataSourceBuilder + .create(dataSource) + .listener(queryMetricsListener) + .name("main") + .build(); + } +} diff --git a/src/main/java/com/example/solidconnection/common/config/web/WebMvcConfig.java b/src/main/java/com/example/solidconnection/common/config/web/WebMvcConfig.java index 56bb288e8..7e2f199ca 100644 --- a/src/main/java/com/example/solidconnection/common/config/web/WebMvcConfig.java +++ b/src/main/java/com/example/solidconnection/common/config/web/WebMvcConfig.java @@ -1,11 +1,18 @@ package com.example.solidconnection.common.config.web; +import com.example.solidconnection.common.filter.HttpLoggingFilter; +import com.example.solidconnection.common.interceptor.ApiPerformanceInterceptor; +import com.example.solidconnection.common.interceptor.RequestContextInterceptor; import com.example.solidconnection.common.resolver.AuthorizedUserResolver; import com.example.solidconnection.common.resolver.CustomPageableHandlerMethodArgumentResolver; import java.util.List; import lombok.RequiredArgsConstructor; +import org.springframework.boot.web.servlet.FilterRegistrationBean; +import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.core.Ordered; import org.springframework.web.method.support.HandlerMethodArgumentResolver; +import org.springframework.web.servlet.config.annotation.InterceptorRegistry; import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; @Configuration @@ -14,6 +21,9 @@ public class WebMvcConfig implements WebMvcConfigurer { private final AuthorizedUserResolver authorizedUserResolver; private final CustomPageableHandlerMethodArgumentResolver customPageableHandlerMethodArgumentResolver; + private final HttpLoggingFilter httpLoggingFilter; + private final ApiPerformanceInterceptor apiPerformanceInterceptor; + private final RequestContextInterceptor requestContextInterceptor; @Override public void addArgumentResolvers(List resolvers) { @@ -22,4 +32,23 @@ public void addArgumentResolvers(List resolvers) customPageableHandlerMethodArgumentResolver )); } + + @Override + public void addInterceptors(InterceptorRegistry registry){ + registry.addInterceptor(apiPerformanceInterceptor) + .addPathPatterns("/**") + .excludePathPatterns("/actuator/**"); + + registry.addInterceptor(requestContextInterceptor) + .addPathPatterns("/**") + .excludePathPatterns("/actuator/**"); + } + + @Bean + public FilterRegistrationBean customHttpLoggingFilter() { + FilterRegistrationBean filterBean = new FilterRegistrationBean<>(); + filterBean.setFilter(httpLoggingFilter); + filterBean.setOrder(Ordered.HIGHEST_PRECEDENCE); + return filterBean; + } } diff --git a/src/main/java/com/example/solidconnection/common/filter/HttpLoggingFilter.java b/src/main/java/com/example/solidconnection/common/filter/HttpLoggingFilter.java new file mode 100644 index 000000000..74f2dfa6c --- /dev/null +++ b/src/main/java/com/example/solidconnection/common/filter/HttpLoggingFilter.java @@ -0,0 +1,156 @@ +package com.example.solidconnection.common.filter; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; +import java.util.List; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.slf4j.MDC; +import org.springframework.http.HttpStatus; +import org.springframework.stereotype.Component; +import org.springframework.util.AntPathMatcher; +import org.springframework.web.filter.OncePerRequestFilter; + +@Slf4j +@RequiredArgsConstructor +@Component +public class HttpLoggingFilter extends OncePerRequestFilter { + + private static final AntPathMatcher PATH_MATCHER = new AntPathMatcher(); + private static final List EXCLUDE_PATTERNS = List.of("/actuator/**"); + private static final List EXCLUDE_QUERIES = List.of("token"); + private static final String MASK_VALUE = "****"; + + @Override + protected void doFilterInternal( + HttpServletRequest request, + HttpServletResponse response, + FilterChain filterChain + ) throws ServletException, IOException { + + // 1) traceId 부여 + String traceId = generateTraceId(); + MDC.put("traceId", traceId); + + boolean excluded = isExcluded(request); + + // 2) 로깅 제외 대상이면 그냥 통과 (traceId는 유지: 추후 하위 레이어 로그에도 붙음) + if (excluded) { + try { + filterChain.doFilter(request, response); + } finally { + MDC.clear(); + } + return; + } + + printRequestUri(request); + + try { + filterChain.doFilter(request, response); + printResponse(request, response); + } finally { + MDC.clear(); + } + } + + private boolean isExcluded(HttpServletRequest req) { + String path = req.getRequestURI(); + for (String p : EXCLUDE_PATTERNS) { + if (PATH_MATCHER.match(p, path)) { + return true; + } + } + return false; + } + + private String generateTraceId() { + return java.util.UUID.randomUUID().toString().replace("-", "").substring(0, 16); + } + + private void printRequestUri(HttpServletRequest request) { + String methodType = request.getMethod(); + String uri = buildDecodedRequestUri(request); + log.info("[REQUEST] {} {}", methodType, uri); + } + + private void printResponse( + HttpServletRequest request, + HttpServletResponse response + ) { + Long userId = (Long) request.getAttribute("userId"); + String uri = buildDecodedRequestUri(request); + HttpStatus status = HttpStatus.valueOf(response.getStatus()); + + log.info("[RESPONSE] {} userId = {}, ({})", uri, userId, status); + } + + private String buildDecodedRequestUri(HttpServletRequest request) { + String path = request.getRequestURI(); + String query = request.getQueryString(); + + if(query == null || query.isBlank()){ + return path; + } + + String decodedQuery = decodeQuery(query); + String maskedQuery = maskSensitiveParams(decodedQuery); + + return path + "?" + maskedQuery; + } + + private String decodeQuery(String rawQuery) { + if(rawQuery == null || rawQuery.isBlank()){ + return rawQuery; + } + + try { + return URLDecoder.decode(rawQuery, StandardCharsets.UTF_8); + } catch (IllegalArgumentException e) { + log.warn("Query 디코딩 실패 parameter: {}, msg: {}", rawQuery, e.getMessage()); + return rawQuery; + } + } + + private String maskSensitiveParams(String decodedQuery) { + String[] params = decodedQuery.split("&"); + StringBuilder maskedQuery = new StringBuilder(); + + for(int i = 0; i < params.length; i++){ + String param = params[i]; + + if(!param.contains("=")){ + maskedQuery.append(param); + }else{ + int equalIndex = param.indexOf("="); + String key = param.substring(0, equalIndex); + + if(isSensitiveParam(key)){ + maskedQuery.append(key).append("=").append(MASK_VALUE); + }else{ + maskedQuery.append(param); + } + } + + if(i < params.length - 1){ + maskedQuery.append("&"); + } + } + + return maskedQuery.toString(); + } + + private boolean isSensitiveParam(String paramKey) { + for (String sensitiveParam : EXCLUDE_QUERIES){ + if(sensitiveParam.equalsIgnoreCase(paramKey)){ + return true; + } + } + return false; + } +} diff --git a/src/main/java/com/example/solidconnection/common/interceptor/ApiPerformanceInterceptor.java b/src/main/java/com/example/solidconnection/common/interceptor/ApiPerformanceInterceptor.java new file mode 100644 index 000000000..50a95f937 --- /dev/null +++ b/src/main/java/com/example/solidconnection/common/interceptor/ApiPerformanceInterceptor.java @@ -0,0 +1,67 @@ +package com.example.solidconnection.common.interceptor; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Component; +import org.springframework.web.servlet.HandlerInterceptor; + +@Slf4j +@RequiredArgsConstructor +@Component +public class ApiPerformanceInterceptor implements HandlerInterceptor { + private static final String START_TIME_ATTRIBUTE = "startTime"; + private static final String REQUEST_URI_ATTRIBUTE = "requestUri"; + private static final int RESPONSE_TIME_THRESHOLD = 3_000; + private static final Logger API_PERF = LoggerFactory.getLogger("API_PERF"); + + @Override + public boolean preHandle( + HttpServletRequest request, + HttpServletResponse response, + Object handler + ) throws Exception { + + long startTime = System.currentTimeMillis(); + + request.setAttribute(START_TIME_ATTRIBUTE, startTime); + request.setAttribute(REQUEST_URI_ATTRIBUTE, request.getRequestURI()); + + return true; + } + + @Override + public void afterCompletion( + HttpServletRequest request, + HttpServletResponse response, + Object handler, + Exception ex + ) throws Exception { + Long startTime = (Long) request.getAttribute(START_TIME_ATTRIBUTE); + if(startTime == null) { + return; + } + + long responseTime = System.currentTimeMillis() - startTime; + + String uri = request.getRequestURI(); + String method = request.getMethod(); + int status = response.getStatus(); + + if (responseTime > RESPONSE_TIME_THRESHOLD) { + API_PERF.warn( + "type=API_Performance method_type={} uri={} response_time={} status={}", + method, uri, responseTime, status + ); + } + else { + API_PERF.info( + "type=API_Performance method_type={} uri={} response_time={} status={}", + method, uri, responseTime, status + ); + } + } +} diff --git a/src/main/java/com/example/solidconnection/common/interceptor/RequestContext.java b/src/main/java/com/example/solidconnection/common/interceptor/RequestContext.java new file mode 100644 index 000000000..1f4d2790c --- /dev/null +++ b/src/main/java/com/example/solidconnection/common/interceptor/RequestContext.java @@ -0,0 +1,14 @@ +package com.example.solidconnection.common.interceptor; + +import lombok.Getter; + +@Getter +public class RequestContext { + private final String httpMethod; + private final String bestMatchPath; + + public RequestContext(String httpMethod, String bestMatchPath) { + this.httpMethod = httpMethod; + this.bestMatchPath = bestMatchPath; + } +} diff --git a/src/main/java/com/example/solidconnection/common/interceptor/RequestContextHolder.java b/src/main/java/com/example/solidconnection/common/interceptor/RequestContextHolder.java new file mode 100644 index 000000000..0c786bf10 --- /dev/null +++ b/src/main/java/com/example/solidconnection/common/interceptor/RequestContextHolder.java @@ -0,0 +1,18 @@ +package com.example.solidconnection.common.interceptor; + +public class RequestContextHolder { + private static final ThreadLocal CONTEXT = new ThreadLocal<>(); + + public static void initContext(RequestContext requestContext) { + CONTEXT.remove(); + CONTEXT.set(requestContext); + } + + public static RequestContext getContext() { + return CONTEXT.get(); + } + + public static void clear(){ + CONTEXT.remove(); + } +} diff --git a/src/main/java/com/example/solidconnection/common/interceptor/RequestContextInterceptor.java b/src/main/java/com/example/solidconnection/common/interceptor/RequestContextInterceptor.java new file mode 100644 index 000000000..e42b14e11 --- /dev/null +++ b/src/main/java/com/example/solidconnection/common/interceptor/RequestContextInterceptor.java @@ -0,0 +1,36 @@ +package com.example.solidconnection.common.interceptor; + +import static org.springframework.web.servlet.HandlerMapping.BEST_MATCHING_PATTERN_ATTRIBUTE; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.springframework.stereotype.Component; +import org.springframework.web.servlet.HandlerInterceptor; + +@Component +public class RequestContextInterceptor implements HandlerInterceptor { + + @Override + public boolean preHandle( + HttpServletRequest request, + HttpServletResponse response, + Object handler + ) { + String httpMethod = request.getMethod(); + String bestMatchPath = (String) request.getAttribute(BEST_MATCHING_PATTERN_ATTRIBUTE); + + RequestContext context = new RequestContext(httpMethod, bestMatchPath); + RequestContextHolder.initContext(context); + + return true; + } + + @Override + public void afterCompletion( + HttpServletRequest request, + HttpServletResponse response, + Object handler, Exception ex + ) { + RequestContextHolder.clear(); + } +} diff --git a/src/main/java/com/example/solidconnection/common/listener/QueryMetricsListener.java b/src/main/java/com/example/solidconnection/common/listener/QueryMetricsListener.java new file mode 100644 index 000000000..8f3258b6b --- /dev/null +++ b/src/main/java/com/example/solidconnection/common/listener/QueryMetricsListener.java @@ -0,0 +1,53 @@ +package com.example.solidconnection.common.listener; + +import com.example.solidconnection.common.interceptor.RequestContext; +import com.example.solidconnection.common.interceptor.RequestContextHolder; +import io.micrometer.core.instrument.MeterRegistry; +import java.util.List; +import java.util.concurrent.TimeUnit; +import lombok.RequiredArgsConstructor; +import net.ttddyy.dsproxy.ExecutionInfo; +import net.ttddyy.dsproxy.QueryInfo; +import net.ttddyy.dsproxy.listener.QueryExecutionListener; +import org.springframework.stereotype.Component; + + +@RequiredArgsConstructor +@Component +public class QueryMetricsListener implements QueryExecutionListener { + + private final MeterRegistry meterRegistry; + + @Override + public void beforeQuery(ExecutionInfo executionInfo, List list) { + + } + + @Override + public void afterQuery(ExecutionInfo exec, List queries) { + long elapsedMs = exec.getElapsedTime(); + String sql = queries.isEmpty() ? "" : queries.get(0).getQuery(); + String type = guessType(sql); + + RequestContext rc = RequestContextHolder.getContext(); + String httpMethod = (rc != null && rc.getHttpMethod() != null) ? rc.getHttpMethod() : "-"; + String httpPath = (rc != null && rc.getBestMatchPath() != null) ? rc.getBestMatchPath() : "-"; + + meterRegistry.timer( + "db.query", + "sql_type", type, + "http_method", httpMethod, + "http_path", httpPath + ).record(elapsedMs, TimeUnit.MILLISECONDS); + } + + private String guessType(String sql) { + if (sql == null) return "OTHER"; + String s = sql.trim().toUpperCase(); + if (s.startsWith("SELECT")) return "SELECT"; + if (s.startsWith("INSERT")) return "INSERT"; + if (s.startsWith("UPDATE")) return "UPDATE"; + if (s.startsWith("DELETE")) return "DELETE"; + return "UNKNOWN"; + } +} diff --git a/src/main/java/com/example/solidconnection/security/filter/TokenAuthenticationFilter.java b/src/main/java/com/example/solidconnection/security/filter/TokenAuthenticationFilter.java index 8c8dc8f30..6e1899dd3 100644 --- a/src/main/java/com/example/solidconnection/security/filter/TokenAuthenticationFilter.java +++ b/src/main/java/com/example/solidconnection/security/filter/TokenAuthenticationFilter.java @@ -2,6 +2,7 @@ import com.example.solidconnection.security.authentication.TokenAuthentication; import com.example.solidconnection.security.infrastructure.AuthorizationHeaderParser; +import com.example.solidconnection.security.userdetails.SiteUserDetails; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; @@ -34,6 +35,7 @@ public void doFilterInternal(@NonNull HttpServletRequest request, TokenAuthentication authToken = new TokenAuthentication(token); Authentication auth = authenticationManager.authenticate(authToken); SecurityContextHolder.getContext().setAuthentication(auth); + extractIdFromAuthentication(request, auth); }); filterChain.doFilter(request, response); @@ -45,4 +47,10 @@ private Optional resolveToken(HttpServletRequest request) { } return authorizationHeaderParser.parseToken(request); } + + private void extractIdFromAuthentication(HttpServletRequest request, Authentication auth) { + SiteUserDetails principal = (SiteUserDetails) auth.getPrincipal(); + Long id = principal.getSiteUser().getId(); + request.setAttribute("userId", id); + } } diff --git a/src/main/resources/logback-spring.xml b/src/main/resources/logback-spring.xml index e179be0fb..52d0bb4e8 100644 --- a/src/main/resources/logback-spring.xml +++ b/src/main/resources/logback-spring.xml @@ -2,34 +2,96 @@ - - + - - /var/log/spring/solid-connection-server.log + + + - - - /var/log/spring/solid-connection-server.%d{yyyy-MM-dd}.log - 30 - + + + ${LOG_PATH}/info/info.log + + ${LOG_PATH}/info/info.%d{yyyy-MM-dd}.log + 7 + + + ${LOG_PATTERN} + + + INFO + ACCEPT + DENY + + - - - timestamp=%d{yyyy-MM-dd'T'HH:mm:ss.SSS} level=%-5level thread=%thread logger=%logger{36} - message=%msg%n - - - + + + ${LOG_PATH}/warn/warn.log + + ${LOG_PATH}/warn/warn.%d{yyyy-MM-dd}.log + 7 + + + ${LOG_PATTERN} + + + WARN + ACCEPT + DENY + + - - - + + + ${LOG_PATH}/error/error.log + + ${LOG_PATH}/error/error.%d{yyyy-MM-dd}.log + 7 + + + ${LOG_PATTERN} + + + ERROR + ACCEPT + DENY + + - + + + ${LOG_PATH}/api-perf/api-perf.log + + ${LOG_PATH}/api-perf/api-perf.%d{yyyy-MM-dd}.log + 7 + + + ${LOG_PATTERN} + + + + + + + + + + + + + + + + + + + + - + - + \ No newline at end of file diff --git a/src/test/java/com/example/solidconnection/common/filter/HttpLoggingFilterTest.java b/src/test/java/com/example/solidconnection/common/filter/HttpLoggingFilterTest.java new file mode 100644 index 000000000..cfd8ee681 --- /dev/null +++ b/src/test/java/com/example/solidconnection/common/filter/HttpLoggingFilterTest.java @@ -0,0 +1,243 @@ +package com.example.solidconnection.common.filter; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import ch.qos.logback.classic.Logger; +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.read.ListAppender; +import com.example.solidconnection.support.TestContainerSpringBootTest; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.slf4j.LoggerFactory; +import org.slf4j.MDC; + +@TestContainerSpringBootTest +@DisplayName("HttpLoggingFilter 테스트") +class HttpLoggingFilterTest { + + private HttpLoggingFilter filter; + private HttpServletRequest request; + private HttpServletResponse response; + private FilterChain filterChain; + + private ListAppender listAppender; + private Logger logger; + + @BeforeEach + void setUp() { + filter = new HttpLoggingFilter(); + request = mock(HttpServletRequest.class); + response = mock(HttpServletResponse.class); + filterChain = mock(FilterChain.class); + + logger = (Logger) LoggerFactory.getLogger(HttpLoggingFilter.class); + listAppender = new ListAppender<>(); + listAppender.start(); + logger.addAppender(listAppender); + } + + @AfterEach + void tearDown() { + MDC.clear(); + logger.detachAppender(listAppender); + listAppender.stop(); + } + + @Nested + class TraceId_생성 { + + @Test + void 요청마다_traceId를_생성한다() throws ServletException, IOException { + // given + when(request.getRequestURI()).thenReturn("/api/test"); + when(request.getMethod()).thenReturn("GET"); + when(response.getStatus()).thenReturn(200); + + AtomicReference capturedTraceId = new AtomicReference<>(); + + doAnswer(invocation ->{ + capturedTraceId.set(MDC.get("traceId")); + return null; + }).when(filterChain).doFilter(request, response); + + // when + filter.doFilterInternal(request, response, filterChain); + + // then + String traceId = capturedTraceId.get(); + assertAll( + () -> assertThat(traceId).isNotNull(), + () -> assertThat(traceId).hasSize(16), + () -> assertThat(traceId).matches("[a-f0-9]{16}") + ); + verify(filterChain).doFilter(request, response); + } + } + + @Nested + class 로깅_제외_패턴 { + + @Test + void actuator_경로는_로깅에서_제외된다() throws ServletException, IOException { + // given + when(request.getRequestURI()).thenReturn("/actuator/health"); + when(request.getMethod()).thenReturn("GET"); + + // when + filter.doFilterInternal(request, response, filterChain); + + // then + assertAll( + () -> assertThat(listAppender.list).noneMatch(event -> event.getFormattedMessage().contains("[REQUEST]")), + () -> assertThat(listAppender.list).noneMatch(event -> event.getFormattedMessage().contains("[RESPONSE]")) + ); + verify(filterChain).doFilter(request, response); + } + + @Test + void 일반_경로는_로깅된다() throws ServletException, IOException { + // given + when(request.getRequestURI()).thenReturn("/api/users"); + when(request.getMethod()).thenReturn("GET"); + when(response.getStatus()).thenReturn(200); + String expectedRequestLog = "[REQUEST] GET /api/users"; + String expectedResponseLog = "[RESPONSE] /api/users userId = null, (200 OK)"; + + + // when + filter.doFilterInternal(request, response, filterChain); + + // then + assertAll( + () -> assertThat(listAppender.list).anyMatch(event -> event.getFormattedMessage().contains(expectedRequestLog)), + () -> assertThat(listAppender.list).anyMatch(event -> event.getFormattedMessage().contains(expectedResponseLog)) + ); + verify(filterChain).doFilter(request, response); + } + } + + @Nested + class 민감한_쿼리_파라미터_마스킹 { + + @Test + void token_파라미터는_마스킹된다() throws ServletException, IOException { + // given + when(request.getRequestURI()).thenReturn("/api/auth"); + when(request.getQueryString()).thenReturn("token=secret123&userId=1"); + when(request.getMethod()).thenReturn("GET"); + when(response.getStatus()).thenReturn(200); + String expectedRequestLog = "[REQUEST] GET /api/auth?token=****&userId=1"; + String expectedResponseLog = "[RESPONSE] /api/auth?token=****&userId=1 userId = null, (200 OK)"; + + // when + filter.doFilterInternal(request, response, filterChain); + + // then + assertAll( + () -> assertThat(listAppender.list).anyMatch(event -> event.getFormattedMessage().contains(expectedRequestLog)), + () -> assertThat(listAppender.list).anyMatch(event -> event.getFormattedMessage().contains(expectedResponseLog)) + ); + verify(filterChain).doFilter(request, response); + } + + @Test + void 일반_파라미터는_마스킹되지_않는다() throws ServletException, IOException { + // given + when(request.getRequestURI()).thenReturn("/api/users"); + when(request.getQueryString()).thenReturn("name=홍길동&age=20"); + when(request.getMethod()).thenReturn("GET"); + when(response.getStatus()).thenReturn(200); + String expectedRequestLog = "[REQUEST] GET /api/users?name=홍길동&age=20"; + String expectedResponseLog = "[RESPONSE] /api/users?name=홍길동&age=20 userId = null, (200 OK)"; + + // when + filter.doFilterInternal(request, response, filterChain); + + // then + assertAll( + () -> assertThat(listAppender.list).anyMatch(event -> event.getFormattedMessage().contains(expectedRequestLog)), + () -> assertThat(listAppender.list).anyMatch(event -> event.getFormattedMessage().contains(expectedResponseLog)) + ); + verify(filterChain).doFilter(request, response); + } + } + + @Nested + class 쿼리_파라미터_디코딩 { + + @Test + void URL_인코딩된_파라미터를_디코딩한다() throws ServletException, IOException { + // given + when(request.getRequestURI()).thenReturn("/api/search"); + when(request.getQueryString()).thenReturn("keyword=%ED%99%8D%EA%B8%B8%EB%8F%99"); + when(request.getMethod()).thenReturn("GET"); + when(response.getStatus()).thenReturn(200); + String expectedParameter = "홍길동"; + String expectedRequestLog = "[REQUEST] GET /api/search?keyword=" + expectedParameter; + String expectedResponseLog = "[RESPONSE] /api/search?keyword=" + expectedParameter + " userId = null, (200 OK)"; + + // when + filter.doFilterInternal(request, response, filterChain); + + // then + assertAll( + () -> assertThat(listAppender.list).anyMatch(event -> event.getFormattedMessage().contains(expectedRequestLog)), + () -> assertThat(listAppender.list).anyMatch(event -> event.getFormattedMessage().contains(expectedResponseLog)) + ); + verify(filterChain).doFilter(request, response); + } + + @Test + void 디코딩_실패_시_원본_쿼리를_사용한다() throws ServletException, IOException { + // given + when(request.getRequestURI()).thenReturn("/api/search"); + when(request.getQueryString()).thenReturn("invalid=%"); + when(request.getMethod()).thenReturn("GET"); + when(response.getStatus()).thenReturn(200); + String expectedRequestLog = "[REQUEST] GET /api/search?invalid=%"; + String expectedResponseLog = "[RESPONSE] /api/search?invalid=% userId = null, (200 OK)"; + + // when + filter.doFilterInternal(request, response, filterChain); + + // then + assertAll( + () -> assertThat(listAppender.list).anyMatch(event -> event.getFormattedMessage().contains(expectedRequestLog)), + () -> assertThat(listAppender.list).anyMatch(event -> event.getFormattedMessage().contains(expectedResponseLog)) + ); + verify(filterChain).doFilter(request, response); + } + } + + @Nested + class MDC_정리 { + + @Test + void 요청_완료_후_MDC가_정리된다() throws ServletException, IOException { + // given + when(request.getRequestURI()).thenReturn("/api/test"); + when(request.getMethod()).thenReturn("GET"); + when(response.getStatus()).thenReturn(200); + + // when + filter.doFilterInternal(request, response, filterChain); + + // then + assertThat(MDC.get("traceId")).isNull(); + } + } +} diff --git a/src/test/java/com/example/solidconnection/common/interceptor/ApiPerformanceInterceptorTest.java b/src/test/java/com/example/solidconnection/common/interceptor/ApiPerformanceInterceptorTest.java new file mode 100644 index 000000000..ad836ec29 --- /dev/null +++ b/src/test/java/com/example/solidconnection/common/interceptor/ApiPerformanceInterceptorTest.java @@ -0,0 +1,201 @@ +package com.example.solidconnection.common.interceptor; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import ch.qos.logback.classic.Logger; +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.read.ListAppender; +import com.example.solidconnection.support.TestContainerSpringBootTest; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.util.List; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.slf4j.LoggerFactory; + +@TestContainerSpringBootTest +@DisplayName("ApiPerformanceInterceptor 테스트") +class ApiPerformanceInterceptorTest { + + private ApiPerformanceInterceptor interceptor; + private HttpServletRequest request; + private HttpServletResponse response; + private Object handler; + + private ListAppender listAppender; + private Logger logger; + + @BeforeEach + void setUp() { + interceptor = new ApiPerformanceInterceptor(); + request = mock(HttpServletRequest.class); + response = mock(HttpServletResponse.class); + handler = new Object(); + + logger = (Logger) LoggerFactory.getLogger("API_PERF"); + listAppender = new ListAppender<>(); + listAppender.start(); + logger.addAppender(listAppender); + } + + @AfterEach + void tearDown() { + logger.detachAppender(listAppender); + listAppender.stop(); + } + + @Nested + class PreHandle_메서드 { + + @Test + void 시작_시간을_request에_저장한다() throws Exception { + // given + when(request.getRequestURI()).thenReturn("/api/test"); + long beforeTime = System.currentTimeMillis(); + + // when + interceptor.preHandle(request, response, handler); + + // then + ArgumentCaptor keyCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor valueCaptor = ArgumentCaptor.forClass(Object.class); + + verify(request, times(2)).setAttribute(keyCaptor.capture(), valueCaptor.capture()); + + List capturedKeys = keyCaptor.getAllValues(); + List capturedValues = valueCaptor.getAllValues(); + + assertThat(capturedKeys).contains("startTime"); + Long startTime = (Long) capturedValues.get(capturedKeys.indexOf("startTime")); + assertThat(startTime) + .isGreaterThanOrEqualTo(beforeTime); + + assertThat(capturedKeys).contains("requestUri"); + String uri = (String) capturedValues.get(capturedKeys.indexOf("requestUri")); + assertThat(uri).isEqualTo("/api/test"); + } + + @Test + void preHandle_항상_true를_반환한다() throws Exception { + // given + when(request.getRequestURI()).thenReturn("/api/test"); + + // when + boolean result = interceptor.preHandle(request, response, handler); + + // then + assertThat(result).isTrue(); + } + } + + @Nested + class AfterCompletion_메서드 { + + @Test + void 응답_시간을_계산하고_로그를_남긴다() throws Exception { + // given + long startTime = System.currentTimeMillis(); + when(request.getAttribute("startTime")).thenReturn(startTime); + when(request.getRequestURI()).thenReturn("/api/test"); + when(request.getMethod()).thenReturn("GET"); + when(response.getStatus()).thenReturn(200); + String expectedApiPerfLog = "type=API_Performance"; + + // when + interceptor.afterCompletion(request, response, handler, null); + + // then + ILoggingEvent logEvent = listAppender.list.stream() + .filter(event -> event.getFormattedMessage().contains(expectedApiPerfLog)) + .findFirst() + .orElseThrow(); + assertAll( + () -> assertThat(logEvent.getLevel().toString()).isEqualTo("INFO"), + () -> assertThat(logEvent.getFormattedMessage()).contains("uri=/api/test"), + () -> assertThat(logEvent.getFormattedMessage()).contains("method_type=GET"), + () -> assertThat(logEvent.getFormattedMessage()).contains("status=200") + ); + } + + @Test + void 응답_시간이_3초를_초과하면_WARN_로그를_남긴다() throws Exception { + // given + long startTime = System.currentTimeMillis() - 4000; // 4초 전 + when(request.getAttribute("startTime")).thenReturn(startTime); + when(request.getRequestURI()).thenReturn("/api/slow"); + when(request.getMethod()).thenReturn("GET"); + when(response.getStatus()).thenReturn(200); + String expectedApiPerfLog = "type=API_Performance"; + + // when + interceptor.afterCompletion(request, response, handler, null); + + // then + ILoggingEvent logEvent = listAppender.list.stream() + .filter(event -> event.getFormattedMessage().contains(expectedApiPerfLog)) + .findFirst() + .orElseThrow(); + assertAll( + () -> assertThat(logEvent.getLevel().toString()).isEqualTo("WARN"), + () -> assertThat(logEvent.getFormattedMessage()).contains("uri=/api/slow"), + () -> assertThat(logEvent.getFormattedMessage()).contains("method_type=GET"), + () -> assertThat(logEvent.getFormattedMessage()).contains("status=200") + ); + } + + @Test + void startTime이_없으면_로그를_남기지_않는다() throws Exception { + // given + when(request.getAttribute("startTime")).thenReturn(null); + String noExpectedApiPerfLog = "type=API_Performance"; + + // when + interceptor.afterCompletion(request, response, handler, null); + + // then + assertThat(listAppender.list).noneMatch(event -> event.getFormattedMessage().contains(noExpectedApiPerfLog)); + } + } + + @Nested + class 예외_발생_시 { + + @Test + void 예외가_발생해도_로그를_정상_기록한다() throws Exception { + // given + long startTime = System.currentTimeMillis(); + when(request.getAttribute("startTime")).thenReturn(startTime); + when(request.getRequestURI()).thenReturn("/api/error"); + when(request.getMethod()).thenReturn("GET"); + when(response.getStatus()).thenReturn(500); + + Exception ex = new RuntimeException("Test exception"); + + String expectedApiPerfLog = "type=API_Performance"; + + // when + interceptor.afterCompletion(request, response, handler, ex); + + // then + ILoggingEvent logEvent = listAppender.list.stream() + .filter(event -> event.getFormattedMessage().contains(expectedApiPerfLog)) + .findFirst() + .orElseThrow(); + assertAll( + () -> assertThat(logEvent.getLevel().toString()).isEqualTo("INFO"), + () -> assertThat(logEvent.getFormattedMessage()).contains("uri=/api/error"), + () -> assertThat(logEvent.getFormattedMessage()).contains("method_type=GET"), + () -> assertThat(logEvent.getFormattedMessage()).contains("status=500") + ); + } + } +} diff --git a/src/test/java/com/example/solidconnection/common/interceptor/RequestContextInterceptorTest.java b/src/test/java/com/example/solidconnection/common/interceptor/RequestContextInterceptorTest.java new file mode 100644 index 000000000..6d463e958 --- /dev/null +++ b/src/test/java/com/example/solidconnection/common/interceptor/RequestContextInterceptorTest.java @@ -0,0 +1,112 @@ +package com.example.solidconnection.common.interceptor; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.springframework.web.servlet.HandlerMapping.BEST_MATCHING_PATTERN_ATTRIBUTE; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +@DisplayName("RequestContextInterceptor 테스트") +class RequestContextInterceptorTest { + + private RequestContextInterceptor interceptor; + private HttpServletRequest request; + private HttpServletResponse response; + private Object handler; + + @BeforeEach + void setUp() { + interceptor = new RequestContextInterceptor(); + request = mock(HttpServletRequest.class); + response = mock(HttpServletResponse.class); + handler = new Object(); + } + + @AfterEach + void tearDown() { + RequestContextHolder.clear(); + } + + @Nested + class PreHandle_메서드 { + + @Test + void RequestContext를_초기화_한_후_true를_리턴한다() { + // given + when(request.getMethod()).thenReturn("GET"); + when(request.getAttribute(BEST_MATCHING_PATTERN_ATTRIBUTE)).thenReturn("/api/users/{id}"); + + // when + boolean result = interceptor.preHandle(request, response, handler); + + // then + assertThat(result).isTrue(); + + RequestContext context = RequestContextHolder.getContext(); + assertThat(context).isNotNull(); + assertThat(context.getHttpMethod()).isEqualTo("GET"); + assertThat(context.getBestMatchPath()).isEqualTo("/api/users/{id}"); + } + + @Test + void best_matching_pattern이_null이면_null을_저장한다() { + // given + when(request.getMethod()).thenReturn("GET"); + when(request.getAttribute(BEST_MATCHING_PATTERN_ATTRIBUTE)).thenReturn(null); + + // when + boolean result = interceptor.preHandle(request, response, handler); + + // then + assertThat(result).isTrue(); + + RequestContext context = RequestContextHolder.getContext(); + assertThat(context.getBestMatchPath()).isNull(); + } + } + + @Nested + class AfterCompletion_메서드 { + + @Test + void RequestContext를_정리한다() { + // given + when(request.getMethod()).thenReturn("GET"); + when(request.getAttribute(BEST_MATCHING_PATTERN_ATTRIBUTE)).thenReturn("/api/users"); + + interceptor.preHandle(request, response, handler); + assertThat(RequestContextHolder.getContext()).isNotNull(); + + // when + interceptor.afterCompletion(request, response, handler, null); + + // then + assertThat(RequestContextHolder.getContext()).isNull(); + } + + @Test + void 예외가_발생해도_RequestContext를_정리한다() { + // given + when(request.getMethod()).thenReturn("POST"); + when(request.getAttribute(BEST_MATCHING_PATTERN_ATTRIBUTE)).thenReturn("/api/users"); + + interceptor.preHandle(request, response, handler); + assertThat(RequestContextHolder.getContext()).isNotNull(); + + Exception ex = new RuntimeException("Test exception"); + + // when + interceptor.afterCompletion(request, response, handler, ex); + + // then + assertThat(RequestContextHolder.getContext()).isNull(); + } + } +} diff --git a/src/test/java/com/example/solidconnection/common/listener/QueryMetricsListenerTest.java b/src/test/java/com/example/solidconnection/common/listener/QueryMetricsListenerTest.java new file mode 100644 index 000000000..e0ca19a4c --- /dev/null +++ b/src/test/java/com/example/solidconnection/common/listener/QueryMetricsListenerTest.java @@ -0,0 +1,289 @@ +package com.example.solidconnection.common.listener; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.example.solidconnection.common.interceptor.RequestContext; +import com.example.solidconnection.common.interceptor.RequestContextHolder; +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Timer; +import java.util.List; +import java.util.concurrent.TimeUnit; +import net.ttddyy.dsproxy.ExecutionInfo; +import net.ttddyy.dsproxy.QueryInfo; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +@DisplayName("QueryMetricsListener 테스트") +class QueryMetricsListenerTest { + + private QueryMetricsListener listener; + private MeterRegistry meterRegistry; + private ExecutionInfo executionInfo; + + @BeforeEach + void setUp() { + meterRegistry = mock(MeterRegistry.class); + listener = new QueryMetricsListener(meterRegistry); + executionInfo = mock(ExecutionInfo.class); + } + + @AfterEach + void tearDown() { + RequestContextHolder.clear(); + } + + @Nested + class 쿼리_메트릭_수집 { + + @Test + void SELECT_쿼리의_실행_시간을_기록한다() { + // given + String sql = "SELECT * FROM users WHERE id = ?"; + QueryInfo queryInfo = new QueryInfo(); + queryInfo.setQuery(sql); + + when(executionInfo.getElapsedTime()).thenReturn(100L); + + Timer timer = mock(Timer.class); + when(meterRegistry.timer( + eq("db.query"), + eq("sql_type"), any(String.class), + eq("http_method"), any(String.class), + eq("http_path"), any(String.class) + )).thenReturn(timer); + + // when + listener.afterQuery(executionInfo, List.of(queryInfo)); + + // then + verify(meterRegistry).timer( + eq("db.query"), + eq("sql_type"), eq("SELECT"), + eq("http_method"), any(String.class), + eq("http_path"), any(String.class) + ); + verify(timer).record(100L, TimeUnit.MILLISECONDS); + } + + @Test + void INSERT_쿼리의_실행_시간을_기록한다() { + // given + String sql = "INSERT INTO users (name) VALUES (?)"; + QueryInfo queryInfo = new QueryInfo(); + queryInfo.setQuery(sql); + + when(executionInfo.getElapsedTime()).thenReturn(100L); + + Timer timer = mock(Timer.class); + when(meterRegistry.timer( + eq("db.query"), + eq("sql_type"), any(String.class), + eq("http_method"), any(String.class), + eq("http_path"), any(String.class) + )).thenReturn(timer); + + // when + listener.afterQuery(executionInfo, List.of(queryInfo)); + + // then + verify(meterRegistry).timer( + eq("db.query"), + eq("sql_type"), eq("INSERT"), + eq("http_method"), any(String.class), + eq("http_path"), any(String.class) + ); + verify(timer).record(100L, TimeUnit.MILLISECONDS); + } + + @Test + void UPDATE_쿼리의_실행_시간을_기록한다() { + // given + String sql = "UPDATE users SET name = ? WHERE id = ?"; + QueryInfo queryInfo = new QueryInfo(); + queryInfo.setQuery(sql); + + when(executionInfo.getElapsedTime()).thenReturn(100L); + + Timer timer = mock(Timer.class); + when(meterRegistry.timer( + eq("db.query"), + eq("sql_type"), eq("UPDATE"), + eq("http_method"), any(String.class), + eq("http_path"), any(String.class) + )).thenReturn(timer); + + // when + listener.afterQuery(executionInfo, List.of(queryInfo)); + + // then + verify(meterRegistry).timer( + eq("db.query"), + eq("sql_type"), eq("UPDATE"), + eq("http_method"), any(String.class), + eq("http_path"), any(String.class) + ); + verify(timer).record(100L, TimeUnit.MILLISECONDS); + } + + @Test + void DELETE_쿼리의_실행_시간을_기록한다() { + // given + String sql = "DELETE FROM users WHERE id = ?"; + QueryInfo queryInfo = new QueryInfo(); + queryInfo.setQuery(sql); + + when(executionInfo.getElapsedTime()).thenReturn(100L); + + Timer timer = mock(Timer.class); + when(meterRegistry.timer( + eq("db.query"), + eq("sql_type"), eq("DELETE"), + eq("http_method"), any(String.class), + eq("http_path"), any(String.class) + )).thenReturn(timer); + + // when + listener.afterQuery(executionInfo, List.of(queryInfo)); + + // then + verify(meterRegistry).timer( + eq("db.query"), + eq("sql_type"), eq("DELETE"), + eq("http_method"), any(String.class), + eq("http_path"), any(String.class) + ); + verify(timer).record(100L, TimeUnit.MILLISECONDS); + } + + @Test + void 알수없는_쿼리는_UNKNOWN으로_기록한다() { + // given + String sql = "SHOW TABLES"; + QueryInfo queryInfo = new QueryInfo(); + queryInfo.setQuery(sql); + + when(executionInfo.getElapsedTime()).thenReturn(100L); + + Timer timer = mock(Timer.class); + when(meterRegistry.timer( + eq("db.query"), + eq("sql_type"), any(String.class), + eq("http_method"), any(String.class), + eq("http_path"), any(String.class) + )).thenReturn(timer); + + // when + listener.afterQuery(executionInfo, List.of(queryInfo)); + + // then + verify(meterRegistry).timer( + eq("db.query"), + eq("sql_type"), eq("UNKNOWN"), + eq("http_method"), any(String.class), + eq("http_path"), any(String.class) + ); + verify(timer).record(100L, TimeUnit.MILLISECONDS); + } + + @Test + void null_쿼리는_OTHER로_기록한다() { + // given + QueryInfo queryInfo = new QueryInfo(); + when(executionInfo.getElapsedTime()).thenReturn(100L); + + Timer timer = mock(Timer.class); + when(meterRegistry.timer( + eq("db.query"), + eq("sql_type"), any(String.class), + eq("http_method"), any(String.class), + eq("http_path"), any(String.class) + )).thenReturn(timer); + + // when + listener.afterQuery(executionInfo, List.of(queryInfo)); + + // then + verify(meterRegistry).timer( + eq("db.query"), + eq("sql_type"), eq("OTHER"), + eq("http_method"), any(String.class), + eq("http_path"), any(String.class) + ); + verify(timer).record(100L, TimeUnit.MILLISECONDS); + } + } + + @Nested + class RequestContext_연동 { + + @Test + void RequestContext가_있으면_HTTP_정보를_포함한다() { + // given + RequestContext context = new RequestContext("GET", "/api/users"); + RequestContextHolder.initContext(context); + + String sql = "SELECT * FROM users"; + QueryInfo queryInfo = new QueryInfo(); + queryInfo.setQuery(sql); + + when(executionInfo.getElapsedTime()).thenReturn(100L); + + Timer timer = mock(Timer.class); + when(meterRegistry.timer( + eq("db.query"), + eq("sql_type"), any(String.class), + eq("http_method"), any(String.class), + eq("http_path"), any(String.class) + )).thenReturn(timer); + + // when + listener.afterQuery(executionInfo, List.of(queryInfo)); + + // then + verify(meterRegistry).timer( + eq("db.query"), + eq("sql_type"), eq("SELECT"), + eq("http_method"), eq("GET"), + eq("http_path"), eq("/api/users") + ); + verify(timer).record(100L, TimeUnit.MILLISECONDS); + } + + @Test + void RequestContext가_없으면_기본값을_사용한다() { + // given + String sql = "SELECT * FROM users"; + QueryInfo queryInfo = new QueryInfo(); + queryInfo.setQuery(sql); + + when(executionInfo.getElapsedTime()).thenReturn(100L); + + Timer timer = mock(Timer.class); + when(meterRegistry.timer( + eq("db.query"), + eq("sql_type"), any(String.class), + eq("http_method"), any(String.class), + eq("http_path"), any(String.class) + )).thenReturn(timer); + + // when + listener.afterQuery(executionInfo, List.of(queryInfo)); + + // then + verify(meterRegistry).timer( + eq("db.query"), + eq("sql_type"), eq("SELECT"), + eq("http_method"), eq("-"), + eq("http_path"), eq("-") + ); + verify(timer).record(100L, TimeUnit.MILLISECONDS); + } + } +} diff --git a/src/test/java/com/example/solidconnection/security/filter/TokenAuthenticationFilterTest.java b/src/test/java/com/example/solidconnection/security/filter/TokenAuthenticationFilterTest.java index 36d8c3dd8..d0b7d8963 100644 --- a/src/test/java/com/example/solidconnection/security/filter/TokenAuthenticationFilterTest.java +++ b/src/test/java/com/example/solidconnection/security/filter/TokenAuthenticationFilterTest.java @@ -1,12 +1,17 @@ package com.example.solidconnection.security.filter; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.then; import static org.mockito.Mockito.spy; import com.example.solidconnection.auth.token.config.JwtProperties; import com.example.solidconnection.security.authentication.TokenAuthentication; +import com.example.solidconnection.security.userdetails.SiteUserDetails; import com.example.solidconnection.security.userdetails.SiteUserDetailsService; +import com.example.solidconnection.siteuser.domain.SiteUser; +import com.example.solidconnection.siteuser.fixture.SiteUserFixture; import com.example.solidconnection.support.TestContainerSpringBootTest; import io.jsonwebtoken.Jwts; import io.jsonwebtoken.SignatureAlgorithm; @@ -33,6 +38,9 @@ class TokenAuthenticationFilterTest { @Autowired private JwtProperties jwtProperties; + @Autowired + private SiteUserFixture siteUserFixture; + @MockBean // 이 테스트코드에서 사용자를 조회할 필요는 없으므로 MockBean 으로 대체 private SiteUserDetailsService siteUserDetailsService; @@ -45,6 +53,11 @@ void setUp() { response = new MockHttpServletResponse(); filterChain = spy(FilterChain.class); SecurityContextHolder.clearContext(); + + SiteUser siteUser = siteUserFixture.사용자(1, "test"); + SiteUserDetails userDetails = new SiteUserDetails(siteUser); + given(siteUserDetailsService.loadUserByUsername(anyString())) + .willReturn(userDetails); } @Test @@ -61,8 +74,9 @@ void setUp() { } @Test - void 토큰이_있으면_컨텍스트에_저장한다() throws Exception { + void 토큰이_있으면_컨텍스트에_저장하고_userId를_request에_설정한다() throws Exception { // given + Long expectedUserId = 1L; Date validExpiration = new Date(System.currentTimeMillis() + 1000); String token = createTokenWithExpiration(validExpiration); request = createRequestWithToken(token); @@ -73,6 +87,7 @@ void setUp() { // then assertThat(SecurityContextHolder.getContext().getAuthentication()) .isExactlyInstanceOf(TokenAuthentication.class); + assertThat(request.getAttribute("userId")).isEqualTo(expectedUserId); then(filterChain).should().doFilter(request, response); }