Skip to content

Commit

Permalink
Allow to specify the sts endpoint for hive connector on s3
Browse files Browse the repository at this point in the history
  • Loading branch information
clemensvonschwerin authored and losipiuk committed Mar 15, 2022
1 parent ccb3d30 commit 0d66aa4
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 3 deletions.
6 changes: 6 additions & 0 deletions docs/src/main/sphinx/connector/hive-s3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ Property Name Description
``hive.s3.proxy.preemptive-basic-auth`` Whether to attempt to authenticate preemptively against proxy
when using base authorization, defaults to ``false``.

``hive.s3.sts.endpoint`` Optional override for the sts endpoint given that IAM role based
authentication via sts is used.

``hive.s3.sts.region`` Optional override for the sts region given that IAM role based
authentication via sts is used.

============================================ =================================================================

.. _hive-s3-credentials:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ public class HiveS3Config
private String s3proxyUsername;
private String s3proxyPassword;
private boolean s3preemptiveBasicProxyAuth;
private String s3StsEndpoint;
private String s3StsRegion;

public String getS3AwsAccessKey()
{
Expand Down Expand Up @@ -590,4 +592,28 @@ public HiveS3Config setS3PreemptiveBasicProxyAuth(boolean s3preemptiveBasicProxy
this.s3preemptiveBasicProxyAuth = s3preemptiveBasicProxyAuth;
return this;
}

public String getS3StsEndpoint()
{
return s3StsEndpoint;
}

@Config("hive.s3.sts.endpoint")
public HiveS3Config setS3StsEndpoint(String s3StsEndpoint)
{
this.s3StsEndpoint = s3StsEndpoint;
return this;
}

public String getS3StsRegion()
{
return s3StsRegion;
}

@Config("hive.s3.sts.region")
public HiveS3Config setS3StsRegion(String s3StsRegion)
{
this.s3StsRegion = s3StsRegion;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
import static io.trino.plugin.hive.s3.TrinoS3FileSystem.S3_STORAGE_CLASS;
import static io.trino.plugin.hive.s3.TrinoS3FileSystem.S3_STREAMING_UPLOAD_ENABLED;
import static io.trino.plugin.hive.s3.TrinoS3FileSystem.S3_STREAMING_UPLOAD_PART_SIZE;
import static io.trino.plugin.hive.s3.TrinoS3FileSystem.S3_STS_ENDPOINT;
import static io.trino.plugin.hive.s3.TrinoS3FileSystem.S3_STS_REGION;
import static io.trino.plugin.hive.s3.TrinoS3FileSystem.S3_USER_AGENT_PREFIX;
import static java.util.stream.Collectors.joining;

Expand Down Expand Up @@ -106,6 +108,8 @@ public class TrinoS3ConfigurationInitializer
private final String s3proxyUsername;
private final String s3proxyPassword;
private final boolean s3preemptiveBasicProxyAuth;
private final String s3StsEndpoint;
private final String s3StsRegion;

@Inject
public TrinoS3ConfigurationInitializer(HiveS3Config config)
Expand Down Expand Up @@ -149,6 +153,8 @@ public TrinoS3ConfigurationInitializer(HiveS3Config config)
this.s3proxyUsername = config.getS3ProxyUsername();
this.s3proxyPassword = config.getS3ProxyPassword();
this.s3preemptiveBasicProxyAuth = config.getS3PreemptiveBasicProxyAuth();
this.s3StsEndpoint = config.getS3StsEndpoint();
this.s3StsRegion = config.getS3StsRegion();
}

@Override
Expand Down Expand Up @@ -230,5 +236,11 @@ public void initializeConfiguration(Configuration config)
config.set(S3_PROXY_PASSWORD, s3proxyPassword);
}
config.setBoolean(S3_PREEMPTIVE_BASIC_PROXY_AUTH, s3preemptiveBasicProxyAuth);
if (s3StsEndpoint != null) {
config.set(S3_STS_ENDPOINT, s3StsEndpoint);
}
if (s3StsRegion != null) {
config.set(S3_STS_REGION, s3StsRegion);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.amazonaws.AmazonServiceException;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.Protocol;
import com.amazonaws.SdkClientException;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
Expand All @@ -32,6 +33,7 @@
import com.amazonaws.event.ProgressEventType;
import com.amazonaws.event.ProgressListener;
import com.amazonaws.metrics.RequestMetricCollector;
import com.amazonaws.regions.DefaultAwsRegionProviderChain;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3Builder;
import com.amazonaws.services.s3.AmazonS3Client;
Expand Down Expand Up @@ -64,6 +66,7 @@
import com.amazonaws.services.s3.transfer.TransferManager;
import com.amazonaws.services.s3.transfer.TransferManagerBuilder;
import com.amazonaws.services.s3.transfer.Upload;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Splitter;
import com.google.common.collect.AbstractSequentialIterator;
Expand Down Expand Up @@ -205,6 +208,9 @@ public class TrinoS3FileSystem
public static final String S3_PROXY_PASSWORD = "trino.s3.proxy.password";
public static final String S3_PREEMPTIVE_BASIC_PROXY_AUTH = "trino.s3.proxy.preemptive-basic-auth";

public static final String S3_STS_ENDPOINT = "trino.s3.sts.endpoint";
public static final String S3_STS_REGION = "trino.s3.sts.region";

private static final Logger log = Logger.get(TrinoS3FileSystem.class);
private static final TrinoS3FileSystemStats STATS = new TrinoS3FileSystemStats();
private static final RequestMetricCollector METRIC_COLLECTOR = new TrinoS3FileSystemMetricCollector(STATS);
Expand Down Expand Up @@ -973,9 +979,37 @@ private AWSCredentialsProvider createAwsCredentialsProvider(URI uri, Configurati
.orElseGet(DefaultAWSCredentialsProviderChain::getInstance);

if (iamRole != null) {
String stsEndpointOverride = conf.get(S3_STS_ENDPOINT);
String stsRegionOverride = conf.get(S3_STS_REGION);

AWSSecurityTokenServiceClientBuilder stsClientBuilder = AWSSecurityTokenServiceClientBuilder.standard()
.withCredentials(provider);

String region;
if (!isNullOrEmpty(stsRegionOverride)) {
region = stsRegionOverride;
}
else {
DefaultAwsRegionProviderChain regionProviderChain = new DefaultAwsRegionProviderChain();
try {
region = regionProviderChain.getRegion();
}
catch (SdkClientException ex) {
log.warn("Falling back to default AWS region " + US_EAST_1);
region = US_EAST_1.getName();
}
}

if (!isNullOrEmpty(stsEndpointOverride)) {
stsClientBuilder.withEndpointConfiguration(new EndpointConfiguration(stsEndpointOverride, region));
}
else {
stsClientBuilder.withRegion(region);
}

provider = new STSAssumeRoleSessionCredentialsProvider.Builder(iamRole, s3RoleSessionName)
.withExternalId(externalId)
.withLongLivedCredentialsProvider(provider)
.withStsClient(stsClientBuilder.build())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ public void testDefaults()
.setS3NonProxyHosts(ImmutableList.of())
.setS3ProxyUsername(null)
.setS3ProxyPassword(null)
.setS3PreemptiveBasicProxyAuth(false));
.setS3PreemptiveBasicProxyAuth(false)
.setS3StsEndpoint(null)
.setS3StsRegion(null));
}

@Test
Expand Down Expand Up @@ -125,6 +127,8 @@ public void testExplicitPropertyMappings()
.put("hive.s3.proxy.username", "test")
.put("hive.s3.proxy.password", "test")
.put("hive.s3.proxy.preemptive-basic-auth", "true")
.put("hive.s3.sts.endpoint", "http://minio:9000")
.put("hive.s3.sts.region", "eu-central-1")
.buildOrThrow();

HiveS3Config expected = new HiveS3Config()
Expand Down Expand Up @@ -166,7 +170,9 @@ public void testExplicitPropertyMappings()
.setS3NonProxyHosts(ImmutableList.of("test", "test2", "test3"))
.setS3ProxyUsername("test")
.setS3ProxyPassword("test")
.setS3PreemptiveBasicProxyAuth(true);
.setS3PreemptiveBasicProxyAuth(true)
.setS3StsEndpoint("http://minio:9000")
.setS3StsRegion("eu-central-1");

assertFullMapping(properties, expected);
}
Expand Down

0 comments on commit 0d66aa4

Please sign in to comment.