Skip to content

Commit

Permalink
Add sparse vector query (#3140)
Browse files Browse the repository at this point in the history
  • Loading branch information
Philippus authored Aug 22, 2024
1 parent 377ae6e commit 1155e58
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import com.sksamuel.elastic4s.requests.searches.queries.compound.BoolQuery
import com.sksamuel.elastic4s.requests.searches.queries.funcscorer.FunctionScoreQuery
import com.sksamuel.elastic4s.requests.searches.queries.geo.{GeoBoundingBoxQuery, GeoDistanceQuery, GeoPolygonQuery, GeoShapeQuery, Shape}
import com.sksamuel.elastic4s.requests.searches.queries.matches.{MatchAllQuery, MatchBoolPrefixQuery, MatchNoneQuery, MatchPhrasePrefixQuery, MatchPhraseQuery, MatchQuery, MultiMatchQuery}
import com.sksamuel.elastic4s.requests.searches.queries.{ArtificialDocument, BoostingQuery, CombinedFieldsQuery, ConstantScore, DisMaxQuery, DistanceFeatureQuery, ExistsQuery, FuzzyQuery, HasChildQuery, HasParentQuery, IdQuery, IntervalsQuery, IntervalsRule, MoreLikeThisItem, MoreLikeThisQuery, MultiTermQuery, NestedQuery, PercolateQuery, PinnedQuery, PrefixQuery, Query, QueryStringQuery, RangeQuery, RankFeatureQuery, RawQuery, RegexQuery, ScriptQuery, ScriptScoreQuery, SimpleStringQuery}
import com.sksamuel.elastic4s.requests.searches.queries.{ArtificialDocument, BoostingQuery, CombinedFieldsQuery, ConstantScore, DisMaxQuery, DistanceFeatureQuery, ExistsQuery, FuzzyQuery, HasChildQuery, HasParentQuery, IdQuery, IntervalsQuery, IntervalsRule, MoreLikeThisItem, MoreLikeThisQuery, MultiTermQuery, NestedQuery, PercolateQuery, PinnedQuery, PrefixQuery, Query, QueryStringQuery, RangeQuery, RankFeatureQuery, RawQuery, RegexQuery, ScriptQuery, ScriptScoreQuery, SimpleStringQuery, SparseVectorQuery}
import com.sksamuel.elastic4s.requests.searches.span.{SpanContainingQuery, SpanFieldMaskingQuery, SpanFirstQuery, SpanMultiTermQuery, SpanNearQuery, SpanNotQuery, SpanOrQuery, SpanQuery, SpanTermQuery, SpanWithinQuery}
import com.sksamuel.elastic4s.requests.searches.term.{TermQuery, TermsLookupQuery, TermsQuery, TermsSetQuery, WildcardQuery}
import com.sksamuel.elastic4s.requests.searches.{GeoPoint, ScoreMode, TermsLookup, span, term}
Expand Down Expand Up @@ -269,4 +269,10 @@ trait QueryApi {
// short cut for a boolean query with nots
def not(queries: Query*): BoolQuery = BoolQuery().not(queries: _*)
def not(queries: Iterable[Query]): BoolQuery = BoolQuery().not(queries)

def sparseVectorQuery(field: String, inferenceId: String, query: String): SparseVectorQuery =
SparseVectorQuery(field, inferenceId = Some(inferenceId), query = Some(query))

def sparseVectorQuery(field: String, queryVector: Map[String, Double]): SparseVectorQuery =
SparseVectorQuery(field, queryVector = queryVector)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package com.sksamuel.elastic4s.requests.searches.queries

import com.sksamuel.elastic4s.JsonSugar
import com.sksamuel.elastic4s.api.QueryApi
import com.sksamuel.elastic4s.handlers.searches.queries.SparseVectorQueryBuilderFn
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers

class SparseVectorQueryBuilderFnTest extends AnyFunSuite with QueryApi with Matchers with JsonSugar {
test("Should correctly build minimal sparse vector query") {
val query = SparseVectorQuery("testfield")

val queryBody = SparseVectorQueryBuilderFn(query)

queryBody.string shouldBe """{"sparse_vector":{"field":"testfield"}}"""
}

test("Should correctly build sparse vector query using an nlp model") {
val query = sparseVectorQuery("ml.tokens", "the inference ID to produce the token weights", "the query string")

val queryBody = SparseVectorQueryBuilderFn(query)

queryBody.string should matchJson("""{"sparse_vector": {
| "field":"ml.tokens",
| "inference_id":"the inference ID to produce the token weights",
| "query":"the query string"
|}}""".stripMargin.replace("\n", ""))
}

test("Should correctly build sparse vector query using precomputed vectors") {
val query = sparseVectorQuery("ml.tokens", Map("token1" -> 0.5D, "token2" -> 0.3D, "token3" -> 0.2D))

val queryBody = SparseVectorQueryBuilderFn(query)

queryBody.string should matchJson("""{"sparse_vector": {
| "field": "ml.tokens",
| "query_vector": { "token1": 0.5, "token2": 0.3, "token3": 0.2 }
|}}""".stripMargin.replace("\n", ""))
}

test("Should correctly build sparse vector query with pruning configuration") {
val query = sparseVectorQuery("ml.tokens", "my-elser-model", "How is the weather in Jamaica?")
.prune(true)
.pruningConfig(PruningConfig(
tokensFreqRatioThreshold = Some(5), tokensWeighThreshold = Some(0.4F), onlyScorePrunedTokens = Some(false)
))

val queryBody = SparseVectorQueryBuilderFn(query)

queryBody.string should matchJson("""{"sparse_vector":{
| "field": "ml.tokens",
| "inference_id": "my-elser-model",
| "query":"How is the weather in Jamaica?",
| "prune": true,
| "pruning_config": {
| "tokens_freq_ratio_threshold": 5,
| "tokens_weight_threshold": 0.4000000059604645,
| "only_score_pruned_tokens": false
| }
|}}""".stripMargin.replace("\n", ""))
}

test("Supports boost and queryName") {
val query = SparseVectorQuery("testfield")
.boost(1.0D)
.queryName("abc")

val queryBody = SparseVectorQueryBuilderFn(query)

queryBody.string should matchJson("""{"sparse_vector": {
| "field": "testfield",
| "boost": 1.0,
| "_name": "abc"
|}}""".stripMargin.replace("\n", ""))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.sksamuel.elastic4s.requests.searches.queries

import com.sksamuel.elastic4s.ext.OptionImplicits.RichOptionImplicits

case class PruningConfig(tokensFreqRatioThreshold: Option[Int] = None,
tokensWeighThreshold: Option[Float] = None,
onlyScorePrunedTokens: Option[Boolean] = None)

case class SparseVectorQuery(field: String,
inferenceId: Option[String] = None,
query: Option[String] = None,
queryVector: Map[String, Double] = Map.empty[String, Double],
boost: Option[Double] = None,
queryName: Option[String] = None,
prune: Option[Boolean] = None,
pruningConfig: Option[PruningConfig] = None)
extends Query {

def boost(boost: Double): SparseVectorQuery = copy(boost = boost.some)

def queryName(queryName: String): SparseVectorQuery = copy(queryName = queryName.some)

def prune(prune: Boolean): SparseVectorQuery = copy(prune = prune.some)

def pruningConfig(pruningConfig: PruningConfig): SparseVectorQuery = copy(pruningConfig = pruningConfig.some)
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import com.sksamuel.elastic4s.requests.searches.queries.compound.BoolQuery
import com.sksamuel.elastic4s.requests.searches.queries.funcscorer.FunctionScoreQuery
import com.sksamuel.elastic4s.requests.searches.queries.geo.{GeoBoundingBoxQuery, GeoDistanceQuery, GeoPolygonQuery, GeoShapeQuery}
import com.sksamuel.elastic4s.requests.searches.queries.matches.{MatchAllQuery, MatchBoolPrefixQuery, MatchNoneQuery, MatchPhrasePrefixQuery, MatchPhraseQuery, MatchQuery, MultiMatchQuery}
import com.sksamuel.elastic4s.requests.searches.queries.{BoostingQuery, CombinedFieldsQuery, ConstantScore, CustomQuery, DisMaxQuery, DistanceFeatureQuery, ExistsQuery, FuzzyQuery, HasChildQuery, HasParentQuery, IdQuery, IntervalsQuery, MoreLikeThisQuery, NestedQuery, NoopQuery, ParentIdQuery, PercolateQuery, PinnedQuery, PrefixQuery, Query, QueryStringQuery, RangeQuery, RankFeatureQuery, RawQuery, RegexQuery, ScriptQuery, ScriptScoreQuery, SimpleStringQuery}
import com.sksamuel.elastic4s.requests.searches.queries.{BoostingQuery, CombinedFieldsQuery, ConstantScore, CustomQuery, DisMaxQuery, DistanceFeatureQuery, ExistsQuery, FuzzyQuery, HasChildQuery, HasParentQuery, IdQuery, IntervalsQuery, MoreLikeThisQuery, NestedQuery, NoopQuery, ParentIdQuery, PercolateQuery, PinnedQuery, PrefixQuery, Query, QueryStringQuery, RangeQuery, RankFeatureQuery, RawQuery, RegexQuery, ScriptQuery, ScriptScoreQuery, SimpleStringQuery, SparseVectorQuery}
import com.sksamuel.elastic4s.requests.searches.span.{SpanContainingQuery, SpanFieldMaskingQuery, SpanFirstQuery, SpanMultiTermQuery, SpanNearQuery, SpanNotQuery, SpanOrQuery, SpanTermQuery, SpanWithinQuery}
import com.sksamuel.elastic4s.requests.searches.term.{TermQuery, TermsLookupQuery, TermsQuery, TermsSetQuery, WildcardQuery}

Expand Down Expand Up @@ -65,6 +65,7 @@ object QueryBuilderFn {
case s: SpanOrQuery => SpanOrQueryBodyFn(s)
case s: SpanTermQuery => SpanTermQueryBodyFn(s)
case s: SpanWithinQuery => SpanWithinQueryBodyFn(s)
case s: SparseVectorQuery => SparseVectorQueryBuilderFn(s)
case t: TermQuery => TermQueryBodyFn(t)
case t: TermsQuery[_] => TermsQueryBodyFn(t)
case t: TermsLookupQuery => TermsLookupQueryBodyFn(t)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package com.sksamuel.elastic4s.handlers.searches.queries

import com.sksamuel.elastic4s.json.{XContentBuilder, XContentFactory}
import com.sksamuel.elastic4s.requests.searches.queries.{PruningConfig, SparseVectorQuery}

object SparseVectorQueryBuilderFn {
def apply(q: SparseVectorQuery): XContentBuilder = {
val builder = XContentFactory.jsonBuilder()
builder.startObject("sparse_vector")
builder.field("field", q.field)

q.inferenceId.foreach(builder.field("inference_id", _))
q.query.foreach(builder.field("query", _))
if (q.queryVector.nonEmpty) {
builder.startObject("query_vector")
q.queryVector.foreach { case (k, v) => builder.field(k, v) }
builder.endObject()
}
q.boost.foreach(builder.field("boost", _))
q.queryName.foreach(builder.field("_name", _))
q.prune.foreach(builder.field("prune", _))
q.pruningConfig.foreach { pc =>
if (pc.tokensFreqRatioThreshold.nonEmpty || pc.tokensWeighThreshold.nonEmpty || pc.onlyScorePrunedTokens.nonEmpty) {
builder.startObject("pruning_config")
pc.tokensFreqRatioThreshold.foreach(builder.field("tokens_freq_ratio_threshold", _))
pc.tokensWeighThreshold.foreach(builder.field("tokens_weight_threshold", _))
pc.onlyScorePrunedTokens.foreach(builder.field("only_score_pruned_tokens", _))
builder.endObject()
}
}
builder.endObject()
builder
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -971,4 +971,9 @@ class SearchDslTest extends AnyFlatSpec with MockitoSugar with JsonSugar with On
val req = search("music").matchAllQuery() docValues ("field1", "field2")
req.request.entity.get.get should matchJsonResource("/json/search/search_doc_values.json")
}

it should "generate json for sparse vector query" in {
val req = search("index").query(sparseVectorQuery("test", Map("a" -> 0.3D)))
req.request.entity.get.get should matchJson("""{"query":{"sparse_vector":{"field":"test", "query_vector":{"a":0.3}}}}""")
}
}

0 comments on commit 1155e58

Please sign in to comment.