diff --git a/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java index f5e461cc3e5d..5eed6d8fe2d6 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java +++ b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java @@ -13,6 +13,7 @@ */ package io.trino.operator; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.http.client.HttpClient; @@ -36,9 +37,9 @@ import java.util.Deque; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.locks.Lock; @@ -68,7 +69,7 @@ public class DirectExchangeClient @GuardedBy("this") private boolean noMoreLocations; - private final ConcurrentMap allClients = new ConcurrentHashMap<>(); + private final Map allClients = new ConcurrentHashMap<>(); @GuardedBy("this") private final Deque queuedClients = new LinkedList<>(); @@ -260,31 +261,37 @@ public synchronized void close() } } - private synchronized void scheduleRequestIfNecessary() + @VisibleForTesting + synchronized int scheduleRequestIfNecessary() { if ((buffer.isFinished() || buffer.isFailed()) && completedClients.size() == allClients.size()) { - return; + return 0; } long neededBytes = buffer.getRemainingCapacityInBytes(); if (neededBytes <= 0) { - return; + return 0; } - int clientCount = (int) ((1.0 * neededBytes / averageBytesPerRequest) * concurrentRequestMultiplier); - clientCount = Math.max(clientCount, 1); - - int pendingClients = allClients.size() - queuedClients.size() - completedClients.size(); - clientCount -= pendingClients; + long reservedBytesForScheduledClients = allClients.values().stream() + .filter(client -> !queuedClients.contains(client) && !completedClients.contains(client)) + .mapToLong(HttpPageBufferClient::getAverageRequestSizeInBytes) + .sum(); + long projectedBytesToBeRequested = 0; + int clientCount = 0; + for (HttpPageBufferClient client : queuedClients) { + if (projectedBytesToBeRequested >= neededBytes * concurrentRequestMultiplier - reservedBytesForScheduledClients) { + break; + } + projectedBytesToBeRequested += client.getAverageRequestSizeInBytes(); + clientCount++; + } for (int i = 0; i < clientCount; i++) { HttpPageBufferClient client = queuedClients.poll(); - if (client == null) { - // no more clients available - return; - } client.scheduleRequest(); } + return clientCount; } public ListenableFuture isBlocked() @@ -292,6 +299,18 @@ public ListenableFuture isBlocked() return buffer.isBlocked(); } + @VisibleForTesting + Deque getQueuedClients() + { + return queuedClients; + } + + @VisibleForTesting + Map getAllClients() + { + return allClients; + } + private boolean addPages(HttpPageBufferClient client, List pages) { checkState(!completedClients.contains(client), "client is already marked as completed"); diff --git a/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java b/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java index 388883f38dee..3e9b0dcfd72c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java +++ b/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java @@ -13,6 +13,7 @@ */ package io.trino.operator; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Ticker; import com.google.common.collect.ImmutableList; import com.google.common.io.LittleEndianDataInputStream; @@ -145,6 +146,9 @@ public interface ClientCallback @GuardedBy("this") private String taskInstanceId; + // it is synchronized on `this` for update + private volatile long averageRequestSizeInBytes; + private final AtomicLong rowsReceived = new AtomicLong(); private final AtomicInteger pagesReceived = new AtomicInteger(); @@ -153,6 +157,7 @@ public interface ClientCallback private final AtomicInteger requestsScheduled = new AtomicInteger(); private final AtomicInteger requestsCompleted = new AtomicInteger(); + private final AtomicInteger requestsSucceeded = new AtomicInteger(); private final AtomicInteger requestsFailed = new AtomicInteger(); private final Executor pageBufferClientCallbackExecutor; @@ -251,6 +256,7 @@ else if (completed) { requestsScheduled.get(), requestsCompleted.get(), requestsFailed.get(), + requestsSucceeded.get(), httpRequestState); } @@ -259,6 +265,11 @@ public TaskId getRemoteTaskId() return remoteTaskId; } + public long getAverageRequestSizeInBytes() + { + return averageRequestSizeInBytes; + } + public synchronized boolean isRunning() { return future != null; @@ -434,6 +445,8 @@ public Void handle(Request request, Response response) } } requestsCompleted.incrementAndGet(); + long responseSize = pages.stream().mapToLong(Slice::length).sum(); + requestSucceeded(responseSize); synchronized (HttpPageBufferClient.this) { // client is complete, acknowledge it by sending it a delete in the next request @@ -485,6 +498,14 @@ public void onFailure(Throwable t) }, pageBufferClientCallbackExecutor); } + @VisibleForTesting + synchronized void requestSucceeded(long responseSize) + { + int successfulRequests = requestsSucceeded.incrementAndGet(); + // AVG_n = AVG_(n-1) * (n-1)/n + VALUE_n / n + averageRequestSizeInBytes = (long) ((1.0 * averageRequestSizeInBytes * (successfulRequests - 1)) + responseSize) / successfulRequests; + } + private synchronized void destroyTaskResults() { HttpResponseFuture resultFuture = httpClient.executeAsync(prepareDelete().setUri(location).build(), createStatusResponseHandler()); diff --git a/core/trino-main/src/main/java/io/trino/operator/PageBufferClientStatus.java b/core/trino-main/src/main/java/io/trino/operator/PageBufferClientStatus.java index a1a2dac3b707..f8584b97f210 100644 --- a/core/trino-main/src/main/java/io/trino/operator/PageBufferClientStatus.java +++ b/core/trino-main/src/main/java/io/trino/operator/PageBufferClientStatus.java @@ -37,6 +37,7 @@ public class PageBufferClientStatus private final int requestsScheduled; private final int requestsCompleted; private final int requestsFailed; + private final int requestsSucceeded; private final String httpRequestState; @JsonCreator @@ -50,6 +51,7 @@ public PageBufferClientStatus(@JsonProperty("uri") URI uri, @JsonProperty("requestsScheduled") int requestsScheduled, @JsonProperty("requestsCompleted") int requestsCompleted, @JsonProperty("requestsFailed") int requestsFailed, + @JsonProperty("requestsSucceeded") int requestsSucceeded, @JsonProperty("httpRequestState") String httpRequestState) { this.uri = uri; @@ -62,6 +64,7 @@ public PageBufferClientStatus(@JsonProperty("uri") URI uri, this.requestsScheduled = requestsScheduled; this.requestsCompleted = requestsCompleted; this.requestsFailed = requestsFailed; + this.requestsSucceeded = requestsSucceeded; this.httpRequestState = httpRequestState; } @@ -125,6 +128,12 @@ public int getRequestsFailed() return requestsFailed; } + @JsonProperty + public int getRequestsSucceeded() + { + return requestsSucceeded; + } + @JsonProperty public String getHttpRequestState() { diff --git a/core/trino-main/src/test/java/io/trino/operator/TestDirectExchangeClient.java b/core/trino-main/src/test/java/io/trino/operator/TestDirectExchangeClient.java index 36cc9ae989b5..540528144f3b 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestDirectExchangeClient.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestDirectExchangeClient.java @@ -70,6 +70,7 @@ import static io.airlift.concurrent.MoreFutures.tryGetFutureValue; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.testing.Assertions.assertLessThan; +import static io.trino.execution.TestSqlTaskExecution.TASK_ID; import static io.trino.execution.buffer.PagesSerdeUtil.getSerializedPagePositionCount; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; @@ -946,6 +947,128 @@ public void testStreamingClose() assertEquals(clientStatus.getHttpRequestState(), "not scheduled", "httpRequestState"); } + @Test + public void testScheduleWhenOneClientFilledBuffer() + { + DataSize maxResponseSize = DataSize.of(8, Unit.MEGABYTE); + + URI locationOne = URI.create("http://localhost:8080"); + URI locationTwo = URI.create("http://localhost:8081"); + MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize); + + HttpPageBufferClient clientToBeUsed = createHttpPageBufferClient(processor, maxResponseSize, locationOne, new MockClientCallback()); + HttpPageBufferClient clientToBeSkipped = createHttpPageBufferClient(processor, maxResponseSize, locationTwo, new MockClientCallback()); + clientToBeUsed.requestSucceeded(DataSize.of(33, Unit.MEGABYTE).toBytes()); + clientToBeSkipped.requestSucceeded(DataSize.of(1, Unit.MEGABYTE).toBytes()); + + @SuppressWarnings("resource") + DirectExchangeClient exchangeClient = new DirectExchangeClient( + "localhost", + DataIntegrityVerification.ABORT, + new StreamingDirectExchangeBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)), + maxResponseSize, + 1, + new Duration(1, TimeUnit.MINUTES), + true, + new TestingHttpClient(processor, scheduler), + scheduler, + new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), + pageBufferClientCallbackExecutor, + (taskId, failure) -> {}); + exchangeClient.getAllClients().putAll(Map.of(locationOne, clientToBeUsed, locationTwo, clientToBeSkipped)); + exchangeClient.getQueuedClients().addAll(ImmutableList.of(clientToBeUsed, clientToBeSkipped)); + + int clientCount = exchangeClient.scheduleRequestIfNecessary(); + // The first client filled the buffer. There is no place for the another one + assertEquals(clientCount, 1); + } + + @Test + public void testScheduleWhenAllClientsAreEmpty() + { + DataSize maxResponseSize = DataSize.of(8, Unit.MEGABYTE); + + URI locationOne = URI.create("http://localhost:8080"); + URI locationTwo = URI.create("http://localhost:8081"); + MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize); + + HttpPageBufferClient firstClient = createHttpPageBufferClient(processor, maxResponseSize, locationOne, new MockClientCallback()); + HttpPageBufferClient secondClient = createHttpPageBufferClient(processor, maxResponseSize, locationTwo, new MockClientCallback()); + + @SuppressWarnings("resource") + DirectExchangeClient exchangeClient = new DirectExchangeClient( + "localhost", + DataIntegrityVerification.ABORT, + new StreamingDirectExchangeBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)), + maxResponseSize, + 1, + new Duration(1, TimeUnit.MINUTES), + true, + new TestingHttpClient(processor, scheduler), + scheduler, + new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), + pageBufferClientCallbackExecutor, + (taskId, failure) -> {}); + exchangeClient.getAllClients().putAll(Map.of(locationOne, firstClient, locationTwo, secondClient)); + exchangeClient.getQueuedClients().addAll(ImmutableList.of(firstClient, secondClient)); + + int clientCount = exchangeClient.scheduleRequestIfNecessary(); + assertEquals(clientCount, 2); + } + + @Test + public void testScheduleWhenThereIsPendingClient() + { + DataSize maxResponseSize = DataSize.of(8, Unit.MEGABYTE); + + URI locationOne = URI.create("http://localhost:8080"); + URI locationTwo = URI.create("http://localhost:8081"); + + MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize); + + HttpPageBufferClient pendingClient = createHttpPageBufferClient(processor, maxResponseSize, locationOne, new MockClientCallback()); + HttpPageBufferClient clientToBeSkipped = createHttpPageBufferClient(processor, maxResponseSize, locationTwo, new MockClientCallback()); + + pendingClient.requestSucceeded(DataSize.of(33, Unit.MEGABYTE).toBytes()); + + @SuppressWarnings("resource") + DirectExchangeClient exchangeClient = new DirectExchangeClient( + "localhost", + DataIntegrityVerification.ABORT, + new StreamingDirectExchangeBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)), + maxResponseSize, + 1, + new Duration(1, TimeUnit.MINUTES), + true, + new TestingHttpClient(processor, scheduler), + scheduler, + new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), + pageBufferClientCallbackExecutor, + (taskId, failure) -> {}); + exchangeClient.getAllClients().putAll(Map.of(locationOne, pendingClient, locationTwo, clientToBeSkipped)); + exchangeClient.getQueuedClients().add(clientToBeSkipped); + + int clientCount = exchangeClient.scheduleRequestIfNecessary(); + // The first client is pending and it reserved the space in the buffer. There is no place for the another one + assertEquals(clientCount, 0); + } + + private HttpPageBufferClient createHttpPageBufferClient(TestingHttpClient.Processor processor, DataSize expectedMaxSize, URI location, HttpPageBufferClient.ClientCallback callback) + { + return new HttpPageBufferClient( + "localhost", + new TestingHttpClient(processor, scheduler), + DataIntegrityVerification.ABORT, + expectedMaxSize, + new Duration(1, TimeUnit.MINUTES), + true, + TASK_ID, + location, + callback, + scheduler, + pageBufferClientCallbackExecutor); + } + private static Page createPage(int size) { return new Page(BlockAssertions.createLongSequenceBlock(0, size)); @@ -985,4 +1108,29 @@ private static void assertStatus( assertEquals(clientStatus.getRequestsCompleted(), requestsCompleted, "requestsCompleted"); assertEquals(clientStatus.getHttpRequestState(), httpRequestState, "httpRequestState"); } + + private static class MockClientCallback + implements HttpPageBufferClient.ClientCallback + { + @Override + public boolean addPages(HttpPageBufferClient client, List pages) + { + return false; + } + + @Override + public void requestComplete(HttpPageBufferClient client) + { + } + + @Override + public void clientFinished(HttpPageBufferClient client) + { + } + + @Override + public void clientFailed(HttpPageBufferClient client, Throwable cause) + { + } + } } diff --git a/core/trino-main/src/test/java/io/trino/operator/TestHttpPageBufferClient.java b/core/trino-main/src/test/java/io/trino/operator/TestHttpPageBufferClient.java index 77c6293e47c2..47ab79c4a4c0 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestHttpPageBufferClient.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestHttpPageBufferClient.java @@ -438,6 +438,33 @@ public void testErrorCodes() assertEquals(new PageTransportTimeoutException(HostAddress.fromParts("127.0.0.1", 8080), "", null).getErrorCode(), PAGE_TRANSPORT_TIMEOUT.toErrorCode()); } + @Test + public void testAverageSizeOfRequest() + { + HttpPageBufferClient client = new HttpPageBufferClient( + "localhost", + new TestingHttpClient(new MockExchangeRequestProcessor(DataSize.of(10, MEGABYTE)), scheduler), + DataIntegrityVerification.ABORT, + DataSize.of(10, MEGABYTE), + new Duration(30, TimeUnit.SECONDS), + true, + TASK_ID, + URI.create("http://localhost:8080"), + new TestingClientCallback(new CyclicBarrier(1)), + scheduler, + new TestingTicker(), + pageBufferClientCallbackExecutor); + + assertEquals(client.getAverageRequestSizeInBytes(), 0); + + client.requestSucceeded(0); + assertEquals(client.getAverageRequestSizeInBytes(), 0); + + client.requestSucceeded(1000); + client.requestSucceeded(800); + assertEquals(client.getAverageRequestSizeInBytes(), 600); + } + @Test public void testMemoryExceededInAddPages() throws Exception