diff --git a/src/main/java/org/patinanetwork/codebloom/api/auth/AuthController.java b/src/main/java/org/patinanetwork/codebloom/api/auth/AuthController.java index f3073b942..aff6da699 100644 --- a/src/main/java/org/patinanetwork/codebloom/api/auth/AuthController.java +++ b/src/main/java/org/patinanetwork/codebloom/api/auth/AuthController.java @@ -233,7 +233,7 @@ public ResponseEntity> enrollSchool( MagicLink magicLink = new MagicLink(email, userId); try { - String token = jwtClient.encode(magicLink, Duration.ofHours(1)); + String token = jwtClient.encode(magicLink, Duration.ofHours(1), serverUrlUtils.getUrl()); String verificationLink = serverUrlUtils.getUrl() + "/api/auth/school/verify?state=" + token; emailClient.sendMessage(SendEmailOptions.builder() .recipientEmail(email) @@ -275,7 +275,7 @@ public RedirectView verifySchoolEmail(final HttpServletRequest request) { String token = request.getParameter("state"); MagicLink magicLink; try { - magicLink = jwtClient.decode(token, MagicLink.class); + magicLink = jwtClient.decode(token, MagicLink.class, serverUrlUtils.getUrl()); } catch (Exception e) { return new RedirectView("/settings?success=false&message=Invalid or expired token"); } diff --git a/src/main/java/org/patinanetwork/codebloom/common/jwt/JWTClient.java b/src/main/java/org/patinanetwork/codebloom/common/jwt/JWTClient.java index 5655284d3..3ed4c6ebd 100644 --- a/src/main/java/org/patinanetwork/codebloom/common/jwt/JWTClient.java +++ b/src/main/java/org/patinanetwork/codebloom/common/jwt/JWTClient.java @@ -14,4 +14,11 @@ public interface JWTClient { /** Parse the JWT token back into a valid Object. Will throw if expired or unable to verify JWT. */ T decode(String token, Class clazz) throws JsonProcessingException, JWTVerificationException; + + /** Create the JWT token with an audience claim, binding it to a specific server. */ + String encode(T obj, Duration expiresIn, String audience) throws JsonProcessingException; + + /** Parse the JWT token, enforcing audience match. Will throw if expired, invalid, or audience mismatch. */ + T decode(String token, Class clazz, String expectedAudience) + throws JsonProcessingException, JWTVerificationException; } diff --git a/src/main/java/org/patinanetwork/codebloom/common/jwt/impl/JWTClientImpl.java b/src/main/java/org/patinanetwork/codebloom/common/jwt/impl/JWTClientImpl.java index d53807e00..72aa79ef6 100644 --- a/src/main/java/org/patinanetwork/codebloom/common/jwt/impl/JWTClientImpl.java +++ b/src/main/java/org/patinanetwork/codebloom/common/jwt/impl/JWTClientImpl.java @@ -67,4 +67,27 @@ public T decode(final String token, final Class clazz) return objectMapper.readValue(payloadString, clazz); } + + /** Create the JWT token with an audience claim, binding it to a specific server. */ + public String encode(final T obj, final Duration expiresIn, final String audience) + throws JsonProcessingException { + String payload = objectMapper.writeValueAsString(obj); + return JWT.create() + .withClaim("payload", payload) + .withAudience(audience) + .withExpiresAt(Instant.now().plus(expiresIn)) + .sign(algorithm); + } + + /** Parse the JWT token, enforcing audience match. Will throw if expired, invalid, or audience mismatch. */ + public T decode(final String token, final Class clazz, final String expectedAudience) + throws JsonProcessingException, JWTVerificationException { + DecodedJWT decodedJWT = + JWT.require(algorithm).withAudience(expectedAudience).build().verify(token); + String payloadString = decodedJWT.getClaim("payload").asString(); + if (payloadString == null) { + return null; + } + return objectMapper.readValue(payloadString, clazz); + } } diff --git a/src/test/java/org/patinanetwork/codebloom/api/auth/AuthControllerTest.java b/src/test/java/org/patinanetwork/codebloom/api/auth/AuthControllerTest.java index 34efcb7e2..32a03cd22 100644 --- a/src/test/java/org/patinanetwork/codebloom/api/auth/AuthControllerTest.java +++ b/src/test/java/org/patinanetwork/codebloom/api/auth/AuthControllerTest.java @@ -282,7 +282,8 @@ void enrollSchoolEmailSendFailure() throws Exception { EmailBody emailBody = new EmailBody("test@myhunter.cuny.edu"); when(protector.validateSession(request)).thenReturn(authObj); - when(jwtClient.encode(any(MagicLink.class), any(Duration.class))).thenReturn("mock-token"); + when(jwtClient.encode(any(MagicLink.class), any(Duration.class), any(String.class))) + .thenReturn("mock-token"); when(serverUrlUtils.getUrl()).thenReturn("http://localhost:8080"); when(reactEmailTemplater.schoolEmailTemplate(any())).thenReturn("Template"); doThrow(new EmailException("Failed to send email")).when(emailClient).sendMessage(any(SendEmailOptions.class)); @@ -309,7 +310,8 @@ void enrollSchoolHappyPath() throws Exception { when(protector.validateSession(request)).thenReturn(authObj); when(simpleRedis.containsKey(user.getId())).thenReturn(false); - when(jwtClient.encode(any(MagicLink.class), any(Duration.class))).thenReturn("mock-token"); + when(jwtClient.encode(any(MagicLink.class), any(Duration.class), any(String.class))) + .thenReturn("mock-token"); when(serverUrlUtils.getUrl()).thenReturn("http://localhost:8080"); when(reactEmailTemplater.schoolEmailTemplate(any())).thenReturn("Template"); @@ -332,6 +334,7 @@ void enrollSchoolHappyPath() throws Exception { void verifySchoolEmailNotAuthenticated() { HttpServletRequest request = mock(HttpServletRequest.class); + when(serverUrlUtils.getUrl()).thenReturn("http://localhost:8080"); when(protector.validateSession(request)).thenThrow(new RuntimeException("Not authenticated")); RedirectView redirectView = authController.verifySchoolEmail(request); @@ -351,9 +354,11 @@ void verifySchoolEmailInvalidToken() throws Exception { HttpServletRequest request = mock(HttpServletRequest.class); + when(serverUrlUtils.getUrl()).thenReturn("http://localhost:8080"); when(protector.validateSession(request)).thenReturn(authObj); when(request.getParameter("state")).thenReturn("invalid-token"); - when(jwtClient.decode("invalid-token", MagicLink.class)).thenThrow(new RuntimeException("Invalid token")); + when(jwtClient.decode("invalid-token", MagicLink.class, "http://localhost:8080")) + .thenThrow(new RuntimeException("Invalid token")); RedirectView redirectView = authController.verifySchoolEmail(request); @@ -361,7 +366,7 @@ void verifySchoolEmailInvalidToken() throws Exception { assertEquals("/settings?success=false&message=Invalid or expired token", redirectView.getUrl()); verify(protector, times(1)).validateSession(request); - verify(jwtClient, times(1)).decode("invalid-token", MagicLink.class); + verify(jwtClient, times(1)).decode("invalid-token", MagicLink.class, "http://localhost:8080"); } @Test @@ -374,9 +379,11 @@ void verifySchoolEmailUserIdMismatch() throws Exception { HttpServletRequest request = mock(HttpServletRequest.class); MagicLink magicLink = new MagicLink("test@myhunter.cuny.edu", "different-user-id"); + when(serverUrlUtils.getUrl()).thenReturn("http://localhost:8080"); when(protector.validateSession(request)).thenReturn(authObj); when(request.getParameter("state")).thenReturn("valid-token"); - when(jwtClient.decode("valid-token", MagicLink.class)).thenReturn(magicLink); + when(jwtClient.decode("valid-token", MagicLink.class, "http://localhost:8080")) + .thenReturn(magicLink); RedirectView redirectView = authController.verifySchoolEmail(request); @@ -384,7 +391,7 @@ void verifySchoolEmailUserIdMismatch() throws Exception { assertEquals("/settings?success=false&message=ID does not match current user", redirectView.getUrl()); verify(protector, times(1)).validateSession(request); - verify(jwtClient, times(1)).decode("valid-token", MagicLink.class); + verify(jwtClient, times(1)).decode("valid-token", MagicLink.class, "http://localhost:8080"); } @Test @@ -397,9 +404,11 @@ void verifySchoolEmailHappyPath() throws Exception { HttpServletRequest request = mock(HttpServletRequest.class); MagicLink magicLink = new MagicLink("test@myhunter.cuny.edu", user.getId()); + when(serverUrlUtils.getUrl()).thenReturn("http://localhost:8080"); when(protector.validateSession(request)).thenReturn(authObj); when(request.getParameter("state")).thenReturn("valid-token"); - when(jwtClient.decode("valid-token", MagicLink.class)).thenReturn(magicLink); + when(jwtClient.decode("valid-token", MagicLink.class, "http://localhost:8080")) + .thenReturn(magicLink); when(userRepository.updateUser(any(User.class))).thenReturn(true); RedirectView redirectView = authController.verifySchoolEmail(request); @@ -408,7 +417,7 @@ void verifySchoolEmailHappyPath() throws Exception { assertEquals("/settings?success=true&message=The email has been verified!", redirectView.getUrl()); verify(protector, times(1)).validateSession(request); - verify(jwtClient, times(1)).decode("valid-token", MagicLink.class); + verify(jwtClient, times(1)).decode("valid-token", MagicLink.class, "http://localhost:8080"); verify(userRepository, times(1)).updateUser(any(User.class)); verify(userTagRepository, times(1)).createTag(any()); } diff --git a/src/test/java/org/patinanetwork/codebloom/common/jwt/JWTTest.java b/src/test/java/org/patinanetwork/codebloom/common/jwt/JWTTest.java index 87118c28d..e9b5643e4 100644 --- a/src/test/java/org/patinanetwork/codebloom/common/jwt/JWTTest.java +++ b/src/test/java/org/patinanetwork/codebloom/common/jwt/JWTTest.java @@ -77,10 +77,55 @@ void testValidCaseExpire() { jwtClient.decode(jwt, JWTTestObject.class); fail("Expected TokenExpiredException was not thrown"); } catch (JWTVerificationException e) { - // Expected exception. return; } catch (Exception e) { fail("Unexpected exception thrown: " + e.getClass().getName()); } } + + @Test + void testValidCaseWithAudience() { + JWTTestObject userTag = createTestableObject(); + String audience = "https://codebloom.patinanetwork.org"; + String jwt = null; + try { + jwt = jwtClient.encode(userTag, Duration.ofMinutes(15), audience); + } catch (JsonProcessingException e) { + e.printStackTrace(); + fail("Failed to create JWT with audience"); + } + assertNotNull(jwt, "JWT is null when it should not be."); + + JWTTestObject reParsedJsonTag = null; + try { + reParsedJsonTag = jwtClient.decode(jwt, JWTTestObject.class, audience); + } catch (JsonProcessingException e) { + e.printStackTrace(); + fail("Failed to parse JWT with audience"); + } + + assertEquals(userTag, reParsedJsonTag); + } + + @Test + void testAudienceMismatchThrows() { + JWTTestObject userTag = createTestableObject(); + String jwt = null; + try { + jwt = jwtClient.encode(userTag, Duration.ofMinutes(15), "https://stg.codebloom.patinanetwork.org"); + } catch (JsonProcessingException e) { + e.printStackTrace(); + fail("Failed to create JWT"); + } + + final String finalJwt = jwt; + try { + jwtClient.decode(finalJwt, JWTTestObject.class, "http://localhost:8080"); + fail("Expected JWTVerificationException was not thrown"); + } catch (JWTVerificationException e) { + // Expected — audience mismatch + } catch (Exception e) { + fail("Unexpected exception thrown: " + e.getClass().getName()); + } + } }