Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Access token callback #1940

Merged
merged 12 commits into from
Nov 3, 2022
1 change: 1 addition & 0 deletions src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ final class TDS {
static final byte ADALWORKFLOW_ACTIVEDIRECTORYMANAGEDIDENTITY = 0x03;
static final byte ADALWORKFLOW_ACTIVEDIRECTORYINTERACTIVE = 0x03;
static final byte ADALWORKFLOW_DEFAULTAZURECREDENTIAL = 0x03;
static final byte ADALWORKFLOW_ACCESSTOKENCALLBACK = 0x03;
static final byte ADALWORKFLOW_ACTIVEDIRECTORYSERVICEPRINCIPAL = 0x01; // Using the Password byte as that is the
// closest we have.
static final byte FEDAUTH_INFO_ID_STSURL = 0x01; // FedAuthInfoData is token endpoint URL from which to acquire fed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,13 +411,18 @@ CallableStatement prepareCall(String sql, int nType, int nConcur, int nHold,
/**
* Deprecated. Time-to-live is no longer supported for the cached Managed Identity tokens.
* This method will always return 0 and is for backwards compatibility only.
*
* @return Method will always return 0.
*/
@Deprecated
int getMsiTokenCacheTtl();

/**
* Deprecated. Time-to-live is no longer supported for the cached Managed Identity tokens.
* This method is a no-op for backwards compatibility only.
*
* @param timeToLive
* Time-to-live is no longer supported.
*/
@Deprecated
void setMsiTokenCacheTtl(int timeToLive);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1216,14 +1216,34 @@ public interface ISQLServerDataSource extends javax.sql.CommonDataSource {
/**
* Deprecated. Time-to-live is no longer supported for the cached Managed Identity tokens.
* This method is a no-op for backwards compatibility only.
*
* @param timeToLive
* Time-to-live is no longer supported.
*/
@Deprecated
void setMsiTokenCacheTtl(int timeToLive);

/**
* Deprecated. Time-to-live is no longer supported for the cached Managed Identity tokens.
* This method will always return 0 and is for backwards compatibility only.
*
* @return Method will always return 0.
*/
@Deprecated
int getMsiTokenCacheTtl();

tkyc marked this conversation as resolved.
Show resolved Hide resolved
/**
* Sets the {@link SQLServerAccessTokenCallback} delegate.
*
* @param accessTokenCallback
* Access token callback delegate.
*/
void setAccessTokenCallback(SQLServerAccessTokenCallback accessTokenCallback);

/**
* Returns a {@link SQLServerAccessTokenCallback}, the access token callback delegate.
*
* @return Access token callback delegate.
*/
SQLServerAccessTokenCallback getAccessTokenCallback();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.microsoft.sqlserver.jdbc;

/**
* Provides SqlAuthenticationToken callback to be implemented by client code.
*/
public interface SQLServerAccessTokenCallback {

/**
* For an example of callback usage, look under the project's code samples.
*
* Returns the access token for the authentication request
*
* @param stsurl
* - Security token service URL.
* @param spn
* - Service principal name.
*
* @return Returns a {@link SqlAuthenticationToken}.
*/
SqlAuthenticationToken getAccessToken(String stsurl, String spn);
}
73 changes: 56 additions & 17 deletions src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ public class SQLServerConnection implements ISQLServerConnection, java.io.Serial
private byte[] accessTokenInByte = null;

/** fedAuth token */
private SqlFedAuthToken fedAuthToken = null;
private SqlAuthenticationToken fedAuthToken = null;

/** original hostNameInCertificate */
private String originalHostNameInCertificate = null;
Expand Down Expand Up @@ -537,6 +537,13 @@ class FederatedAuthenticationFeatureExtensionData implements Serializable {
this.authentication = SqlAuthentication.ActiveDirectoryInteractive;
break;
default:
// If authenticationString not specified, check if access token callback was set.
// If access token callback is set, break.
if (null != activeConnectionProperties
.get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString())) {
this.authentication = SqlAuthentication.NotSpecified;
break;
}
assert (false);
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_InvalidConnectionSetting"));
Expand Down Expand Up @@ -1669,7 +1676,7 @@ void checkClosed() throws SQLServerException {
* @return true/false
*/
protected boolean needsReconnect() {
return (null != fedAuthToken && Util.checkIfNeedNewAccessToken(this, fedAuthToken.expiresOn));
return (null != fedAuthToken && Util.checkIfNeedNewAccessToken(this, fedAuthToken.getExpiresOn()));
}

/**
Expand Down Expand Up @@ -2423,6 +2430,17 @@ Connection connectInternal(Properties propsIn,
ntlmAuthentication = true;
}

SQLServerAccessTokenCallback callback = (SQLServerAccessTokenCallback) activeConnectionProperties
.get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString());

if (null != callback && (!activeConnectionProperties
.getProperty(SQLServerDriverStringProperty.USER.toString()).isEmpty()
|| !activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString())
.isEmpty())) {
throw new SQLServerException(
SQLServerException.getErrString("R_AccessTokenCallbackWithUserPassword"), null);
}

sPropKey = SQLServerDriverStringProperty.AUTHENTICATION.toString();
sPropValue = activeConnectionProperties.getProperty(sPropKey);
if (null == sPropValue) {
Expand All @@ -2434,7 +2452,7 @@ Connection connectInternal(Properties propsIn,
&& (!activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString())
.isEmpty())) {
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_MSIAuthenticationWithPassword"));
SQLServerException.getErrString("R_ManagedIdentityAuthenticationWithPassword"));
throw new SQLServerException(form.format(new Object[] {authenticationString}), null);
}

Expand Down Expand Up @@ -2466,7 +2484,7 @@ Connection connectInternal(Properties propsIn,
&& (!activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString())
.isEmpty())) {
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_MSIAuthenticationWithPassword"));
SQLServerException.getErrString("R_ManagedIdentityAuthenticationWithPassword"));
throw new SQLServerException(form.format(new Object[] {authenticationString}), null);
}

Expand Down Expand Up @@ -3464,7 +3482,8 @@ private void executeReconnect(LogonCommand logonCommand) throws SQLServerExcepti
void prelogin(String serverName, int portNumber) throws SQLServerException {
// Build a TDS Pre-Login packet to send to the server.
if ((!authenticationString.equalsIgnoreCase(SqlAuthentication.NotSpecified.toString()))
|| (null != accessTokenInByte)) {
|| (null != accessTokenInByte) || null != activeConnectionProperties
.get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString())) {
fedAuthRequiredByUser = true;
}

Expand Down Expand Up @@ -3846,7 +3865,8 @@ void prelogin(String serverName, int portNumber) throws SQLServerException {
// Or AccessToken is not null, mean token based authentication is used.
if (((null != authenticationString)
&& (!authenticationString.equalsIgnoreCase(SqlAuthentication.NotSpecified.toString())))
|| (null != accessTokenInByte)) {
|| (null != accessTokenInByte) || null != activeConnectionProperties
.get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString())) {
fedAuthRequiredPreLoginResponse = (preloginResponse[optionOffset] == 1);
}
break;
Expand Down Expand Up @@ -4790,6 +4810,12 @@ int writeFedAuthFeatureRequest(boolean write, /* if false just calculates the le
workflow = TDS.ADALWORKFLOW_ACTIVEDIRECTORYSERVICEPRINCIPAL;
break;
default:
// If not specified, check if access token callback was set. If it is set, break.
if (null != activeConnectionProperties
.get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString())) {
workflow = TDS.ADALWORKFLOW_ACCESSTOKENCALLBACK;
break;
}
assert (false); // Unrecognized Authentication type for fedauth ADAL request
break;
}
Expand Down Expand Up @@ -5005,7 +5031,9 @@ private void logon(LogonCommand command) throws SQLServerException {
.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryServicePrincipal.toString())
|| authenticationString
.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryInteractive.toString()))
&& fedAuthRequiredPreLoginResponse)) {
&& fedAuthRequiredPreLoginResponse)
|| null != activeConnectionProperties
.get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString())) {
federatedAuthenticationInfoRequested = true;
fedAuthFeatureExtensionData = new FederatedAuthenticationFeatureExtensionData(TDS.TDS_FEDAUTH_LIBRARY_ADAL,
authenticationString, fedAuthRequiredPreLoginResponse);
Expand Down Expand Up @@ -5499,9 +5527,9 @@ final class FedAuthTokenCommand extends UninterruptableTDSCommand {
// Always update serialVersionUID when prompted.
private static final long serialVersionUID = 1L;
TDSTokenHandler tdsTokenHandler = null;
SqlFedAuthToken sqlFedAuthToken = null;
SqlAuthenticationToken sqlFedAuthToken = null;

FedAuthTokenCommand(SqlFedAuthToken sqlFedAuthToken, TDSTokenHandler tdsTokenHandler) {
FedAuthTokenCommand(SqlAuthenticationToken sqlFedAuthToken, TDSTokenHandler tdsTokenHandler) {
super("FedAuth");
this.tdsTokenHandler = tdsTokenHandler;
this.sqlFedAuthToken = sqlFedAuthToken;
Expand Down Expand Up @@ -5530,7 +5558,16 @@ void onFedAuthInfo(SqlFedAuthInfo fedAuthInfo, TDSTokenHandler tdsTokenHandler)
assert null != fedAuthInfo;

attemptRefreshTokenLocked = true;
fedAuthToken = getFedAuthToken(fedAuthInfo);

SQLServerAccessTokenCallback callback = (SQLServerAccessTokenCallback) activeConnectionProperties
.get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString());

if (authenticationString.equals(SqlAuthentication.NotSpecified.toString()) && null != callback) {
fedAuthToken = callback.getAccessToken(fedAuthInfo.spn, fedAuthInfo.stsurl);
} else {
fedAuthToken = getFedAuthToken(fedAuthInfo);
}

attemptRefreshTokenLocked = false;

// fedAuthToken cannot be null.
Expand All @@ -5540,8 +5577,8 @@ void onFedAuthInfo(SqlFedAuthInfo fedAuthInfo, TDSTokenHandler tdsTokenHandler)
fedAuthCommand.execute(tdsChannel.getWriter(), tdsChannel.getReader(fedAuthCommand));
}

private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLServerException {
SqlFedAuthToken fedAuthToken = null;
private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLServerException {
SqlAuthenticationToken fedAuthToken = null;

// fedAuthInfo should not be null.
assert null != fedAuthInfo;
Expand All @@ -5552,7 +5589,7 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe
int sleepInterval = 100;

if (!msalContextExists()
&& !authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryInteractive.toString())) {
&& !authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryIntegrated.toString())) {
tkyc marked this conversation as resolved.
Show resolved Hide resolved
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_MSALMissing"));
throw new SQLServerException(form.format(new Object[] {authenticationString}), null, 0, null);
}
Expand Down Expand Up @@ -5613,7 +5650,9 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe
byte[] accessTokenFromDLL = dllInfo.accessTokenBytes;

String accessToken = new String(accessTokenFromDLL, UTF_16LE);
fedAuthToken = new SqlFedAuthToken(accessToken, dllInfo.expiresIn);
Date now = new Date();
now.setTime(now.getTime() + (dllInfo.expiresIn * 1000));
fedAuthToken = new SqlAuthenticationToken(accessToken, now);

// Break out of the retry loop in successful case.
break;
Expand Down Expand Up @@ -5721,18 +5760,18 @@ private boolean msalContextExists() {
/**
* Send the access token to the server.
*/
private void sendFedAuthToken(FedAuthTokenCommand fedAuthCommand, SqlFedAuthToken fedAuthToken,
private void sendFedAuthToken(FedAuthTokenCommand fedAuthCommand, SqlAuthenticationToken fedAuthToken,
TDSTokenHandler tdsTokenHandler) throws SQLServerException {
assert null != fedAuthToken;
assert null != fedAuthToken.accessToken;
assert null != fedAuthToken.getAccessToken();

if (connectionlogger.isLoggable(Level.FINER)) {
connectionlogger.fine(toString() + " Sending federated authentication token.");
}

TDSWriter tdsWriter = fedAuthCommand.startRequest(TDS.PKT_FEDAUTH_TOKEN_MESSAGE);

byte[] accessToken = fedAuthToken.accessToken.getBytes(UTF_16LE);
byte[] accessToken = fedAuthToken.getAccessToken().getBytes(UTF_16LE);

// Send total length (length of token plus 4 bytes for the token length field)
// If we were sending a nonce, this would include that length as well
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,30 @@ public int getMsiTokenCacheTtl() {
return 0;
}

/**
* Sets the {@link SQLServerAccessTokenCallback} delegate.
*
* @param accessTokenCallback
* Access token callback delegate.
*/
@Override
public void setAccessTokenCallback(SQLServerAccessTokenCallback accessTokenCallback) {
setObjectProperty(connectionProps, SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString(),
accessTokenCallback);
}

/**
* Returns a {@link SQLServerAccessTokenCallback}, the access token callback delegate.
*
* @return Access token callback delegate.
*/
@Override
public SQLServerAccessTokenCallback getAccessTokenCallback() {
return (SQLServerAccessTokenCallback) getObjectProperty(connectionProps,
SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString(),
SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.getDefaultValue());
}

/**
* Sets a property string value.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ enum SqlAuthentication {
SqlPassword,
ActiveDirectoryPassword,
ActiveDirectoryIntegrated,
ActiveDirectoryMSI,
ActiveDirectoryManagedIdentity,
ActiveDirectoryServicePrincipal,
ActiveDirectoryInteractive,
Expand Down Expand Up @@ -395,7 +394,8 @@ static ApplicationIntent valueOfString(String value) throws SQLServerException {


enum SQLServerDriverObjectProperty {
GSS_CREDENTIAL("gsscredential", null);
GSS_CREDENTIAL("gsscredential", null),
ACCESS_TOKEN_CALLBACK("accessTokenCallback", null);

private final String name;
private final String defaultValue;
Expand Down
Loading