Skip to content
172 changes: 143 additions & 29 deletions core/src/main/java/com/predic8/membrane/core/interceptor/jwt/Jwks.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,79 @@
import com.predic8.membrane.core.transport.http.client.HttpClientConfiguration;
import com.predic8.membrane.core.util.ConfigurationException;
import com.predic8.membrane.core.util.text.TextUtil;
import org.jetbrains.annotations.NotNull;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Resolver instead of HTTPClient.
Resolver uses proxy configuration

import org.jose4j.jwk.RsaJsonWebKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.*;

import static com.predic8.membrane.core.interceptor.jwt.JwtSignInterceptor.DEFAULT_PKEY;
import static java.util.Collections.emptyList;

/**
* @description
* JSON Web Key Set, configured <b>either</b> by an explicit list of JWK <b>or</b> by a list of JWK URIs that will be refreshed periodically.
*/
@MCElement(name="jwks")
public class Jwks {

List<Jwk> jwks = new ArrayList<>();
public static final String DEFAULT_JWK_WARNING = """
\n------------------------------------ DEFAULT JWK IN USE! ------------------------------------
This key is for demonstration purposes only and UNSAFE for production use. \s
---------------------------------------------------------------------------------------------""";
private static final Logger log = LoggerFactory.getLogger(Jwks.class);
final ObjectMapper mapper = new ObjectMapper();

private volatile List<Jwk> jwks = new ArrayList<>(); // this is basically a write-only field, contents are converted to keysByKid ASAP
private volatile HashMap<String, RsaJsonWebKey> keysByKid = new HashMap<>();

String jwksUris;
AuthorizationService authorizationService;
private Router router;

public void init(Router router) {
this.router = router;
if(jwksUris == null || jwksUris.isEmpty()) {
if (jwks.isEmpty())
throw new ConfigurationException("JWKs need to be configured either via JwksUris or Jwks.");
this.keysByKid = buildKeyMap(jwks);
return;
}
if (!jwks.isEmpty())
throw new ConfigurationException("JWKs cannot be set both via JwksUris and Jwks elements.");
setJwks(loadJwks(false));
if (authorizationService != null && authorizationService.getJwksRefreshInterval() > 0) {
router.getTimerManager().schedulePeriodicTask(new TimerTask() {
@Override
public void run() {
try {
List<Jwk> loaded = loadJwks(true);
if (!loaded.isEmpty()) {
setJwks(loaded);
} else {
log.warn("JWKS refresh returned no keys — keeping previous key set.");
}
} catch (Exception e) {
log.error("JWKS refresh failed, will retry on next interval.", e);
}
}
}, authorizationService.getJwksRefreshInterval() * 1_000L, "JWKS Refresh"
);
}
}

public List<Jwk> getJwks() {
return jwks;
}

@MCChildElement
public Jwks setJwks(List<Jwk> jwks) {
this.jwks = jwks;
this.jwks = jwks; // unnecessary, mainly for consistency when debugging
if (router != null) // set in init, so we can't update prior to that call
this.keysByKid = buildKeyMap(jwks);
return this;
}

Expand All @@ -62,33 +114,95 @@ public Jwks setJwksUris(String jwksUris) {
return this;
}

public void init(Router router) {
if(jwksUris == null || jwksUris.isEmpty())
return;
public Optional<RsaJsonWebKey> getKeyByKid(String kid) {
return Optional.ofNullable(keysByKid.get(kid));
}

private HashMap<String, RsaJsonWebKey> buildKeyMap(List<Jwk> jwks) {
var keyMap = jwks.stream()
.map(this::extractRsaJsonWebKey)
.collect(
() -> new HashMap<String, RsaJsonWebKey>(),
(m,e) -> m.put(e.getKeyId(),e),
HashMap::putAll
);
if (keyMap.isEmpty())
throw new RuntimeException("No JWKs given or none resolvable - please specify at least one resolvable JWK");
return keyMap;
}

ObjectMapper mapper = new ObjectMapper();
for (String uri : jwksUris.split(" ")) {
try {
for (Object jwkRaw : parseJwksUriIntoList(router.getResolverMap(), router.getConfiguration().getBaseLocation(), mapper, uri)) {
Jwk jwk = new Jwk();
jwk.setContent(mapper.writeValueAsString(jwkRaw));
this.jwks.add(jwk);
private @NotNull RsaJsonWebKey extractRsaJsonWebKey(Jwk jwk) {
try {
var params = mapper.readValue(jwk.getJwk(router, mapper), new TypeReference<Map<String, Object>>() {});
if (Objects.equals(params.get("p"), DEFAULT_PKEY)) {
log.warn(DEFAULT_JWK_WARNING);
if (router.getConfiguration().isProduction()) {
throw new RuntimeException("Default JWK detected in production environment. Please use a secure key.");
}
} catch (JsonProcessingException e) {
throw new ConfigurationException("Could not parse JWK keys retrieved from %s.".formatted(uri), e);
} catch (ResourceRetrievalException e) {
throw new ConfigurationException("Could not retrieve JWK keys from %s.".formatted(uri), e);
} catch (Exception e) {
throw new RuntimeException(e);
}

return new RsaJsonWebKey(params);
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private List parseJwksUriIntoList(ResolverMap resolverMap, String baseLocation, ObjectMapper mapper, String uri) throws Exception {
InputStream resolve = authorizationService != null ?
authorizationService.resolve(resolverMap, baseLocation, uri) :
resolverMap.resolve(ResolverMap.combine(baseLocation, uri));
return (List) mapper.readValue(resolve, Map.class).get("keys");

private List<Jwk> loadJwks(boolean suppressExceptions) {
return Arrays.stream(jwksUris.split(" "))
.map(uri -> parseJwksUriIntoList(router.getResolverMap(), router.getConfiguration().getBaseLocation(), mapper, uri, suppressExceptions))
.flatMap(l -> l.jwks().stream().map(jwkRaw -> convertToJwk(jwkRaw, mapper, l.uri(), suppressExceptions)))
.filter(Objects::nonNull)
.toList();
}

private static Jwk convertToJwk(Object jwkRaw, ObjectMapper mapper, String uri, boolean suppressExceptions) {
try {
Jwk jwk = new Jwk();
jwk.setContent(mapper.writeValueAsString(jwkRaw));
return jwk;
} catch (JsonProcessingException e) {
String message = "Could not parse JWK keys retrieved from %s.".formatted(uri);
if (suppressExceptions) {
log.error(message);
return null;
} else {
throw new ConfigurationException(message, e);
}
}
}

private record JwkListByUri(String uri, List<?> jwks) {}

private JwkListByUri parseJwksUriIntoList(ResolverMap resolverMap, String baseLocation, ObjectMapper mapper, String uri, boolean suppressExceptions) {
try {
InputStream resolve = authorizationService != null ?
authorizationService.resolve(resolverMap, baseLocation, uri) :
resolverMap.resolve(ResolverMap.combine(baseLocation, uri));
return new JwkListByUri(uri, mapper.convertValue(mapper.readTree(resolve).path("keys"), List.class));
} catch (JsonProcessingException e) {
String message = "Could not parse JWK keys retrieved from %s.".formatted(uri);
if (suppressExceptions) {
log.error(message);
} else {
throw new ConfigurationException(message, e);
}
} catch (ResourceRetrievalException e) {
String message = "Could not retrieve JWK keys from %s.".formatted(uri);
if (suppressExceptions) {
log.error(message);
} else {
throw new ConfigurationException(message, e);
}
} catch (Exception e) {
if (suppressExceptions) {
log.error(e.toString());
log.error(e.getMessage());
} else {
throw new RuntimeException(e);
}
}
return new JwkListByUri(uri, emptyList());
}

public AuthorizationService getAuthorizationService() {
Expand Down Expand Up @@ -146,14 +260,14 @@ public String getJwk(Router router, ObjectMapper mapper) throws IOException {

Map<String,Object> mapped = mapper.readValue(maybeJwk, new TypeReference<>() {});

if(mapped.containsKey("keys"))
if (mapped.containsKey("keys"))
return handleJwks(mapper, mapped);

return maybeJwk;
}

private String handleJwks(ObjectMapper mapper, Map<String, Object> mapped) {
return ((List<Map>)mapped.get("keys")).stream()
return ((List<Map>) mapped.get("keys")).stream()
.filter(m -> m.get("kid").toString().equals(kid))
.map(m -> {
try {
Expand All @@ -162,7 +276,7 @@ private String handleJwks(ObjectMapper mapper, Map<String, Object> mapped) {
throw new RuntimeException(e);
}
})
.findFirst().get();
.findFirst().orElseThrow();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
package com.predic8.membrane.core.interceptor.jwt;

import com.fasterxml.jackson.core.*;
import com.fasterxml.jackson.databind.*;
import com.predic8.membrane.annot.*;
import com.predic8.membrane.core.exceptions.ProblemDetails;
import com.predic8.membrane.core.exchange.*;
Expand All @@ -28,7 +27,6 @@

import static com.predic8.membrane.core.interceptor.Interceptor.Flow.*;
import static com.predic8.membrane.core.interceptor.Outcome.*;
import static com.predic8.membrane.core.interceptor.jwt.JwtSignInterceptor.DEFAULT_PKEY;
import static java.util.EnumSet.*;
import static org.apache.commons.text.StringEscapeUtils.*;

Expand Down Expand Up @@ -62,15 +60,11 @@ public static String ERROR_JWT_VALUE_NOT_PRESENT(String key) {
public static final String ERROR_JWT_VALUE_NOT_PRESENT_ID = "jwt-payload-entry-missing";
private static final Logger log = LoggerFactory.getLogger(JwtAuthInterceptor.class);

final ObjectMapper mapper = new ObjectMapper();
JwtRetriever jwtRetriever;
Jwks jwks;
String expectedAud;
String expectedTid;

// should be used read only after init
// Hashmap done on purpose as only here the read only thread safety is guaranteed
volatile HashMap<String, RsaJsonWebKey> kidToKey;

public JwtAuthInterceptor() {
name = "jwt checker.";
Expand All @@ -84,30 +78,6 @@ public void init() {
jwtRetriever = new HeaderJwtRetriever("Authorization","Bearer");

jwks.init(router);

kidToKey = jwks.getJwks().stream()
.map(jwk -> {
try {
Map params = mapper.readValue(jwk.getJwk(router, mapper), Map.class);
if (Objects.equals(params.get("p"), DEFAULT_PKEY)) {
log.warn("""
\n------------------------------------ DEFAULT JWK IN USE! ------------------------------------
This key is for demonstration purposes only and UNSAFE for production use. \s
---------------------------------------------------------------------------------------------""");
if (router.getConfiguration().isProduction()) {
throw new RuntimeException("Default JWK detected in production environment. Please use a secure key.");
}
}

return new RsaJsonWebKey(params);
} catch (Exception e) {
throw new RuntimeException(e);
}
})
.collect(HashMap::new, (m,e) -> m.put(e.getKeyId(),e), HashMap::putAll);

if (kidToKey.isEmpty())
throw new RuntimeException("No JWKs given or none resolvable - please specify at least one resolvable JWK");
}

@Override
Expand Down Expand Up @@ -156,13 +126,9 @@ public Outcome handleJwt(Exchange exc, String jwt) throws JWTException, JsonProc
var decodedJwt = new JsonWebToken(jwt);
var kid = decodedJwt.getHeader().kid();

if (!kidToKey.containsKey(kid)) {
throw new JWTException(ERROR_UNKNOWN_KEY, ERROR_UNKNOWN_KEY_ID);
}

// we could make it possible that every key is checked instead of having the "kid" field mandatory
// this would then need up to n checks per incoming JWT - could be a performance problem
RsaJsonWebKey key = kidToKey.get(kid);
RsaJsonWebKey key = jwks.getKeyByKid(kid).orElseThrow(() -> new JWTException(ERROR_UNKNOWN_KEY, ERROR_UNKNOWN_KEY_ID));

Map<String, Object> jwtClaims = createValidator(key).processToClaims(jwt).getClaimsMap();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public abstract class AuthorizationService {
private String clientId;
@GuardedBy("lock")
private String clientSecret;
private int jwksRefreshInterval = 24 * 60 * 60;
private JWSSigner JWSSigner;
protected String scope;
private SSLParser sslParser;
Expand Down Expand Up @@ -175,6 +176,16 @@ protected void setClientIdAndSecret(String clientId, String clientSecret) {
}
}

public Integer getJwksRefreshInterval() {
return jwksRefreshInterval;
}

@MCAttribute
public void setJwksRefreshInterval(int jwksRefreshInterval) {
this.jwksRefreshInterval = jwksRefreshInterval;
}


public String getScope() {
return scope;
}
Expand Down
Loading
Loading