Skip to content

Commit

Permalink
Merge pull request #811 from AzureAD/avdunn/claims-refresh-fix
Browse files Browse the repository at this point in the history
Refresh tokens when request contains claims
  • Loading branch information
Avery-Dunn committed Aug 23, 2024
2 parents 07d67b7 + b581622 commit 260656e
Show file tree
Hide file tree
Showing 7 changed files with 270 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ void acquireTokenSilent_ConfidentialClient_acquireTokenSilent(String environment
}

@Test
public void acquireTokenSilent_ConfidentialClient_acquireTokenSilentDifferentScopeThrowsException()
void acquireTokenSilent_ConfidentialClient_acquireTokenSilentDifferentScopeThrowsException()
throws Exception {
cfg = new Config(AzureEnvironment.AZURE);

Expand Down Expand Up @@ -344,6 +344,48 @@ void acquireTokenSilent_emptyScopeSet(String environment) throws Exception {
assertEquals(result.accessToken(), silentResult.accessToken());
}

@Test
public void acquireTokenSilent_ClaimsForceRefresh() throws Exception {
cfg = new Config(AzureEnvironment.AZURE);
User user = labUserProvider.getDefaultUser(AzureEnvironment.AZURE);

Set<String> scopes = new HashSet<>();
PublicClientApplication pca = PublicClientApplication.builder(
user.getAppId()).
authority(cfg.organizationsAuthority()).
build();

IAuthenticationResult result = pca.acquireToken(UserNamePasswordParameters.
builder(scopes,
user.getUpn(),
user.getPassword().toCharArray())
.build())
.get();

assertResultNotNull(result);

IAuthenticationResult silentResultWithoutClaims = pca.acquireTokenSilently(SilentParameters.
builder(scopes, result.account())
.build())
.get();

assertResultNotNull(silentResultWithoutClaims);
assertEquals(result.accessToken(), silentResultWithoutClaims.accessToken());

//If claims are added to a silent request, it should trigger the refresh flow and return a new token
ClaimsRequest cr = new ClaimsRequest();
cr.requestClaimInAccessToken("email", null);

IAuthenticationResult silentResultWithClaims = pca.acquireTokenSilently(SilentParameters.
builder(scopes, result.account())
.claims(cr)
.build())
.get();

assertResultNotNull(silentResultWithClaims);
assertNotEquals(result.accessToken(), silentResultWithClaims.accessToken());
}

private IConfidentialClientApplication getConfidentialClientApplications() throws Exception {
String clientId = cfg.appProvider.getOboAppId();
String password = cfg.appProvider.getOboAppPassword();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
class AcquireTokenSilentSupplier extends AuthenticationResultSupplier {

private SilentRequest silentRequest;
protected static final int ACCESS_TOKEN_EXPIRE_BUFFER_IN_SEC = 5 * 60;

AcquireTokenSilentSupplier(AbstractApplicationBase clientApplication, SilentRequest silentRequest) {
super(clientApplication, silentRequest);
Expand All @@ -22,6 +23,7 @@ class AcquireTokenSilentSupplier extends AuthenticationResultSupplier {

@Override
AuthenticationResult execute() throws Exception {
boolean shouldRefresh;
Authority requestAuthority = silentRequest.requestAuthority();
if (requestAuthority.authorityType != AuthorityType.B2C) {
requestAuthority =
Expand Down Expand Up @@ -53,29 +55,9 @@ AuthenticationResult execute() throws Exception {
clientApplication.serviceBundle().getServerSideTelemetry().incrementSilentSuccessfulCount();
}

//Determine if the current token needs to be refreshed according to the refresh_in value
long currTimeStampSec = new Date().getTime() / 1000;
boolean afterRefreshOn = res.refreshOn() != null && res.refreshOn() > 0 &&
res.refreshOn() < currTimeStampSec && res.expiresOn() >= currTimeStampSec;

if (silentRequest.parameters().forceRefresh() || afterRefreshOn || StringHelper.isBlank(res.accessToken())) {

//As of version 3 of the telemetry schema, there is a field for collecting data about why a token was refreshed,
// so here we set the telemetry value based on the cause of the refresh
if (silentRequest.parameters().forceRefresh()) {
clientApplication.serviceBundle().getServerSideTelemetry().getCurrentRequest().cacheInfo(
CacheTelemetry.REFRESH_FORCE_REFRESH.telemetryValue);
} else if (afterRefreshOn) {
clientApplication.serviceBundle().getServerSideTelemetry().getCurrentRequest().cacheInfo(
CacheTelemetry.REFRESH_REFRESH_IN.telemetryValue);
} else if (res.expiresOn() < currTimeStampSec) {
clientApplication.serviceBundle().getServerSideTelemetry().getCurrentRequest().cacheInfo(
CacheTelemetry.REFRESH_ACCESS_TOKEN_EXPIRED.telemetryValue);
} else if (StringHelper.isBlank(res.accessToken())) {
clientApplication.serviceBundle().getServerSideTelemetry().getCurrentRequest().cacheInfo(
CacheTelemetry.REFRESH_NO_ACCESS_TOKEN.telemetryValue);
}
shouldRefresh = shouldRefresh(silentRequest.parameters(), res);

if (shouldRefresh || clientApplication.serviceBundle().getServerSideTelemetry().getCurrentRequest().cacheInfo() == CacheTelemetry.REFRESH_REFRESH_IN.telemetryValue) {
if (!StringHelper.isBlank(res.refreshToken())) {
//There are certain scenarios where the cached authority may differ from the client app's authority,
// such as when a request is instance aware. Unless overridden by SilentParameters.authorityUrl, the
Expand All @@ -84,29 +66,7 @@ AuthenticationResult execute() throws Exception {
requestAuthority = Authority.createAuthority(new URL(requestAuthority.authority().replace(requestAuthority.host(),
res.account().environment())));
}

RefreshTokenRequest refreshTokenRequest = new RefreshTokenRequest(
RefreshTokenParameters.builder(silentRequest.parameters().scopes(), res.refreshToken()).build(),
silentRequest.application(),
silentRequest.requestContext(),
silentRequest);

AcquireTokenByAuthorizationGrantSupplier acquireTokenByAuthorisationGrantSupplier =
new AcquireTokenByAuthorizationGrantSupplier(clientApplication, refreshTokenRequest, requestAuthority);

try {
res = acquireTokenByAuthorisationGrantSupplier.execute();

res.metadata().tokenSource(TokenSource.IDENTITY_PROVIDER);

log.info("Access token refreshed successfully.");
} catch (MsalServiceException ex) {
//If the token refresh attempt threw a MsalServiceException but the refresh attempt was done
// only because of refreshOn, then simply return the existing cached token
if (afterRefreshOn && !(silentRequest.parameters().forceRefresh() || StringHelper.isBlank(res.accessToken()))) {
return res;
} else throw ex;
}
res = makeRefreshRequest(res, requestAuthority);
} else {
res = null;
}
Expand All @@ -120,4 +80,81 @@ AuthenticationResult execute() throws Exception {

return res;
}

private AuthenticationResult makeRefreshRequest(AuthenticationResult cachedResult, Authority requestAuthority) throws Exception {
RefreshTokenRequest refreshTokenRequest = new RefreshTokenRequest(
RefreshTokenParameters.builder(silentRequest.parameters().scopes(), cachedResult.refreshToken()).build(),
silentRequest.application(),
silentRequest.requestContext(),
silentRequest);

AcquireTokenByAuthorizationGrantSupplier acquireTokenByAuthorisationGrantSupplier =
new AcquireTokenByAuthorizationGrantSupplier(clientApplication, refreshTokenRequest, requestAuthority);

try {
AuthenticationResult refreshedResult = acquireTokenByAuthorisationGrantSupplier.execute();

refreshedResult.metadata().tokenSource(TokenSource.IDENTITY_PROVIDER);

log.info("Access token refreshed successfully.");
return refreshedResult;
} catch (MsalServiceException ex) {
//If the token refresh attempt threw a MsalServiceException but the refresh attempt was done
// only because of refreshOn, then simply return the existing cached token rather than throw an exception
if (clientApplication.serviceBundle().getServerSideTelemetry().getCurrentRequest().cacheInfo() == CacheTelemetry.REFRESH_REFRESH_IN.telemetryValue) {
return cachedResult;
}
throw ex;
}
}

//Handles any logic to determine if a token should be refreshed, based on the request parameters and the status of cached tokens
private boolean shouldRefresh(SilentParameters parameters, AuthenticationResult cachedResult) {

//If forceRefresh is true, no reason to check any other option
if (parameters.forceRefresh()) {
setCacheTelemetry(CacheTelemetry.REFRESH_FORCE_REFRESH.telemetryValue);
log.debug("Refreshing access token because forceRefresh parameter is true.");
return true;
}

//If the request contains claims then the token should be refreshed, to ensure that the returned token has the correct claims
// Note: these are the types of claims found in (for example) a claims challenge, and do not include client capabilities
if (parameters.claims() != null) {
setCacheTelemetry(CacheTelemetry.REFRESH_FORCE_REFRESH.telemetryValue);
log.debug("Refreshing access token because the claims parameter is not null.");
return true;
}

long currTimeStampSec = new Date().getTime() / 1000;

//If the access token is expired or within 5 minutes of becoming expired, refresh it
if (!StringHelper.isBlank(cachedResult.accessToken()) && cachedResult.expiresOn() < (currTimeStampSec - ACCESS_TOKEN_EXPIRE_BUFFER_IN_SEC)) {
setCacheTelemetry(CacheTelemetry.REFRESH_ACCESS_TOKEN_EXPIRED.telemetryValue);
log.debug("Refreshing access token because it is expired.");
return true;
}

//Certain long-lived tokens will have a 'refresh on' time that indicates a refresh should be attempted long before the token would expire
if (!StringHelper.isBlank(cachedResult.accessToken()) &&
cachedResult.refreshOn() != null && cachedResult.refreshOn() > 0 &&
cachedResult.refreshOn() < currTimeStampSec && cachedResult.expiresOn() >= (currTimeStampSec + ACCESS_TOKEN_EXPIRE_BUFFER_IN_SEC)){
setCacheTelemetry(CacheTelemetry.REFRESH_REFRESH_IN.telemetryValue);
log.debug("Attempting to refresh access token because it is after the refreshOn time.");
return true;
}

//If there is a refresh token but no access token, we should use the refresh token to get the access token
if (StringHelper.isBlank(cachedResult.accessToken()) && !StringHelper.isBlank(cachedResult.refreshToken())) {
setCacheTelemetry(CacheTelemetry.REFRESH_NO_ACCESS_TOKEN.telemetryValue);
log.debug("Refreshing access token because it was missing from the cache.");
return true;
}

return false;
}

private void setCacheTelemetry(int cacheInfoValue){
clientApplication.serviceBundle().getServerSideTelemetry().getCurrentRequest().cacheInfo(cacheInfoValue);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -474,13 +474,11 @@ private Optional<AccessTokenCacheEntity> getAccessTokenCacheEntity(
Set<String> scopes,
String clientId,
Set<String> environmentAliases) {
long currTimeStampSec = new Date().getTime() / 1000;

return accessTokens.values().stream().filter(
accessToken ->
accessToken.homeAccountId.equals(account.homeAccountId()) &&
environmentAliases.contains(accessToken.environment) &&
Long.parseLong(accessToken.expiresOn()) > currTimeStampSec + MIN_ACCESS_TOKEN_EXPIRE_IN_SEC &&
accessToken.realm.equals(authority.tenant()) &&
accessToken.clientId.equals(clientId) &&
isMatchingScopes(accessToken, scopes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,17 @@

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.assertThrows;

import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;

import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Collections;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
Expand All @@ -16,6 +24,9 @@
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class AcquireTokenSilentlyTest {

Account basicAccount = new Account("home_account_id", "login.windows.net", "username", null);
String cache = readResource("/AAD_cache_data/full_cache.json");

@Test
void publicAppAcquireTokenSilently_emptyCache_MsalClientException() throws Throwable {

Expand All @@ -29,7 +40,7 @@ void publicAppAcquireTokenSilently_emptyCache_MsalClientException() throws Throw

ExecutionException ex = assertThrows(ExecutionException.class, future::get);

assertTrue(ex.getCause() instanceof MsalClientException);
assertInstanceOf(MsalClientException.class, ex.getCause());
assertTrue(ex.getMessage().contains(AuthenticationErrorMessage.NO_TOKEN_IN_CACHE));
}

Expand All @@ -45,7 +56,71 @@ void confidentialAppAcquireTokenSilently_emptyCache_MsalClientException() throws

ExecutionException ex = assertThrows(ExecutionException.class, future::get);

assertTrue(ex.getCause() instanceof MsalClientException);
assertInstanceOf(MsalClientException.class, ex.getCause());
assertTrue(ex.getMessage().contains(AuthenticationErrorMessage.NO_TOKEN_IN_CACHE));
}

@Test
void publicAppAcquireTokenSilently_claimsSkipCache() throws Throwable {

PublicClientApplication application = PublicClientApplication.builder("client_id")
.instanceDiscovery(false)
.authority("https://some.authority.com/realm")
.build();

application.tokenCache.deserialize(cache);

SilentParameters parameters = SilentParameters.builder(Collections.singleton("scopes"), basicAccount).build();

IAuthenticationResult result = application.acquireTokenSilently(parameters).get();

//Confirm cached dummy token returned from silent request
assertNotNull(result);
assertEquals("token", result.accessToken());

ClaimsRequest cr = new ClaimsRequest();
cr.requestClaimInAccessToken("something", null);

parameters = SilentParameters.builder(Collections.singleton("scopes"), basicAccount).claims(cr).build();
CompletableFuture<IAuthenticationResult> future = application.acquireTokenSilently(parameters);

//Confirm cached dummy token ignored when claims are part of request
ExecutionException ex = assertThrows(ExecutionException.class, future::get);
assertInstanceOf(MsalInteractionRequiredException.class, ex.getCause());
}

@Test
void confidentialAppAcquireTokenSilently_claimsSkipCache() throws Throwable {

ConfidentialClientApplication application = ConfidentialClientApplication
.builder("client_id", ClientCredentialFactory.createFromSecret(TestConfiguration.AAD_CLIENT_DUMMYSECRET))
.instanceDiscovery(false)
.authority("https://some.authority.com/realm").build();

application.tokenCache.deserialize(cache);

SilentParameters parameters = SilentParameters.builder(Collections.singleton("scopes"), basicAccount).build();

IAuthenticationResult result = application.acquireTokenSilently(parameters).get();

assertNotNull(result);
assertEquals("token", result.accessToken());

ClaimsRequest cr = new ClaimsRequest();
cr.requestClaimInAccessToken("something", null);

parameters = SilentParameters.builder(Collections.singleton("scopes"), basicAccount).claims(cr).build();
CompletableFuture<IAuthenticationResult> future = application.acquireTokenSilently(parameters);

ExecutionException ex = assertThrows(ExecutionException.class, future::get);
assertInstanceOf(MsalInteractionRequiredException.class, ex.getCause());
}

String readResource(String resource) {
try {
return new String(Files.readAllBytes(Paths.get(getClass().getResource(resource).toURI())));
} catch (IOException | URISyntaxException e) {
throw new RuntimeException(e);
}
}
}
Loading

0 comments on commit 260656e

Please sign in to comment.