From ac661b957a8b8674ed273df5ad8ae1c8dba990d2 Mon Sep 17 00:00:00 2001 From: radek-starburst <94364205+radek-starburst@users.noreply.github.com> Date: Thu, 15 Dec 2022 13:01:26 +0100 Subject: [PATCH] Decide number of clients basing on average request size of client Change the way how DirectExchangeClient.scheduleRequestIfNecessary calculates the number of clients to be requested on the exchange phase to use an average request size of specific client instead of aggregated average of all clients. --- .../trino/operator/DirectExchangeClient.java | 47 ++++-- .../trino/operator/HttpPageBufferClient.java | 21 +++ .../operator/PageBufferClientStatus.java | 9 ++ .../operator/TestDirectExchangeClient.java | 148 ++++++++++++++++++ .../operator/TestHttpPageBufferClient.java | 27 ++++ 5 files changed, 238 insertions(+), 14 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java index f5e461cc3e5d..5eed6d8fe2d6 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java +++ b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java @@ -13,6 +13,7 @@ */ package io.trino.operator; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.http.client.HttpClient; @@ -36,9 +37,9 @@ import java.util.Deque; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.locks.Lock; @@ -68,7 +69,7 @@ public class DirectExchangeClient @GuardedBy("this") private boolean noMoreLocations; - private final ConcurrentMap allClients = new ConcurrentHashMap<>(); + private final Map allClients = new ConcurrentHashMap<>(); @GuardedBy("this") private final Deque queuedClients = new LinkedList<>(); @@ -260,31 +261,37 @@ public synchronized void close() } } - private synchronized void scheduleRequestIfNecessary() + @VisibleForTesting + synchronized int scheduleRequestIfNecessary() { if ((buffer.isFinished() || buffer.isFailed()) && completedClients.size() == allClients.size()) { - return; + return 0; } long neededBytes = buffer.getRemainingCapacityInBytes(); if (neededBytes <= 0) { - return; + return 0; } - int clientCount = (int) ((1.0 * neededBytes / averageBytesPerRequest) * concurrentRequestMultiplier); - clientCount = Math.max(clientCount, 1); - - int pendingClients = allClients.size() - queuedClients.size() - completedClients.size(); - clientCount -= pendingClients; + long reservedBytesForScheduledClients = allClients.values().stream() + .filter(client -> !queuedClients.contains(client) && !completedClients.contains(client)) + .mapToLong(HttpPageBufferClient::getAverageRequestSizeInBytes) + .sum(); + long projectedBytesToBeRequested = 0; + int clientCount = 0; + for (HttpPageBufferClient client : queuedClients) { + if (projectedBytesToBeRequested >= neededBytes * concurrentRequestMultiplier - reservedBytesForScheduledClients) { + break; + } + projectedBytesToBeRequested += client.getAverageRequestSizeInBytes(); + clientCount++; + } for (int i = 0; i < clientCount; i++) { HttpPageBufferClient client = queuedClients.poll(); - if (client == null) { - // no more clients available - return; - } client.scheduleRequest(); } + return clientCount; } public ListenableFuture isBlocked() @@ -292,6 +299,18 @@ public ListenableFuture isBlocked() return buffer.isBlocked(); } + @VisibleForTesting + Deque getQueuedClients() + { + return queuedClients; + } + + @VisibleForTesting + Map getAllClients() + { + return allClients; + } + private boolean addPages(HttpPageBufferClient client, List pages) { checkState(!completedClients.contains(client), "client is already marked as completed"); diff --git a/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java b/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java index 388883f38dee..3e9b0dcfd72c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java +++ b/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java @@ -13,6 +13,7 @@ */ package io.trino.operator; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Ticker; import com.google.common.collect.ImmutableList; import com.google.common.io.LittleEndianDataInputStream; @@ -145,6 +146,9 @@ public interface ClientCallback @GuardedBy("this") private String taskInstanceId; + // it is synchronized on `this` for update + private volatile long averageRequestSizeInBytes; + private final AtomicLong rowsReceived = new AtomicLong(); private final AtomicInteger pagesReceived = new AtomicInteger(); @@ -153,6 +157,7 @@ public interface ClientCallback private final AtomicInteger requestsScheduled = new AtomicInteger(); private final AtomicInteger requestsCompleted = new AtomicInteger(); + private final AtomicInteger requestsSucceeded = new AtomicInteger(); private final AtomicInteger requestsFailed = new AtomicInteger(); private final Executor pageBufferClientCallbackExecutor; @@ -251,6 +256,7 @@ else if (completed) { requestsScheduled.get(), requestsCompleted.get(), requestsFailed.get(), + requestsSucceeded.get(), httpRequestState); } @@ -259,6 +265,11 @@ public TaskId getRemoteTaskId() return remoteTaskId; } + public long getAverageRequestSizeInBytes() + { + return averageRequestSizeInBytes; + } + public synchronized boolean isRunning() { return future != null; @@ -434,6 +445,8 @@ public Void handle(Request request, Response response) } } requestsCompleted.incrementAndGet(); + long responseSize = pages.stream().mapToLong(Slice::length).sum(); + requestSucceeded(responseSize); synchronized (HttpPageBufferClient.this) { // client is complete, acknowledge it by sending it a delete in the next request @@ -485,6 +498,14 @@ public void onFailure(Throwable t) }, pageBufferClientCallbackExecutor); } + @VisibleForTesting + synchronized void requestSucceeded(long responseSize) + { + int successfulRequests = requestsSucceeded.incrementAndGet(); + // AVG_n = AVG_(n-1) * (n-1)/n + VALUE_n / n + averageRequestSizeInBytes = (long) ((1.0 * averageRequestSizeInBytes * (successfulRequests - 1)) + responseSize) / successfulRequests; + } + private synchronized void destroyTaskResults() { HttpResponseFuture resultFuture = httpClient.executeAsync(prepareDelete().setUri(location).build(), createStatusResponseHandler()); diff --git a/core/trino-main/src/main/java/io/trino/operator/PageBufferClientStatus.java b/core/trino-main/src/main/java/io/trino/operator/PageBufferClientStatus.java index a1a2dac3b707..f8584b97f210 100644 --- a/core/trino-main/src/main/java/io/trino/operator/PageBufferClientStatus.java +++ b/core/trino-main/src/main/java/io/trino/operator/PageBufferClientStatus.java @@ -37,6 +37,7 @@ public class PageBufferClientStatus private final int requestsScheduled; private final int requestsCompleted; private final int requestsFailed; + private final int requestsSucceeded; private final String httpRequestState; @JsonCreator @@ -50,6 +51,7 @@ public PageBufferClientStatus(@JsonProperty("uri") URI uri, @JsonProperty("requestsScheduled") int requestsScheduled, @JsonProperty("requestsCompleted") int requestsCompleted, @JsonProperty("requestsFailed") int requestsFailed, + @JsonProperty("requestsSucceeded") int requestsSucceeded, @JsonProperty("httpRequestState") String httpRequestState) { this.uri = uri; @@ -62,6 +64,7 @@ public PageBufferClientStatus(@JsonProperty("uri") URI uri, this.requestsScheduled = requestsScheduled; this.requestsCompleted = requestsCompleted; this.requestsFailed = requestsFailed; + this.requestsSucceeded = requestsSucceeded; this.httpRequestState = httpRequestState; } @@ -125,6 +128,12 @@ public int getRequestsFailed() return requestsFailed; } + @JsonProperty + public int getRequestsSucceeded() + { + return requestsSucceeded; + } + @JsonProperty public String getHttpRequestState() { diff --git a/core/trino-main/src/test/java/io/trino/operator/TestDirectExchangeClient.java b/core/trino-main/src/test/java/io/trino/operator/TestDirectExchangeClient.java index 36cc9ae989b5..540528144f3b 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestDirectExchangeClient.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestDirectExchangeClient.java @@ -70,6 +70,7 @@ import static io.airlift.concurrent.MoreFutures.tryGetFutureValue; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.testing.Assertions.assertLessThan; +import static io.trino.execution.TestSqlTaskExecution.TASK_ID; import static io.trino.execution.buffer.PagesSerdeUtil.getSerializedPagePositionCount; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; @@ -946,6 +947,128 @@ public void testStreamingClose() assertEquals(clientStatus.getHttpRequestState(), "not scheduled", "httpRequestState"); } + @Test + public void testScheduleWhenOneClientFilledBuffer() + { + DataSize maxResponseSize = DataSize.of(8, Unit.MEGABYTE); + + URI locationOne = URI.create("http://localhost:8080"); + URI locationTwo = URI.create("http://localhost:8081"); + MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize); + + HttpPageBufferClient clientToBeUsed = createHttpPageBufferClient(processor, maxResponseSize, locationOne, new MockClientCallback()); + HttpPageBufferClient clientToBeSkipped = createHttpPageBufferClient(processor, maxResponseSize, locationTwo, new MockClientCallback()); + clientToBeUsed.requestSucceeded(DataSize.of(33, Unit.MEGABYTE).toBytes()); + clientToBeSkipped.requestSucceeded(DataSize.of(1, Unit.MEGABYTE).toBytes()); + + @SuppressWarnings("resource") + DirectExchangeClient exchangeClient = new DirectExchangeClient( + "localhost", + DataIntegrityVerification.ABORT, + new StreamingDirectExchangeBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)), + maxResponseSize, + 1, + new Duration(1, TimeUnit.MINUTES), + true, + new TestingHttpClient(processor, scheduler), + scheduler, + new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), + pageBufferClientCallbackExecutor, + (taskId, failure) -> {}); + exchangeClient.getAllClients().putAll(Map.of(locationOne, clientToBeUsed, locationTwo, clientToBeSkipped)); + exchangeClient.getQueuedClients().addAll(ImmutableList.of(clientToBeUsed, clientToBeSkipped)); + + int clientCount = exchangeClient.scheduleRequestIfNecessary(); + // The first client filled the buffer. There is no place for the another one + assertEquals(clientCount, 1); + } + + @Test + public void testScheduleWhenAllClientsAreEmpty() + { + DataSize maxResponseSize = DataSize.of(8, Unit.MEGABYTE); + + URI locationOne = URI.create("http://localhost:8080"); + URI locationTwo = URI.create("http://localhost:8081"); + MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize); + + HttpPageBufferClient firstClient = createHttpPageBufferClient(processor, maxResponseSize, locationOne, new MockClientCallback()); + HttpPageBufferClient secondClient = createHttpPageBufferClient(processor, maxResponseSize, locationTwo, new MockClientCallback()); + + @SuppressWarnings("resource") + DirectExchangeClient exchangeClient = new DirectExchangeClient( + "localhost", + DataIntegrityVerification.ABORT, + new StreamingDirectExchangeBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)), + maxResponseSize, + 1, + new Duration(1, TimeUnit.MINUTES), + true, + new TestingHttpClient(processor, scheduler), + scheduler, + new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), + pageBufferClientCallbackExecutor, + (taskId, failure) -> {}); + exchangeClient.getAllClients().putAll(Map.of(locationOne, firstClient, locationTwo, secondClient)); + exchangeClient.getQueuedClients().addAll(ImmutableList.of(firstClient, secondClient)); + + int clientCount = exchangeClient.scheduleRequestIfNecessary(); + assertEquals(clientCount, 2); + } + + @Test + public void testScheduleWhenThereIsPendingClient() + { + DataSize maxResponseSize = DataSize.of(8, Unit.MEGABYTE); + + URI locationOne = URI.create("http://localhost:8080"); + URI locationTwo = URI.create("http://localhost:8081"); + + MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize); + + HttpPageBufferClient pendingClient = createHttpPageBufferClient(processor, maxResponseSize, locationOne, new MockClientCallback()); + HttpPageBufferClient clientToBeSkipped = createHttpPageBufferClient(processor, maxResponseSize, locationTwo, new MockClientCallback()); + + pendingClient.requestSucceeded(DataSize.of(33, Unit.MEGABYTE).toBytes()); + + @SuppressWarnings("resource") + DirectExchangeClient exchangeClient = new DirectExchangeClient( + "localhost", + DataIntegrityVerification.ABORT, + new StreamingDirectExchangeBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)), + maxResponseSize, + 1, + new Duration(1, TimeUnit.MINUTES), + true, + new TestingHttpClient(processor, scheduler), + scheduler, + new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), + pageBufferClientCallbackExecutor, + (taskId, failure) -> {}); + exchangeClient.getAllClients().putAll(Map.of(locationOne, pendingClient, locationTwo, clientToBeSkipped)); + exchangeClient.getQueuedClients().add(clientToBeSkipped); + + int clientCount = exchangeClient.scheduleRequestIfNecessary(); + // The first client is pending and it reserved the space in the buffer. There is no place for the another one + assertEquals(clientCount, 0); + } + + private HttpPageBufferClient createHttpPageBufferClient(TestingHttpClient.Processor processor, DataSize expectedMaxSize, URI location, HttpPageBufferClient.ClientCallback callback) + { + return new HttpPageBufferClient( + "localhost", + new TestingHttpClient(processor, scheduler), + DataIntegrityVerification.ABORT, + expectedMaxSize, + new Duration(1, TimeUnit.MINUTES), + true, + TASK_ID, + location, + callback, + scheduler, + pageBufferClientCallbackExecutor); + } + private static Page createPage(int size) { return new Page(BlockAssertions.createLongSequenceBlock(0, size)); @@ -985,4 +1108,29 @@ private static void assertStatus( assertEquals(clientStatus.getRequestsCompleted(), requestsCompleted, "requestsCompleted"); assertEquals(clientStatus.getHttpRequestState(), httpRequestState, "httpRequestState"); } + + private static class MockClientCallback + implements HttpPageBufferClient.ClientCallback + { + @Override + public boolean addPages(HttpPageBufferClient client, List pages) + { + return false; + } + + @Override + public void requestComplete(HttpPageBufferClient client) + { + } + + @Override + public void clientFinished(HttpPageBufferClient client) + { + } + + @Override + public void clientFailed(HttpPageBufferClient client, Throwable cause) + { + } + } } diff --git a/core/trino-main/src/test/java/io/trino/operator/TestHttpPageBufferClient.java b/core/trino-main/src/test/java/io/trino/operator/TestHttpPageBufferClient.java index 77c6293e47c2..47ab79c4a4c0 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestHttpPageBufferClient.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestHttpPageBufferClient.java @@ -438,6 +438,33 @@ public void testErrorCodes() assertEquals(new PageTransportTimeoutException(HostAddress.fromParts("127.0.0.1", 8080), "", null).getErrorCode(), PAGE_TRANSPORT_TIMEOUT.toErrorCode()); } + @Test + public void testAverageSizeOfRequest() + { + HttpPageBufferClient client = new HttpPageBufferClient( + "localhost", + new TestingHttpClient(new MockExchangeRequestProcessor(DataSize.of(10, MEGABYTE)), scheduler), + DataIntegrityVerification.ABORT, + DataSize.of(10, MEGABYTE), + new Duration(30, TimeUnit.SECONDS), + true, + TASK_ID, + URI.create("http://localhost:8080"), + new TestingClientCallback(new CyclicBarrier(1)), + scheduler, + new TestingTicker(), + pageBufferClientCallbackExecutor); + + assertEquals(client.getAverageRequestSizeInBytes(), 0); + + client.requestSucceeded(0); + assertEquals(client.getAverageRequestSizeInBytes(), 0); + + client.requestSucceeded(1000); + client.requestSucceeded(800); + assertEquals(client.getAverageRequestSizeInBytes(), 600); + } + @Test public void testMemoryExceededInAddPages() throws Exception