Skip to content

Commit

Permalink
Adding leak detection to BulkShardRequest
Browse files Browse the repository at this point in the history
  • Loading branch information
masseyke committed Dec 7, 2023
1 parent 5a9f08f commit 1fb746b
Show file tree
Hide file tree
Showing 10 changed files with 1,147 additions and 1,013 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.transport.LeakTracker;
import org.elasticsearch.transport.RawIndexingDataTransportRequest;

import java.io.IOException;
Expand All @@ -27,20 +31,24 @@
public final class BulkShardRequest extends ReplicatedWriteRequest<BulkShardRequest>
implements
Accountable,
RawIndexingDataTransportRequest {
RawIndexingDataTransportRequest,
Releasable {

private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(BulkShardRequest.class);

private final BulkItemRequest[] items;
private final RefCounted refCounted;

public BulkShardRequest(StreamInput in) throws IOException {
super(in);
items = in.readArray(i -> i.readOptionalWriteable(inpt -> new BulkItemRequest(shardId, inpt)), BulkItemRequest[]::new);
this.refCounted = LeakTracker.wrap(new BulkRequestRefCounted());
}

public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRequest[] items) {
super(shardId);
this.items = items;
this.refCounted = LeakTracker.wrap(new BulkRequestRefCounted());
setRefreshPolicy(refreshPolicy);
}

Expand Down Expand Up @@ -154,4 +162,36 @@ public long ramBytesUsed() {
}
return sum;
}

@Override
public void incRef() {
refCounted.incRef();
}

@Override
public boolean tryIncRef() {
return refCounted.tryIncRef();
}

@Override
public boolean decRef() {
return refCounted.decRef();
}

@Override
public boolean hasReferences() {
return refCounted.hasReferences();
}

@Override
public void close() {
decRef();
}

private static class BulkRequestRefCounted extends AbstractRefCounted {
@Override
protected void closeInternal() {
// nothing to close
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -685,26 +685,30 @@ protected void doRun() {
client.executeLocally(TransportShardBulkAction.TYPE, bulkShardRequest, new ActionListener<>() {
@Override
public void onResponse(BulkShardResponse bulkShardResponse) {
for (BulkItemResponse bulkItemResponse : bulkShardResponse.getResponses()) {
// we may have no response if item failed
if (bulkItemResponse.getResponse() != null) {
bulkItemResponse.getResponse().setShardInfo(bulkShardResponse.getShardInfo());
try (bulkShardRequest) {
for (BulkItemResponse bulkItemResponse : bulkShardResponse.getResponses()) {
// we may have no response if item failed
if (bulkItemResponse.getResponse() != null) {
bulkItemResponse.getResponse().setShardInfo(bulkShardResponse.getShardInfo());
}
responses.set(bulkItemResponse.getItemId(), bulkItemResponse);
}
responses.set(bulkItemResponse.getItemId(), bulkItemResponse);
maybeFinishHim();
}
maybeFinishHim();
}

@Override
public void onFailure(Exception e) {
// create failures for all relevant requests
for (BulkItemRequest request : requests) {
final String indexName = request.index();
DocWriteRequest<?> docWriteRequest = request.request();
BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e);
responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure));
try (bulkShardRequest) {
// create failures for all relevant requests
for (BulkItemRequest request : requests) {
final String indexName = request.index();
DocWriteRequest<?> docWriteRequest = request.request();
BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(), e);
responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure));
}
maybeFinishHim();
}
maybeFinishHim();
}

private void maybeFinishHim() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,8 @@ final class ReroutePhase extends AbstractRunnable {
if (task != null) {
this.request.setParentTask(clusterService.localNode().getId(), task.getId());
}
this.listener = listener;
request.incRef();
this.listener = ActionListener.runAfter(listener, request::decRef);
this.task = task;
this.observer = new ClusterStateObserver(clusterService, request.timeout(), logger, threadPool.getThreadContext());
}
Expand Down Expand Up @@ -1418,6 +1419,27 @@ public boolean isRawIndexingData() {
public String toString() {
return "request: " + request + ", target allocation id: " + targetAllocationID + ", primary term: " + primaryTerm;
}

@Override
public void incRef() {
request.incRef();
}

@Override
public boolean tryIncRef() {
return request.tryIncRef();
}

@Override
public boolean decRef() {
return request.decRef();
}

@Override
public boolean hasReferences() {
return request.hasReferences();
}

}

protected static final class ConcreteReplicaRequest<R extends TransportRequest> extends ConcreteShardRequest<R> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ protected void shardOperationOnPrimary(
IndexShard primary,
ActionListener<PrimaryResult<ReplicaRequest, Response>> listener
) {
request.incRef();
listener = ActionListener.runAfter(listener, request::decRef);
threadPool.executor(executorFunction.apply(executorSelector, primary)).execute(new ActionRunnable<>(listener) {
@Override
protected void doRun() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,31 @@
public class BulkPrimaryExecutionContextTests extends ESTestCase {

public void testAbortedSkipped() {
BulkShardRequest shardRequest = generateRandomRequest();

ArrayList<DocWriteRequest<?>> nonAbortedRequests = new ArrayList<>();
for (BulkItemRequest request : shardRequest.items()) {
if (randomBoolean()) {
request.abort("index", new ElasticsearchException("bla"));
} else {
nonAbortedRequests.add(request.request());
try (BulkShardRequest shardRequest = generateRandomRequest()) {

ArrayList<DocWriteRequest<?>> nonAbortedRequests = new ArrayList<>();
for (BulkItemRequest request : shardRequest.items()) {
if (randomBoolean()) {
request.abort("index", new ElasticsearchException("bla"));
} else {
nonAbortedRequests.add(request.request());
}
}
}

ArrayList<DocWriteRequest<?>> visitedRequests = new ArrayList<>();
for (BulkPrimaryExecutionContext context = new BulkPrimaryExecutionContext(shardRequest, null); context
.hasMoreOperationsToExecute();) {
visitedRequests.add(context.getCurrent());
context.setRequestToExecute(context.getCurrent());
// using failures prevents caring about types
context.markOperationAsExecuted(
new Engine.IndexResult(new ElasticsearchException("bla"), 1, context.getRequestToExecute().id())
);
context.markAsCompleted(context.getExecutionResult());
}
ArrayList<DocWriteRequest<?>> visitedRequests = new ArrayList<>();
for (BulkPrimaryExecutionContext context = new BulkPrimaryExecutionContext(shardRequest, null); context
.hasMoreOperationsToExecute();) {
visitedRequests.add(context.getCurrent());
context.setRequestToExecute(context.getCurrent());
// using failures prevents caring about types
context.markOperationAsExecuted(
new Engine.IndexResult(new ElasticsearchException("bla"), 1, context.getRequestToExecute().id())
);
context.markAsCompleted(context.getExecutionResult());
}

assertThat(visitedRequests, equalTo(nonAbortedRequests));
assertThat(visitedRequests, equalTo(nonAbortedRequests));
}
}

private BulkShardRequest generateRandomRequest() {
Expand All @@ -73,62 +74,63 @@ private BulkShardRequest generateRandomRequest() {

public void testTranslogLocation() {

BulkShardRequest shardRequest = generateRandomRequest();

Translog.Location expectedLocation = null;
final IndexShard primary = mock(IndexShard.class);
when(primary.shardId()).thenReturn(shardRequest.shardId());

long translogGen = 0;
long translogOffset = 0;

BulkPrimaryExecutionContext context = new BulkPrimaryExecutionContext(shardRequest, primary);
while (context.hasMoreOperationsToExecute()) {
final Engine.Result result;
final DocWriteRequest<?> current = context.getCurrent();
final boolean failure = rarely();
if (frequently()) {
translogGen += randomIntBetween(1, 4);
translogOffset = 0;
} else {
translogOffset += randomIntBetween(200, 400);
}
try (BulkShardRequest shardRequest = generateRandomRequest()) {

Translog.Location expectedLocation = null;
final IndexShard primary = mock(IndexShard.class);
when(primary.shardId()).thenReturn(shardRequest.shardId());

long translogGen = 0;
long translogOffset = 0;

BulkPrimaryExecutionContext context = new BulkPrimaryExecutionContext(shardRequest, primary);
while (context.hasMoreOperationsToExecute()) {
final Engine.Result result;
final DocWriteRequest<?> current = context.getCurrent();
final boolean failure = rarely();
if (frequently()) {
translogGen += randomIntBetween(1, 4);
translogOffset = 0;
} else {
translogOffset += randomIntBetween(200, 400);
}

Translog.Location location = new Translog.Location(translogGen, translogOffset, randomInt(200));
switch (current.opType()) {
case INDEX, CREATE -> {
context.setRequestToExecute(current);
if (failure) {
result = new Engine.IndexResult(new ElasticsearchException("bla"), 1, current.id());
} else {
result = new FakeIndexResult(1, 1, randomLongBetween(0, 200), randomBoolean(), location, "id");
Translog.Location location = new Translog.Location(translogGen, translogOffset, randomInt(200));
switch (current.opType()) {
case INDEX, CREATE -> {
context.setRequestToExecute(current);
if (failure) {
result = new Engine.IndexResult(new ElasticsearchException("bla"), 1, current.id());
} else {
result = new FakeIndexResult(1, 1, randomLongBetween(0, 200), randomBoolean(), location, "id");
}
}
}
case UPDATE -> {
context.setRequestToExecute(new IndexRequest(current.index()).id(current.id()));
if (failure) {
result = new Engine.IndexResult(new ElasticsearchException("bla"), 1, 1, 1, current.id());
} else {
result = new FakeIndexResult(1, 1, randomLongBetween(0, 200), randomBoolean(), location, "id");
case UPDATE -> {
context.setRequestToExecute(new IndexRequest(current.index()).id(current.id()));
if (failure) {
result = new Engine.IndexResult(new ElasticsearchException("bla"), 1, 1, 1, current.id());
} else {
result = new FakeIndexResult(1, 1, randomLongBetween(0, 200), randomBoolean(), location, "id");
}
}
}
case DELETE -> {
context.setRequestToExecute(current);
if (failure) {
result = new Engine.DeleteResult(new ElasticsearchException("bla"), 1, 1, current.id());
} else {
result = new FakeDeleteResult(1, 1, randomLongBetween(0, 200), randomBoolean(), location, current.id());
case DELETE -> {
context.setRequestToExecute(current);
if (failure) {
result = new Engine.DeleteResult(new ElasticsearchException("bla"), 1, 1, current.id());
} else {
result = new FakeDeleteResult(1, 1, randomLongBetween(0, 200), randomBoolean(), location, current.id());
}
}
default -> throw new AssertionError("unknown type:" + current.opType());
}
default -> throw new AssertionError("unknown type:" + current.opType());
}
if (failure == false) {
expectedLocation = location;
if (failure == false) {
expectedLocation = location;
}
context.markOperationAsExecuted(result);
context.markAsCompleted(context.getExecutionResult());
}
context.markOperationAsExecuted(result);
context.markAsCompleted(context.getExecutionResult());
}

assertThat(context.getLocationToSync(), equalTo(expectedLocation));
assertThat(context.getLocationToSync(), equalTo(expectedLocation));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,19 @@ public void testToString() {
String index = randomSimpleString(random(), 10);
int count = between(2, 100);
final ShardId shardId = new ShardId(index, "ignored", 0);
BulkShardRequest r = new BulkShardRequest(shardId, RefreshPolicy.NONE, new BulkItemRequest[count]);
assertEquals("BulkShardRequest [" + shardId + "] containing [" + count + "] requests", r.toString());
assertEquals("requests[" + count + "], index[" + index + "][0]", r.getDescription());
try (BulkShardRequest r = new BulkShardRequest(shardId, RefreshPolicy.NONE, new BulkItemRequest[count])) {
assertEquals("BulkShardRequest [" + shardId + "] containing [" + count + "] requests", r.toString());
assertEquals("requests[" + count + "], index[" + index + "][0]", r.getDescription());
}

r = new BulkShardRequest(shardId, RefreshPolicy.IMMEDIATE, new BulkItemRequest[count]);
assertEquals("BulkShardRequest [" + shardId + "] containing [" + count + "] requests and a refresh", r.toString());
assertEquals("requests[" + count + "], index[" + index + "][0], refresh[IMMEDIATE]", r.getDescription());
try (BulkShardRequest r = new BulkShardRequest(shardId, RefreshPolicy.IMMEDIATE, new BulkItemRequest[count])) {
assertEquals("BulkShardRequest [" + shardId + "] containing [" + count + "] requests and a refresh", r.toString());
assertEquals("requests[" + count + "], index[" + index + "][0], refresh[IMMEDIATE]", r.getDescription());
}

r = new BulkShardRequest(shardId, RefreshPolicy.WAIT_UNTIL, new BulkItemRequest[count]);
assertEquals("BulkShardRequest [" + shardId + "] containing [" + count + "] requests blocking until refresh", r.toString());
assertEquals("requests[" + count + "], index[" + index + "][0], refresh[WAIT_UNTIL]", r.getDescription());
try (BulkShardRequest r = new BulkShardRequest(shardId, RefreshPolicy.WAIT_UNTIL, new BulkItemRequest[count])) {
assertEquals("BulkShardRequest [" + shardId + "] containing [" + count + "] requests blocking until refresh", r.toString());
assertEquals("requests[" + count + "], index[" + index + "][0], refresh[WAIT_UNTIL]", r.getDescription());
}
}
}
Loading

0 comments on commit 1fb746b

Please sign in to comment.