Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Access token callback #1940

Merged
merged 12 commits into from
Nov 3, 2022
12 changes: 8 additions & 4 deletions src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ final class TDS {
static final byte ADALWORKFLOW_ACTIVEDIRECTORYMANAGEDIDENTITY = 0x03;
static final byte ADALWORKFLOW_ACTIVEDIRECTORYINTERACTIVE = 0x03;
static final byte ADALWORKFLOW_DEFAULTAZURECREDENTIAL = 0x03;
static final byte ADALWORKFLOW_ACCESSTOKENCALLBACK = 0x03;
static final byte ADALWORKFLOW_ACTIVEDIRECTORYSERVICEPRINCIPAL = 0x01; // Using the Password byte as that is the
// closest we have.
static final byte FEDAUTH_INFO_ID_STSURL = 0x01; // FedAuthInfoData is token endpoint URL from which to acquire fed
Expand Down Expand Up @@ -6862,12 +6863,15 @@ final boolean readPacket() throws SQLServerException {
TDS.PACKET_HEADER_SIZE - headerBytesRead);
if (bytesRead < 0) {
if (logger.isLoggable(Level.FINER))
logger.finer(toString() + " Premature EOS in response. packetNum:" + packetNum + " headerBytesRead:"
+ headerBytesRead);
logger.finer(toString() + " Premature EOS in response. packetNum:" + packetNum
+ " headerBytesRead:" + headerBytesRead);

con.terminate(SQLServerException.DRIVER_ERROR_IO_FAILED,
((0 == packetNum && 0 == headerBytesRead) ? SQLServerException.getErrString(
"R_noServerResponse") : SQLServerException.getErrString("R_truncatedServerResponse")));
((0 == packetNum && 0 == headerBytesRead)
? SQLServerException
.getErrString("R_noServerResponse")
: SQLServerException.getErrString(
"R_truncatedServerResponse")));
}

headerBytesRead += bytesRead;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1226,4 +1226,8 @@ public interface ISQLServerDataSource extends javax.sql.CommonDataSource {
*/
@Deprecated
int getMsiTokenCacheTtl();

tkyc marked this conversation as resolved.
Show resolved Hide resolved
void setAccessTokenCallback(SQLServerAccessTokenCallback accessTokenCallback);

SQLServerAccessTokenCallback getAccessTokenCallback();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.microsoft.sqlserver.jdbc;

/**
* Provides SqlAuthenticationToken callback to be implemented by client code.
*/
public interface SQLServerAccessTokenCallback {

/**
* For an example of callback usage, look under the project's code samples.
*
* Returns the access token for the authentication request
*
* @param stsurl
* - Security token service URL.
* @param spn
* - Service principal name.
*/
SqlAuthenticationToken getAccessToken(String stsurl, String spn);
}
164 changes: 105 additions & 59 deletions src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,19 @@ public int getMsiTokenCacheTtl() {
return 0;
}

@Override
public void setAccessTokenCallback(SQLServerAccessTokenCallback accessTokenCallback) {
setObjectProperty(connectionProps, SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString(),
accessTokenCallback);
}

@Override
public SQLServerAccessTokenCallback getAccessTokenCallback() {
return (SQLServerAccessTokenCallback) getObjectProperty(connectionProps,
SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString(),
SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.getDefaultValue());
}

/**
* Sets a property string value.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ enum SqlAuthentication {
SqlPassword,
ActiveDirectoryPassword,
ActiveDirectoryIntegrated,
ActiveDirectoryMSI,
ActiveDirectoryManagedIdentity,
ActiveDirectoryServicePrincipal,
ActiveDirectoryInteractive,
Expand Down Expand Up @@ -395,7 +394,8 @@ static ApplicationIntent valueOfString(String value) throws SQLServerException {


enum SQLServerDriverObjectProperty {
GSS_CREDENTIAL("gsscredential", null);
GSS_CREDENTIAL("gsscredential", null),
ACCESS_TOKEN_CALLBACK("accessTokenCallback", null);

private final String name;
private final String defaultValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class SQLServerMSAL4JUtils {
private static final java.util.logging.Logger logger = java.util.logging.Logger
.getLogger("com.microsoft.sqlserver.jdbc.SQLServerMSAL4JUtils");

static SqlFedAuthToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String user, String password,
static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String user, String password,
String authenticationString) throws SQLServerException {
ExecutorService executorService = Executors.newSingleThreadExecutor();

Expand All @@ -64,7 +64,7 @@ static SqlFedAuthToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String use
+ authenticationResult.expiresOnDate());
}

return new SqlFedAuthToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate());
return new SqlAuthenticationToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate());
} catch (MalformedURLException | InterruptedException e) {
// re-interrupt thread
Thread.currentThread().interrupt();
Expand All @@ -77,7 +77,7 @@ static SqlFedAuthToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String use
}
}

static SqlFedAuthToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuthInfo, String aadPrincipalID,
static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuthInfo, String aadPrincipalID,
String aadPrincipalSecret, String authenticationString) throws SQLServerException {
ExecutorService executorService = Executors.newSingleThreadExecutor();
try {
Expand All @@ -99,7 +99,7 @@ static SqlFedAuthToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuthInfo, S
+ authenticationResult.expiresOnDate());
}

return new SqlFedAuthToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate());
return new SqlAuthenticationToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate());
} catch (MalformedURLException | InterruptedException e) {
// re-interrupt thread
Thread.currentThread().interrupt();
Expand All @@ -112,7 +112,7 @@ static SqlFedAuthToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuthInfo, S
}
}

static SqlFedAuthToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo,
static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo,
String authenticationString) throws SQLServerException {
ExecutorService executorService = Executors.newSingleThreadExecutor();

Expand Down Expand Up @@ -142,7 +142,7 @@ static SqlFedAuthToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo,
+ authenticationResult.expiresOnDate());
}

return new SqlFedAuthToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate());
return new SqlAuthenticationToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate());
} catch (InterruptedException | IOException e) {
// re-interrupt thread
Thread.currentThread().interrupt();
Expand All @@ -155,7 +155,7 @@ static SqlFedAuthToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo,
}
}

static SqlFedAuthToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAuthInfo, String user,
static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAuthInfo, String user,
String authenticationString) throws SQLServerException {
ExecutorService executorService = Executors.newSingleThreadExecutor();

Expand Down Expand Up @@ -208,7 +208,7 @@ static SqlFedAuthToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAuthInfo,
+ authenticationResult.expiresOnDate());
}

return new SqlFedAuthToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate());
return new SqlAuthenticationToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate());
} catch (MalformedURLException | InterruptedException | URISyntaxException e) {
// re-interrupt thread
Thread.currentThread().interrupt();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Level;


/**
* Represents a physical database connection in a connection pool. If provides methods for the connection pool manager
* to manage the connection pool. Applications typically do not instantiate these connections directly.
Expand Down Expand Up @@ -225,7 +226,7 @@ public void close() throws SQLException {
try {
// First close the last proxy
if (null != lastProxyConnection)
// use internal close so there wont be an event due to us closing the connection, if not closed already.
// use internal close so there wont be an event due to us closing the connection, if not closed already.
lastProxyConnection.internalClose();
if (null != physicalConnection) {
physicalConnection.DetachFromPool();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,9 @@ protected Object[][] getContents() {
{"R_NtlmNoUserPasswordDomain", "\"User\" (or \"UserName\") and \"Password\" connection properties must be specified for NTLM authentication."},
{"R_SetAccesstokenWhenIntegratedSecurityTrue", "Cannot set the AccessToken property if the \"IntegratedSecurity\" connection string keyword has been set to \"true\"."},
{"R_IntegratedAuthenticationWithUserPassword", "Cannot use \"Authentication=ActiveDirectoryIntegrated\" with \"User\", \"UserName\" or \"Password\" connection string keywords."},
{"R_MSIAuthenticationWithPassword", "Cannot use \"Authentication={0}\" with \"Password\" connection string keyword."},
{"R_ManagedIdentityAuthenticationWithPassword", "Cannot use \"Authentication={0}\" with \"Password\" connection string keyword."},
{"R_AccessTokenWithUserPassword", "Cannot set the AccessToken property if \"User\", \"UserName\" or \"Password\" has been specified in the connection string."},
{"R_AccessTokenCallbackWithUserPassword", "Cannot set the access token callback if \"User\", \"UserName\" or \"Password\" has been set."},
{"R_AccessTokenCannotBeEmpty", "AccesToken cannot be empty."},
{"R_SetBothAuthenticationAndAccessToken", "Cannot set the AccessToken property if \"Authentication\" has been specified in the connection string."},
{"R_NoUserPasswordForActivePassword", "Both \"User\" (or \"UserName\") and \"Password\" connection string keywords must be specified, if \"Authentication=ActiveDirectoryPassword\"."},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ static void verifyColumnMasterKeyMetadata(SQLServerConnection connection, SQLSer
* @return fedauth token
* @throws SQLServerException
*/
static SqlFedAuthToken getManagedIdentityCredAuthToken(String resource,
static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource,
String managedIdentityClientId) throws SQLServerException {
ManagedIdentityCredential mic = null;

Expand All @@ -343,7 +343,7 @@ static SqlFedAuthToken getManagedIdentityCredAuthToken(String resource,
+ SQLServerMSAL4JUtils.SLASH_DEFAULT;
tokenRequestContext.setScopes(Arrays.asList(scope));

SqlFedAuthToken sqlFedAuthToken = null;
SqlAuthenticationToken sqlFedAuthToken = null;

Optional<AccessToken> accessTokenOptional = mic.getToken(tokenRequestContext).blockOptional();

Expand All @@ -352,7 +352,8 @@ static SqlFedAuthToken getManagedIdentityCredAuthToken(String resource,
null);
} else {
AccessToken accessToken = accessTokenOptional.get();
sqlFedAuthToken = new SqlFedAuthToken(accessToken.getToken(), accessToken.getExpiresAt().toEpochSecond());
sqlFedAuthToken = new SqlAuthenticationToken(accessToken.getToken(),
accessToken.getExpiresAt().toEpochSecond());
}

return sqlFedAuthToken;
Expand All @@ -368,7 +369,7 @@ static SqlFedAuthToken getManagedIdentityCredAuthToken(String resource,
* @return fedauth token
* @throws SQLServerException
*/
static SqlFedAuthToken getDefaultAzureCredAuthToken(String resource,
static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource,
String managedIdentityClientId) throws SQLServerException {
String intellijKeepassPath = System.getenv(INTELLIJ_KEEPASS_PASS);
String[] additionallyAllowedTenants = getAdditonallyAllowedTenants();
Expand All @@ -395,7 +396,7 @@ static SqlFedAuthToken getDefaultAzureCredAuthToken(String resource,
+ SQLServerMSAL4JUtils.SLASH_DEFAULT;
tokenRequestContext.setScopes(Arrays.asList(scope));

SqlFedAuthToken sqlFedAuthToken = null;
SqlAuthenticationToken sqlFedAuthToken = null;

Optional<AccessToken> accessTokenOptional = dac.getToken(tokenRequestContext).blockOptional();

Expand All @@ -404,7 +405,8 @@ static SqlFedAuthToken getDefaultAzureCredAuthToken(String resource,
null);
} else {
AccessToken accessToken = accessTokenOptional.get();
sqlFedAuthToken = new SqlFedAuthToken(accessToken.getToken(), accessToken.getExpiresAt().toEpochSecond());
sqlFedAuthToken = new SqlAuthenticationToken(accessToken.getToken(),
accessToken.getExpiresAt().toEpochSecond());
}

return sqlFedAuthToken;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,39 @@


/**
* Provides an implementation of a FedAuth token
* Provides an implementation of a SqlAuthenticationToken
*/
class SqlFedAuthToken implements Serializable {
public class SqlAuthenticationToken implements Serializable {

/**
* Always update serialVersionUID when prompted
*/
private static final long serialVersionUID = -1343105491285383937L;

final Date expiresOn;
final String accessToken;
private final Date expiresOn;
private final String accessToken;

SqlFedAuthToken(String accessToken, long expiresIn) {
public SqlAuthenticationToken(String accessToken, long expiresIn) {
tkyc marked this conversation as resolved.
Show resolved Hide resolved
this.accessToken = accessToken;

Date now = new Date();
now.setTime(now.getTime() + (expiresIn * 1000));
this.expiresOn = now;
}

SqlFedAuthToken(String accessToken, Date expiresOn) {
public SqlAuthenticationToken(String accessToken, Date expiresOn) {
this.accessToken = accessToken;
this.expiresOn = expiresOn;
}

public Date getExpiresOn() {
return expiresOn;
}

public String getAccessToken() {
return accessToken;
}

public String toString() {
return "accessToken hashCode: " + accessToken.hashCode() + " expiresOn: " + expiresOn.toInstant().toString();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made
* available under the terms of the MIT License. See the LICENSE file in the project root for more information.
*/
package azureactivedirectoryauthentication.src.main.java;

import com.microsoft.aad.msal4j.IClientCredential;
import com.microsoft.aad.msal4j.ClientCredentialFactory;
import com.microsoft.aad.msal4j.ConfidentialClientApplication;
import com.microsoft.aad.msal4j.IAuthenticationResult;
import com.microsoft.aad.msal4j.ClientCredentialParameters;
import java.sql.Connection;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
* Sample code that demonstrates how to use access token callback.
*/
public class AzureActiveDirectoryAccessTokenCallback {

public static void main(String[] args) {

SQLServerAccessTokenCallback callback = new SQLServerAccessTokenCallback() {
@Override
public SqlAuthenticationToken getAccessToken(String spn, String stsurl) {

String clientSecret = "<your-client-secret>";
String clientId = "<your-client-id>";

String scope = spn + "/.default";
Set<String> scopes = new HashSet<>();
scopes.add(scope);

try {
ExecutorService executorService = Executors.newSingleThreadExecutor();
IClientCredential credential = ClientCredentialFactory.createFromSecret(clientSecret);
ConfidentialClientApplication clientApplication = ConfidentialClientApplication
.builder(clientId, credential).executorService(executorService).authority(stsurl).build();
CompletableFuture<IAuthenticationResult> future = clientApplication
.acquireToken(ClientCredentialParameters.builder(scopes).build());

IAuthenticationResult authenticationResult = future.get();
String accessToken = authenticationResult.accessToken();

return new SqlAuthenticationToken(accessToken, authenticationResult.expiresOnDate().getTime());
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
};

SQLServerDataSource ds = new SQLServerDataSource();
ds.setServerName("<your-server-name>");
ds.setDatabaseName("<your-database-name>");
ds.setAccessTokenCallback(callback);

try (Connection conn = (SQLServerConnection) ds.getConnection()) {
System.out.println("Connected...");
}

}
}
tkyc marked this conversation as resolved.
Show resolved Hide resolved
Loading