Skip to content

Commit

Permalink
chore: use ImpersonatedCredentials for service account impersonation …
Browse files Browse the repository at this point in the history
…for 3pi (#501)

* chore: use ImpersonatedCredentials for service account impersonation in ExternalAccountCredentials

* chore: add test for invalid service account impersonation url
  • Loading branch information
lsirac authored Oct 29, 2020
1 parent 124c77c commit 17e849e
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 131 deletions.
3 changes: 1 addition & 2 deletions oauth2_http/java/com/google/auth/oauth2/AwsCredentials.java
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ public AccessToken refreshAccessToken() throws IOException {
stsTokenExchangeRequest.setScopes(new ArrayList<>(scopes));
}

AccessToken accessToken = exchange3PICredentialForAccessToken(stsTokenExchangeRequest.build());
return attemptServiceAccountImpersonation(accessToken);
return exchange3PICredentialForAccessToken(stsTokenExchangeRequest.build());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,18 @@
import static com.google.api.client.util.Preconditions.checkNotNull;
import static com.google.common.base.MoreObjects.firstNonNull;

import com.google.api.client.http.GenericUrl;
import com.google.api.client.http.HttpHeaders;
import com.google.api.client.http.HttpRequest;
import com.google.api.client.http.HttpResponse;
import com.google.api.client.http.UrlEncodedContent;
import com.google.api.client.json.GenericJson;
import com.google.api.client.json.JsonObjectParser;
import com.google.api.client.util.GenericData;
import com.google.auth.http.AuthHttpConstants;
import com.google.auth.http.HttpTransportFactory;
import com.google.auth.oauth2.AwsCredentials.AwsCredentialSource;
import com.google.auth.oauth2.IdentityPoolCredentials.IdentityPoolCredentialSource;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Date;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
Expand All @@ -77,7 +68,6 @@ protected CredentialSource(Map<String, Object> credentialSourceMap) {
}
}

private static final String RFC3339 = "yyyy-MM-dd'T'HH:mm:ss'Z'";
private static final String CLOUD_PLATFORM_SCOPE =
"https://www.googleapis.com/auth/cloud-platform";

Expand All @@ -88,16 +78,18 @@ protected CredentialSource(Map<String, Object> credentialSourceMap) {
protected final String subjectTokenType;
protected final String tokenUrl;
protected final String tokenInfoUrl;
protected final String serviceAccountImpersonationUrl;
protected final CredentialSource credentialSource;
protected final Collection<String> scopes;

@Nullable protected final String serviceAccountImpersonationUrl;
@Nullable protected final String quotaProjectId;
@Nullable protected final String clientId;
@Nullable protected final String clientSecret;

protected transient HttpTransportFactory transportFactory;

@Nullable protected final ImpersonatedCredentials impersonatedCredentials;

/**
* Constructor with minimum identifying information and custom HTTP transport.
*
Expand Down Expand Up @@ -148,6 +140,35 @@ protected ExternalAccountCredentials(
this.clientSecret = clientSecret;
this.scopes =
(scopes == null || scopes.isEmpty()) ? Arrays.asList(CLOUD_PLATFORM_SCOPE) : scopes;
this.impersonatedCredentials = initializeImpersonatedCredentials();
}

private ImpersonatedCredentials initializeImpersonatedCredentials() {
if (serviceAccountImpersonationUrl == null) {
return null;
}
// Create a copy of this instance without service account impersonation.
ExternalAccountCredentials sourceCredentials;
if (this instanceof AwsCredentials) {
sourceCredentials =
AwsCredentials.newBuilder((AwsCredentials) this)
.setServiceAccountImpersonationUrl(null)
.build();
} else {
sourceCredentials =
IdentityPoolCredentials.newBuilder((IdentityPoolCredentials) this)
.setServiceAccountImpersonationUrl(null)
.build();
}

String targetPrincipal = extractTargetPrincipal(serviceAccountImpersonationUrl);
return ImpersonatedCredentials.newBuilder()
.setSourceCredentials(sourceCredentials)
.setHttpTransportFactory(transportFactory)
.setTargetPrincipal(targetPrincipal)
.setScopes(new ArrayList<>(scopes))
.setLifetime(3600) // 1 hour in seconds
.build();
}

@Override
Expand Down Expand Up @@ -262,6 +283,10 @@ private static boolean isAwsCredential(Map<String, Object> credentialSource) {
*/
protected AccessToken exchange3PICredentialForAccessToken(
StsTokenExchangeRequest stsTokenExchangeRequest) throws IOException {
// Handle service account impersonation if necessary.
if (impersonatedCredentials != null) {
return impersonatedCredentials.refreshAccessToken();
}

StsRequestHandler requestHandler =
StsRequestHandler.newBuilder(
Expand All @@ -273,52 +298,16 @@ protected AccessToken exchange3PICredentialForAccessToken(
return response.getAccessToken();
}

/**
* Attempts service account impersonation.
*
* @param accessToken the access token to be included in the request.
* @return the access token returned by the generateAccessToken call.
* @throws IOException if the service account impersonation call fails.
*/
protected AccessToken attemptServiceAccountImpersonation(AccessToken accessToken)
throws IOException {
if (serviceAccountImpersonationUrl == null) {
return accessToken;
}

HttpRequest request =
transportFactory
.create()
.createRequestFactory()
.buildPostRequest(
new GenericUrl(serviceAccountImpersonationUrl),
new UrlEncodedContent(new GenericData().set("scope", scopes.toArray())));
request.setParser(new JsonObjectParser(OAuth2Utils.JSON_FACTORY));
request.setHeaders(
new HttpHeaders()
.setAuthorization(
String.format("%s %s", AuthHttpConstants.BEARER, accessToken.getTokenValue())));

HttpResponse response;
try {
response = request.execute();
} catch (IOException e) {
throw new IOException(
String.format("Error getting access token for service account: %s", e.getMessage()), e);
}
private static String extractTargetPrincipal(String serviceAccountImpersonationUrl) {
// Extract the target principle.
int startIndex = serviceAccountImpersonationUrl.lastIndexOf('/');
int endIndex = serviceAccountImpersonationUrl.indexOf(":generateAccessToken");

GenericData responseData = response.parseAs(GenericData.class);
String token =
OAuth2Utils.validateString(responseData, "accessToken", "Expected to find an accessToken");

DateFormat format = new SimpleDateFormat(RFC3339);
String expireTime =
OAuth2Utils.validateString(responseData, "expireTime", "Expected to find an expireTime");
try {
Date date = format.parse(expireTime);
return new AccessToken(token, date);
} catch (ParseException e) {
throw new IOException("Error parsing expireTime: " + e.getMessage());
if (startIndex != -1 && endIndex != -1 && startIndex < endIndex) {
return serviceAccountImpersonationUrl.substring(startIndex + 1, endIndex);
} else {
throw new IllegalArgumentException(
"Unable to determine target principal from service account impersonation URL.");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ public AccessToken refreshAccessToken() throws IOException {
stsTokenExchangeRequest.setScopes(new ArrayList<>(scopes));
}

AccessToken accessToken = exchange3PICredentialForAccessToken(stsTokenExchangeRequest.build());
return attemptServiceAccountImpersonation(accessToken);
return exchange3PICredentialForAccessToken(stsTokenExchangeRequest.build());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ public class AwsCredentialsTest {
private static final String GET_CALLER_IDENTITY_URL =
"https://sts.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15";

private static final String SERVICE_ACCOUNT_IMPERSONATION_URL =
"https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/testn@test.iam.gserviceaccount.com:generateAccessToken";

private static final Map<String, Object> AWS_CREDENTIAL_SOURCE_MAP =
new HashMap<String, Object>() {
{
Expand Down Expand Up @@ -125,7 +128,8 @@ public void refreshAccessToken_withServiceAccountImpersonation() throws IOExcept

AccessToken accessToken = awsCredential.refreshAccessToken();

assertEquals(transportFactory.transport.getAccessToken(), accessToken.getTokenValue());
assertEquals(
transportFactory.transport.getServiceAccountAccessToken(), accessToken.getTokenValue());
}

@Test
Expand Down Expand Up @@ -293,7 +297,7 @@ public void createdScoped_clonedCredentialWithAddedScopes() {
AwsCredentials credentials =
(AwsCredentials)
AwsCredentials.newBuilder(AWS_CREDENTIAL)
.setServiceAccountImpersonationUrl("tokenInfoUrl")
.setServiceAccountImpersonationUrl(SERVICE_ACCOUNT_IMPERSONATION_URL)
.setQuotaProjectId("quotaProjectId")
.setClientId("clientId")
.setClientSecret("clientSecret")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,12 @@

import static com.google.auth.TestUtils.getDefaultExpireTime;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;

import com.google.api.client.http.HttpTransport;
import com.google.api.client.json.GenericJson;
import com.google.api.client.testing.http.MockLowLevelHttpRequest;
import com.google.auth.TestUtils;
import com.google.auth.http.HttpTransportFactory;
import com.google.auth.oauth2.ExternalAccountCredentialsTest.TestExternalAccountCredentials.TestCredentialSource;
Expand All @@ -64,9 +62,6 @@
public class ExternalAccountCredentialsTest {

private static final String STS_URL = "https://www.sts.google.com";
private static final String ACCESS_TOKEN = "eya23tfgdfga2123as";
private static final String CLOUD_PLATFORM_SCOPE =
"https://www.googleapis.com/auth/cloud-platform";

static class MockExternalAccountCredentialsTransportFactory implements HttpTransportFactory {

Expand Down Expand Up @@ -175,6 +170,25 @@ public void run() {
});
}

@Test
public void fromJson_invalidServiceAccountImpersonationUrl_throws() {
final GenericJson json = buildJsonIdentityPoolCredential();
json.put("service_account_impersonation_url", "invalid_url");

IllegalArgumentException e =
assertThrows(
IllegalArgumentException.class,
new ThrowingRunnable() {
@Override
public void run() {
ExternalAccountCredentials.fromJson(json, OAuth2Utils.HTTP_TRANSPORT_FACTORY);
}
});
assertEquals(
"Unable to determine target principal from service account impersonation URL.",
e.getMessage());
}

@Test
public void fromJson_nullTransport_throws() {
assertThrows(
Expand Down Expand Up @@ -207,6 +221,29 @@ public void exchange3PICredentialForAccessToken() throws IOException {
assertEquals("application/x-www-form-urlencoded", headers.get("content-type").get(0));
}

@Test
public void exchange3PICredentialForAccessToken_withServiceAccountImpersonation()
throws IOException {
transportFactory.transport.setExpireTime(getDefaultExpireTime());

ExternalAccountCredentials credential =
ExternalAccountCredentials.fromStream(
IdentityPoolCredentialsTest.writeIdentityPoolCredentialsStream(
transportFactory.transport.getStsUrl(),
transportFactory.transport.getMetadataUrl(),
transportFactory.transport.getServiceAccountImpersonationUrl()),
transportFactory);

StsTokenExchangeRequest stsTokenExchangeRequest =
StsTokenExchangeRequest.newBuilder("credential", "subjectTokenType").build();

AccessToken returnedToken =
credential.exchange3PICredentialForAccessToken(stsTokenExchangeRequest);

assertEquals(
transportFactory.transport.getServiceAccountAccessToken(), returnedToken.getTokenValue());
}

@Test
public void exchange3PICredentialForAccessToken_throws() throws IOException {
final ExternalAccountCredentials credential =
Expand Down Expand Up @@ -236,44 +273,6 @@ public void run() throws Throwable {
assertEquals(errorUri, e.getErrorUri());
}

@Test
public void attemptServiceAccountImpersonation() throws IOException {
GenericJson defaultCredential = buildJsonIdentityPoolCredential();
defaultCredential.put(
"service_account_impersonation_url",
transportFactory.transport.getServiceAccountImpersonationUrl());

ExternalAccountCredentials credential =
ExternalAccountCredentials.fromJson(defaultCredential, transportFactory);

transportFactory.transport.setExpireTime(getDefaultExpireTime());
AccessToken accessToken = new AccessToken(ACCESS_TOKEN, new Date());

AccessToken returnedToken = credential.attemptServiceAccountImpersonation(accessToken);

assertEquals(transportFactory.transport.getAccessToken(), returnedToken.getTokenValue());
assertNotEquals(accessToken.getTokenValue(), returnedToken.getTokenValue());

// Validate request content.
MockLowLevelHttpRequest request = transportFactory.transport.getRequest();
Map<String, String> actualRequestContent = TestUtils.parseQuery(request.getContentAsString());

Map<String, String> expectedRequestContent = new HashMap<>();
expectedRequestContent.put("scope", CLOUD_PLATFORM_SCOPE);
assertEquals(expectedRequestContent, actualRequestContent);
}

@Test
public void attemptServiceAccountImpersonation_noUrl() throws IOException {
ExternalAccountCredentials credential =
ExternalAccountCredentials.fromJson(buildJsonIdentityPoolCredential(), transportFactory);

AccessToken accessToken = new AccessToken(ACCESS_TOKEN, new Date());
AccessToken returnedToken = credential.attemptServiceAccountImpersonation(accessToken);

assertEquals(accessToken, returnedToken);
}

@Test
public void getRequestMetadata_withQuotaProjectId() throws IOException {
TestExternalAccountCredentials testCredentials =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ public void fromStream_identityPoolCredentials_providesToken() throws IOExceptio
new MockExternalAccountCredentialsTransportFactory();
InputStream identityPoolCredentialStream =
IdentityPoolCredentialsTest.writeIdentityPoolCredentialsStream(
transportFactory.transport.getStsUrl(), transportFactory.transport.getMetadataUrl());
transportFactory.transport.getStsUrl(),
transportFactory.transport.getMetadataUrl(),
/* serviceAccountImpersonationUrl= */ null);

GoogleCredentials credentials =
GoogleCredentials.fromStream(identityPoolCredentialStream, transportFactory);
Expand Down
Loading

0 comments on commit 17e849e

Please sign in to comment.