Skip to content

Commit

Permalink
Commit Jim's fix
Browse files Browse the repository at this point in the history
  • Loading branch information
kderusso committed Oct 15, 2024
1 parent bd23441 commit d80feee
Showing 1 changed file with 129 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,29 @@

package org.elasticsearch.xpack.application.rules.retriever;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.TransportSearchAction;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.StoredFieldsContext;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RankDocsRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.ScoreSortBuilder;
import org.elasticsearch.search.sort.ShardDocSortField;
import org.elasticsearch.search.sort.SortBuilder;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
Expand All @@ -29,6 +39,7 @@
import org.elasticsearch.xpack.core.XPackPlugin;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand All @@ -40,7 +51,7 @@
/**
* A query rule retriever applies query rules defined in one or more rulesets to the underlying retriever.
*/
public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder<QueryRuleRetrieverBuilder> {
public final class QueryRuleRetrieverBuilder extends RetrieverBuilder {

public static final String NAME = "rule";
public static final NodeFeature QUERY_RULE_RETRIEVERS_SUPPORTED = new NodeFeature("query_rule_retriever_supported");
Expand Down Expand Up @@ -86,29 +97,29 @@ public static QueryRuleRetrieverBuilder fromXContent(XContentParser parser, Retr

private final List<String> rulesetIds;
private final Map<String, Object> matchCriteria;
private final CompoundRetrieverBuilder.RetrieverSource subRetriever;
private final int rankWindowSize;
private boolean executed = false;

public QueryRuleRetrieverBuilder(
List<String> rulesetIds,
Map<String, Object> matchCriteria,
RetrieverBuilder retrieverBuilder,
RetrieverBuilder subRetriever,
int rankWindowSize
) {
super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize);
this.rulesetIds = rulesetIds;
this.matchCriteria = matchCriteria;
this(rulesetIds, matchCriteria, new CompoundRetrieverBuilder.RetrieverSource(subRetriever, null), rankWindowSize);
}

public QueryRuleRetrieverBuilder(
private QueryRuleRetrieverBuilder(
List<String> rulesetIds,
Map<String, Object> matchCriteria,
List<RetrieverSource> retrieverSource,
int rankWindowSize,
String retrieverName
CompoundRetrieverBuilder.RetrieverSource subRetriever,
int rankWindowSize
) {
super(retrieverSource, rankWindowSize);
this.subRetriever = subRetriever;
this.rulesetIds = rulesetIds;
this.matchCriteria = matchCriteria;
this.retrieverName = retrieverName;
this.rankWindowSize = rankWindowSize;
}

@Override
Expand All @@ -117,20 +128,98 @@ public String getName() {
}

@Override
public boolean isCompound() {
return executed == false;
}

@Override
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
if (ctx.getPointInTimeBuilder() == null) {
throw new IllegalStateException("PIT is required");
}

if (executed) {
return this;
}

// Rewrite prefilters
var newPreFilters = rewritePreFilters(ctx);
if (newPreFilters != preFilterQueryBuilders) {
var ret = new QueryRuleRetrieverBuilder(rulesetIds, matchCriteria, subRetriever, rankWindowSize);
ret.preFilterQueryBuilders = newPreFilters;
return ret;
}

// Rewrite retriever sources
var newRetriever = subRetriever.retriever().rewrite(ctx);
if (newRetriever != subRetriever.retriever()) {
return new QueryRuleRetrieverBuilder(rulesetIds, matchCriteria, newRetriever, rankWindowSize);
} else {
var newSource = subRetriever.source() != null
? subRetriever.source()
: createSearchSourceBuilder(ctx.getPointInTimeBuilder(), newRetriever);
var rewrittenSource = newSource.rewrite(ctx);
if (rewrittenSource != subRetriever.source()) {
return new QueryRuleRetrieverBuilder(
rulesetIds,
matchCriteria,
new CompoundRetrieverBuilder.RetrieverSource(newRetriever, rewrittenSource),
rankWindowSize
);
}
}

// execute searches
final SetOnce<RankDoc[]> results = new SetOnce<>();
final SearchRequest searchRequest = new SearchRequest().source(subRetriever.source());
// The can match phase can reorder shards, so we disable it to ensure the stable ordering
searchRequest.setPreFilterShardSize(Integer.MAX_VALUE);
ctx.registerAsyncAction((client, listener) -> {
client.execute(TransportSearchAction.TYPE, searchRequest, new ActionListener<>() {
@Override
public void onResponse(SearchResponse resp) {
var rankDocs = getRankDocs(resp);
results.set(rankDocs);
listener.onResponse(null);
}

@Override
public void onFailure(Exception e) {
listener.onFailure(e);
}
});
});

executed = true;
return new RankDocsRetrieverBuilder(rankWindowSize, List.of(this), results::get, newPreFilters);
}

@Override
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
throw new IllegalStateException("Should not be called, missing a rewrite?");
}

protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
.trackTotalHits(false)
.storedFields(new StoredFieldsContext(false))
.size(rankWindowSize);
retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true);
// TODO: ensure that the inner sort is by relevance and throw an error otherwise.

QueryBuilder query = sourceBuilder.query();
if (query != null && query instanceof RuleQueryBuilder == false) {
QueryBuilder ruleQuery = new RuleQueryBuilder(query, matchCriteria, rulesetIds);
sourceBuilder.query(ruleQuery);
}

addSort(sourceBuilder);
// Record the shard id in the sort result
List<SortBuilder<?>> sortBuilders = sourceBuilder.sorts() != null ? new ArrayList<>(sourceBuilder.sorts()) : new ArrayList<>();
if (sortBuilders.isEmpty()) {
sortBuilders.add(new ScoreSortBuilder());
}
sortBuilders.add(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME));
sourceBuilder.sort(sortBuilders);

return sourceBuilder;
}
Expand All @@ -144,36 +233,41 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
}

@Override
protected QueryRuleRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
return new QueryRuleRetrieverBuilder(rulesetIds, matchCriteria, newChildRetrievers, rankWindowSize, retrieverName);
}

@Override
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
assert rankResults.size() == 1;
ScoreDoc[] scoreDocs = rankResults.getFirst();
RankDoc[] rankDocs = new RankDoc[scoreDocs.length];
for (int i = 0; i < scoreDocs.length; i++) {
ScoreDoc scoreDoc = scoreDocs[i];
rankDocs[i] = new RankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
public QueryBuilder topDocsQuery() {
QueryBuilder topDocsQuery = subRetriever.source().query();
if (preFilterQueryBuilders.isEmpty()) {
topDocsQuery.queryName(this.retrieverName);
return topDocsQuery;
}
return rankDocs;
}

@Override
public QueryBuilder explainQuery() {
// the original matching set of the QueryRuleRetriever retriever is specified by its nested retriever
return new RankDocsQueryBuilder(rankDocs, new QueryBuilder[] { innerRetrievers.getFirst().retriever().explainQuery() }, true);
var ret = new BoolQueryBuilder().filter(topDocsQuery).queryName(this.retrieverName);
preFilterQueryBuilders.stream().forEach(ret::filter);
return subRetriever.source().query();
}

@Override
public boolean doEquals(Object o) {
QueryRuleRetrieverBuilder that = (QueryRuleRetrieverBuilder) o;
return super.doEquals(o) && Objects.equals(rulesetIds, that.rulesetIds) && Objects.equals(matchCriteria, that.matchCriteria);
return Objects.equals(rulesetIds, that.rulesetIds)
&& Objects.equals(matchCriteria, that.matchCriteria)
&& subRetriever.equals(that.subRetriever);
}

@Override
public int doHashCode() {
return Objects.hash(super.doHashCode(), rulesetIds, matchCriteria);
return Objects.hash(subRetriever, rulesetIds, matchCriteria);
}

private RankDoc[] getRankDocs(SearchResponse searchResponse) {
int size = searchResponse.getHits().getHits().length;
RankDoc[] docs = new RankDoc[size];
for (int i = 0; i < size; i++) {
var hit = searchResponse.getHits().getAt(i);
long sortValue = (long) hit.getRawSortValues()[hit.getRawSortValues().length - 1];
int doc = ShardDocSortField.decodeDoc(sortValue);
int shardRequestIndex = ShardDocSortField.decodeShardRequestIndex(sortValue);
docs[i] = new RankDoc(doc, hit.getScore(), shardRequestIndex);
docs[i].rank = i + 1;
}
return docs;
}
}

0 comments on commit d80feee

Please sign in to comment.