Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BREAKING] Fix for OOM error when reading from S3 #5613

Merged
merged 11 commits into from
Jun 21, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,6 @@ final class S3ChannelContext extends BaseSeekableChannelContext implements Seeka
*/
final S3RequestCache sharedCache;

/**
* Used to cache recently fetched fragments as well as the ownership token for the request. This cache is local to
* the context and is used to keep the requests alive as long as the context is alive.
*/
private final S3Request.AcquiredRequest[] localCache;

/**
* The size of the object in bytes, stored in context to avoid fetching multiple times
*/
Expand All @@ -72,7 +66,6 @@ final class S3ChannelContext extends BaseSeekableChannelContext implements Seeka
this.provider = Objects.requireNonNull(provider);
this.client = Objects.requireNonNull(client);
this.instructions = Objects.requireNonNull(instructions);
this.localCache = new S3Request.AcquiredRequest[instructions.maxCacheSize()];
this.sharedCache = sharedCache;
if (sharedCache.getFragmentSize() != instructions.fragmentSize()) {
throw new IllegalArgumentException("Fragment size mismatch between shared cache and instructions, "
Expand Down Expand Up @@ -121,22 +114,27 @@ int fill(final long position, final ByteBuffer dest) throws IOException {
final int impliedReadAhead = (int) (lastFragmentIx - firstFragmentIx);
final int desiredReadAhead = instructions.readAheadCount();
final long totalRemainingFragments = numFragments - firstFragmentIx - 1;
final int maxReadAhead = instructions.maxCacheSize() - 1;
readAhead = Math.min(
Math.max(impliedReadAhead, desiredReadAhead),
(int) Math.min(maxReadAhead, totalRemainingFragments));
readAhead = Math.min(Math.max(impliedReadAhead, desiredReadAhead), totalRemainingFragments);
}
final S3Request firstRequest = getOrCreateRequest(firstFragmentIx);
// Hold a reference to the first request to ensure it is not evicted from the cache
S3Request.AcquiredRequest acquiredFirstRequest = getOrCreateRequest(firstFragmentIx);
for (int i = 0; i < readAhead; ++i) {
// Do not hold references to the read ahead requests
getOrCreateRequest(firstFragmentIx + i + 1);
}
// blocking
int filled = firstRequest.fill(position, dest);
int filled = acquiredFirstRequest.request.fill(position, dest);
acquiredFirstRequest.release();
acquiredFirstRequest = null;
malhotrashivam marked this conversation as resolved.
Show resolved Hide resolved

for (int i = 0; dest.hasRemaining(); ++i) {
// Since we have already created requests for read ahead fragments, we can retrieve them from the local
// cache
final S3Request request = getRequestFromLocalCache(firstFragmentIx + i + 1);
if (request == null || !request.isDone()) {
final S3Request.AcquiredRequest acquiredReadAheadRequest =
getRequestFromSharedCache(firstFragmentIx + i + 1);
if (acquiredReadAheadRequest == null) {
break;
}
final S3Request request = acquiredReadAheadRequest.request;
if (!request.isDone()) {
malhotrashivam marked this conversation as resolved.
Show resolved Hide resolved
break;
}
// non-blocking since we know isDone
Expand All @@ -146,7 +144,6 @@ int fill(final long position, final ByteBuffer dest) throws IOException {
}

private void reset() {
releaseOutstanding();
// Reset the internal state
uri = null;
size = UNINITIALIZED_SIZE;
Expand All @@ -162,49 +159,27 @@ public void close() {
if (log.isDebugEnabled()) {
log.debug().append("Closing context: ").append(ctxStr()).endl();
}
releaseOutstanding();
}

/**
* Release all outstanding requests associated with this context. Eventually, the request will be canceled when the
* objects are garbage collected.
*/
private void releaseOutstanding() {
Arrays.fill(localCache, null);
}

// --------------------------------------------------------------------------------------------------

@Nullable
private S3Request getRequestFromLocalCache(final long fragmentIndex) {
return getRequestFromLocalCache(fragmentIndex, cacheIndex(fragmentIndex));
}

@Nullable
private S3Request getRequestFromLocalCache(final long fragmentIndex, final int cacheIdx) {
if (localCache[cacheIdx] != null && localCache[cacheIdx].request.isFragment(fragmentIndex)) {
return localCache[cacheIdx].request;
private S3Request.AcquiredRequest getRequestFromSharedCache(final long fragmentIndex) {
final S3Request.AcquiredRequest cachedRequest = sharedCache.getRequest(uri, fragmentIndex);
if (cachedRequest == null) {
return null;
}
return null;
// Send the request, if not sent already. The following method is idempotent, so we always call it.
cachedRequest.request.sendRequest();
malhotrashivam marked this conversation as resolved.
Show resolved Hide resolved
return cachedRequest;
}

@NotNull
private S3Request getOrCreateRequest(final long fragmentIndex) {
final int cacheIdx = cacheIndex(fragmentIndex);
final S3Request locallyCached = getRequestFromLocalCache(fragmentIndex, cacheIdx);
if (locallyCached != null) {
return locallyCached;
}
final S3Request.AcquiredRequest sharedCacheRequest = sharedCache.getOrCreateRequest(uri, fragmentIndex, this);
// Cache the request and the ownership token locally
localCache[cacheIdx] = sharedCacheRequest;
private S3Request.AcquiredRequest getOrCreateRequest(final long fragmentIndex) {
final S3Request.AcquiredRequest cachedRequest = sharedCache.getOrCreateRequest(uri, fragmentIndex, this);
// Send the request, if not sent already. The following method is idempotent, so we always call it.
sharedCacheRequest.request.sendRequest();
return sharedCacheRequest.request;
}

private int cacheIndex(final long fragmentIndex) {
return (int) (fragmentIndex % instructions.maxCacheSize());
cachedRequest.request.sendRequest();
return cachedRequest;
}

private long fragmentIndex(final long pos) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ public abstract class S3Instructions implements LogOutputAppendable {
private final static int DEFAULT_READ_AHEAD_COUNT = 32;
private final static int DEFAULT_FRAGMENT_SIZE = 1 << 16; // 64 KiB
private final static int MIN_FRAGMENT_SIZE = 8 << 10; // 8 KiB
private final static int DEFAULT_MAX_CACHE_SIZE = 256;
private final static Duration DEFAULT_CONNECTION_TIMEOUT = Duration.ofSeconds(2);
private final static Duration DEFAULT_READ_TIMEOUT = Duration.ofSeconds(2);

Expand Down Expand Up @@ -73,17 +72,6 @@ public int fragmentSize() {
return DEFAULT_FRAGMENT_SIZE;
}

/**
* The maximum number of fragments to cache in memory, defaults to
* {@code Math.max(1 + readAheadCount(), DEFAULT_MAX_CACHE_SIZE)}, which is at least
* {@value #DEFAULT_MAX_CACHE_SIZE}. This caching is done at the deephaven layer for faster access to recently read
* fragments. Must be greater than or equal to {@code 1 + readAheadCount()}.
*/
@Default
public int maxCacheSize() {
return Math.max(1 + readAheadCount(), DEFAULT_MAX_CACHE_SIZE);
}

/**
* The amount of time to wait when initially establishing a connection before giving up and timing out, defaults to
* 2 seconds.
Expand Down Expand Up @@ -133,8 +121,6 @@ public interface Builder {

Builder fragmentSize(int fragmentSize);

Builder maxCacheSize(int maxCacheSize);

Builder connectionTimeout(Duration connectionTimeout);

Builder readTimeout(Duration connectionTimeout);
Expand All @@ -152,13 +138,10 @@ default Builder endpointOverride(String endpointOverride) {

abstract S3Instructions withReadAheadCount(int readAheadCount);

abstract S3Instructions withMaxCacheSize(int maxCacheSize);

@Lazy
S3Instructions singleUse() {
final int readAheadCount = Math.min(DEFAULT_READ_AHEAD_COUNT, readAheadCount());
return withReadAheadCount(readAheadCount)
.withMaxCacheSize(readAheadCount + 1);
return withReadAheadCount(readAheadCount);
}

@Check
Expand All @@ -183,14 +166,6 @@ final void boundsCheckMinFragmentSize() {
}
}

@Check
final void boundsCheckMaxCacheSize() {
if (maxCacheSize() < readAheadCount() + 1) {
throw new IllegalArgumentException("maxCacheSize(=" + maxCacheSize() + ") must be >= 1 + " +
"readAheadCount(=" + readAheadCount() + ")");
}
}

@Check
final void awsSdkV2Credentials() {
if (!(credentials() instanceof AwsSdkV2Credentials)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,16 @@ static class AcquiredRequest {
* The ownership token keeps the request alive. When the ownership token is GC'd, the request is no longer
* usable and will be cleaned up.
*/
final Object ownershipToken;
Object ownershipToken;

AcquiredRequest(final S3Request request, final Object ownershipToken) {
this.request = request;
this.ownershipToken = ownershipToken;
}

void release() {
ownershipToken = null;
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import io.deephaven.internal.log.LoggerFactory;
import io.deephaven.io.logger.Logger;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import software.amazon.awssdk.services.s3.S3Uri;

/**
Expand Down Expand Up @@ -46,6 +47,27 @@ int getFragmentSize() {
return fragmentSize;
}

/**
* Acquire a request for the given URI and fragment index if it already exists in the cache.
*
* @param uri the URI
* @param fragmentIndex the fragment index
* @return the request
*/
@Nullable
S3Request.AcquiredRequest getRequest(@NotNull final S3Uri uri, final long fragmentIndex) {
final S3Request.ID key = new S3Request.ID(uri, fragmentIndex);
final S3Request existingRequest = requests.get(key);
if (existingRequest != null) {
final S3Request.AcquiredRequest acquired = existingRequest.tryAcquire();
if (acquired != null) {
return acquired;
}
remove(existingRequest);
}
return null;
}

/**
* Acquire a request for the given URI and fragment index, creating and sending a new request it if it does not
* exist in the cache.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ void defaults() {
assertThat(instructions.maxConcurrentRequests()).isEqualTo(256);
assertThat(instructions.readAheadCount()).isEqualTo(32);
assertThat(instructions.fragmentSize()).isEqualTo(65536);
assertThat(instructions.maxCacheSize()).isEqualTo(256);
assertThat(instructions.connectionTimeout()).isEqualTo(Duration.ofSeconds(2));
assertThat(instructions.readTimeout()).isEqualTo(Duration.ofSeconds(2));
assertThat(instructions.credentials()).isEqualTo(Credentials.defaultCredentials());
Expand Down Expand Up @@ -102,30 +101,6 @@ void tooSmallFragmentSize() {
}
}

@Test
void minMaxCacheSize() {
assertThat(S3Instructions.builder()
.regionName("some-region")
.readAheadCount(99)
.maxCacheSize(100)
.build()
.maxCacheSize())
.isEqualTo(100);
}

@Test
void tooSmallCacheSize() {
try {
S3Instructions.builder()
.regionName("some-region")
.readAheadCount(99)
.maxCacheSize(99)
.build();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageContaining("maxCacheSize");
}
}

@Test
void basicCredentials() {
assertThat(S3Instructions.builder()
Expand Down
6 changes: 0 additions & 6 deletions py/server/deephaven/experimental/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def __init__(self,
max_concurrent_requests: Optional[int] = None,
read_ahead_count: Optional[int] = None,
fragment_size: Optional[int] = None,
max_cache_size: Optional[int] = None,
connection_timeout: Union[
Duration, int, str, datetime.timedelta, np.timedelta64, pd.Timedelta, None] = None,
read_timeout: Union[
Expand All @@ -61,8 +60,6 @@ def __init__(self,
fragment. Defaults to 32, which means fetch the next 32 fragments in advance when reading the current fragment.
fragment_size (int): the maximum size of each fragment to read, defaults to 64 KiB. If there are fewer bytes
remaining in the file, the fetched fragment can be smaller.
max_cache_size (int): the maximum number of fragments to cache in memory while reading, defaults to 256. This
caching is done at the Deephaven layer for faster access to recently read fragments.
connection_timeout (Union[Duration, int, str, datetime.timedelta, np.timedelta64, pd.Timedelta]):
the amount of time to wait when initially establishing a connection before giving up and timing out, can
be expressed as an integer in nanoseconds, a time interval string, e.g. "PT00:00:00.001" or "PT1s", or
Expand Down Expand Up @@ -103,9 +100,6 @@ def __init__(self,
if fragment_size is not None:
builder.fragmentSize(fragment_size)

if max_cache_size is not None:
builder.maxCacheSize(max_cache_size)

if connection_timeout is not None:
builder.connectionTimeout(time.to_j_duration(connection_timeout))

Expand Down
Loading