Skip to content

Commit

Permalink
fix: initialize nested selectors correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
triceo committed Sep 3, 2024
1 parent cee979f commit 7149abe
Show file tree
Hide file tree
Showing 28 changed files with 304 additions and 372 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiPredicate;
import java.util.function.Consumer;
import java.util.function.Predicate;

Expand All @@ -37,7 +38,6 @@
import ai.timefold.solver.core.api.domain.variable.PlanningVariable;
import ai.timefold.solver.core.api.domain.variable.PreviousElementShadowVariable;
import ai.timefold.solver.core.api.domain.variable.ShadowVariable;
import ai.timefold.solver.core.api.score.director.ScoreDirector;
import ai.timefold.solver.core.config.heuristic.selector.common.decorator.SelectionSorterOrder;
import ai.timefold.solver.core.config.util.ConfigUtils;
import ai.timefold.solver.core.impl.domain.common.ReflectionHelper;
Expand All @@ -61,11 +61,9 @@
import ai.timefold.solver.core.impl.domain.variable.nextprev.NextElementShadowVariableDescriptor;
import ai.timefold.solver.core.impl.domain.variable.nextprev.PreviousElementShadowVariableDescriptor;
import ai.timefold.solver.core.impl.heuristic.selector.common.decorator.ComparatorSelectionSorter;
import ai.timefold.solver.core.impl.heuristic.selector.common.decorator.SelectionFilter;
import ai.timefold.solver.core.impl.heuristic.selector.common.decorator.SelectionSorter;
import ai.timefold.solver.core.impl.heuristic.selector.common.decorator.SelectionSorterWeightFactory;
import ai.timefold.solver.core.impl.heuristic.selector.common.decorator.WeightFactorySelectionSorter;
import ai.timefold.solver.core.impl.heuristic.selector.entity.decorator.PinEntityFilter;
import ai.timefold.solver.core.impl.util.CollectionUtils;
import ai.timefold.solver.core.impl.util.MutableInt;

Expand Down Expand Up @@ -102,19 +100,19 @@ public class EntityDescriptor<Solution_> {
private Predicate<Object> hasNoNullVariablesBasicVar;
private Predicate<Object> hasNoNullVariablesListVar;
// Only declared movable filter, excludes inherited and descending movable filters
private SelectionFilter<Solution_, Object> declaredMovableEntitySelectionFilter;
private MovableFilter<Solution_> declaredMovableEntityFilter;
private SelectionSorter<Solution_, Object> decreasingDifficultySorter;

// Only declared variable descriptors, excludes inherited variable descriptors
private Map<String, GenuineVariableDescriptor<Solution_>> declaredGenuineVariableDescriptorMap;
private Map<String, ShadowVariableDescriptor<Solution_>> declaredShadowVariableDescriptorMap;
private Map<String, CascadingUpdateShadowVariableDescriptor<Solution_>> declaredCascadingUpdateShadowVariableDecriptorMap;

private List<SelectionFilter<Solution_, Object>> declaredPinEntityFilterList;
private List<MovableFilter<Solution_>> declaredPinEntityFilterList;
private List<EntityDescriptor<Solution_>> inheritedEntityDescriptorList;

// Caches the inherited, declared and descending movable filters (including @PlanningPin filters) as a composite filter
private SelectionFilter<Solution_, Object> effectiveMovableEntitySelectionFilter;
private MovableFilter<Solution_> effectiveMovableEntityFilter;
private PlanningPinToIndexReader<Solution_> effectivePlanningPinToIndexReader;

// Caches the inherited and declared variable descriptors
Expand Down Expand Up @@ -198,7 +196,7 @@ public <A> Predicate<A> getHasNoNullVariablesPredicateListVar() {
// ************************************************************************

public void processAnnotations(DescriptorPolicy descriptorPolicy) {
processEntityAnnotations(descriptorPolicy);
processEntityAnnotations();
declaredGenuineVariableDescriptorMap = new LinkedHashMap<>();
declaredShadowVariableDescriptorMap = new LinkedHashMap<>();
declaredCascadingUpdateShadowVariableDecriptorMap = new HashMap<>();
Expand All @@ -219,29 +217,28 @@ public void processAnnotations(DescriptorPolicy descriptorPolicy) {
processVariableAnnotations(descriptorPolicy);
}

private void processEntityAnnotations(DescriptorPolicy descriptorPolicy) {
private void processEntityAnnotations() {
PlanningEntity entityAnnotation = entityClass.getAnnotation(PlanningEntity.class);
if (entityAnnotation == null) {
throw new IllegalStateException("The entityClass (" + entityClass
+ ") has been specified as a planning entity in the configuration," +
" but does not have a @" + PlanningEntity.class.getSimpleName() + " annotation.");
}
processMovable(descriptorPolicy, entityAnnotation);
processDifficulty(descriptorPolicy, entityAnnotation);
processMovable(entityAnnotation);
processDifficulty(entityAnnotation);
}

private void processMovable(DescriptorPolicy descriptorPolicy, PlanningEntity entityAnnotation) {
Class<? extends PinningFilter> pinningFilterClass = entityAnnotation.pinningFilter();
boolean hasPinningFilter = pinningFilterClass != PlanningEntity.NullPinningFilter.class;
private void processMovable(PlanningEntity entityAnnotation) {
var pinningFilterClass = entityAnnotation.pinningFilter();
var hasPinningFilter = pinningFilterClass != PlanningEntity.NullPinningFilter.class;
if (hasPinningFilter) {
PinningFilter<Solution_, Object> pinningFilter = ConfigUtils.newInstance(this::toString, "pinningFilterClass",
var pinningFilter = ConfigUtils.newInstance(this::toString, "pinningFilterClass",
(Class<? extends PinningFilter<Solution_, Object>>) pinningFilterClass);
declaredMovableEntitySelectionFilter =
(scoreDirector, selection) -> !pinningFilter.accept(scoreDirector.getWorkingSolution(), selection);
declaredMovableEntityFilter = (solution, selection) -> !pinningFilter.accept(solution, selection);
}
}

private void processDifficulty(DescriptorPolicy descriptorPolicy, PlanningEntity entityAnnotation) {
private void processDifficulty(PlanningEntity entityAnnotation) {
Class<? extends Comparator> difficultyComparatorClass = entityAnnotation.difficultyComparatorClass();
if (difficultyComparatorClass == PlanningEntity.NullDifficultyComparator.class) {
difficultyComparatorClass = null;
Expand Down Expand Up @@ -479,27 +476,29 @@ private void createEffectiveVariableDescriptorMaps() {
}

private void createEffectiveMovableEntitySelectionFilter() {
if (declaredMovableEntitySelectionFilter != null && !hasAnyDeclaredGenuineVariableDescriptor()) {
if (declaredMovableEntityFilter != null && !hasAnyDeclaredGenuineVariableDescriptor()) {
throw new IllegalStateException("The entityClass (" + entityClass
+ ") has a movableEntitySelectionFilterClass (" + declaredMovableEntitySelectionFilter.getClass()
+ ") has a movableEntitySelectionFilterClass (" + declaredMovableEntityFilter.getClass()
+ "), but it has no declared genuine variables, only shadow variables.");
}
List<SelectionFilter<Solution_, Object>> selectionFilterList = new ArrayList<>();
var movableFilterList = new ArrayList<MovableFilter<Solution_>>();
// TODO Also add in child entity selectors
for (EntityDescriptor<Solution_> inheritedEntityDescriptor : inheritedEntityDescriptorList) {
if (inheritedEntityDescriptor.hasEffectiveMovableEntitySelectionFilter()) {
for (var inheritedEntityDescriptor : inheritedEntityDescriptorList) {
if (inheritedEntityDescriptor.hasEffectiveMovableEntityFilter()) {
// Includes movable and pinned
selectionFilterList.add(inheritedEntityDescriptor.getEffectiveMovableEntitySelectionFilter());
movableFilterList.add(inheritedEntityDescriptor.effectiveMovableEntityFilter);
}
}
if (declaredMovableEntitySelectionFilter != null) {
selectionFilterList.add(declaredMovableEntitySelectionFilter);
if (declaredMovableEntityFilter != null) {
movableFilterList.add(declaredMovableEntityFilter);
}
selectionFilterList.addAll(declaredPinEntityFilterList);
if (selectionFilterList.isEmpty()) {
effectiveMovableEntitySelectionFilter = null;
movableFilterList.addAll(declaredPinEntityFilterList);
if (movableFilterList.isEmpty()) {
effectiveMovableEntityFilter = null;
} else {
effectiveMovableEntitySelectionFilter = SelectionFilter.compose(selectionFilterList);
effectiveMovableEntityFilter = movableFilterList.stream()
.reduce(MovableFilter::and)
.orElseThrow(() -> new IllegalStateException("Impossible state: no movable filters."));
}
}

Expand Down Expand Up @@ -562,20 +561,20 @@ public boolean matchesEntity(Object entity) {
return entityClass.isAssignableFrom(entity.getClass());
}

public boolean hasEffectiveMovableEntitySelectionFilter() {
return effectiveMovableEntitySelectionFilter != null;
public boolean hasEffectiveMovableEntityFilter() {
return effectiveMovableEntityFilter != null;
}

public boolean hasCascadingShadowVariables() {
return !declaredShadowVariableDescriptorMap.isEmpty();
}

public boolean supportsPinning() {
return hasEffectiveMovableEntitySelectionFilter() || effectivePlanningPinToIndexReader != null;
return hasEffectiveMovableEntityFilter() || effectivePlanningPinToIndexReader != null;
}

public SelectionFilter<Solution_, Object> getEffectiveMovableEntitySelectionFilter() {
return effectiveMovableEntitySelectionFilter;
public BiPredicate<Solution_, Object> getEffectiveMovableEntityFilter() {
return effectiveMovableEntityFilter;
}

public SelectionSorter<Solution_, Object> getDecreasingDifficultySorter() {
Expand Down Expand Up @@ -724,16 +723,15 @@ public long getMaximumValueCount(Solution_ solution, Object entity) {

}

public void processProblemScale(ScoreDirector<Solution_> scoreDirector, Solution_ solution, Object entity,
ProblemScaleTracker tracker) {
public void processProblemScale(Solution_ solution, Object entity, ProblemScaleTracker tracker) {
for (GenuineVariableDescriptor<Solution_> variableDescriptor : effectiveGenuineVariableDescriptorList) {
long valueCount = variableDescriptor.getValueRangeSize(solution, entity);
// TODO: When minimum Java supported is 21, this can be replaced with a sealed interface switch
if (variableDescriptor instanceof BasicVariableDescriptor<Solution_> basicVariableDescriptor) {
if (basicVariableDescriptor.isChained()) {
// An entity is a value
tracker.addListValueCount(1);
if (!isMovable(scoreDirector, entity)) {
if (!isMovable(solution, entity)) {
tracker.addPinnedListValueCount(1);
}
// Anchors are entities
Expand All @@ -759,13 +757,13 @@ The value range (%s) for variable (%s) is not countable.
CountableValueRange.class.getSimpleName(), Collection.class.getSimpleName()));
}
} else {
if (isMovable(scoreDirector, entity)) {
if (isMovable(solution, entity)) {
tracker.addBasicProblemScale(valueCount);
}
}
} else if (variableDescriptor instanceof ListVariableDescriptor<Solution_> listVariableDescriptor) {
tracker.setListTotalValueCount((int) listVariableDescriptor.getValueRangeSize(solution, entity));
if (isMovable(scoreDirector, entity)) {
if (isMovable(solution, entity)) {
tracker.incrementListEntityCount(true);
tracker.addPinnedListValueCount(listVariableDescriptor.getFirstUnpinnedIndex(entity));
} else {
Expand Down Expand Up @@ -824,10 +822,10 @@ public int countReinitializableVariables(Object entity) {
return count;
}

public boolean isMovable(ScoreDirector<Solution_> scoreDirector, Object entity) {
public boolean isMovable(Solution_ workingSolution, Object entity) {
return isGenuine() &&
(effectiveMovableEntitySelectionFilter == null
|| effectiveMovableEntitySelectionFilter.accept(scoreDirector, entity));
(effectiveMovableEntityFilter == null
|| effectiveMovableEntityFilter.test(workingSolution, entity));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package ai.timefold.solver.core.impl.domain.entity.descriptor;

import java.util.function.BiPredicate;

@FunctionalInterface
interface MovableFilter<Solution_> extends BiPredicate<Solution_, Object> {

default MovableFilter<Solution_> and(MovableFilter<Solution_> other) {
return (solution, entity) -> test(solution, entity) && other.test(solution, entity);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package ai.timefold.solver.core.impl.domain.entity.descriptor;

import ai.timefold.solver.core.api.domain.entity.PlanningPin;
import ai.timefold.solver.core.api.domain.solution.PlanningSolution;
import ai.timefold.solver.core.impl.domain.common.accessor.MemberAccessor;

/**
* Filters out entities that return true for the {@link PlanningPin} annotated boolean member.
*
* @param <Solution_> the solution type, the class with the {@link PlanningSolution} annotation
*/
record PinEntityFilter<Solution_>(MemberAccessor memberAccessor) implements MovableFilter<Solution_> {

@Override
public boolean test(Solution_ solution, Object entity) {
var pinned = (Boolean) memberAccessor.executeGetter(entity);
if (pinned == null) {
throw new IllegalStateException("The entity (" + entity + ") has a @" + PlanningPin.class.getSimpleName()
+ " annotated property (" + memberAccessor.getName() + ") that returns null.");
}
return !pinned;
}

@Override
public String toString() {
return "Non-pinned entities only";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1003,8 +1003,9 @@ public void visitAll(Solution_ solution, Consumer<Object> visitor) {
* @return {@code >= 0}
*/
public boolean hasMovableEntities(ScoreDirector<Solution_> scoreDirector) {
return extractAllEntitiesStream(scoreDirector.getWorkingSolution())
.anyMatch(entity -> findEntityDescriptorOrFail(entity.getClass()).isMovable(scoreDirector, entity));
var workingSolution = scoreDirector.getWorkingSolution();
return extractAllEntitiesStream(workingSolution)
.anyMatch(entity -> findEntityDescriptorOrFail(entity.getClass()).isMovable(workingSolution, entity));
}

/**
Expand Down Expand Up @@ -1071,13 +1072,13 @@ public long getMaximumValueRangeSize(Solution_ solution) {
* @param solution never null
* @return {@code >= 0}
*/
public double getProblemScale(ScoreDirector<Solution_> scoreDirector, Solution_ solution) {
public double getProblemScale(Solution_ solution) {
long logBase = getMaximumValueRangeSize(solution);
ProblemScaleTracker problemScaleTracker = new ProblemScaleTracker(logBase);
visitAllEntities(solution, entity -> {
var entityDescriptor = findEntityDescriptorOrFail(entity.getClass());
if (entityDescriptor.isGenuine()) {
entityDescriptor.processProblemScale(scoreDirector, solution, entity, problemScaleTracker);
entityDescriptor.processProblemScale(solution, entity, problemScaleTracker);
}
});
long result = problemScaleTracker.getBasicProblemScaleLog();
Expand All @@ -1102,12 +1103,12 @@ public double getProblemScale(ScoreDirector<Solution_> scoreDirector, Solution_
return scale;
}

public ProblemSizeStatistics getProblemSizeStatistics(ScoreDirector<Solution_> scoreDirector, Solution_ solution) {
public ProblemSizeStatistics getProblemSizeStatistics(Solution_ solution) {
return new ProblemSizeStatistics(
getGenuineEntityCount(solution),
getGenuineVariableCount(solution),
getApproximateValueCount(solution),
getProblemScale(scoreDirector, solution));
getProblemScale(solution));
}

public SolutionInitializationStatistics computeInitializationStatistics(Solution_ solution) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ The entityClass (%s) has a @%s-annotated property (%s) with chained (%s) and pro
@Override
public void linkVariableDescriptors(DescriptorPolicy descriptorPolicy) {
super.linkVariableDescriptors(descriptorPolicy);
if (chained && entityDescriptor.hasEffectiveMovableEntitySelectionFilter()) {
if (chained && entityDescriptor.hasEffectiveMovableEntityFilter()) {
movableChainedTrailingValueFilter = new MovableChainedTrailingValueFilter<>(this);
} else {
movableChainedTrailingValueFilter = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import ai.timefold.solver.core.api.domain.valuerange.ValueRangeProvider;
import ai.timefold.solver.core.api.domain.variable.InverseRelationShadowVariable;
import ai.timefold.solver.core.api.domain.variable.PlanningListVariable;
import ai.timefold.solver.core.api.score.director.ScoreDirector;
import ai.timefold.solver.core.config.util.ConfigUtils;
import ai.timefold.solver.core.impl.domain.common.accessor.MemberAccessor;
import ai.timefold.solver.core.impl.domain.entity.descriptor.EntityDescriptor;
Expand Down Expand Up @@ -189,10 +188,10 @@ public boolean supportsPinning() {
return entityDescriptor.supportsPinning();
}

public boolean isElementPinned(ScoreDirector<Solution_> scoreDirector, Object entity, int index) {
public boolean isElementPinned(Solution_ workingSolution, Object entity, int index) {
if (!supportsPinning()) {
return false;
} else if (!entityDescriptor.isMovable(scoreDirector, entity)) { // Skipping due to @PlanningPin.
} else if (!entityDescriptor.isMovable(workingSolution, entity)) { // Skipping due to @PlanningPin.
return true;
} else {
return index < getFirstUnpinnedIndex(entity);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ private EntitySelector<Solution_> buildBaseEntitySelector(EntityDescriptor<Solut
}

private boolean hasFiltering(EntityDescriptor<Solution_> entityDescriptor) {
return config.getFilterClass() != null || entityDescriptor.hasEffectiveMovableEntitySelectionFilter();
return config.getFilterClass() != null || entityDescriptor.hasEffectiveMovableEntityFilter();
}

private EntitySelector<Solution_> applyNearbySelection(HeuristicConfigPolicy<Solution_> configPolicy,
Expand All @@ -191,8 +191,9 @@ private EntitySelector<Solution_> applyFiltering(EntitySelector<Solution_> entit
filterList.add(selectionFilter);
}
// Filter out pinned entities
if (entityDescriptor.hasEffectiveMovableEntitySelectionFilter()) {
filterList.add(entityDescriptor.getEffectiveMovableEntitySelectionFilter());
if (entityDescriptor.hasEffectiveMovableEntityFilter()) {
filterList.add((scoreDirector, selection) -> entityDescriptor.getEffectiveMovableEntityFilter()
.test(scoreDirector.getWorkingSolution(), selection));
}
// Do not filter out initialized entities here for CH and ES, because they can be partially initialized
// Instead, ValueSelectorConfig.applyReinitializeVariableFiltering() does that.
Expand Down
Loading

0 comments on commit 7149abe

Please sign in to comment.