Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Switch from org.apache.cxf.rs.security.jose to com.nimbusds.jose.jwk. #3595

Merged
merged 2 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ dependencies {
implementation 'commons-cli:commons-cli:1.5.0'
implementation "org.bouncycastle:bcprov-jdk15to18:${versions.bouncycastle}"
implementation 'org.ldaptive:ldaptive:1.2.3'
implementation 'com.nimbusds:nimbus-jose-jwt:9.31'

//JWT
implementation "io.jsonwebtoken:jjwt-api:${jjwt_version}"
Expand All @@ -587,9 +588,6 @@ dependencies {

runtimeOnly 'net.minidev:accessors-smart:2.4.7'

implementation("org.apache.cxf:cxf-rt-rs-security-jose:${apache_cxf_version}") {
exclude(group: 'jakarta.activation', module: 'jakarta.activation-api')
}
runtimeOnly "org.apache.cxf:cxf-core:${apache_cxf_version}"
implementation "org.apache.cxf:cxf-rt-rs-json-basic:${apache_cxf_version}"
runtimeOnly "org.apache.cxf:cxf-rt-security:${apache_cxf_version}"
Expand Down
1 change: 1 addition & 0 deletions checkstyle/checkstyle.xml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
</module>
<module name="IllegalImport"> <!-- defaults to sun.* packages -->
<property name="severity" value="error"/>
<property name="illegalPkgs" value="org.apache.cxf.rs.security.jose"/>
</module>
<module name="RedundantImport">
<property name="severity" value="error"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.text.ParseException;
import java.util.Collection;
import java.util.Map;
import java.util.Optional;
import java.util.Map.Entry;
import java.util.regex.Pattern;

import com.google.common.annotations.VisibleForTesting;
import org.apache.cxf.rs.security.jose.jwt.JwtClaims;
import org.apache.cxf.rs.security.jose.jwt.JwtToken;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import org.apache.http.HttpStatus;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand Down Expand Up @@ -112,37 +113,34 @@ private AuthCredentials extractCredentials0(final SecurityRequest request) throw
return null;
}

JwtToken jwt;
SignedJWT jwt;
JWTClaimsSet claimsSet;

try {
jwt = jwtVerifier.getVerifiedJwtToken(jwtString);
claimsSet = jwt.getJWTClaimsSet();
} catch (AuthenticatorUnavailableException e) {
log.info(e.toString());
throw new OpenSearchSecurityException(e.getMessage(), RestStatus.SERVICE_UNAVAILABLE);
} catch (BadCredentialsException e) {
} catch (BadCredentialsException | ParseException e) {
log.info("Extracting JWT token from {} failed", jwtString, e);
return null;
}

JwtClaims claims = jwt.getClaims();

final String subject = extractSubject(claims);

final String subject = extractSubject(claimsSet);
if (subject == null) {
log.error("No subject found in JWT token");
return null;
}

final String[] roles = extractRoles(claims);

final String[] roles = extractRoles(claimsSet);
final AuthCredentials ac = new AuthCredentials(subject, roles).markComplete();

for (Entry<String, Object> claim : claims.asMap().entrySet()) {
for (Entry<String, Object> claim : claimsSet.getClaims().entrySet()) {
ac.addAttribute("attr.jwt." + claim.getKey(), String.valueOf(claim.getValue()));
}

return ac;

}

protected String getJwtTokenString(SecurityRequest request) {
Expand Down Expand Up @@ -174,7 +172,7 @@ protected String getJwtTokenString(SecurityRequest request) {
}

@VisibleForTesting
public String extractSubject(JwtClaims claims) {
public String extractSubject(JWTClaimsSet claims) {
String subject = claims.getSubject();

if (subjectKey != null) {
Expand Down Expand Up @@ -204,7 +202,7 @@ public String extractSubject(JwtClaims claims) {

@SuppressWarnings("unchecked")
@VisibleForTesting
public String[] extractRoles(JwtClaims claims) {
public String[] extractRoles(JWTClaimsSet claims) {
if (rolesKey == null) {
return new String[0];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,24 @@
package com.amazon.dlic.auth.http.jwt.keybyoidc;

import com.google.common.base.Strings;
import com.nimbusds.jose.Algorithm;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSVerifier;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jose.crypto.factories.DefaultJWSVerifierFactory;
import com.nimbusds.jose.proc.SimpleSecurityContext;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.jwt.proc.BadJWTException;
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier;
import org.apache.commons.lang3.StringEscapeUtils;
import org.apache.cxf.rs.security.jose.jwa.SignatureAlgorithm;
import org.apache.cxf.rs.security.jose.jwk.JsonWebKey;
import org.apache.cxf.rs.security.jose.jwk.KeyType;
import org.apache.cxf.rs.security.jose.jwk.PublicKeyUse;
import org.apache.cxf.rs.security.jose.jws.JwsJwtCompactConsumer;
import org.apache.cxf.rs.security.jose.jws.JwsSignatureVerifier;
import org.apache.cxf.rs.security.jose.jws.JwsUtils;
import org.apache.cxf.rs.security.jose.jwt.JwtClaims;
import org.apache.cxf.rs.security.jose.jwt.JwtException;
import org.apache.cxf.rs.security.jose.jwt.JwtToken;
import org.apache.cxf.rs.security.jose.jwt.JwtUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.text.ParseException;
import java.util.Collections;

public class JwtVerifier {

private final static Logger log = LogManager.getLogger(JwtVerifier.class);
Expand All @@ -43,31 +46,24 @@
this.requiredAudience = requiredAudience;
}

public JwtToken getVerifiedJwtToken(String encodedJwt) throws BadCredentialsException {
public SignedJWT getVerifiedJwtToken(String encodedJwt) throws BadCredentialsException {
try {
JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(encodedJwt);
JwtToken jwt = jwtConsumer.getJwtToken();
SignedJWT jwt = SignedJWT.parse(encodedJwt);

String escapedKid = jwt.getJwsHeaders().getKeyId();
String escapedKid = jwt.getHeader().getKeyID();
String kid = escapedKid;
if (!Strings.isNullOrEmpty(kid)) {
kid = StringEscapeUtils.unescapeJava(escapedKid);
}
JsonWebKey key = keyProvider.getKey(kid);

// Algorithm is not mandatory for the key material, so we set it to the same as the JWT
if (key.getAlgorithm() == null && key.getPublicKeyUse() == PublicKeyUse.SIGN && key.getKeyType() == KeyType.RSA) {
key.setAlgorithm(jwt.getJwsHeaders().getAlgorithm());
}

JwsSignatureVerifier signatureVerifier = getInitializedSignatureVerifier(key, jwt);
JWK key = keyProvider.getKey(kid);

boolean signatureValid = jwtConsumer.verifySignatureWith(signatureVerifier);
JWSVerifier signatureVerifier = getInitializedSignatureVerifier(key, jwt);
boolean signatureValid = jwt.verify(signatureVerifier);

if (!signatureValid && Strings.isNullOrEmpty(kid)) {
key = keyProvider.getKeyAfterRefresh(null);
signatureVerifier = getInitializedSignatureVerifier(key, jwt);
signatureValid = jwtConsumer.verifySignatureWith(signatureVerifier);
signatureValid = jwt.verify(signatureVerifier);

Check warning on line 66 in src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/JwtVerifier.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/JwtVerifier.java#L66

Added line #L66 was not covered by tests
}

if (!signatureValid) {
Expand All @@ -77,18 +73,18 @@
validateClaims(jwt);

return jwt;
} catch (JwtException e) {
} catch (JOSEException | ParseException | BadJWTException e) {
throw new BadCredentialsException(e.getMessage(), e);
}
}

private void validateSignatureAlgorithm(JsonWebKey key, JwtToken jwt) throws BadCredentialsException {
if (Strings.isNullOrEmpty(key.getAlgorithm())) {
private void validateSignatureAlgorithm(JWK key, SignedJWT jwt) throws BadCredentialsException {
if (key.getAlgorithm() == null || jwt.getHeader().getAlgorithm() == null) {
return;
}

SignatureAlgorithm keyAlgorithm = SignatureAlgorithm.getAlgorithm(key.getAlgorithm());
SignatureAlgorithm tokenAlgorithm = SignatureAlgorithm.getAlgorithm(jwt.getJwsHeaders().getAlgorithm());
Algorithm keyAlgorithm = key.getAlgorithm();
Algorithm tokenAlgorithm = jwt.getHeader().getAlgorithm();

if (!keyAlgorithm.equals(tokenAlgorithm)) {
throw new BadCredentialsException(
Expand All @@ -97,38 +93,48 @@
}
}

private JwsSignatureVerifier getInitializedSignatureVerifier(JsonWebKey key, JwtToken jwt) throws BadCredentialsException,
JwtException {
private JWSVerifier getInitializedSignatureVerifier(JWK key, SignedJWT jwt) throws BadCredentialsException, JOSEException {

validateSignatureAlgorithm(key, jwt);
JwsSignatureVerifier result = JwsUtils.getSignatureVerifier(key, jwt.getJwsHeaders().getSignatureAlgorithm());
final JWSVerifier result;
if (key.getClass() == OctetSequenceKey.class) {
result = new DefaultJWSVerifierFactory().createJWSVerifier(jwt.getHeader(), key.toOctetSequenceKey().toSecretKey());
} else {
result = new DefaultJWSVerifierFactory().createJWSVerifier(jwt.getHeader(), key.toRSAKey().toRSAPublicKey());
}

if (result == null) {
throw new BadCredentialsException("Cannot verify JWT");
} else {
return result;
}
}

private void validateClaims(JwtToken jwt) throws JwtException {
JwtClaims claims = jwt.getClaims();
private void validateClaims(SignedJWT jwt) throws ParseException, BadJWTException {
JWTClaimsSet claims = jwt.getJWTClaimsSet();

if (claims != null) {
JwtUtils.validateJwtExpiry(claims, clockSkewToleranceSeconds, false);
JwtUtils.validateJwtNotBefore(claims, clockSkewToleranceSeconds, false);
DefaultJWTClaimsVerifier<SimpleSecurityContext> claimsVerifier = new DefaultJWTClaimsVerifier<>(
requiredAudience,
null,
Collections.emptySet()
);
claimsVerifier.setMaxClockSkew(clockSkewToleranceSeconds);
claimsVerifier.verify(claims, null);
validateRequiredAudienceAndIssuer(claims);
}
}

private void validateRequiredAudienceAndIssuer(JwtClaims claims) {
String audience = claims.getAudience();
private void validateRequiredAudienceAndIssuer(JWTClaimsSet claims) throws BadJWTException {
String audience = claims.getAudience().stream().findFirst().orElse("");
String issuer = claims.getIssuer();

if (!Strings.isNullOrEmpty(requiredAudience) && !requiredAudience.equals(audience)) {
throw new JwtException("Invalid audience");
throw new BadJWTException("Invalid audience");

Check warning on line 133 in src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/JwtVerifier.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/JwtVerifier.java#L133

Added line #L133 was not covered by tests
}

if (!Strings.isNullOrEmpty(requiredIssuer) && !requiredIssuer.equals(issuer)) {
throw new JwtException("Invalid issuer");
throw new BadJWTException("Invalid issuer");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

package com.amazon.dlic.auth.http.jwt.keybyoidc;

import org.apache.cxf.rs.security.jose.jwk.JsonWebKey;
import com.nimbusds.jose.jwk.JWK;

public interface KeyProvider {
public JsonWebKey getKey(String kid) throws AuthenticatorUnavailableException, BadCredentialsException;
JWK getKey(String kid) throws AuthenticatorUnavailableException, BadCredentialsException;

public JsonWebKey getKeyAfterRefresh(String kid) throws AuthenticatorUnavailableException, BadCredentialsException;
JWK getKeyAfterRefresh(String kid) throws AuthenticatorUnavailableException, BadCredentialsException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

package com.amazon.dlic.auth.http.jwt.keybyoidc;

import org.apache.cxf.rs.security.jose.jwk.JsonWebKeys;
import com.nimbusds.jose.jwk.JWKSet;

@FunctionalInterface
public interface KeySetProvider {
JsonWebKeys get() throws AuthenticatorUnavailableException;
JWKSet get() throws AuthenticatorUnavailableException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
package com.amazon.dlic.auth.http.jwt.keybyoidc;

import java.io.IOException;
import java.text.ParseException;

import com.nimbusds.jose.jwk.JWKSet;
import joptsimple.internal.Strings;
import org.apache.cxf.rs.security.jose.jwk.JsonWebKeys;
import org.apache.cxf.rs.security.jose.jwk.JwkUtils;
import org.apache.http.HttpEntity;
import org.apache.http.StatusLine;
import org.apache.http.client.cache.HttpCacheContext;
Expand Down Expand Up @@ -68,7 +68,7 @@
configureCache(useCacheForOidConnectEndpoint);
}

public JsonWebKeys get() throws AuthenticatorUnavailableException {
public JWKSet get() throws AuthenticatorUnavailableException {
String uri = getJwksUri();

try (CloseableHttpClient httpClient = createHttpClient(null)) {
Expand All @@ -95,10 +95,11 @@
if (httpEntity == null) {
throw new AuthenticatorUnavailableException("Error while getting " + uri + ": Empty response entity");
}

JsonWebKeys keySet = JwkUtils.readJwkSet(httpEntity.getContent());
JWKSet keySet = JWKSet.load(httpEntity.getContent());

return keySet;
} catch (ParseException e) {
throw new RuntimeException(e);

Check warning on line 102 in src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/KeySetRetriever.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/KeySetRetriever.java#L101-L102

Added lines #L101 - L102 were not covered by tests
}
} catch (IOException e) {
throw new AuthenticatorUnavailableException("Error while getting " + uri + ": " + e, e);
Expand Down
Loading
Loading