Skip to content

Commit

Permalink
Add details section for dcg ranking metric (#31177)
Browse files Browse the repository at this point in the history
While the other two ranking evaluation metrics (precicion and reciprocal rank)
already provide a more detailed output for how their score is calculated, the
discounted cumulative gain metric (dcg) and its normalized variant are lacking
this until now. Its not really clear which level of detail might be useful for
debugging and understanding the final metric calculation, but this change adds a
`metric_details` section to REST output that contains some information about the
evaluation details.
  • Loading branch information
Christoph Büscher authored Jun 15, 2018
1 parent ca00deb commit a0d6c19
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.elasticsearch.client;

import com.fasterxml.jackson.core.JsonParseException;

import org.apache.http.Header;
import org.apache.http.HttpEntity;
import org.apache.http.HttpHost;
Expand Down Expand Up @@ -607,7 +608,7 @@ public void testDefaultNamedXContents() {

public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(7, namedXContents.size());
assertEquals(8, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
Expand All @@ -625,9 +626,10 @@ public void testProvidedNamedXContents() {
assertTrue(names.contains(PrecisionAtK.NAME));
assertTrue(names.contains(DiscountedCumulativeGain.NAME));
assertTrue(names.contains(MeanReciprocalRank.NAME));
assertEquals(Integer.valueOf(2), categories.get(MetricDetail.class));
assertEquals(Integer.valueOf(3), categories.get(MetricDetail.class));
assertTrue(names.contains(PrecisionAtK.NAME));
assertTrue(names.contains(MeanReciprocalRank.NAME));
assertTrue(names.contains(DiscountedCumulativeGain.NAME));
}

private static class TrackingActionListener implements ActionListener<Integer> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.Optional;
import java.util.stream.Collectors;

import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;

Expand Down Expand Up @@ -129,26 +130,31 @@ public EvalQueryQuality evaluate(String taskId, SearchHit[] hits,
.collect(Collectors.toList());
List<RatedSearchHit> ratedHits = joinHitsWithRatings(hits, ratedDocs);
List<Integer> ratingsInSearchHits = new ArrayList<>(ratedHits.size());
int unratedResults = 0;
for (RatedSearchHit hit : ratedHits) {
// unknownDocRating might be null, which means it will be unrated docs are
// ignored in the dcg calculation
// we still need to add them as a placeholder so the rank of the subsequent
// ratings is correct
// unknownDocRating might be null, in which case unrated docs will be ignored in the dcg calculation.
// we still need to add them as a placeholder so the rank of the subsequent ratings is correct
ratingsInSearchHits.add(hit.getRating().orElse(unknownDocRating));
if (hit.getRating().isPresent() == false) {
unratedResults++;
}
}
double dcg = computeDCG(ratingsInSearchHits);
final double dcg = computeDCG(ratingsInSearchHits);
double result = dcg;
double idcg = 0;

if (normalize) {
Collections.sort(allRatings, Comparator.nullsLast(Collections.reverseOrder()));
double idcg = computeDCG(allRatings.subList(0, Math.min(ratingsInSearchHits.size(), allRatings.size())));
if (idcg > 0) {
dcg = dcg / idcg;
idcg = computeDCG(allRatings.subList(0, Math.min(ratingsInSearchHits.size(), allRatings.size())));
if (idcg != 0) {
result = dcg / idcg;
} else {
dcg = 0;
result = 0;
}
}
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, dcg);
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, result);
evalQueryQuality.addHitsAndRatings(ratedHits);
evalQueryQuality.setMetricDetails(new Detail(dcg, idcg, unratedResults));
return evalQueryQuality;
}

Expand All @@ -167,7 +173,7 @@ private static double computeDCG(List<Integer> ratings) {
private static final ParseField K_FIELD = new ParseField("k");
private static final ParseField NORMALIZE_FIELD = new ParseField("normalize");
private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating");
private static final ConstructingObjectParser<DiscountedCumulativeGain, Void> PARSER = new ConstructingObjectParser<>("dcg_at", false,
private static final ConstructingObjectParser<DiscountedCumulativeGain, Void> PARSER = new ConstructingObjectParser<>("dcg", false,
args -> {
Boolean normalized = (Boolean) args[0];
Integer optK = (Integer) args[2];
Expand Down Expand Up @@ -217,4 +223,118 @@ public final boolean equals(Object obj) {
public final int hashCode() {
return Objects.hash(normalize, unknownDocRating, k);
}

public static final class Detail implements MetricDetail {

private static ParseField DCG_FIELD = new ParseField("dcg");
private static ParseField IDCG_FIELD = new ParseField("ideal_dcg");
private static ParseField NDCG_FIELD = new ParseField("normalized_dcg");
private static ParseField UNRATED_FIELD = new ParseField("unrated_docs");
private final double dcg;
private final double idcg;
private final int unratedDocs;

Detail(double dcg, double idcg, int unratedDocs) {
this.dcg = dcg;
this.idcg = idcg;
this.unratedDocs = unratedDocs;
}

Detail(StreamInput in) throws IOException {
this.dcg = in.readDouble();
this.idcg = in.readDouble();
this.unratedDocs = in.readVInt();
}

@Override
public
String getMetricName() {
return NAME;
}

@Override
public XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(DCG_FIELD.getPreferredName(), this.dcg);
if (this.idcg != 0) {
builder.field(IDCG_FIELD.getPreferredName(), this.idcg);
builder.field(NDCG_FIELD.getPreferredName(), this.dcg / this.idcg);
}
builder.field(UNRATED_FIELD.getPreferredName(), this.unratedDocs);
return builder;
}

private static final ConstructingObjectParser<Detail, Void> PARSER = new ConstructingObjectParser<>(NAME, true, args -> {
return new Detail((Double) args[0], (Double) args[1] != null ? (Double) args[1] : 0.0d, (Integer) args[2]);
});

static {
PARSER.declareDouble(constructorArg(), DCG_FIELD);
PARSER.declareDouble(optionalConstructorArg(), IDCG_FIELD);
PARSER.declareInt(constructorArg(), UNRATED_FIELD);
}

public static Detail fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(this.dcg);
out.writeDouble(this.idcg);
out.writeVInt(this.unratedDocs);
}

@Override
public String getWriteableName() {
return NAME;
}

/**
* @return the discounted cumulative gain
*/
public double getDCG() {
return this.dcg;
}

/**
* @return the ideal discounted cumulative gain, can be 0 if nothing was computed, e.g. because no normalization was required
*/
public double getIDCG() {
return this.idcg;
}

/**
* @return the normalized discounted cumulative gain, can be 0 if nothing was computed, e.g. because no normalization was required
*/
public double getNDCG() {
return (this.idcg != 0) ? this.dcg / this.idcg : 0;
}

/**
* @return the number of unrated documents in the search results
*/
public Object getUnratedDocs() {
return this.unratedDocs;
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
DiscountedCumulativeGain.Detail other = (DiscountedCumulativeGain.Detail) obj;
return (this.dcg == other.dcg &&
this.idcg == other.idcg &&
this.unratedDocs == other.unratedDocs);
}

@Override
public int hashCode() {
return Objects.hash(this.dcg, this.idcg, this.unratedDocs);
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
PrecisionAtK.Detail::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(MeanReciprocalRank.NAME),
MeanReciprocalRank.Detail::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(DiscountedCumulativeGain.NAME),
DiscountedCumulativeGain.Detail::fromXContent));
return namedXContent;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
namedWriteables.add(
new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, PrecisionAtK.NAME, PrecisionAtK.Detail::new));
namedWriteables
.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(MetricDetail.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain.Detail::new));
return namedWriteables;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.elasticsearch.index.rankeval;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.text.Text;
Expand Down Expand Up @@ -254,9 +255,8 @@ private void assertParsedCorrect(String xContent, Integer expectedUnknownDocRati

public static DiscountedCumulativeGain createTestItem() {
boolean normalize = randomBoolean();
Integer unknownDocRating = Integer.valueOf(randomIntBetween(0, 1000));

return new DiscountedCumulativeGain(normalize, unknownDocRating, 10);
Integer unknownDocRating = frequently() ? Integer.valueOf(randomIntBetween(0, 1000)) : null;
return new DiscountedCumulativeGain(normalize, unknownDocRating, randomIntBetween(1, 10));
}

public void testXContentRoundtrip() throws IOException {
Expand All @@ -283,7 +283,25 @@ public void testXContentParsingIsNotLenient() throws IOException {
parser.nextToken();
XContentParseException exception = expectThrows(XContentParseException.class,
() -> DiscountedCumulativeGain.fromXContent(parser));
assertThat(exception.getMessage(), containsString("[dcg_at] unknown field"));
assertThat(exception.getMessage(), containsString("[dcg] unknown field"));
}
}

public void testMetricDetails() {
double dcg = randomDoubleBetween(0, 1, true);
double idcg = randomBoolean() ? 0.0 : randomDoubleBetween(0, 1, true);
double expectedNdcg = idcg != 0 ? dcg / idcg : 0.0;
int unratedDocs = randomIntBetween(0, 100);
DiscountedCumulativeGain.Detail detail = new DiscountedCumulativeGain.Detail(dcg, idcg, unratedDocs);
assertEquals(dcg, detail.getDCG(), 0.0);
assertEquals(idcg, detail.getIDCG(), 0.0);
assertEquals(expectedNdcg, detail.getNDCG(), 0.0);
assertEquals(unratedDocs, detail.getUnratedDocs());
if (idcg != 0) {
assertEquals("{\"dcg\":{\"dcg\":" + dcg + ",\"ideal_dcg\":" + idcg + ",\"normalized_dcg\":" + expectedNdcg
+ ",\"unrated_docs\":" + unratedDocs + "}}", Strings.toString(detail));
} else {
assertEquals("{\"dcg\":{\"dcg\":" + dcg + ",\"unrated_docs\":" + unratedDocs + "}}", Strings.toString(detail));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,20 @@ public static EvalQueryQuality randomEvalQueryQuality() {
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(randomAlphaOfLength(10),
randomDoubleBetween(0.0, 1.0, true));
if (randomBoolean()) {
if (randomBoolean()) {
int metricDetail = randomIntBetween(0, 2);
switch (metricDetail) {
case 0:
evalQueryQuality.setMetricDetails(new PrecisionAtK.Detail(randomIntBetween(0, 1000), randomIntBetween(0, 1000)));
} else {
break;
case 1:
evalQueryQuality.setMetricDetails(new MeanReciprocalRank.Detail(randomIntBetween(0, 1000)));
break;
case 2:
evalQueryQuality.setMetricDetails(new DiscountedCumulativeGain.Detail(randomDoubleBetween(0, 1, true),
randomBoolean() ? randomDoubleBetween(0, 1, true) : 0, randomInt()));
break;
default:
throw new IllegalArgumentException("illegal randomized value in test");
}
}
evalQueryQuality.addHitsAndRatings(ratedHits);
Expand Down

0 comments on commit a0d6c19

Please sign in to comment.