diff --git a/elastic4s-core/src/main/scala/com/sksamuel/elastic4s/api/QueryApi.scala b/elastic4s-core/src/main/scala/com/sksamuel/elastic4s/api/QueryApi.scala index 3ebb514e9..0331250a1 100644 --- a/elastic4s-core/src/main/scala/com/sksamuel/elastic4s/api/QueryApi.scala +++ b/elastic4s-core/src/main/scala/com/sksamuel/elastic4s/api/QueryApi.scala @@ -7,7 +7,7 @@ import com.sksamuel.elastic4s.requests.searches.queries.compound.BoolQuery import com.sksamuel.elastic4s.requests.searches.queries.funcscorer.FunctionScoreQuery import com.sksamuel.elastic4s.requests.searches.queries.geo.{GeoBoundingBoxQuery, GeoDistanceQuery, GeoPolygonQuery, GeoShapeQuery, Shape} import com.sksamuel.elastic4s.requests.searches.queries.matches.{MatchAllQuery, MatchBoolPrefixQuery, MatchNoneQuery, MatchPhrasePrefixQuery, MatchPhraseQuery, MatchQuery, MultiMatchQuery} -import com.sksamuel.elastic4s.requests.searches.queries.{ArtificialDocument, BoostingQuery, CombinedFieldsQuery, ConstantScore, DisMaxQuery, DistanceFeatureQuery, ExistsQuery, FuzzyQuery, HasChildQuery, HasParentQuery, IdQuery, IntervalsQuery, IntervalsRule, MoreLikeThisItem, MoreLikeThisQuery, MultiTermQuery, NestedQuery, PercolateQuery, PinnedQuery, PrefixQuery, Query, QueryStringQuery, RangeQuery, RankFeatureQuery, RawQuery, RegexQuery, ScriptQuery, ScriptScoreQuery, SimpleStringQuery} +import com.sksamuel.elastic4s.requests.searches.queries.{ArtificialDocument, BoostingQuery, CombinedFieldsQuery, ConstantScore, DisMaxQuery, DistanceFeatureQuery, ExistsQuery, FuzzyQuery, HasChildQuery, HasParentQuery, IdQuery, IntervalsQuery, IntervalsRule, MoreLikeThisItem, MoreLikeThisQuery, MultiTermQuery, NestedQuery, PercolateQuery, PinnedQuery, PrefixQuery, Query, QueryStringQuery, RangeQuery, RankFeatureQuery, RawQuery, RegexQuery, ScriptQuery, ScriptScoreQuery, SimpleStringQuery, SparseVectorQuery} import com.sksamuel.elastic4s.requests.searches.span.{SpanContainingQuery, SpanFieldMaskingQuery, SpanFirstQuery, SpanMultiTermQuery, SpanNearQuery, SpanNotQuery, SpanOrQuery, SpanQuery, SpanTermQuery, SpanWithinQuery} import com.sksamuel.elastic4s.requests.searches.term.{TermQuery, TermsLookupQuery, TermsQuery, TermsSetQuery, WildcardQuery} import com.sksamuel.elastic4s.requests.searches.{GeoPoint, ScoreMode, TermsLookup, span, term} @@ -269,4 +269,10 @@ trait QueryApi { // short cut for a boolean query with nots def not(queries: Query*): BoolQuery = BoolQuery().not(queries: _*) def not(queries: Iterable[Query]): BoolQuery = BoolQuery().not(queries) + + def sparseVectorQuery(field: String, inferenceId: String, query: String): SparseVectorQuery = + SparseVectorQuery(field, inferenceId = Some(inferenceId), query = Some(query)) + + def sparseVectorQuery(field: String, queryVector: Map[String, Double]): SparseVectorQuery = + SparseVectorQuery(field, queryVector = queryVector) } diff --git a/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/searches/queries/SparseVectorQueryBuilderFnTest.scala b/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/searches/queries/SparseVectorQueryBuilderFnTest.scala new file mode 100644 index 000000000..566aa9aee --- /dev/null +++ b/elastic4s-core/src/test/scala/com/sksamuel/elastic4s/requests/searches/queries/SparseVectorQueryBuilderFnTest.scala @@ -0,0 +1,76 @@ +package com.sksamuel.elastic4s.requests.searches.queries + +import com.sksamuel.elastic4s.JsonSugar +import com.sksamuel.elastic4s.api.QueryApi +import com.sksamuel.elastic4s.handlers.searches.queries.SparseVectorQueryBuilderFn +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers + +class SparseVectorQueryBuilderFnTest extends AnyFunSuite with QueryApi with Matchers with JsonSugar { + test("Should correctly build minimal sparse vector query") { + val query = SparseVectorQuery("testfield") + + val queryBody = SparseVectorQueryBuilderFn(query) + + queryBody.string shouldBe """{"sparse_vector":{"field":"testfield"}}""" + } + + test("Should correctly build sparse vector query using an nlp model") { + val query = sparseVectorQuery("ml.tokens", "the inference ID to produce the token weights", "the query string") + + val queryBody = SparseVectorQueryBuilderFn(query) + + queryBody.string should matchJson("""{"sparse_vector": { + | "field":"ml.tokens", + | "inference_id":"the inference ID to produce the token weights", + | "query":"the query string" + |}}""".stripMargin.replace("\n", "")) + } + + test("Should correctly build sparse vector query using precomputed vectors") { + val query = sparseVectorQuery("ml.tokens", Map("token1" -> 0.5D, "token2" -> 0.3D, "token3" -> 0.2D)) + + val queryBody = SparseVectorQueryBuilderFn(query) + + queryBody.string should matchJson("""{"sparse_vector": { + | "field": "ml.tokens", + | "query_vector": { "token1": 0.5, "token2": 0.3, "token3": 0.2 } + |}}""".stripMargin.replace("\n", "")) + } + + test("Should correctly build sparse vector query with pruning configuration") { + val query = sparseVectorQuery("ml.tokens", "my-elser-model", "How is the weather in Jamaica?") + .prune(true) + .pruningConfig(PruningConfig( + tokensFreqRatioThreshold = Some(5), tokensWeighThreshold = Some(0.4F), onlyScorePrunedTokens = Some(false) + )) + + val queryBody = SparseVectorQueryBuilderFn(query) + + queryBody.string should matchJson("""{"sparse_vector":{ + | "field": "ml.tokens", + | "inference_id": "my-elser-model", + | "query":"How is the weather in Jamaica?", + | "prune": true, + | "pruning_config": { + | "tokens_freq_ratio_threshold": 5, + | "tokens_weight_threshold": 0.4000000059604645, + | "only_score_pruned_tokens": false + | } + |}}""".stripMargin.replace("\n", "")) + } + + test("Supports boost and queryName") { + val query = SparseVectorQuery("testfield") + .boost(1.0D) + .queryName("abc") + + val queryBody = SparseVectorQueryBuilderFn(query) + + queryBody.string should matchJson("""{"sparse_vector": { + | "field": "testfield", + | "boost": 1.0, + | "_name": "abc" + |}}""".stripMargin.replace("\n", "")) + } +} diff --git a/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/requests/searches/queries/SparseVectorQuery.scala b/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/requests/searches/queries/SparseVectorQuery.scala new file mode 100644 index 000000000..8d27c4b54 --- /dev/null +++ b/elastic4s-domain/src/main/scala/com/sksamuel/elastic4s/requests/searches/queries/SparseVectorQuery.scala @@ -0,0 +1,26 @@ +package com.sksamuel.elastic4s.requests.searches.queries + +import com.sksamuel.elastic4s.ext.OptionImplicits.RichOptionImplicits + +case class PruningConfig(tokensFreqRatioThreshold: Option[Int] = None, + tokensWeighThreshold: Option[Float] = None, + onlyScorePrunedTokens: Option[Boolean] = None) + +case class SparseVectorQuery(field: String, + inferenceId: Option[String] = None, + query: Option[String] = None, + queryVector: Map[String, Double] = Map.empty[String, Double], + boost: Option[Double] = None, + queryName: Option[String] = None, + prune: Option[Boolean] = None, + pruningConfig: Option[PruningConfig] = None) + extends Query { + + def boost(boost: Double): SparseVectorQuery = copy(boost = boost.some) + + def queryName(queryName: String): SparseVectorQuery = copy(queryName = queryName.some) + + def prune(prune: Boolean): SparseVectorQuery = copy(prune = prune.some) + + def pruningConfig(pruningConfig: PruningConfig): SparseVectorQuery = copy(pruningConfig = pruningConfig.some) +} diff --git a/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/searches/queries/QueryBuilderFn.scala b/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/searches/queries/QueryBuilderFn.scala index 45864219c..502062c81 100644 --- a/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/searches/queries/QueryBuilderFn.scala +++ b/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/searches/queries/QueryBuilderFn.scala @@ -12,7 +12,7 @@ import com.sksamuel.elastic4s.requests.searches.queries.compound.BoolQuery import com.sksamuel.elastic4s.requests.searches.queries.funcscorer.FunctionScoreQuery import com.sksamuel.elastic4s.requests.searches.queries.geo.{GeoBoundingBoxQuery, GeoDistanceQuery, GeoPolygonQuery, GeoShapeQuery} import com.sksamuel.elastic4s.requests.searches.queries.matches.{MatchAllQuery, MatchBoolPrefixQuery, MatchNoneQuery, MatchPhrasePrefixQuery, MatchPhraseQuery, MatchQuery, MultiMatchQuery} -import com.sksamuel.elastic4s.requests.searches.queries.{BoostingQuery, CombinedFieldsQuery, ConstantScore, CustomQuery, DisMaxQuery, DistanceFeatureQuery, ExistsQuery, FuzzyQuery, HasChildQuery, HasParentQuery, IdQuery, IntervalsQuery, MoreLikeThisQuery, NestedQuery, NoopQuery, ParentIdQuery, PercolateQuery, PinnedQuery, PrefixQuery, Query, QueryStringQuery, RangeQuery, RankFeatureQuery, RawQuery, RegexQuery, ScriptQuery, ScriptScoreQuery, SimpleStringQuery} +import com.sksamuel.elastic4s.requests.searches.queries.{BoostingQuery, CombinedFieldsQuery, ConstantScore, CustomQuery, DisMaxQuery, DistanceFeatureQuery, ExistsQuery, FuzzyQuery, HasChildQuery, HasParentQuery, IdQuery, IntervalsQuery, MoreLikeThisQuery, NestedQuery, NoopQuery, ParentIdQuery, PercolateQuery, PinnedQuery, PrefixQuery, Query, QueryStringQuery, RangeQuery, RankFeatureQuery, RawQuery, RegexQuery, ScriptQuery, ScriptScoreQuery, SimpleStringQuery, SparseVectorQuery} import com.sksamuel.elastic4s.requests.searches.span.{SpanContainingQuery, SpanFieldMaskingQuery, SpanFirstQuery, SpanMultiTermQuery, SpanNearQuery, SpanNotQuery, SpanOrQuery, SpanTermQuery, SpanWithinQuery} import com.sksamuel.elastic4s.requests.searches.term.{TermQuery, TermsLookupQuery, TermsQuery, TermsSetQuery, WildcardQuery} @@ -65,6 +65,7 @@ object QueryBuilderFn { case s: SpanOrQuery => SpanOrQueryBodyFn(s) case s: SpanTermQuery => SpanTermQueryBodyFn(s) case s: SpanWithinQuery => SpanWithinQueryBodyFn(s) + case s: SparseVectorQuery => SparseVectorQueryBuilderFn(s) case t: TermQuery => TermQueryBodyFn(t) case t: TermsQuery[_] => TermsQueryBodyFn(t) case t: TermsLookupQuery => TermsLookupQueryBodyFn(t) diff --git a/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/searches/queries/SparseVectorQueryBuilderFn.scala b/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/searches/queries/SparseVectorQueryBuilderFn.scala new file mode 100644 index 000000000..00fbe21e3 --- /dev/null +++ b/elastic4s-handlers/src/main/scala/com/sksamuel/elastic4s/handlers/searches/queries/SparseVectorQueryBuilderFn.scala @@ -0,0 +1,34 @@ +package com.sksamuel.elastic4s.handlers.searches.queries + +import com.sksamuel.elastic4s.json.{XContentBuilder, XContentFactory} +import com.sksamuel.elastic4s.requests.searches.queries.{PruningConfig, SparseVectorQuery} + +object SparseVectorQueryBuilderFn { + def apply(q: SparseVectorQuery): XContentBuilder = { + val builder = XContentFactory.jsonBuilder() + builder.startObject("sparse_vector") + builder.field("field", q.field) + + q.inferenceId.foreach(builder.field("inference_id", _)) + q.query.foreach(builder.field("query", _)) + if (q.queryVector.nonEmpty) { + builder.startObject("query_vector") + q.queryVector.foreach { case (k, v) => builder.field(k, v) } + builder.endObject() + } + q.boost.foreach(builder.field("boost", _)) + q.queryName.foreach(builder.field("_name", _)) + q.prune.foreach(builder.field("prune", _)) + q.pruningConfig.foreach { pc => + if (pc.tokensFreqRatioThreshold.nonEmpty || pc.tokensWeighThreshold.nonEmpty || pc.onlyScorePrunedTokens.nonEmpty) { + builder.startObject("pruning_config") + pc.tokensFreqRatioThreshold.foreach(builder.field("tokens_freq_ratio_threshold", _)) + pc.tokensWeighThreshold.foreach(builder.field("tokens_weight_threshold", _)) + pc.onlyScorePrunedTokens.foreach(builder.field("only_score_pruned_tokens", _)) + builder.endObject() + } + } + builder.endObject() + builder + } +} diff --git a/elastic4s-tests/src/test/scala/com/sksamuel/elastic4s/search/SearchDslTest.scala b/elastic4s-tests/src/test/scala/com/sksamuel/elastic4s/search/SearchDslTest.scala index d1fbc5415..b42c213fe 100644 --- a/elastic4s-tests/src/test/scala/com/sksamuel/elastic4s/search/SearchDslTest.scala +++ b/elastic4s-tests/src/test/scala/com/sksamuel/elastic4s/search/SearchDslTest.scala @@ -971,4 +971,9 @@ class SearchDslTest extends AnyFlatSpec with MockitoSugar with JsonSugar with On val req = search("music").matchAllQuery() docValues ("field1", "field2") req.request.entity.get.get should matchJsonResource("/json/search/search_doc_values.json") } + + it should "generate json for sparse vector query" in { + val req = search("index").query(sparseVectorQuery("test", Map("a" -> 0.3D))) + req.request.entity.get.get should matchJson("""{"query":{"sparse_vector":{"field":"test", "query_vector":{"a":0.3}}}}""") + } }