diff --git a/modules/lang-expression/src/main/java/org/elasticsearch/script/expression/ExpressionScriptEngine.java b/modules/lang-expression/src/main/java/org/elasticsearch/script/expression/ExpressionScriptEngine.java index b50eb788c6f57..1cde9c258b4f1 100644 --- a/modules/lang-expression/src/main/java/org/elasticsearch/script/expression/ExpressionScriptEngine.java +++ b/modules/lang-expression/src/main/java/org/elasticsearch/script/expression/ExpressionScriptEngine.java @@ -23,8 +23,10 @@ import org.apache.lucene.expressions.SimpleBindings; import org.apache.lucene.expressions.js.JavascriptCompiler; import org.apache.lucene.expressions.js.VariableContext; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.queries.function.valuesource.DoubleConstValueSource; +import org.apache.lucene.search.Scorer; import org.apache.lucene.search.SortField; import org.elasticsearch.SpecialPermission; import org.elasticsearch.common.Nullable; @@ -39,12 +41,14 @@ import org.elasticsearch.script.ClassPermission; import org.elasticsearch.script.ExecutableScript; import org.elasticsearch.script.FilterScript; +import org.elasticsearch.script.ScoreScript; import org.elasticsearch.script.ScriptContext; import org.elasticsearch.script.ScriptEngine; import org.elasticsearch.script.ScriptException; import org.elasticsearch.script.SearchScript; import org.elasticsearch.search.lookup.SearchLookup; +import java.io.IOException; import java.security.AccessControlContext; import java.security.AccessController; import java.security.PrivilegedAction; @@ -111,6 +115,9 @@ protected Class loadClass(String name, boolean resolve) throws ClassNotFoundE } else if (context.instanceClazz.equals(FilterScript.class)) { FilterScript.Factory factory = (p, lookup) -> newFilterScript(expr, lookup, p); return context.factoryClazz.cast(factory); + } else if (context.instanceClazz.equals(ScoreScript.class)) { + ScoreScript.Factory factory = (p, lookup) -> newScoreScript(expr, lookup, p); + return context.factoryClazz.cast(factory); } throw new IllegalArgumentException("expression engine does not know how to handle script context [" + context.name + "]"); } @@ -260,6 +267,42 @@ public void setDocument(int docid) { }; }; } + + private ScoreScript.LeafFactory newScoreScript(Expression expr, SearchLookup lookup, @Nullable Map vars) { + SearchScript.LeafFactory searchLeafFactory = newSearchScript(expr, lookup, vars); + return new ScoreScript.LeafFactory() { + @Override + public boolean needs_score() { + return searchLeafFactory.needs_score(); + } + + @Override + public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { + SearchScript script = searchLeafFactory.newInstance(ctx); + return new ScoreScript(vars, lookup, ctx) { + @Override + public double execute() { + return script.runAsDouble(); + } + + @Override + public void setDocument(int docid) { + script.setDocument(docid); + } + + @Override + public void setScorer(Scorer scorer) { + script.setScorer(scorer); + } + + @Override + public double get_score() { + return script.getScore(); + } + }; + } + }; + } /** * converts a ParseException at compile-time or link-time to a ScriptException diff --git a/plugins/examples/script-expert-scoring/src/main/java/org/elasticsearch/example/expertscript/ExpertScriptPlugin.java b/plugins/examples/script-expert-scoring/src/main/java/org/elasticsearch/example/expertscript/ExpertScriptPlugin.java index b910526ef3d98..cead97696a028 100644 --- a/plugins/examples/script-expert-scoring/src/main/java/org/elasticsearch/example/expertscript/ExpertScriptPlugin.java +++ b/plugins/examples/script-expert-scoring/src/main/java/org/elasticsearch/example/expertscript/ExpertScriptPlugin.java @@ -30,9 +30,9 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.ScriptPlugin; +import org.elasticsearch.script.ScoreScript; import org.elasticsearch.script.ScriptContext; import org.elasticsearch.script.ScriptEngine; -import org.elasticsearch.script.SearchScript; /** * An example script plugin that adds a {@link ScriptEngine} implementing expert scoring. @@ -54,12 +54,12 @@ public String getType() { @Override public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { - if (context.equals(SearchScript.SCRIPT_SCORE_CONTEXT) == false) { + if (context.equals(ScoreScript.CONTEXT) == false) { throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); } // we use the script "source" as the script identifier if ("pure_df".equals(scriptSource)) { - SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() { + ScoreScript.Factory factory = (p, lookup) -> new ScoreScript.LeafFactory() { final String field; final String term; { @@ -74,18 +74,18 @@ public T compile(String scriptName, String scriptSource, ScriptContext co } @Override - public SearchScript newInstance(LeafReaderContext context) throws IOException { + public ScoreScript newInstance(LeafReaderContext context) throws IOException { PostingsEnum postings = context.reader().postings(new Term(field, term)); if (postings == null) { // the field and/or term don't exist in this segment, so always return 0 - return new SearchScript(p, lookup, context) { + return new ScoreScript(p, lookup, context) { @Override - public double runAsDouble() { + public double execute() { return 0.0d; } }; } - return new SearchScript(p, lookup, context) { + return new ScoreScript(p, lookup, context) { int currentDocid = -1; @Override public void setDocument(int docid) { @@ -100,7 +100,7 @@ public void setDocument(int docid) { currentDocid = docid; } @Override - public double runAsDouble() { + public double execute() { if (postings.docID() != currentDocid) { // advance moved past the current doc, so this doc has no occurrences of the term return 0.0d; diff --git a/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java index bcca4c4a03580..7f8b10349bc7d 100644 --- a/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java +++ b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java @@ -24,8 +24,8 @@ import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Scorer; import org.elasticsearch.script.ExplainableSearchScript; +import org.elasticsearch.script.ScoreScript; import org.elasticsearch.script.Script; -import org.elasticsearch.script.SearchScript; import java.io.IOException; import java.util.Objects; @@ -58,10 +58,10 @@ public DocIdSetIterator iterator() { private final Script sScript; - private final SearchScript.LeafFactory script; + private final ScoreScript.LeafFactory script; - public ScriptScoreFunction(Script sScript, SearchScript.LeafFactory script) { + public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script) { super(CombineFunction.REPLACE); this.sScript = sScript; this.script = script; @@ -69,7 +69,7 @@ public ScriptScoreFunction(Script sScript, SearchScript.LeafFactory script) { @Override public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOException { - final SearchScript leafScript = script.newInstance(ctx); + final ScoreScript leafScript = script.newInstance(ctx); final CannedScorer scorer = new CannedScorer(); leafScript.setScorer(scorer); return new LeafScoreFunction() { @@ -78,7 +78,7 @@ public double score(int docId, float subQueryScore) throws IOException { leafScript.setDocument(docId); scorer.docid = docId; scorer.score = subQueryScore; - double result = leafScript.runAsDouble(); + double result = leafScript.execute(); return result; } diff --git a/server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java b/server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java index ed4c5f5a26952..9592ffe0b1fe5 100644 --- a/server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java @@ -28,6 +28,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.index.query.QueryShardException; +import org.elasticsearch.script.ScoreScript; import org.elasticsearch.script.Script; import org.elasticsearch.script.SearchScript; @@ -92,8 +93,8 @@ protected int doHashCode() { @Override protected ScoreFunction doToFunction(QueryShardContext context) { try { - SearchScript.Factory factory = context.getScriptService().compile(script, SearchScript.SCRIPT_SCORE_CONTEXT); - SearchScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup()); + ScoreScript.Factory factory = context.getScriptService().compile(script, ScoreScript.CONTEXT); + ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup()); return new ScriptScoreFunction(script, searchScript); } catch (Exception e) { throw new QueryShardException(context, "script_score: the script could not be loaded", e); diff --git a/server/src/main/java/org/elasticsearch/script/ScoreScript.java b/server/src/main/java/org/elasticsearch/script/ScoreScript.java new file mode 100644 index 0000000000000..d9e56d5573cae --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/ScoreScript.java @@ -0,0 +1,102 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.script; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Scorer; +import org.elasticsearch.index.fielddata.ScriptDocValues; +import org.elasticsearch.search.lookup.LeafSearchLookup; +import org.elasticsearch.search.lookup.SearchLookup; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Map; +import java.util.function.DoubleSupplier; + +/** + * A script used for adjusting the score on a per document basis. + */ +public abstract class ScoreScript { + + public static final String[] PARAMETERS = new String[]{}; + + /** The generic runtime parameters for the script. */ + private final Map params; + + /** A leaf lookup for the bound segment this script will operate on. */ + private final LeafSearchLookup leafLookup; + + private DoubleSupplier scoreSupplier = () -> 0.0; + + public ScoreScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { + this.params = params; + this.leafLookup = lookup.getLeafSearchLookup(leafContext); + } + + public abstract double execute(); + + /** Return the parameters for this script. */ + public Map getParams() { + return params; + } + + /** The doc lookup for the Lucene segment this script was created for. */ + public final Map> getDoc() { + return leafLookup.doc(); + } + + /** Set the current document to run the script on next. */ + public void setDocument(int docid) { + leafLookup.setDocument(docid); + } + + public void setScorer(Scorer scorer) { + this.scoreSupplier = () -> { + try { + return scorer.score(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }; + } + + public double get_score() { + return scoreSupplier.getAsDouble(); + } + + /** A factory to construct {@link ScoreScript} instances. */ + public interface LeafFactory { + + /** + * Return {@code true} if the script needs {@code _score} calculated, or {@code false} otherwise. + */ + boolean needs_score(); + + ScoreScript newInstance(LeafReaderContext ctx) throws IOException; + } + + /** A factory to construct stateful {@link ScoreScript} factories for a specific index. */ + public interface Factory { + + ScoreScript.LeafFactory newFactory(Map params, SearchLookup lookup); + + } + + public static final ScriptContext CONTEXT = new ScriptContext<>("score", ScoreScript.Factory.class); +} diff --git a/server/src/main/java/org/elasticsearch/script/ScriptModule.java b/server/src/main/java/org/elasticsearch/script/ScriptModule.java index 583421be8e581..7074d3ad9fe44 100644 --- a/server/src/main/java/org/elasticsearch/script/ScriptModule.java +++ b/server/src/main/java/org/elasticsearch/script/ScriptModule.java @@ -42,7 +42,7 @@ public class ScriptModule { CORE_CONTEXTS = Stream.of( SearchScript.CONTEXT, SearchScript.AGGS_CONTEXT, - SearchScript.SCRIPT_SCORE_CONTEXT, + ScoreScript.CONTEXT, SearchScript.SCRIPT_SORT_CONTEXT, SearchScript.TERMS_SET_QUERY_CONTEXT, ExecutableScript.CONTEXT, diff --git a/server/src/main/java/org/elasticsearch/script/SearchScript.java b/server/src/main/java/org/elasticsearch/script/SearchScript.java index e5762adb1bbe9..43ea020aa6e24 100644 --- a/server/src/main/java/org/elasticsearch/script/SearchScript.java +++ b/server/src/main/java/org/elasticsearch/script/SearchScript.java @@ -162,8 +162,6 @@ public interface Factory { public static final ScriptContext AGGS_CONTEXT = new ScriptContext<>("aggs", Factory.class); // Can return a double. (For ScriptSortType#NUMBER only, for ScriptSortType#STRING normal CONTEXT should be used) public static final ScriptContext SCRIPT_SORT_CONTEXT = new ScriptContext<>("sort", Factory.class); - // Can return a float - public static final ScriptContext SCRIPT_SCORE_CONTEXT = new ScriptContext<>("score", Factory.class); // Can return a long public static final ScriptContext TERMS_SET_QUERY_CONTEXT = new ScriptContext<>("terms_set", Factory.class); } diff --git a/server/src/test/java/org/elasticsearch/search/functionscore/ExplainableScriptIT.java b/server/src/test/java/org/elasticsearch/search/functionscore/ExplainableScriptIT.java index 842748107d1d1..6657ad9823ffe 100644 --- a/server/src/test/java/org/elasticsearch/search/functionscore/ExplainableScriptIT.java +++ b/server/src/test/java/org/elasticsearch/search/functionscore/ExplainableScriptIT.java @@ -30,14 +30,14 @@ import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.ScriptPlugin; import org.elasticsearch.script.ExplainableSearchScript; +import org.elasticsearch.script.ScoreScript; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptContext; import org.elasticsearch.script.ScriptEngine; import org.elasticsearch.script.ScriptType; -import org.elasticsearch.script.SearchScript; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; -import org.elasticsearch.search.lookup.LeafDocLookup; +import org.elasticsearch.search.lookup.SearchLookup; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.ESIntegTestCase.ClusterScope; import org.elasticsearch.test.ESIntegTestCase.Scope; @@ -76,16 +76,17 @@ public String getType() { @Override public T compile(String scriptName, String scriptSource, ScriptContext context, Map params) { assert scriptSource.equals("explainable_script"); - assert context == SearchScript.SCRIPT_SCORE_CONTEXT; - SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() { - @Override - public SearchScript newInstance(LeafReaderContext context) throws IOException { - return new MyScript(lookup.doc().getLeafDocLookup(context)); - } + assert context == ScoreScript.CONTEXT; + ScoreScript.Factory factory = (params1, lookup) -> new ScoreScript.LeafFactory() { @Override public boolean needs_score() { return false; } + + @Override + public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { + return new MyScript(params1, lookup, ctx); + } }; return context.factoryClazz.cast(factory); } @@ -93,28 +94,21 @@ public boolean needs_score() { } } - static class MyScript extends SearchScript implements ExplainableSearchScript { - LeafDocLookup docLookup; + static class MyScript extends ScoreScript implements ExplainableSearchScript { - MyScript(LeafDocLookup docLookup) { - super(null, null, null); - this.docLookup = docLookup; + MyScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { + super(params, lookup, leafContext); } - - @Override - public void setDocument(int doc) { - docLookup.setDocument(doc); - } - + @Override public Explanation explain(Explanation subQueryScore) throws IOException { Explanation scoreExp = Explanation.match(subQueryScore.getValue(), "_score: ", subQueryScore); - return Explanation.match((float) (runAsDouble()), "This script returned " + runAsDouble(), scoreExp); + return Explanation.match((float) (execute()), "This script returned " + execute(), scoreExp); } @Override - public double runAsDouble() { - return ((Number) ((ScriptDocValues) docLookup.get("number_field")).getValues().get(0)).doubleValue(); + public double execute() { + return ((Number) ((ScriptDocValues) getDoc().get("number_field")).getValues().get(0)).doubleValue(); } } diff --git a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java index 00303b344b92a..b86cb9ff29352 100644 --- a/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java +++ b/test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java @@ -25,7 +25,6 @@ import org.elasticsearch.index.similarity.ScriptedSimilarity.Field; import org.elasticsearch.index.similarity.ScriptedSimilarity.Query; import org.elasticsearch.index.similarity.ScriptedSimilarity.Term; -import org.elasticsearch.index.similarity.SimilarityService; import org.elasticsearch.search.aggregations.pipeline.movfn.MovingFunctionScript; import org.elasticsearch.search.aggregations.pipeline.movfn.MovingFunctions; import org.elasticsearch.search.lookup.LeafSearchLookup; @@ -36,7 +35,6 @@ import java.util.HashMap; import java.util.Map; import java.util.function.Function; -import java.util.function.Predicate; import static java.util.Collections.emptyMap; @@ -114,6 +112,9 @@ public String execute() { } else if (context.instanceClazz.equals(MovingFunctionScript.class)) { MovingFunctionScript.Factory factory = mockCompiled::createMovingFunctionScript; return context.factoryClazz.cast(factory); + } else if (context.instanceClazz.equals(ScoreScript.class)) { + ScoreScript.Factory factory = new MockScoreScript(script); + return context.factoryClazz.cast(factory); } throw new IllegalArgumentException("mock script engine does not know how to handle context [" + context.name + "]"); } @@ -342,5 +343,45 @@ public double execute(Map params, double[] values) { return MovingFunctions.unweightedAvg(values); } } + + public class MockScoreScript implements ScoreScript.Factory { + + private final Function, Object> scripts; + + MockScoreScript(Function, Object> scripts) { + this.scripts = scripts; + } + + @Override + public ScoreScript.LeafFactory newFactory(Map params, SearchLookup lookup) { + return new ScoreScript.LeafFactory() { + @Override + public boolean needs_score() { + return true; + } + + @Override + public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { + Scorer[] scorerHolder = new Scorer[1]; + return new ScoreScript(params, lookup, ctx) { + @Override + public double execute() { + Map vars = new HashMap<>(getParams()); + vars.put("doc", getDoc()); + if (scorerHolder[0] != null) { + vars.put("_score", new ScoreAccessor(scorerHolder[0])); + } + return ((Number) scripts.apply(vars)).doubleValue(); + } + + @Override + public void setScorer(Scorer scorer) { + scorerHolder[0] = scorer; + } + }; + } + }; + } + } }