Skip to content

Commit

Permalink
Add minimal sanity checks to custom/scripted similarities. (#33564)
Browse files Browse the repository at this point in the history
Add minimal sanity checks to custom/scripted similarities.

Lucene 8 introduced more constraints on similarities, in particular:
 - scores must not be negative,
 - scores must not decrease when term freq increases,
 - scores must not increase when norm (interpreted as an unsigned long)
   increases.

We can't check every single case, but could at least run some sanity checks.

Relates #33309
  • Loading branch information
jpountz authored Sep 19, 2018
1 parent 7f473b6 commit c4261ba
Show file tree
Hide file tree
Showing 6 changed files with 339 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.index.similarity;

import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.Similarity;

/**
* A {@link Similarity} that rejects negative scores. This class exists so that users get
* an error instead of silently corrupt top hits. It should be applied to any custom or
* scripted similarity.
*/
// public for testing
public final class NonNegativeScoresSimilarity extends Similarity {

// Escape hatch
private static final String ES_ENFORCE_POSITIVE_SCORES = "es.enforce.positive.scores";
private static final boolean ENFORCE_POSITIVE_SCORES;
static {
String enforcePositiveScores = System.getProperty(ES_ENFORCE_POSITIVE_SCORES);
if (enforcePositiveScores == null) {
ENFORCE_POSITIVE_SCORES = true;
} else if ("false".equals(enforcePositiveScores)) {
ENFORCE_POSITIVE_SCORES = false;
} else {
throw new IllegalArgumentException(ES_ENFORCE_POSITIVE_SCORES + " may only be unset or set to [false], but got [" +
enforcePositiveScores + "]");
}
}

private final Similarity in;

public NonNegativeScoresSimilarity(Similarity in) {
this.in = in;
}

public Similarity getDelegate() {
return in;
}

@Override
public long computeNorm(FieldInvertState state) {
return in.computeNorm(state);
}

@Override
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
final SimScorer inScorer = in.scorer(boost, collectionStats, termStats);
return new SimScorer() {

@Override
public float score(float freq, long norm) {
float score = inScorer.score(freq, norm);
if (score < 0f) {
if (ENFORCE_POSITIVE_SCORES) {
throw new IllegalArgumentException("Similarities must not produce negative scores, but got:\n" +
inScorer.explain(Explanation.match(freq, "term frequency"), norm));
} else {
return 0f;
}
}
return score;
}

@Override
public Explanation explain(Explanation freq, long norm) {
Explanation expl = inScorer.explain(freq, norm);
if (expl.isMatch() && expl.getValue().floatValue() < 0) {
expl = Explanation.match(0f, "max of:",
expl, Explanation.match(0f, "Minimum allowed score"));
}
return expl;
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,22 @@

package org.elasticsearch.index.similarity;

import org.apache.logging.log4j.LogManager;
import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.BooleanSimilarity;
import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.search.similarities.PerFieldSimilarityWrapper;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.search.similarities.Similarity.SimScorer;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.common.TriFunction;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.AbstractIndexComponent;
import org.elasticsearch.index.IndexModule;
Expand All @@ -44,7 +51,7 @@

public final class SimilarityService extends AbstractIndexComponent {

private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(Loggers.getLogger(SimilarityService.class));
private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(LogManager.getLogger(SimilarityService.class));
public static final String DEFAULT_SIMILARITY = "BM25";
private static final String CLASSIC_SIMILARITY = "classic";
private static final Map<String, Function<Version, Supplier<Similarity>>> DEFAULTS;
Expand Down Expand Up @@ -131,8 +138,14 @@ public SimilarityService(IndexSettings indexSettings, ScriptService scriptServic
}
TriFunction<Settings, Version, ScriptService, Similarity> defaultFactory = BUILT_IN.get(typeName);
TriFunction<Settings, Version, ScriptService, Similarity> factory = similarities.getOrDefault(typeName, defaultFactory);
final Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService);
providers.put(name, () -> similarity);
Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService);
validateSimilarity(indexSettings.getIndexVersionCreated(), similarity);
if (BUILT_IN.containsKey(typeName) == false || "scripted".equals(typeName)) {
// We don't trust custom similarities
similarity = new NonNegativeScoresSimilarity(similarity);
}
final Similarity similarityF = similarity; // like similarity but final
providers.put(name, () -> similarityF);
}
for (Map.Entry<String, Function<Version, Supplier<Similarity>>> entry : DEFAULTS.entrySet()) {
providers.put(entry.getKey(), entry.getValue().apply(indexSettings.getIndexVersionCreated()));
Expand All @@ -151,7 +164,7 @@ public Similarity similarity(MapperService mapperService) {
defaultSimilarity;
}


public SimilarityProvider getSimilarity(String name) {
Supplier<Similarity> sim = similarities.get(name);
if (sim == null) {
Expand Down Expand Up @@ -182,4 +195,80 @@ public Similarity get(String name) {
return (fieldType != null && fieldType.similarity() != null) ? fieldType.similarity().get() : defaultSimilarity;
}
}

static void validateSimilarity(Version indexCreatedVersion, Similarity similarity) {
validateScoresArePositive(indexCreatedVersion, similarity);
validateScoresDoNotDecreaseWithFreq(indexCreatedVersion, similarity);
validateScoresDoNotIncreaseWithNorm(indexCreatedVersion, similarity);
}

private static void validateScoresArePositive(Version indexCreatedVersion, Similarity similarity) {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
IndexOptions.DOCS_AND_FREQS, 20, 20, 0, 50, 10, 3); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
for (int freq = 1; freq <= 10; ++freq) {
float score = scorer.score(freq, norm);
if (score < 0) {
fail(indexCreatedVersion, "Similarities should not return negative scores:\n" +
scorer.explain(Explanation.match(freq, "term freq"), norm));
}
}
}

private static void validateScoresDoNotDecreaseWithFreq(Version indexCreatedVersion, Similarity similarity) {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
IndexOptions.DOCS_AND_FREQS, 20, 20, 0, 50, 10, 3); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
float previousScore = 0;
for (int freq = 1; freq <= 10; ++freq) {
float score = scorer.score(freq, norm);
if (score < previousScore) {
fail(indexCreatedVersion, "Similarity scores should not decrease when term frequency increases:\n" +
scorer.explain(Explanation.match(freq - 1, "term freq"), norm) + "\n" +
scorer.explain(Explanation.match(freq, "term freq"), norm));
}
previousScore = score;
}
}

private static void validateScoresDoNotIncreaseWithNorm(Version indexCreatedVersion, Similarity similarity) {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);

long previousNorm = 0;
float previousScore = Float.MAX_VALUE;
for (int length = 1; length <= 10; ++length) {
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
IndexOptions.DOCS_AND_FREQS, length, length, 0, 50, 10, 3); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
if (Long.compareUnsigned(previousNorm, norm) > 0) {
// esoteric similarity, skip this check
break;
}
float score = scorer.score(1, norm);
if (score > previousScore) {
fail(indexCreatedVersion, "Similarity scores should not increase when norm increases:\n" +
scorer.explain(Explanation.match(1, "term freq"), norm - 1) + "\n" +
scorer.explain(Explanation.match(1, "term freq"), norm));
}
previousScore = score;
previousNorm = norm;
}
}

private static void fail(Version indexCreatedVersion, String message) {
if (indexCreatedVersion.onOrAfter(Version.V_7_0_0_alpha1)) {
throw new IllegalArgumentException(message);
} else if (indexCreatedVersion.onOrAfter(Version.V_6_5_0)) {
DEPRECATION_LOGGER.deprecated(message);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.elasticsearch.index.shard.IndexingOperationListener;
import org.elasticsearch.index.shard.SearchOperationListener;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.index.similarity.NonNegativeScoresSimilarity;
import org.elasticsearch.index.similarity.SimilarityService;
import org.elasticsearch.index.store.IndexStore;
import org.elasticsearch.indices.IndicesModule;
Expand All @@ -77,6 +78,7 @@
import org.elasticsearch.test.engine.MockEngineFactory;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.hamcrest.Matchers;

import java.io.IOException;
import java.util.Collections;
Expand Down Expand Up @@ -295,10 +297,13 @@ public void testAddSimilarity() throws IOException {

IndexService indexService = newIndexService(module);
SimilarityService similarityService = indexService.similarityService();
assertNotNull(similarityService.getSimilarity("my_similarity"));
assertTrue(similarityService.getSimilarity("my_similarity").get() instanceof TestSimilarity);
Similarity similarity = similarityService.getSimilarity("my_similarity").get();
assertNotNull(similarity);
assertThat(similarity, Matchers.instanceOf(NonNegativeScoresSimilarity.class));
similarity = ((NonNegativeScoresSimilarity) similarity).getDelegate();
assertThat(similarity, Matchers.instanceOf(TestSimilarity.class));
assertEquals("my_similarity", similarityService.getSimilarity("my_similarity").name());
assertEquals("there is a key", ((TestSimilarity) similarityService.getSimilarity("my_similarity").get()).key);
assertEquals("there is a key", ((TestSimilarity) similarity).key);
indexService.close("simon says", false);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.index.similarity;

import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.search.similarities.Similarity.SimScorer;
import org.elasticsearch.test.ESTestCase;
import org.hamcrest.Matchers;

public class NonNegativeScoresSimilarityTests extends ESTestCase {

public void testBasics() {
Similarity negativeScoresSim = new Similarity() {

@Override
public long computeNorm(FieldInvertState state) {
return state.getLength();
}

@Override
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
return new SimScorer() {
@Override
public float score(float freq, long norm) {
return freq - 5;
}
};
}
};
Similarity assertingSimilarity = new NonNegativeScoresSimilarity(negativeScoresSim);
SimScorer scorer = assertingSimilarity.scorer(1f, null);
assertEquals(2f, scorer.score(7f, 1L), 0f);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> scorer.score(2f, 1L));
assertThat(e.getMessage(), Matchers.containsString("Similarities must not produce negative scores"));
}

}
Loading

0 comments on commit c4261ba

Please sign in to comment.