Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support token credential cache for azure-identity-extension #43659

Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.azure.identity.extensions.implementation.cache;

public interface IdentityCache<K, V> {
moarychan marked this conversation as resolved.
Show resolved Hide resolved

void put(K key, V value);

V get(K key);

void remove(K key);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.identity.extensions.implementation.cache;

import com.azure.core.credential.AccessToken;
import com.azure.core.credential.TokenCredential;
import com.azure.identity.extensions.implementation.utils.ClassUtil;

import static com.azure.identity.extensions.implementation.utils.ClassUtil.instantiateClass;

public final class IdentityCacheHelper {

private static Class<? extends IdentityCache<String, TokenCredential>> DEFAULT_TOKEN_CREDENTIAL_CACHE_CLASS = InMemoryTokenCredentialCache.class;
private static Class<? extends IdentityCache<String, AccessToken>> DEFAULT_ACCESS_TOKEN_CACHE_CLASS = InMemoryAccessTokenCache.class;

private IdentityCacheHelper() {

}

public static IdentityCache<String, TokenCredential> createTokenCredentialCacheInstance() {
return createTokenCredentialCacheInstance(null);
}

public static IdentityCache<String, TokenCredential> createTokenCredentialCacheInstance(String cacheClassName) {
Class<? extends IdentityCache<String, TokenCredential>> clazz = ClassUtil.getClass(cacheClassName, IdentityCache.class);
if (clazz == null) {
clazz = DEFAULT_TOKEN_CREDENTIAL_CACHE_CLASS;
}

return instantiateClass(clazz);
}

public static IdentityCache<String, AccessToken> createAccessTokenCacheInstance() {
return createAccessTokenCacheInstance(null);
}

public static IdentityCache<String, AccessToken> createAccessTokenCacheInstance(String cacheClassName) {
Class<? extends IdentityCache<String, AccessToken>> clazz = ClassUtil.getClass(cacheClassName, IdentityCache.class);
if (clazz == null) {
clazz = DEFAULT_ACCESS_TOKEN_CACHE_CLASS;
}

return instantiateClass(clazz);
}

public static void setDefaultTokenCredentialCacheClass(Class<? extends IdentityCache<String, TokenCredential>> clazz) {
DEFAULT_TOKEN_CREDENTIAL_CACHE_CLASS = clazz;
}

public static void setDefaultAccessTokenCacheClass(Class<? extends IdentityCache<String, AccessToken>> clazz) {
DEFAULT_ACCESS_TOKEN_CACHE_CLASS = clazz;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.azure.identity.extensions.implementation.cache;

import com.azure.core.credential.AccessToken;

import java.util.concurrent.ConcurrentHashMap;

public class InMemoryAccessTokenCache implements IdentityCache<String, AccessToken> {

private static final ConcurrentHashMap<String, AccessToken> CACHE = new ConcurrentHashMap<>();

@Override
public synchronized void put(String key, AccessToken value) {
CACHE.putIfAbsent(key, value);
}

@Override
public AccessToken get(String key) {
return CACHE.get(key);
}

@Override
public synchronized void remove(String key) {
CACHE.remove(key);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.azure.identity.extensions.implementation.cache;

import com.azure.core.credential.TokenCredential;

import java.util.concurrent.ConcurrentHashMap;

public class InMemoryTokenCredentialCache implements IdentityCache<String, TokenCredential> {

private static final ConcurrentHashMap<String, TokenCredential> CACHE = new ConcurrentHashMap<>();

@Override
public synchronized void put(String key, TokenCredential value) {
CACHE.putIfAbsent(key, value);
}

@Override
public TokenCredential get(String key) {
return CACHE.get(key);
}

@Override
public synchronized void remove(String key) {
CACHE.remove(key);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ public class TokenCredentialProviderOptions {
private String tokenCredentialProviderClassName;
private String tokenCredentialBeanName;
private String accessTokenTimeoutInSeconds;
private String tokenCredentialCacheClassName;
private String accessTokenCacheClassName;

public TokenCredentialProviderOptions() {

Expand All @@ -42,6 +44,8 @@ public TokenCredentialProviderOptions(Properties properties) {
this.managedIdentityEnabled = Boolean.TRUE.equals(AuthProperty.MANAGED_IDENTITY_ENABLED.getBoolean(properties));
this.tokenCredentialProviderClassName = AuthProperty.TOKEN_CREDENTIAL_PROVIDER_CLASS_NAME.get(properties);
this.tokenCredentialBeanName = AuthProperty.TOKEN_CREDENTIAL_BEAN_NAME.get(properties);
this.tokenCredentialCacheClassName = AuthProperty.TOKEN_CREDENTIAL_CACHE_CLASS_NAME.get(properties);
moarychan marked this conversation as resolved.
Show resolved Hide resolved
this.accessTokenCacheClassName = AuthProperty.ACCESS_TOKEN_CACHE_CLASS_NAME.get(properties);
this.accessTokenTimeoutInSeconds = AuthProperty.GET_TOKEN_TIMEOUT.get(properties);
this.authorityHost = AuthProperty.AUTHORITY_HOST.get(properties);
}
Expand Down Expand Up @@ -141,4 +145,20 @@ public String getAccessTokenTimeoutInSeconds() {
public void setAccessTokenTimeoutInSeconds(String accessTokenTimeoutInSeconds) {
this.accessTokenTimeoutInSeconds = accessTokenTimeoutInSeconds;
}

public String getTokenCredentialCacheClassName() {
return tokenCredentialCacheClassName;
}

public void setTokenCredentialCacheClassName(String tokenCredentialCacheClassName) {
this.tokenCredentialCacheClassName = tokenCredentialCacheClassName;
}

public String getAccessTokenCacheClassName() {
return accessTokenCacheClassName;
}

public void setAccessTokenCacheClassName(String accessTokenCacheClassName) {
this.accessTokenCacheClassName = accessTokenCacheClassName;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.identity.extensions.implementation.credential.provider;

import com.azure.core.credential.TokenCredential;
import com.azure.core.util.logging.ClientLogger;
import com.azure.identity.extensions.implementation.cache.IdentityCache;
import com.azure.identity.extensions.implementation.cache.IdentityCacheHelper;
import com.azure.identity.extensions.implementation.credential.TokenCredentialProviderOptions;

import static com.azure.identity.extensions.implementation.utils.StringUtils.getTokenCredentialCacheKey;

/**
* Default cache tokenCredentialProvider implementation that provides tokenCredential instance.
*/
public class DefaultCacheTokenCredentialProvider implements TokenCredentialProvider {

private static final ClientLogger LOGGER = new ClientLogger(DefaultCacheTokenCredentialProvider.class);

private final TokenCredentialProviderOptions options;

private final IdentityCache<String, TokenCredential> tokenCredentialCache;

private final DefaultTokenCredentialProvider defaultProvider;

DefaultCacheTokenCredentialProvider() {
this(new TokenCredentialProviderOptions());
}

DefaultCacheTokenCredentialProvider(TokenCredentialProviderOptions options) {
this(options, null);
}

DefaultCacheTokenCredentialProvider(TokenCredentialProviderOptions options, IdentityCache<String, TokenCredential> tokenCredentialCache) {
this.options = options;
if (tokenCredentialCache == null) {
this.tokenCredentialCache = IdentityCacheHelper.createTokenCredentialCacheInstance(options.getTokenCredentialCacheClassName());
} else {
this.tokenCredentialCache = tokenCredentialCache;
}
this.defaultProvider = new DefaultTokenCredentialProvider(this.options);
}

@Override
public TokenCredential get() {
String tokenCredentialCacheKey = getTokenCredentialCacheKey(options);
TokenCredential cachedTokenCredential = tokenCredentialCache.get(tokenCredentialCacheKey);
if (cachedTokenCredential != null) {
LOGGER.verbose("Returning token credential from cache.");
return cachedTokenCredential;
}

TokenCredential tokenCredential = defaultProvider.get();
tokenCredentialCache.put(tokenCredentialCacheKey, tokenCredential);
LOGGER.verbose("The token credential cached.");
return tokenCredential;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,14 @@
*/
public class DefaultTokenCredentialProvider implements TokenCredentialProvider {

private final TokenCredentialProviderOptions options;

private final TokenCredential tokenCredential;

DefaultTokenCredentialProvider() {
this.options = new TokenCredentialProviderOptions();
this.tokenCredential = get(this.options);
this(new TokenCredentialProviderOptions());
}

DefaultTokenCredentialProvider(TokenCredentialProviderOptions options) {
this.options = options;
this.tokenCredential = get(this.options);
this.tokenCredential = get(options);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ public enum AuthProperty {
* The given bean name of a TokenCredential bean in the Spring context.
*/
TOKEN_CREDENTIAL_BEAN_NAME("azure.tokenCredentialBeanName", "springCloudAzureDefaultCredential",
"The given bean name of a TokenCredential bean in the Spring context.", false);
"The given bean name of a TokenCredential bean in the Spring context.", false),
TOKEN_CREDENTIAL_CACHE_CLASS_NAME("azure.tokenCredentialCacheClassName",
"The given class name of a TokenCredential cache.", false),
ACCESS_TOKEN_CACHE_ENABLED("azure.accessTokenCacheEnabled", "true",
"Whether to enable the token cache.", false),
ACCESS_TOKEN_CACHE_CLASS_NAME("azure.accessTokenCacheClassName",
"The given class name of a AccessToken cache.", false);

String propertyKey;
String defaultValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,21 @@

import com.azure.core.credential.AccessToken;
import com.azure.core.util.logging.ClientLogger;
import com.azure.identity.extensions.implementation.credential.provider.TokenCredentialProvider;
import com.azure.identity.extensions.implementation.cache.IdentityCache;
import com.azure.identity.extensions.implementation.credential.TokenCredentialProviderOptions;
import com.azure.identity.extensions.implementation.credential.provider.TokenCredentialProvider;
import com.azure.identity.extensions.implementation.enums.AuthProperty;
import com.azure.identity.extensions.implementation.token.AccessTokenResolver;
import com.azure.identity.extensions.implementation.token.AccessTokenResolverOptions;
import reactor.core.publisher.Mono;

import java.time.Duration;
import java.util.Properties;
import java.util.concurrent.atomic.AtomicBoolean;
import reactor.core.publisher.Mono;

import static com.azure.identity.extensions.implementation.cache.IdentityCacheHelper.createAccessTokenCacheInstance;
import static com.azure.identity.extensions.implementation.enums.AuthProperty.GET_TOKEN_TIMEOUT;
import static com.azure.identity.extensions.implementation.utils.StringUtils.getAccessTokenCacheKey;

/**
* Template class can be extended to get password from access token.
Expand All @@ -28,6 +34,10 @@ public class AzureAuthenticationTemplate {

private AccessTokenResolver accessTokenResolver;

private IdentityCache<String, AccessToken> accessTokenCache;
moarychan marked this conversation as resolved.
Show resolved Hide resolved

private AccessTokenResolverOptions resolverOptions;

private long accessTokenTimeoutInSeconds;

/**
Expand Down Expand Up @@ -59,14 +69,20 @@ public void init(Properties properties) {
if (isInitialized.compareAndSet(false, true)) {
LOGGER.verbose("Initializing AzureAuthenticationTemplate.");

TokenCredentialProviderOptions providerOptions = new TokenCredentialProviderOptions(properties);
if (getTokenCredentialProvider() == null) {
this.tokenCredentialProvider
= TokenCredentialProvider.createDefault(new TokenCredentialProviderOptions(properties));
= TokenCredentialProvider.createDefault(providerOptions);
}

this.resolverOptions = new AccessTokenResolverOptions(properties);
if (getAccessTokenResolver() == null) {
this.accessTokenResolver
= AccessTokenResolver.createDefault(new AccessTokenResolverOptions(properties));
= AccessTokenResolver.createDefault(resolverOptions);
}

if (AuthProperty.ACCESS_TOKEN_CACHE_ENABLED.getBoolean(properties)) {
this.accessTokenCache = createAccessTokenCacheInstance(providerOptions.getAccessTokenCacheClassName());
}

if (properties.containsKey(GET_TOKEN_TIMEOUT.getPropertyKey())) {
Expand All @@ -90,10 +106,31 @@ public Mono<String> getTokenAsPasswordAsync() {
if (!isInitialized.get()) {
throw LOGGER.logExceptionAsError(new IllegalStateException("must call init() first"));
}

if (accessTokenCache != null) {
String accessTokenCacheKey = getAccessTokenCacheKey(this.resolverOptions);
AccessToken accessToken = accessTokenCache.get(accessTokenCacheKey);
if (accessToken != null) {
if (!accessToken.isExpired()) {
LOGGER.verbose("Returning access token from cache.");
return Mono.just(accessToken.getToken());
} else {
accessTokenCache.remove(accessTokenCacheKey);
}

}
}

return Mono.fromSupplier(getTokenCredentialProvider())
.flatMap(getAccessTokenResolver())
.filter(token -> !token.isExpired())
.map(AccessToken::getToken);
.flatMap(getAccessTokenResolver())
.doOnSuccess(accessToken -> {
if (accessTokenCache != null) {
accessTokenCache.put(getAccessTokenCacheKey(this.resolverOptions), accessToken);
LOGGER.verbose("The access token cached.");
}
})
.filter(token -> !token.isExpired())
.map(AccessToken::getToken);
}

/**
Expand Down Expand Up @@ -125,5 +162,4 @@ Duration getBlockTimeout() {
AtomicBoolean getIsInitialized() {
return isInitialized;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.azure.identity.extensions.implementation.utils;

import com.azure.identity.extensions.implementation.credential.TokenCredentialProviderOptions;
import com.azure.identity.extensions.implementation.token.AccessTokenResolverOptions;

import java.util.Arrays;
import java.util.stream.Collectors;

public class StringUtils {

private StringUtils() {

}

public static String getTokenCredentialCacheKey(TokenCredentialProviderOptions options) {
return joinOptions(options.getTenantId(), options.getClientId(), options.getClientCertificatePath(),
options.getUsername(), String.valueOf(options.isManagedIdentityEnabled()),
options.getTokenCredentialProviderClassName(), options.getTokenCredentialBeanName(),
options.getTokenCredentialCacheClassName());
}

public static String getAccessTokenCacheKey(AccessTokenResolverOptions options) {
return joinOptions(options.getTenantId(), options.getClaims(), String.join("-", options.getScopes()));
}

private static String joinOptions(String... options) {
return Arrays.stream(options).map(StringUtils::nonNullOption).collect(Collectors.joining(","));
}

private static String nonNullOption(String option) {
return option == null ? "" : option;
}
}
Loading