diff --git a/engine/src/main/java/net/jqwik/engine/properties/state/ShrinkableChainIteration.java b/engine/src/main/java/net/jqwik/engine/properties/state/ShrinkableChainIteration.java index 4567d964d..0481dde9d 100644 --- a/engine/src/main/java/net/jqwik/engine/properties/state/ShrinkableChainIteration.java +++ b/engine/src/main/java/net/jqwik/engine/properties/state/ShrinkableChainIteration.java @@ -2,11 +2,61 @@ import java.util.*; import java.util.function.*; +import java.util.stream.*; import net.jqwik.api.*; import net.jqwik.api.state.*; class ShrinkableChainIteration { + // Larger values might improve shrink quality, however, they increase the shrink space, so it might increase shrink duration + private final static int NUM_SAMPLES_IN_EAGER_CHAIN_SHRINK = Integer.getInteger("jqwik.eagerChainShrinkSamples", 2); + + static class ShrinkableWithEagerValue implements Shrinkable { + protected final Shrinkable base; + private final ShrinkingDistance distance; + final T value; + + ShrinkableWithEagerValue(Shrinkable base) { + this.base = base; + this.distance = base.distance(); + this.value = base.value(); + } + + @Override + public T value() { + return value; + } + + @Override + public Stream> shrink() { + return base.shrink(); + } + + @Override + public ShrinkingDistance distance() { + return distance; + } + } + + static class EagerShrinkable extends ShrinkableWithEagerValue { + private final List> shrinkResults; + + EagerShrinkable(Shrinkable base, int numSamples) { + super(base); + this.shrinkResults = + base.shrink() + .sorted(Comparator.comparing(Shrinkable::distance)) + .limit(numSamples) + .map(ShrinkableWithEagerValue::new) + .collect(Collectors.toList()); + } + + @Override + public Stream> shrink() { + return shrinkResults.stream(); + } + } + final Shrinkable> shrinkable; private final Predicate precondition; final boolean accessState; @@ -30,19 +80,29 @@ private ShrinkableChainIteration( this.precondition = precondition; this.accessState = accessState; this.changeState = changeState; - this.shrinkable = shrinkable; + // When the shrinkable does not access state, we could just use it as is for ".value()", and ".shrink()" + // If we get LazyShrinkable here, it means we are in a shrinking phase, so we know ".shrink()" will be called only + // in case the subsequent execution fails. So we can just keep LazyShrinkable as is + // Otherwise, we need to eagerly evaluate the shrinkables to since the state might change by appyling subsequent transformers, + // so we won't be able to access the state anymore. + // See https://github.com/jlink/jqwik/issues/428 + if (!accessState || shrinkable instanceof ShrinkableChainIteration.ShrinkableWithEagerValue) { + this.shrinkable = shrinkable; + } else { + this.shrinkable = new EagerShrinkable<>(shrinkable, NUM_SAMPLES_IN_EAGER_CHAIN_SHRINK); + } } @Override public String toString() { return String.format( "Iteration[accessState=%s, changeState=%s, transformation=%s]", - accessState, changeState, shrinkable.value().transformation() + accessState, changeState, transformer().transformation() ); } boolean isEndOfChain() { - return shrinkable.value().equals(Transformer.END_OF_CHAIN); + return transformer().equals(Transformer.END_OF_CHAIN); } Optional> precondition() { diff --git a/engine/src/test/java/net/jqwik/engine/properties/state/ActionChainArbitraryTests.java b/engine/src/test/java/net/jqwik/engine/properties/state/ActionChainArbitraryTests.java index bda69b935..5ff10bb1c 100644 --- a/engine/src/test/java/net/jqwik/engine/properties/state/ActionChainArbitraryTests.java +++ b/engine/src/test/java/net/jqwik/engine/properties/state/ActionChainArbitraryTests.java @@ -119,6 +119,102 @@ ActionChainArbitrary xOrFailing() { .withMaxTransformations(30); } + static class SetMutatingChainState { + final List actualOps = new ArrayList<>(); + boolean hasPrints; + final Set set = new HashSet<>(); + + @Override + public String toString() { + return "set=" + set + ", actualOps=" + actualOps; + } + } + + @Property + void chainActionsAreProperlyDescribedEvenAfterChainExecution(@ForAll("setMutatingChain") ActionChain chain) { + chain = chain.withInvariant( + state -> { + if (state.hasPrints) { + assertThat(state.actualOps).hasSizeLessThan(5); + } + } + ); + + SetMutatingChainState finalState = chain.run(); + + assertThat(chain.transformations()) + .describedAs("chain.transformations() should be the same as the list of operations in finalState.actualOps, final state is %s", finalState.set) + .isEqualTo(finalState.actualOps); + } + + @Provide + public ActionChainArbitrary setMutatingChain() { + return + ActionChain + .startWith(SetMutatingChainState::new) + // This is an action that does not depend on the state to produce the transformation + .addAction( + 1, + Action.just("clear anyway", state -> { + state.actualOps.add("clear anyway"); + state.set.clear(); + return state; + }) + ) + // Below actions depend on the state to derive the transformations + .addAction( + 1, + (Action.Dependent) + state -> + Arbitraries + .just( + state.set.isEmpty() + ? Transformer.noop() + : Transformer.mutate("clear " + state.set, set -> { + state.actualOps.add("clear " + set.set); + state.set.clear(); + }) + ) + ) + .addAction( + 4, + (Action.Dependent) + state -> + Arbitraries + .integers() + .between(1, 10) + .map(i -> { + if (state.set.contains(i)) { + return Transformer.noop(); + } + return Transformer.mutate("add " + i + " to " + state.set, newState -> { + newState.actualOps.add("add " + i + " to " + newState.set); + newState.set.add(i); + }); + } + ) + ) + .addAction( + 2, + (Action.Dependent) + state -> + state.set.isEmpty() ? Arbitraries.just(Transformer.noop()) : + Arbitraries + .of(state.set) + .map(i -> { + if (!state.set.contains(i)) { + throw new IllegalStateException("The set does not contain " + i + ", current state is " + state); + } + return Transformer.mutate("print " + i + " from " + state.set, newState -> { + newState.actualOps.add("print " + i + " from " + newState.set); + newState.hasPrints = true; + }); + } + ) + ) + .withMaxTransformations(7); + } + @Property void chainChoosesBetweenTwoActions(@ForAll("xOrY") ActionChain chain) { String result = chain.run();