Skip to content

Commit

Permalink
Refactoring gated and ref-counted interfaces and their implementations (
Browse files Browse the repository at this point in the history
#2396)

* Reducing duplication in plugins around ref-counted releasable classes

Both AmazonEc2Reference and AmazonS3Reference duplicate the same logic - a subclass of AbstractRefCounted that also implements Releasable. This change centralizes this paradigm into a AbstractRefCountedReleasable class and supports both clients via generics. It also updates all fetching implementations to use the get() method instead of client()

Signed-off-by: Kartik Ganesh <gkart@amazon.com>

* Introduce Reference classes for the Closeable and AutoCloseable interfaces

These classes allow you to wrap a reference instance with an onClose runnable that is executed when close() is invoked. Two separate classes are needed because the close() signatures for the two interfaces are different. This change takes the first step to have implementing classes extend from these generic superclasses, before attempting to remove the subclasses entirely. The get() method is also replaced throughout the code base.

Note that there is also a separate Releasable interface that has a similar access pattern, but is implemented separately. This is used in AbstractRefCountedReleasable introduced in a prior commit

Signed-off-by: Kartik Ganesh <gkart@amazon.com>

* More improvements and refactoring

* Functionality around one-way gating is now moved to a dedicated class - OneWayGate. This replaces duplicate functionality throughout the code.
* The two *Reference classes have been renamed to Gated* since that better represents their functionality
* The AbstractRefCountedReleasable has been improved to no longer be abstract by accepting the shutdown hook. This removes the need for the inner class in ReleasableBytesReference, and further simplifies the plugin subclasses (these could probably be removed entirely).
* Finally, unit tests have been added for some classes

Signed-off-by: Kartik Ganesh <gkart@amazon.com>

* Added tests for GatedCloseable

Also updated the license information in GatedAutoCloseableTests

Signed-off-by: Kartik Ganesh <gkart@amazon.com>

* Fixing license information in new files

Signed-off-by: Kartik Ganesh <gkart@amazon.com>

* Added unit tests for RefCountedReleasable

Signed-off-by: Kartik Ganesh <gkart@amazon.com>
  • Loading branch information
kartg authored Mar 9, 2022
1 parent 5a9a114 commit fb9e150
Show file tree
Hide file tree
Showing 33 changed files with 493 additions and 197 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@

package org.opensearch.common.util.concurrent;

import org.opensearch.common.concurrent.OneWayGate;
import org.opensearch.test.OpenSearchTestCase;
import org.hamcrest.Matchers;

import java.io.IOException;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
Expand Down Expand Up @@ -138,19 +138,19 @@ public void run() {

private final class MyRefCounted extends AbstractRefCounted {

private final AtomicBoolean closed = new AtomicBoolean(false);
private final OneWayGate gate = new OneWayGate();

MyRefCounted() {
super("test");
}

@Override
protected void closeInternal() {
this.closed.set(true);
gate.close();
}

public void ensureOpen() {
if (closed.get()) {
if (gate.isClosed()) {
assert this.refCount() == 0;
throw new IllegalStateException("closed");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,42 +33,15 @@
package org.opensearch.discovery.ec2;

import com.amazonaws.services.ec2.AmazonEC2;

import org.opensearch.common.lease.Releasable;
import org.opensearch.common.util.concurrent.AbstractRefCounted;
import org.opensearch.common.concurrent.RefCountedReleasable;

/**
* Handles the shutdown of the wrapped {@link AmazonEC2} using reference
* counting.
*/
public class AmazonEc2Reference extends AbstractRefCounted implements Releasable {

private final AmazonEC2 client;
public class AmazonEc2Reference extends RefCountedReleasable<AmazonEC2> {

AmazonEc2Reference(AmazonEC2 client) {
super("AWS_EC2_CLIENT");
this.client = client;
super("AWS_EC2_CLIENT", client, client::shutdown);
}

/**
* Call when the client is not needed anymore.
*/
@Override
public void close() {
decRef();
}

/**
* Returns the underlying `AmazonEC2` client. All method calls are permitted BUT
* NOT shutdown. Shutdown is called when reference count reaches 0.
*/
public AmazonEC2 client() {
return client;
}

@Override
protected void closeInternal() {
client.shutdown();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ protected List<TransportAddress> fetchDynamicNodes() {
// NOTE: we don't filter by security group during the describe instances request for two reasons:
// 1. differences in VPCs require different parameters during query (ID vs Name)
// 2. We want to use two different strategies: (all security groups vs. any security groups)
descInstances = SocketAccess.doPrivileged(() -> clientReference.client().describeInstances(buildDescribeInstancesRequest()));
descInstances = SocketAccess.doPrivileged(() -> clientReference.get().describeInstances(buildDescribeInstancesRequest()));
} catch (final AmazonClientException e) {
logger.info("Exception while retrieving instance list from AWS API: {}", e.getMessage());
logger.debug("Full exception:", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@ public void testNodeAttributesErrorLenient() throws Exception {

public void testDefaultEndpoint() throws IOException {
try (Ec2DiscoveryPluginMock plugin = new Ec2DiscoveryPluginMock(Settings.EMPTY)) {
final String endpoint = ((AmazonEC2Mock) plugin.ec2Service.client().client()).endpoint;
final String endpoint = ((AmazonEC2Mock) plugin.ec2Service.client().get()).endpoint;
assertThat(endpoint, is(""));
}
}

public void testSpecificEndpoint() throws IOException {
final Settings settings = Settings.builder().put(Ec2ClientSettings.ENDPOINT_SETTING.getKey(), "ec2.endpoint").build();
try (Ec2DiscoveryPluginMock plugin = new Ec2DiscoveryPluginMock(settings)) {
final String endpoint = ((AmazonEC2Mock) plugin.ec2Service.client().client()).endpoint;
final String endpoint = ((AmazonEC2Mock) plugin.ec2Service.client().get()).endpoint;
assertThat(endpoint, is("ec2.endpoint"));
}
}
Expand Down Expand Up @@ -150,7 +150,7 @@ public void testClientSettingsReInit() throws IOException {
try (Ec2DiscoveryPluginMock plugin = new Ec2DiscoveryPluginMock(settings1)) {
try (AmazonEc2Reference clientReference = plugin.ec2Service.client()) {
{
final AWSCredentials credentials = ((AmazonEC2Mock) clientReference.client()).credentials.getCredentials();
final AWSCredentials credentials = ((AmazonEC2Mock) clientReference.get()).credentials.getCredentials();
assertThat(credentials.getAWSAccessKeyId(), is("ec2_access_1"));
assertThat(credentials.getAWSSecretKey(), is("ec2_secret_1"));
if (mockSecure1HasSessionToken) {
Expand All @@ -159,32 +159,32 @@ public void testClientSettingsReInit() throws IOException {
} else {
assertThat(credentials, instanceOf(BasicAWSCredentials.class));
}
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyUsername(), is("proxy_username_1"));
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyPassword(), is("proxy_password_1"));
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyHost(), is("proxy_host_1"));
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyPort(), is(881));
assertThat(((AmazonEC2Mock) clientReference.client()).endpoint, is("ec2_endpoint_1"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyUsername(), is("proxy_username_1"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyPassword(), is("proxy_password_1"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyHost(), is("proxy_host_1"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyPort(), is(881));
assertThat(((AmazonEC2Mock) clientReference.get()).endpoint, is("ec2_endpoint_1"));
}
// reload secure settings2
plugin.reload(settings2);
// client is not released, it is still using the old settings
{
final AWSCredentials credentials = ((AmazonEC2Mock) clientReference.client()).credentials.getCredentials();
final AWSCredentials credentials = ((AmazonEC2Mock) clientReference.get()).credentials.getCredentials();
if (mockSecure1HasSessionToken) {
assertThat(credentials, instanceOf(BasicSessionCredentials.class));
assertThat(((BasicSessionCredentials) credentials).getSessionToken(), is("ec2_session_token_1"));
} else {
assertThat(credentials, instanceOf(BasicAWSCredentials.class));
}
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyUsername(), is("proxy_username_1"));
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyPassword(), is("proxy_password_1"));
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyHost(), is("proxy_host_1"));
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyPort(), is(881));
assertThat(((AmazonEC2Mock) clientReference.client()).endpoint, is("ec2_endpoint_1"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyUsername(), is("proxy_username_1"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyPassword(), is("proxy_password_1"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyHost(), is("proxy_host_1"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyPort(), is(881));
assertThat(((AmazonEC2Mock) clientReference.get()).endpoint, is("ec2_endpoint_1"));
}
}
try (AmazonEc2Reference clientReference = plugin.ec2Service.client()) {
final AWSCredentials credentials = ((AmazonEC2Mock) clientReference.client()).credentials.getCredentials();
final AWSCredentials credentials = ((AmazonEC2Mock) clientReference.get()).credentials.getCredentials();
assertThat(credentials.getAWSAccessKeyId(), is("ec2_access_2"));
assertThat(credentials.getAWSSecretKey(), is("ec2_secret_2"));
if (mockSecure2HasSessionToken) {
Expand All @@ -193,11 +193,11 @@ public void testClientSettingsReInit() throws IOException {
} else {
assertThat(credentials, instanceOf(BasicAWSCredentials.class));
}
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyUsername(), is("proxy_username_2"));
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyPassword(), is("proxy_password_2"));
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyHost(), is("proxy_host_2"));
assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyPort(), is(882));
assertThat(((AmazonEC2Mock) clientReference.client()).endpoint, is("ec2_endpoint_2"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyUsername(), is("proxy_username_2"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyPassword(), is("proxy_password_2"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyHost(), is("proxy_host_2"));
assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyPort(), is(882));
assertThat(((AmazonEC2Mock) clientReference.get()).endpoint, is("ec2_endpoint_2"));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,45 +32,17 @@

package org.opensearch.repositories.s3;

import org.opensearch.common.util.concurrent.AbstractRefCounted;

import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3Client;

import org.opensearch.common.lease.Releasable;
import org.opensearch.common.concurrent.RefCountedReleasable;

/**
* Handles the shutdown of the wrapped {@link AmazonS3Client} using reference
* counting.
*/
public class AmazonS3Reference extends AbstractRefCounted implements Releasable {

private final AmazonS3 client;
public class AmazonS3Reference extends RefCountedReleasable<AmazonS3> {

AmazonS3Reference(AmazonS3 client) {
super("AWS_S3_CLIENT");
this.client = client;
}

/**
* Call when the client is not needed anymore.
*/
@Override
public void close() {
decRef();
super("AWS_S3_CLIENT", client, client::shutdown);
}

/**
* Returns the underlying `AmazonS3` client. All method calls are permitted BUT
* NOT shutdown. Shutdown is called when reference count reaches 0.
*/
public AmazonS3 client() {
return client;
}

@Override
protected void closeInternal() {
client.shutdown();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class S3BlobContainer extends AbstractBlobContainer {
@Override
public boolean blobExists(String blobName) {
try (AmazonS3Reference clientReference = blobStore.clientReference()) {
return SocketAccess.doPrivileged(() -> clientReference.client().doesObjectExist(blobStore.bucket(), buildKey(blobName)));
return SocketAccess.doPrivileged(() -> clientReference.get().doesObjectExist(blobStore.bucket(), buildKey(blobName)));
} catch (final Exception e) {
throw new BlobStoreException("Failed to check if blob [" + blobName + "] exists", e);
}
Expand Down Expand Up @@ -169,13 +169,13 @@ public DeleteResult delete() throws IOException {
ObjectListing list;
if (prevListing != null) {
final ObjectListing finalPrevListing = prevListing;
list = SocketAccess.doPrivileged(() -> clientReference.client().listNextBatchOfObjects(finalPrevListing));
list = SocketAccess.doPrivileged(() -> clientReference.get().listNextBatchOfObjects(finalPrevListing));
} else {
final ListObjectsRequest listObjectsRequest = new ListObjectsRequest();
listObjectsRequest.setBucketName(blobStore.bucket());
listObjectsRequest.setPrefix(keyPath);
listObjectsRequest.setRequestMetricCollector(blobStore.listMetricCollector);
list = SocketAccess.doPrivileged(() -> clientReference.client().listObjects(listObjectsRequest));
list = SocketAccess.doPrivileged(() -> clientReference.get().listObjects(listObjectsRequest));
}
final List<String> blobsToDelete = new ArrayList<>();
list.getObjectSummaries().forEach(s3ObjectSummary -> {
Expand Down Expand Up @@ -236,7 +236,7 @@ private void doDeleteBlobs(List<String> blobNames, boolean relative) throws IOEx
.map(DeleteObjectsRequest.KeyVersion::getKey)
.collect(Collectors.toList());
try {
clientReference.client().deleteObjects(deleteRequest);
clientReference.get().deleteObjects(deleteRequest);
outstanding.removeAll(keysInRequest);
} catch (MultiObjectDeleteException e) {
// We are sending quiet mode requests so we can't use the deleted keys entry on the exception and instead
Expand Down Expand Up @@ -324,9 +324,9 @@ private static List<ObjectListing> executeListing(AmazonS3Reference clientRefere
ObjectListing list;
if (prevListing != null) {
final ObjectListing finalPrevListing = prevListing;
list = SocketAccess.doPrivileged(() -> clientReference.client().listNextBatchOfObjects(finalPrevListing));
list = SocketAccess.doPrivileged(() -> clientReference.get().listNextBatchOfObjects(finalPrevListing));
} else {
list = SocketAccess.doPrivileged(() -> clientReference.client().listObjects(listObjectsRequest));
list = SocketAccess.doPrivileged(() -> clientReference.get().listObjects(listObjectsRequest));
}
results.add(list);
if (list.isTruncated()) {
Expand Down Expand Up @@ -374,7 +374,7 @@ void executeSingleUpload(final S3BlobStore blobStore, final String blobName, fin
putRequest.setRequestMetricCollector(blobStore.putMetricCollector);

try (AmazonS3Reference clientReference = blobStore.clientReference()) {
SocketAccess.doPrivilegedVoid(() -> { clientReference.client().putObject(putRequest); });
SocketAccess.doPrivilegedVoid(() -> { clientReference.get().putObject(putRequest); });
} catch (final AmazonClientException e) {
throw new IOException("Unable to upload object [" + blobName + "] using a single upload", e);
}
Expand Down Expand Up @@ -413,7 +413,7 @@ void executeMultipartUpload(final S3BlobStore blobStore, final String blobName,
}
try (AmazonS3Reference clientReference = blobStore.clientReference()) {

uploadId.set(SocketAccess.doPrivileged(() -> clientReference.client().initiateMultipartUpload(initRequest).getUploadId()));
uploadId.set(SocketAccess.doPrivileged(() -> clientReference.get().initiateMultipartUpload(initRequest).getUploadId()));
if (Strings.isEmpty(uploadId.get())) {
throw new IOException("Failed to initialize multipart upload " + blobName);
}
Expand All @@ -439,7 +439,7 @@ void executeMultipartUpload(final S3BlobStore blobStore, final String blobName,
}
bytesCount += uploadRequest.getPartSize();

final UploadPartResult uploadResponse = SocketAccess.doPrivileged(() -> clientReference.client().uploadPart(uploadRequest));
final UploadPartResult uploadResponse = SocketAccess.doPrivileged(() -> clientReference.get().uploadPart(uploadRequest));
parts.add(uploadResponse.getPartETag());
}

Expand All @@ -456,7 +456,7 @@ void executeMultipartUpload(final S3BlobStore blobStore, final String blobName,
parts
);
complRequest.setRequestMetricCollector(blobStore.multiPartUploadMetricCollector);
SocketAccess.doPrivilegedVoid(() -> clientReference.client().completeMultipartUpload(complRequest));
SocketAccess.doPrivilegedVoid(() -> clientReference.get().completeMultipartUpload(complRequest));
success = true;

} catch (final AmazonClientException e) {
Expand All @@ -465,7 +465,7 @@ void executeMultipartUpload(final S3BlobStore blobStore, final String blobName,
if ((success == false) && Strings.hasLength(uploadId.get())) {
final AbortMultipartUploadRequest abortRequest = new AbortMultipartUploadRequest(bucketName, blobName, uploadId.get());
try (AmazonS3Reference clientReference = blobStore.clientReference()) {
SocketAccess.doPrivilegedVoid(() -> clientReference.client().abortMultipartUpload(abortRequest));
SocketAccess.doPrivilegedVoid(() -> clientReference.get().abortMultipartUpload(abortRequest));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ private void openStream() throws IOException {
+ end;
getObjectRequest.setRange(Math.addExact(start, currentOffset), end);
}
final S3Object s3Object = SocketAccess.doPrivileged(() -> clientReference.client().getObject(getObjectRequest));
final S3Object s3Object = SocketAccess.doPrivileged(() -> clientReference.get().getObject(getObjectRequest));
this.currentStreamLastOffset = Math.addExact(Math.addExact(start, currentOffset), getStreamLength(s3Object));
this.currentStream = s3Object.getObjectContent();
} catch (final AmazonClientException e) {
Expand Down
Loading

0 comments on commit fb9e150

Please sign in to comment.