diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RankEvalIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RankEvalIT.java index 330afafd9ef96..0af270cb051ea 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RankEvalIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RankEvalIT.java @@ -22,7 +22,11 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.rankeval.DiscountedCumulativeGain; import org.elasticsearch.index.rankeval.EvalQueryQuality; +import org.elasticsearch.index.rankeval.EvaluationMetric; +import org.elasticsearch.index.rankeval.ExpectedReciprocalRank; +import org.elasticsearch.index.rankeval.MeanReciprocalRank; import org.elasticsearch.index.rankeval.PrecisionAtK; import org.elasticsearch.index.rankeval.RankEvalRequest; import org.elasticsearch.index.rankeval.RankEvalResponse; @@ -35,8 +39,10 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -64,15 +70,7 @@ public void indexDocuments() throws IOException { * calculation where all unlabeled documents are treated as not relevant. */ public void testRankEvalRequest() throws IOException { - SearchSourceBuilder testQuery = new SearchSourceBuilder(); - testQuery.query(new MatchAllQueryBuilder()); - List amsterdamRatedDocs = createRelevant("index" , "amsterdam1", "amsterdam2", "amsterdam3", "amsterdam4"); - amsterdamRatedDocs.addAll(createRelevant("index2", "amsterdam0")); - RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query", amsterdamRatedDocs, testQuery); - RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("index", "berlin"), testQuery); - List specifications = new ArrayList<>(); - specifications.add(amsterdamRequest); - specifications.add(berlinRequest); + List specifications = createTestEvaluationSpec(); PrecisionAtK metric = new PrecisionAtK(1, false, 10); RankEvalSpec spec = new RankEvalSpec(specifications, metric); @@ -114,6 +112,38 @@ public void testRankEvalRequest() throws IOException { response = execute(rankEvalRequest, highLevelClient()::rankEval, highLevelClient()::rankEvalAsync); } + private static List createTestEvaluationSpec() { + SearchSourceBuilder testQuery = new SearchSourceBuilder(); + testQuery.query(new MatchAllQueryBuilder()); + List amsterdamRatedDocs = createRelevant("index" , "amsterdam1", "amsterdam2", "amsterdam3", "amsterdam4"); + amsterdamRatedDocs.addAll(createRelevant("index2", "amsterdam0")); + RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query", amsterdamRatedDocs, testQuery); + RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("index", "berlin"), testQuery); + List specifications = new ArrayList<>(); + specifications.add(amsterdamRequest); + specifications.add(berlinRequest); + return specifications; + } + + /** + * Test case checks that the default metrics are registered and usable + */ + public void testMetrics() throws IOException { + List specifications = createTestEvaluationSpec(); + List> metrics = Arrays.asList(PrecisionAtK::new, MeanReciprocalRank::new, DiscountedCumulativeGain::new, + () -> new ExpectedReciprocalRank(1)); + double expectedScores[] = new double[] {0.4285714285714286, 0.75, 1.6408962261063627, 0.4407738095238095}; + int i = 0; + for (Supplier metricSupplier : metrics) { + RankEvalSpec spec = new RankEvalSpec(specifications, metricSupplier.get()); + + RankEvalRequest rankEvalRequest = new RankEvalRequest(spec, new String[] { "index", "index2" }); + RankEvalResponse response = execute(rankEvalRequest, highLevelClient()::rankEval, highLevelClient()::rankEvalAsync); + assertEquals(expectedScores[i], response.getMetricScore(), Double.MIN_VALUE); + i++; + } + } + private static List createRelevant(String indexName, String... docs) { return Stream.of(docs).map(s -> new RatedDocument(indexName, s, 1)).collect(Collectors.toList()); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 64a344790caa0..48934a9bed8cb 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -20,6 +20,7 @@ package org.elasticsearch.client; import com.fasterxml.jackson.core.JsonParseException; + import org.apache.http.HttpEntity; import org.apache.http.HttpHost; import org.apache.http.HttpResponse; @@ -60,6 +61,7 @@ import org.elasticsearch.common.xcontent.smile.SmileXContent; import org.elasticsearch.index.rankeval.DiscountedCumulativeGain; import org.elasticsearch.index.rankeval.EvaluationMetric; +import org.elasticsearch.index.rankeval.ExpectedReciprocalRank; import org.elasticsearch.index.rankeval.MeanReciprocalRank; import org.elasticsearch.index.rankeval.MetricDetail; import org.elasticsearch.index.rankeval.PrecisionAtK; @@ -616,7 +618,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(8, namedXContents.size()); + assertEquals(10, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -630,14 +632,16 @@ public void testProvidedNamedXContents() { assertEquals(Integer.valueOf(2), categories.get(Aggregation.class)); assertTrue(names.contains(ChildrenAggregationBuilder.NAME)); assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME)); - assertEquals(Integer.valueOf(3), categories.get(EvaluationMetric.class)); + assertEquals(Integer.valueOf(4), categories.get(EvaluationMetric.class)); assertTrue(names.contains(PrecisionAtK.NAME)); assertTrue(names.contains(DiscountedCumulativeGain.NAME)); assertTrue(names.contains(MeanReciprocalRank.NAME)); - assertEquals(Integer.valueOf(3), categories.get(MetricDetail.class)); + assertTrue(names.contains(ExpectedReciprocalRank.NAME)); + assertEquals(Integer.valueOf(4), categories.get(MetricDetail.class)); assertTrue(names.contains(PrecisionAtK.NAME)); assertTrue(names.contains(MeanReciprocalRank.NAME)); assertTrue(names.contains(DiscountedCumulativeGain.NAME)); + assertTrue(names.contains(ExpectedReciprocalRank.NAME)); } public void testApiNamingConventions() throws Exception { diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ExpectedReciprocalRank.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ExpectedReciprocalRank.java index 4aac29f299d67..39e1266504d9a 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ExpectedReciprocalRank.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/ExpectedReciprocalRank.java @@ -65,6 +65,9 @@ public class ExpectedReciprocalRank implements EvaluationMetric { public static final String NAME = "expected_reciprocal_rank"; + /** + * @param maxRelevance the highest expected relevance in the data + */ public ExpectedReciprocalRank(int maxRelevance) { this(maxRelevance, null, DEFAULT_K); } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java index f2176113cdf9d..7eddcf9dff644 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java @@ -37,12 +37,17 @@ public List getNamedXContentParsers() { MeanReciprocalRank::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(DiscountedCumulativeGain.NAME), DiscountedCumulativeGain::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(ExpectedReciprocalRank.NAME), + ExpectedReciprocalRank::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(PrecisionAtK.NAME), PrecisionAtK.Detail::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(MeanReciprocalRank.NAME), MeanReciprocalRank.Detail::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(DiscountedCumulativeGain.NAME), DiscountedCumulativeGain.Detail::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(ExpectedReciprocalRank.NAME), + ExpectedReciprocalRank.Detail::fromXContent)); return namedXContent; } } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java index 8ac2b7fbee528..0e5d754778f84 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java @@ -60,10 +60,14 @@ public List getNamedWriteables() { namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new)); namedWriteables.add( new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new)); + namedWriteables.add( + new NamedWriteableRegistry.Entry(EvaluationMetric.class, ExpectedReciprocalRank.NAME, ExpectedReciprocalRank::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, PrecisionAtK.NAME, PrecisionAtK.Detail::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new)); namedWriteables.add( new NamedWriteableRegistry.Entry(MetricDetail.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain.Detail::new)); + namedWriteables.add( + new NamedWriteableRegistry.Entry(MetricDetail.class, ExpectedReciprocalRank.NAME, ExpectedReciprocalRank.Detail::new)); return namedWriteables; } diff --git a/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yml b/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yml index fe877b37a68f4..ebe23ae53f411 100644 --- a/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yml +++ b/modules/rank-eval/src/test/resources/rest-api-spec/test/rank_eval/10_basic.yml @@ -161,3 +161,37 @@ setup: - match: {details.berlin_query.metric_details.mean_reciprocal_rank: {"first_relevant": 2}} - match: {details.berlin_query.unrated_docs: [ {"_index": "foo", "_id": "doc1"}]} +--- +"Expected Reciprocal Rank": + + - skip: + version: " - 6.3.99" + reason: ERR was introduced in 6.4 + + - do: + rank_eval: + body: { + "requests" : [ + { + "id": "amsterdam_query", + "request": { "query": { "match" : {"text" : "amsterdam" }}}, + "ratings": [{"_index": "foo", "_id": "doc4", "rating": 1}] + }, + { + "id" : "berlin_query", + "request": { "query": { "match" : { "text" : "berlin" } }, "size" : 10 }, + "ratings": [{"_index": "foo", "_id": "doc4", "rating": 1}] + } + ], + "metric" : { + "expected_reciprocal_rank": { + "maximum_relevance" : 1, + "k" : 5 + } + } + } + + - gt: {metric_score: 0.2083333} + - lt: {metric_score: 0.2083334} + - match: {details.amsterdam_query.metric_details.expected_reciprocal_rank.unrated_docs: 2} + - match: {details.berlin_query.metric_details.expected_reciprocal_rank.unrated_docs: 1}