Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add minimal sanity checks to custom/scripted similarities. (backport) #33893

Merged
merged 1 commit into from
Sep 25, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,35 @@

package org.elasticsearch.index.similarity;

import org.apache.logging.log4j.LogManager;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.index.Fields;
import org.apache.lucene.index.LeafMetaData;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.StoredFieldVisitor;
import org.apache.lucene.index.Terms;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.BooleanSimilarity;
import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.search.similarities.PerFieldSimilarityWrapper;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.search.similarities.Similarity.SimScorer;
import org.apache.lucene.search.similarities.Similarity.SimWeight;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.common.TriFunction;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.AbstractIndexComponent;
import org.elasticsearch.index.IndexModule;
Expand All @@ -36,6 +56,8 @@
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.script.ScriptService;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
Expand All @@ -44,7 +66,7 @@

public final class SimilarityService extends AbstractIndexComponent {

private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(Loggers.getLogger(SimilarityService.class));
private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(LogManager.getLogger(SimilarityService.class));
public static final String DEFAULT_SIMILARITY = "BM25";
private static final String CLASSIC_SIMILARITY = "classic";
private static final Map<String, Function<Version, Supplier<Similarity>>> DEFAULTS;
Expand Down Expand Up @@ -120,7 +142,8 @@ public SimilarityService(IndexSettings indexSettings, ScriptService scriptServic
}
TriFunction<Settings, Version, ScriptService, Similarity> defaultFactory = BUILT_IN.get(typeName);
TriFunction<Settings, Version, ScriptService, Similarity> factory = similarities.getOrDefault(typeName, defaultFactory);
final Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService);
Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService);
validateSimilarity(indexSettings.getIndexVersionCreated(), similarity);
providers.put(name, () -> similarity);
}
for (Map.Entry<String, Function<Version, Supplier<Similarity>>> entry : DEFAULTS.entrySet()) {
Expand All @@ -140,7 +163,7 @@ public Similarity similarity(MapperService mapperService) {
defaultSimilarity;
}


public SimilarityProvider getSimilarity(String name) {
Supplier<Similarity> sim = similarities.get(name);
if (sim == null) {
Expand Down Expand Up @@ -171,4 +194,231 @@ public Similarity get(String name) {
return (fieldType != null && fieldType.similarity() != null) ? fieldType.similarity().get() : defaultSimilarity;
}
}

static void validateSimilarity(Version indexCreatedVersion, Similarity similarity) {
try {
validateScoresArePositive(indexCreatedVersion, similarity);
validateScoresDoNotDecreaseWithFreq(indexCreatedVersion, similarity);
validateScoresDoNotIncreaseWithNorm(indexCreatedVersion, similarity);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

private static class SingleNormLeafReader extends LeafReader {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do wonder if this class should be in a separate utility file somewhere? It's a third of the entire file length. On the other hand, this is all going away in master, so maybe it doesn't matter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I'll leave things this way though as master doesn't have the issue so it's not a big deal like you said.

private final long norm;

SingleNormLeafReader(long norm) {
this.norm = norm;
}

@Override
public CacheHelper getCoreCacheHelper() {
return null;
}

@Override
public Terms terms(String field) throws IOException {
throw new UnsupportedOperationException();
}

@Override
public NumericDocValues getNumericDocValues(String field) throws IOException {
throw new UnsupportedOperationException();
}

@Override
public BinaryDocValues getBinaryDocValues(String field) throws IOException {
throw new UnsupportedOperationException();
}

@Override
public SortedDocValues getSortedDocValues(String field) throws IOException {
throw new UnsupportedOperationException();
}

@Override
public SortedNumericDocValues getSortedNumericDocValues(String field) throws IOException {
throw new UnsupportedOperationException();
}

@Override
public SortedSetDocValues getSortedSetDocValues(String field) throws IOException {
throw new UnsupportedOperationException();
}

@Override
public NumericDocValues getNormValues(String field) throws IOException {
return new NumericDocValues() {

int doc = -1;

@Override
public long longValue() throws IOException {
return norm;
}

@Override
public boolean advanceExact(int target) throws IOException {
doc = target;
return true;
}

@Override
public int docID() {
return doc;
}

@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}

@Override
public int advance(int target) throws IOException {
if (target == 0) {
return doc = 0;
} else {
return doc = NO_MORE_DOCS;
}
}

@Override
public long cost() {
return 1;
}

};
}

@Override
public FieldInfos getFieldInfos() {
throw new UnsupportedOperationException();
}

@Override
public Bits getLiveDocs() {
return null;
}

@Override
public PointValues getPointValues(String field) throws IOException {
throw new UnsupportedOperationException();
}

@Override
public void checkIntegrity() throws IOException {}

@Override
public LeafMetaData getMetaData() {
return new LeafMetaData(
org.apache.lucene.util.Version.LATEST.major,
org.apache.lucene.util.Version.LATEST,
null);
}

@Override
public Fields getTermVectors(int docID) throws IOException {
throw new UnsupportedOperationException();
}

@Override
public int numDocs() {
return 1;
}

@Override
public int maxDoc() {
return 1;
}

@Override
public void document(int docID, StoredFieldVisitor visitor) throws IOException {
throw new UnsupportedOperationException();
}

@Override
protected void doClose() throws IOException {
}

@Override
public CacheHelper getReaderCacheHelper() {
throw new UnsupportedOperationException();
}

}

private static void validateScoresArePositive(Version indexCreatedVersion, Similarity similarity) throws IOException {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimWeight simWeight = similarity.computeWeight(2f, collectionStats, termStats);
FieldInvertState state = new FieldInvertState(indexCreatedVersion.luceneVersion.major,
"some_field", 20, 20, 0, 50); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
LeafReader reader = new SingleNormLeafReader(norm);
SimScorer scorer = similarity.simScorer(simWeight, reader.getContext());
for (int freq = 1; freq <= 10; ++freq) {
float score = scorer.score(0, freq);
if (score < 0) {
DEPRECATION_LOGGER.deprecated("Similarities should not return negative scores:\n" +
scorer.explain(0, Explanation.match(freq, "term freq")));
break;
}
}
}

private static void validateScoresDoNotDecreaseWithFreq(Version indexCreatedVersion, Similarity similarity) throws IOException {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimWeight simWeight = similarity.computeWeight(2f, collectionStats, termStats);
FieldInvertState state = new FieldInvertState(indexCreatedVersion.luceneVersion.major,
"some_field", 20, 20, 0, 50); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
LeafReader reader = new SingleNormLeafReader(norm);
SimScorer scorer = similarity.simScorer(simWeight, reader.getContext());
float previousScore = Float.NEGATIVE_INFINITY;
for (int freq = 1; freq <= 10; ++freq) {
float score = scorer.score(0, freq);
if (score < previousScore) {
DEPRECATION_LOGGER.deprecated("Similarity scores should not decrease when term frequency increases:\n" +
scorer.explain(0, Explanation.match(freq - 1, "term freq")) + "\n" +
scorer.explain(0, Explanation.match(freq, "term freq")));
break;
}
previousScore = score;
}
}

private static void validateScoresDoNotIncreaseWithNorm(Version indexCreatedVersion, Similarity similarity) throws IOException {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimWeight simWeight = similarity.computeWeight(2f, collectionStats, termStats);

SimScorer previousScorer = null;
long previousNorm = 0;
float previousScore = Float.POSITIVE_INFINITY;
for (int length = 1; length <= 10; ++length) {
FieldInvertState state = new FieldInvertState(indexCreatedVersion.luceneVersion.major,
"some_field", length, length, 0, 50); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
if (Long.compareUnsigned(previousNorm, norm) > 0) {
// esoteric similarity, skip this check
break;
}
LeafReader reader = new SingleNormLeafReader(norm);
SimScorer scorer = similarity.simScorer(simWeight, reader.getContext());
float score = scorer.score(0, 1);
if (score > previousScore) {
DEPRECATION_LOGGER.deprecated("Similarity scores should not increase when norm increases:\n" +
previousScorer.explain(0, Explanation.match(1, "term freq")) + "\n" +
scorer.explain(0, Explanation.match(1, "term freq")));
break;
}
previousScorer = scorer;
previousScore = score;
previousNorm = norm;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import org.elasticsearch.test.engine.MockEngineFactory;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.hamcrest.Matchers;

import java.io.IOException;
import java.util.Collections;
Expand Down Expand Up @@ -297,10 +298,11 @@ public void testAddSimilarity() throws IOException {

IndexService indexService = newIndexService(module);
SimilarityService similarityService = indexService.similarityService();
assertNotNull(similarityService.getSimilarity("my_similarity"));
assertTrue(similarityService.getSimilarity("my_similarity").get() instanceof TestSimilarity);
Similarity similarity = similarityService.getSimilarity("my_similarity").get();
assertNotNull(similarity);
assertThat(similarity, Matchers.instanceOf(TestSimilarity.class));
assertEquals("my_similarity", similarityService.getSimilarity("my_similarity").name());
assertEquals("there is a key", ((TestSimilarity) similarityService.getSimilarity("my_similarity").get()).key);
assertEquals("there is a key", ((TestSimilarity) similarity).key);
indexService.close("simon says", false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
package org.elasticsearch.index.similarity;

import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.BasicStats;
import org.apache.lucene.search.similarities.BooleanSimilarity;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.search.similarities.SimilarityBase;
import org.elasticsearch.Version;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.test.ESTestCase;
Expand Down Expand Up @@ -56,4 +60,48 @@ public void testOverrideDefaultSimilarity() {
SimilarityService service = new SimilarityService(indexSettings, null, Collections.emptyMap());
assertTrue(service.getDefaultSimilarity() instanceof BooleanSimilarity);
}

public void testSimilarityValidation() {
Similarity negativeScoresSim = new SimilarityBase() {
@Override
public String toString() {
return "negativeScoresSim";
}
@Override
protected float score(BasicStats stats, float freq, float docLen) {
return -1;
}
};
SimilarityService.validateSimilarity(Version.V_6_5_0, negativeScoresSim);
assertWarnings("Similarities should not return negative scores:\n-1.0 = score(, doc=0, freq=1.0), computed from:\n");

Similarity decreasingScoresWithFreqSim = new SimilarityBase() {
@Override
public String toString() {
return "decreasingScoresWithFreqSim";
}
@Override
protected float score(BasicStats stats, float freq, float docLen) {
return 1 / (freq + docLen);
}
};
SimilarityService.validateSimilarity(Version.V_6_5_0, decreasingScoresWithFreqSim);
assertWarnings("Similarity scores should not decrease when term frequency increases:\n0.04761905 = score(, doc=0, freq=1.0), " +
"computed from:\n\n0.045454547 = score(, doc=0, freq=2.0), computed from:\n");

Similarity increasingScoresWithNormSim = new SimilarityBase() {
@Override
public String toString() {
return "increasingScoresWithNormSim";
}
@Override
protected float score(BasicStats stats, float freq, float docLen) {
return freq + docLen;
}
};
SimilarityService.validateSimilarity(Version.V_6_5_0, increasingScoresWithNormSim);
assertWarnings("Similarity scores should not increase when norm increases:\n2.0 = score(, doc=0, freq=1.0), " +
"computed from:\n\n3.0 = score(, doc=0, freq=1.0), computed from:\n");
}

}
Loading