Skip to content

Commit

Permalink
Allow timeout for async DoFn (#5534)
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones authored Dec 16, 2024
1 parent cc1e584 commit 4d701c1
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeoutException;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.apache.beam.sdk.transforms.DoFn;
Expand Down Expand Up @@ -238,7 +239,7 @@ public void finishBundle(FinishBundleContext context) {
Thread.currentThread().interrupt();
LOG.error("Failed to process futures", e);
throw new RuntimeException("Failed to process futures", e);
} catch (ExecutionException e) {
} catch (ExecutionException | TimeoutException e) {
LOG.error("Failed to process futures", e);
throw new RuntimeException("Failed to process futures", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.function.Consumer;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
Expand Down Expand Up @@ -64,7 +65,7 @@ public void finishBundle(FinishBundleContext context) {
Thread.currentThread().interrupt();
LOG.error("Failed to process futures", e);
throw new RuntimeException("Failed to process futures", e);
} catch (ExecutionException e) {
} catch (ExecutionException | TimeoutException e) {
LOG.error("Failed to process futures", e);
throw new RuntimeException("Failed to process futures", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeoutException;
import java.util.function.Consumer;
import java.util.function.Supplier;
import javax.annotation.CheckForNull;
Expand Down Expand Up @@ -233,7 +234,7 @@ public void finishBundle(FinishBundleContext context) {
Thread.currentThread().interrupt();
LOG.error("Failed to process futures", e);
throw new RuntimeException("Failed to process futures", e);
} catch (ExecutionException e) {
} catch (ExecutionException | TimeoutException e) {
LOG.error("Failed to process futures", e);
throw new RuntimeException("Failed to process futures", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@
package com.spotify.scio.transforms;

import com.google.common.util.concurrent.*;
import java.util.concurrent.*;
import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;
import java.util.stream.StreamSupport;
import javax.annotation.Nullable;
Expand All @@ -33,7 +40,13 @@ public class FutureHandlers {
* @param <V> value type.
*/
public interface Base<F, V> {
void waitForFutures(Iterable<F> futures) throws InterruptedException, ExecutionException;

default Duration getTimeout() {
return Duration.ofMinutes(1);
}

void waitForFutures(Iterable<F> futures)
throws InterruptedException, ExecutionException, TimeoutException;

F addCallback(F future, Function<V, Void> onSuccess, Function<Throwable, Void> onFailure);
}
Expand All @@ -53,10 +66,16 @@ default Executor getCallbackExecutor() {

@Override
default void waitForFutures(Iterable<ListenableFuture<V>> futures)
throws InterruptedException, ExecutionException {
throws InterruptedException, ExecutionException, TimeoutException {
// use Future#successfulAsList instead of Futures#allAsList which only works if all
// futures succeed
Futures.successfulAsList(futures).get();
ListenableFuture<?> f = Futures.successfulAsList(futures);
Duration timeout = getTimeout();
if (timeout != null) {
f.get(timeout.toMillis(), TimeUnit.MILLISECONDS);
} else {
f.get();
}
}

@Override
Expand Down Expand Up @@ -116,10 +135,16 @@ public void onFailure(Throwable t) {
public interface Java<V> extends Base<CompletableFuture<V>, V> {
@Override
default void waitForFutures(Iterable<CompletableFuture<V>> futures)
throws InterruptedException, ExecutionException {
throws InterruptedException, ExecutionException, TimeoutException {
CompletableFuture[] array =
StreamSupport.stream(futures.spliterator(), false).toArray(CompletableFuture[]::new);
CompletableFuture.allOf(array).exceptionally(t -> null).get();
CompletableFuture<?> f = CompletableFuture.allOf(array).exceptionally(t -> null);
Duration timeout = getTimeout();
if (timeout != null) {
f.get(timeout.toMillis(), TimeUnit.MILLISECONDS);
} else {
f.get();
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.function.{Function => JFunction}

import scala.jdk.CollectionConverters._
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.concurrent.duration._
import scala.util.{Failure, Success}

/** A [[FutureHandlers.Base]] implementation for Scala [[Future]]. */
Expand All @@ -35,8 +35,10 @@ trait ScalaFutureHandlers[T] extends FutureHandlers.Base[Future[T], T] {
}

override def waitForFutures(futures: lang.Iterable[Future[T]]): Unit = {
Await.ready(Future.sequence(futures.asScala), Duration.Inf)
()
val timeout = Option(getTimeout)
.map(_.toMillis.millis)
.getOrElse(Duration.Inf)
Await.ready(Future.sequence(futures.asScala), timeout)
}

override def addCallback(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,20 @@ import com.google.common.util.concurrent.{ListenableFuture, SettableFuture}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

import java.time.{Duration => JDuration}
import java.util.concurrent.{CompletableFuture, Executor, RejectedExecutionException}
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.concurrent._
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}

class GuavaFutureHandler extends FutureHandlers.Guava[String]

class JavaFutureHandler extends FutureHandlers.Java[String]
class GuavaFutureHandler extends FutureHandlers.Guava[String] {
override def getTimeout: JDuration = JDuration.ofMillis(500)
}
class JavaFutureHandler extends FutureHandlers.Java[String] {
override def getTimeout: JDuration = JDuration.ofMillis(500)
}

class RejectFutureHandler extends FutureHandlers.Guava[String] {
override def getCallbackExecutor: Executor = _ => throw new RejectedExecutionException("Rejected")
Expand Down Expand Up @@ -185,14 +189,34 @@ class FutureHandlersTest extends AnyFlatSpec with Matchers {
}
cause.getSuppressed.headOption.map(_.getMessage) shouldBe expectedSuppressed
}

it should "wait for futures to complete" in {
import scala.concurrent.ExecutionContext.Implicits.global
val successFuture = create()
val failureFuture = create()
val cancelFuture = create()
Future {
Thread.sleep(100)
complete(successFuture)("success")
fail(failureFuture)(new Exception("failure"))
cancel(cancelFuture)
}
handler.waitForFutures(Iterable[F](successFuture, failureFuture, cancelFuture).asJava)
}

it should "throw a timeout exception " in {
val f = create()
a[TimeoutException] shouldBe thrownBy(handler.waitForFutures(Iterable[F](f).asJava))
}

}

"Guava handler" should behave like futureHandler[
ListenableFuture[String],
SettableFuture[String]
](
new GuavaFutureHandler,
SettableFuture.create[String],
() => SettableFuture.create[String](),
_.set,
_.setException,
_.cancel(true),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.spotify.scio.transforms.BaseAsyncLookupDoFn;
import com.spotify.scio.transforms.GuavaAsyncLookupDoFn;
import java.io.IOException;
import java.time.Duration;
import org.apache.beam.sdk.transforms.DoFn;

/**
Expand Down Expand Up @@ -99,6 +100,11 @@ public ResourceType getResourceType() {
return ResourceType.PER_INSTANCE;
}

@Override
public Duration getTimeout() {
return Duration.ofMillis(options.getCallOptionsConfig().getMutateRpcTimeoutMs());
}

protected BigtableSession newClient() {
try {
return new BigtableSession(options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package com.spotify.scio.bigtable

import com.google.cloud.bigtable.config.BigtableOptions

import java.util.concurrent.ConcurrentLinkedQueue
import com.google.cloud.bigtable.grpc.BigtableSession
import com.google.common.cache.{Cache, CacheBuilder}
Expand Down Expand Up @@ -66,21 +68,26 @@ object BigtableDoFnTest {
val queue: ConcurrentLinkedQueue[Int] = new ConcurrentLinkedQueue[Int]()
}

class TestBigtableDoFn extends BigtableDoFn[Int, String](null) {
class TestBigtableDoFn extends BigtableDoFn[Int, String](BigtableOptions.getDefaultOptions) {
override def newClient(): BigtableSession = null
override def asyncLookup(session: BigtableSession, input: Int): ListenableFuture[String] =
Futures.immediateFuture(input.toString)
}

class TestCachingBigtableDoFn extends BigtableDoFn[Int, String](null, 100, new TestCacheSupplier) {
class TestCachingBigtableDoFn
extends BigtableDoFn[Int, String](
BigtableOptions.getDefaultOptions,
100,
new TestCacheSupplier
) {
override def newClient(): BigtableSession = null
override def asyncLookup(session: BigtableSession, input: Int): ListenableFuture[String] = {
BigtableDoFnTest.queue.add(input)
Futures.immediateFuture(input.toString)
}
}

class TestFailingBigtableDoFn extends BigtableDoFn[Int, String](null) {
class TestFailingBigtableDoFn extends BigtableDoFn[Int, String](BigtableOptions.getDefaultOptions) {
override def newClient(): BigtableSession = null
override def asyncLookup(session: BigtableSession, input: Int): ListenableFuture[String] =
if (input % 2 == 0) {
Expand Down

0 comments on commit 4d701c1

Please sign in to comment.