Skip to content

Commit

Permalink
Decide number of clients basing on average request size of client
Browse files Browse the repository at this point in the history
Change the way how DirectExchangeClient.scheduleRequestIfNecessary calculates
the number of clients to be requested on the exchange phase to use an average
request size of specific client instead of aggregated average of all clients.
  • Loading branch information
radek-kondziolka authored Dec 15, 2022
1 parent 4a65f20 commit ac661b9
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.operator;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.http.client.HttpClient;
Expand All @@ -36,9 +37,9 @@
import java.util.Deque;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.locks.Lock;
Expand Down Expand Up @@ -68,7 +69,7 @@ public class DirectExchangeClient
@GuardedBy("this")
private boolean noMoreLocations;

private final ConcurrentMap<URI, HttpPageBufferClient> allClients = new ConcurrentHashMap<>();
private final Map<URI, HttpPageBufferClient> allClients = new ConcurrentHashMap<>();

@GuardedBy("this")
private final Deque<HttpPageBufferClient> queuedClients = new LinkedList<>();
Expand Down Expand Up @@ -260,38 +261,56 @@ public synchronized void close()
}
}

private synchronized void scheduleRequestIfNecessary()
@VisibleForTesting
synchronized int scheduleRequestIfNecessary()
{
if ((buffer.isFinished() || buffer.isFailed()) && completedClients.size() == allClients.size()) {
return;
return 0;
}

long neededBytes = buffer.getRemainingCapacityInBytes();
if (neededBytes <= 0) {
return;
return 0;
}

int clientCount = (int) ((1.0 * neededBytes / averageBytesPerRequest) * concurrentRequestMultiplier);
clientCount = Math.max(clientCount, 1);

int pendingClients = allClients.size() - queuedClients.size() - completedClients.size();
clientCount -= pendingClients;
long reservedBytesForScheduledClients = allClients.values().stream()
.filter(client -> !queuedClients.contains(client) && !completedClients.contains(client))
.mapToLong(HttpPageBufferClient::getAverageRequestSizeInBytes)
.sum();
long projectedBytesToBeRequested = 0;
int clientCount = 0;

for (HttpPageBufferClient client : queuedClients) {
if (projectedBytesToBeRequested >= neededBytes * concurrentRequestMultiplier - reservedBytesForScheduledClients) {
break;
}
projectedBytesToBeRequested += client.getAverageRequestSizeInBytes();
clientCount++;
}
for (int i = 0; i < clientCount; i++) {
HttpPageBufferClient client = queuedClients.poll();
if (client == null) {
// no more clients available
return;
}
client.scheduleRequest();
}
return clientCount;
}

public ListenableFuture<Void> isBlocked()
{
return buffer.isBlocked();
}

@VisibleForTesting
Deque<HttpPageBufferClient> getQueuedClients()
{
return queuedClients;
}

@VisibleForTesting
Map<URI, HttpPageBufferClient> getAllClients()
{
return allClients;
}

private boolean addPages(HttpPageBufferClient client, List<Slice> pages)
{
checkState(!completedClients.contains(client), "client is already marked as completed");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.operator;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Ticker;
import com.google.common.collect.ImmutableList;
import com.google.common.io.LittleEndianDataInputStream;
Expand Down Expand Up @@ -145,6 +146,9 @@ public interface ClientCallback
@GuardedBy("this")
private String taskInstanceId;

// it is synchronized on `this` for update
private volatile long averageRequestSizeInBytes;

private final AtomicLong rowsReceived = new AtomicLong();
private final AtomicInteger pagesReceived = new AtomicInteger();

Expand All @@ -153,6 +157,7 @@ public interface ClientCallback

private final AtomicInteger requestsScheduled = new AtomicInteger();
private final AtomicInteger requestsCompleted = new AtomicInteger();
private final AtomicInteger requestsSucceeded = new AtomicInteger();
private final AtomicInteger requestsFailed = new AtomicInteger();

private final Executor pageBufferClientCallbackExecutor;
Expand Down Expand Up @@ -251,6 +256,7 @@ else if (completed) {
requestsScheduled.get(),
requestsCompleted.get(),
requestsFailed.get(),
requestsSucceeded.get(),
httpRequestState);
}

Expand All @@ -259,6 +265,11 @@ public TaskId getRemoteTaskId()
return remoteTaskId;
}

public long getAverageRequestSizeInBytes()
{
return averageRequestSizeInBytes;
}

public synchronized boolean isRunning()
{
return future != null;
Expand Down Expand Up @@ -434,6 +445,8 @@ public Void handle(Request request, Response response)
}
}
requestsCompleted.incrementAndGet();
long responseSize = pages.stream().mapToLong(Slice::length).sum();
requestSucceeded(responseSize);

synchronized (HttpPageBufferClient.this) {
// client is complete, acknowledge it by sending it a delete in the next request
Expand Down Expand Up @@ -485,6 +498,14 @@ public void onFailure(Throwable t)
}, pageBufferClientCallbackExecutor);
}

@VisibleForTesting
synchronized void requestSucceeded(long responseSize)
{
int successfulRequests = requestsSucceeded.incrementAndGet();
// AVG_n = AVG_(n-1) * (n-1)/n + VALUE_n / n
averageRequestSizeInBytes = (long) ((1.0 * averageRequestSizeInBytes * (successfulRequests - 1)) + responseSize) / successfulRequests;
}

private synchronized void destroyTaskResults()
{
HttpResponseFuture<StatusResponse> resultFuture = httpClient.executeAsync(prepareDelete().setUri(location).build(), createStatusResponseHandler());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public class PageBufferClientStatus
private final int requestsScheduled;
private final int requestsCompleted;
private final int requestsFailed;
private final int requestsSucceeded;
private final String httpRequestState;

@JsonCreator
Expand All @@ -50,6 +51,7 @@ public PageBufferClientStatus(@JsonProperty("uri") URI uri,
@JsonProperty("requestsScheduled") int requestsScheduled,
@JsonProperty("requestsCompleted") int requestsCompleted,
@JsonProperty("requestsFailed") int requestsFailed,
@JsonProperty("requestsSucceeded") int requestsSucceeded,
@JsonProperty("httpRequestState") String httpRequestState)
{
this.uri = uri;
Expand All @@ -62,6 +64,7 @@ public PageBufferClientStatus(@JsonProperty("uri") URI uri,
this.requestsScheduled = requestsScheduled;
this.requestsCompleted = requestsCompleted;
this.requestsFailed = requestsFailed;
this.requestsSucceeded = requestsSucceeded;
this.httpRequestState = httpRequestState;
}

Expand Down Expand Up @@ -125,6 +128,12 @@ public int getRequestsFailed()
return requestsFailed;
}

@JsonProperty
public int getRequestsSucceeded()
{
return requestsSucceeded;
}

@JsonProperty
public String getHttpRequestState()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
import static io.airlift.concurrent.MoreFutures.tryGetFutureValue;
import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static io.airlift.testing.Assertions.assertLessThan;
import static io.trino.execution.TestSqlTaskExecution.TASK_ID;
import static io.trino.execution.buffer.PagesSerdeUtil.getSerializedPagePositionCount;
import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext;
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
Expand Down Expand Up @@ -946,6 +947,128 @@ public void testStreamingClose()
assertEquals(clientStatus.getHttpRequestState(), "not scheduled", "httpRequestState");
}

@Test
public void testScheduleWhenOneClientFilledBuffer()
{
DataSize maxResponseSize = DataSize.of(8, Unit.MEGABYTE);

URI locationOne = URI.create("http://localhost:8080");
URI locationTwo = URI.create("http://localhost:8081");
MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize);

HttpPageBufferClient clientToBeUsed = createHttpPageBufferClient(processor, maxResponseSize, locationOne, new MockClientCallback());
HttpPageBufferClient clientToBeSkipped = createHttpPageBufferClient(processor, maxResponseSize, locationTwo, new MockClientCallback());
clientToBeUsed.requestSucceeded(DataSize.of(33, Unit.MEGABYTE).toBytes());
clientToBeSkipped.requestSucceeded(DataSize.of(1, Unit.MEGABYTE).toBytes());

@SuppressWarnings("resource")
DirectExchangeClient exchangeClient = new DirectExchangeClient(
"localhost",
DataIntegrityVerification.ABORT,
new StreamingDirectExchangeBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)),
maxResponseSize,
1,
new Duration(1, TimeUnit.MINUTES),
true,
new TestingHttpClient(processor, scheduler),
scheduler,
new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"),
pageBufferClientCallbackExecutor,
(taskId, failure) -> {});
exchangeClient.getAllClients().putAll(Map.of(locationOne, clientToBeUsed, locationTwo, clientToBeSkipped));
exchangeClient.getQueuedClients().addAll(ImmutableList.of(clientToBeUsed, clientToBeSkipped));

int clientCount = exchangeClient.scheduleRequestIfNecessary();
// The first client filled the buffer. There is no place for the another one
assertEquals(clientCount, 1);
}

@Test
public void testScheduleWhenAllClientsAreEmpty()
{
DataSize maxResponseSize = DataSize.of(8, Unit.MEGABYTE);

URI locationOne = URI.create("http://localhost:8080");
URI locationTwo = URI.create("http://localhost:8081");
MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize);

HttpPageBufferClient firstClient = createHttpPageBufferClient(processor, maxResponseSize, locationOne, new MockClientCallback());
HttpPageBufferClient secondClient = createHttpPageBufferClient(processor, maxResponseSize, locationTwo, new MockClientCallback());

@SuppressWarnings("resource")
DirectExchangeClient exchangeClient = new DirectExchangeClient(
"localhost",
DataIntegrityVerification.ABORT,
new StreamingDirectExchangeBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)),
maxResponseSize,
1,
new Duration(1, TimeUnit.MINUTES),
true,
new TestingHttpClient(processor, scheduler),
scheduler,
new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"),
pageBufferClientCallbackExecutor,
(taskId, failure) -> {});
exchangeClient.getAllClients().putAll(Map.of(locationOne, firstClient, locationTwo, secondClient));
exchangeClient.getQueuedClients().addAll(ImmutableList.of(firstClient, secondClient));

int clientCount = exchangeClient.scheduleRequestIfNecessary();
assertEquals(clientCount, 2);
}

@Test
public void testScheduleWhenThereIsPendingClient()
{
DataSize maxResponseSize = DataSize.of(8, Unit.MEGABYTE);

URI locationOne = URI.create("http://localhost:8080");
URI locationTwo = URI.create("http://localhost:8081");

MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize);

HttpPageBufferClient pendingClient = createHttpPageBufferClient(processor, maxResponseSize, locationOne, new MockClientCallback());
HttpPageBufferClient clientToBeSkipped = createHttpPageBufferClient(processor, maxResponseSize, locationTwo, new MockClientCallback());

pendingClient.requestSucceeded(DataSize.of(33, Unit.MEGABYTE).toBytes());

@SuppressWarnings("resource")
DirectExchangeClient exchangeClient = new DirectExchangeClient(
"localhost",
DataIntegrityVerification.ABORT,
new StreamingDirectExchangeBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)),
maxResponseSize,
1,
new Duration(1, TimeUnit.MINUTES),
true,
new TestingHttpClient(processor, scheduler),
scheduler,
new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"),
pageBufferClientCallbackExecutor,
(taskId, failure) -> {});
exchangeClient.getAllClients().putAll(Map.of(locationOne, pendingClient, locationTwo, clientToBeSkipped));
exchangeClient.getQueuedClients().add(clientToBeSkipped);

int clientCount = exchangeClient.scheduleRequestIfNecessary();
// The first client is pending and it reserved the space in the buffer. There is no place for the another one
assertEquals(clientCount, 0);
}

private HttpPageBufferClient createHttpPageBufferClient(TestingHttpClient.Processor processor, DataSize expectedMaxSize, URI location, HttpPageBufferClient.ClientCallback callback)
{
return new HttpPageBufferClient(
"localhost",
new TestingHttpClient(processor, scheduler),
DataIntegrityVerification.ABORT,
expectedMaxSize,
new Duration(1, TimeUnit.MINUTES),
true,
TASK_ID,
location,
callback,
scheduler,
pageBufferClientCallbackExecutor);
}

private static Page createPage(int size)
{
return new Page(BlockAssertions.createLongSequenceBlock(0, size));
Expand Down Expand Up @@ -985,4 +1108,29 @@ private static void assertStatus(
assertEquals(clientStatus.getRequestsCompleted(), requestsCompleted, "requestsCompleted");
assertEquals(clientStatus.getHttpRequestState(), httpRequestState, "httpRequestState");
}

private static class MockClientCallback
implements HttpPageBufferClient.ClientCallback
{
@Override
public boolean addPages(HttpPageBufferClient client, List<Slice> pages)
{
return false;
}

@Override
public void requestComplete(HttpPageBufferClient client)
{
}

@Override
public void clientFinished(HttpPageBufferClient client)
{
}

@Override
public void clientFailed(HttpPageBufferClient client, Throwable cause)
{
}
}
}
Loading

0 comments on commit ac661b9

Please sign in to comment.