Skip to content

Commit

Permalink
Fix negative scores returned from multi_match query with `cross_fie…
Browse files Browse the repository at this point in the history
…lds` (opensearch-project#13829)

Under specific circumstances, when using `cross_fields` scoring on a
`multi_match` query, we can end up with negative scores from the inverse
document frequency calculation in the BM25 formula.

Specifically, the IDF is calculated as:

```
log(1 + (N - n + 0.5) / (n + 0.5))
```

where `N` is the number of documents containing the field and `n` is the
number of documents containing the given term in the field. Obviously,
`n` should always be less than or equal to `N`.

Unfortunately, `cross_fields` makes up a new value for `n` and tries to
use it across all fields.

This change finds the (nonzero) value of `N` for each field and uses that as an
upper bound for the new value of `n`.

Signed-off-by: Michael Froh <froh@amazon.com>

---------

Signed-off-by: Michael Froh <froh@amazon.com>
  • Loading branch information
msfroh authored May 31, 2024
1 parent f50121a commit fffd101
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix get field mapping API returns 404 error in mixed cluster with multiple versions ([#13624](https://github.com/opensearch-project/OpenSearch/pull/13624))
- Allow clearing `remote_store.compatibility_mode` setting ([#13646](https://github.com/opensearch-project/OpenSearch/pull/13646))
- Fix ReplicaShardBatchAllocator to batch shards without duplicates ([#13710](https://github.com/opensearch-project/OpenSearch/pull/13710))
- Don't return negative scores from `multi_match` query with `cross_fields` type ([#13829](https://github.com/opensearch-project/OpenSearch/pull/13829))
- Pass parent filter to inner hit query ([#13903](https://github.com/opensearch-project/OpenSearch/pull/13903))

### Security
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"Cross fields do not return negative scores":
- skip:
version: " - 2.99.99"
reason: "This fix is in 2.15. Until we do the BWC dance, we need to skip all pre-3.0, though."
- do:
index:
index: test
id: 1
body: { "color" : "orange red yellow" }
- do:
index:
index: test
id: 2
body: { "color": "orange red purple", "shape": "red square" }
- do:
index:
index: test
id: 3
body: { "color" : "orange red yellow purple" }
- do:
indices.refresh: { }
- do:
search:
index: test
body:
query:
multi_match:
query: "red"
type: "cross_fields"
fields: [ "color", "shape^100"]
tie_breaker: 0.1
explain: true
- match: { hits.total.value: 3 }
- match: { hits.hits.0._id: "2" }
- gt: { hits.hits.2._score: 0.0 }
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ protected void blend(final TermStates[] contexts, int maxDoc, IndexReader reader
}
int max = 0;
long minSumTTF = Long.MAX_VALUE;
int[] docCounts = new int[contexts.length];
for (int i = 0; i < contexts.length; i++) {
TermStates ctx = contexts[i];
int df = ctx.docFreq();
Expand All @@ -133,6 +134,7 @@ protected void blend(final TermStates[] contexts, int maxDoc, IndexReader reader
// we need to find out the minimum sumTTF to adjust the statistics
// otherwise the statistics don't match
minSumTTF = Math.min(minSumTTF, reader.getSumTotalTermFreq(terms[i].field()));
docCounts[i] = reader.getDocCount(terms[i].field());
}
}
if (maxDoc > minSumTTF) {
Expand Down Expand Up @@ -175,7 +177,11 @@ protected int compare(int i, int j) {
if (prev > current) {
actualDf++;
}
contexts[i] = ctx = adjustDF(reader.getContext(), ctx, Math.min(maxDoc, actualDf));
// Per field, we want to guarantee that the adjusted df does not exceed the number of docs with the field.
// That is, in the IDF formula (log(1 + (N - n + 0.5) / (n + 0.5))), we need to make sure that n (the
// adjusted df) is never bigger than N (the number of docs with the field).
int fieldMaxDoc = Math.min(maxDoc, docCounts[i]);
contexts[i] = ctx = adjustDF(reader.getContext(), ctx, Math.min(fieldMaxDoc, actualDf));
prev = current;
sumTTF += ctx.totalTermFreq();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import java.io.IOException;
import java.util.Map;

import static org.junit.Assert.fail;

/**
* Base class for executable sections that hold assertions
*/
Expand Down Expand Up @@ -79,6 +81,41 @@ protected final Object getActualValue(ClientYamlTestExecutionContext executionCo
return executionContext.response(field);
}

static Object convertActualValue(Object actualValue, Object expectedValue) {
if (actualValue == null || expectedValue.getClass().isAssignableFrom(actualValue.getClass())) {
return actualValue;
}
if (actualValue instanceof Number && expectedValue instanceof Number) {
if (expectedValue instanceof Float) {
return Float.parseFloat(actualValue.toString());
} else if (expectedValue instanceof Double) {
return Double.parseDouble(actualValue.toString());
} else if (expectedValue instanceof Integer) {
return Integer.parseInt(actualValue.toString());
} else if (expectedValue instanceof Long) {
return Long.parseLong(actualValue.toString());
}
}
// Force a class cast exception here, so developers can flesh out the above logic as needed.
try {
expectedValue.getClass().cast(actualValue);
} catch (ClassCastException e) {
fail(
"Type mismatch: Expected value ("
+ expectedValue
+ ") has type "
+ expectedValue.getClass()
+ ". "
+ "Actual value ("
+ actualValue
+ ") has type "
+ actualValue.getClass()
+ "."
);
}
return actualValue;
}

@Override
public XContentLocation getLocation() {
return location;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public GreaterThanAssertion(XContentLocation location, String field, Object expe
@Override
protected void doAssert(Object actualValue, Object expectedValue) {
logger.trace("assert that [{}] is greater than [{}] (field: [{}])", actualValue, expectedValue, getField());
actualValue = convertActualValue(actualValue, expectedValue);
assertThat(
"value of [" + getField() + "] is not comparable (got [" + safeClass(actualValue) + "])",
actualValue,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public GreaterThanEqualToAssertion(XContentLocation location, String field, Obje
@Override
protected void doAssert(Object actualValue, Object expectedValue) {
logger.trace("assert that [{}] is greater than or equal to [{}] (field: [{}])", actualValue, expectedValue, getField());
actualValue = convertActualValue(actualValue, expectedValue);
assertThat(
"value of [" + getField() + "] is not comparable (got [" + safeClass(actualValue) + "])",
actualValue,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public LessThanAssertion(XContentLocation location, String field, Object expecte
@Override
protected void doAssert(Object actualValue, Object expectedValue) {
logger.trace("assert that [{}] is less than [{}] (field: [{}])", actualValue, expectedValue, getField());
actualValue = convertActualValue(actualValue, expectedValue);
assertThat(
"value of [" + getField() + "] is not comparable (got [" + safeClass(actualValue) + "])",
actualValue,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public LessThanOrEqualToAssertion(XContentLocation location, String field, Objec
@Override
protected void doAssert(Object actualValue, Object expectedValue) {
logger.trace("assert that [{}] is less than or equal to [{}] (field: [{}])", actualValue, expectedValue, getField());
actualValue = convertActualValue(actualValue, expectedValue);
assertThat(
"value of [" + getField() + "] is not comparable (got [" + safeClass(actualValue) + "])",
actualValue,
Expand Down

0 comments on commit fffd101

Please sign in to comment.