From fedc59e141518af9cbfe66225b33844c382a4981 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=BCscher?= Date: Tue, 24 Jul 2018 12:01:52 +0200 Subject: [PATCH] Register ERR metric with NamedXContentRegistry This adds the ERR metric to the provided xContent parsers in the module and the high level rest client registry. Also adding integration tests to make sure the metric is correctly registered and usable from the client. --- .../org/elasticsearch/client/RankEvalIT.java | 48 +++++++++++++++---- .../client/RestHighLevelClientTests.java | 10 ++-- .../rankeval/ExpectedReciprocalRank.java | 3 ++ .../RankEvalNamedXContentProvider.java | 5 ++ .../index/rankeval/RankEvalPlugin.java | 4 ++ .../rest-api-spec/test/rank_eval/10_basic.yml | 34 +++++++++++++ 6 files changed, 92 insertions(+), 12 deletions(-) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RankEvalIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RankEvalIT.java index 330afafd9ef96..0af270cb051ea 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RankEvalIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RankEvalIT.java @@ -22,7 +22,11 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.rankeval.DiscountedCumulativeGain; import org.elasticsearch.index.rankeval.EvalQueryQuality; +import org.elasticsearch.index.rankeval.EvaluationMetric; +import org.elasticsearch.index.rankeval.ExpectedReciprocalRank; +import org.elasticsearch.index.rankeval.MeanReciprocalRank; import org.elasticsearch.index.rankeval.PrecisionAtK; import org.elasticsearch.index.rankeval.RankEvalRequest; import org.elasticsearch.index.rankeval.RankEvalResponse; @@ -35,8 +39,10 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -64,15 +70,7 @@ public void indexDocuments() throws IOException { * calculation where all unlabeled documents are treated as not relevant. */ public void testRankEvalRequest() throws IOException { - SearchSourceBuilder testQuery = new SearchSourceBuilder(); - testQuery.query(new MatchAllQueryBuilder()); - List amsterdamRatedDocs = createRelevant("index" , "amsterdam1", "amsterdam2", "amsterdam3", "amsterdam4"); - amsterdamRatedDocs.addAll(createRelevant("index2", "amsterdam0")); - RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query", amsterdamRatedDocs, testQuery); - RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("index", "berlin"), testQuery); - List specifications = new ArrayList<>(); - specifications.add(amsterdamRequest); - specifications.add(berlinRequest); + List specifications = createTestEvaluationSpec(); PrecisionAtK metric = new PrecisionAtK(1, false, 10); RankEvalSpec spec = new RankEvalSpec(specifications, metric); @@ -114,6 +112,38 @@ public void testRankEvalRequest() throws IOException { response = execute(rankEvalRequest, highLevelClient()::rankEval, highLevelClient()::rankEvalAsync); } + private static List createTestEvaluationSpec() { + SearchSourceBuilder testQuery = new SearchSourceBuilder(); + testQuery.query(new MatchAllQueryBuilder()); + List amsterdamRatedDocs = createRelevant("index" , "amsterdam1", "amsterdam2", "amsterdam3", "amsterdam4"); + amsterdamRatedDocs.addAll(createRelevant("index2", "amsterdam0")); + RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query", amsterdamRatedDocs, testQuery); + RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("index", "berlin"), testQuery); + List specifications = new ArrayList<>(); + specifications.add(amsterdamRequest); + specifications.add(berlinRequest); + return specifications; + } + + /** + * Test case checks that the default metrics are registered and usable + */ + public void testMetrics() throws IOException { + List specifications = createTestEvaluationSpec(); + List> metrics = Arrays.asList(PrecisionAtK::new, MeanReciprocalRank::new, DiscountedCumulativeGain::new, + () -> new ExpectedReciprocalRank(1)); + double expectedScores[] = new double[] {0.4285714285714286, 0.75, 1.6408962261063627, 0.4407738095238095}; + int i = 0; + for (Supplier metricSupplier : metrics) { + RankEvalSpec spec = new RankEvalSpec(specifications, metricSupplier.get()); + + RankEvalRequest rankEvalRequest = new RankEvalRequest(spec, new String[] { "index", "index2" }); + RankEvalResponse response = execute(rankEvalRequest, highLevelClient()::rankEval, highLevelClient()::rankEvalAsync); + assertEquals(expectedScores[i], response.getMetricScore(), Double.MIN_VALUE); + i++; + } + } + private static List createRelevant(String indexName, String... docs) { return Stream.of(docs).map(s -> new RatedDocument(indexName, s, 1)).collect(Collectors.toList()); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 64a344790caa0..48934a9bed8cb 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -20,6 +20,7 @@ package org.elasticsearch.client; import com.fasterxml.jackson.core.JsonParseException; + import org.apache.http.HttpEntity; import org.apache.http.HttpHost; import org.apache.http.HttpResponse; @@ -60,6 +61,7 @@ import org.elasticsearch.common.xcontent.smile.SmileXContent; import org.elasticsearch.index.rankeval.DiscountedCumulativeGain; import org.elasticsearch.index.rankeval.EvaluationMetric; +import org.elasticsearch.index.rankeval.ExpectedReciprocalRank; import org.elasticsearch.index.rankeval.MeanReciprocalRank; import org.elasticsearch.index.rankeval.MetricDetail; import org.elasticsearch.index.rankeval.PrecisionAtK; @@ -616,7 +618,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(8, namedXContents.size()); + assertEquals(10, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -630,14 +632,16 @@ public void testProvidedNamedXContents() { assertEquals(Integer.valueOf(2), categories.get(Aggregation.class)); assertTrue(names.contains(ChildrenAggregationBuilder.NAME)); assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME)); - assertEquals(Integer.valueOf(3), categories.get(EvaluationMetric.class)); + assertEquals(Integer.valueOf(4), categories.get(EvaluationMetric.class)); assertTrue(names.contains(PrecisionAtK.NAME)); assertTrue(names.contains(DiscountedCumulativeGain.NAME)); assertTrue(names.contains(MeanReciprocalRank.NAME)); - assertEquals(Integer.valueOf(3), categories.get(MetricDetail.class)); + assertTrue(names.contains(ExpectedReciprocalRank.NAME)); + assertEquals(Integer.valueOf(4), categories.get(MetricDetail.class)); assertTrue(names.contains(PrecisionAtK.NAME)); assertTrue(names.contains(MeanReciprocalRank.NAME)); assertTrue(names.contains(DiscountedCumulativeGain.NAME)); + assertTrue(names.contains(ExpectedReciprocalRank.NAME)); } public void testApiNamingConventions() throws Exception { diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ExpectedReciprocalRank.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ExpectedReciprocalRank.java index 4aac29f299d67..39e1266504d9a 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ExpectedReciprocalRank.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ExpectedReciprocalRank.java @@ -65,6 +65,9 @@ public class ExpectedReciprocalRank implements EvaluationMetric { public static final String NAME = "expected_reciprocal_rank"; + /** + * @param maxRelevance the highest expected relevance in the data + */ public ExpectedReciprocalRank(int maxRelevance) { this(maxRelevance, null, DEFAULT_K); } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java index f2176113cdf9d..7eddcf9dff644 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java @@ -37,12 +37,17 @@ public List getNamedXContentParsers() { MeanReciprocalRank::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(DiscountedCumulativeGain.NAME), DiscountedCumulativeGain::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(ExpectedReciprocalRank.NAME), + ExpectedReciprocalRank::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(PrecisionAtK.NAME), PrecisionAtK.Detail::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(MeanReciprocalRank.NAME), MeanReciprocalRank.Detail::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(DiscountedCumulativeGain.NAME), DiscountedCumulativeGain.Detail::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(ExpectedReciprocalRank.NAME), + ExpectedReciprocalRank.Detail::fromXContent)); return namedXContent; } } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java index 8ac2b7fbee528..0e5d754778f84 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java @@ -60,10 +60,14 @@ public List getNamedWriteables() { namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new)); namedWriteables.add( new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new)); + namedWriteables.add( + new NamedWriteableRegistry.Entry(EvaluationMetric.class, ExpectedReciprocalRank.NAME, ExpectedReciprocalRank::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, PrecisionAtK.NAME, PrecisionAtK.Detail::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new)); namedWriteables.add( new NamedWriteableRegistry.Entry(MetricDetail.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain.Detail::new)); + namedWriteables.add( + new NamedWriteableRegistry.Entry(MetricDetail.class, ExpectedReciprocalRank.NAME, ExpectedReciprocalRank.Detail::new)); return namedWriteables; } diff --git a/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yml b/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yml index fe877b37a68f4..ebe23ae53f411 100644 --- a/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yml +++ b/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yml @@ -161,3 +161,37 @@ setup: - match: {details.berlin_query.metric_details.mean_reciprocal_rank: {"first_relevant": 2}} - match: {details.berlin_query.unrated_docs: [ {"_index": "foo", "_id": "doc1"}]} +--- +"Expected Reciprocal Rank": + + - skip: + version: " - 6.3.99" + reason: ERR was introduced in 6.4 + + - do: + rank_eval: + body: { + "requests" : [ + { + "id": "amsterdam_query", + "request": { "query": { "match" : {"text" : "amsterdam" }}}, + "ratings": [{"_index": "foo", "_id": "doc4", "rating": 1}] + }, + { + "id" : "berlin_query", + "request": { "query": { "match" : { "text" : "berlin" } }, "size" : 10 }, + "ratings": [{"_index": "foo", "_id": "doc4", "rating": 1}] + } + ], + "metric" : { + "expected_reciprocal_rank": { + "maximum_relevance" : 1, + "k" : 5 + } + } + } + + - gt: {metric_score: 0.2083333} + - lt: {metric_score: 0.2083334} + - match: {details.amsterdam_query.metric_details.expected_reciprocal_rank.unrated_docs: 2} + - match: {details.berlin_query.metric_details.expected_reciprocal_rank.unrated_docs: 1}