Skip to content

Commit

Permalink
Set necessary headers when authenticating via Azure CLI (#136)
Browse files Browse the repository at this point in the history
## Changes
The Java SDK request authentication logic is inconsistent between the
Azure login types: for service principal auth, the SDK correctly adds
the X-Databricks-Azure-Workspace-Resource-Id when configured, but this
is missed for Azure CLI auth. Additionally, when logging in via Azure
CLI using a service principal, the service management token must also be
fetched from the CLI.

This PR fixes this by defining the logic to attach these header in a
common function that is used by all Azure-specific authentication types.

See databricks/databricks-sdk-go#584 for the
same change in the Go SDK.
See databricks/databricks-sdk-py#290 for the
same changes in the Python SDK.
## Tests
- [x] Unit tests to cover the two scenarios for Azure CLI w.r.t.
management endpoint token fetching, and one to verify that
X-Databricks-Azure-Workspace-Resource-Id is included when using Azure
CLI.
  • Loading branch information
mgyucht committed Aug 18, 2023
1 parent e0174d0 commit 7bb4fd0
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 90 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
fmt:
mvn spotless:apply

Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,24 @@ public HeaderFactory configure(DatabricksConfig config) {
ensureHostPresent(config, mapper);
String resource = config.getEffectiveAzureLoginAppId();
CliTokenSource tokenSource = tokenSourceFor(config, resource);
CliTokenSource mgmtTokenSource =
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());
tokenSource.getToken(); // We need this for checking if Azure CLI is installed.
try {
mgmtTokenSource.getToken();
} catch (Exception e) {
LOG.debug("Not including service management token in headers", e);
mgmtTokenSource = null;
}
CliTokenSource finalMgmtTokenSource = mgmtTokenSource;
return () -> {
Token token = tokenSource.getToken();
Map<String, String> headers = new HashMap<>();
headers.put("Authorization", token.getTokenType() + " " + token.getAccessToken());
return headers;
if (finalMgmtTokenSource != null) {
addSpManagementToken(finalMgmtTokenSource, headers);
}
return addWorkspaceResourceId(config, headers);
};
} catch (DatabricksException e) {
String stderr = e.getMessage();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,8 @@ public HeaderFactory configure(DatabricksConfig config) {
return () -> {
Map<String, String> headers = new HashMap<>();
headers.put("Authorization", "Bearer " + inner.getToken().getAccessToken());
headers.put("X-Databricks-Azure-SP-Management-Token", cloud.getToken().getAccessToken());
if (config.getAzureWorkspaceResourceId() != null) {
headers.put(
"X-Databricks-Azure-Workspace-Resource-Id", config.getAzureWorkspaceResourceId());
}
addWorkspaceResourceId(config, headers);
addSpManagementToken(cloud, headers);
return headers;
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,18 @@ default void ensureHostPresent(DatabricksConfig config, ObjectMapper mapper) {
throw new DatabricksException("Unable to fetch workspace URL: " + e.getMessage(), e);
}
}

default Map<String, String> addWorkspaceResourceId(
DatabricksConfig config, Map<String, String> headers) {
if (config.getAzureWorkspaceResourceId() != null) {
headers.put("X-Databricks-Azure-Workspace-Resource-Id", config.getAzureWorkspaceResourceId());
}
return headers;
}

default Map<String, String> addSpManagementToken(
RefreshableTokenSource tokenSource, Map<String, String> headers) {
headers.put("X-Databricks-Azure-SP-Management-Token", tokenSource.getToken().getAccessToken());
return headers;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package com.databricks.sdk;

import com.databricks.sdk.core.ConfigResolving;
import com.databricks.sdk.core.DatabricksConfig;
import com.databricks.sdk.core.utils.TestOSUtils;
import java.util.Map;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

public class DatabricksAuthManualTest implements ConfigResolving {
@Test
void azureCliWorkspaceHeaderPresent() {
StaticEnv env =
new StaticEnv()
.with("HOME", TestOSUtils.resource("/testdata/azure"))
.with("PATH", "testdata:/bin");
String azureWorkspaceResourceId =
"/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123";
DatabricksConfig config =
new DatabricksConfig()
.setAuthType("azure-cli")
.setHost("https://x")
.setAzureWorkspaceResourceId(azureWorkspaceResourceId);
resolveConfig(config, env);
Map<String, String> headers = config.authenticate();
Assertions.assertEquals(
azureWorkspaceResourceId, headers.get("X-Databricks-Azure-Workspace-Resource-Id"));
}

@Test
void azureCliUserWithManagementAccess() {
StaticEnv env =
new StaticEnv()
.with("HOME", TestOSUtils.resource("/testdata/azure"))
.with("PATH", "testdata:/bin");
String azureWorkspaceResourceId =
"/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123";
DatabricksConfig config =
new DatabricksConfig()
.setAuthType("azure-cli")
.setHost("https://x")
.setAzureWorkspaceResourceId(azureWorkspaceResourceId);
resolveConfig(config, env);
Map<String, String> headers = config.authenticate();
Assertions.assertEquals("...", headers.get("X-Databricks-Azure-SP-Management-Token"));
}

@Test
void azureCliUserNoManagementAccess() {
StaticEnv env =
new StaticEnv()
.with("HOME", TestOSUtils.resource("/testdata/azure"))
.with("PATH", "testdata:/bin")
.with("FAIL_IF", "https://management.core.windows.net/");
String azureWorkspaceResourceId =
"/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123";
DatabricksConfig config =
new DatabricksConfig()
.setAuthType("azure-cli")
.setHost("https://x")
.setAzureWorkspaceResourceId(azureWorkspaceResourceId);
resolveConfig(config, env);
Map<String, String> headers = config.authenticate();
Assertions.assertNull(headers.get("X-Databricks-Azure-SP-Management-Token"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,12 @@
import com.databricks.sdk.core.utils.GitHubUtils;
import com.databricks.sdk.core.utils.TestOSUtils;
import java.io.File;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Supplier;
import org.junit.jupiter.api.Test;

public class DatabricksAuthTest implements TestOSUtils, GitHubUtils, ConfigResolving {

private static String prefixPath;
public class DatabricksAuthTest implements GitHubUtils, ConfigResolving {

public DatabricksAuthTest() {
setPermissionOnTestAz();
prefixPath = System.getProperty("user.dir") + getTestDir();
}

@Test
Expand Down Expand Up @@ -209,7 +202,7 @@ public void testTestConfigConfigFile() {
@Test
public void testTestConfigConfigFileSkipDefaultProfileIfHostSpecified() {
// Set environment variables
StaticEnv env = new StaticEnv().with("HOME", resource("/testdata"));
StaticEnv env = new StaticEnv().with("HOME", TestOSUtils.resource("/testdata"));
raises(
"default auth: cannot configure default credentials. Config: host=https://x",
() -> {
Expand All @@ -222,7 +215,7 @@ public void testTestConfigConfigFileSkipDefaultProfileIfHostSpecified() {
@Test
public void testTestConfigConfigFileWithEmptyDefaultProfileSelectDefault() {
// Set environment variables
StaticEnv env = new StaticEnv().with("HOME", resource("/testdata/empty_default"));
StaticEnv env = new StaticEnv().with("HOME", TestOSUtils.resource("/testdata/empty_default"));
raises(
"default auth: cannot configure default credentials",
() -> {
Expand All @@ -238,7 +231,7 @@ public void testTestConfigConfigFileWithEmptyDefaultProfileSelectAbc() {
StaticEnv env =
new StaticEnv()
.with("DATABRICKS_CONFIG_PROFILE", "abc")
.with("HOME", resource("/testdata/empty_default"));
.with("HOME", TestOSUtils.resource("/testdata/empty_default"));
DatabricksConfig config = new DatabricksConfig();
resolveConfig(config, env);
config.authenticate();
Expand All @@ -250,7 +243,7 @@ public void testTestConfigConfigFileWithEmptyDefaultProfileSelectAbc() {
@Test
public void testTestConfigPatFromDatabricksCfg() {
// Set environment variables
StaticEnv env = new StaticEnv().with("HOME", resource("/testdata"));
StaticEnv env = new StaticEnv().with("HOME", TestOSUtils.resource("/testdata"));
DatabricksConfig config = new DatabricksConfig();
resolveConfig(config, env);
config.authenticate();
Expand All @@ -265,7 +258,7 @@ public void testTestConfigPatFromDatabricksCfgDotProfile() {
StaticEnv env =
new StaticEnv()
.with("DATABRICKS_CONFIG_PROFILE", "pat.with.dot")
.with("HOME", resource("/testdata"));
.with("HOME", TestOSUtils.resource("/testdata"));
DatabricksConfig config = new DatabricksConfig();
resolveConfig(config, env);
config.authenticate();
Expand All @@ -280,7 +273,7 @@ public void testTestConfigPatFromDatabricksCfgNohostProfile() {
StaticEnv env =
new StaticEnv()
.with("DATABRICKS_CONFIG_PROFILE", "nohost")
.with("HOME", resource("/testdata"));
.with("HOME", TestOSUtils.resource("/testdata"));
raises(
"default auth: cannot configure default credentials. Config: token=***, profile=nohost. Env: DATABRICKS_CONFIG_PROFILE",
() -> {
Expand All @@ -297,7 +290,7 @@ public void testTestConfigConfigProfileAndToken() {
new StaticEnv()
.with("DATABRICKS_CONFIG_PROFILE", "nohost")
.with("DATABRICKS_TOKEN", "x")
.with("HOME", resource("/testdata"));
.with("HOME", TestOSUtils.resource("/testdata"));
raises(
"default auth: cannot configure default credentials. Config: token=***, profile=nohost. Env: DATABRICKS_TOKEN, DATABRICKS_CONFIG_PROFILE",
() -> {
Expand All @@ -314,7 +307,7 @@ public void testTestConfigConfigProfileAndPassword() {
new StaticEnv()
.with("DATABRICKS_CONFIG_PROFILE", "nohost")
.with("DATABRICKS_USERNAME", "x")
.with("HOME", resource("/testdata"));
.with("HOME", TestOSUtils.resource("/testdata"));
raises(
"validate: more than one authorization method configured: basic and pat. Config: token=***, username=x, profile=nohost. Env: DATABRICKS_USERNAME, DATABRICKS_CONFIG_PROFILE",
() -> {
Expand All @@ -341,7 +334,9 @@ public void testTestConfigAzurePat() {
public void testTestConfigAzureCliHost() {
// Set environment variables
StaticEnv env =
new StaticEnv().with("HOME", resource("/testdata/azure")).with("PATH", "testdata:/bin");
new StaticEnv()
.with("HOME", TestOSUtils.resource("/testdata/azure"))
.with("PATH", "testdata:/bin");
DatabricksConfig config =
new DatabricksConfig().setHost("x").setAzureWorkspaceResourceId("/sub/rg/ws");
resolveConfig(config, env);
Expand All @@ -358,7 +353,7 @@ public void testTestConfigAzureCliHostFail() {
StaticEnv env =
new StaticEnv()
.with("FAIL", "yes")
.with("HOME", resource("/testdata/azure"))
.with("HOME", TestOSUtils.resource("/testdata/azure"))
.with("PATH", "testdata:/bin");
raises(
"default auth: azure-cli: cannot get access token: This is just a failing script.\n. Config: azure_workspace_resource_id=/sub/rg/ws",
Expand All @@ -374,7 +369,9 @@ public void testTestConfigAzureCliHostFail() {
public void testTestConfigAzureCliHostAzNotInstalled() {
// Set environment variables
StaticEnv env =
new StaticEnv().with("HOME", resource("/testdata/azure")).with("PATH", "whatever");
new StaticEnv()
.with("HOME", TestOSUtils.resource("/testdata/azure"))
.with("PATH", "whatever");
raises(
"default auth: cannot configure default credentials. Config: azure_workspace_resource_id=/sub/rg/ws",
() -> {
Expand All @@ -389,7 +386,9 @@ public void testTestConfigAzureCliHostAzNotInstalled() {
public void testTestConfigAzureCliHostPatConflictWithConfigFilePresentWithoutDefaultProfile() {
// Set environment variables
StaticEnv env =
new StaticEnv().with("HOME", resource("/testdata/azure")).with("PATH", "testdata:/bin");
new StaticEnv()
.with("HOME", TestOSUtils.resource("/testdata/azure"))
.with("PATH", "testdata:/bin");
raises(
"validate: more than one authorization method configured: azure and pat. Config: token=***, azure_workspace_resource_id=/sub/rg/ws",
() -> {
Expand All @@ -404,7 +403,9 @@ public void testTestConfigAzureCliHostPatConflictWithConfigFilePresentWithoutDef
public void testTestConfigAzureCliHostAndResourceId() {
// Set environment variables
StaticEnv env =
new StaticEnv().with("HOME", resource("/testdata")).with("PATH", "testdata:/bin");
new StaticEnv()
.with("HOME", TestOSUtils.resource("/testdata"))
.with("PATH", "testdata:/bin");
DatabricksConfig config =
new DatabricksConfig().setHost("x").setAzureWorkspaceResourceId("/sub/rg/ws");
resolveConfig(config, env);
Expand All @@ -421,7 +422,7 @@ public void testTestConfigAzureCliHostAndResourceIDConfigurationPrecedence() {
StaticEnv env =
new StaticEnv()
.with("DATABRICKS_CONFIG_PROFILE", "justhost")
.with("HOME", resource("/testdata/azure"))
.with("HOME", TestOSUtils.resource("/testdata/azure"))
.with("PATH", "testdata:/bin");
DatabricksConfig config =
new DatabricksConfig().setHost("x").setAzureWorkspaceResourceId("/sub/rg/ws");
Expand All @@ -439,7 +440,7 @@ public void testTestConfigAzureAndPasswordConflict() {
StaticEnv env =
new StaticEnv()
.with("DATABRICKS_USERNAME", "x")
.with("HOME", resource("/testdata/azure"))
.with("HOME", TestOSUtils.resource("/testdata/azure"))
.with("PATH", "testdata:/bin");
raises(
"validate: more than one authorization method configured: azure and basic. Config: host=x, username=x, azure_workspace_resource_id=/sub/rg/ws. Env: DATABRICKS_USERNAME",
Expand All @@ -457,7 +458,7 @@ public void testTestConfigCorruptConfig() {
StaticEnv env =
new StaticEnv()
.with("DATABRICKS_CONFIG_PROFILE", "DEFAULT")
.with("HOME", resource("/testdata/corrupt"));
.with("HOME", TestOSUtils.resource("/testdata/corrupt"));
raises(
"resolve: testdata/corrupt/.databrickscfg has no DEFAULT profile configured. Config: profile=DEFAULT. Env: DATABRICKS_CONFIG_PROFILE",
() -> {
Expand All @@ -484,31 +485,6 @@ public void testTestConfigAuthTypeFromEnv() {
assertEquals("https://x", config.getHost());
}

private String resource(String file) {
URL resource = getClass().getResource(file);
if (resource == null) {
fail("Asset not found: " + file);
}
return resource.getFile();
}

static class StaticEnv implements Supplier<Map<String, String>> {
private final Map<String, String> env = new HashMap<>();

public StaticEnv with(String key, String value) {
if (key.equals("PATH")) {
value = prefixPath + value;
}
env.put(key, value);
return this;
}

@Override
public Map<String, String> get() {
return env;
}
}

private void raises(String contains, Runnable cb) {
boolean raised = false;
try {
Expand All @@ -521,7 +497,7 @@ private void raises(String contains, Runnable cb) {
File.separator,
"/"); // We would need to do this upstream also for making paths compatible with
// windows
message = message.replace(prefixPath, "");
message = message.replace(StaticEnv.getPrefixPath(), "");
if (!message.contains(contains)) {
fail(String.format("Expected exception to contain '%s'", contains), e);
}
Expand Down
Loading

0 comments on commit 7bb4fd0

Please sign in to comment.