diff --git a/server/src/main/java/org/elasticsearch/index/similarity/SimilarityService.java b/server/src/main/java/org/elasticsearch/index/similarity/SimilarityService.java index e5ca87cecdf84..aeaae386fc17b 100644 --- a/server/src/main/java/org/elasticsearch/index/similarity/SimilarityService.java +++ b/server/src/main/java/org/elasticsearch/index/similarity/SimilarityService.java @@ -19,15 +19,35 @@ package org.elasticsearch.index.similarity; +import org.apache.logging.log4j.LogManager; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FieldInvertState; +import org.apache.lucene.index.Fields; +import org.apache.lucene.index.LeafMetaData; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.index.PointValues; +import org.apache.lucene.index.SortedDocValues; +import org.apache.lucene.index.SortedNumericDocValues; +import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.index.StoredFieldVisitor; +import org.apache.lucene.index.Terms; +import org.apache.lucene.search.CollectionStatistics; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.TermStatistics; import org.apache.lucene.search.similarities.BM25Similarity; import org.apache.lucene.search.similarities.BooleanSimilarity; import org.apache.lucene.search.similarities.ClassicSimilarity; import org.apache.lucene.search.similarities.PerFieldSimilarityWrapper; import org.apache.lucene.search.similarities.Similarity; +import org.apache.lucene.search.similarities.Similarity.SimScorer; +import org.apache.lucene.search.similarities.Similarity.SimWeight; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.Version; import org.elasticsearch.common.TriFunction; import org.elasticsearch.common.logging.DeprecationLogger; -import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.AbstractIndexComponent; import org.elasticsearch.index.IndexModule; @@ -36,6 +56,8 @@ import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.script.ScriptService; +import java.io.IOException; +import java.io.UncheckedIOException; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -44,7 +66,7 @@ public final class SimilarityService extends AbstractIndexComponent { - private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(Loggers.getLogger(SimilarityService.class)); + private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(LogManager.getLogger(SimilarityService.class)); public static final String DEFAULT_SIMILARITY = "BM25"; private static final String CLASSIC_SIMILARITY = "classic"; private static final Map>> DEFAULTS; @@ -120,7 +142,8 @@ public SimilarityService(IndexSettings indexSettings, ScriptService scriptServic } TriFunction defaultFactory = BUILT_IN.get(typeName); TriFunction factory = similarities.getOrDefault(typeName, defaultFactory); - final Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService); + Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService); + validateSimilarity(indexSettings.getIndexVersionCreated(), similarity); providers.put(name, () -> similarity); } for (Map.Entry>> entry : DEFAULTS.entrySet()) { @@ -140,7 +163,7 @@ public Similarity similarity(MapperService mapperService) { defaultSimilarity; } - + public SimilarityProvider getSimilarity(String name) { Supplier sim = similarities.get(name); if (sim == null) { @@ -171,4 +194,231 @@ public Similarity get(String name) { return (fieldType != null && fieldType.similarity() != null) ? fieldType.similarity().get() : defaultSimilarity; } } + + static void validateSimilarity(Version indexCreatedVersion, Similarity similarity) { + try { + validateScoresArePositive(indexCreatedVersion, similarity); + validateScoresDoNotDecreaseWithFreq(indexCreatedVersion, similarity); + validateScoresDoNotIncreaseWithNorm(indexCreatedVersion, similarity); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private static class SingleNormLeafReader extends LeafReader { + + private final long norm; + + SingleNormLeafReader(long norm) { + this.norm = norm; + } + + @Override + public CacheHelper getCoreCacheHelper() { + return null; + } + + @Override + public Terms terms(String field) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public NumericDocValues getNumericDocValues(String field) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public BinaryDocValues getBinaryDocValues(String field) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public SortedDocValues getSortedDocValues(String field) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public SortedNumericDocValues getSortedNumericDocValues(String field) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public SortedSetDocValues getSortedSetDocValues(String field) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public NumericDocValues getNormValues(String field) throws IOException { + return new NumericDocValues() { + + int doc = -1; + + @Override + public long longValue() throws IOException { + return norm; + } + + @Override + public boolean advanceExact(int target) throws IOException { + doc = target; + return true; + } + + @Override + public int docID() { + return doc; + } + + @Override + public int nextDoc() throws IOException { + return advance(doc + 1); + } + + @Override + public int advance(int target) throws IOException { + if (target == 0) { + return doc = 0; + } else { + return doc = NO_MORE_DOCS; + } + } + + @Override + public long cost() { + return 1; + } + + }; + } + + @Override + public FieldInfos getFieldInfos() { + throw new UnsupportedOperationException(); + } + + @Override + public Bits getLiveDocs() { + return null; + } + + @Override + public PointValues getPointValues(String field) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public void checkIntegrity() throws IOException {} + + @Override + public LeafMetaData getMetaData() { + return new LeafMetaData( + org.apache.lucene.util.Version.LATEST.major, + org.apache.lucene.util.Version.LATEST, + null); + } + + @Override + public Fields getTermVectors(int docID) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public int numDocs() { + return 1; + } + + @Override + public int maxDoc() { + return 1; + } + + @Override + public void document(int docID, StoredFieldVisitor visitor) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + protected void doClose() throws IOException { + } + + @Override + public CacheHelper getReaderCacheHelper() { + throw new UnsupportedOperationException(); + } + + } + + private static void validateScoresArePositive(Version indexCreatedVersion, Similarity similarity) throws IOException { + CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000); + TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130); + SimWeight simWeight = similarity.computeWeight(2f, collectionStats, termStats); + FieldInvertState state = new FieldInvertState(indexCreatedVersion.luceneVersion.major, + "some_field", 20, 20, 0, 50); // length = 20, no overlap + final long norm = similarity.computeNorm(state); + LeafReader reader = new SingleNormLeafReader(norm); + SimScorer scorer = similarity.simScorer(simWeight, reader.getContext()); + for (int freq = 1; freq <= 10; ++freq) { + float score = scorer.score(0, freq); + if (score < 0) { + DEPRECATION_LOGGER.deprecated("Similarities should not return negative scores:\n" + + scorer.explain(0, Explanation.match(freq, "term freq"))); + break; + } + } + } + + private static void validateScoresDoNotDecreaseWithFreq(Version indexCreatedVersion, Similarity similarity) throws IOException { + CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000); + TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130); + SimWeight simWeight = similarity.computeWeight(2f, collectionStats, termStats); + FieldInvertState state = new FieldInvertState(indexCreatedVersion.luceneVersion.major, + "some_field", 20, 20, 0, 50); // length = 20, no overlap + final long norm = similarity.computeNorm(state); + LeafReader reader = new SingleNormLeafReader(norm); + SimScorer scorer = similarity.simScorer(simWeight, reader.getContext()); + float previousScore = Float.NEGATIVE_INFINITY; + for (int freq = 1; freq <= 10; ++freq) { + float score = scorer.score(0, freq); + if (score < previousScore) { + DEPRECATION_LOGGER.deprecated("Similarity scores should not decrease when term frequency increases:\n" + + scorer.explain(0, Explanation.match(freq - 1, "term freq")) + "\n" + + scorer.explain(0, Explanation.match(freq, "term freq"))); + break; + } + previousScore = score; + } + } + + private static void validateScoresDoNotIncreaseWithNorm(Version indexCreatedVersion, Similarity similarity) throws IOException { + CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000); + TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130); + SimWeight simWeight = similarity.computeWeight(2f, collectionStats, termStats); + + SimScorer previousScorer = null; + long previousNorm = 0; + float previousScore = Float.POSITIVE_INFINITY; + for (int length = 1; length <= 10; ++length) { + FieldInvertState state = new FieldInvertState(indexCreatedVersion.luceneVersion.major, + "some_field", length, length, 0, 50); // length = 20, no overlap + final long norm = similarity.computeNorm(state); + if (Long.compareUnsigned(previousNorm, norm) > 0) { + // esoteric similarity, skip this check + break; + } + LeafReader reader = new SingleNormLeafReader(norm); + SimScorer scorer = similarity.simScorer(simWeight, reader.getContext()); + float score = scorer.score(0, 1); + if (score > previousScore) { + DEPRECATION_LOGGER.deprecated("Similarity scores should not increase when norm increases:\n" + + previousScorer.explain(0, Explanation.match(1, "term freq")) + "\n" + + scorer.explain(0, Explanation.match(1, "term freq"))); + break; + } + previousScorer = scorer; + previousScore = score; + previousNorm = norm; + } + } + } diff --git a/server/src/test/java/org/elasticsearch/index/IndexModuleTests.java b/server/src/test/java/org/elasticsearch/index/IndexModuleTests.java index 6d4e18cd3f9c1..edb3637a79c80 100644 --- a/server/src/test/java/org/elasticsearch/index/IndexModuleTests.java +++ b/server/src/test/java/org/elasticsearch/index/IndexModuleTests.java @@ -78,6 +78,7 @@ import org.elasticsearch.test.engine.MockEngineFactory; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; +import org.hamcrest.Matchers; import java.io.IOException; import java.util.Collections; @@ -297,10 +298,11 @@ public void testAddSimilarity() throws IOException { IndexService indexService = newIndexService(module); SimilarityService similarityService = indexService.similarityService(); - assertNotNull(similarityService.getSimilarity("my_similarity")); - assertTrue(similarityService.getSimilarity("my_similarity").get() instanceof TestSimilarity); + Similarity similarity = similarityService.getSimilarity("my_similarity").get(); + assertNotNull(similarity); + assertThat(similarity, Matchers.instanceOf(TestSimilarity.class)); assertEquals("my_similarity", similarityService.getSimilarity("my_similarity").name()); - assertEquals("there is a key", ((TestSimilarity) similarityService.getSimilarity("my_similarity").get()).key); + assertEquals("there is a key", ((TestSimilarity) similarity).key); indexService.close("simon says", false); } diff --git a/server/src/test/java/org/elasticsearch/index/similarity/SimilarityServiceTests.java b/server/src/test/java/org/elasticsearch/index/similarity/SimilarityServiceTests.java index 5d18a595e9687..ae0a78e8f320e 100644 --- a/server/src/test/java/org/elasticsearch/index/similarity/SimilarityServiceTests.java +++ b/server/src/test/java/org/elasticsearch/index/similarity/SimilarityServiceTests.java @@ -19,7 +19,11 @@ package org.elasticsearch.index.similarity; import org.apache.lucene.search.similarities.BM25Similarity; +import org.apache.lucene.search.similarities.BasicStats; import org.apache.lucene.search.similarities.BooleanSimilarity; +import org.apache.lucene.search.similarities.Similarity; +import org.apache.lucene.search.similarities.SimilarityBase; +import org.elasticsearch.Version; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.test.ESTestCase; @@ -56,4 +60,48 @@ public void testOverrideDefaultSimilarity() { SimilarityService service = new SimilarityService(indexSettings, null, Collections.emptyMap()); assertTrue(service.getDefaultSimilarity() instanceof BooleanSimilarity); } + + public void testSimilarityValidation() { + Similarity negativeScoresSim = new SimilarityBase() { + @Override + public String toString() { + return "negativeScoresSim"; + } + @Override + protected float score(BasicStats stats, float freq, float docLen) { + return -1; + } + }; + SimilarityService.validateSimilarity(Version.V_6_5_0, negativeScoresSim); + assertWarnings("Similarities should not return negative scores:\n-1.0 = score(, doc=0, freq=1.0), computed from:\n"); + + Similarity decreasingScoresWithFreqSim = new SimilarityBase() { + @Override + public String toString() { + return "decreasingScoresWithFreqSim"; + } + @Override + protected float score(BasicStats stats, float freq, float docLen) { + return 1 / (freq + docLen); + } + }; + SimilarityService.validateSimilarity(Version.V_6_5_0, decreasingScoresWithFreqSim); + assertWarnings("Similarity scores should not decrease when term frequency increases:\n0.04761905 = score(, doc=0, freq=1.0), " + + "computed from:\n\n0.045454547 = score(, doc=0, freq=2.0), computed from:\n"); + + Similarity increasingScoresWithNormSim = new SimilarityBase() { + @Override + public String toString() { + return "increasingScoresWithNormSim"; + } + @Override + protected float score(BasicStats stats, float freq, float docLen) { + return freq + docLen; + } + }; + SimilarityService.validateSimilarity(Version.V_6_5_0, increasingScoresWithNormSim); + assertWarnings("Similarity scores should not increase when norm increases:\n2.0 = score(, doc=0, freq=1.0), " + + "computed from:\n\n3.0 = score(, doc=0, freq=1.0), computed from:\n"); + } + } diff --git a/server/src/test/java/org/elasticsearch/indices/IndicesServiceTests.java b/server/src/test/java/org/elasticsearch/indices/IndicesServiceTests.java index 35416c617fdd0..808cea4d0ee6e 100644 --- a/server/src/test/java/org/elasticsearch/indices/IndicesServiceTests.java +++ b/server/src/test/java/org/elasticsearch/indices/IndicesServiceTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.indices; import org.apache.lucene.search.similarities.BM25Similarity; +import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.store.AlreadyClosedException; import org.elasticsearch.Version; import org.elasticsearch.action.admin.indices.stats.CommonStatsFlags; @@ -448,8 +449,8 @@ public void testStandAloneMapperServiceWithPlugins() throws IOException { .build(); MapperService mapperService = indicesService.createIndexMapperService(indexMetaData); assertNotNull(mapperService.documentMapperParser().parserContext("type").typeParser("fake-mapper")); - assertThat(mapperService.documentMapperParser().parserContext("type").getSimilarity("test").get(), - instanceOf(BM25Similarity.class)); + Similarity sim = mapperService.documentMapperParser().parserContext("type").getSimilarity("test").get(); + assertThat(sim, instanceOf(BM25Similarity.class)); } public void testStatsByShardDoesNotDieFromExpectedExceptions() {