diff --git a/docs/src/main/sphinx/connector/hive-s3.rst b/docs/src/main/sphinx/connector/hive-s3.rst index 68bb7a8e47a1..e6fd1ff61d50 100644 --- a/docs/src/main/sphinx/connector/hive-s3.rst +++ b/docs/src/main/sphinx/connector/hive-s3.rst @@ -108,6 +108,12 @@ Property Name Description ``hive.s3.proxy.preemptive-basic-auth`` Whether to attempt to authenticate preemptively against proxy when using base authorization, defaults to ``false``. +``hive.s3.sts.endpoint`` Optional override for the sts endpoint given that IAM role based + authentication via sts is used. + +``hive.s3.sts.region`` Optional override for the sts region given that IAM role based + authentication via sts is used. + ============================================ ================================================================= .. _hive-s3-credentials: diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/s3/HiveS3Config.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/s3/HiveS3Config.java index aea5e007fedf..7b1abd3389ed 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/s3/HiveS3Config.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/s3/HiveS3Config.java @@ -79,6 +79,8 @@ public class HiveS3Config private String s3proxyUsername; private String s3proxyPassword; private boolean s3preemptiveBasicProxyAuth; + private String s3StsEndpoint; + private String s3StsRegion; public String getS3AwsAccessKey() { @@ -590,4 +592,28 @@ public HiveS3Config setS3PreemptiveBasicProxyAuth(boolean s3preemptiveBasicProxy this.s3preemptiveBasicProxyAuth = s3preemptiveBasicProxyAuth; return this; } + + public String getS3StsEndpoint() + { + return s3StsEndpoint; + } + + @Config("hive.s3.sts.endpoint") + public HiveS3Config setS3StsEndpoint(String s3StsEndpoint) + { + this.s3StsEndpoint = s3StsEndpoint; + return this; + } + + public String getS3StsRegion() + { + return s3StsRegion; + } + + @Config("hive.s3.sts.region") + public HiveS3Config setS3StsRegion(String s3StsRegion) + { + this.s3StsRegion = s3StsRegion; + return this; + } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/s3/TrinoS3ConfigurationInitializer.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/s3/TrinoS3ConfigurationInitializer.java index 49b3554ba0f9..8c3ccfc24d4b 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/s3/TrinoS3ConfigurationInitializer.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/s3/TrinoS3ConfigurationInitializer.java @@ -61,6 +61,8 @@ import static io.trino.plugin.hive.s3.TrinoS3FileSystem.S3_STORAGE_CLASS; import static io.trino.plugin.hive.s3.TrinoS3FileSystem.S3_STREAMING_UPLOAD_ENABLED; import static io.trino.plugin.hive.s3.TrinoS3FileSystem.S3_STREAMING_UPLOAD_PART_SIZE; +import static io.trino.plugin.hive.s3.TrinoS3FileSystem.S3_STS_ENDPOINT; +import static io.trino.plugin.hive.s3.TrinoS3FileSystem.S3_STS_REGION; import static io.trino.plugin.hive.s3.TrinoS3FileSystem.S3_USER_AGENT_PREFIX; import static java.util.stream.Collectors.joining; @@ -106,6 +108,8 @@ public class TrinoS3ConfigurationInitializer private final String s3proxyUsername; private final String s3proxyPassword; private final boolean s3preemptiveBasicProxyAuth; + private final String s3StsEndpoint; + private final String s3StsRegion; @Inject public TrinoS3ConfigurationInitializer(HiveS3Config config) @@ -149,6 +153,8 @@ public TrinoS3ConfigurationInitializer(HiveS3Config config) this.s3proxyUsername = config.getS3ProxyUsername(); this.s3proxyPassword = config.getS3ProxyPassword(); this.s3preemptiveBasicProxyAuth = config.getS3PreemptiveBasicProxyAuth(); + this.s3StsEndpoint = config.getS3StsEndpoint(); + this.s3StsRegion = config.getS3StsRegion(); } @Override @@ -230,5 +236,11 @@ public void initializeConfiguration(Configuration config) config.set(S3_PROXY_PASSWORD, s3proxyPassword); } config.setBoolean(S3_PREEMPTIVE_BASIC_PROXY_AUTH, s3preemptiveBasicProxyAuth); + if (s3StsEndpoint != null) { + config.set(S3_STS_ENDPOINT, s3StsEndpoint); + } + if (s3StsRegion != null) { + config.set(S3_STS_REGION, s3StsRegion); + } } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/s3/TrinoS3FileSystem.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/s3/TrinoS3FileSystem.java index 0e21ade96bb8..844713847564 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/s3/TrinoS3FileSystem.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/s3/TrinoS3FileSystem.java @@ -18,6 +18,7 @@ import com.amazonaws.AmazonServiceException; import com.amazonaws.ClientConfiguration; import com.amazonaws.Protocol; +import com.amazonaws.SdkClientException; import com.amazonaws.auth.AWSCredentials; import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.auth.AWSStaticCredentialsProvider; @@ -32,6 +33,7 @@ import com.amazonaws.event.ProgressEventType; import com.amazonaws.event.ProgressListener; import com.amazonaws.metrics.RequestMetricCollector; +import com.amazonaws.regions.DefaultAwsRegionProviderChain; import com.amazonaws.services.s3.AmazonS3; import com.amazonaws.services.s3.AmazonS3Builder; import com.amazonaws.services.s3.AmazonS3Client; @@ -64,6 +66,7 @@ import com.amazonaws.services.s3.transfer.TransferManager; import com.amazonaws.services.s3.transfer.TransferManagerBuilder; import com.amazonaws.services.s3.transfer.Upload; +import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Splitter; import com.google.common.collect.AbstractSequentialIterator; @@ -205,6 +208,9 @@ public class TrinoS3FileSystem public static final String S3_PROXY_PASSWORD = "trino.s3.proxy.password"; public static final String S3_PREEMPTIVE_BASIC_PROXY_AUTH = "trino.s3.proxy.preemptive-basic-auth"; + public static final String S3_STS_ENDPOINT = "trino.s3.sts.endpoint"; + public static final String S3_STS_REGION = "trino.s3.sts.region"; + private static final Logger log = Logger.get(TrinoS3FileSystem.class); private static final TrinoS3FileSystemStats STATS = new TrinoS3FileSystemStats(); private static final RequestMetricCollector METRIC_COLLECTOR = new TrinoS3FileSystemMetricCollector(STATS); @@ -973,9 +979,37 @@ private AWSCredentialsProvider createAwsCredentialsProvider(URI uri, Configurati .orElseGet(DefaultAWSCredentialsProviderChain::getInstance); if (iamRole != null) { + String stsEndpointOverride = conf.get(S3_STS_ENDPOINT); + String stsRegionOverride = conf.get(S3_STS_REGION); + + AWSSecurityTokenServiceClientBuilder stsClientBuilder = AWSSecurityTokenServiceClientBuilder.standard() + .withCredentials(provider); + + String region; + if (!isNullOrEmpty(stsRegionOverride)) { + region = stsRegionOverride; + } + else { + DefaultAwsRegionProviderChain regionProviderChain = new DefaultAwsRegionProviderChain(); + try { + region = regionProviderChain.getRegion(); + } + catch (SdkClientException ex) { + log.warn("Falling back to default AWS region " + US_EAST_1); + region = US_EAST_1.getName(); + } + } + + if (!isNullOrEmpty(stsEndpointOverride)) { + stsClientBuilder.withEndpointConfiguration(new EndpointConfiguration(stsEndpointOverride, region)); + } + else { + stsClientBuilder.withRegion(region); + } + provider = new STSAssumeRoleSessionCredentialsProvider.Builder(iamRole, s3RoleSessionName) .withExternalId(externalId) - .withLongLivedCredentialsProvider(provider) + .withStsClient(stsClientBuilder.build()) .build(); } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/s3/TestHiveS3Config.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/s3/TestHiveS3Config.java index cf07834b57c0..111c5d3185a5 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/s3/TestHiveS3Config.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/s3/TestHiveS3Config.java @@ -76,7 +76,9 @@ public void testDefaults() .setS3NonProxyHosts(ImmutableList.of()) .setS3ProxyUsername(null) .setS3ProxyPassword(null) - .setS3PreemptiveBasicProxyAuth(false)); + .setS3PreemptiveBasicProxyAuth(false) + .setS3StsEndpoint(null) + .setS3StsRegion(null)); } @Test @@ -125,6 +127,8 @@ public void testExplicitPropertyMappings() .put("hive.s3.proxy.username", "test") .put("hive.s3.proxy.password", "test") .put("hive.s3.proxy.preemptive-basic-auth", "true") + .put("hive.s3.sts.endpoint", "http://minio:9000") + .put("hive.s3.sts.region", "eu-central-1") .buildOrThrow(); HiveS3Config expected = new HiveS3Config() @@ -166,7 +170,9 @@ public void testExplicitPropertyMappings() .setS3NonProxyHosts(ImmutableList.of("test", "test2", "test3")) .setS3ProxyUsername("test") .setS3ProxyPassword("test") - .setS3PreemptiveBasicProxyAuth(true); + .setS3PreemptiveBasicProxyAuth(true) + .setS3StsEndpoint("http://minio:9000") + .setS3StsRegion("eu-central-1"); assertFullMapping(properties, expected); }