diff --git a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java
index 545d27fcfb..f6ee182c10 100644
--- a/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java
+++ b/core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java
@@ -28,19 +28,69 @@
import com.predic8.membrane.core.transport.http.client.HttpClientConfiguration;
import com.predic8.membrane.core.util.ConfigurationException;
import com.predic8.membrane.core.util.text.TextUtil;
+import org.jetbrains.annotations.NotNull;
+import org.jose4j.jwk.RsaJsonWebKey;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.io.InputStream;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
+import static com.predic8.membrane.core.interceptor.jwt.JwtSignInterceptor.DEFAULT_PKEY;
+import static java.util.Collections.emptyList;
+
+/**
+ * @description
+ * JSON Web Key Set, configured either by an explicit list of JWK or by a list of JWK URIs that will be refreshed periodically.
+ */
@MCElement(name="jwks")
public class Jwks {
- List jwks = new ArrayList<>();
+ public static final String DEFAULT_JWK_WARNING = """
+ \n------------------------------------ DEFAULT JWK IN USE! ------------------------------------
+ This key is for demonstration purposes only and UNSAFE for production use. \s
+ ---------------------------------------------------------------------------------------------""";
+ private static final Logger log = LoggerFactory.getLogger(Jwks.class);
+ final ObjectMapper mapper = new ObjectMapper();
+
+ private volatile List jwks = new ArrayList<>(); // this is basically a write-only field, contents are converted to keysByKid ASAP
+ private volatile HashMap keysByKid = new HashMap<>();
+
String jwksUris;
AuthorizationService authorizationService;
+ private Router router;
+
+ public void init(Router router) {
+ this.router = router;
+ if(jwksUris == null || jwksUris.isEmpty()) {
+ if (jwks.isEmpty())
+ throw new ConfigurationException("JWKs need to be configured either via JwksUris or Jwks.");
+ this.keysByKid = buildKeyMap(jwks);
+ return;
+ }
+ if (!jwks.isEmpty())
+ throw new ConfigurationException("JWKs cannot be set both via JwksUris and Jwks elements.");
+ setJwks(loadJwks(false));
+ if (authorizationService != null && authorizationService.getJwksRefreshInterval() > 0) {
+ router.getTimerManager().schedulePeriodicTask(new TimerTask() {
+ @Override
+ public void run() {
+ try {
+ List loaded = loadJwks(true);
+ if (!loaded.isEmpty()) {
+ setJwks(loaded);
+ } else {
+ log.warn("JWKS refresh returned no keys — keeping previous key set.");
+ }
+ } catch (Exception e) {
+ log.error("JWKS refresh failed, will retry on next interval.", e);
+ }
+ }
+ }, authorizationService.getJwksRefreshInterval() * 1_000L, "JWKS Refresh"
+ );
+ }
+ }
public List getJwks() {
return jwks;
@@ -48,7 +98,9 @@ public List getJwks() {
@MCChildElement
public Jwks setJwks(List jwks) {
- this.jwks = jwks;
+ this.jwks = jwks; // unnecessary, mainly for consistency when debugging
+ if (router != null) // set in init, so we can't update prior to that call
+ this.keysByKid = buildKeyMap(jwks);
return this;
}
@@ -62,33 +114,95 @@ public Jwks setJwksUris(String jwksUris) {
return this;
}
- public void init(Router router) {
- if(jwksUris == null || jwksUris.isEmpty())
- return;
+ public Optional getKeyByKid(String kid) {
+ return Optional.ofNullable(keysByKid.get(kid));
+ }
+
+ private HashMap buildKeyMap(List jwks) {
+ var keyMap = jwks.stream()
+ .map(this::extractRsaJsonWebKey)
+ .collect(
+ () -> new HashMap(),
+ (m,e) -> m.put(e.getKeyId(),e),
+ HashMap::putAll
+ );
+ if (keyMap.isEmpty())
+ throw new RuntimeException("No JWKs given or none resolvable - please specify at least one resolvable JWK");
+ return keyMap;
+ }
- ObjectMapper mapper = new ObjectMapper();
- for (String uri : jwksUris.split(" ")) {
- try {
- for (Object jwkRaw : parseJwksUriIntoList(router.getResolverMap(), router.getConfiguration().getBaseLocation(), mapper, uri)) {
- Jwk jwk = new Jwk();
- jwk.setContent(mapper.writeValueAsString(jwkRaw));
- this.jwks.add(jwk);
+ private @NotNull RsaJsonWebKey extractRsaJsonWebKey(Jwk jwk) {
+ try {
+ var params = mapper.readValue(jwk.getJwk(router, mapper), new TypeReference