diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index c6e40324d60..3ebf74d4a54 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -19,7 +19,8 @@ Improvements Optimizations --------------------- -(No changes) + +* GITHUB#14052: Speed up DisjunctionDISIApproximation#advance. (Adrien Grand) Bug Fixes --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractMultiTermQueryConstantScoreWrapper.java b/lucene/core/src/java/org/apache/lucene/search/AbstractMultiTermQueryConstantScoreWrapper.java index fc18b0476a7..b6ee4db540a 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractMultiTermQueryConstantScoreWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractMultiTermQueryConstantScoreWrapper.java @@ -28,7 +28,6 @@ import org.apache.lucene.index.TermsEnum; import org.apache.lucene.util.Accountable; import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.IOSupplier; import org.apache.lucene.util.RamUsageEstimator; /** @@ -151,7 +150,8 @@ protected abstract WeightOrDocIdSetIterator rewriteInner( int fieldDocCount, Terms terms, TermsEnum termsEnum, - List collectedTerms) + List collectedTerms, + long leadCost) throws IOException; private WeightOrDocIdSetIterator rewriteAsBooleanQuery( @@ -247,21 +247,22 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti cost = estimateCost(terms, q.getTermsCount()); } - IOSupplier weightOrIteratorSupplier = - () -> { + IOLongFunction weightOrIteratorSupplier = + leadCost -> { if (collectResult) { return rewriteAsBooleanQuery(context, collectedTerms); } else { // Too many terms to rewrite as a simple bq. // Invoke rewriteInner logic to handle rewriting: - return rewriteInner(context, fieldDocCount, terms, termsEnum, collectedTerms); + return rewriteInner( + context, fieldDocCount, terms, termsEnum, collectedTerms, leadCost); } }; return new ScorerSupplier() { @Override public Scorer get(long leadCost) throws IOException { - WeightOrDocIdSetIterator weightOrIterator = weightOrIteratorSupplier.get(); + WeightOrDocIdSetIterator weightOrIterator = weightOrIteratorSupplier.apply(leadCost); final Scorer scorer; if (weightOrIterator == null) { scorer = null; @@ -281,7 +282,8 @@ public Scorer get(long leadCost) throws IOException { @Override public BulkScorer bulkScorer() throws IOException { - WeightOrDocIdSetIterator weightOrIterator = weightOrIteratorSupplier.get(); + WeightOrDocIdSetIterator weightOrIterator = + weightOrIteratorSupplier.apply(Long.MAX_VALUE); final BulkScorer bulkScorer; if (weightOrIterator == null) { bulkScorer = null; @@ -311,6 +313,10 @@ public long cost() { }; } + private static interface IOLongFunction { + T apply(long arg) throws IOException; + } + private static long estimateCost(Terms terms, long queryTermsCount) throws IOException { // Estimate the cost. If the MTQ can provide its term count, we can do a better job // estimating. diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java index 7732445e8cd..b7c613d06a7 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java @@ -237,7 +237,8 @@ BulkScorer booleanScorer() throws IOException { Scorer prohibitedScorer = prohibited.size() == 1 ? prohibited.get(0) - : new DisjunctionSumScorer(prohibited, ScoreMode.COMPLETE_NO_SCORES); + : new DisjunctionSumScorer( + prohibited, ScoreMode.COMPLETE_NO_SCORES, positiveScorerCost); return new ReqExclBulkScorer(positiveScorer, prohibitedScorer); } } @@ -509,7 +510,7 @@ private Scorer opt( if ((scoreMode == ScoreMode.TOP_SCORES && topLevelScoringClause) || minShouldMatch > 1) { return new WANDScorer(optionalScorers, minShouldMatch, scoreMode, leadCost); } else { - return new DisjunctionSumScorer(optionalScorers, scoreMode); + return new DisjunctionSumScorer(optionalScorers, scoreMode, leadCost); } } } diff --git a/lucene/core/src/java/org/apache/lucene/search/DisjunctionDISIApproximation.java b/lucene/core/src/java/org/apache/lucene/search/DisjunctionDISIApproximation.java index 805eadff25a..3b7e2b1014c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/DisjunctionDISIApproximation.java +++ b/lucene/core/src/java/org/apache/lucene/search/DisjunctionDISIApproximation.java @@ -17,6 +17,10 @@ package org.apache.lucene.search; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.List; /** * A {@link DocIdSetIterator} which is a disjunction of the approximations of the provided @@ -24,18 +28,75 @@ * * @lucene.internal */ -public class DisjunctionDISIApproximation extends DocIdSetIterator { +public final class DisjunctionDISIApproximation extends DocIdSetIterator { - final DisiPriorityQueue subIterators; - final long cost; + public static DisjunctionDISIApproximation of( + Collection subIterators, long leadCost) { + + return new DisjunctionDISIApproximation(subIterators, leadCost); + } + + // Heap of iterators that lead iteration. + private final DisiPriorityQueue leadIterators; + // List of iterators that will likely advance on every call to nextDoc() / advance() + private final DisiWrapper[] otherIterators; + private final long cost; + private DisiWrapper leadTop; + private int minOtherDoc; + + public DisjunctionDISIApproximation(Collection subIterators, long leadCost) { + // Using a heap to store disjunctive clauses is great for exhaustive evaluation, when a single + // clause needs to move through the heap on every iteration on average. However, when + // intersecting with a selective filter, it is possible that all clauses need advancing, which + // makes the reordering cost scale in O(N * log(N)) per advance() call when checking clauses + // linearly would scale in O(N). + // To protect against this reordering overhead, we try to have 1.5 clauses or less that advance + // on every advance() call by only putting clauses into the heap as long as Σ min(1, cost / + // leadCost) <= 1.5, or Σ min(leadCost, cost) <= 1.5 * leadCost. Other clauses are checked + // linearly. + + List wrappers = new ArrayList<>(subIterators); + // Sort by descending cost. + wrappers.sort(Comparator.comparingLong(w -> w.cost).reversed()); + + leadIterators = new DisiPriorityQueue(subIterators.size()); + + long reorderThreshold = leadCost + (leadCost >> 1); + if (reorderThreshold < 0) { // overflow + reorderThreshold = Long.MAX_VALUE; + } + long reorderCost = 0; + while (wrappers.isEmpty() == false) { + DisiWrapper last = wrappers.getLast(); + long inc = Math.min(last.cost, leadCost); + if (reorderCost + inc < 0 || reorderCost + inc > reorderThreshold) { + break; + } + leadIterators.add(wrappers.removeLast()); + reorderCost += inc; + } + + // Make leadIterators not empty. This helps save conditionals in the implementation which are + // rarely tested. + if (leadIterators.size() == 0) { + leadIterators.add(wrappers.removeLast()); + } + + otherIterators = wrappers.toArray(DisiWrapper[]::new); - public DisjunctionDISIApproximation(DisiPriorityQueue subIterators) { - this.subIterators = subIterators; long cost = 0; - for (DisiWrapper w : subIterators) { + for (DisiWrapper w : leadIterators) { + cost += w.cost; + } + for (DisiWrapper w : otherIterators) { cost += w.cost; } this.cost = cost; + minOtherDoc = Integer.MAX_VALUE; + for (DisiWrapper w : otherIterators) { + minOtherDoc = Math.min(minOtherDoc, w.doc); + } + leadTop = leadIterators.top(); } @Override @@ -45,29 +106,62 @@ public long cost() { @Override public int docID() { - return subIterators.top().doc; + return Math.min(minOtherDoc, leadTop.doc); } @Override public int nextDoc() throws IOException { - DisiWrapper top = subIterators.top(); - final int doc = top.doc; - do { - top.doc = top.approximation.nextDoc(); - top = subIterators.updateTop(); - } while (top.doc == doc); - - return top.doc; + if (leadTop.doc < minOtherDoc) { + int curDoc = leadTop.doc; + do { + leadTop.doc = leadTop.approximation.nextDoc(); + leadTop = leadIterators.updateTop(); + } while (leadTop.doc == curDoc); + return Math.min(leadTop.doc, minOtherDoc); + } else { + return advance(minOtherDoc + 1); + } } @Override public int advance(int target) throws IOException { - DisiWrapper top = subIterators.top(); - do { - top.doc = top.approximation.advance(target); - top = subIterators.updateTop(); - } while (top.doc < target); + while (leadTop.doc < target) { + leadTop.doc = leadTop.approximation.advance(target); + leadTop = leadIterators.updateTop(); + } - return top.doc; + minOtherDoc = Integer.MAX_VALUE; + for (DisiWrapper w : otherIterators) { + if (w.doc < target) { + w.doc = w.approximation.advance(target); + } + minOtherDoc = Math.min(minOtherDoc, w.doc); + } + + return Math.min(leadTop.doc, minOtherDoc); + } + + /** Return the linked list of iterators positioned on the current doc. */ + public DisiWrapper topList() { + if (leadTop.doc < minOtherDoc) { + return leadIterators.topList(); + } else { + return computeTopList(); + } + } + + private DisiWrapper computeTopList() { + assert leadTop.doc >= minOtherDoc; + DisiWrapper topList = null; + if (leadTop.doc == minOtherDoc) { + topList = leadIterators.topList(); + } + for (DisiWrapper w : otherIterators) { + if (w.doc == minOtherDoc) { + w.next = topList; + topList = w; + } + } + return topList; } } diff --git a/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java b/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java index a98dc91ca4e..7255a67419b 100644 --- a/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java @@ -155,7 +155,7 @@ public Scorer get(long leadCost) throws IOException { for (ScorerSupplier ss : scorerSuppliers) { scorers.add(ss.get(leadCost)); } - return new DisjunctionMaxScorer(tieBreakerMultiplier, scorers, scoreMode); + return new DisjunctionMaxScorer(tieBreakerMultiplier, scorers, scoreMode, leadCost); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxScorer.java b/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxScorer.java index 58965a5e58f..de8c88fd5af 100644 --- a/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxScorer.java @@ -40,9 +40,10 @@ final class DisjunctionMaxScorer extends DisjunctionScorer { * as they are summed into the result. * @param subScorers The sub scorers this Scorer should iterate on */ - DisjunctionMaxScorer(float tieBreakerMultiplier, List subScorers, ScoreMode scoreMode) + DisjunctionMaxScorer( + float tieBreakerMultiplier, List subScorers, ScoreMode scoreMode, long leadCost) throws IOException { - super(subScorers, scoreMode); + super(subScorers, scoreMode, leadCost); this.subScorers = subScorers; this.tieBreakerMultiplier = tieBreakerMultiplier; if (tieBreakerMultiplier < 0 || tieBreakerMultiplier > 1) { diff --git a/lucene/core/src/java/org/apache/lucene/search/DisjunctionScorer.java b/lucene/core/src/java/org/apache/lucene/search/DisjunctionScorer.java index 54681f4a8ac..69237ec4e97 100644 --- a/lucene/core/src/java/org/apache/lucene/search/DisjunctionScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/DisjunctionScorer.java @@ -25,37 +25,34 @@ /** Base class for Scorers that score disjunctions. */ abstract class DisjunctionScorer extends Scorer { + private final int numClauses; private final boolean needsScores; - private final DisiPriorityQueue subScorers; - private final DocIdSetIterator approximation; + private final DisjunctionDISIApproximation approximation; private final TwoPhase twoPhase; - protected DisjunctionScorer(List subScorers, ScoreMode scoreMode) throws IOException { + protected DisjunctionScorer(List subScorers, ScoreMode scoreMode, long leadCost) + throws IOException { if (subScorers.size() <= 1) { throw new IllegalArgumentException("There must be at least 2 subScorers"); } - this.subScorers = new DisiPriorityQueue(subScorers.size()); - for (Scorer scorer : subScorers) { - final DisiWrapper w = new DisiWrapper(scorer, false); - this.subScorers.add(w); - } + this.numClauses = subScorers.size(); this.needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES; - this.approximation = new DisjunctionDISIApproximation(this.subScorers); - boolean hasApproximation = false; float sumMatchCost = 0; long sumApproxCost = 0; - // Compute matchCost as the average over the matchCost of the subScorers. - // This is weighted by the cost, which is an expected number of matching documents. - for (DisiWrapper w : this.subScorers) { + List wrappers = new ArrayList<>(); + for (Scorer scorer : subScorers) { + DisiWrapper w = new DisiWrapper(scorer, false); long costWeight = (w.cost <= 1) ? 1 : w.cost; sumApproxCost += costWeight; if (w.twoPhaseView != null) { hasApproximation = true; sumMatchCost += w.matchCost * costWeight; } + wrappers.add(w); } + this.approximation = new DisjunctionDISIApproximation(wrappers, leadCost); if (hasApproximation == false) { // no sub scorer supports approximations twoPhase = null; @@ -91,7 +88,7 @@ private TwoPhase(DocIdSetIterator approximation, float matchCost) { super(approximation); this.matchCost = matchCost; unverifiedMatches = - new PriorityQueue(DisjunctionScorer.this.subScorers.size()) { + new PriorityQueue(numClauses) { @Override protected boolean lessThan(DisiWrapper a, DisiWrapper b) { return a.matchCost < b.matchCost; @@ -116,7 +113,7 @@ public boolean matches() throws IOException { verifiedMatches = null; unverifiedMatches.clear(); - for (DisiWrapper w = subScorers.topList(); w != null; ) { + for (DisiWrapper w = DisjunctionScorer.this.approximation.topList(); w != null; ) { DisiWrapper next = w.next; if (w.twoPhaseView == null) { @@ -160,12 +157,12 @@ public float matchCost() { @Override public final int docID() { - return subScorers.top().doc; + return approximation.docID(); } DisiWrapper getSubMatches() throws IOException { if (twoPhase == null) { - return subScorers.topList(); + return approximation.topList(); } else { return twoPhase.getSubMatches(); } diff --git a/lucene/core/src/java/org/apache/lucene/search/DisjunctionSumScorer.java b/lucene/core/src/java/org/apache/lucene/search/DisjunctionSumScorer.java index fb9648e725f..1a7b7f497e7 100644 --- a/lucene/core/src/java/org/apache/lucene/search/DisjunctionSumScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/DisjunctionSumScorer.java @@ -30,8 +30,9 @@ final class DisjunctionSumScorer extends DisjunctionScorer { * * @param subScorers Array of at least two subscorers. */ - DisjunctionSumScorer(List subScorers, ScoreMode scoreMode) throws IOException { - super(subScorers, scoreMode); + DisjunctionSumScorer(List subScorers, ScoreMode scoreMode, long leadCost) + throws IOException { + super(subScorers, scoreMode, leadCost); this.scorers = subScorers; } diff --git a/lucene/core/src/java/org/apache/lucene/search/IndriDisjunctionScorer.java b/lucene/core/src/java/org/apache/lucene/search/IndriDisjunctionScorer.java index 6836269189d..36e613e47ff 100644 --- a/lucene/core/src/java/org/apache/lucene/search/IndriDisjunctionScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/IndriDisjunctionScorer.java @@ -17,6 +17,7 @@ package org.apache.lucene.search; import java.io.IOException; +import java.util.ArrayList; import java.util.List; /** @@ -27,18 +28,17 @@ public abstract class IndriDisjunctionScorer extends IndriScorer { private final List subScorersList; - private final DisiPriorityQueue subScorers; private final DocIdSetIterator approximation; protected IndriDisjunctionScorer(List subScorersList, ScoreMode scoreMode, float boost) { super(boost); this.subScorersList = subScorersList; - this.subScorers = new DisiPriorityQueue(subScorersList.size()); + List wrappers = new ArrayList<>(); for (Scorer scorer : subScorersList) { final DisiWrapper w = new DisiWrapper(scorer, false); - this.subScorers.add(w); + wrappers.add(w); } - this.approximation = new DisjunctionDISIApproximation(this.subScorers); + this.approximation = new DisjunctionDISIApproximation(wrappers, Long.MAX_VALUE); } @Override @@ -71,6 +71,6 @@ public float smoothingScore(int docId) throws IOException { @Override public int docID() { - return subScorers.top().doc; + return approximation.docID(); } } diff --git a/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryConstantScoreBlendedWrapper.java b/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryConstantScoreBlendedWrapper.java index 78a75fd2b3f..b4b81f66e6c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryConstantScoreBlendedWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryConstantScoreBlendedWrapper.java @@ -17,6 +17,7 @@ package org.apache.lucene.search; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PostingsEnum; @@ -52,7 +53,8 @@ protected WeightOrDocIdSetIterator rewriteInner( int fieldDocCount, Terms terms, TermsEnum termsEnum, - List collectedTerms) + List collectedTerms, + long leadCost) throws IOException { DocIdSetBuilder otherTerms = new DocIdSetBuilder(context.reader().maxDoc(), terms); PriorityQueue highFrequencyTerms = @@ -110,7 +112,7 @@ protected boolean lessThan(PostingsEnum a, PostingsEnum b) { } } while (termsEnum.next() != null); - DisiPriorityQueue subs = new DisiPriorityQueue(highFrequencyTerms.size() + 1); + List subs = new ArrayList<>(highFrequencyTerms.size() + 1); for (DocIdSetIterator disi : highFrequencyTerms) { Scorer s = wrapWithDummyScorer(this, disi); subs.add(new DisiWrapper(s, false)); @@ -118,7 +120,7 @@ protected boolean lessThan(PostingsEnum a, PostingsEnum b) { Scorer s = wrapWithDummyScorer(this, otherTerms.build().iterator()); subs.add(new DisiWrapper(s, false)); - return new WeightOrDocIdSetIterator(new DisjunctionDISIApproximation(subs)); + return new WeightOrDocIdSetIterator(new DisjunctionDISIApproximation(subs, leadCost)); } }; } diff --git a/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryConstantScoreWrapper.java b/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryConstantScoreWrapper.java index 4a23ddaa006..ff31cbab21c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryConstantScoreWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryConstantScoreWrapper.java @@ -49,7 +49,8 @@ protected WeightOrDocIdSetIterator rewriteInner( int fieldDocCount, Terms terms, TermsEnum termsEnum, - List collectedTerms) + List collectedTerms, + long leadCost) throws IOException { DocIdSetBuilder builder = new DocIdSetBuilder(context.reader().maxDoc(), terms); PostingsEnum docs = null; diff --git a/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java b/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java index ef1073116d7..834deede909 100644 --- a/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java @@ -357,17 +357,19 @@ public Scorer get(long leadCost) throws IOException { } else { // we use termscorers + disjunction as an impl detail - DisiPriorityQueue queue = new DisiPriorityQueue(iterators.size()); + List wrappers = new ArrayList<>(); for (int i = 0; i < iterators.size(); i++) { PostingsEnum postings = iterators.get(i); final TermScorer termScorer = new TermScorer(postings, simWeight, norms); float boost = termBoosts.get(i); final DisiWrapperFreq wrapper = new DisiWrapperFreq(termScorer, boost); - queue.add(wrapper); + wrappers.add(wrapper); } // Even though it is called approximation, it is accurate since none of // the sub iterators are two-phase iterators. - DocIdSetIterator iterator = new DisjunctionDISIApproximation(queue); + DisjunctionDISIApproximation disjunctionIterator = + new DisjunctionDISIApproximation(wrappers, leadCost); + DocIdSetIterator iterator = disjunctionIterator; float[] boosts = new float[impacts.size()]; for (int i = 0; i < boosts.length; i++) { @@ -384,7 +386,7 @@ public Scorer get(long leadCost) throws IOException { iterator = impactsDisi; } - return new SynonymScorer(queue, iterator, impactsDisi, simWeight, norms); + return new SynonymScorer(iterator, disjunctionIterator, impactsDisi, simWeight, norms); } } @@ -576,21 +578,21 @@ public void advanceShallow(int target) throws IOException { private static class SynonymScorer extends Scorer { - private final DisiPriorityQueue queue; private final DocIdSetIterator iterator; + private final DisjunctionDISIApproximation disjunctionDisi; private final MaxScoreCache maxScoreCache; private final ImpactsDISI impactsDisi; private final SimScorer scorer; private final NumericDocValues norms; SynonymScorer( - DisiPriorityQueue queue, DocIdSetIterator iterator, + DisjunctionDISIApproximation disjunctionDisi, ImpactsDISI impactsDisi, SimScorer scorer, NumericDocValues norms) { - this.queue = queue; this.iterator = iterator; + this.disjunctionDisi = disjunctionDisi; this.maxScoreCache = impactsDisi.getMaxScoreCache(); this.impactsDisi = impactsDisi; this.scorer = scorer; @@ -603,7 +605,7 @@ public int docID() { } float freq() throws IOException { - DisiWrapperFreq w = (DisiWrapperFreq) queue.topList(); + DisiWrapperFreq w = (DisiWrapperFreq) disjunctionDisi.topList(); float freq = w.freq(); for (w = (DisiWrapperFreq) w.next; w != null; w = (DisiWrapperFreq) w.next) { freq += w.freq(); diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/CombinedFieldQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/CombinedFieldQuery.java index 0fbcd9e2a9a..1ca3f790f43 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/CombinedFieldQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/CombinedFieldQuery.java @@ -39,7 +39,6 @@ import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.CollectionStatistics; -import org.apache.lucene.search.DisiPriorityQueue; import org.apache.lucene.search.DisiWrapper; import org.apache.lucene.search.DisjunctionDISIApproximation; import org.apache.lucene.search.DocIdSetIterator; @@ -383,6 +382,7 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { List iterators = new ArrayList<>(); List fields = new ArrayList<>(); + long cost = 0; for (int i = 0; i < fieldTerms.length; i++) { IOSupplier supplier = termStates[i].get(context); TermState state = supplier == null ? null : supplier.get(); @@ -392,6 +392,7 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti PostingsEnum postingsEnum = termsEnum.postings(null, PostingsEnum.FREQS); iterators.add(postingsEnum); fields.add(fieldAndWeights.get(fieldTerms[i].field())); + cost += postingsEnum.cost(); } } @@ -401,18 +402,31 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti MultiNormsLeafSimScorer scoringSimScorer = new MultiNormsLeafSimScorer(simWeight, context.reader(), fieldAndWeights.values(), true); - // we use termscorers + disjunction as an impl detail - DisiPriorityQueue queue = new DisiPriorityQueue(iterators.size()); - for (int i = 0; i < iterators.size(); i++) { - float weight = fields.get(i).weight; - queue.add( - new WeightedDisiWrapper(new TermScorer(iterators.get(i), simWeight, null), weight)); - } - // Even though it is called approximation, it is accurate since none of - // the sub iterators are two-phase iterators. - DocIdSetIterator iterator = new DisjunctionDISIApproximation(queue); - final var scorer = new CombinedFieldScorer(queue, iterator, scoringSimScorer); - return new DefaultScorerSupplier(scorer); + + final long finalCost = cost; + return new ScorerSupplier() { + + @Override + public Scorer get(long leadCost) throws IOException { + // we use termscorers + disjunction as an impl detail + List wrappers = new ArrayList<>(iterators.size()); + for (int i = 0; i < iterators.size(); i++) { + float weight = fields.get(i).weight; + wrappers.add( + new WeightedDisiWrapper(new TermScorer(iterators.get(i), simWeight, null), weight)); + } + // Even though it is called approximation, it is accurate since none of + // the sub iterators are two-phase iterators. + DisjunctionDISIApproximation iterator = + new DisjunctionDISIApproximation(wrappers, leadCost); + return new CombinedFieldScorer(iterator, scoringSimScorer); + } + + @Override + public long cost() { + return finalCost; + } + }; } @Override @@ -437,14 +451,11 @@ float freq() throws IOException { } private static class CombinedFieldScorer extends Scorer { - private final DisiPriorityQueue queue; - private final DocIdSetIterator iterator; + private final DisjunctionDISIApproximation iterator; private final MultiNormsLeafSimScorer simScorer; private final float maxScore; - CombinedFieldScorer( - DisiPriorityQueue queue, DocIdSetIterator iterator, MultiNormsLeafSimScorer simScorer) { - this.queue = queue; + CombinedFieldScorer(DisjunctionDISIApproximation iterator, MultiNormsLeafSimScorer simScorer) { this.iterator = iterator; this.simScorer = simScorer; this.maxScore = simScorer.getSimScorer().score(Float.POSITIVE_INFINITY, 1L); @@ -456,7 +467,7 @@ public int docID() { } float freq() throws IOException { - DisiWrapper w = queue.topList(); + DisiWrapper w = iterator.topList(); float freq = ((WeightedDisiWrapper) w).freq(); for (w = w.next; w != null; w = w.next) { freq += ((WeightedDisiWrapper) w).freq();