Skip to content

Commit

Permalink
Access token callback
Browse files Browse the repository at this point in the history
  • Loading branch information
tkyc committed Oct 26, 2022
1 parent 095c7ee commit c3c3fd6
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1226,4 +1226,8 @@ public interface ISQLServerDataSource extends javax.sql.CommonDataSource {
*/
@Deprecated
int getMsiTokenCacheTtl();

void setAccessTokenCallback(SQLServerAccessTokenCallback accessTokenCallback);

SQLServerAccessTokenCallback getAccessTokenCallback();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.microsoft.sqlserver.jdbc;

/**
* Provides an access token callback to be implemented by client code.
*/
public interface SQLServerAccessTokenCallback {
String getAccessToken();
}
Original file line number Diff line number Diff line change
Expand Up @@ -2503,6 +2503,8 @@ Connection connectInternal(Properties propsIn,
sPropValue = activeConnectionProperties.getProperty(sPropKey);
if (null != sPropValue) {
accessTokenInByte = sPropValue.getBytes(UTF_16LE);
} else if (null != SQLServerDataSource.accessTokenCallback) {
accessTokenInByte = SQLServerDataSource.accessTokenCallback.getAccessToken().getBytes(UTF_16LE);
}

if ((null != accessTokenInByte) && 0 == accessTokenInByte.length) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ public class SQLServerDataSource
*/
final private String traceID;

/**
* Callback method for returning an access token
*/
static SQLServerAccessTokenCallback accessTokenCallback = null;

/**
* Constructs a SQLServerDataSource.
*/
Expand Down Expand Up @@ -1224,6 +1229,16 @@ public int getMsiTokenCacheTtl() {
return 0;
}

@Override
public void setAccessTokenCallback(SQLServerAccessTokenCallback accessTokenCallback) {
this.accessTokenCallback = accessTokenCallback;
}

@Override
public SQLServerAccessTokenCallback getAccessTokenCallback() {
return accessTokenCallback;
}

/**
* Sets a property string value.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package com.microsoft.sqlserver.jdbc.connection;

import static org.junit.Assert.fail;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

Expand All @@ -16,7 +17,13 @@
import java.io.PrintWriter;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import com.microsoft.aad.msal4j.*;
import com.microsoft.sqlserver.jdbc.*;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Tag;
Expand Down Expand Up @@ -66,6 +73,50 @@ public void testDSNormal() throws ClassNotFoundException, IOException, SQLExcept
try (Connection conn = ds.getConnection()) {}
}

@Tag(Constants.xSQLv11)
@Tag(Constants.xSQLv12)
@Tag(Constants.xSQLv14)
@Tag(Constants.xSQLv15)
@Test
public void testPooledConnectionAccessTokenCallback() throws SQLException {
SQLServerAccessTokenCallback callback = new SQLServerAccessTokenCallback() {
@Override
public String getAccessToken() {
String scope = spn + "/.default";
Set<String> scopes = new HashSet<>();
scopes.add(scope);

try {
ExecutorService executorService = Executors.newSingleThreadExecutor();
IClientCredential credential = ClientCredentialFactory.createFromSecret(accessTokenSecret);
ConfidentialClientApplication clientApplication = ConfidentialClientApplication
.builder(accessTokenClientId, credential).executorService(executorService).authority(accessTokenStsUrl).build();
CompletableFuture<IAuthenticationResult> future = clientApplication
.acquireToken(ClientCredentialParameters.builder(scopes).build());

IAuthenticationResult authenticationResult = future.get();
String accessToken = authenticationResult.accessToken();

return accessToken;

} catch (Exception e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
return null;
}
};

SQLServerConnectionPoolDataSource ds = new SQLServerConnectionPoolDataSource();
AbstractTest.updateDataSource(connectionString, ds);
ds.setUser("");
ds.setPassword("");
ds.setAccessTokenCallback(callback);

// Callback should provide valid token on connection open for all new connections
try (Connection conn1 = (SQLServerConnection) ds.getConnection()) {}
try (Connection conn2 = (SQLServerConnection) ds.getConnection()) {}
}

@Test
@Tag(Constants.xAzureSQLDW)
@Tag(Constants.xAzureSQLDB)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ public abstract class AbstractTest {

// properties needed for Managed Identity
protected static String managedIdentityClientId = null;

// properties for access token callback testing
protected static String accessTokenClientId = null;
protected static String accessTokenStsUrl = null;
protected static String accessTokenSecret = null;
protected static String spn = null;

protected static String keyStorePrincipalId = null;
protected static String keyStoreSecret = null;

Expand Down Expand Up @@ -138,6 +145,11 @@ public static void setup() throws Exception {
applicationKey = getConfiguredProperty("applicationKey");
tenantID = getConfiguredProperty("tenantID");

accessTokenClientId = getConfiguredProperty("accessTokenClientId");
accessTokenSecret = getConfiguredProperty("accessTokenSecret");
accessTokenStsUrl = getConfiguredProperty("accessTokenStsUrl");
spn = getConfiguredProperty("spn");

encrypt = getConfiguredProperty("encrypt", "false");
connectionString = TestUtils.addOrOverrideProperty(connectionString, "encrypt", encrypt);

Expand Down

0 comments on commit c3c3fd6

Please sign in to comment.