Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ public ResponseEntity<ApiResponder<Empty>> enrollSchool(

MagicLink magicLink = new MagicLink(email, userId);
try {
String token = jwtClient.encode(magicLink, Duration.ofHours(1));
String token = jwtClient.encode(magicLink, Duration.ofHours(1), serverUrlUtils.getUrl());
String verificationLink = serverUrlUtils.getUrl() + "/api/auth/school/verify?state=" + token;
emailClient.sendMessage(SendEmailOptions.builder()
.recipientEmail(email)
Expand Down Expand Up @@ -275,7 +275,7 @@ public RedirectView verifySchoolEmail(final HttpServletRequest request) {
String token = request.getParameter("state");
MagicLink magicLink;
try {
magicLink = jwtClient.decode(token, MagicLink.class);
magicLink = jwtClient.decode(token, MagicLink.class, serverUrlUtils.getUrl());
} catch (Exception e) {
return new RedirectView("/settings?success=false&message=Invalid or expired token");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,11 @@ public interface JWTClient {

/** Parse the JWT token back into a valid Object. Will throw if expired or unable to verify JWT. */
<T> T decode(String token, Class<T> clazz) throws JsonProcessingException, JWTVerificationException;

/** Create the JWT token with an audience claim, binding it to a specific server. */
<T> String encode(T obj, Duration expiresIn, String audience) throws JsonProcessingException;

/** Parse the JWT token, enforcing audience match. Will throw if expired, invalid, or audience mismatch. */
<T> T decode(String token, Class<T> clazz, String expectedAudience)
throws JsonProcessingException, JWTVerificationException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,27 @@ public <T> T decode(final String token, final Class<T> clazz)

return objectMapper.readValue(payloadString, clazz);
}

/** Create the JWT token with an audience claim, binding it to a specific server. */
public <T> String encode(final T obj, final Duration expiresIn, final String audience)
throws JsonProcessingException {
String payload = objectMapper.writeValueAsString(obj);
return JWT.create()
.withClaim("payload", payload)
.withAudience(audience)
.withExpiresAt(Instant.now().plus(expiresIn))
.sign(algorithm);
}

/** Parse the JWT token, enforcing audience match. Will throw if expired, invalid, or audience mismatch. */
public <T> T decode(final String token, final Class<T> clazz, final String expectedAudience)
throws JsonProcessingException, JWTVerificationException {
DecodedJWT decodedJWT =
JWT.require(algorithm).withAudience(expectedAudience).build().verify(token);
String payloadString = decodedJWT.getClaim("payload").asString();
if (payloadString == null) {
return null;
}
return objectMapper.readValue(payloadString, clazz);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ void enrollSchoolEmailSendFailure() throws Exception {
EmailBody emailBody = new EmailBody("test@myhunter.cuny.edu");

when(protector.validateSession(request)).thenReturn(authObj);
when(jwtClient.encode(any(MagicLink.class), any(Duration.class))).thenReturn("mock-token");
when(jwtClient.encode(any(MagicLink.class), any(Duration.class), any(String.class)))
.thenReturn("mock-token");
when(serverUrlUtils.getUrl()).thenReturn("http://localhost:8080");
when(reactEmailTemplater.schoolEmailTemplate(any())).thenReturn("<html>Template</html>");
doThrow(new EmailException("Failed to send email")).when(emailClient).sendMessage(any(SendEmailOptions.class));
Expand All @@ -309,7 +310,8 @@ void enrollSchoolHappyPath() throws Exception {

when(protector.validateSession(request)).thenReturn(authObj);
when(simpleRedis.containsKey(user.getId())).thenReturn(false);
when(jwtClient.encode(any(MagicLink.class), any(Duration.class))).thenReturn("mock-token");
when(jwtClient.encode(any(MagicLink.class), any(Duration.class), any(String.class)))
.thenReturn("mock-token");
when(serverUrlUtils.getUrl()).thenReturn("http://localhost:8080");
when(reactEmailTemplater.schoolEmailTemplate(any())).thenReturn("<html>Template</html>");

Expand All @@ -332,6 +334,7 @@ void enrollSchoolHappyPath() throws Exception {
void verifySchoolEmailNotAuthenticated() {
HttpServletRequest request = mock(HttpServletRequest.class);

when(serverUrlUtils.getUrl()).thenReturn("http://localhost:8080");
when(protector.validateSession(request)).thenThrow(new RuntimeException("Not authenticated"));

RedirectView redirectView = authController.verifySchoolEmail(request);
Expand All @@ -351,17 +354,19 @@ void verifySchoolEmailInvalidToken() throws Exception {

HttpServletRequest request = mock(HttpServletRequest.class);

when(serverUrlUtils.getUrl()).thenReturn("http://localhost:8080");
when(protector.validateSession(request)).thenReturn(authObj);
when(request.getParameter("state")).thenReturn("invalid-token");
when(jwtClient.decode("invalid-token", MagicLink.class)).thenThrow(new RuntimeException("Invalid token"));
when(jwtClient.decode("invalid-token", MagicLink.class, "http://localhost:8080"))
.thenThrow(new RuntimeException("Invalid token"));

RedirectView redirectView = authController.verifySchoolEmail(request);

assertNotNull(redirectView);
assertEquals("/settings?success=false&message=Invalid or expired token", redirectView.getUrl());

verify(protector, times(1)).validateSession(request);
verify(jwtClient, times(1)).decode("invalid-token", MagicLink.class);
verify(jwtClient, times(1)).decode("invalid-token", MagicLink.class, "http://localhost:8080");
}

@Test
Expand All @@ -374,17 +379,19 @@ void verifySchoolEmailUserIdMismatch() throws Exception {
HttpServletRequest request = mock(HttpServletRequest.class);
MagicLink magicLink = new MagicLink("test@myhunter.cuny.edu", "different-user-id");

when(serverUrlUtils.getUrl()).thenReturn("http://localhost:8080");
when(protector.validateSession(request)).thenReturn(authObj);
when(request.getParameter("state")).thenReturn("valid-token");
when(jwtClient.decode("valid-token", MagicLink.class)).thenReturn(magicLink);
when(jwtClient.decode("valid-token", MagicLink.class, "http://localhost:8080"))
.thenReturn(magicLink);

RedirectView redirectView = authController.verifySchoolEmail(request);

assertNotNull(redirectView);
assertEquals("/settings?success=false&message=ID does not match current user", redirectView.getUrl());

verify(protector, times(1)).validateSession(request);
verify(jwtClient, times(1)).decode("valid-token", MagicLink.class);
verify(jwtClient, times(1)).decode("valid-token", MagicLink.class, "http://localhost:8080");
}

@Test
Expand All @@ -397,9 +404,11 @@ void verifySchoolEmailHappyPath() throws Exception {
HttpServletRequest request = mock(HttpServletRequest.class);
MagicLink magicLink = new MagicLink("test@myhunter.cuny.edu", user.getId());

when(serverUrlUtils.getUrl()).thenReturn("http://localhost:8080");
when(protector.validateSession(request)).thenReturn(authObj);
when(request.getParameter("state")).thenReturn("valid-token");
when(jwtClient.decode("valid-token", MagicLink.class)).thenReturn(magicLink);
when(jwtClient.decode("valid-token", MagicLink.class, "http://localhost:8080"))
.thenReturn(magicLink);
when(userRepository.updateUser(any(User.class))).thenReturn(true);

RedirectView redirectView = authController.verifySchoolEmail(request);
Expand All @@ -408,7 +417,7 @@ void verifySchoolEmailHappyPath() throws Exception {
assertEquals("/settings?success=true&message=The email has been verified!", redirectView.getUrl());

verify(protector, times(1)).validateSession(request);
verify(jwtClient, times(1)).decode("valid-token", MagicLink.class);
verify(jwtClient, times(1)).decode("valid-token", MagicLink.class, "http://localhost:8080");
verify(userRepository, times(1)).updateUser(any(User.class));
verify(userTagRepository, times(1)).createTag(any());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,55 @@ void testValidCaseExpire() {
jwtClient.decode(jwt, JWTTestObject.class);
fail("Expected TokenExpiredException was not thrown");
} catch (JWTVerificationException e) {
// Expected exception.
return;
} catch (Exception e) {
fail("Unexpected exception thrown: " + e.getClass().getName());
}
}

@Test
void testValidCaseWithAudience() {
JWTTestObject userTag = createTestableObject();
String audience = "https://codebloom.patinanetwork.org";
String jwt = null;
try {
jwt = jwtClient.encode(userTag, Duration.ofMinutes(15), audience);
} catch (JsonProcessingException e) {
e.printStackTrace();
fail("Failed to create JWT with audience");
}
assertNotNull(jwt, "JWT is null when it should not be.");

JWTTestObject reParsedJsonTag = null;
try {
reParsedJsonTag = jwtClient.decode(jwt, JWTTestObject.class, audience);
} catch (JsonProcessingException e) {
e.printStackTrace();
fail("Failed to parse JWT with audience");
}

assertEquals(userTag, reParsedJsonTag);
}

@Test
void testAudienceMismatchThrows() {
JWTTestObject userTag = createTestableObject();
String jwt = null;
try {
jwt = jwtClient.encode(userTag, Duration.ofMinutes(15), "https://stg.codebloom.patinanetwork.org");
} catch (JsonProcessingException e) {
e.printStackTrace();
fail("Failed to create JWT");
}

final String finalJwt = jwt;
try {
jwtClient.decode(finalJwt, JWTTestObject.class, "http://localhost:8080");
fail("Expected JWTVerificationException was not thrown");
} catch (JWTVerificationException e) {
// Expected — audience mismatch
} catch (Exception e) {
fail("Unexpected exception thrown: " + e.getClass().getName());
}
}
}