diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java index b57f4a9ff5734..89d6fea291105 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java @@ -275,8 +275,13 @@ public interface SingletonAssert { /** * Constructs an {@link IterableAssert} for the elements of the provided {@link PCollection}. */ + public static IterableAssert that(String message, PCollection actual) { + return new PCollectionContentsAssert<>(PAssertionSite.capture(message), actual); + } + + /** @see #that(String, PCollection) */ public static IterableAssert that(PCollection actual) { - return new PCollectionContentsAssert<>(actual); + return that("", actual); } /** @@ -284,7 +289,7 @@ public static IterableAssert that(PCollection actual) { * must contain a single {@code Iterable} value. */ public static IterableAssert thatSingletonIterable( - PCollection> actual) { + String message, PCollection> actual) { try { } catch (NoSuchElementException | IllegalArgumentException exc) { @@ -297,15 +302,29 @@ public static IterableAssert thatSingletonIterable( @SuppressWarnings("unchecked") // Safe covariant cast PCollection> actualIterables = (PCollection>) actual; - return new PCollectionSingletonIterableAssert<>(actualIterables); + return new PCollectionSingletonIterableAssert<>( + PAssertionSite.capture(message), actualIterables); } - /** - * Constructs a {@link SingletonAssert} for the value of the provided - * {@code PCollection PCollection}, which must be a singleton. - */ + /** @see #thatSingletonIterable(String, PCollection) */ + public static IterableAssert thatSingletonIterable( + PCollection> actual) { + return thatSingletonIterable("", actual); + } + + /** + * Constructs a {@link SingletonAssert} for the value of the provided + * {@code PCollection PCollection}, which must be a singleton. + */ + public static SingletonAssert thatSingleton(String message, PCollection actual) { + return new PCollectionViewAssert<>( + PAssertionSite.capture(message), + actual, View.asSingleton(), actual.getCoder()); + } + + /** @see #thatSingleton(String, PCollection) */ public static SingletonAssert thatSingleton(PCollection actual) { - return new PCollectionViewAssert<>(actual, View.asSingleton(), actual.getCoder()); + return thatSingleton("", actual); } /** @@ -315,48 +334,90 @@ public static SingletonAssert thatSingleton(PCollection actual) { * {@code Coder}. */ public static SingletonAssert>> thatMultimap( + String message, PCollection> actual) { @SuppressWarnings("unchecked") KvCoder kvCoder = (KvCoder) actual.getCoder(); return new PCollectionViewAssert<>( + PAssertionSite.capture(message), actual, View.asMultimap(), MapCoder.of(kvCoder.getKeyCoder(), IterableCoder.of(kvCoder.getValueCoder()))); } + /** @see #thatMultimap(String, PCollection) */ + public static SingletonAssert>> thatMultimap( + PCollection> actual) { + return thatMultimap("", actual); + } + /** * Constructs a {@link SingletonAssert} for the value of the provided {@link PCollection}, which * must have at most one value per key. * - *

Note that the actual value must be coded by a {@link KvCoder}, not just any - * {@code Coder}. + *

Note that the actual value must be coded by a {@link KvCoder}, not just any {@code Coder}. */ - public static SingletonAssert> thatMap(PCollection> actual) { + public static SingletonAssert> thatMap( + String message, PCollection> actual) { @SuppressWarnings("unchecked") KvCoder kvCoder = (KvCoder) actual.getCoder(); return new PCollectionViewAssert<>( - actual, View.asMap(), MapCoder.of(kvCoder.getKeyCoder(), kvCoder.getValueCoder())); + PAssertionSite.capture(message), + actual, + View.asMap(), + MapCoder.of(kvCoder.getKeyCoder(), kvCoder.getValueCoder())); + } + + /** @see #thatMap(String, PCollection) */ + public static SingletonAssert> thatMap(PCollection> actual) { + return thatMap("", actual); } //////////////////////////////////////////////////////////// + private static class PAssertionSite implements Serializable { + private final String message; + private final StackTraceElement[] creationStackTrace; + + static PAssertionSite capture(String message) { + return new PAssertionSite(message, new Throwable().getStackTrace()); + } + + PAssertionSite(String message, StackTraceElement[] creationStackTrace) { + this.message = message; + this.creationStackTrace = creationStackTrace; + } + + public AssertionError wrap(Throwable t) { + AssertionError res = + new AssertionError( + message.isEmpty() ? t.getMessage() : (message + ": " + t.getMessage()), t); + res.setStackTrace(creationStackTrace); + return res; + } + } + /** * An {@link IterableAssert} about the contents of a {@link PCollection}. This does not require * the runner to support side inputs. */ private static class PCollectionContentsAssert implements IterableAssert { + private final PAssertionSite site; private final PCollection actual; private final AssertionWindows rewindowingStrategy; private final SimpleFunction>, Iterable> paneExtractor; - public PCollectionContentsAssert(PCollection actual) { - this(actual, IntoGlobalWindow.of(), PaneExtractors.allPanes()); + public PCollectionContentsAssert(PAssertionSite site, PCollection actual) { + this(site, actual, IntoGlobalWindow.of(), PaneExtractors.allPanes()); } public PCollectionContentsAssert( + PAssertionSite site, PCollection actual, AssertionWindows rewindowingStrategy, SimpleFunction>, Iterable> paneExtractor) { + this.site = site; this.actual = actual; this.rewindowingStrategy = rewindowingStrategy; this.paneExtractor = paneExtractor; @@ -394,7 +455,7 @@ private PCollectionContentsAssert withPane( Coder windowCoder = (Coder) actual.getWindowingStrategy().getWindowFn().windowCoder(); return new PCollectionContentsAssert<>( - actual, IntoStaticWindows.of(windowCoder, window), paneExtractor); + site, actual, IntoStaticWindows.of(windowCoder, window), paneExtractor); } /** @@ -429,7 +490,7 @@ public PCollectionContentsAssert satisfies( SerializableFunction, Void> checkerFn) { actual.apply( nextAssertionName(), - new GroupThenAssert<>(checkerFn, rewindowingStrategy, paneExtractor)); + new GroupThenAssert<>(site, checkerFn, rewindowingStrategy, paneExtractor)); return this; } @@ -471,7 +532,7 @@ PCollectionContentsAssert satisfies( (SerializableFunction) new MatcherCheckerFn<>(matcher); actual.apply( "PAssert$" + (assertCount++), - new GroupThenAssert<>(checkerFn, rewindowingStrategy, paneExtractor)); + new GroupThenAssert<>(site, checkerFn, rewindowingStrategy, paneExtractor)); return this; } @@ -518,21 +579,26 @@ public int hashCode() { * This does not require the runner to support side inputs. */ private static class PCollectionSingletonIterableAssert implements IterableAssert { + private final PAssertionSite site; private final PCollection> actual; private final Coder elementCoder; private final AssertionWindows rewindowingStrategy; private final SimpleFunction>>, Iterable>> paneExtractor; - public PCollectionSingletonIterableAssert(PCollection> actual) { - this(actual, IntoGlobalWindow.>of(), PaneExtractors.>onlyPane()); + public PCollectionSingletonIterableAssert( + PAssertionSite site, PCollection> actual) { + this( + site, actual, IntoGlobalWindow.>of(), PaneExtractors.>onlyPane()); } public PCollectionSingletonIterableAssert( + PAssertionSite site, PCollection> actual, AssertionWindows rewindowingStrategy, SimpleFunction>>, Iterable>> paneExtractor) { + this.site = site; this.actual = actual; @SuppressWarnings("unchecked") @@ -576,7 +642,7 @@ private PCollectionSingletonIterableAssert withPanes( Coder windowCoder = (Coder) actual.getWindowingStrategy().getWindowFn().windowCoder(); return new PCollectionSingletonIterableAssert<>( - actual, IntoStaticWindows.>of(windowCoder, window), paneExtractor); + site, actual, IntoStaticWindows.>of(windowCoder, window), paneExtractor); } @Override @@ -600,7 +666,7 @@ public PCollectionSingletonIterableAssert satisfies( SerializableFunction, Void> checkerFn) { actual.apply( "PAssert$" + (assertCount++), - new GroupThenAssertForSingleton<>(checkerFn, rewindowingStrategy, paneExtractor)); + new GroupThenAssertForSingleton<>(site, checkerFn, rewindowingStrategy, paneExtractor)); return this; } @@ -617,6 +683,7 @@ private PCollectionSingletonIterableAssert satisfies( * of type {@code ViewT}. This requires side input support from the runner. */ private static class PCollectionViewAssert implements SingletonAssert { + private final PAssertionSite site; private final PCollection actual; private final PTransform, PCollectionView> view; private final AssertionWindows rewindowActuals; @@ -625,18 +692,27 @@ private static class PCollectionViewAssert implements SingletonAss private final Coder coder; protected PCollectionViewAssert( + PAssertionSite site, PCollection actual, PTransform, PCollectionView> view, Coder coder) { - this(actual, view, IntoGlobalWindow.of(), PaneExtractors.onlyPane(), coder); + this( + site, + actual, + view, + IntoGlobalWindow.of(), + PaneExtractors.onlyPane(), + coder); } private PCollectionViewAssert( + PAssertionSite site, PCollection actual, PTransform, PCollectionView> view, AssertionWindows rewindowActuals, SimpleFunction>, Iterable> paneExtractor, Coder coder) { + this.site = site; this.actual = actual; this.view = view; this.rewindowActuals = rewindowActuals; @@ -663,6 +739,7 @@ private PCollectionViewAssert inPane( BoundedWindow window, SimpleFunction>, Iterable> paneExtractor) { return new PCollectionViewAssert<>( + site, actual, view, IntoStaticWindows.of( @@ -689,6 +766,7 @@ public PCollectionViewAssert satisfies( .apply( "PAssert$" + (assertCount++), new OneSideInputAssert( + site, CreateActual.from(actual, rewindowActuals, paneExtractor, view), rewindowActuals.windowDummy(), checkerFn)); @@ -911,14 +989,17 @@ public void processElement(ProcessContext c) throws Exception { */ public static class GroupThenAssert extends PTransform, PDone> implements Serializable { + private final PAssertionSite site; private final SerializableFunction, Void> checkerFn; private final AssertionWindows rewindowingStrategy; private final SimpleFunction>, Iterable> paneExtractor; private GroupThenAssert( + PAssertionSite site, SerializableFunction, Void> checkerFn, AssertionWindows rewindowingStrategy, SimpleFunction>, Iterable> paneExtractor) { + this.site = site; this.checkerFn = checkerFn; this.rewindowingStrategy = rewindowingStrategy; this.paneExtractor = paneExtractor; @@ -930,7 +1011,7 @@ public PDone expand(PCollection input) { .apply("GroupGlobally", new GroupGlobally(rewindowingStrategy)) .apply("GetPane", MapElements.via(paneExtractor)) .setCoder(IterableCoder.of(input.getCoder())) - .apply("RunChecks", ParDo.of(new GroupedValuesCheckerDoFn<>(checkerFn))); + .apply("RunChecks", ParDo.of(new GroupedValuesCheckerDoFn<>(site, checkerFn))); return PDone.in(input.getPipeline()); } @@ -942,16 +1023,19 @@ public PDone expand(PCollection input) { */ public static class GroupThenAssertForSingleton extends PTransform>, PDone> implements Serializable { + private final PAssertionSite site; private final SerializableFunction, Void> checkerFn; private final AssertionWindows rewindowingStrategy; private final SimpleFunction>>, Iterable>> paneExtractor; private GroupThenAssertForSingleton( + PAssertionSite site, SerializableFunction, Void> checkerFn, AssertionWindows rewindowingStrategy, SimpleFunction>>, Iterable>> paneExtractor) { + this.site = site; this.checkerFn = checkerFn; this.rewindowingStrategy = rewindowingStrategy; this.paneExtractor = paneExtractor; @@ -963,7 +1047,7 @@ public PDone expand(PCollection> input) { .apply("GroupGlobally", new GroupGlobally>(rewindowingStrategy)) .apply("GetPane", MapElements.via(paneExtractor)) .setCoder(IterableCoder.of(input.getCoder())) - .apply("RunChecks", ParDo.of(new SingletonCheckerDoFn<>(checkerFn))); + .apply("RunChecks", ParDo.of(new SingletonCheckerDoFn<>(site, checkerFn))); return PDone.in(input.getPipeline()); } @@ -981,14 +1065,17 @@ public PDone expand(PCollection> input) { */ public static class OneSideInputAssert extends PTransform implements Serializable { + private final PAssertionSite site; private final transient PTransform> createActual; private final transient PTransform, PCollection> windowToken; private final SerializableFunction checkerFn; private OneSideInputAssert( + PAssertionSite site, PTransform> createActual, PTransform, PCollection> windowToken, SerializableFunction checkerFn) { + this.site = site; this.createActual = createActual; this.windowToken = windowToken; this.checkerFn = checkerFn; @@ -1003,7 +1090,7 @@ public PDone expand(PBegin input) { .apply("WindowToken", windowToken) .apply( "RunChecks", - ParDo.withSideInputs(actual).of(new SideInputCheckerDoFn<>(checkerFn, actual))); + ParDo.withSideInputs(actual).of(new SideInputCheckerDoFn<>(site, checkerFn, actual))); return PDone.in(input.getPipeline()); } @@ -1017,6 +1104,7 @@ public PDone expand(PBegin input) { * null values. */ private static class SideInputCheckerDoFn extends DoFn { + private final PAssertionSite site; private final SerializableFunction checkerFn; private final Aggregator success = createAggregator(SUCCESS_COUNTER, Sum.ofIntegers()); @@ -1025,7 +1113,10 @@ private static class SideInputCheckerDoFn extends DoFn { private final PCollectionView actual; private SideInputCheckerDoFn( - SerializableFunction checkerFn, PCollectionView actual) { + PAssertionSite site, + SerializableFunction checkerFn, + PCollectionView actual) { + this.site = site; this.checkerFn = checkerFn; this.actual = actual; } @@ -1034,7 +1125,7 @@ private SideInputCheckerDoFn( public void processElement(ProcessContext c) { try { ActualT actualContents = c.sideInput(actual); - doChecks(actualContents, checkerFn, success, failure); + doChecks(site, actualContents, checkerFn, success, failure); } catch (Throwable t) { // Suppress exception in streaming if (!c.getPipelineOptions().as(StreamingOptions.class).isStreaming()) { @@ -1052,19 +1143,22 @@ public void processElement(ProcessContext c) { *

The singleton property is presumed, not enforced. */ private static class GroupedValuesCheckerDoFn extends DoFn { + private final PAssertionSite site; private final SerializableFunction checkerFn; private final Aggregator success = createAggregator(SUCCESS_COUNTER, Sum.ofIntegers()); private final Aggregator failure = createAggregator(FAILURE_COUNTER, Sum.ofIntegers()); - private GroupedValuesCheckerDoFn(SerializableFunction checkerFn) { + private GroupedValuesCheckerDoFn( + PAssertionSite site, SerializableFunction checkerFn) { + this.site = site; this.checkerFn = checkerFn; } @ProcessElement public void processElement(ProcessContext c) { - doChecks(c.element(), checkerFn, success, failure); + doChecks(site, c.element(), checkerFn, success, failure); } } @@ -1077,24 +1171,28 @@ public void processElement(ProcessContext c) { * each input element must be a singleton iterable, or this will fail. */ private static class SingletonCheckerDoFn extends DoFn, Void> { + private final PAssertionSite site; private final SerializableFunction checkerFn; private final Aggregator success = createAggregator(SUCCESS_COUNTER, Sum.ofIntegers()); private final Aggregator failure = createAggregator(FAILURE_COUNTER, Sum.ofIntegers()); - private SingletonCheckerDoFn(SerializableFunction checkerFn) { + private SingletonCheckerDoFn( + PAssertionSite site, SerializableFunction checkerFn) { + this.site = site; this.checkerFn = checkerFn; } @ProcessElement public void processElement(ProcessContext c) { ActualT actualContents = Iterables.getOnlyElement(c.element()); - doChecks(actualContents, checkerFn, success, failure); + doChecks(site, actualContents, checkerFn, success, failure); } } private static void doChecks( + PAssertionSite site, ActualT actualContents, SerializableFunction checkerFn, Aggregator successAggregator, @@ -1103,9 +1201,8 @@ private static void doChecks( checkerFn.apply(actualContents); successAggregator.addValue(1); } catch (Throwable t) { - LOG.error("PAssert failed expectations.", t); failureAggregator.addValue(1); - throw t; + throw site.wrap(t); } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java index 1997bbeef337e..e09f54b3b6731 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java @@ -24,6 +24,7 @@ import static org.junit.Assert.fail; import com.fasterxml.jackson.annotation.JsonCreator; +import com.google.common.base.Throwables; import com.google.common.collect.Iterables; import java.io.IOException; import java.io.InputStream; @@ -392,6 +393,49 @@ public void testEmptyFalse() throws Exception { assertThat(thrown.getMessage(), containsString("Expected: iterable over [] in any order")); } + @Test + @Category(RunnableOnService.class) + public void testAssertionSiteIsCapturedWithMessage() throws Exception { + PCollection vals = pipeline.apply(CountingInput.upTo(5L)); + assertThatCollectionIsEmptyWithMessage(vals); + + Throwable thrown = runExpectingAssertionFailure(pipeline); + + assertThat( + thrown.getMessage(), + containsString("Should be empty")); + assertThat( + thrown.getMessage(), + containsString("Expected: iterable over [] in any order")); + String stacktrace = Throwables.getStackTraceAsString(thrown); + assertThat(stacktrace, containsString("testAssertionSiteIsCapturedWithMessage")); + assertThat(stacktrace, containsString("assertThatCollectionIsEmptyWithMessage")); + } + + @Test + @Category(RunnableOnService.class) + public void testAssertionSiteIsCapturedWithoutMessage() throws Exception { + PCollection vals = pipeline.apply(CountingInput.upTo(5L)); + assertThatCollectionIsEmptyWithoutMessage(vals); + + Throwable thrown = runExpectingAssertionFailure(pipeline); + + assertThat( + thrown.getMessage(), + containsString("Expected: iterable over [] in any order")); + String stacktrace = Throwables.getStackTraceAsString(thrown); + assertThat(stacktrace, containsString("testAssertionSiteIsCapturedWithoutMessage")); + assertThat(stacktrace, containsString("assertThatCollectionIsEmptyWithoutMessage")); + } + + private static void assertThatCollectionIsEmptyWithMessage(PCollection vals) { + PAssert.that("Should be empty", vals).empty(); + } + + private static void assertThatCollectionIsEmptyWithoutMessage(PCollection vals) { + PAssert.that(vals).empty(); + } + private static Throwable runExpectingAssertionFailure(Pipeline pipeline) { // We cannot use thrown.expect(AssertionError.class) because the AssertionError // is first caught by JUnit and causes a test failure.