Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jans-auth-server): added reference_id to JWT tokens - both encoded into jwt and as separate attribute in persistence #8512 #8516

Merged
merged 7 commits into from
May 17, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public abstract class AbstractAuthorizationGrant implements IAuthorizationGrant
private String codeChallengeMethod;
private String claims;
private String dpopJkt;
private String referenceId;

private String acrValues;
private String sessionDn;
Expand All @@ -98,6 +99,14 @@ protected void init(User user, AuthorizationGrantType authorizationGrantType, Cl
this.grantId = UUID.randomUUID().toString();
}

public String getReferenceId() {
return referenceId;
}

public void setReferenceId(String referenceId) {
this.referenceId = referenceId;
}

public String getDpopJkt() {
return dpopJkt;
}
Expand Down Expand Up @@ -340,6 +349,7 @@ public AccessToken createAccessToken(ExecutionContext executionContext) {

accessToken.setSessionDn(getSessionDn());
accessToken.setX5ts256(CertUtils.confirmationMethodHashS256(executionContext.getCertAsPem()));
accessToken.setReferenceId(executionContext.getTokenReferenceId());

final String dpop = executionContext.getDpop();
if (StringUtils.isNoneBlank(dpop)) {
Expand All @@ -352,6 +362,8 @@ public AccessToken createAccessToken(ExecutionContext executionContext) {

@Override
public RefreshToken createRefreshToken(ExecutionContext context) {
context.generateRandomTokenReferenceId();

int lifetime = appConfiguration.getRefreshTokenLifetime();
if (client.getRefreshTokenLifetime() != null && client.getRefreshTokenLifetime() > 0) {
lifetime = client.getRefreshTokenLifetime();
Expand All @@ -361,15 +373,20 @@ public RefreshToken createRefreshToken(ExecutionContext context) {

refreshToken.setSessionDn(getSessionDn());
refreshToken.setDpop(context.getDpop());
refreshToken.setReferenceId(context.getTokenReferenceId());

return refreshToken;
}

@Override
public RefreshToken createRefreshToken(ExecutionContext context, int lifetime) {
context.generateRandomTokenReferenceId();

RefreshToken refreshToken = new RefreshToken(lifetime);

refreshToken.setSessionDn(getSessionDn());
refreshToken.setDpop(context.getDpop());
refreshToken.setReferenceId(context.getTokenReferenceId());

return refreshToken;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import io.jans.as.model.crypto.signature.SignatureAlgorithm;
import io.jans.as.model.util.HashUtil;
import io.jans.as.model.util.Util;
import io.jans.as.server.model.token.HandleTokenFactory;
import io.jans.as.server.util.ServerUtil;
import io.jans.orm.annotation.AttributeName;
Expand Down Expand Up @@ -56,6 +55,9 @@ public abstract class AbstractToken implements Serializable, Deletable {
@AttributeName(name = "dpop")
private String dpop;

@AttributeName(name = "jansId")
private String referenceId;

@Expiration
private int ttl;

Expand Down Expand Up @@ -212,6 +214,24 @@ public synchronized void setRevoked(boolean revoked) {
this.revoked = revoked;
}

/**
* Gets reference id
*
* @return reference id
*/
public String getReferenceId() {
return referenceId;
}

/**
* Sets reference id
*
* @param referenceId reference id
*/
public void setReferenceId(String referenceId) {
this.referenceId = referenceId;
}

/**
* Return <code>true</code> if the token has expired.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ private IdToken createIdTokenInternal(AuthorizationCode authorizationCode, Acces
JsonWebResponse jwr = idTokenFactory.createJwr(this, authorizationCode, accessToken, refreshToken, executionContext);
final IdToken idToken = new IdToken(jwr.toString(), jwr.getClaims().getClaimAsDate(JwtClaimName.ISSUED_AT),
jwr.getClaims().getClaimAsDate(JwtClaimName.EXPIRATION_TIME));
idToken.setReferenceId(executionContext.getTokenReferenceId());
if (log.isTraceEnabled())
log.trace("Created id_token: {}", idToken.getCode());
return idToken;
Expand Down Expand Up @@ -202,6 +203,7 @@ private void initTokenFromGrant(TokenEntity token) {
public AccessToken createAccessToken(ExecutionContext context) {
try {
context.initFromGrantIfNeeded(this);
context.generateRandomTokenReferenceId();

final AccessToken accessToken = super.createAccessToken(context);
if (accessToken.getExpiresIn() < 0) {
Expand Down Expand Up @@ -282,6 +284,7 @@ public JwtSigner createAccessTokenAsJwt(AccessToken accessToken, ExecutionContex
jwt.getClaims().setIssuedAt(accessToken.getCreationDate());
jwt.getClaims().setSubjectIdentifier(getSub());
jwt.getClaims().setClaim("x5t#S256", accessToken.getX5ts256());
jwt.getClaims().setClaim("jti", context.getTokenReferenceId());

final AuthzDetails authzDetails = getAuthzDetails();
if (!AuthzDetails.isEmpty(authzDetails)) {
Expand Down Expand Up @@ -401,6 +404,7 @@ public IdToken createIdToken(
executionContext.setClaimsAsString(getClaims());
executionContext.setNonce(nonce);
executionContext.setState(state);
executionContext.generateRandomTokenReferenceId();

final IdToken idToken = createIdTokenInternal(authorizationCode, accessToken, refreshToken, executionContext);
final AuthorizationGrant grant = executionContext.getGrant();
Expand Down Expand Up @@ -488,6 +492,7 @@ public TokenEntity asTokenEntity(AbstractToken token) {
result.setUserId(getUserId());
result.setUserDn(getUserDn());
result.setClientId(getClientId());
result.setReferenceId(token.getReferenceId());

result.getAttributes().setX5cs256(token.getX5ts256());
result.getAttributes().setDpopJkt(getDpopJkt());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,18 @@ public AuthorizationGrant getAuthorizationGrantByIdToken(String idToken) {
return null;
}

@Override
public AuthorizationGrant getAuthorizationGrantByReferenceId(String referenceId) {
if (StringUtils.isBlank(referenceId)) {
return null;
}
final TokenEntity tokenEntity = grantService.getGrantByReferenceId(referenceId);
if (tokenEntity != null) {
return asGrant(tokenEntity);
}
return null;
}

public AuthorizationGrant asGrant(TokenEntity tokenEntity) {
if (tokenEntity != null) {
final AuthorizationGrantType grantType = AuthorizationGrantType.fromString(tokenEntity.getGrantType());
Expand Down Expand Up @@ -353,6 +365,7 @@ public AuthorizationGrant asGrant(TokenEntity tokenEntity) {
result.setX5ts256(tokenEntity.getAttributes().getX5cs256());
result.setDpopJkt(tokenEntity.getAttributes().getDpopJkt());
result.setTokenEntity(tokenEntity);
result.setReferenceId(tokenEntity.getReferenceId());
if (StringUtils.isNotBlank(grantId)) {
result.setGrantId(grantId);
}
Expand Down Expand Up @@ -381,34 +394,40 @@ public AuthorizationGrant asGrant(TokenEntity tokenEntity) {
final AuthorizationCode code = new AuthorizationCode(tokenEntity.getTokenCode(), tokenEntity.getCreationDate(), tokenEntity.getExpirationDate());
final AuthorizationCodeGrant g = (AuthorizationCodeGrant) result;
code.setX5ts256(g.getX5ts256());
code.setReferenceId(tokenEntity.getReferenceId());
g.setAuthorizationCode(code);
}
break;
case REFRESH_TOKEN:
final RefreshToken refreshToken = new RefreshToken(tokenEntity.getTokenCode(), tokenEntity.getCreationDate(), tokenEntity.getExpirationDate());
refreshToken.setX5ts256(result.getX5ts256());
refreshToken.setReferenceId(tokenEntity.getReferenceId());
result.setRefreshTokens(Collections.singletonList(refreshToken));
break;
case ACCESS_TOKEN:
final AccessToken accessToken = new AccessToken(tokenEntity.getTokenCode(), tokenEntity.getCreationDate(), tokenEntity.getExpirationDate());
accessToken.setDpop(tokenEntity.getDpop());
accessToken.setX5ts256(result.getX5ts256());
accessToken.setReferenceId(tokenEntity.getReferenceId());
result.setAccessTokens(Collections.singletonList(accessToken));
break;
case TX_TOKEN:
final TxToken txToken = new TxToken(tokenEntity.getTokenCode(), tokenEntity.getCreationDate(), tokenEntity.getExpirationDate());
txToken.setDpop(tokenEntity.getDpop());
txToken.setX5ts256(result.getX5ts256());
txToken.setReferenceId(tokenEntity.getReferenceId());
result.setTxTokens(Collections.singletonList(txToken));
break;
case ID_TOKEN:
final IdToken idToken = new IdToken(tokenEntity.getTokenCode(), tokenEntity.getCreationDate(), tokenEntity.getExpirationDate());
idToken.setX5ts256(result.getX5ts256());
idToken.setReferenceId(tokenEntity.getReferenceId());
result.setIdToken(idToken);
break;
case LONG_LIVED_ACCESS_TOKEN:
final AccessToken longLivedAccessToken = new AccessToken(tokenEntity.getTokenCode(), tokenEntity.getCreationDate(), tokenEntity.getExpirationDate());
longLivedAccessToken.setX5ts256(result.getX5ts256());
longLivedAccessToken.setReferenceId(tokenEntity.getReferenceId());
result.setLongLivedAccessToken(longLivedAccessToken);
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.jans.as.server.model.audit.OAuth2AuditLog;
import io.jans.model.custom.script.conf.CustomScriptConfiguration;
import io.jans.model.token.TokenEntity;
import io.jans.util.IdUtil;
import jakarta.faces.context.ExternalContext;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
Expand Down Expand Up @@ -65,6 +66,7 @@ public class ExecutionContext {

private String nonce;
private String state;
private String tokenReferenceId = IdUtil.randomShortUUID();

private boolean includeIdTokenClaims;

Expand Down Expand Up @@ -158,6 +160,19 @@ public static ExecutionContext of(ExecutionContext context) {
return executionContext;
}

public String generateRandomTokenReferenceId() {
tokenReferenceId = IdUtil.randomShortUUID();
return tokenReferenceId;
}

public String getTokenReferenceId() {
return tokenReferenceId;
}

public void setTokenReferenceId(String tokenReferenceId) {
this.tokenReferenceId = tokenReferenceId;
}

public ExecutionContext copy() {
return of(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ public interface IAuthorizationGrantList {

AuthorizationGrant getAuthorizationGrantByIdToken(String idToken);

AuthorizationGrant getAuthorizationGrantByReferenceId(String idToken);

CIBAGrant getCIBAGrant(String authReqId);

DeviceCodeGrant createDeviceGrant(DeviceAuthorizationCacheControl data, User user);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ private void fillClaims(JsonWebResponse jwr,

jwr.getClaims().setExpirationTime(expiration);
jwr.getClaims().setIssuedAt(issuedAt);
jwr.setClaim("random", UUID.randomUUID().toString()); // provided uniqueness of id_token for same RP requests, oxauth: 1493
jwr.setClaim("jti", executionContext.getTokenReferenceId()); // provided uniqueness of id_token for same RP requests, oxauth: 1493

if (executionContext.getPreProcessing() != null) {
executionContext.getPreProcessing().apply(jwr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,22 @@ public TokenEntity getGrantByCode(String code) {
}
}

public TokenEntity getGrantByReferenceId(String referenceId) {
try {
final List<TokenEntity> grants = persistenceEntryManager.findEntries(tokenBaseDn(), TokenEntity.class, Filter.createEqualityFilter("jansId", referenceId));
if (grants.size() > 1) {
log.error("Found more then one tokens by referenceId {}", referenceId);
return null;
}
if (grants.size() == 1) {
return grants.get(0);
}
} catch (Exception e) {
logException(e);
}
return null;
}

private void logException(Exception e) {
if (isTrue(appConfiguration.getLogNotFoundEntityAsError())) {
log.error(e.getMessage(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,11 @@

package io.jans.model.token;

import io.jans.orm.annotation.*;

import java.io.Serializable;
import java.util.Date;

import io.jans.orm.annotation.AttributeName;
import io.jans.orm.annotation.DN;
import io.jans.orm.annotation.DataEntry;
import io.jans.orm.annotation.Expiration;
import io.jans.orm.annotation.JsonObject;
import io.jans.orm.annotation.ObjectClass;

/**
* @author Yuriy Zabrovarnyy
* @author Javier Rojas Blum
Expand Down Expand Up @@ -68,6 +63,8 @@ public class TokenEntity implements Serializable {
private String claims;
@AttributeName(name = "tknBndCnf")
private String tokenBindingHash;
@AttributeName(name = "jansId")
private String referenceId;

@AttributeName(name = "acr")
private String authMode;
Expand All @@ -84,6 +81,14 @@ public class TokenEntity implements Serializable {
@AttributeName(name = "dpop")
private String dpop;

public String getReferenceId() {
return referenceId;
}

public void setReferenceId(String referenceId) {
this.referenceId = referenceId;
}

public TokenAttributes getAttributes() {
if (attributes == null) {
attributes = new TokenAttributes();
Expand Down
34 changes: 34 additions & 0 deletions jans-core/util/src/main/java/io/jans/util/IdUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package io.jans.util;

import java.nio.ByteBuffer;
import java.util.Base64;
import java.util.Random;
import java.util.UUID;

/**
* @author Yuriy Z
*/
public class IdUtil {

// we are ok to have not secured random
private static final Random RANDOM = new Random();

private IdUtil() {
}

public static String randomShortUUID() {
return toShortUUID(UUID.randomUUID());
}

public static String toShortUUID(UUID uuid) {
ByteBuffer byteBuffer = ByteBuffer.allocate(16);
byteBuffer.putLong(uuid.getMostSignificantBits());
byteBuffer.putLong(uuid.getLeastSignificantBits());
return Base64.getEncoder().withoutPadding().encodeToString(byteBuffer.array())
.replace("/", randomChar()).replace("+", randomChar());
}

private static String randomChar() {
return (char) (RANDOM.nextInt(26) + 'a') + "";
}
}
24 changes: 24 additions & 0 deletions jans-core/util/src/test/java/io/jans/util/IdUtilTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package io.jans.util;

import org.testng.annotations.Test;

import static org.testng.AssertJUnit.assertEquals;

/**
* @author Yuriy Z
*/
public class IdUtilTest {

@Test
public void shortUuid_lenthMustBe22() {
assertEquals(22, IdUtil.randomShortUUID().length());
}

@Test(enabled = false)
public void shortUuid_generateALotIdsAndPrintThem() {
for (int i = 0; i < 100000; i++) {
final String shortUUID = IdUtil.randomShortUUID();
System.out.println(shortUUID + " length: " + shortUUID.length());
};
}
}
Loading