diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java b/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java index 45c531acf..11f62553c 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java @@ -204,6 +204,7 @@ final class TDS { static final byte ADALWORKFLOW_ACTIVEDIRECTORYMANAGEDIDENTITY = 0x03; static final byte ADALWORKFLOW_ACTIVEDIRECTORYINTERACTIVE = 0x03; static final byte ADALWORKFLOW_DEFAULTAZURECREDENTIAL = 0x03; + static final byte ADALWORKFLOW_ACCESSTOKENCALLBACK = 0x03; static final byte ADALWORKFLOW_ACTIVEDIRECTORYSERVICEPRINCIPAL = 0x01; // Using the Password byte as that is the // closest we have. static final byte FEDAUTH_INFO_ID_STSURL = 0x01; // FedAuthInfoData is token endpoint URL from which to acquire fed diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerConnection.java index eb90ab73f..743ccdc65 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerConnection.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerConnection.java @@ -411,6 +411,8 @@ CallableStatement prepareCall(String sql, int nType, int nConcur, int nHold, /** * Deprecated. Time-to-live is no longer supported for the cached Managed Identity tokens. * This method will always return 0 and is for backwards compatibility only. + * + * @return Method will always return 0. */ @Deprecated int getMsiTokenCacheTtl(); @@ -418,6 +420,9 @@ CallableStatement prepareCall(String sql, int nType, int nConcur, int nHold, /** * Deprecated. Time-to-live is no longer supported for the cached Managed Identity tokens. * This method is a no-op for backwards compatibility only. + * + * @param timeToLive + * Time-to-live is no longer supported. */ @Deprecated void setMsiTokenCacheTtl(int timeToLive); diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerDataSource.java b/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerDataSource.java index 19f986953..a9ba172d6 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerDataSource.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerDataSource.java @@ -1216,6 +1216,9 @@ public interface ISQLServerDataSource extends javax.sql.CommonDataSource { /** * Deprecated. Time-to-live is no longer supported for the cached Managed Identity tokens. * This method is a no-op for backwards compatibility only. + * + * @param timeToLive + * Time-to-live is no longer supported. */ @Deprecated void setMsiTokenCacheTtl(int timeToLive); @@ -1223,7 +1226,24 @@ public interface ISQLServerDataSource extends javax.sql.CommonDataSource { /** * Deprecated. Time-to-live is no longer supported for the cached Managed Identity tokens. * This method will always return 0 and is for backwards compatibility only. + * + * @return Method will always return 0. */ @Deprecated int getMsiTokenCacheTtl(); + + /** + * Sets the {@link SQLServerAccessTokenCallback} delegate. + * + * @param accessTokenCallback + * Access token callback delegate. + */ + void setAccessTokenCallback(SQLServerAccessTokenCallback accessTokenCallback); + + /** + * Returns a {@link SQLServerAccessTokenCallback}, the access token callback delegate. + * + * @return Access token callback delegate. + */ + SQLServerAccessTokenCallback getAccessTokenCallback(); } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerAccessTokenCallback.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerAccessTokenCallback.java new file mode 100644 index 000000000..2d3ba4515 --- /dev/null +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerAccessTokenCallback.java @@ -0,0 +1,21 @@ +package com.microsoft.sqlserver.jdbc; + +/** + * Provides SqlAuthenticationToken callback to be implemented by client code. + */ +public interface SQLServerAccessTokenCallback { + + /** + * For an example of callback usage, look under the project's code samples. + * + * Returns the access token for the authentication request + * + * @param stsurl + * - Security token service URL. + * @param spn + * - Service principal name. + * + * @return Returns a {@link SqlAuthenticationToken}. + */ + SqlAuthenticationToken getAccessToken(String stsurl, String spn); +} diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java index 9520ce40a..8df4650eb 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java @@ -162,7 +162,7 @@ public class SQLServerConnection implements ISQLServerConnection, java.io.Serial private byte[] accessTokenInByte = null; /** fedAuth token */ - private SqlFedAuthToken fedAuthToken = null; + private SqlAuthenticationToken fedAuthToken = null; /** original hostNameInCertificate */ private String originalHostNameInCertificate = null; @@ -537,6 +537,13 @@ class FederatedAuthenticationFeatureExtensionData implements Serializable { this.authentication = SqlAuthentication.ActiveDirectoryInteractive; break; default: + // If authenticationString not specified, check if access token callback was set. + // If access token callback is set, break. + if (null != activeConnectionProperties + .get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString())) { + this.authentication = SqlAuthentication.NotSpecified; + break; + } assert (false); MessageFormat form = new MessageFormat( SQLServerException.getErrString("R_InvalidConnectionSetting")); @@ -1669,7 +1676,7 @@ void checkClosed() throws SQLServerException { * @return true/false */ protected boolean needsReconnect() { - return (null != fedAuthToken && Util.checkIfNeedNewAccessToken(this, fedAuthToken.expiresOn)); + return (null != fedAuthToken && Util.checkIfNeedNewAccessToken(this, fedAuthToken.getExpiresOn())); } /** @@ -2423,6 +2430,17 @@ Connection connectInternal(Properties propsIn, ntlmAuthentication = true; } + SQLServerAccessTokenCallback callback = (SQLServerAccessTokenCallback) activeConnectionProperties + .get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString()); + + if (null != callback && (!activeConnectionProperties + .getProperty(SQLServerDriverStringProperty.USER.toString()).isEmpty() + || !activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()) + .isEmpty())) { + throw new SQLServerException( + SQLServerException.getErrString("R_AccessTokenCallbackWithUserPassword"), null); + } + sPropKey = SQLServerDriverStringProperty.AUTHENTICATION.toString(); sPropValue = activeConnectionProperties.getProperty(sPropKey); if (null == sPropValue) { @@ -2434,7 +2452,7 @@ Connection connectInternal(Properties propsIn, && (!activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()) .isEmpty())) { MessageFormat form = new MessageFormat( - SQLServerException.getErrString("R_MSIAuthenticationWithPassword")); + SQLServerException.getErrString("R_ManagedIdentityAuthenticationWithPassword")); throw new SQLServerException(form.format(new Object[] {authenticationString}), null); } @@ -2466,7 +2484,7 @@ Connection connectInternal(Properties propsIn, && (!activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()) .isEmpty())) { MessageFormat form = new MessageFormat( - SQLServerException.getErrString("R_MSIAuthenticationWithPassword")); + SQLServerException.getErrString("R_ManagedIdentityAuthenticationWithPassword")); throw new SQLServerException(form.format(new Object[] {authenticationString}), null); } @@ -3464,7 +3482,8 @@ private void executeReconnect(LogonCommand logonCommand) throws SQLServerExcepti void prelogin(String serverName, int portNumber) throws SQLServerException { // Build a TDS Pre-Login packet to send to the server. if ((!authenticationString.equalsIgnoreCase(SqlAuthentication.NotSpecified.toString())) - || (null != accessTokenInByte)) { + || (null != accessTokenInByte) || null != activeConnectionProperties + .get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString())) { fedAuthRequiredByUser = true; } @@ -3846,7 +3865,8 @@ void prelogin(String serverName, int portNumber) throws SQLServerException { // Or AccessToken is not null, mean token based authentication is used. if (((null != authenticationString) && (!authenticationString.equalsIgnoreCase(SqlAuthentication.NotSpecified.toString()))) - || (null != accessTokenInByte)) { + || (null != accessTokenInByte) || null != activeConnectionProperties + .get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString())) { fedAuthRequiredPreLoginResponse = (preloginResponse[optionOffset] == 1); } break; @@ -4790,6 +4810,12 @@ int writeFedAuthFeatureRequest(boolean write, /* if false just calculates the le workflow = TDS.ADALWORKFLOW_ACTIVEDIRECTORYSERVICEPRINCIPAL; break; default: + // If not specified, check if access token callback was set. If it is set, break. + if (null != activeConnectionProperties + .get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString())) { + workflow = TDS.ADALWORKFLOW_ACCESSTOKENCALLBACK; + break; + } assert (false); // Unrecognized Authentication type for fedauth ADAL request break; } @@ -5005,7 +5031,9 @@ private void logon(LogonCommand command) throws SQLServerException { .equalsIgnoreCase(SqlAuthentication.ActiveDirectoryServicePrincipal.toString()) || authenticationString .equalsIgnoreCase(SqlAuthentication.ActiveDirectoryInteractive.toString())) - && fedAuthRequiredPreLoginResponse)) { + && fedAuthRequiredPreLoginResponse) + || null != activeConnectionProperties + .get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString())) { federatedAuthenticationInfoRequested = true; fedAuthFeatureExtensionData = new FederatedAuthenticationFeatureExtensionData(TDS.TDS_FEDAUTH_LIBRARY_ADAL, authenticationString, fedAuthRequiredPreLoginResponse); @@ -5499,9 +5527,9 @@ final class FedAuthTokenCommand extends UninterruptableTDSCommand { // Always update serialVersionUID when prompted. private static final long serialVersionUID = 1L; TDSTokenHandler tdsTokenHandler = null; - SqlFedAuthToken sqlFedAuthToken = null; + SqlAuthenticationToken sqlFedAuthToken = null; - FedAuthTokenCommand(SqlFedAuthToken sqlFedAuthToken, TDSTokenHandler tdsTokenHandler) { + FedAuthTokenCommand(SqlAuthenticationToken sqlFedAuthToken, TDSTokenHandler tdsTokenHandler) { super("FedAuth"); this.tdsTokenHandler = tdsTokenHandler; this.sqlFedAuthToken = sqlFedAuthToken; @@ -5530,7 +5558,16 @@ void onFedAuthInfo(SqlFedAuthInfo fedAuthInfo, TDSTokenHandler tdsTokenHandler) assert null != fedAuthInfo; attemptRefreshTokenLocked = true; - fedAuthToken = getFedAuthToken(fedAuthInfo); + + SQLServerAccessTokenCallback callback = (SQLServerAccessTokenCallback) activeConnectionProperties + .get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString()); + + if (authenticationString.equals(SqlAuthentication.NotSpecified.toString()) && null != callback) { + fedAuthToken = callback.getAccessToken(fedAuthInfo.spn, fedAuthInfo.stsurl); + } else { + fedAuthToken = getFedAuthToken(fedAuthInfo); + } + attemptRefreshTokenLocked = false; // fedAuthToken cannot be null. @@ -5540,8 +5577,8 @@ void onFedAuthInfo(SqlFedAuthInfo fedAuthInfo, TDSTokenHandler tdsTokenHandler) fedAuthCommand.execute(tdsChannel.getWriter(), tdsChannel.getReader(fedAuthCommand)); } - private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLServerException { - SqlFedAuthToken fedAuthToken = null; + private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLServerException { + SqlAuthenticationToken fedAuthToken = null; // fedAuthInfo should not be null. assert null != fedAuthInfo; @@ -5552,7 +5589,7 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe int sleepInterval = 100; if (!msalContextExists() - && !authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryInteractive.toString())) { + && !authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryIntegrated.toString())) { MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_MSALMissing")); throw new SQLServerException(form.format(new Object[] {authenticationString}), null, 0, null); } @@ -5613,7 +5650,9 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe byte[] accessTokenFromDLL = dllInfo.accessTokenBytes; String accessToken = new String(accessTokenFromDLL, UTF_16LE); - fedAuthToken = new SqlFedAuthToken(accessToken, dllInfo.expiresIn); + Date now = new Date(); + now.setTime(now.getTime() + (dllInfo.expiresIn * 1000)); + fedAuthToken = new SqlAuthenticationToken(accessToken, now); // Break out of the retry loop in successful case. break; @@ -5721,10 +5760,10 @@ private boolean msalContextExists() { /** * Send the access token to the server. */ - private void sendFedAuthToken(FedAuthTokenCommand fedAuthCommand, SqlFedAuthToken fedAuthToken, + private void sendFedAuthToken(FedAuthTokenCommand fedAuthCommand, SqlAuthenticationToken fedAuthToken, TDSTokenHandler tdsTokenHandler) throws SQLServerException { assert null != fedAuthToken; - assert null != fedAuthToken.accessToken; + assert null != fedAuthToken.getAccessToken(); if (connectionlogger.isLoggable(Level.FINER)) { connectionlogger.fine(toString() + " Sending federated authentication token."); @@ -5732,7 +5771,7 @@ private void sendFedAuthToken(FedAuthTokenCommand fedAuthCommand, SqlFedAuthToke TDSWriter tdsWriter = fedAuthCommand.startRequest(TDS.PKT_FEDAUTH_TOKEN_MESSAGE); - byte[] accessToken = fedAuthToken.accessToken.getBytes(UTF_16LE); + byte[] accessToken = fedAuthToken.getAccessToken().getBytes(UTF_16LE); // Send total length (length of token plus 4 bytes for the token length field) // If we were sending a nonce, this would include that length as well diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java index 191a8785a..d2565c823 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java @@ -1224,6 +1224,30 @@ public int getMsiTokenCacheTtl() { return 0; } + /** + * Sets the {@link SQLServerAccessTokenCallback} delegate. + * + * @param accessTokenCallback + * Access token callback delegate. + */ + @Override + public void setAccessTokenCallback(SQLServerAccessTokenCallback accessTokenCallback) { + setObjectProperty(connectionProps, SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString(), + accessTokenCallback); + } + + /** + * Returns a {@link SQLServerAccessTokenCallback}, the access token callback delegate. + * + * @return Access token callback delegate. + */ + @Override + public SQLServerAccessTokenCallback getAccessTokenCallback() { + return (SQLServerAccessTokenCallback) getObjectProperty(connectionProps, + SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString(), + SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.getDefaultValue()); + } + /** * Sets a property string value. * diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java index 603df2155..a0227bf14 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java @@ -65,7 +65,6 @@ enum SqlAuthentication { SqlPassword, ActiveDirectoryPassword, ActiveDirectoryIntegrated, - ActiveDirectoryMSI, ActiveDirectoryManagedIdentity, ActiveDirectoryServicePrincipal, ActiveDirectoryInteractive, @@ -395,7 +394,8 @@ static ApplicationIntent valueOfString(String value) throws SQLServerException { enum SQLServerDriverObjectProperty { - GSS_CREDENTIAL("gsscredential", null); + GSS_CREDENTIAL("gsscredential", null), + ACCESS_TOKEN_CALLBACK("accessTokenCallback", null); private final String name; private final String defaultValue; diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java index 516200125..5ec6d09d6 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java @@ -45,7 +45,7 @@ class SQLServerMSAL4JUtils { private static final java.util.logging.Logger logger = java.util.logging.Logger .getLogger("com.microsoft.sqlserver.jdbc.SQLServerMSAL4JUtils"); - static SqlFedAuthToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String user, String password, + static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String user, String password, String authenticationString) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); @@ -64,7 +64,7 @@ static SqlFedAuthToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String use + authenticationResult.expiresOnDate()); } - return new SqlFedAuthToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate()); + return new SqlAuthenticationToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate()); } catch (MalformedURLException | InterruptedException e) { // re-interrupt thread Thread.currentThread().interrupt(); @@ -77,7 +77,7 @@ static SqlFedAuthToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String use } } - static SqlFedAuthToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuthInfo, String aadPrincipalID, + static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuthInfo, String aadPrincipalID, String aadPrincipalSecret, String authenticationString) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); try { @@ -99,7 +99,7 @@ static SqlFedAuthToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuthInfo, S + authenticationResult.expiresOnDate()); } - return new SqlFedAuthToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate()); + return new SqlAuthenticationToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate()); } catch (MalformedURLException | InterruptedException e) { // re-interrupt thread Thread.currentThread().interrupt(); @@ -112,7 +112,7 @@ static SqlFedAuthToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuthInfo, S } } - static SqlFedAuthToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo, + static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo, String authenticationString) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); @@ -142,7 +142,7 @@ static SqlFedAuthToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo, + authenticationResult.expiresOnDate()); } - return new SqlFedAuthToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate()); + return new SqlAuthenticationToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate()); } catch (InterruptedException | IOException e) { // re-interrupt thread Thread.currentThread().interrupt(); @@ -155,7 +155,7 @@ static SqlFedAuthToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo, } } - static SqlFedAuthToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAuthInfo, String user, + static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAuthInfo, String user, String authenticationString) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); @@ -208,7 +208,7 @@ static SqlFedAuthToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAuthInfo, + authenticationResult.expiresOnDate()); } - return new SqlFedAuthToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate()); + return new SqlAuthenticationToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate()); } catch (MalformedURLException | InterruptedException | URISyntaxException e) { // re-interrupt thread Thread.currentThread().interrupt(); diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java index b264d3db3..420905d8a 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java @@ -339,8 +339,9 @@ protected Object[][] getContents() { {"R_NtlmNoUserPasswordDomain", "\"User\" (or \"UserName\") and \"Password\" connection properties must be specified for NTLM authentication."}, {"R_SetAccesstokenWhenIntegratedSecurityTrue", "Cannot set the AccessToken property if the \"IntegratedSecurity\" connection string keyword has been set to \"true\"."}, {"R_IntegratedAuthenticationWithUserPassword", "Cannot use \"Authentication=ActiveDirectoryIntegrated\" with \"User\", \"UserName\" or \"Password\" connection string keywords."}, - {"R_MSIAuthenticationWithPassword", "Cannot use \"Authentication={0}\" with \"Password\" connection string keyword."}, + {"R_ManagedIdentityAuthenticationWithPassword", "Cannot use \"Authentication={0}\" with \"Password\" connection string keyword."}, {"R_AccessTokenWithUserPassword", "Cannot set the AccessToken property if \"User\", \"UserName\" or \"Password\" has been specified in the connection string."}, + {"R_AccessTokenCallbackWithUserPassword", "Cannot set the access token callback if \"User\", \"UserName\" or \"Password\" has been set."}, {"R_AccessTokenCannotBeEmpty", "AccesToken cannot be empty."}, {"R_SetBothAuthenticationAndAccessToken", "Cannot set the AccessToken property if \"Authentication\" has been specified in the connection string."}, {"R_NoUserPasswordForActivePassword", "Both \"User\" (or \"UserName\") and \"Password\" connection string keywords must be specified, if \"Authentication=ActiveDirectoryPassword\"."}, diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java index 277ddbc6e..1fde31b08 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java @@ -328,7 +328,7 @@ static void verifyColumnMasterKeyMetadata(SQLServerConnection connection, SQLSer * @return fedauth token * @throws SQLServerException */ - static SqlFedAuthToken getManagedIdentityCredAuthToken(String resource, + static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, String managedIdentityClientId) throws SQLServerException { ManagedIdentityCredential mic = null; @@ -343,7 +343,7 @@ static SqlFedAuthToken getManagedIdentityCredAuthToken(String resource, + SQLServerMSAL4JUtils.SLASH_DEFAULT; tokenRequestContext.setScopes(Arrays.asList(scope)); - SqlFedAuthToken sqlFedAuthToken = null; + SqlAuthenticationToken sqlFedAuthToken = null; Optional accessTokenOptional = mic.getToken(tokenRequestContext).blockOptional(); @@ -352,7 +352,8 @@ static SqlFedAuthToken getManagedIdentityCredAuthToken(String resource, null); } else { AccessToken accessToken = accessTokenOptional.get(); - sqlFedAuthToken = new SqlFedAuthToken(accessToken.getToken(), accessToken.getExpiresAt().toEpochSecond()); + sqlFedAuthToken = new SqlAuthenticationToken(accessToken.getToken(), + accessToken.getExpiresAt().toEpochSecond()); } return sqlFedAuthToken; @@ -368,7 +369,7 @@ static SqlFedAuthToken getManagedIdentityCredAuthToken(String resource, * @return fedauth token * @throws SQLServerException */ - static SqlFedAuthToken getDefaultAzureCredAuthToken(String resource, + static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource, String managedIdentityClientId) throws SQLServerException { String intellijKeepassPath = System.getenv(INTELLIJ_KEEPASS_PASS); String[] additionallyAllowedTenants = getAdditonallyAllowedTenants(); @@ -395,7 +396,7 @@ static SqlFedAuthToken getDefaultAzureCredAuthToken(String resource, + SQLServerMSAL4JUtils.SLASH_DEFAULT; tokenRequestContext.setScopes(Arrays.asList(scope)); - SqlFedAuthToken sqlFedAuthToken = null; + SqlAuthenticationToken sqlFedAuthToken = null; Optional accessTokenOptional = dac.getToken(tokenRequestContext).blockOptional(); @@ -404,7 +405,8 @@ static SqlFedAuthToken getDefaultAzureCredAuthToken(String resource, null); } else { AccessToken accessToken = accessTokenOptional.get(); - sqlFedAuthToken = new SqlFedAuthToken(accessToken.getToken(), accessToken.getExpiresAt().toEpochSecond()); + sqlFedAuthToken = new SqlAuthenticationToken(accessToken.getToken(), + accessToken.getExpiresAt().toEpochSecond()); } return sqlFedAuthToken; diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SqlAuthenticationToken.java b/src/main/java/com/microsoft/sqlserver/jdbc/SqlAuthenticationToken.java new file mode 100644 index 000000000..6c4420919 --- /dev/null +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SqlAuthenticationToken.java @@ -0,0 +1,74 @@ +/* + * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made + * available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ + +package com.microsoft.sqlserver.jdbc; + +import java.io.Serializable; +import java.util.Date; + + +/** + * Provides an implementation of a SqlAuthenticationToken + */ +public class SqlAuthenticationToken implements Serializable { + + /** Always update serialVersionUID when prompted **/ + private static final long serialVersionUID = -1343105491285383937L; + + /** The token expiration date. **/ + private final Date expiresOn; + + /** The access token string. **/ + private final String accessToken; + + + /** + * Contructs a SqlAuthentication token. + * + * @param accessToken + * The access token string. + * @param expiresOn + * The expiration date in seconds since the unix epoch. + */ + public SqlAuthenticationToken(String accessToken, long expiresOn) { + this.accessToken = accessToken; + this.expiresOn = new Date(expiresOn); + } + + /** + * Contructs a SqlAuthentication token. + * + * @param accessToken + * The access token string. + * @param expiresOn + * The expiration date. + */ + public SqlAuthenticationToken(String accessToken, Date expiresOn) { + this.accessToken = accessToken; + this.expiresOn = expiresOn; + } + + /** + * Returns the expiration date of the token. + * + * @return The token expiration date. + */ + public Date getExpiresOn() { + return expiresOn; + } + + /** + * Returns the access token string. + * + * @return The access token. + */ + public String getAccessToken() { + return accessToken; + } + + public String toString() { + return "accessToken hashCode: " + accessToken.hashCode() + " expiresOn: " + expiresOn.toInstant().toString(); + } +} diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SqlFedAuthToken.java b/src/main/java/com/microsoft/sqlserver/jdbc/SqlFedAuthToken.java deleted file mode 100644 index 912f4317c..000000000 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SqlFedAuthToken.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made - * available under the terms of the MIT License. See the LICENSE file in the project root for more information. - */ - -package com.microsoft.sqlserver.jdbc; - -import java.io.Serializable; -import java.util.Date; - - -/** - * Provides an implementation of a FedAuth token - */ -class SqlFedAuthToken implements Serializable { - /** - * Always update serialVersionUID when prompted - */ - private static final long serialVersionUID = -1343105491285383937L; - - final Date expiresOn; - final String accessToken; - - SqlFedAuthToken(String accessToken, long expiresIn) { - this.accessToken = accessToken; - - Date now = new Date(); - now.setTime(now.getTime() + (expiresIn * 1000)); - this.expiresOn = now; - } - - SqlFedAuthToken(String accessToken, Date expiresOn) { - this.accessToken = accessToken; - this.expiresOn = expiresOn; - } - - public String toString() { - return "accessToken hashCode: " + accessToken.hashCode() + " expiresOn: " + expiresOn.toInstant().toString(); - } -} diff --git a/src/samples/azureactivedirectoryauthentication/pom.xml b/src/samples/azureactivedirectoryauthentication/pom.xml index d8105446a..5558a02fd 100644 --- a/src/samples/azureactivedirectoryauthentication/pom.xml +++ b/src/samples/azureactivedirectoryauthentication/pom.xml @@ -34,6 +34,22 @@ + + AzureActiveDirectoryAccessTokenCallback + + AzureActiveDirectoryAccessTokenCallback + + + org.codehaus.mojo + exec-maven-plugin + 1.2.1 + + azureactivedirectoryauthentication.src.main.java.AzureActiveDirectoryAccessTokenCallback + + + + + diff --git a/src/samples/azureactivedirectoryauthentication/src/main/java/AzureActiveDirectoryAccessTokenCallback.java b/src/samples/azureactivedirectoryauthentication/src/main/java/AzureActiveDirectoryAccessTokenCallback.java new file mode 100644 index 000000000..1c94dbdfa --- /dev/null +++ b/src/samples/azureactivedirectoryauthentication/src/main/java/AzureActiveDirectoryAccessTokenCallback.java @@ -0,0 +1,66 @@ +/* + * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made + * available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ +package azureactivedirectoryauthentication.src.main.java; + +import com.microsoft.aad.msal4j.IClientCredential; +import com.microsoft.aad.msal4j.ClientCredentialFactory; +import com.microsoft.aad.msal4j.ConfidentialClientApplication; +import com.microsoft.aad.msal4j.IAuthenticationResult; +import com.microsoft.aad.msal4j.ClientCredentialParameters; +import java.sql.Connection; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +/** + * Sample code that demonstrates how to use access token callback. + */ +public class AzureActiveDirectoryAccessTokenCallback { + + public static void main(String[] args) { + + SQLServerAccessTokenCallback callback = new SQLServerAccessTokenCallback() { + @Override + public SqlAuthenticationToken getAccessToken(String spn, String stsurl) { + + String clientSecret = ""; + String clientId = ""; + + String scope = spn + "/.default"; + Set scopes = new HashSet<>(); + scopes.add(scope); + + try { + ExecutorService executorService = Executors.newSingleThreadExecutor(); + IClientCredential credential = ClientCredentialFactory.createFromSecret(clientSecret); + ConfidentialClientApplication clientApplication = ConfidentialClientApplication + .builder(clientId, credential).executorService(executorService).authority(stsurl).build(); + CompletableFuture future = clientApplication + .acquireToken(ClientCredentialParameters.builder(scopes).build()); + + IAuthenticationResult authenticationResult = future.get(); + String accessToken = authenticationResult.accessToken(); + + return new SqlAuthenticationToken(accessToken, authenticationResult.expiresOnDate().getTime()); + } catch (Exception e) { + e.printStackTrace(); + } + return null; + } + }; + + SQLServerDataSource ds = new SQLServerDataSource(); + ds.setServerName(""); + ds.setDatabaseName(""); + ds.setAccessTokenCallback(callback); + + try (Connection conn = (SQLServerConnection) ds.getConnection()) { + System.out.println("Connected..."); + } + + } +} diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/connection/NativeMSSQLDataSourceTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/connection/NativeMSSQLDataSourceTest.java index e8c508f40..b477e52cd 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/connection/NativeMSSQLDataSourceTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/connection/NativeMSSQLDataSourceTest.java @@ -4,6 +4,7 @@ */ package com.microsoft.sqlserver.jdbc.connection; +import static org.junit.Assert.fail; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -16,8 +17,25 @@ import java.io.PrintWriter; import java.sql.Connection; import java.sql.SQLException; - -import com.microsoft.sqlserver.jdbc.*; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import com.microsoft.aad.msal4j.IClientCredential; +import com.microsoft.aad.msal4j.ClientCredentialFactory; +import com.microsoft.aad.msal4j.ConfidentialClientApplication; +import com.microsoft.aad.msal4j.IAuthenticationResult; +import com.microsoft.aad.msal4j.ClientCredentialParameters; +import com.microsoft.sqlserver.jdbc.SQLServerXADataSource; +import com.microsoft.sqlserver.jdbc.SQLServerDataSource; +import com.microsoft.sqlserver.jdbc.SQLServerAccessTokenCallback; +import com.microsoft.sqlserver.jdbc.SqlAuthenticationToken; +import com.microsoft.sqlserver.jdbc.TestResource; +import com.microsoft.sqlserver.jdbc.SQLServerConnectionPoolDataSource; +import com.microsoft.sqlserver.jdbc.SQLServerConnection; +import com.microsoft.sqlserver.jdbc.ISQLServerDataSource; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; @@ -66,6 +84,54 @@ public void testDSNormal() throws ClassNotFoundException, IOException, SQLExcept try (Connection conn = ds.getConnection()) {} } + @Tag(Constants.xSQLv11) + @Tag(Constants.xSQLv12) + @Tag(Constants.xSQLv14) + @Tag(Constants.xSQLv15) + @Tag(Constants.xAzureSQLDW) + @Tag(Constants.reqExternalSetup) + @Test + public void testDSPooledConnectionAccessTokenCallback() throws SQLException { + SQLServerAccessTokenCallback callback = new SQLServerAccessTokenCallback() { + @Override + public SqlAuthenticationToken getAccessToken(String spn, String stsurl) { + String scope = spn + "/.default"; + Set scopes = new HashSet<>(); + scopes.add(scope); + + try { + ExecutorService executorService = Executors.newSingleThreadExecutor(); + IClientCredential credential = ClientCredentialFactory.createFromSecret(applicationKey); + ConfidentialClientApplication clientApplication = ConfidentialClientApplication + .builder(applicationClientID, credential).executorService(executorService) + .authority(stsurl).build(); + CompletableFuture future = clientApplication + .acquireToken(ClientCredentialParameters.builder(scopes).build()); + + IAuthenticationResult authenticationResult = future.get(); + String accessToken = authenticationResult.accessToken(); + long expiresOn = authenticationResult.expiresOnDate().getTime(); + + return new SqlAuthenticationToken(accessToken, expiresOn); + + } catch (Exception e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + return null; + } + }; + + SQLServerConnectionPoolDataSource ds = new SQLServerConnectionPoolDataSource(); + AbstractTest.updateDataSource(connectionString, ds); + ds.setUser(""); + ds.setPassword(""); + ds.setAccessTokenCallback(callback); + + // Callback should provide valid token on connection open for all new connections + try (Connection conn1 = (SQLServerConnection) ds.getConnection()) {} + try (Connection conn2 = (SQLServerConnection) ds.getConnection()) {} + } + @Test @Tag(Constants.xAzureSQLDW) @Tag(Constants.xAzureSQLDB) diff --git a/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java b/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java index 6a91da948..4ecd403db 100644 --- a/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java +++ b/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java @@ -84,6 +84,11 @@ public abstract class AbstractTest { // properties needed for Managed Identity protected static String managedIdentityClientId = null; + + // properties for access token callback testing + protected static String accessTokenClientId = null; + protected static String accessTokenSecret = null; + protected static String keyStorePrincipalId = null; protected static String keyStoreSecret = null; @@ -138,6 +143,9 @@ public static void setup() throws Exception { applicationKey = getConfiguredProperty("applicationKey"); tenantID = getConfiguredProperty("tenantID"); + accessTokenClientId = getConfiguredProperty("accessTokenClientId"); + accessTokenSecret = getConfiguredProperty("accessTokenSecret"); + encrypt = getConfiguredProperty("encrypt", "false"); connectionString = TestUtils.addOrOverrideProperty(connectionString, "encrypt", encrypt);