diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index 22bef026523e9..8b6a2c4e7b078 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -229,6 +229,12 @@ protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, sourceBuilder.query(newQuery); } + addSort(sourceBuilder); + + return sourceBuilder; + } + + protected void addSort(SearchSourceBuilder sourceBuilder) { // Record the shard id in the sort result List> sortBuilders = sourceBuilder.sorts() != null ? new ArrayList<>(sourceBuilder.sorts()) : new ArrayList<>(); if (sortBuilders.isEmpty()) { @@ -236,7 +242,6 @@ protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, } sortBuilders.add(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME)); sourceBuilder.sort(sortBuilders); - return sourceBuilder; } private RankDoc[] getRankDocs(SearchResponse searchResponse) { diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java index 95757a70e28a0..6982ec65190ed 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java @@ -10,11 +10,8 @@ import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.common.ParsingException; import org.elasticsearch.features.NodeFeature; -import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.license.LicenseUtils; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.fetch.StoredFieldsContext; @@ -23,7 +20,9 @@ import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder; +import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.ScoreSortBuilder; +import org.elasticsearch.search.sort.SortBuilder; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; @@ -108,14 +107,12 @@ public QueryRuleRetrieverBuilder( Map matchCriteria, List retrieverSource, int rankWindowSize, - String retrieverName, - List preFilterQueryBuilders + String retrieverName ) { super(retrieverSource, rankWindowSize); this.rulesetIds = rulesetIds; this.matchCriteria = matchCriteria; this.retrieverName = retrieverName; - this.preFilterQueryBuilders = new ArrayList<>(preFilterQueryBuilders); } @Override @@ -125,35 +122,20 @@ public String getName() { @Override protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { - Logger logger = LogManager.getLogger(QueryRuleRetrieverBuilder.class); var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit) .trackTotalHits(false) .storedFields(new StoredFieldsContext(false)) .size(rankWindowSize); - if (preFilterQueryBuilders.isEmpty() == false) { - retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders); - } retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true); QueryBuilder query = sourceBuilder.query(); if (query != null && query instanceof RuleQueryBuilder == false) { - QueryBuilder organicQuery = query; - query = new RuleQueryBuilder(organicQuery, matchCriteria, rulesetIds); - } - - // apply the pre-filters - if (preFilterQueryBuilders.size() > 0) { - BoolQueryBuilder newQuery = new BoolQueryBuilder(); - if (query != null) { - newQuery.must(query); - } - preFilterQueryBuilders.forEach(newQuery::filter); - sourceBuilder.query(newQuery); + QueryBuilder ruleQuery = new RuleQueryBuilder(query, matchCriteria, rulesetIds); + sourceBuilder.query(ruleQuery); } - sourceBuilder.sort(new ScoreSortBuilder()); + addSort(sourceBuilder); - logger.info("sourceBuilder: " + sourceBuilder); return sourceBuilder; } @@ -167,14 +149,7 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept @Override protected QueryRuleRetrieverBuilder clone(List newChildRetrievers) { - return new QueryRuleRetrieverBuilder( - rulesetIds, - matchCriteria, - newChildRetrievers, - rankWindowSize, - retrieverName, - preFilterQueryBuilders - ); + return new QueryRuleRetrieverBuilder(rulesetIds, matchCriteria, newChildRetrievers, rankWindowSize, retrieverName); } @Override