Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added credential caching for Managed Identity Credential and Default Azure Credential #2415

Merged
merged 5 commits into from
May 21, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 105 additions & 16 deletions src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,22 @@
package com.microsoft.sqlserver.jdbc;

import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.text.MessageFormat;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Optional;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;

import com.azure.core.credential.AccessToken;
import com.azure.core.credential.TokenCredential;
import com.azure.core.credential.TokenRequestContext;
import com.azure.identity.ManagedIdentityCredential;
import com.azure.identity.ManagedIdentityCredentialBuilder;
Expand Down Expand Up @@ -46,6 +51,11 @@ class SQLServerSecurityUtility {
// Environment variable for additionally allowed tenants. The tenantIds are comma delimited
private static final String ADDITIONALLY_ALLOWED_TENANTS = "ADDITIONALLY_ALLOWED_TENANTS";

// Credential Cache for ManagedIdentityCredential and DefaultAzureCredential
private static final HashMap<String, Credential> CREDENTIAL_CACHE = new HashMap<>();

private static final Lock LOCK = new ReentrantLock();
tkyc marked this conversation as resolved.
Show resolved Hide resolved

private SQLServerSecurityUtility() {
throw new UnsupportedOperationException(SQLServerException.getErrString("R_notSupported"));
}
Expand Down Expand Up @@ -331,16 +341,33 @@ static void verifyColumnMasterKeyMetadata(SQLServerConnection connection, SQLSer
*/
static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource,
String managedIdentityClientId) throws SQLServerException {
ManagedIdentityCredential mic = null;

if (logger.isLoggable(java.util.logging.Level.FINEST)) {
logger.finest("Getting Managed Identity authentication token for: " + managedIdentityClientId);
}

if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
mic = new ManagedIdentityCredentialBuilder().clientId(managedIdentityClientId).build();
} else {
mic = new ManagedIdentityCredentialBuilder().build();
String key = getHashedSecret(new String[]{managedIdentityClientId, ManagedIdentityCredential.class.getSimpleName()});
ManagedIdentityCredential mic = (ManagedIdentityCredential) getCredentialFromCache(key);

if (null == mic) {
LOCK.lock();

try {
mic = (ManagedIdentityCredential) getCredentialFromCache(key);
if (null == mic) {
ManagedIdentityCredentialBuilder micBuilder = new ManagedIdentityCredentialBuilder();

if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
mic = micBuilder.clientId(managedIdentityClientId).build();
addCredentialToCache(key, mic);
} else {
mic = micBuilder.build();
addCredentialToCache(key, mic);
}
}
tkyc marked this conversation as resolved.
Show resolved Hide resolved
} finally {
LOCK.unlock();
}
}

TokenRequestContext tokenRequestContext = new TokenRequestContext();
Expand Down Expand Up @@ -383,22 +410,47 @@ static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource,
String intellijKeepassPath = System.getenv(INTELLIJ_KEEPASS_PASS);
String[] additionallyAllowedTenants = getAdditonallyAllowedTenants();

DefaultAzureCredentialBuilder dacBuilder = new DefaultAzureCredentialBuilder();
DefaultAzureCredential dac = null;
int secretsLength = null == additionallyAllowedTenants ? 3 : additionallyAllowedTenants.length + 3;
String[] secrets = new String[secretsLength];

if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
dacBuilder.managedIdentityClientId(managedIdentityClientId);
if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) {
System.arraycopy(additionallyAllowedTenants, 0, secrets, 3, additionallyAllowedTenants.length);
}

if (null != intellijKeepassPath && !intellijKeepassPath.isEmpty()) {
dacBuilder.intelliJKeePassDatabasePath(intellijKeepassPath);
}
secrets[0] = DefaultAzureCredential.class.getSimpleName();
secrets[1] = managedIdentityClientId;
secrets[2] = intellijKeepassPath;

if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) {
dacBuilder.additionallyAllowedTenants(additionallyAllowedTenants);
}
String key = getHashedSecret(secrets);
DefaultAzureCredential dac = (DefaultAzureCredential) getCredentialFromCache(key);

if (null == dac) {
LOCK.lock();

try {
dac = (DefaultAzureCredential) getCredentialFromCache(key);
if (null == dac) {
DefaultAzureCredentialBuilder dacBuilder = new DefaultAzureCredentialBuilder();

dac = dacBuilder.build();
if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
dacBuilder.managedIdentityClientId(managedIdentityClientId);
}

if (null != intellijKeepassPath && !intellijKeepassPath.isEmpty()) {
dacBuilder.intelliJKeePassDatabasePath(intellijKeepassPath);
}

if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) {
dacBuilder.additionallyAllowedTenants(additionallyAllowedTenants);
}

dac = dacBuilder.build();
addCredentialToCache(key, dac);
}
} finally {
LOCK.unlock();
}
}

TokenRequestContext tokenRequestContext = new TokenRequestContext();
String scope = resource.endsWith(SQLServerMSAL4JUtils.SLASH_DEFAULT) ? resource : resource
Expand Down Expand Up @@ -430,4 +482,41 @@ private static String[] getAdditonallyAllowedTenants() {

return null;
}

private static void addCredentialToCache(String key, TokenCredential tokenCredential) {
Credential credential = new Credential(tokenCredential);
CREDENTIAL_CACHE.put(key, credential);
}

private static TokenCredential getCredentialFromCache(String key) {
Credential credential = CREDENTIAL_CACHE.get(key);

if (null != credential) {
return credential.tokenCredential;
}

return null;
}

private static class Credential {
TokenCredential tokenCredential;

public Credential(TokenCredential tokenCredential) {
this.tokenCredential = tokenCredential;
}
}

private static String getHashedSecret(String[] secrets) throws SQLServerException {
try {
MessageDigest md = MessageDigest.getInstance("SHA-256");
for (String secret : secrets) {
if (null != secret) {
md.update(secret.getBytes(java.nio.charset.StandardCharsets.UTF_16LE));
}
}
return new String(md.digest());
} catch (NoSuchAlgorithmException e) {
throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e);
}
}
}
Loading