From 07e2f958c7c7dcf6fc1497a873906f954f79cf10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Br=C3=BCckel?= Date: Mon, 16 Feb 2026 16:54:04 +0100 Subject: [PATCH 01/11] Add JWKS refresh interval support and periodic refresh handling --- .../membrane/core/interceptor/jwt/Jwks.java | 103 +++++++++++++----- .../AuthorizationService.java | 12 ++ 2 files changed, 90 insertions(+), 25 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java index 545d27fcfb..2d925128b1 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java @@ -28,17 +28,24 @@ import com.predic8.membrane.core.transport.http.client.HttpClientConfiguration; import com.predic8.membrane.core.util.ConfigurationException; import com.predic8.membrane.core.util.text.TextUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.InputStream; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; +import java.util.*; +import static java.util.Collections.emptyList; + +/** + * @description + * JSON Web Key Set, configured either by an explicit list of JWK or by a list of JWK URIs that will be refreshed periodically. + */ @MCElement(name="jwks") public class Jwks { - List jwks = new ArrayList<>(); + private static final Logger log = LoggerFactory.getLogger(Jwks.class); + volatile List jwks = new ArrayList<>(); String jwksUris; AuthorizationService authorizationService; @@ -65,30 +72,76 @@ public Jwks setJwksUris(String jwksUris) { public void init(Router router) { if(jwksUris == null || jwksUris.isEmpty()) return; + if (!jwks.isEmpty()) + throw new ConfigurationException("JWKs cannot be set both via JwksUris and Jwks elements."); + this.jwks = loadJwks(router, false); + if (authorizationService.getJwksRefreshInterval() > 0) { + router.getTimerManager().schedulePeriodicTask(new TimerTask() { + @Override + public void run() { + setJwks(loadJwks(router, true)); + } + }, authorizationService.getJwksRefreshInterval() * 1_000L, "JWKS Refresh" + ); + } + } + private List loadJwks(Router router, boolean suppressExceptions) { ObjectMapper mapper = new ObjectMapper(); - for (String uri : jwksUris.split(" ")) { - try { - for (Object jwkRaw : parseJwksUriIntoList(router.getResolverMap(), router.getConfiguration().getBaseLocation(), mapper, uri)) { - Jwk jwk = new Jwk(); - jwk.setContent(mapper.writeValueAsString(jwkRaw)); - this.jwks.add(jwk); - } - } catch (JsonProcessingException e) { - throw new ConfigurationException("Could not parse JWK keys retrieved from %s.".formatted(uri), e); - } catch (ResourceRetrievalException e) { - throw new ConfigurationException("Could not retrieve JWK keys from %s.".formatted(uri), e); - } catch (Exception e) { - throw new RuntimeException(e); + return Arrays.stream(jwksUris.split(" ")) + .map(uri -> parseJwksUriIntoList(router.getResolverMap(), router.getConfiguration().getBaseLocation(), mapper, uri, suppressExceptions)) + .flatMap(l -> l.jwks().stream().map(jwkRaw -> convertToJwk(jwkRaw, mapper, l.uri(), suppressExceptions))) + .filter(Objects::nonNull) + .toList(); + } + + private static Jwk convertToJwk(Object jwkRaw, ObjectMapper mapper, String uri, boolean suppressExceptions) { + try { + Jwk jwk = new Jwk(); + jwk.setContent(mapper.writeValueAsString(jwkRaw)); + return jwk; + } catch (JsonProcessingException e) { + String message = "Could not parse JWK keys retrieved from %s.".formatted(uri); + if (suppressExceptions) { + log.error(message); + return null; + } else { + throw new ConfigurationException(message, e); } } } - private List parseJwksUriIntoList(ResolverMap resolverMap, String baseLocation, ObjectMapper mapper, String uri) throws Exception { - InputStream resolve = authorizationService != null ? - authorizationService.resolve(resolverMap, baseLocation, uri) : - resolverMap.resolve(ResolverMap.combine(baseLocation, uri)); - return (List) mapper.readValue(resolve, Map.class).get("keys"); + private record JwkListByUri(String uri, List jwks) {} + + private JwkListByUri parseJwksUriIntoList(ResolverMap resolverMap, String baseLocation, ObjectMapper mapper, String uri, boolean suppressExceptions) { + try { + InputStream resolve = authorizationService != null ? + authorizationService.resolve(resolverMap, baseLocation, uri) : + resolverMap.resolve(ResolverMap.combine(baseLocation, uri)); + return new JwkListByUri(uri, ((List) mapper.readValue(resolve, Map.class).get("keys"))); + } catch (JsonProcessingException e) { + String message = "Could not parse JWK keys retrieved from %s.".formatted(uri); + if (suppressExceptions) { + log.error(message); + } else { + throw new ConfigurationException(message, e); + } + } catch (ResourceRetrievalException e) { + String message = "Could not retrieve JWK keys from %s.".formatted(uri); + if (suppressExceptions) { + log.error(message); + } else { + throw new ConfigurationException(message, e); + } + } catch (Exception e) { + if (suppressExceptions) { + log.error(e.toString()); + log.error(e.getMessage()); + } else { + throw new RuntimeException(e); + } + } + return new JwkListByUri(uri, emptyList()); } public AuthorizationService getAuthorizationService() { @@ -146,14 +199,14 @@ public String getJwk(Router router, ObjectMapper mapper) throws IOException { Map mapped = mapper.readValue(maybeJwk, new TypeReference<>() {}); - if(mapped.containsKey("keys")) + if (mapped.containsKey("keys")) return handleJwks(mapper, mapped); return maybeJwk; } private String handleJwks(ObjectMapper mapper, Map mapped) { - return ((List)mapped.get("keys")).stream() + return ((List) mapped.get("keys")).stream() .filter(m -> m.get("kid").toString().equals(kid)) .map(m -> { try { @@ -162,7 +215,7 @@ private String handleJwks(ObjectMapper mapper, Map mapped) { throw new RuntimeException(e); } }) - .findFirst().get(); + .findFirst().orElseThrow(); } } diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/authorizationservice/AuthorizationService.java b/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/authorizationservice/AuthorizationService.java index 67b82158ee..b21138b17b 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/authorizationservice/AuthorizationService.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/authorizationservice/AuthorizationService.java @@ -62,6 +62,7 @@ public abstract class AuthorizationService { private String clientId; @GuardedBy("lock") private String clientSecret; + private Integer jwksRefreshInterval = 24 * 60 * 60; private JWSSigner JWSSigner; protected String scope; private SSLParser sslParser; @@ -175,6 +176,17 @@ protected void setClientIdAndSecret(String clientId, String clientSecret) { } } + public Integer getJwksRefreshInterval() { + return jwksRefreshInterval; + } + + @MCAttribute + public AuthorizationService setJwksRefreshInterval(Integer jwksRefreshInterval) { + this.jwksRefreshInterval = jwksRefreshInterval; + return this; + } + + public String getScope() { return scope; } From 0bc75bda2d4caf3f1e0f4f1c859ab35128ffa5ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Br=C3=BCckel?= Date: Tue, 17 Feb 2026 18:02:34 +0100 Subject: [PATCH 02/11] Created JwksRefreshTest and added observer logic to JwtAuthInterceptor --- .../membrane/core/interceptor/jwt/Jwks.java | 16 ++ .../interceptor/jwt/JwtAuthInterceptor.java | 7 + .../core/interceptor/jwt/JwksRefreshTest.java | 187 ++++++++++++++++++ 3 files changed, 210 insertions(+) create mode 100644 core/src/test/java/com/predic8/membrane/core/interceptor/jwt/JwksRefreshTest.java diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java index 2d925128b1..e3baa2147f 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java @@ -48,6 +48,7 @@ public class Jwks { volatile List jwks = new ArrayList<>(); String jwksUris; AuthorizationService authorizationService; + private final List observers = new ArrayList<>(); public List getJwks() { return jwks; @@ -56,6 +57,7 @@ public List getJwks() { @MCChildElement public Jwks setJwks(List jwks) { this.jwks = jwks; + notifyObservers(); return this; } @@ -153,6 +155,20 @@ public void setAuthorizationService(AuthorizationService authService) { authorizationService = authService; } + public void addObserver(Runnable observer) { + observers.add(observer); + } + + private void notifyObservers() { + for (Runnable observer : observers) { + try { + observer.run(); + } catch (Exception e) { + log.error("Error notifying observer", e); + } + } + } + @MCElement(name="jwk", mixed = true, component = false, id="jwks-jwk") public static class Jwk extends Blob { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptor.java index aad69ccbc2..6ce8df16ad 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptor.java @@ -84,7 +84,11 @@ public void init() { jwtRetriever = new HeaderJwtRetriever("Authorization","Bearer"); jwks.init(router); + jwks.addObserver(this::updateKeys); + updateKeys(); + } + private void updateKeys() { kidToKey = jwks.getJwks().stream() .map(jwk -> { try { @@ -108,6 +112,8 @@ public void init() { if (kidToKey.isEmpty()) throw new RuntimeException("No JWKs given or none resolvable - please specify at least one resolvable JWK"); + + log.info("Loaded keys: " + kidToKey.keySet()); } @Override @@ -157,6 +163,7 @@ public Outcome handleJwt(Exchange exc, String jwt) throws JWTException, JsonProc var kid = decodedJwt.getHeader().kid(); if (!kidToKey.containsKey(kid)) { + log.warn("Unknown key: " + kid + ". Available keys: " + kidToKey.keySet()); throw new JWTException(ERROR_UNKNOWN_KEY, ERROR_UNKNOWN_KEY_ID); } diff --git a/core/src/test/java/com/predic8/membrane/core/interceptor/jwt/JwksRefreshTest.java b/core/src/test/java/com/predic8/membrane/core/interceptor/jwt/JwksRefreshTest.java new file mode 100644 index 0000000000..e2a4b5e301 --- /dev/null +++ b/core/src/test/java/com/predic8/membrane/core/interceptor/jwt/JwksRefreshTest.java @@ -0,0 +1,187 @@ +package com.predic8.membrane.core.interceptor.jwt; + +import com.predic8.membrane.core.router.DefaultRouter; +import com.predic8.membrane.core.exchange.Exchange; +import com.predic8.membrane.core.http.Request; +import com.predic8.membrane.core.http.Response; +import com.predic8.membrane.core.interceptor.AbstractInterceptor; +import com.predic8.membrane.core.interceptor.Outcome; +import com.predic8.membrane.core.interceptor.oauth2.authorizationservice.AuthorizationService; +import com.predic8.membrane.core.proxies.ServiceProxy; +import com.predic8.membrane.core.proxies.ServiceProxyKey; +import com.predic8.membrane.core.transport.http.HttpClient; +import org.jetbrains.annotations.NotNull; +import org.jose4j.jwk.JsonWebKeySet; +import org.jose4j.jwk.RsaJsonWebKey; +import org.jose4j.jwk.RsaJwkGenerator; +import org.jose4j.jws.AlgorithmIdentifiers; +import org.jose4j.jws.JsonWebSignature; +import org.jose4j.jwt.JwtClaims; +import org.jose4j.lang.JoseException; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import static com.predic8.membrane.core.interceptor.jwt.JwtAuthInterceptorTest.KID; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class JwksRefreshTest { + + public static final int PROVIDER_PORT = 3000; + public static final int AUTH_INTERCEPTOR_PORT = 3001; + static DefaultRouter jwksProvider; + static DefaultRouter jwtValidator; + + static RsaJsonWebKey privateKey1; + static RsaJsonWebKey publicKey1; + static RsaJsonWebKey privateKey2; + static RsaJsonWebKey publicKey2; + + static final AtomicReference currentJwkSet = new AtomicReference<>(); + + @BeforeAll + public static void setup() throws Exception { + privateKey1 = RsaJwkGenerator.generateJwk(2048); + privateKey1.setKeyId(KID); + publicKey1 = new RsaJsonWebKey(privateKey1.getRsaPublicKey()); + publicKey1.setKeyId(KID); + + privateKey2 = RsaJwkGenerator.generateJwk(2048); + privateKey2.setKeyId(KID + "2"); + publicKey2 = new RsaJsonWebKey(privateKey2.getRsaPublicKey()); + publicKey2.setKeyId(KID + "2"); + + currentJwkSet.set(new JsonWebKeySet(publicKey1)); + + jwksProvider = new DefaultRouter(); + jwksProvider.add(proxyWithInterceptors(PROVIDER_PORT, jwkServingInterceptor(currentJwkSet::get))); + jwksProvider.start(); + + // Wait for jwksProvider to start + Thread.sleep(1000); + + jwtValidator = new DefaultRouter(); + jwtValidator.add(proxyWithInterceptors( + AUTH_INTERCEPTOR_PORT, + jwtAuthInterceptor(), + new AbstractInterceptor() { + @Override + public Outcome handleRequest(Exchange exc) { + exc.setResponse(Response.ok().build()); + return Outcome.RETURN; + } + }) + ); + jwtValidator.start(); + } + + private static @NotNull ServiceProxy proxyWithInterceptors(int port, @NotNull AbstractInterceptor... interceptors) { + var proxy = new ServiceProxy(new ServiceProxyKey(port), null, 0); + Arrays.stream(interceptors).forEach(proxy.getFlow()::add); + return proxy; + } + + private static @NotNull JwtAuthInterceptor jwtAuthInterceptor() { + Jwks jwks = new Jwks(); + jwks.setJwksUris("http://localhost:%d/jwks".formatted(PROVIDER_PORT)); + jwks.setAuthorizationService(buildAuthorizationService(1)); + + JwtAuthInterceptor jwtAuth = new JwtAuthInterceptor(); + jwtAuth.setExpectedAud("some-audience"); + jwtAuth.setJwks(jwks); + + return jwtAuth; + } + + private static @NotNull AbstractInterceptor jwkServingInterceptor(final Supplier jwkSupplier) { + return new AbstractInterceptor() { + @Override + public Outcome handleRequest(Exchange exc) { + exc.setResponse(Response.ok(jwkSupplier.get().toJson()).contentType("application/json").build()); + return Outcome.RETURN; + } + }; + } + + private static @NotNull AuthorizationService buildAuthorizationService(int jwksRefreshInterval) { + AuthorizationService authService = new AuthorizationService() { + @Override public void init() {} + @Override public String getIssuer() { return null; } + @Override public String getJwksEndpoint() { return null; } + @Override public String getEndSessionEndpoint() { return null; } + @Override public String getLoginURL(String callbackURL) { return null; } + @Override public String getUserInfoEndpoint() { return null; } + @Override public String getSubject() { return null; } + @Override protected String getTokenEndpoint() { return null; } + @Override public String getRevocationEndpoint() { return null; } + }; + authService.setJwksRefreshInterval(jwksRefreshInterval); + authService.setHttpClient(new HttpClient()); + return authService; + } + + @AfterAll + public static void teardown() throws IOException { + jwksProvider.stop(); + jwtValidator.stop(); + } + + @Test + public void testRefresh() throws Exception { + String authInterceptorUrl = "http://localhost:%d/".formatted(AUTH_INTERCEPTOR_PORT); + + try (HttpClient hc = new HttpClient()) { + + // 1. initial key works + Exchange exc1 = new Request.Builder() + .get(authInterceptorUrl) + .header("Authorization", "Bearer " + createJwt(privateKey1)) + .buildExchange(); + hc.call(exc1); + assertEquals(200, exc1.getResponse().getStatusCode()); + + // 2. switch keys + currentJwkSet.set(new JsonWebKeySet(publicKey2)); + Thread.sleep(2000); // wait for refresh + + // 3. new key works + Exchange exc3 = new Request.Builder() + .get(authInterceptorUrl) + .header("Authorization", "Bearer " + createJwt(privateKey2)) + .buildExchange(); + hc.call(exc3); + assertEquals(200, exc3.getResponse().getStatusCode()); + + // 4. old key does not work anymore + Exchange exc2 = new Request.Builder() + .get(authInterceptorUrl) + .header("Authorization", "Bearer " + createJwt(privateKey1)) + .buildExchange(); + hc.call(exc2); + assertEquals(400, exc2.getResponse().getStatusCode()); + } + } + + private static String createJwt(RsaJsonWebKey privateKey) throws JoseException { + JwtClaims claims = new JwtClaims(); + claims.setSubject("user"); + claims.setIssuer("some-issuer"); + claims.setAudience("some-audience"); + claims.setExpirationTimeMinutesInTheFuture(10); + claims.setIssuedAtToNow(); + claims.setNotBeforeMinutesInThePast(2); + + JsonWebSignature jws = new JsonWebSignature(); + jws.setPayload(claims.toJson()); + jws.setKey(privateKey.getPrivateKey()); + jws.setKeyIdHeaderValue(privateKey.getKeyId()); + jws.setAlgorithmHeaderValue(AlgorithmIdentifiers.RSA_USING_SHA256); + + return jws.getCompactSerialization(); + } +} From f5bd43996909a78568bf558553047e069cdad2d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Br=C3=BCckel?= Date: Wed, 18 Feb 2026 08:23:45 +0100 Subject: [PATCH 03/11] Remove usages of Jwks.getJwks() in Test initialization --- .../membrane/core/interceptor/jwt/JwtAuthInterceptorTest.java | 3 +-- .../core/interceptor/jwt/JwtAuthInterceptorUnitTests.java | 3 ++- .../security/JWTInterceptorAndSecurityValidatorTest.java | 3 +-- .../predic8/membrane/core/security/JWTSecuritySchemeTest.java | 3 +-- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/core/src/test/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptorTest.java b/core/src/test/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptorTest.java index d9745e415f..c847b54a38 100644 --- a/core/src/test/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptorTest.java +++ b/core/src/test/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptorTest.java @@ -265,8 +265,7 @@ private JwtAuthInterceptor createInterceptor(RsaJsonWebKey publicOnly) { JwtAuthInterceptor interceptor = new JwtAuthInterceptor(); Jwks jwks = new Jwks(); - jwks.setJwks(new ArrayList<>()); - jwks.getJwks().add(jwk); + jwks.setJwks(List.of(jwk)); interceptor.setJwks(jwks); interceptor.setExpectedAud(AUDIENCE); interceptor.setExpectedTid(TENANT_ID); diff --git a/core/src/test/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptorUnitTests.java b/core/src/test/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptorUnitTests.java index 86561118fa..6302bb1747 100644 --- a/core/src/test/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptorUnitTests.java +++ b/core/src/test/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptorUnitTests.java @@ -21,6 +21,7 @@ import com.predic8.membrane.core.http.Request; import org.junit.jupiter.api.*; +import java.util.List; import java.util.Map; import static com.predic8.membrane.core.interceptor.jwt.JwtAuthInterceptor.*; @@ -60,7 +61,7 @@ void noJwtInHeader() { jwk.setContent("{\"kty\":\"RSA\", \"n\":\""+ "B".repeat(1024 * 8 / 6) +"\", \"e\":\"BB\"}"); - jwks.getJwks().add(jwk); + jwks.setJwks(List.of(jwk)); interceptor.setJwks(jwks); interceptor.init(new DummyTestRouter()); interceptor.handleRequest(exchange); diff --git a/core/src/test/java/com/predic8/membrane/core/openapi/validators/security/JWTInterceptorAndSecurityValidatorTest.java b/core/src/test/java/com/predic8/membrane/core/openapi/validators/security/JWTInterceptorAndSecurityValidatorTest.java index 0af18fd664..2f6260b8c2 100644 --- a/core/src/test/java/com/predic8/membrane/core/openapi/validators/security/JWTInterceptorAndSecurityValidatorTest.java +++ b/core/src/test/java/com/predic8/membrane/core/openapi/validators/security/JWTInterceptorAndSecurityValidatorTest.java @@ -123,8 +123,7 @@ private static Jwks getJwks(RsaJsonWebKey publicOnly) { Jwks.Jwk jwk = new Jwks.Jwk(); jwk.setContent(publicOnly.toJson()); Jwks jwks = new Jwks(); - jwks.setJwks(new ArrayList<>()); - jwks.getJwks().add(jwk); + jwks.setJwks(List.of(jwk)); return jwks; } diff --git a/core/src/test/java/com/predic8/membrane/core/security/JWTSecuritySchemeTest.java b/core/src/test/java/com/predic8/membrane/core/security/JWTSecuritySchemeTest.java index 453d5e5952..1d8d10f613 100644 --- a/core/src/test/java/com/predic8/membrane/core/security/JWTSecuritySchemeTest.java +++ b/core/src/test/java/com/predic8/membrane/core/security/JWTSecuritySchemeTest.java @@ -61,8 +61,7 @@ private JwtAuthInterceptor createInterceptor(RsaJsonWebKey publicOnly) { JwtAuthInterceptor interceptor = new JwtAuthInterceptor(); Jwks jwks = new Jwks(); - jwks.setJwks(new ArrayList<>()); - jwks.getJwks().add(jwk); + jwks.setJwks(List.of(jwk)); interceptor.setJwks(jwks); interceptor.setExpectedAud(AUDIENCE); return interceptor; From c247400efc6878bd64801e98a19aaf355b1c1dc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Br=C3=BCckel?= Date: Wed, 18 Feb 2026 09:18:47 +0100 Subject: [PATCH 04/11] Move key HashMap into Jwks and remove observer logic --- .../membrane/core/interceptor/jwt/Jwks.java | 103 ++++++++++++------ .../interceptor/jwt/JwtAuthInterceptor.java | 40 +------ 2 files changed, 72 insertions(+), 71 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java index e3baa2147f..65e14b0d64 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java @@ -28,6 +28,8 @@ import com.predic8.membrane.core.transport.http.client.HttpClientConfiguration; import com.predic8.membrane.core.util.ConfigurationException; import com.predic8.membrane.core.util.text.TextUtil; +import org.jetbrains.annotations.NotNull; +import org.jose4j.jwk.RsaJsonWebKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,6 +37,7 @@ import java.io.InputStream; import java.util.*; +import static com.predic8.membrane.core.interceptor.jwt.JwtSignInterceptor.DEFAULT_PKEY; import static java.util.Collections.emptyList; /** @@ -44,11 +47,41 @@ @MCElement(name="jwks") public class Jwks { + public static final String DEFAULT_JWK_WARNING = """ + \n------------------------------------ DEFAULT JWK IN USE! ------------------------------------ + This key is for demonstration purposes only and UNSAFE for production use. \s + ---------------------------------------------------------------------------------------------"""; private static final Logger log = LoggerFactory.getLogger(Jwks.class); - volatile List jwks = new ArrayList<>(); + final ObjectMapper mapper = new ObjectMapper(); + + private volatile List jwks = new ArrayList<>(); // this is basically a write-only field, contents are converted to keysByKid ASAP + private volatile HashMap keysByKid = new HashMap<>(); + String jwksUris; AuthorizationService authorizationService; - private final List observers = new ArrayList<>(); + private Router router; + + public void init(Router router) { + this.router = router; + if(jwksUris == null || jwksUris.isEmpty()) { + if (jwks.isEmpty()) + throw new ConfigurationException("JWKs need to be configured either via JwksUris or Jwks."); + this.keysByKid = buildKeyMap(jwks); + return; + } + if (!jwks.isEmpty()) + throw new ConfigurationException("JWKs cannot be set both via JwksUris and Jwks elements."); + setJwks(loadJwks(false)); + if (authorizationService.getJwksRefreshInterval() > 0) { + router.getTimerManager().schedulePeriodicTask(new TimerTask() { + @Override + public void run() { + setJwks(loadJwks(true)); + } + }, authorizationService.getJwksRefreshInterval() * 1_000L, "JWKS Refresh" + ); + } + } public List getJwks() { return jwks; @@ -56,8 +89,9 @@ public List getJwks() { @MCChildElement public Jwks setJwks(List jwks) { - this.jwks = jwks; - notifyObservers(); + this.jwks = jwks; // unnecessary, mainly for consistency when debugging + if (router != null) // set in init, so we can't update prior to that call + this.keysByKid = buildKeyMap(jwks); return this; } @@ -71,24 +105,41 @@ public Jwks setJwksUris(String jwksUris) { return this; } - public void init(Router router) { - if(jwksUris == null || jwksUris.isEmpty()) - return; - if (!jwks.isEmpty()) - throw new ConfigurationException("JWKs cannot be set both via JwksUris and Jwks elements."); - this.jwks = loadJwks(router, false); - if (authorizationService.getJwksRefreshInterval() > 0) { - router.getTimerManager().schedulePeriodicTask(new TimerTask() { - @Override - public void run() { - setJwks(loadJwks(router, true)); - } - }, authorizationService.getJwksRefreshInterval() * 1_000L, "JWKS Refresh" - ); + public HashMap getKeysByKid() { + return keysByKid; + } + + private HashMap buildKeyMap(List jwks) { + var keyMap = jwks.stream() + .map(this::extractRsaJsonWebKey) + .collect( + () -> new HashMap(), + (m,e) -> m.put(e.getKeyId(),e), + HashMap::putAll + ); + if (keyMap.isEmpty()) + throw new RuntimeException("No JWKs given or none resolvable - please specify at least one resolvable JWK"); + return keyMap; + } + + private @NotNull RsaJsonWebKey extractRsaJsonWebKey(Jwk jwk) { + try { + var params = mapper.readValue(jwk.getJwk(router, mapper), new TypeReference>() {}); + if (Objects.equals(params.get("p"), DEFAULT_PKEY)) { + log.warn(DEFAULT_JWK_WARNING); + if (router.getConfiguration().isProduction()) { + throw new RuntimeException("Default JWK detected in production environment. Please use a secure key."); + } + } + + return new RsaJsonWebKey(params); + } catch (Exception e) { + throw new RuntimeException(e); } } - private List loadJwks(Router router, boolean suppressExceptions) { + + private List loadJwks(boolean suppressExceptions) { ObjectMapper mapper = new ObjectMapper(); return Arrays.stream(jwksUris.split(" ")) .map(uri -> parseJwksUriIntoList(router.getResolverMap(), router.getConfiguration().getBaseLocation(), mapper, uri, suppressExceptions)) @@ -155,20 +206,6 @@ public void setAuthorizationService(AuthorizationService authService) { authorizationService = authService; } - public void addObserver(Runnable observer) { - observers.add(observer); - } - - private void notifyObservers() { - for (Runnable observer : observers) { - try { - observer.run(); - } catch (Exception e) { - log.error("Error notifying observer", e); - } - } - } - @MCElement(name="jwk", mixed = true, component = false, id="jwks-jwk") public static class Jwk extends Blob { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptor.java index 6ce8df16ad..23a092f369 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptor.java @@ -28,7 +28,6 @@ import static com.predic8.membrane.core.interceptor.Interceptor.Flow.*; import static com.predic8.membrane.core.interceptor.Outcome.*; -import static com.predic8.membrane.core.interceptor.jwt.JwtSignInterceptor.DEFAULT_PKEY; import static java.util.EnumSet.*; import static org.apache.commons.text.StringEscapeUtils.*; @@ -62,15 +61,11 @@ public static String ERROR_JWT_VALUE_NOT_PRESENT(String key) { public static final String ERROR_JWT_VALUE_NOT_PRESENT_ID = "jwt-payload-entry-missing"; private static final Logger log = LoggerFactory.getLogger(JwtAuthInterceptor.class); - final ObjectMapper mapper = new ObjectMapper(); JwtRetriever jwtRetriever; Jwks jwks; String expectedAud; String expectedTid; - // should be used read only after init - // Hashmap done on purpose as only here the read only thread safety is guaranteed - volatile HashMap kidToKey; public JwtAuthInterceptor() { name = "jwt checker."; @@ -84,36 +79,6 @@ public void init() { jwtRetriever = new HeaderJwtRetriever("Authorization","Bearer"); jwks.init(router); - jwks.addObserver(this::updateKeys); - updateKeys(); - } - - private void updateKeys() { - kidToKey = jwks.getJwks().stream() - .map(jwk -> { - try { - Map params = mapper.readValue(jwk.getJwk(router, mapper), Map.class); - if (Objects.equals(params.get("p"), DEFAULT_PKEY)) { - log.warn(""" - \n------------------------------------ DEFAULT JWK IN USE! ------------------------------------ - This key is for demonstration purposes only and UNSAFE for production use. \s - ---------------------------------------------------------------------------------------------"""); - if (router.getConfiguration().isProduction()) { - throw new RuntimeException("Default JWK detected in production environment. Please use a secure key."); - } - } - - return new RsaJsonWebKey(params); - } catch (Exception e) { - throw new RuntimeException(e); - } - }) - .collect(HashMap::new, (m,e) -> m.put(e.getKeyId(),e), HashMap::putAll); - - if (kidToKey.isEmpty()) - throw new RuntimeException("No JWKs given or none resolvable - please specify at least one resolvable JWK"); - - log.info("Loaded keys: " + kidToKey.keySet()); } @Override @@ -162,14 +127,13 @@ public Outcome handleJwt(Exchange exc, String jwt) throws JWTException, JsonProc var decodedJwt = new JsonWebToken(jwt); var kid = decodedJwt.getHeader().kid(); - if (!kidToKey.containsKey(kid)) { - log.warn("Unknown key: " + kid + ". Available keys: " + kidToKey.keySet()); + if (!jwks.getKeysByKid().containsKey(kid)) { throw new JWTException(ERROR_UNKNOWN_KEY, ERROR_UNKNOWN_KEY_ID); } // we could make it possible that every key is checked instead of having the "kid" field mandatory // this would then need up to n checks per incoming JWT - could be a performance problem - RsaJsonWebKey key = kidToKey.get(kid); + RsaJsonWebKey key = jwks.getKeysByKid().get(kid); Map jwtClaims = createValidator(key).processToClaims(jwt).getClaimsMap(); From a0900692a756ef4ef97d7d01f005b6ca231f90d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Br=C3=BCckel?= Date: Wed, 18 Feb 2026 09:37:13 +0100 Subject: [PATCH 05/11] make `jwksRefreshInterval` primitive --- .../oauth2/authorizationservice/AuthorizationService.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/authorizationservice/AuthorizationService.java b/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/authorizationservice/AuthorizationService.java index b21138b17b..3dd0fbe998 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/authorizationservice/AuthorizationService.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/oauth2/authorizationservice/AuthorizationService.java @@ -62,7 +62,7 @@ public abstract class AuthorizationService { private String clientId; @GuardedBy("lock") private String clientSecret; - private Integer jwksRefreshInterval = 24 * 60 * 60; + private int jwksRefreshInterval = 24 * 60 * 60; private JWSSigner JWSSigner; protected String scope; private SSLParser sslParser; @@ -181,9 +181,8 @@ public Integer getJwksRefreshInterval() { } @MCAttribute - public AuthorizationService setJwksRefreshInterval(Integer jwksRefreshInterval) { + public void setJwksRefreshInterval(int jwksRefreshInterval) { this.jwksRefreshInterval = jwksRefreshInterval; - return this; } From 53062be8b954126a2ab277ba88c258fc099532c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Br=C3=BCckel?= Date: Wed, 18 Feb 2026 09:39:56 +0100 Subject: [PATCH 06/11] Fix potential TOCTOU issue --- .../membrane/core/interceptor/jwt/JwtAuthInterceptor.java | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptor.java index 23a092f369..315e02fc15 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptor.java @@ -127,13 +127,10 @@ public Outcome handleJwt(Exchange exc, String jwt) throws JWTException, JsonProc var decodedJwt = new JsonWebToken(jwt); var kid = decodedJwt.getHeader().kid(); - if (!jwks.getKeysByKid().containsKey(kid)) { - throw new JWTException(ERROR_UNKNOWN_KEY, ERROR_UNKNOWN_KEY_ID); - } - // we could make it possible that every key is checked instead of having the "kid" field mandatory // this would then need up to n checks per incoming JWT - could be a performance problem - RsaJsonWebKey key = jwks.getKeysByKid().get(kid); + RsaJsonWebKey key = Optional.ofNullable(jwks.getKeysByKid().get(kid)) + .orElseThrow(() -> new JWTException(ERROR_UNKNOWN_KEY, ERROR_UNKNOWN_KEY_ID)); Map jwtClaims = createValidator(key).processToClaims(jwt).getClaimsMap(); From a8299c9475665340e8a64cf0a523983a88bec3ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Br=C3=BCckel?= Date: Wed, 18 Feb 2026 09:48:26 +0100 Subject: [PATCH 07/11] JSON parsing improvements --- .../java/com/predic8/membrane/core/interceptor/jwt/Jwks.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java index 65e14b0d64..3724e569ae 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java @@ -171,7 +171,7 @@ private JwkListByUri parseJwksUriIntoList(ResolverMap resolverMap, String baseLo InputStream resolve = authorizationService != null ? authorizationService.resolve(resolverMap, baseLocation, uri) : resolverMap.resolve(ResolverMap.combine(baseLocation, uri)); - return new JwkListByUri(uri, ((List) mapper.readValue(resolve, Map.class).get("keys"))); + return new JwkListByUri(uri, mapper.convertValue(mapper.readTree(resolve).path("keys"), List.class)); } catch (JsonProcessingException e) { String message = "Could not parse JWK keys retrieved from %s.".formatted(uri); if (suppressExceptions) { From 456428d6f53f426abe064085da7a2ac389149d9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Br=C3=BCckel?= Date: Wed, 18 Feb 2026 09:51:22 +0100 Subject: [PATCH 08/11] catch potential error conditions --- .../predic8/membrane/core/interceptor/jwt/Jwks.java | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java index 3724e569ae..f3c8e0a25e 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java @@ -72,11 +72,20 @@ public void init(Router router) { if (!jwks.isEmpty()) throw new ConfigurationException("JWKs cannot be set both via JwksUris and Jwks elements."); setJwks(loadJwks(false)); - if (authorizationService.getJwksRefreshInterval() > 0) { + if (authorizationService != null && authorizationService.getJwksRefreshInterval() > 0) { router.getTimerManager().schedulePeriodicTask(new TimerTask() { @Override public void run() { - setJwks(loadJwks(true)); + try { + List loaded = loadJwks(true); + if (!loaded.isEmpty()) { + setJwks(loaded); + } else { + log.warn("JWKS refresh returned no keys — keeping previous key set."); + } + } catch (Exception e) { + log.error("JWKS refresh failed, will retry on next interval.", e); + } } }, authorizationService.getJwksRefreshInterval() * 1_000L, "JWKS Refresh" ); From 5f9b7d66d648c1b23c2f908b952b2a78697db457 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Br=C3=BCckel?= Date: Wed, 18 Feb 2026 09:56:01 +0100 Subject: [PATCH 09/11] mapper is already a field --- .../java/com/predic8/membrane/core/interceptor/jwt/Jwks.java | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java index f3c8e0a25e..ebf7c39d18 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java @@ -149,7 +149,6 @@ private HashMap buildKeyMap(List jwks) { private List loadJwks(boolean suppressExceptions) { - ObjectMapper mapper = new ObjectMapper(); return Arrays.stream(jwksUris.split(" ")) .map(uri -> parseJwksUriIntoList(router.getResolverMap(), router.getConfiguration().getBaseLocation(), mapper, uri, suppressExceptions)) .flatMap(l -> l.jwks().stream().map(jwkRaw -> convertToJwk(jwkRaw, mapper, l.uri(), suppressExceptions))) From 95395fc4ab62d52c8eee0688baa3dc6eac4a7f87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Br=C3=BCckel?= Date: Wed, 18 Feb 2026 10:12:04 +0100 Subject: [PATCH 10/11] provide single-key accessor instead of whole Map --- .../java/com/predic8/membrane/core/interceptor/jwt/Jwks.java | 4 ++-- .../membrane/core/interceptor/jwt/JwtAuthInterceptor.java | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java index ebf7c39d18..f6ee182c10 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java @@ -114,8 +114,8 @@ public Jwks setJwksUris(String jwksUris) { return this; } - public HashMap getKeysByKid() { - return keysByKid; + public Optional getKeyByKid(String kid) { + return Optional.ofNullable(keysByKid.get(kid)); } private HashMap buildKeyMap(List jwks) { diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptor.java b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptor.java index 315e02fc15..5998755c67 100644 --- a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptor.java +++ b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptor.java @@ -13,7 +13,6 @@ package com.predic8.membrane.core.interceptor.jwt; import com.fasterxml.jackson.core.*; -import com.fasterxml.jackson.databind.*; import com.predic8.membrane.annot.*; import com.predic8.membrane.core.exceptions.ProblemDetails; import com.predic8.membrane.core.exchange.*; @@ -129,8 +128,7 @@ public Outcome handleJwt(Exchange exc, String jwt) throws JWTException, JsonProc // we could make it possible that every key is checked instead of having the "kid" field mandatory // this would then need up to n checks per incoming JWT - could be a performance problem - RsaJsonWebKey key = Optional.ofNullable(jwks.getKeysByKid().get(kid)) - .orElseThrow(() -> new JWTException(ERROR_UNKNOWN_KEY, ERROR_UNKNOWN_KEY_ID)); + RsaJsonWebKey key = jwks.getKeyByKid(kid).orElseThrow(() -> new JWTException(ERROR_UNKNOWN_KEY, ERROR_UNKNOWN_KEY_ID)); Map jwtClaims = createValidator(key).processToClaims(jwt).getClaimsMap(); From bfb09aa9ce8ea5390e76417462cdd169d3fef369 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Br=C3=BCckel?= Date: Wed, 18 Feb 2026 10:14:01 +0100 Subject: [PATCH 11/11] make sleep dependent on configured refresh interval --- .../membrane/core/interceptor/jwt/JwksRefreshTest.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/test/java/com/predic8/membrane/core/interceptor/jwt/JwksRefreshTest.java b/core/src/test/java/com/predic8/membrane/core/interceptor/jwt/JwksRefreshTest.java index e2a4b5e301..283397a5e3 100644 --- a/core/src/test/java/com/predic8/membrane/core/interceptor/jwt/JwksRefreshTest.java +++ b/core/src/test/java/com/predic8/membrane/core/interceptor/jwt/JwksRefreshTest.java @@ -34,6 +34,7 @@ public class JwksRefreshTest { public static final int PROVIDER_PORT = 3000; public static final int AUTH_INTERCEPTOR_PORT = 3001; + public static final int JWKS_REFRESH_INTERVAL = 1; static DefaultRouter jwksProvider; static DefaultRouter jwtValidator; @@ -89,7 +90,7 @@ public Outcome handleRequest(Exchange exc) { private static @NotNull JwtAuthInterceptor jwtAuthInterceptor() { Jwks jwks = new Jwks(); jwks.setJwksUris("http://localhost:%d/jwks".formatted(PROVIDER_PORT)); - jwks.setAuthorizationService(buildAuthorizationService(1)); + jwks.setAuthorizationService(buildAuthorizationService(JWKS_REFRESH_INTERVAL)); JwtAuthInterceptor jwtAuth = new JwtAuthInterceptor(); jwtAuth.setExpectedAud("some-audience"); @@ -147,7 +148,7 @@ public void testRefresh() throws Exception { // 2. switch keys currentJwkSet.set(new JsonWebKeySet(publicKey2)); - Thread.sleep(2000); // wait for refresh + Thread.sleep(JWKS_REFRESH_INTERVAL * 1_000 * 2); // wait for refresh // 3. new key works Exchange exc3 = new Request.Builder()